Skip to main content

fx_callback/
callback.rs

1use std::fmt::Debug;
2use std::sync::Arc;
3use tokio::sync::broadcast;
4use tokio::sync::broadcast::{Receiver, Sender};
5
6/// The subscription type for the interested event.
7/// Drop this subscription to remove the callback.
8pub type Subscription<T> = Receiver<Arc<T>>;
9
10/// Allows adding callbacks to the struct.
11/// The struct will inform the [Subscription] when a certain event occurs.
12///
13/// # Example
14///
15/// ```rust,no_run
16/// use std::sync::Arc;
17/// use tokio::runtime::Runtime;
18/// use fx_callback::{Callback, MultiThreadedCallback};
19///
20/// #[derive(Debug)]
21/// pub enum MyEvent {
22///     Foo,
23///     Bar,
24/// }
25///
26/// async fn register_callback() {
27///     let callback = MultiThreadedCallback::<MyEvent>::new();
28///     let mut receiver = callback.subscribe();
29///
30///     let event = receiver.recv().await.unwrap();
31///     // do something with the event
32/// }
33/// ```
34pub trait Callback<T>: Debug
35where
36    T: Debug,
37{
38    /// Subscribe to the interested event.
39    /// This creates a new [Subscription] that will be invoked with a shared instance of the event when the interested event occurs.
40    ///
41    /// # Example
42    ///
43    /// ```rust,no_run
44    /// use fx_callback::Callback;
45    ///
46    /// #[derive(Debug, Clone, PartialEq)]
47    /// pub enum MyEvent {
48    ///     Foo,
49    /// }
50    ///
51    /// async fn example(callback: &dyn Callback<MyEvent>) {
52    ///     let mut receiver = callback.subscribe();
53    ///     
54    ///     if let Some(event) = receiver.recv().await {
55    ///         // do something with the event
56    ///     }
57    /// }
58    ///
59    /// ```
60    ///
61    /// # Returns
62    ///
63    /// It returns a [Subscription] which can be dropped to remove the callback.
64    fn subscribe(&self) -> Subscription<T>;
65}
66
67/// A multithreaded callback holder.
68///
69/// This callback holder will invoke the given events on a separate thread, thus unblocking the caller thread for other tasks.
70///
71/// # Example
72///
73/// ```rust,no_run
74/// use fx_callback::{Callback, MultiThreadedCallback, Subscription};
75///
76/// /// The events of the struct that informs subscribers about changes to the data within the struct.
77/// #[derive(Debug, Clone, PartialEq)]
78/// enum MyEvent {
79///     Foo,
80/// }
81///
82/// /// The struct to which an interested subscriber can subscribe to.
83/// #[derive(Debug)]
84/// struct Example {
85///     callbacks: MultiThreadedCallback<MyEvent>,
86/// }
87///
88/// impl Example {
89///     fn invoke_event(&self) {
90///         self.callbacks.invoke(MyEvent::Foo);
91///     }
92/// }
93///
94/// impl Callback<MyEvent> for Example {
95///     fn subscribe(&self) -> Subscription<MyEvent> {
96///         self.callbacks.subscribe()
97///     }
98/// }
99/// ```
100#[derive(Debug)]
101pub struct MultiThreadedCallback<T>
102where
103    T: Debug,
104{
105    sender: Sender<Arc<T>>,
106}
107
108impl<T> Callback<T> for MultiThreadedCallback<T>
109where
110    T: Debug,
111{
112    fn subscribe(&self) -> Subscription<T> {
113        self.sender.subscribe()
114    }
115}
116
117impl<T> Clone for MultiThreadedCallback<T>
118where
119    T: Debug,
120{
121    fn clone(&self) -> Self {
122        Self {
123            sender: self.sender.clone(),
124        }
125    }
126}
127
128impl<T> MultiThreadedCallback<T>
129where
130    T: Debug + Send + Sync + 'static,
131{
132    /// Creates a new multithreaded callback.
133    pub fn new() -> Self {
134        Self::new_with_capacity(512)
135    }
136
137    /// Creates a new multithreaded callback with a specified capacity.
138    pub fn new_with_capacity(capacity: usize) -> Self {
139        let (sender, _) = broadcast::channel(capacity);
140        Self { sender }
141    }
142
143    /// Invoke the currently registered callbacks and inform them of the given value.
144    ///
145    /// # Arguments
146    ///
147    /// * `value` - The value to invoke the callbacks with.
148    pub fn invoke(&self, value: T) {
149        let _ = self.sender.send(Arc::new(value));
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use std::time::Duration;
157    use tokio::{select, time};
158
159    #[derive(Debug, Clone, PartialEq)]
160    pub enum Event {
161        Foo,
162    }
163
164    #[derive(Debug, PartialEq)]
165    enum NoneCloneEvent {
166        Bar,
167    }
168
169    #[tokio::test]
170    async fn test_multi_threaded_invoke() {
171        init_logger!();
172        let expected_result = Event::Foo;
173        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
174        let callback = MultiThreadedCallback::<Event>::new();
175
176        let mut receiver = callback.subscribe();
177        tokio::spawn(async move {
178            if let Ok(e) = receiver.recv().await {
179                let _ = tx.send(e).await;
180            }
181        });
182
183        callback.invoke(expected_result.clone());
184        let result = select! {
185            _ = time::sleep(Duration::from_millis(150)) => {
186                panic!("Callback invocation receiver timed out")
187            },
188            Some(result) = rx.recv() => result,
189        };
190
191        assert_eq!(expected_result, *result);
192    }
193
194    #[tokio::test]
195    async fn test_invoke_dropped_receiver() {
196        init_logger!();
197        let expected_result = Event::Foo;
198        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
199        let callback = MultiThreadedCallback::<Event>::new();
200
201        let _ = callback.subscribe();
202        let mut receiver = callback.subscribe();
203        tokio::spawn(async move {
204            if let Ok(e) = receiver.recv().await {
205                let _ = tx.send(e).await;
206            }
207        });
208
209        callback.invoke(expected_result.clone());
210        let result = select! {
211            _ = time::sleep(Duration::from_millis(150)) => {
212                panic!("Callback invocation receiver timed out")
213            },
214            Some(result) = rx.recv() => result,
215        };
216
217        assert_eq!(expected_result, *result);
218    }
219
220    #[tokio::test]
221    async fn test_non_cloneable_type() {
222        init_logger!();
223        let (tx, mut rx) = tokio::sync::mpsc::channel(1);
224        let callback = MultiThreadedCallback::<NoneCloneEvent>::new();
225
226        let mut receiver = callback.subscribe();
227        tokio::spawn(async move {
228            if let Ok(e) = receiver.recv().await {
229                let _ = tx.send(e).await;
230            }
231        });
232
233        callback.invoke(NoneCloneEvent::Bar);
234        let result = select! {
235            _ = time::sleep(Duration::from_millis(150)) => {
236                panic!("Callback invocation receiver timed out")
237            },
238            Some(result) = rx.recv() => result,
239        };
240
241        assert_eq!(NoneCloneEvent::Bar, *result);
242    }
243}