fx_callback/
callback.rs

1use fx_handle::Handle;
2use log::{debug, error, trace, warn};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::{Arc, Mutex};
6use std::time::Instant;
7use tokio::runtime::Runtime;
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9
10/// The unique identifier for a callback.
11pub type CallbackHandle = Handle;
12
13/// The subscription type for the interested event.
14/// Drop this subscription to remove the callback.
15pub type Subscription<T> = UnboundedReceiver<Arc<T>>;
16
17/// The subscriber type for the interested event.
18/// This can be used to send the interested event from multiple sources into one receiver.
19pub type Subscriber<T> = UnboundedSender<Arc<T>>;
20
21/// Allows adding callbacks to the struct.
22/// The struct will inform the [Subscription] when a certain event occurs.
23///
24/// # Example
25///
26/// ```rust,no_run
27/// use std::sync::Arc;
28/// use tokio::runtime::Runtime;
29/// use fx_callback::{Callback, MultiThreadedCallback};
30///
31/// #[derive(Debug)]
32/// pub enum MyEvent {
33///     Foo,
34///     Bar,
35/// }
36///
37/// async fn register_callback() {
38///     let callback = MultiThreadedCallback::<MyEvent>::new();
39///     let mut receiver = callback.subscribe();
40///
41///     let event = receiver.recv().await.unwrap();
42///     // do something with the event
43/// }
44/// ```
45pub trait Callback<T>: Debug
46where
47    T: Debug + Send + Sync,
48{
49    /// Subscribe to the interested event.
50    /// This creates a new [Subscription] that will be invoked with a shared instance of the event when the interested event occurs.
51    ///
52    /// # Example
53    ///
54    /// ```rust,no_run
55    /// use fx_callback::Callback;
56    ///
57    /// #[derive(Debug, Clone, PartialEq)]
58    /// pub enum MyEvent {
59    ///     Foo,
60    /// }
61    ///
62    /// async fn example(callback: &dyn Callback<MyEvent>) {
63    ///     let mut receiver = callback.subscribe();
64    ///     
65    ///     if let Some(event) = receiver.recv().await {
66    ///         // do something with the event
67    ///     }
68    /// }
69    ///
70    /// ```
71    ///
72    /// # Returns
73    ///
74    /// It returns a [Subscription] which can be dropped to remove the callback.
75    fn subscribe(&self) -> Subscription<T>;
76
77    /// Subscribe to the interested event with a [Subscriber].
78    /// This creates an underlying new subscription which will be invoked with the given subscriber when the interested event occurs.
79    ///
80    /// ## Remarks
81    ///
82    /// It is possible to grant multiple subscriptions from the same source to the same interested event,
83    /// as the [Callback] is only a holder for the [Subscription] and can't detect any duplicates.
84    fn subscribe_with(&self, subscriber: Subscriber<T>);
85}
86
87/// A multithreaded callback holder.
88///
89/// This callback holder will invoke the given events on a separate thread, thus unblocking the caller thread for other tasks.
90///
91/// # Example
92///
93/// ```rust,no_run
94/// use fx_callback::{Callback, MultiThreadedCallback, Subscriber, Subscription};
95///
96/// /// The events of the struct that informs subscribers about changes to the data within the struct.
97/// #[derive(Debug, Clone, PartialEq)]
98/// enum MyEvent {
99///     Foo,
100/// }
101///
102/// /// The struct to which an interested subscriber can subscribe to.
103/// #[derive(Debug)]
104/// struct Example {
105///     callbacks: MultiThreadedCallback<MyEvent>,
106/// }
107///
108/// impl Example {
109///     fn invoke_event(&self) {
110///         self.callbacks.invoke(MyEvent::Foo);
111///     }
112/// }
113///
114/// impl Callback<MyEvent> for Example {
115///     fn subscribe(&self) -> Subscription<MyEvent> {
116///         self.callbacks.subscribe()
117///     }
118///
119///     fn subscribe_with(&self, subscriber: Subscriber<MyEvent>) {
120///         self.callbacks.subscribe_with(subscriber)
121///     }
122/// }
123/// ```
124#[derive(Debug, Clone)]
125pub struct MultiThreadedCallback<T>
126where
127    T: Debug + Send + Sync,
128{
129    base: Arc<BaseCallback<T>>,
130    runtime: Arc<Mutex<Option<Runtime>>>,
131}
132
133impl<T> Callback<T> for MultiThreadedCallback<T>
134where
135    T: Debug + Send + Sync,
136{
137    fn subscribe(&self) -> Subscription<T> {
138        self.base.subscribe()
139    }
140
141    fn subscribe_with(&self, subscriber: Subscriber<T>) {
142        self.base.subscribe_with(subscriber)
143    }
144}
145
146impl<T> MultiThreadedCallback<T>
147where
148    T: Debug + Send + Sync + 'static,
149{
150    /// Creates a new multithreaded callback.
151    pub fn new() -> Self {
152        Self {
153            base: Arc::new(BaseCallback::<T>::new()),
154            runtime: Arc::new(Mutex::new(None)),
155        }
156    }
157
158    /// Invoke the currently registered callbacks and inform them of the given value.
159    ///
160    /// # Arguments
161    ///
162    /// * `value` - The value to invoke the callbacks with.
163    pub fn invoke(&self, value: T) {
164        let inner = self.base.clone();
165        match tokio::runtime::Handle::try_current() {
166            Ok(_) => {
167                // spawn the invocation operation in a new thread
168                tokio::spawn(async move {
169                    inner.invoke(value);
170                });
171            }
172            Err(_) => match self.runtime.lock() {
173                Ok(mut runtime) => {
174                    runtime
175                        .get_or_insert_with(|| Runtime::new().unwrap())
176                        .spawn(async move {
177                            inner.invoke(value);
178                        });
179                }
180                Err(e) => error!("Failed to acquire lock: {}", e),
181            },
182        }
183    }
184}
185
186/// A single threaded or current threaded callback holder.
187///
188/// This callback holder will invoke the given events on the current thread, thus blocking the caller thread for other tasks.
189#[derive(Debug, Clone)]
190pub struct SingleThreadedCallback<T>
191where
192    T: Debug + Send + Sync,
193{
194    base: Arc<BaseCallback<T>>,
195}
196
197impl<T> SingleThreadedCallback<T>
198where
199    T: Debug + Send + Sync,
200{
201    /// Create a new single/current threaded callback holder.
202    pub fn new() -> Self {
203        Self {
204            base: Arc::new(BaseCallback::<T>::new()),
205        }
206    }
207
208    /// Invoke the currently registered callbacks and inform them of the given value.
209    ///
210    /// # Arguments
211    ///
212    /// * `value` - The value to invoke the callbacks with.
213    pub fn invoke(&self, value: T) {
214        self.base.invoke(value)
215    }
216}
217
218impl<T> Callback<T> for SingleThreadedCallback<T>
219where
220    T: Debug + Send + Sync,
221{
222    fn subscribe(&self) -> Subscription<T> {
223        self.base.subscribe()
224    }
225
226    fn subscribe_with(&self, subscriber: Subscriber<T>) {
227        self.base.subscribe_with(subscriber)
228    }
229}
230
231struct BaseCallback<T>
232where
233    T: Debug + Send + Sync,
234{
235    callbacks: Mutex<HashMap<CallbackHandle, UnboundedSender<Arc<T>>>>,
236}
237
238impl<T> BaseCallback<T>
239where
240    T: Debug + Send + Sync,
241{
242    fn new() -> Self {
243        Self {
244            callbacks: Mutex::new(HashMap::new()),
245        }
246    }
247
248    fn subscribe(&self) -> Subscription<T> {
249        let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
250        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
251        let handle = CallbackHandle::new();
252        mutex.insert(handle, tx);
253        drop(mutex);
254        trace!("Added callback {} to {:?}", handle, self);
255        rx
256    }
257
258    fn subscribe_with(&self, subscriber: Subscriber<T>) {
259        let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
260        let handle = CallbackHandle::new();
261        mutex.insert(handle, subscriber);
262        drop(mutex);
263        trace!("Added callback {} to {:?}", handle, self);
264    }
265
266    fn invoke(&self, value: T) {
267        let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
268        let value = Arc::new(value);
269
270        trace!(
271            "Invoking a total of {} callbacks for {:?}",
272            mutex.len(),
273            *value
274        );
275
276        let handles_to_remove: Vec<CallbackHandle> = mutex
277            .iter()
278            .map(|(handle, callback)| {
279                BaseCallback::invoke_callback(handle, callback, value.clone())
280            })
281            .flat_map(|e| e)
282            .collect();
283
284        let total_handles = handles_to_remove.len();
285        for handle in handles_to_remove {
286            mutex.remove(&handle);
287        }
288
289        if total_handles > 0 {
290            debug!("Removed a total of {} callbacks", total_handles);
291        }
292    }
293
294    /// Try to invoke the callback for the given value.
295    /// This is a convenience method for handling dropped callbacks.
296    ///
297    /// # Returns
298    ///
299    /// It returns the callback handle if the callback has been dropped.
300    fn invoke_callback(
301        handle: &CallbackHandle,
302        callback: &UnboundedSender<Arc<T>>,
303        value: Arc<T>,
304    ) -> Option<CallbackHandle> {
305        let start_time = Instant::now();
306        if let Err(_) = callback.send(value) {
307            trace!("Callback {} has been dropped", handle);
308            return Some(handle.clone());
309        }
310        let elapsed = start_time.elapsed();
311        let message = format!(
312            "Callback {} took {}.{:03}ms to process the invocation",
313            handle,
314            elapsed.as_millis(),
315            elapsed.subsec_micros() % 1000
316        );
317        if elapsed.as_millis() >= 1000 {
318            warn!("{}", message);
319        } else {
320            trace!("{}", message);
321        }
322
323        None
324    }
325}
326
327impl<T> Debug for BaseCallback<T>
328where
329    T: Debug + Send + Sync,
330{
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        f.debug_struct("BaseCallback")
333            .field("callbacks", &self.callbacks.lock().unwrap().len())
334            .finish()
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::init_logger;
342    use std::sync::mpsc::channel;
343    use std::time::Duration;
344    use tokio::{select, time};
345
346    #[derive(Debug, Clone, PartialEq)]
347    pub enum Event {
348        Foo,
349    }
350
351    #[tokio::test]
352    async fn test_multi_threaded_invoke() {
353        init_logger!();
354        let expected_result = Event::Foo;
355        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
356        let callback = MultiThreadedCallback::<Event>::new();
357
358        let mut receiver = callback.subscribe();
359        tokio::spawn(async move {
360            if let Some(e) = receiver.recv().await {
361                let _ = tx.send(e).await;
362            }
363        });
364
365        callback.invoke(expected_result.clone());
366        let result = select! {
367            _ = time::sleep(Duration::from_millis(150)) => {
368                panic!("Callback invocation receiver timed out")
369            },
370            Some(result) = rx.recv() => result,
371        };
372
373        assert_eq!(expected_result, *result);
374    }
375
376    #[test]
377    fn test_multi_threaded_invoke_without_runtime() {
378        init_logger!();
379        let expected_result = Event::Foo;
380        let (tx, rx) = channel();
381        let runtime = Runtime::new().unwrap();
382        let callback = MultiThreadedCallback::<Event>::new();
383
384        let mut receiver = callback.subscribe();
385        runtime.spawn(async move {
386            if let Some(e) = receiver.recv().await {
387                tx.send(e).unwrap();
388            }
389        });
390
391        callback.invoke(expected_result.clone());
392        let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
393
394        assert_eq!(expected_result, *result);
395    }
396
397    #[tokio::test]
398    async fn test_invoke_dropped_receiver() {
399        init_logger!();
400        let expected_result = Event::Foo;
401        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
402        let callback = MultiThreadedCallback::<Event>::new();
403
404        let _ = callback.subscribe();
405        let mut receiver = callback.subscribe();
406        tokio::spawn(async move {
407            if let Some(e) = receiver.recv().await {
408                let _ = tx.send(e).await;
409            }
410        });
411
412        callback.invoke(expected_result.clone());
413        let result = select! {
414            _ = time::sleep(Duration::from_millis(150)) => {
415                panic!("Callback invocation receiver timed out")
416            },
417            Some(result) = rx.recv() => result,
418        };
419
420        assert_eq!(expected_result, *result);
421    }
422
423    #[test]
424    fn test_single_threaded_invoke() {
425        init_logger!();
426        let expected_result = Event::Foo;
427        let runtime = Runtime::new().unwrap();
428        let (tx, rx) = channel();
429        let callback = SingleThreadedCallback::new();
430
431        let mut receiver = callback.subscribe();
432        runtime.spawn(async move {
433            if let Some(e) = receiver.recv().await {
434                tx.send(e).unwrap();
435            }
436        });
437
438        callback.invoke(expected_result.clone());
439        let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
440
441        assert_eq!(expected_result, *result);
442    }
443}