1use fx_handle::Handle;
2use log::{debug, error, trace, warn};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::{Arc, Mutex};
6use std::time::Instant;
7use tokio::runtime::Runtime;
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9
10pub type CallbackHandle = Handle;
12
13pub type Subscription<T> = UnboundedReceiver<Arc<T>>;
16
17pub type Subscriber<T> = UnboundedSender<Arc<T>>;
20
21pub trait Callback<T>: Debug
46where
47 T: Debug + Send + Sync,
48{
49 fn subscribe(&self) -> Subscription<T>;
76
77 fn subscribe_with(&self, subscriber: Subscriber<T>);
85}
86
87#[derive(Debug, Clone)]
125pub struct MultiThreadedCallback<T>
126where
127 T: Debug + Send + Sync,
128{
129 base: Arc<BaseCallback<T>>,
130 runtime: Arc<Mutex<Option<Runtime>>>,
131}
132
133impl<T> Callback<T> for MultiThreadedCallback<T>
134where
135 T: Debug + Send + Sync,
136{
137 fn subscribe(&self) -> Subscription<T> {
138 self.base.subscribe()
139 }
140
141 fn subscribe_with(&self, subscriber: Subscriber<T>) {
142 self.base.subscribe_with(subscriber)
143 }
144}
145
146impl<T> MultiThreadedCallback<T>
147where
148 T: Debug + Send + Sync + 'static,
149{
150 pub fn new() -> Self {
152 Self {
153 base: Arc::new(BaseCallback::<T>::new()),
154 runtime: Arc::new(Mutex::new(None)),
155 }
156 }
157
158 pub fn invoke(&self, value: T) {
164 let inner = self.base.clone();
165 match tokio::runtime::Handle::try_current() {
166 Ok(_) => {
167 tokio::spawn(async move {
169 inner.invoke(value);
170 });
171 }
172 Err(_) => match self.runtime.lock() {
173 Ok(mut runtime) => {
174 runtime
175 .get_or_insert_with(|| Runtime::new().unwrap())
176 .spawn(async move {
177 inner.invoke(value);
178 });
179 }
180 Err(e) => error!("Failed to acquire lock: {}", e),
181 },
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
190pub struct SingleThreadedCallback<T>
191where
192 T: Debug + Send + Sync,
193{
194 base: Arc<BaseCallback<T>>,
195}
196
197impl<T> SingleThreadedCallback<T>
198where
199 T: Debug + Send + Sync,
200{
201 pub fn new() -> Self {
203 Self {
204 base: Arc::new(BaseCallback::<T>::new()),
205 }
206 }
207
208 pub fn invoke(&self, value: T) {
214 self.base.invoke(value)
215 }
216}
217
218impl<T> Callback<T> for SingleThreadedCallback<T>
219where
220 T: Debug + Send + Sync,
221{
222 fn subscribe(&self) -> Subscription<T> {
223 self.base.subscribe()
224 }
225
226 fn subscribe_with(&self, subscriber: Subscriber<T>) {
227 self.base.subscribe_with(subscriber)
228 }
229}
230
231struct BaseCallback<T>
232where
233 T: Debug + Send + Sync,
234{
235 callbacks: Mutex<HashMap<CallbackHandle, UnboundedSender<Arc<T>>>>,
236}
237
238impl<T> BaseCallback<T>
239where
240 T: Debug + Send + Sync,
241{
242 fn new() -> Self {
243 Self {
244 callbacks: Mutex::new(HashMap::new()),
245 }
246 }
247
248 fn subscribe(&self) -> Subscription<T> {
249 let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
250 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
251 let handle = CallbackHandle::new();
252 mutex.insert(handle, tx);
253 drop(mutex);
254 trace!("Added callback {} to {:?}", handle, self);
255 rx
256 }
257
258 fn subscribe_with(&self, subscriber: Subscriber<T>) {
259 let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
260 let handle = CallbackHandle::new();
261 mutex.insert(handle, subscriber);
262 drop(mutex);
263 trace!("Added callback {} to {:?}", handle, self);
264 }
265
266 fn invoke(&self, value: T) {
267 let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
268 let value = Arc::new(value);
269
270 trace!(
271 "Invoking a total of {} callbacks for {:?}",
272 mutex.len(),
273 *value
274 );
275
276 let handles_to_remove: Vec<CallbackHandle> = mutex
277 .iter()
278 .map(|(handle, callback)| {
279 BaseCallback::invoke_callback(handle, callback, value.clone())
280 })
281 .flat_map(|e| e)
282 .collect();
283
284 let total_handles = handles_to_remove.len();
285 for handle in handles_to_remove {
286 mutex.remove(&handle);
287 }
288
289 if total_handles > 0 {
290 debug!("Removed a total of {} callbacks", total_handles);
291 }
292 }
293
294 fn invoke_callback(
301 handle: &CallbackHandle,
302 callback: &UnboundedSender<Arc<T>>,
303 value: Arc<T>,
304 ) -> Option<CallbackHandle> {
305 let start_time = Instant::now();
306 if let Err(_) = callback.send(value) {
307 trace!("Callback {} has been dropped", handle);
308 return Some(handle.clone());
309 }
310 let elapsed = start_time.elapsed();
311 let message = format!(
312 "Callback {} took {}.{:03}ms to process the invocation",
313 handle,
314 elapsed.as_millis(),
315 elapsed.subsec_micros() % 1000
316 );
317 if elapsed.as_millis() >= 1000 {
318 warn!("{}", message);
319 } else {
320 trace!("{}", message);
321 }
322
323 None
324 }
325}
326
327impl<T> Debug for BaseCallback<T>
328where
329 T: Debug + Send + Sync,
330{
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 f.debug_struct("BaseCallback")
333 .field("callbacks", &self.callbacks.lock().unwrap().len())
334 .finish()
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::init_logger;
342 use std::sync::mpsc::channel;
343 use std::time::Duration;
344 use tokio::{select, time};
345
346 #[derive(Debug, Clone, PartialEq)]
347 pub enum Event {
348 Foo,
349 }
350
351 #[tokio::test]
352 async fn test_multi_threaded_invoke() {
353 init_logger!();
354 let expected_result = Event::Foo;
355 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
356 let callback = MultiThreadedCallback::<Event>::new();
357
358 let mut receiver = callback.subscribe();
359 tokio::spawn(async move {
360 if let Some(e) = receiver.recv().await {
361 let _ = tx.send(e).await;
362 }
363 });
364
365 callback.invoke(expected_result.clone());
366 let result = select! {
367 _ = time::sleep(Duration::from_millis(150)) => {
368 panic!("Callback invocation receiver timed out")
369 },
370 Some(result) = rx.recv() => result,
371 };
372
373 assert_eq!(expected_result, *result);
374 }
375
376 #[test]
377 fn test_multi_threaded_invoke_without_runtime() {
378 init_logger!();
379 let expected_result = Event::Foo;
380 let (tx, rx) = channel();
381 let runtime = Runtime::new().unwrap();
382 let callback = MultiThreadedCallback::<Event>::new();
383
384 let mut receiver = callback.subscribe();
385 runtime.spawn(async move {
386 if let Some(e) = receiver.recv().await {
387 tx.send(e).unwrap();
388 }
389 });
390
391 callback.invoke(expected_result.clone());
392 let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
393
394 assert_eq!(expected_result, *result);
395 }
396
397 #[tokio::test]
398 async fn test_invoke_dropped_receiver() {
399 init_logger!();
400 let expected_result = Event::Foo;
401 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
402 let callback = MultiThreadedCallback::<Event>::new();
403
404 let _ = callback.subscribe();
405 let mut receiver = callback.subscribe();
406 tokio::spawn(async move {
407 if let Some(e) = receiver.recv().await {
408 let _ = tx.send(e).await;
409 }
410 });
411
412 callback.invoke(expected_result.clone());
413 let result = select! {
414 _ = time::sleep(Duration::from_millis(150)) => {
415 panic!("Callback invocation receiver timed out")
416 },
417 Some(result) = rx.recv() => result,
418 };
419
420 assert_eq!(expected_result, *result);
421 }
422
423 #[test]
424 fn test_single_threaded_invoke() {
425 init_logger!();
426 let expected_result = Event::Foo;
427 let runtime = Runtime::new().unwrap();
428 let (tx, rx) = channel();
429 let callback = SingleThreadedCallback::new();
430
431 let mut receiver = callback.subscribe();
432 runtime.spawn(async move {
433 if let Some(e) = receiver.recv().await {
434 tx.send(e).unwrap();
435 }
436 });
437
438 callback.invoke(expected_result.clone());
439 let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
440
441 assert_eq!(expected_result, *result);
442 }
443}