fx_callback/
callback.rs

1use fx_handle::Handle;
2use log::{debug, error, trace, warn};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::{Arc, Mutex};
6use std::time::Instant;
7use tokio::runtime::Runtime;
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9
10/// The unique identifier for a callback.
11pub type CallbackHandle = Handle;
12
13/// The subscription type for the interested event.
14/// Drop this subscription to remove the callback.
15pub type Subscription<T> = UnboundedReceiver<Arc<T>>;
16
17/// The subscriber type for the interested event.
18/// This can be used to send the interested event from multiple sources into one receiver.
19pub type Subscriber<T> = UnboundedSender<Arc<T>>;
20
21/// Allows adding callbacks to the struct.
22/// The struct will inform the [Subscription] when a certain event occurs.
23///
24/// # Example
25///
26/// ```rust,no_run
27/// use std::sync::Arc;
28/// use tokio::runtime::Runtime;
29/// use fx_callback::{Callback, MultiThreadedCallback};
30///
31/// #[derive(Debug)]
32/// pub enum MyEvent {
33///     Foo,
34///     Bar,
35/// }
36///
37/// async fn register_callback() {
38///     let callback = MultiThreadedCallback::<MyEvent>::new();
39///     let mut receiver = callback.subscribe();
40///
41///     let event = receiver.recv().await.unwrap();
42///     // do something with the event
43/// }
44/// ```
45pub trait Callback<T>: Debug
46where
47    T: Debug + Send + Sync,
48{
49    /// Subscribe to the interested event.
50    /// This creates a new [Subscription] that will be invoked with a shared instance of the event when the interested event occurs.
51    ///
52    /// # Example
53    ///
54    /// ```rust,no_run
55    /// use fx_callback::Callback;
56    ///
57    /// #[derive(Debug, Clone, PartialEq)]
58    /// pub enum MyEvent {
59    ///     Foo,
60    /// }
61    ///
62    /// async fn example(callback: &dyn Callback<MyEvent>) {
63    ///     let mut receiver = callback.subscribe();
64    ///     
65    ///     if let Some(event) = receiver.recv().await {
66    ///         // do something with the event
67    ///     }
68    /// }
69    ///
70    /// ```
71    ///
72    /// # Returns
73    ///
74    /// It returns a [Subscription] which can be dropped to remove the callback.
75    fn subscribe(&self) -> Subscription<T>;
76
77    /// Subscribe to the interested event with a [Subscriber].
78    /// This creates an underlying new subscription which will be invoked with the given subscriber when the interested event occurs.
79    ///
80    /// ## Remarks
81    ///
82    /// It is possible to grant multiple subscriptions from the same source to the same interested event,
83    /// as the [Callback] is only a holder for the [Subscription] and can't detect any duplicates.
84    fn subscribe_with(&self, subscriber: Subscriber<T>);
85}
86
87/// A multithreaded callback holder.
88///
89/// This callback holder will invoke the given events on a separate thread, thus unblocking the caller thread for other tasks.
90///
91/// # Example
92///
93/// ```rust,no_run
94/// use fx_callback::{Callback, MultiThreadedCallback, Subscriber, Subscription};
95///
96/// /// The events of the struct that informs subscribers about changes to the data within the struct.
97/// #[derive(Debug, Clone, PartialEq)]
98/// enum MyEvent {
99///     Foo,
100/// }
101///
102/// /// The struct to which an interested subscriber can subscribe to.
103/// #[derive(Debug)]
104/// struct Example {
105///     callbacks: MultiThreadedCallback<MyEvent>,
106/// }
107///
108/// impl Example {
109///     fn invoke_event(&self) {
110///         self.callbacks.invoke(MyEvent::Foo);
111///     }
112/// }
113///
114/// impl Callback<MyEvent> for Example {
115///     fn subscribe(&self) -> Subscription<MyEvent> {
116///         self.callbacks.subscribe()
117///     }
118///
119///     fn subscribe_with(&self, subscriber: Subscriber<MyEvent>) {
120///         self.callbacks.subscribe_with(subscriber)
121///     }
122/// }
123/// ```
124#[derive(Debug)]
125pub struct MultiThreadedCallback<T>
126where
127    T: Debug + Send + Sync,
128{
129    base: Arc<BaseCallback<T>>,
130    runtime: Arc<Mutex<Option<Runtime>>>,
131}
132
133impl<T> Callback<T> for MultiThreadedCallback<T>
134where
135    T: Debug + Send + Sync,
136{
137    fn subscribe(&self) -> Subscription<T> {
138        self.base.subscribe()
139    }
140
141    fn subscribe_with(&self, subscriber: Subscriber<T>) {
142        self.base.subscribe_with(subscriber)
143    }
144}
145
146impl<T> Clone for MultiThreadedCallback<T>
147where
148    T: Debug + Send + Sync,
149{
150    fn clone(&self) -> Self {
151        Self {
152            base: self.base.clone(),
153            runtime: self.runtime.clone(),
154        }
155    }
156}
157
158impl<T> MultiThreadedCallback<T>
159where
160    T: Debug + Send + Sync + 'static,
161{
162    /// Creates a new multithreaded callback.
163    pub fn new() -> Self {
164        Self {
165            base: Arc::new(BaseCallback::<T>::new()),
166            runtime: Arc::new(Mutex::new(None)),
167        }
168    }
169
170    /// Invoke the currently registered callbacks and inform them of the given value.
171    ///
172    /// # Arguments
173    ///
174    /// * `value` - The value to invoke the callbacks with.
175    pub fn invoke(&self, value: T) {
176        let inner = self.base.clone();
177        match tokio::runtime::Handle::try_current() {
178            Ok(_) => {
179                // spawn the invocation operation in a new thread
180                tokio::spawn(async move {
181                    inner.invoke(value);
182                });
183            }
184            Err(_) => match self.runtime.lock() {
185                Ok(mut runtime) => {
186                    runtime
187                        .get_or_insert_with(|| Runtime::new().unwrap())
188                        .spawn(async move {
189                            inner.invoke(value);
190                        });
191                }
192                Err(e) => error!("Failed to acquire lock: {}", e),
193            },
194        }
195    }
196}
197
198/// A single threaded or current threaded callback holder.
199///
200/// This callback holder will invoke the given events on the current thread, thus blocking the caller thread for other tasks.
201#[derive(Debug, Clone)]
202pub struct SingleThreadedCallback<T>
203where
204    T: Debug + Send + Sync,
205{
206    base: Arc<BaseCallback<T>>,
207}
208
209impl<T> SingleThreadedCallback<T>
210where
211    T: Debug + Send + Sync,
212{
213    /// Create a new single/current threaded callback holder.
214    pub fn new() -> Self {
215        Self {
216            base: Arc::new(BaseCallback::<T>::new()),
217        }
218    }
219
220    /// Invoke the currently registered callbacks and inform them of the given value.
221    ///
222    /// # Arguments
223    ///
224    /// * `value` - The value to invoke the callbacks with.
225    pub fn invoke(&self, value: T) {
226        self.base.invoke(value)
227    }
228}
229
230impl<T> Callback<T> for SingleThreadedCallback<T>
231where
232    T: Debug + Send + Sync,
233{
234    fn subscribe(&self) -> Subscription<T> {
235        self.base.subscribe()
236    }
237
238    fn subscribe_with(&self, subscriber: Subscriber<T>) {
239        self.base.subscribe_with(subscriber)
240    }
241}
242
243struct BaseCallback<T>
244where
245    T: Debug + Send + Sync,
246{
247    callbacks: Mutex<HashMap<CallbackHandle, UnboundedSender<Arc<T>>>>,
248}
249
250impl<T> BaseCallback<T>
251where
252    T: Debug + Send + Sync,
253{
254    fn new() -> Self {
255        Self {
256            callbacks: Mutex::new(HashMap::new()),
257        }
258    }
259
260    fn subscribe(&self) -> Subscription<T> {
261        let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
262        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
263        let handle = CallbackHandle::new();
264        mutex.insert(handle, tx);
265        drop(mutex);
266        trace!("Added callback {} to {:?}", handle, self);
267        rx
268    }
269
270    fn subscribe_with(&self, subscriber: Subscriber<T>) {
271        let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
272        let handle = CallbackHandle::new();
273        mutex.insert(handle, subscriber);
274        drop(mutex);
275        trace!("Added callback {} to {:?}", handle, self);
276    }
277
278    fn invoke(&self, value: T) {
279        let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
280        let value = Arc::new(value);
281
282        trace!(
283            "Invoking a total of {} callbacks for {:?}",
284            mutex.len(),
285            *value
286        );
287
288        let handles_to_remove: Vec<CallbackHandle> = mutex
289            .iter()
290            .map(|(handle, callback)| {
291                BaseCallback::invoke_callback(handle, callback, value.clone())
292            })
293            .flat_map(|e| e)
294            .collect();
295
296        let total_handles = handles_to_remove.len();
297        for handle in handles_to_remove {
298            mutex.remove(&handle);
299        }
300
301        if total_handles > 0 {
302            debug!("Removed a total of {} callbacks", total_handles);
303        }
304    }
305
306    /// Try to invoke the callback for the given value.
307    /// This is a convenience method for handling dropped callbacks.
308    ///
309    /// # Returns
310    ///
311    /// It returns the callback handle if the callback has been dropped.
312    fn invoke_callback(
313        handle: &CallbackHandle,
314        callback: &UnboundedSender<Arc<T>>,
315        value: Arc<T>,
316    ) -> Option<CallbackHandle> {
317        let start_time = Instant::now();
318        if let Err(_) = callback.send(value) {
319            trace!("Callback {} has been dropped", handle);
320            return Some(handle.clone());
321        }
322        let elapsed = start_time.elapsed();
323        let message = format!(
324            "Callback {} took {}.{:03}ms to process the invocation",
325            handle,
326            elapsed.as_millis(),
327            elapsed.subsec_micros() % 1000
328        );
329        if elapsed.as_millis() >= 1000 {
330            warn!("{}", message);
331        } else {
332            trace!("{}", message);
333        }
334
335        None
336    }
337}
338
339impl<T> Debug for BaseCallback<T>
340where
341    T: Debug + Send + Sync,
342{
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        f.debug_struct("BaseCallback")
345            .field("callbacks", &self.callbacks.lock().unwrap().len())
346            .finish()
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::init_logger;
354    use std::sync::mpsc::channel;
355    use std::time::Duration;
356    use tokio::{select, time};
357
358    #[derive(Debug, Clone, PartialEq)]
359    pub enum Event {
360        Foo,
361    }
362
363    #[derive(Debug, PartialEq)]
364    enum NoneCloneEvent {
365        Bar,
366    }
367
368    #[tokio::test]
369    async fn test_multi_threaded_invoke() {
370        init_logger!();
371        let expected_result = Event::Foo;
372        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
373        let callback = MultiThreadedCallback::<Event>::new();
374
375        let mut receiver = callback.subscribe();
376        tokio::spawn(async move {
377            if let Some(e) = receiver.recv().await {
378                let _ = tx.send(e).await;
379            }
380        });
381
382        callback.invoke(expected_result.clone());
383        let result = select! {
384            _ = time::sleep(Duration::from_millis(150)) => {
385                panic!("Callback invocation receiver timed out")
386            },
387            Some(result) = rx.recv() => result,
388        };
389
390        assert_eq!(expected_result, *result);
391    }
392
393    #[test]
394    fn test_multi_threaded_invoke_without_runtime() {
395        init_logger!();
396        let expected_result = Event::Foo;
397        let (tx, rx) = channel();
398        let runtime = Runtime::new().unwrap();
399        let callback = MultiThreadedCallback::<Event>::new();
400
401        let mut receiver = callback.subscribe();
402        runtime.spawn(async move {
403            if let Some(e) = receiver.recv().await {
404                tx.send(e).unwrap();
405            }
406        });
407
408        callback.invoke(expected_result.clone());
409        let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
410
411        assert_eq!(expected_result, *result);
412    }
413
414    #[tokio::test]
415    async fn test_invoke_dropped_receiver() {
416        init_logger!();
417        let expected_result = Event::Foo;
418        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
419        let callback = MultiThreadedCallback::<Event>::new();
420
421        let _ = callback.subscribe();
422        let mut receiver = callback.subscribe();
423        tokio::spawn(async move {
424            if let Some(e) = receiver.recv().await {
425                let _ = tx.send(e).await;
426            }
427        });
428
429        callback.invoke(expected_result.clone());
430        let result = select! {
431            _ = time::sleep(Duration::from_millis(150)) => {
432                panic!("Callback invocation receiver timed out")
433            },
434            Some(result) = rx.recv() => result,
435        };
436
437        assert_eq!(expected_result, *result);
438    }
439
440    #[tokio::test]
441    async fn test_non_cloneable_type() {
442        init_logger!();
443        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
444        let callback = MultiThreadedCallback::<NoneCloneEvent>::new();
445
446        let mut receiver = callback.subscribe();
447        tokio::spawn(async move {
448            if let Some(e) = receiver.recv().await {
449                let _ = tx.send(e).await;
450            }
451        });
452
453        callback.invoke(NoneCloneEvent::Bar);
454        let result = select! {
455            _ = time::sleep(Duration::from_millis(150)) => {
456                panic!("Callback invocation receiver timed out")
457            },
458            Some(result) = rx.recv() => result,
459        };
460
461        assert_eq!(NoneCloneEvent::Bar, *result);
462    }
463
464    #[test]
465    fn test_single_threaded_invoke() {
466        init_logger!();
467        let expected_result = Event::Foo;
468        let runtime = Runtime::new().unwrap();
469        let (tx, rx) = channel();
470        let callback = SingleThreadedCallback::new();
471
472        let mut receiver = callback.subscribe();
473        runtime.spawn(async move {
474            if let Some(e) = receiver.recv().await {
475                tx.send(e).unwrap();
476            }
477        });
478
479        callback.invoke(expected_result.clone());
480        let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
481
482        assert_eq!(expected_result, *result);
483    }
484}