1use fx_handle::Handle;
2use log::{debug, error, trace, warn};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::{Arc, Mutex};
6use std::time::Instant;
7use tokio::runtime::Runtime;
8use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
9
10pub type CallbackHandle = Handle;
12
13pub type Subscription<T> = UnboundedReceiver<Arc<T>>;
16
17pub type Subscriber<T> = UnboundedSender<Arc<T>>;
20
21pub trait Callback<T>: Debug
46where
47 T: Debug + Send + Sync,
48{
49 fn subscribe(&self) -> Subscription<T>;
76
77 fn subscribe_with(&self, subscriber: Subscriber<T>);
85}
86
87#[derive(Debug)]
125pub struct MultiThreadedCallback<T>
126where
127 T: Debug + Send + Sync,
128{
129 base: Arc<BaseCallback<T>>,
130 runtime: Arc<Mutex<Option<Runtime>>>,
131}
132
133impl<T> Callback<T> for MultiThreadedCallback<T>
134where
135 T: Debug + Send + Sync,
136{
137 fn subscribe(&self) -> Subscription<T> {
138 self.base.subscribe()
139 }
140
141 fn subscribe_with(&self, subscriber: Subscriber<T>) {
142 self.base.subscribe_with(subscriber)
143 }
144}
145
146impl<T> Clone for MultiThreadedCallback<T>
147where
148 T: Debug + Send + Sync,
149{
150 fn clone(&self) -> Self {
151 Self {
152 base: self.base.clone(),
153 runtime: self.runtime.clone(),
154 }
155 }
156}
157
158impl<T> MultiThreadedCallback<T>
159where
160 T: Debug + Send + Sync + 'static,
161{
162 pub fn new() -> Self {
164 Self {
165 base: Arc::new(BaseCallback::<T>::new()),
166 runtime: Arc::new(Mutex::new(None)),
167 }
168 }
169
170 pub fn invoke(&self, value: T) {
176 let inner = self.base.clone();
177 match tokio::runtime::Handle::try_current() {
178 Ok(_) => {
179 tokio::spawn(async move {
181 inner.invoke(value);
182 });
183 }
184 Err(_) => match self.runtime.lock() {
185 Ok(mut runtime) => {
186 runtime
187 .get_or_insert_with(|| Runtime::new().unwrap())
188 .spawn(async move {
189 inner.invoke(value);
190 });
191 }
192 Err(e) => error!("Failed to acquire lock: {}", e),
193 },
194 }
195 }
196}
197
198#[derive(Debug, Clone)]
202pub struct SingleThreadedCallback<T>
203where
204 T: Debug + Send + Sync,
205{
206 base: Arc<BaseCallback<T>>,
207}
208
209impl<T> SingleThreadedCallback<T>
210where
211 T: Debug + Send + Sync,
212{
213 pub fn new() -> Self {
215 Self {
216 base: Arc::new(BaseCallback::<T>::new()),
217 }
218 }
219
220 pub fn invoke(&self, value: T) {
226 self.base.invoke(value)
227 }
228}
229
230impl<T> Callback<T> for SingleThreadedCallback<T>
231where
232 T: Debug + Send + Sync,
233{
234 fn subscribe(&self) -> Subscription<T> {
235 self.base.subscribe()
236 }
237
238 fn subscribe_with(&self, subscriber: Subscriber<T>) {
239 self.base.subscribe_with(subscriber)
240 }
241}
242
243struct BaseCallback<T>
244where
245 T: Debug + Send + Sync,
246{
247 callbacks: Mutex<HashMap<CallbackHandle, UnboundedSender<Arc<T>>>>,
248}
249
250impl<T> BaseCallback<T>
251where
252 T: Debug + Send + Sync,
253{
254 fn new() -> Self {
255 Self {
256 callbacks: Mutex::new(HashMap::new()),
257 }
258 }
259
260 fn subscribe(&self) -> Subscription<T> {
261 let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
262 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
263 let handle = CallbackHandle::new();
264 mutex.insert(handle, tx);
265 drop(mutex);
266 trace!("Added callback {} to {:?}", handle, self);
267 rx
268 }
269
270 fn subscribe_with(&self, subscriber: Subscriber<T>) {
271 let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
272 let handle = CallbackHandle::new();
273 mutex.insert(handle, subscriber);
274 drop(mutex);
275 trace!("Added callback {} to {:?}", handle, self);
276 }
277
278 fn invoke(&self, value: T) {
279 let mut mutex = self.callbacks.lock().expect("failed to acquire lock");
280 let value = Arc::new(value);
281
282 trace!(
283 "Invoking a total of {} callbacks for {:?}",
284 mutex.len(),
285 *value
286 );
287
288 let handles_to_remove: Vec<CallbackHandle> = mutex
289 .iter()
290 .map(|(handle, callback)| {
291 BaseCallback::invoke_callback(handle, callback, value.clone())
292 })
293 .flat_map(|e| e)
294 .collect();
295
296 let total_handles = handles_to_remove.len();
297 for handle in handles_to_remove {
298 mutex.remove(&handle);
299 }
300
301 if total_handles > 0 {
302 debug!("Removed a total of {} callbacks", total_handles);
303 }
304 }
305
306 fn invoke_callback(
313 handle: &CallbackHandle,
314 callback: &UnboundedSender<Arc<T>>,
315 value: Arc<T>,
316 ) -> Option<CallbackHandle> {
317 let start_time = Instant::now();
318 if let Err(_) = callback.send(value) {
319 trace!("Callback {} has been dropped", handle);
320 return Some(handle.clone());
321 }
322 let elapsed = start_time.elapsed();
323 let message = format!(
324 "Callback {} took {}.{:03}ms to process the invocation",
325 handle,
326 elapsed.as_millis(),
327 elapsed.subsec_micros() % 1000
328 );
329 if elapsed.as_millis() >= 1000 {
330 warn!("{}", message);
331 } else {
332 trace!("{}", message);
333 }
334
335 None
336 }
337}
338
339impl<T> Debug for BaseCallback<T>
340where
341 T: Debug + Send + Sync,
342{
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 f.debug_struct("BaseCallback")
345 .field("callbacks", &self.callbacks.lock().unwrap().len())
346 .finish()
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::init_logger;
354 use std::sync::mpsc::channel;
355 use std::time::Duration;
356 use tokio::{select, time};
357
358 #[derive(Debug, Clone, PartialEq)]
359 pub enum Event {
360 Foo,
361 }
362
363 #[derive(Debug, PartialEq)]
364 enum NoneCloneEvent {
365 Bar,
366 }
367
368 #[tokio::test]
369 async fn test_multi_threaded_invoke() {
370 init_logger!();
371 let expected_result = Event::Foo;
372 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
373 let callback = MultiThreadedCallback::<Event>::new();
374
375 let mut receiver = callback.subscribe();
376 tokio::spawn(async move {
377 if let Some(e) = receiver.recv().await {
378 let _ = tx.send(e).await;
379 }
380 });
381
382 callback.invoke(expected_result.clone());
383 let result = select! {
384 _ = time::sleep(Duration::from_millis(150)) => {
385 panic!("Callback invocation receiver timed out")
386 },
387 Some(result) = rx.recv() => result,
388 };
389
390 assert_eq!(expected_result, *result);
391 }
392
393 #[test]
394 fn test_multi_threaded_invoke_without_runtime() {
395 init_logger!();
396 let expected_result = Event::Foo;
397 let (tx, rx) = channel();
398 let runtime = Runtime::new().unwrap();
399 let callback = MultiThreadedCallback::<Event>::new();
400
401 let mut receiver = callback.subscribe();
402 runtime.spawn(async move {
403 if let Some(e) = receiver.recv().await {
404 tx.send(e).unwrap();
405 }
406 });
407
408 callback.invoke(expected_result.clone());
409 let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
410
411 assert_eq!(expected_result, *result);
412 }
413
414 #[tokio::test]
415 async fn test_invoke_dropped_receiver() {
416 init_logger!();
417 let expected_result = Event::Foo;
418 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
419 let callback = MultiThreadedCallback::<Event>::new();
420
421 let _ = callback.subscribe();
422 let mut receiver = callback.subscribe();
423 tokio::spawn(async move {
424 if let Some(e) = receiver.recv().await {
425 let _ = tx.send(e).await;
426 }
427 });
428
429 callback.invoke(expected_result.clone());
430 let result = select! {
431 _ = time::sleep(Duration::from_millis(150)) => {
432 panic!("Callback invocation receiver timed out")
433 },
434 Some(result) = rx.recv() => result,
435 };
436
437 assert_eq!(expected_result, *result);
438 }
439
440 #[tokio::test]
441 async fn test_non_cloneable_type() {
442 init_logger!();
443 let (tx, mut rx) = tokio::sync::mpsc::channel(1);
444 let callback = MultiThreadedCallback::<NoneCloneEvent>::new();
445
446 let mut receiver = callback.subscribe();
447 tokio::spawn(async move {
448 if let Some(e) = receiver.recv().await {
449 let _ = tx.send(e).await;
450 }
451 });
452
453 callback.invoke(NoneCloneEvent::Bar);
454 let result = select! {
455 _ = time::sleep(Duration::from_millis(150)) => {
456 panic!("Callback invocation receiver timed out")
457 },
458 Some(result) = rx.recv() => result,
459 };
460
461 assert_eq!(NoneCloneEvent::Bar, *result);
462 }
463
464 #[test]
465 fn test_single_threaded_invoke() {
466 init_logger!();
467 let expected_result = Event::Foo;
468 let runtime = Runtime::new().unwrap();
469 let (tx, rx) = channel();
470 let callback = SingleThreadedCallback::new();
471
472 let mut receiver = callback.subscribe();
473 runtime.spawn(async move {
474 if let Some(e) = receiver.recv().await {
475 tx.send(e).unwrap();
476 }
477 });
478
479 callback.invoke(expected_result.clone());
480 let result = rx.recv_timeout(Duration::from_millis(50)).unwrap();
481
482 assert_eq!(expected_result, *result);
483 }
484}