redux_rs/
middleware.rs

1use crate::{Selector, Subscriber};
2use async_trait::async_trait;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6/// The store api offers an abstraction around all store functionality.
7///
8/// Both Store and StoreWithMiddleware implement StoreApi.
9/// This enables us to wrap multiple middlewares around each other.
10#[async_trait]
11pub trait StoreApi<State, Action>
12where
13    Action: Send + 'static,
14    State: Send + 'static,
15{
16    /// Dispatch a new action to the store
17    ///
18    /// Notice that this method takes &self and not &mut self,
19    /// this enables us to dispatch actions from multiple places at once without requiring locks.
20    async fn dispatch(&self, action: Action);
21
22    /// Select a part of the state, this is more efficient than copying the entire state all the time.
23    /// In case you still need a full copy of the state, use the state_cloned method.
24    async fn select<S: Selector<State, Result = Result>, Result>(&self, selector: S) -> Result
25    where
26        S: Selector<State, Result = Result> + Send + 'static,
27        Result: Send + 'static;
28
29    /// Returns a cloned version of the state.
30    /// This is not efficient, if you only need a part of the state use select instead
31    async fn state_cloned(&self) -> State
32    where
33        State: Clone,
34    {
35        self.select(|state: &State| state.clone()).await
36    }
37
38    /// Subscribe to state changes.
39    /// Every time an action is dispatched the subscriber will be notified after the state is updated
40    async fn subscribe<S: Subscriber<State> + Send + 'static>(&self, subscriber: S);
41}
42
43/// Middlewares are the way to introduce side effects to the redux store.
44///
45/// Some examples of middleware could be:
46/// - Logging middleware, log every action
47/// - Api call middleware, make an api call when a certain action is send
48///
49/// Notice that there's an Action and an InnerAction.
50/// This enables us to send actions which are not of the same type as the underlying store.
51///
52/// ## Logging middleware example
53/// ```
54/// use async_trait::async_trait;
55/// use std::sync::Arc;
56/// use redux_rs::{MiddleWare, Store, StoreApi};
57///
58/// #[derive(Default)]
59/// struct Counter(i8);
60///
61/// #[derive(Debug)]
62/// enum Action {
63///     Increment,
64///     Decrement
65/// }
66///
67/// fn counter_reducer(state: Counter, action: Action) -> Counter {
68///     match action {
69///         Action::Increment => Counter(state.0 + 1),
70///         Action::Decrement => Counter(state.0 - 1),
71///     }
72/// }
73///
74/// // Logger which logs every action before it's dispatched to the store
75/// struct LoggerMiddleware;
76/// #[async_trait]
77/// impl<Inner> MiddleWare<Counter, Action, Inner> for LoggerMiddleware
78///     where
79/// Inner: StoreApi<Counter, Action> + Send + Sync
80/// {
81///     async fn dispatch(&self, action: Action, inner: &Arc<Inner>)
82///     {
83///         // Print the action
84///         println!("Before action: {:?}", action);
85///
86///         // Dispatch the action to the underlying store
87///         inner.dispatch(action).await;
88///     }
89/// }
90///
91/// # #[tokio::main(flavor = "current_thread")]
92/// # async fn async_test() {
93/// // Create a new store and wrap it with out new LoggerMiddleware
94/// let store = Store::new(counter_reducer).wrap(LoggerMiddleware).await;
95///
96/// // Dispatch an increment action
97/// // The console should print our text
98/// store.dispatch(Action::Increment).await;
99///
100/// // Dispatch an decrement action
101/// // The console should print our text
102/// store.dispatch(Action::Decrement).await;
103/// # }
104/// ```
105#[async_trait]
106pub trait MiddleWare<State, Action, Inner, InnerAction = Action>
107where
108    Action: Send + 'static,
109    State: Send + 'static,
110    InnerAction: Send + 'static,
111    Inner: StoreApi<State, InnerAction> + Send + Sync,
112{
113    /// This method is called the moment the middleware is wrapped around an underlying store api.
114    /// Initialization could be done here.
115    ///
116    /// For example, you could launch an "application started" action
117    #[allow(unused_variables)]
118    async fn init(&mut self, inner: &Arc<Inner>) {}
119
120    /// This method is called every time an action is dispatched to the store.
121    ///
122    /// You have the possibility to modify/cancel the action entirely.
123    /// You could also do certain actions before or after launching a specific/every action.
124    ///
125    /// NOTE: In the middleware you need to call `inner.dispatch(action).await;` otherwise no actions will be send to the underlying StoreApi (and eventually store)
126    async fn dispatch(&self, action: Action, inner: &Arc<Inner>);
127}
128
129/// Store which ties an underlying store and middleware together.
130pub struct StoreWithMiddleware<Inner, M, State, InnerAction, OuterAction>
131where
132    Inner: StoreApi<State, InnerAction> + Send + Sync,
133    M: MiddleWare<State, OuterAction, Inner, InnerAction> + Send + Sync,
134    State: Send + Sync + 'static,
135    InnerAction: Send + Sync + 'static,
136    OuterAction: Send + Sync + 'static,
137{
138    inner: Arc<Inner>,
139    middleware: M,
140
141    _types: PhantomData<(State, InnerAction, OuterAction)>,
142}
143
144impl<Inner, M, State, InnerAction, OuterAction> StoreWithMiddleware<Inner, M, State, InnerAction, OuterAction>
145where
146    Inner: StoreApi<State, InnerAction> + Send + Sync,
147    M: MiddleWare<State, OuterAction, Inner, InnerAction> + Send + Sync,
148    State: Send + Sync + 'static,
149    InnerAction: Send + Sync + 'static,
150    OuterAction: Send + Sync + 'static,
151{
152    pub(crate) async fn new(inner: Inner, mut middleware: M) -> Self {
153        let inner = Arc::new(inner);
154
155        middleware.init(&inner).await;
156
157        StoreWithMiddleware {
158            inner,
159            middleware,
160            _types: Default::default(),
161        }
162    }
163
164    /// Wrap the store with middleware
165    pub async fn wrap<MNew, NewOuterAction>(self, middleware: MNew) -> StoreWithMiddleware<Self, MNew, State, OuterAction, NewOuterAction>
166    where
167        MNew: MiddleWare<State, NewOuterAction, Self, OuterAction> + Send + Sync,
168        NewOuterAction: Send + Sync + 'static,
169        State: Sync,
170    {
171        StoreWithMiddleware::new(self, middleware).await
172    }
173}
174
175#[async_trait]
176impl<Inner, M, State, InnerAction, OuterAction> StoreApi<State, OuterAction> for StoreWithMiddleware<Inner, M, State, InnerAction, OuterAction>
177where
178    Inner: StoreApi<State, InnerAction> + Send + Sync,
179    M: MiddleWare<State, OuterAction, Inner, InnerAction> + Send + Sync,
180    State: Send + Sync + 'static,
181    InnerAction: Send + Sync + 'static,
182    OuterAction: Send + Sync + 'static,
183{
184    async fn dispatch(&self, action: OuterAction) {
185        self.middleware.dispatch(action, &self.inner).await
186    }
187
188    async fn select<S: Selector<State, Result = Result>, Result>(&self, selector: S) -> Result
189    where
190        S: Selector<State, Result = Result> + Send + 'static,
191        Result: Send + 'static,
192    {
193        self.inner.select(selector).await
194    }
195
196    async fn subscribe<S: Subscriber<State> + Send + 'static>(&self, subscriber: S) {
197        self.inner.subscribe(subscriber).await;
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::Store;
205    use std::sync::Mutex;
206
207    #[derive(Default)]
208    struct LogStore {
209        logs: Vec<String>,
210    }
211
212    struct Log(String);
213
214    fn log_reducer(store: LogStore, action: Log) -> LogStore {
215        let mut logs = store.logs;
216        logs.push(action.0);
217
218        LogStore { logs }
219    }
220
221    struct LoggerMiddleware {
222        prefix: &'static str,
223        logs: Arc<Mutex<Vec<String>>>,
224    }
225
226    impl LoggerMiddleware {
227        pub fn new(prefix: &'static str, logs: Arc<Mutex<Vec<String>>>) -> Self {
228            LoggerMiddleware { logs, prefix }
229        }
230
231        pub fn log(&self, message: String) {
232            let mut logs = self.logs.lock().unwrap();
233            logs.push(format!("[{}] {}", self.prefix, message));
234        }
235    }
236
237    #[async_trait]
238    impl<Inner> MiddleWare<LogStore, Log, Inner> for LoggerMiddleware
239    where
240        Inner: StoreApi<LogStore, Log> + Send + Sync,
241    {
242        async fn dispatch(&self, action: Log, inner: &Arc<Inner>) {
243            let log_message = action.0.clone();
244
245            // Simulate logging to the console, we log to a vec so we can unit test
246            self.log(format!("Before dispatching log message: {:?}", log_message));
247
248            // Dispatch the actual action
249            inner.dispatch(action).await;
250
251            // Simulate logging to the console, we log to a vec so we can unit test
252            self.log(format!("After dispatching log message: {:?}", log_message));
253        }
254    }
255
256    #[tokio::test]
257    async fn logger_middleware() {
258        let logs = Arc::new(Mutex::new(Vec::new()));
259        let log_middleware = LoggerMiddleware::new("log", logs.clone());
260
261        let store = Store::new(log_reducer).wrap(log_middleware).await;
262
263        store.dispatch(Log("Log 1".to_string())).await;
264
265        {
266            let lock = logs.lock().unwrap();
267            let logs: &Vec<String> = lock.as_ref();
268            assert_eq!(
269                logs,
270                &vec![
271                    "[log] Before dispatching log message: \"Log 1\"".to_string(),
272                    "[log] After dispatching log message: \"Log 1\"".to_string(),
273                ]
274            );
275        }
276
277        store.dispatch(Log("Log 2".to_string())).await;
278
279        {
280            let lock = logs.lock().unwrap();
281            let logs: &Vec<String> = lock.as_ref();
282            assert_eq!(
283                logs,
284                &vec![
285                    "[log] Before dispatching log message: \"Log 1\"".to_string(),
286                    "[log] After dispatching log message: \"Log 1\"".to_string(),
287                    "[log] Before dispatching log message: \"Log 2\"".to_string(),
288                    "[log] After dispatching log message: \"Log 2\"".to_string()
289                ]
290            );
291        }
292    }
293
294    #[tokio::test]
295    async fn logger_nested_middlewares() {
296        let logs = Arc::new(Mutex::new(Vec::new()));
297        let log_middleware_1 = LoggerMiddleware::new("middleware_1", logs.clone());
298        let log_middleware_2 = LoggerMiddleware::new("middleware_2", logs.clone());
299
300        let store = Store::new(log_reducer).wrap(log_middleware_1).await.wrap(log_middleware_2).await;
301
302        store.dispatch(Log("Log 1".to_string())).await;
303
304        {
305            let lock = logs.lock().unwrap();
306            let logs: &Vec<String> = lock.as_ref();
307            assert_eq!(
308                logs,
309                &vec![
310                    "[middleware_2] Before dispatching log message: \"Log 1\"".to_string(),
311                    "[middleware_1] Before dispatching log message: \"Log 1\"".to_string(),
312                    "[middleware_1] After dispatching log message: \"Log 1\"".to_string(),
313                    "[middleware_2] After dispatching log message: \"Log 1\"".to_string(),
314                ]
315            );
316        }
317
318        store.dispatch(Log("Log 2".to_string())).await;
319
320        {
321            let lock = logs.lock().unwrap();
322            let logs: &Vec<String> = lock.as_ref();
323            assert_eq!(
324                logs,
325                &vec![
326                    "[middleware_2] Before dispatching log message: \"Log 1\"".to_string(),
327                    "[middleware_1] Before dispatching log message: \"Log 1\"".to_string(),
328                    "[middleware_1] After dispatching log message: \"Log 1\"".to_string(),
329                    "[middleware_2] After dispatching log message: \"Log 1\"".to_string(),
330                    "[middleware_2] Before dispatching log message: \"Log 2\"".to_string(),
331                    "[middleware_1] Before dispatching log message: \"Log 2\"".to_string(),
332                    "[middleware_1] After dispatching log message: \"Log 2\"".to_string(),
333                    "[middleware_2] After dispatching log message: \"Log 2\"".to_string(),
334                ]
335            );
336        }
337    }
338}