mediator/default_impls/
mediator_impl.rs

1#![allow(irrefutable_let_patterns)]
2
3use crate::error::{Error, ErrorKind};
4use crate::{Event, EventHandler, Mediator, Request, RequestHandler};
5use std::any::{Any, TypeId};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9#[cfg(feature = "interceptors")]
10use crate::Interceptor;
11
12#[cfg(feature = "streams")]
13use {
14    crate::futures::Stream,
15    crate::{StreamRequest, StreamRequestHandler},
16};
17
18#[cfg(all(feature = "interceptors", feature = "streams"))]
19use crate::StreamInterceptor;
20
21type SharedHandler<H> = Arc<Mutex<HashMap<TypeId, H>>>;
22
23// A wrapper around the request handler to handle the request and return the result.
24// To provide type safety without unsafe code we box all: the function, the params and the result.
25#[derive(Clone)]
26struct RequestHandlerWrapper {
27    #[allow(clippy::type_complexity)]
28    handler: Arc<Mutex<dyn FnMut(Box<dyn Any>) -> Box<dyn Any> + Send>>,
29    is_deferred: bool,
30}
31
32impl RequestHandlerWrapper {
33    pub fn new<Req, Res, H>(mut handler: H) -> Self
34    where
35        Res: 'static,
36        Req: Request<Res> + 'static,
37        H: RequestHandler<Req, Res> + Send + 'static,
38    {
39        let f = move |req: Box<dyn Any>| -> Box<dyn Any> {
40            let req = *req.downcast::<Req>().unwrap();
41            Box::new(handler.handle(req))
42        };
43
44        RequestHandlerWrapper {
45            handler: Arc::new(Mutex::new(f)),
46            is_deferred: false,
47        }
48    }
49
50    pub fn from_fn<Req, Res, F>(mut handler: F) -> Self
51    where
52        Res: 'static,
53        Req: Request<Res> + 'static,
54        F: FnMut(Req) -> Res + Send + 'static,
55    {
56        let f = move |req: Box<dyn Any>| -> Box<dyn Any> {
57            let req = *req.downcast::<Req>().unwrap();
58            Box::new(handler(req))
59        };
60
61        RequestHandlerWrapper {
62            handler: Arc::new(Mutex::new(f)),
63            is_deferred: false,
64        }
65    }
66
67    pub fn from_deferred<Res, Req, F>(mut handler: F) -> Self
68    where
69        Res: 'static,
70        Req: Request<Res> + 'static,
71        F: FnMut(Req, DefaultMediator) -> Res + Send + 'static,
72    {
73        let f = move |args: Box<dyn Any>| -> Box<dyn Any> {
74            let (req, mediator) = *args.downcast::<(Req, DefaultMediator)>().unwrap();
75            Box::new(handler(req, mediator))
76        };
77
78        RequestHandlerWrapper {
79            handler: Arc::new(Mutex::new(f)),
80            is_deferred: true,
81        }
82    }
83
84    pub fn handle<Req, Res>(&mut self, req: Req, mediator: Option<DefaultMediator>) -> Option<Res>
85    where
86        Res: 'static,
87        Req: Request<Res> + 'static,
88    {
89        let mut handler = self.handler.lock().unwrap();
90        let req: Box<dyn Any> = match mediator {
91            Some(mediator) => Box::new((req, mediator)),
92            None => Box::new(req),
93        };
94
95        let res = (handler)(req);
96        res.downcast::<Res>().map(|res| *res).ok()
97    }
98}
99
100// A wrapper around the event handler to handle the events.
101// To provide type safety without unsafe code we box all: the function, the params and the result.
102#[derive(Clone)]
103struct EventHandlerWrapper {
104    #[allow(clippy::type_complexity)]
105    handler: Arc<Mutex<dyn FnMut(Box<dyn Any>) + Send>>,
106    is_deferred: bool,
107}
108
109impl EventHandlerWrapper {
110    pub fn new<E, H>(mut handler: H) -> Self
111    where
112        E: Event + 'static,
113        H: EventHandler<E> + Send + 'static,
114    {
115        let f = move |event: Box<dyn Any>| {
116            let event = *event.downcast::<E>().unwrap();
117            handler.handle(event);
118        };
119
120        EventHandlerWrapper {
121            handler: Arc::new(Mutex::new(f)),
122            is_deferred: false,
123        }
124    }
125
126    pub fn from_fn<E, F>(mut handler: F) -> Self
127    where
128        E: Event + 'static,
129        F: FnMut(E) + Send + 'static,
130    {
131        let f = move |event: Box<dyn Any>| {
132            let event = *event.downcast::<E>().unwrap();
133            handler(event);
134        };
135
136        EventHandlerWrapper {
137            handler: Arc::new(Mutex::new(f)),
138            is_deferred: false,
139        }
140    }
141
142    pub fn from_deferred<E, F>(mut handler: F) -> Self
143    where
144        E: Event + 'static,
145        F: FnMut(E, DefaultMediator) + Send + 'static,
146    {
147        let f = move |args: Box<dyn Any>| {
148            let (event, mediator) = *args.downcast::<(E, DefaultMediator)>().unwrap();
149            handler(event, mediator);
150        };
151
152        EventHandlerWrapper {
153            handler: Arc::new(Mutex::new(f)),
154            is_deferred: true,
155        }
156    }
157
158    pub fn handle<E>(&mut self, event: E, mediator: Option<DefaultMediator>)
159    where
160        E: Event + 'static,
161    {
162        let mut handler = self.handler.lock().unwrap();
163        let event: Box<dyn Any> = match mediator {
164            Some(mediator) => Box::new((event, mediator)),
165            None => Box::new(event),
166        };
167
168        (handler)(event);
169    }
170}
171
172// A wrapper around the stream handler to handle the request.
173// To provide type safety without unsafe code we box all: the function, the params and the result.
174#[derive(Clone)]
175#[cfg(feature = "streams")]
176struct StreamRequestHandlerWrapper {
177    #[allow(clippy::type_complexity)]
178    handler: Arc<Mutex<dyn FnMut(Box<dyn Any>) -> Box<dyn Any> + Send>>,
179    is_deferred: bool,
180}
181
182#[cfg(feature = "streams")]
183impl StreamRequestHandlerWrapper {
184    pub fn new<Req, S, T, H>(mut handler: H) -> Self
185    where
186        Req: StreamRequest<Stream = S, Item = T> + 'static,
187        H: StreamRequestHandler<Request = Req, Stream = S, Item = T> + Send + 'static,
188        S: Stream<Item = T> + 'static,
189        T: 'static,
190    {
191        let f = move |req: Box<dyn Any>| -> Box<dyn Any> {
192            let req = *req.downcast::<Req>().unwrap();
193            Box::new(handler.handle_stream(req))
194        };
195
196        StreamRequestHandlerWrapper {
197            handler: Arc::new(Mutex::new(f)),
198            is_deferred: false,
199        }
200    }
201
202    pub fn from_fn<Req, S, T, F>(mut handler: F) -> Self
203    where
204        Req: StreamRequest<Stream = S, Item = T> + 'static,
205        S: Stream<Item = T> + 'static,
206        F: FnMut(Req) -> S + Send + 'static,
207        T: 'static,
208    {
209        let f = move |req: Box<dyn Any>| -> Box<dyn Any> {
210            let req = *req.downcast::<Req>().unwrap();
211            Box::new(handler(req))
212        };
213
214        StreamRequestHandlerWrapper {
215            handler: Arc::new(Mutex::new(f)),
216            is_deferred: false,
217        }
218    }
219
220    pub fn from_fn_with<State, Req, S, T, F>(mut handler: F, state: State) -> Self
221    where
222        State: Send + Clone + 'static,
223        Req: StreamRequest<Stream = S, Item = T> + 'static,
224        S: Stream<Item = T> + 'static,
225        F: FnMut(Req, State) -> S + Send + 'static,
226        T: 'static,
227    {
228        let f = move |req: Box<dyn Any>| -> Box<dyn Any> {
229            let req = *req.downcast::<Req>().unwrap();
230            Box::new(handler(req, state.clone()))
231        };
232
233        StreamRequestHandlerWrapper {
234            handler: Arc::new(Mutex::new(f)),
235            is_deferred: false,
236        }
237    }
238
239    pub fn from_deferred<Req, S, T, F>(mut handler: F) -> Self
240    where
241        Req: StreamRequest<Stream = S, Item = T> + 'static,
242        S: Stream<Item = T> + 'static,
243        F: FnMut(Req, DefaultMediator) -> S + Send + 'static,
244        T: 'static,
245    {
246        let f = move |req: Box<dyn Any>| -> Box<dyn Any> {
247            let (req, mediator) = *req.downcast::<(Req, DefaultMediator)>().unwrap();
248            Box::new(handler(req, mediator))
249        };
250
251        StreamRequestHandlerWrapper {
252            handler: Arc::new(Mutex::new(f)),
253            is_deferred: true,
254        }
255    }
256
257    pub fn from_deferred_with<State, Req, S, T, F>(mut handler: F, state: State) -> Self
258    where
259        State: Send + Clone + 'static,
260        Req: StreamRequest<Stream = S, Item = T> + 'static,
261        S: Stream<Item = T> + 'static,
262        F: FnMut(Req, DefaultMediator, State) -> S + Send + 'static,
263        T: 'static,
264    {
265        let f = move |req: Box<dyn Any>| -> Box<dyn Any> {
266            let (req, mediator) = *req.downcast::<(Req, DefaultMediator)>().unwrap();
267            Box::new(handler(req, mediator, state.clone()))
268        };
269
270        StreamRequestHandlerWrapper {
271            handler: Arc::new(Mutex::new(f)),
272            is_deferred: true,
273        }
274    }
275
276    pub fn handle<Req, S, T>(&mut self, req: Req, mediator: Option<DefaultMediator>) -> Option<S>
277    where
278        Req: StreamRequest<Stream = S, Item = T> + 'static,
279        S: Stream<Item = T> + 'static,
280        T: 'static,
281    {
282        let mut handler = self.handler.lock().unwrap();
283        let req: Box<dyn Any> = match mediator {
284            Some(mediator) => Box::new((req, mediator)),
285            None => Box::new(req),
286        };
287
288        let res = (handler)(req);
289        res.downcast::<S>().map(|res| *res).ok()
290    }
291}
292
293#[cfg(feature = "interceptors")]
294type NextCallback = Box<dyn Any>;
295
296#[cfg(feature = "interceptors")]
297#[derive(Debug, Clone, Hash, Eq, PartialEq)]
298struct InterceptorKey {
299    req_ty: TypeId,
300    res_ty: TypeId,
301}
302
303#[cfg(feature = "interceptors")]
304impl InterceptorKey {
305    pub fn of<Req: 'static, Res: 'static>() -> Self {
306        InterceptorKey {
307            req_ty: TypeId::of::<Req>(),
308            res_ty: TypeId::of::<Res>(),
309        }
310    }
311}
312
313#[cfg(feature = "interceptors")]
314#[derive(Clone)]
315enum InterceptorWrapper {
316    Handler(Arc<Mutex<dyn FnMut(Box<dyn Any>, NextCallback) -> Box<dyn Any> + Send>>),
317
318    #[cfg(feature = "streams")]
319    Stream(Arc<Mutex<dyn FnMut(Box<dyn Any>, NextCallback) -> Box<dyn Any> + Send>>),
320}
321
322#[cfg(feature = "interceptors")]
323impl InterceptorWrapper {
324    pub fn from_handler<Req, Res, H>(mut h: H) -> Self
325    where
326        Res: 'static,
327        Req: Request<Res> + 'static,
328        H: Interceptor<Req, Res> + Send + 'static,
329    {
330        let f = move |req: Box<dyn Any>, next: NextCallback| -> Box<dyn Any> {
331            let req = *req.downcast::<Req>().unwrap();
332            let next = next.downcast::<Box<dyn FnOnce(Req) -> Res>>().unwrap();
333            let res = h.handle(req, next);
334            Box::new(res)
335        };
336
337        InterceptorWrapper::Handler(Arc::new(Mutex::new(f)))
338    }
339
340    pub fn from_handler_fn<Req, Res, F>(mut f: F) -> Self
341    where
342        Res: 'static,
343        Req: Request<Res> + 'static,
344        F: FnMut(Req, Box<dyn FnOnce(Req) -> Res>) -> Res + Send + 'static,
345    {
346        let f = move |req: Box<dyn Any>, next: NextCallback| -> Box<dyn Any> {
347            let req = *req.downcast::<Req>().unwrap();
348            let next = next.downcast::<Box<dyn FnOnce(Req) -> Res>>().unwrap();
349            let res = f(req, next);
350            Box::new(res)
351        };
352
353        InterceptorWrapper::Handler(Arc::new(Mutex::new(f)))
354    }
355
356    #[cfg(feature = "streams")]
357    pub fn from_stream<Req, T, S, H>(mut h: H) -> Self
358    where
359        Req: StreamRequest<Stream = S, Item = T> + 'static,
360        H: StreamInterceptor<Request = Req, Stream = S, Item = T> + Send + 'static,
361        S: Stream<Item = T> + 'static,
362        T: 'static,
363    {
364        let f = move |req: Box<dyn Any>, next: NextCallback| -> Box<dyn Any> {
365            let req = *req.downcast::<Req>().unwrap();
366            let next = next.downcast::<Box<dyn FnOnce(Req) -> S>>().unwrap();
367            let res = h.handle_stream(req, next);
368            Box::new(res)
369        };
370
371        InterceptorWrapper::Stream(Arc::new(Mutex::new(f)))
372    }
373
374    #[cfg(feature = "streams")]
375    pub fn from_stream_fn<Req, T, S, F>(mut f: F) -> Self
376    where
377        Req: StreamRequest<Stream = S, Item = T> + 'static,
378        F: FnMut(Req, Box<dyn FnOnce(Req) -> S>) -> S + Send + 'static,
379        S: Stream<Item = T> + 'static,
380        T: 'static,
381    {
382        let f = move |req: Box<dyn Any>, next: NextCallback| -> Box<dyn Any> {
383            let req = *req.downcast::<Req>().unwrap();
384            let next = next.downcast::<Box<dyn FnOnce(Req) -> S>>().unwrap();
385            let res = f(req, next);
386            Box::new(res)
387        };
388
389        InterceptorWrapper::Stream(Arc::new(Mutex::new(f)))
390    }
391
392    pub fn handle<Req, Res>(&mut self, req: Req, next: Box<dyn FnOnce(Req) -> Res>) -> Option<Res>
393    where
394        Req: Request<Res> + 'static,
395        Res: 'static,
396    {
397        if let InterceptorWrapper::Handler(handler) = self {
398            let mut handler = handler.lock().unwrap();
399            let req: Box<dyn Any> = Box::new(req);
400            let next: Box<dyn Any> = Box::new(next);
401            let res: Box<dyn Any> = (handler)(req, next);
402            res.downcast::<Res>().map(|res| *res).ok()
403        } else {
404            None
405        }
406    }
407
408    #[cfg(feature = "streams")]
409    pub fn stream<Req, T, S>(&mut self, req: Req, next: Box<dyn FnOnce(Req) -> S>) -> Option<S>
410    where
411        Req: StreamRequest<Stream = S, Item = T> + 'static,
412        S: Stream<Item = T> + 'static,
413        T: 'static,
414    {
415        if let InterceptorWrapper::Stream(handler) = self {
416            let mut handler = handler.lock().unwrap();
417            let req: Box<dyn Any> = Box::new(req);
418            let next: Box<dyn Any> = Box::new(next);
419            let res: Box<dyn Any> = (handler)(req, next);
420            res.downcast::<S>().map(|res| *res).ok()
421        } else {
422            None
423        }
424    }
425}
426
427/// A default implementation for the [Mediator] trait.
428///
429/// # Examples
430///
431/// ## Request handler
432/// ```
433/// use std::sync::atomic::AtomicU64;
434/// use mediator::{DefaultMediator, Mediator, Request, RequestHandler};
435///
436/// struct GetNextId;
437/// impl Request<u64> for GetNextId { }
438///
439/// struct GetNextIdHandler;
440/// impl RequestHandler<GetNextId, u64> for GetNextIdHandler {
441///   fn handle(&mut self, _: GetNextId) -> u64 {
442///     static NEXT_ID : AtomicU64 = AtomicU64::new(1);
443///     NEXT_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
444///   }
445/// }
446///
447/// let mut mediator = DefaultMediator::builder()
448///     .add_handler(GetNextIdHandler)
449///     .build();
450///
451/// assert_eq!(Ok(1), mediator.send(GetNextId));
452/// assert_eq!(Ok(2), mediator.send(GetNextId));
453/// assert_eq!(Ok(3), mediator.send(GetNextId));
454/// ```
455///
456/// ## Event handler
457/// ```
458/// use mediator::{Event, DefaultMediator, Mediator};
459///
460/// #[derive(Clone)]
461/// struct Product { name: String };
462///
463/// #[derive(Clone)]
464/// struct ProductAddedEvent(Product);
465/// impl Event for ProductAddedEvent { }
466///
467/// struct ProductService(Vec<Product>, DefaultMediator);
468/// impl ProductService {
469///     pub fn add<S: Into<String>>(&mut self, product: S) {
470///         let product = Product { name: product.into() };
471///         self.0.push(product.clone());
472///         self.1.publish(ProductAddedEvent(product));
473///     }
474/// }
475///
476/// let mut mediator = DefaultMediator::builder()
477///     .subscribe_fn(move |event: ProductAddedEvent| {
478///         println!("Product added: {}", event.0.name);
479///     })
480///    .build();
481///
482/// let mut service = ProductService(vec![], mediator.clone());
483///
484/// service.add("Microwave");   // Product added: Microwave
485/// service.add("Toaster");     // Product added: Toaster
486/// ```
487#[derive(Clone)]
488pub struct DefaultMediator {
489    request_handlers: SharedHandler<RequestHandlerWrapper>,
490    event_handlers: SharedHandler<Vec<EventHandlerWrapper>>,
491
492    #[cfg(feature = "interceptors")]
493    interceptors: Arc<Mutex<HashMap<InterceptorKey, Vec<InterceptorWrapper>>>>,
494
495    #[cfg(feature = "streams")]
496    stream_handlers: SharedHandler<StreamRequestHandlerWrapper>,
497}
498
499impl DefaultMediator {
500    /// Gets a [DefaultMediator] builder.
501    pub fn builder() -> Builder {
502        Builder::new()
503    }
504}
505
506impl Mediator for DefaultMediator {
507    fn send<Req, Res>(&mut self, req: Req) -> crate::Result<Res>
508    where
509        Res: 'static,
510        Req: Request<Res> + 'static,
511    {
512        let type_id = TypeId::of::<Req>();
513        let mut handlers_lock = self
514            .request_handlers
515            .try_lock()
516            .expect("Request handlers are locked");
517
518        if let Some(mut handler) = handlers_lock.get_mut(&type_id).cloned() {
519            // Drop the lock to avoid deadlocks
520            drop(handlers_lock);
521
522            let mediator = if handler.is_deferred {
523                Some(self.clone())
524            } else {
525                None
526            };
527
528            #[cfg(feature = "interceptors")]
529            {
530                let mut interceptors = self.interceptors.lock().expect("Interceptors are locked");
531
532                let key = InterceptorKey::of::<Req, Res>();
533                if let Some(interceptors) = interceptors.get_mut(&key).cloned() {
534                    let next_handler: Box<dyn FnOnce(Req) -> Res> =
535                        Box::new(move |req: Req| handler.handle(req, mediator).unwrap());
536
537                    let handler = interceptors.into_iter().fold(
538                        next_handler,
539                        move |next, mut interceptor| {
540                            let f = move |req: Req| {
541                                // SAFETY: this only fail if the downcast fails,
542                                // but we already checked that the type is correct
543                                interceptor.handle(req, next).unwrap()
544                            };
545                            Box::new(f) as Box<dyn FnOnce(Req) -> Res>
546                        },
547                    );
548
549                    let res = handler(req);
550                    return Ok(res);
551                }
552            }
553
554            if let Some(res) = handler.handle(req, mediator) {
555                return Ok(res);
556            }
557        }
558
559        Err(Error::from(ErrorKind::NotFound))
560    }
561
562    fn publish<E>(&mut self, event: E) -> crate::Result<()>
563    where
564        E: Event + 'static,
565    {
566        let type_id = TypeId::of::<E>();
567        let mut handlers_lock = self
568            .event_handlers
569            .try_lock()
570            .expect("Event handlers are locked");
571
572        // FIXME: Cloning the entire Vec may not be necessary, we could use something like Arc<Mutex<Vec<_>>>
573        if let Some(handlers) = handlers_lock.get_mut(&type_id).cloned() {
574            // Drop the lock to avoid deadlocks
575            drop(handlers_lock);
576
577            for mut handler in handlers {
578                let mediator = if handler.is_deferred {
579                    Some(self.clone())
580                } else {
581                    None
582                };
583
584                handler.handle(event.clone(), mediator);
585            }
586
587            Ok(())
588        } else {
589            Err(Error::from(ErrorKind::NotFound))
590        }
591    }
592
593    #[cfg(feature = "streams")]
594    fn stream<Req, S, T>(&mut self, req: Req) -> crate::Result<S>
595    where
596        Req: StreamRequest<Stream = S, Item = T> + 'static,
597        S: Stream<Item = T> + 'static,
598        T: 'static,
599    {
600        let type_id = TypeId::of::<Req>();
601        let mut handlers_lock = self
602            .stream_handlers
603            .try_lock()
604            .expect("Stream handlers are locked");
605
606        if let Some(mut handler) = handlers_lock.get_mut(&type_id).cloned() {
607            // Drop the lock to avoid deadlocks
608            drop(handlers_lock);
609
610            let mediator = if handler.is_deferred {
611                Some(self.clone())
612            } else {
613                None
614            };
615
616            #[cfg(feature = "interceptors")]
617            {
618                let mut interceptors = self.interceptors.lock().expect("Interceptors are locked");
619
620                let key = InterceptorKey::of::<Req, S>();
621                if let Some(interceptors) = interceptors.get_mut(&key).cloned() {
622                    let next_handler: Box<dyn FnOnce(Req) -> S> =
623                        Box::new(move |req: Req| handler.handle(req, mediator).unwrap());
624
625                    let handler = interceptors.into_iter().fold(
626                        next_handler,
627                        move |next, mut interceptor| {
628                            let f = move |req: Req| {
629                                // SAFETY: this only fail if the downcast fails,
630                                // but we already checked that the type is correct
631                                interceptor.stream(req, next).unwrap()
632                            };
633                            Box::new(f) as Box<dyn FnOnce(Req) -> S>
634                        },
635                    );
636
637                    let res = handler(req);
638                    return Ok(res);
639                }
640            }
641
642            if let Some(stream) = handler.handle(req, mediator) {
643                return Ok(stream);
644            }
645        }
646
647        Err(Error::from(ErrorKind::NotFound))
648    }
649}
650
651/// A builder for the [DefaultMediator].
652pub struct Builder {
653    inner: DefaultMediator,
654}
655
656impl Builder {
657    /// Constructs a new `DefaultMediatorBuilder`.
658    pub fn new() -> Self {
659        Builder {
660            inner: DefaultMediator {
661                request_handlers: SharedHandler::default(),
662                event_handlers: SharedHandler::default(),
663
664                #[cfg(feature = "interceptors")]
665                interceptors: Default::default(),
666
667                #[cfg(feature = "streams")]
668                stream_handlers: SharedHandler::default(),
669            },
670        }
671    }
672
673    /// Registers a request handler.
674    pub fn add_handler<Req, Res, H>(self, handler: H) -> Self
675    where
676        Res: 'static,
677        Req: Request<Res> + 'static,
678        H: RequestHandler<Req, Res> + Send + 'static,
679    {
680        let mut handlers_lock = self.inner.request_handlers.lock().unwrap();
681        handlers_lock.insert(TypeId::of::<Req>(), RequestHandlerWrapper::new(handler));
682        drop(handlers_lock);
683        self
684    }
685
686    /// Registers a request handler from a function.
687    pub fn add_handler_fn<Req, Res, F>(self, handler: F) -> Self
688    where
689        Res: 'static,
690        Req: Request<Res> + 'static,
691        F: FnMut(Req) -> Res + Send + 'static,
692    {
693        let mut handlers_lock = self.inner.request_handlers.lock().unwrap();
694        handlers_lock.insert(TypeId::of::<Req>(), RequestHandlerWrapper::from_fn(handler));
695        drop(handlers_lock);
696        self
697    }
698
699    /// Register a request handler using a copy of the mediator.
700    pub fn add_handler_deferred<Req, Res, H, F>(self, f: F) -> Self
701    where
702        Res: 'static,
703        Req: Request<Res> + 'static,
704        H: RequestHandler<Req, Res> + Send + 'static,
705        F: Fn(DefaultMediator) -> H + Send,
706    {
707        let handler = f(self.inner.clone());
708        self.add_handler(handler)
709    }
710
711    /// Registers a request handler from a function using a copy of the mediator.
712    pub fn add_handler_fn_deferred<Req, Res, F>(self, f: F) -> Self
713    where
714        Res: 'static,
715        Req: Request<Res> + 'static,
716        F: FnMut(Req, DefaultMediator) -> Res + Send + 'static,
717    {
718        let mut handlers_lock = self.inner.request_handlers.lock().unwrap();
719        handlers_lock.insert(TypeId::of::<Req>(), RequestHandlerWrapper::from_deferred(f));
720        drop(handlers_lock);
721        self
722    }
723
724    /// Registers an event handler.
725    pub fn subscribe<E, H>(self, handler: H) -> Self
726    where
727        E: Event + 'static,
728        H: EventHandler<E> + Send + 'static,
729    {
730        let mut handlers_lock = self.inner.event_handlers.lock().unwrap();
731        let event_handlers = handlers_lock
732            .entry(TypeId::of::<E>())
733            .or_insert_with(Vec::new);
734        event_handlers.push(EventHandlerWrapper::new(handler));
735        drop(handlers_lock);
736        self
737    }
738
739    /// Registers an event handler from a function.
740    pub fn subscribe_fn<E, F>(self, handler: F) -> Self
741    where
742        E: Event + 'static,
743        F: FnMut(E) + Send + 'static,
744    {
745        let mut handlers_lock = self.inner.event_handlers.lock().unwrap();
746        let event_handlers = handlers_lock
747            .entry(TypeId::of::<E>())
748            .or_insert_with(Vec::new);
749        event_handlers.push(EventHandlerWrapper::from_fn(handler));
750        drop(handlers_lock);
751        self
752    }
753
754    /// Registers an event handler using a copy of the mediator.
755    pub fn subscribe_deferred<E, H, F>(self, f: F) -> Self
756    where
757        E: Event + 'static,
758        H: EventHandler<E> + Send + 'static,
759        F: Fn(DefaultMediator) -> H + Send,
760    {
761        let handler = f(self.inner.clone());
762        self.subscribe(handler)
763    }
764
765    /// Registers an event handler from a function using a copy of the mediator.
766    pub fn subscribe_fn_deferred<E, H, F>(self, f: F) -> Self
767    where
768        E: Event + 'static,
769        F: FnMut(E, DefaultMediator) + Send + 'static,
770    {
771        let mut handlers_lock = self.inner.event_handlers.lock().unwrap();
772        let event_handlers = handlers_lock
773            .entry(TypeId::of::<E>())
774            .or_insert_with(Vec::new);
775        event_handlers.push(EventHandlerWrapper::from_deferred(f));
776        drop(handlers_lock);
777        self
778    }
779
780    /// Registers a stream handler.
781    #[cfg(feature = "streams")]
782    pub fn add_stream_handler<Req, S, T, H>(self, handler: H) -> Self
783    where
784        Req: StreamRequest<Stream = S, Item = T> + 'static,
785        H: StreamRequestHandler<Request = Req, Stream = S, Item = T> + Send + 'static,
786        S: Stream<Item = T> + 'static,
787        T: 'static,
788    {
789        let mut handlers_lock = self.inner.stream_handlers.lock().unwrap();
790        handlers_lock.insert(
791            TypeId::of::<Req>(),
792            StreamRequestHandlerWrapper::new(handler),
793        );
794        drop(handlers_lock);
795        self
796    }
797
798    /// Registers a stream handler from a function.
799    #[cfg(feature = "streams")]
800    pub fn add_stream_handler_fn<Req, S, T, F>(self, f: F) -> Self
801    where
802        Req: StreamRequest<Stream = S, Item = T> + 'static,
803        F: FnMut(Req) -> S + Send + 'static,
804        S: Stream<Item = T> + 'static,
805        T: 'static,
806    {
807        let mut handlers_lock = self.inner.stream_handlers.lock().unwrap();
808        handlers_lock.insert(TypeId::of::<Req>(), StreamRequestHandlerWrapper::from_fn(f));
809        drop(handlers_lock);
810        self
811    }
812
813    #[cfg(feature = "streams")]
814    pub fn add_stream_handler_fn_with<State, Req, S, T, F>(self, state: State, f: F) -> Self
815    where
816        State: Send + Clone + 'static,
817        Req: StreamRequest<Stream = S, Item = T> + 'static,
818        F: FnMut(Req, State) -> S + Send + 'static,
819        S: Stream<Item = T> + 'static,
820        T: 'static,
821    {
822        let mut handlers_lock = self.inner.stream_handlers.lock().unwrap();
823        handlers_lock.insert(
824            TypeId::of::<Req>(),
825            StreamRequestHandlerWrapper::from_fn_with(f, state),
826        );
827        drop(handlers_lock);
828        self
829    }
830
831    /// Registers a stream handler using a copy of the mediator.
832    #[cfg(feature = "streams")]
833    pub fn add_stream_handler_deferred<Req, S, T, H, F>(self, f: F) -> Self
834    where
835        Req: StreamRequest<Stream = S, Item = T> + 'static,
836        H: StreamRequestHandler<Request = Req, Stream = S, Item = T> + Send + 'static,
837        S: Stream<Item = T> + 'static,
838        T: 'static,
839        F: Fn(DefaultMediator) -> H,
840    {
841        let handler = f(self.inner.clone());
842        self.add_stream_handler(handler)
843    }
844
845    /// Registers a stream handler from a function using a copy of the mediator.
846    #[cfg(feature = "streams")]
847    pub fn add_stream_handler_fn_deferred<Req, S, T, F>(self, f: F) -> Self
848    where
849        Req: StreamRequest<Stream = S, Item = T> + 'static,
850        F: FnMut(Req, DefaultMediator) -> S + Send + 'static,
851        S: Stream<Item = T> + 'static,
852        T: 'static,
853    {
854        let mut handlers_lock = self.inner.stream_handlers.lock().unwrap();
855        handlers_lock.insert(
856            TypeId::of::<Req>(),
857            StreamRequestHandlerWrapper::from_deferred(f),
858        );
859        drop(handlers_lock);
860        self
861    }
862
863    /// Register a stream handler from a function using a copy of the mediator and a state.
864    #[cfg(feature = "streams")]
865    pub fn add_stream_handler_fn_deferred_with<State, Req, S, T, F>(
866        self,
867        state: State,
868        f: F,
869    ) -> Self
870    where
871        State: Send + Clone + 'static,
872        Req: StreamRequest<Stream = S, Item = T> + 'static,
873        F: FnMut(Req, DefaultMediator, State) -> S + Send + 'static,
874        S: Stream<Item = T> + 'static,
875        T: 'static,
876    {
877        let mut handlers_lock = self.inner.stream_handlers.lock().unwrap();
878        handlers_lock.insert(
879            TypeId::of::<Req>(),
880            StreamRequestHandlerWrapper::from_deferred_with(f, state),
881        );
882        drop(handlers_lock);
883        self
884    }
885
886    /// Adds a request interceptor.
887    #[cfg(feature = "interceptors")]
888    pub fn add_interceptor<Req, Res, H>(self, handler: H) -> Self
889    where
890        Res: 'static,
891        Req: Request<Res> + 'static,
892        H: Interceptor<Req, Res> + Send + 'static,
893    {
894        let req_ty = TypeId::of::<Req>();
895        let res_ty = TypeId::of::<Res>();
896        let key = InterceptorKey { req_ty, res_ty };
897
898        let mut handlers_lock = self.inner.interceptors.lock().unwrap();
899        let interceptors = handlers_lock.entry(key).or_insert(Vec::new());
900        interceptors.push(InterceptorWrapper::from_handler(handler));
901        drop(handlers_lock);
902        self
903    }
904
905    /// Adds a request interceptor from a function.
906    #[cfg(feature = "interceptors")]
907    pub fn add_interceptor_fn<Req, Res, F>(self, f: F) -> Self
908    where
909        Res: 'static,
910        Req: Request<Res> + 'static,
911        F: FnMut(Req, Box<dyn FnOnce(Req) -> Res>) -> Res + Send + 'static,
912    {
913        let req_ty = TypeId::of::<Req>();
914        let res_ty = TypeId::of::<Res>();
915        let key = InterceptorKey { req_ty, res_ty };
916        let mut handlers_lock = self.inner.interceptors.lock().unwrap();
917        let interceptors = handlers_lock.entry(key).or_insert(Vec::new());
918        interceptors.push(InterceptorWrapper::from_handler_fn(f));
919        drop(handlers_lock);
920        self
921    }
922
923    /// Adds a stream request interceptor.
924    #[cfg(all(feature = "streams", feature = "interceptors"))]
925    pub fn add_interceptor_stream<Req, T, S, H>(self, handler: H) -> Self
926    where
927        Req: StreamRequest<Stream = S, Item = T> + 'static,
928        S: Stream<Item = T> + 'static,
929        T: 'static,
930        H: StreamInterceptor<Request = Req, Stream = S, Item = T> + Send + 'static,
931    {
932        let key = InterceptorKey::of::<Req, S>();
933
934        let mut handlers_lock = self.inner.interceptors.lock().unwrap();
935        let interceptors = handlers_lock.entry(key).or_insert(Vec::new());
936        interceptors.push(InterceptorWrapper::from_stream(handler));
937        drop(handlers_lock);
938        self
939    }
940
941    /// Adds a stream request interceptor from a function.
942    #[cfg(all(feature = "streams", feature = "interceptors"))]
943    pub fn add_interceptor_stream_fn<Req, T, S, F>(self, f: F) -> Self
944    where
945        Req: StreamRequest<Stream = S, Item = T> + 'static,
946        S: Stream<Item = T> + 'static,
947        T: 'static,
948        F: FnMut(Req, Box<dyn FnOnce(Req) -> S>) -> S + Send + 'static,
949    {
950        let key = InterceptorKey::of::<Req, S>();
951        let mut handlers_lock = self.inner.interceptors.lock().unwrap();
952        let interceptors = handlers_lock.entry(key).or_insert(Vec::new());
953        interceptors.push(InterceptorWrapper::from_stream_fn(f));
954        drop(handlers_lock);
955        self
956    }
957
958    /// Builds the `DefaultMediator`.
959    pub fn build(self) -> DefaultMediator {
960        self.inner
961    }
962}
963
964impl Default for Builder {
965    fn default() -> Self {
966        Builder::new()
967    }
968}
969
970/// Assert the `DefaultMediator` is `Send + Sync`.
971/// ```rust
972/// use mediator::DefaultMediator;
973///
974/// fn assert_send_sync<T: Send + Sync>(t: T) {
975///     drop(t);
976/// }
977///
978/// let mediator = DefaultMediator::builder().build();
979/// assert_send_sync(mediator);
980/// ```
981#[cfg(test)]
982fn _dummy() {
983    fn assert_send_sync<T: Send + Sync>(_: T) {}
984    assert_send_sync(DefaultMediator::builder().build());
985}
986
987#[cfg(test)]
988mod tests {
989    use crate::{
990        box_stream, DefaultMediator, Event, EventHandler, Interceptor, Mediator, Request,
991        RequestHandler, StreamInterceptor,
992    };
993    use std::ops::Range;
994    use std::sync::atomic::AtomicUsize;
995    use std::sync::{Arc, Mutex};
996
997    #[cfg(feature = "streams")]
998    use tokio_stream::StreamExt;
999
1000    #[cfg(feature = "streams")]
1001    use crate::{StreamRequest, StreamRequestHandler};
1002
1003    #[cfg(feature = "streams")]
1004    use crate::futures::BoxStream;
1005
1006    #[cfg(feature = "streams")]
1007    macro_rules! drain_stream {
1008        ($stream:expr) => {{
1009            let mut stream = $stream;
1010            while crate::futures::StreamExt::next(&mut stream).await.is_some() {}
1011        }};
1012    }
1013
1014    #[test]
1015    fn send_request_test() {
1016        struct TwoTimesRequest(i64);
1017        impl Request<i64> for TwoTimesRequest {}
1018
1019        struct TwoTimesRequestHandler;
1020        impl RequestHandler<TwoTimesRequest, i64> for TwoTimesRequestHandler {
1021            fn handle(&mut self, request: TwoTimesRequest) -> i64 {
1022                request.0 * 2
1023            }
1024        }
1025
1026        let mut mediator = DefaultMediator::builder()
1027            .add_handler(TwoTimesRequestHandler)
1028            .build();
1029
1030        assert_eq!(4, mediator.send(TwoTimesRequest(2)).unwrap());
1031        assert_eq!(-6, mediator.send(TwoTimesRequest(-3)).unwrap());
1032    }
1033
1034    #[test]
1035    fn publish_event_test() {
1036        #[derive(Clone)]
1037        struct IncrementEvent;
1038        impl Event for IncrementEvent {}
1039
1040        static COUNTER: AtomicUsize = AtomicUsize::new(0);
1041
1042        struct TestEventHandler;
1043        impl EventHandler<IncrementEvent> for TestEventHandler {
1044            fn handle(&mut self, _: IncrementEvent) {
1045                COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1046            }
1047        }
1048
1049        let mut mediator = DefaultMediator::builder()
1050            .subscribe(TestEventHandler)
1051            .build();
1052
1053        mediator.publish(IncrementEvent).unwrap();
1054        mediator.publish(IncrementEvent).unwrap();
1055        assert_eq!(2, COUNTER.load(std::sync::atomic::Ordering::SeqCst));
1056
1057        mediator.publish(IncrementEvent).unwrap();
1058        assert_eq!(3, COUNTER.load(std::sync::atomic::Ordering::SeqCst));
1059    }
1060
1061    #[tokio::test(flavor = "multi_thread")]
1062    #[cfg(feature = "streams")]
1063    async fn stream_test() {
1064        use tokio_stream::StreamExt;
1065
1066        struct CounterRequest(u32);
1067        impl StreamRequest for CounterRequest {
1068            type Stream = tokio_stream::Iter<Range<u32>>;
1069            type Item = u32;
1070        }
1071
1072        struct CounterRequestHandler;
1073        impl StreamRequestHandler for CounterRequestHandler {
1074            type Request = CounterRequest;
1075            type Stream = tokio_stream::Iter<Range<u32>>;
1076            type Item = u32;
1077
1078            fn handle_stream(&mut self, req: CounterRequest) -> Self::Stream {
1079                tokio_stream::iter(0..req.0)
1080            }
1081        }
1082
1083        let mut mediator = DefaultMediator::builder()
1084            .add_stream_handler(CounterRequestHandler)
1085            .build();
1086
1087        let mut counter_stream = mediator.stream(CounterRequest(5)).unwrap();
1088        assert_eq!(0, counter_stream.next().await.unwrap());
1089        assert_eq!(1, counter_stream.next().await.unwrap());
1090        assert_eq!(2, counter_stream.next().await.unwrap());
1091        assert_eq!(3, counter_stream.next().await.unwrap());
1092        assert_eq!(4, counter_stream.next().await.unwrap());
1093        assert!(counter_stream.next().await.is_none());
1094    }
1095
1096    #[tokio::test(flavor = "multi_thread")]
1097    #[cfg(feature = "streams")]
1098    async fn stream_deferred_fn_with_test() {
1099        struct CounterRequest(u32);
1100        impl StreamRequest for CounterRequest {
1101            type Stream = BoxStream<'static, u32>;
1102            type Item = u32;
1103        }
1104
1105        #[derive(Clone)]
1106        struct RequestEvent(u32);
1107        impl Event for RequestEvent {}
1108
1109        let counter = Arc::new(Mutex::new(0_u32));
1110        let request_history = Arc::new(Mutex::new(Vec::<u32>::new()));
1111        let request_history_copy = request_history.clone();
1112
1113        let mut mediator = DefaultMediator::builder()
1114            .add_stream_handler_fn_deferred_with(
1115                counter.clone(),
1116                |req: CounterRequest, mut mediator, c| {
1117                    box_stream! { _yx move =>
1118                        let mut c = c.lock().unwrap();
1119
1120                        for _ in 0..req.0 {
1121                            *c += 1;
1122                        }
1123
1124                        mediator.publish(RequestEvent(req.0)).unwrap();
1125                    }
1126                },
1127            )
1128            .subscribe_fn(move |event: RequestEvent| {
1129                request_history_copy.lock().unwrap().push(event.0);
1130            })
1131            .build();
1132
1133        drain_stream!(mediator.stream(CounterRequest(5)).unwrap());
1134        drain_stream!(mediator.stream(CounterRequest(3)).unwrap());
1135
1136        assert_eq!(8, *counter.lock().unwrap());
1137
1138        let lock = request_history.lock().unwrap();
1139        assert_eq!(8, lock.iter().cloned().sum::<u32>());
1140    }
1141
1142    #[test]
1143    #[cfg(feature = "interceptors")]
1144    fn interceptor_test() {
1145        struct SumRequest(u32, u32);
1146        impl Request<u64> for SumRequest {}
1147
1148        struct SumRequestInterceptor;
1149        impl Interceptor<SumRequest, u64> for SumRequestInterceptor {
1150            fn handle(&mut self, req: SumRequest, next: Box<dyn FnOnce(SumRequest) -> u64>) -> u64 {
1151                let result = next(req);
1152                result * 2
1153            }
1154        }
1155
1156        let mut mediator = DefaultMediator::builder()
1157            .add_handler(|req: SumRequest| (req.0 + req.1) as u64)
1158            .add_interceptor(SumRequestInterceptor)
1159            .build();
1160
1161        let r1 = mediator.send(SumRequest(1, 2)).unwrap();
1162        assert_eq!(6, r1);
1163
1164        let r2 = mediator.send(SumRequest(3, 4)).unwrap();
1165        assert_eq!(14, r2);
1166    }
1167
1168    #[test]
1169    #[cfg(feature = "interceptors")]
1170    fn interceptor_fn_test() {
1171        struct SumRequest(u32, u32);
1172        impl Request<u64> for SumRequest {}
1173
1174        let mut mediator = DefaultMediator::builder()
1175            .add_handler(|req: SumRequest| (req.0 + req.1) as u64)
1176            .add_interceptor_fn(
1177                |req: SumRequest, next: Box<dyn FnOnce(SumRequest) -> u64>| {
1178                    let result = next(req);
1179                    result * 2
1180                },
1181            )
1182            .build();
1183
1184        let r1 = mediator.send(SumRequest(1, 2)).unwrap();
1185        assert_eq!(6, r1);
1186
1187        let r2 = mediator.send(SumRequest(3, 4)).unwrap();
1188        assert_eq!(14, r2);
1189    }
1190
1191    #[tokio::test(flavor = "multi_thread")]
1192    #[cfg(all(feature = "interceptors", feature = "streams"))]
1193    async fn interceptor_stream_test() {
1194        struct CountRequest(u32);
1195        impl StreamRequest for CountRequest {
1196            type Stream = crate::futures::BoxStream<'static, u32>;
1197            type Item = u32;
1198        }
1199
1200        struct CountRequestInterceptor;
1201        impl StreamInterceptor for CountRequestInterceptor {
1202            type Request = CountRequest;
1203            type Stream = crate::futures::BoxStream<'static, u32>;
1204            type Item = u32;
1205
1206            fn handle_stream(
1207                &mut self,
1208                req: Self::Request,
1209                next: Box<dyn FnOnce(Self::Request) -> Self::Stream>,
1210            ) -> Self::Stream {
1211                let result = next(req);
1212                Box::pin(result.map(|x| x * 2))
1213            }
1214        }
1215
1216        let mut mediator = DefaultMediator::builder()
1217            .add_stream_handler_fn(|req: CountRequest| Box::pin(crate::futures::iter(0..req.0)))
1218            .add_interceptor_stream(CountRequestInterceptor)
1219            .build();
1220
1221        let result = mediator.stream(CountRequest(5)).unwrap();
1222        assert_eq!(vec![0, 2, 4, 6, 8], result.collect::<Vec<_>>().await);
1223    }
1224
1225    #[tokio::test(flavor = "multi_thread")]
1226    #[cfg(all(feature = "interceptors", feature = "streams"))]
1227    async fn interceptor_stream_fn_test() {
1228        struct CountRequest(u32);
1229        impl StreamRequest for CountRequest {
1230            type Stream = crate::futures::BoxStream<'static, u32>;
1231            type Item = u32;
1232        }
1233
1234        let mut mediator = DefaultMediator::builder()
1235            .add_stream_handler_fn(|req: CountRequest| Box::pin(crate::futures::iter(1..=req.0)))
1236            .add_interceptor_stream_fn(|req: CountRequest, next| {
1237                let result = next(req);
1238                Box::pin(result.map(|x| x * 3))
1239            })
1240            .build();
1241
1242        let result = mediator.stream(CountRequest(3)).unwrap();
1243        assert_eq!(vec![3, 6, 9], result.collect::<Vec<_>>().await);
1244    }
1245}