1use crate::{Selector, Subscriber};
2use async_trait::async_trait;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6#[async_trait]
11pub trait StoreApi<State, Action>
12where
13 Action: Send + 'static,
14 State: Send + 'static,
15{
16 async fn dispatch(&self, action: Action);
21
22 async fn select<S: Selector<State, Result = Result>, Result>(&self, selector: S) -> Result
25 where
26 S: Selector<State, Result = Result> + Send + 'static,
27 Result: Send + 'static;
28
29 async fn state_cloned(&self) -> State
32 where
33 State: Clone,
34 {
35 self.select(|state: &State| state.clone()).await
36 }
37
38 async fn subscribe<S: Subscriber<State> + Send + 'static>(&self, subscriber: S);
41}
42
43#[async_trait]
106pub trait MiddleWare<State, Action, Inner, InnerAction = Action>
107where
108 Action: Send + 'static,
109 State: Send + 'static,
110 InnerAction: Send + 'static,
111 Inner: StoreApi<State, InnerAction> + Send + Sync,
112{
113 #[allow(unused_variables)]
118 async fn init(&mut self, inner: &Arc<Inner>) {}
119
120 async fn dispatch(&self, action: Action, inner: &Arc<Inner>);
127}
128
129pub struct StoreWithMiddleware<Inner, M, State, InnerAction, OuterAction>
131where
132 Inner: StoreApi<State, InnerAction> + Send + Sync,
133 M: MiddleWare<State, OuterAction, Inner, InnerAction> + Send + Sync,
134 State: Send + Sync + 'static,
135 InnerAction: Send + Sync + 'static,
136 OuterAction: Send + Sync + 'static,
137{
138 inner: Arc<Inner>,
139 middleware: M,
140
141 _types: PhantomData<(State, InnerAction, OuterAction)>,
142}
143
144impl<Inner, M, State, InnerAction, OuterAction> StoreWithMiddleware<Inner, M, State, InnerAction, OuterAction>
145where
146 Inner: StoreApi<State, InnerAction> + Send + Sync,
147 M: MiddleWare<State, OuterAction, Inner, InnerAction> + Send + Sync,
148 State: Send + Sync + 'static,
149 InnerAction: Send + Sync + 'static,
150 OuterAction: Send + Sync + 'static,
151{
152 pub(crate) async fn new(inner: Inner, mut middleware: M) -> Self {
153 let inner = Arc::new(inner);
154
155 middleware.init(&inner).await;
156
157 StoreWithMiddleware {
158 inner,
159 middleware,
160 _types: Default::default(),
161 }
162 }
163
164 pub async fn wrap<MNew, NewOuterAction>(self, middleware: MNew) -> StoreWithMiddleware<Self, MNew, State, OuterAction, NewOuterAction>
166 where
167 MNew: MiddleWare<State, NewOuterAction, Self, OuterAction> + Send + Sync,
168 NewOuterAction: Send + Sync + 'static,
169 State: Sync,
170 {
171 StoreWithMiddleware::new(self, middleware).await
172 }
173}
174
175#[async_trait]
176impl<Inner, M, State, InnerAction, OuterAction> StoreApi<State, OuterAction> for StoreWithMiddleware<Inner, M, State, InnerAction, OuterAction>
177where
178 Inner: StoreApi<State, InnerAction> + Send + Sync,
179 M: MiddleWare<State, OuterAction, Inner, InnerAction> + Send + Sync,
180 State: Send + Sync + 'static,
181 InnerAction: Send + Sync + 'static,
182 OuterAction: Send + Sync + 'static,
183{
184 async fn dispatch(&self, action: OuterAction) {
185 self.middleware.dispatch(action, &self.inner).await
186 }
187
188 async fn select<S: Selector<State, Result = Result>, Result>(&self, selector: S) -> Result
189 where
190 S: Selector<State, Result = Result> + Send + 'static,
191 Result: Send + 'static,
192 {
193 self.inner.select(selector).await
194 }
195
196 async fn subscribe<S: Subscriber<State> + Send + 'static>(&self, subscriber: S) {
197 self.inner.subscribe(subscriber).await;
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::Store;
205 use std::sync::Mutex;
206
207 #[derive(Default)]
208 struct LogStore {
209 logs: Vec<String>,
210 }
211
212 struct Log(String);
213
214 fn log_reducer(store: LogStore, action: Log) -> LogStore {
215 let mut logs = store.logs;
216 logs.push(action.0);
217
218 LogStore { logs }
219 }
220
221 struct LoggerMiddleware {
222 prefix: &'static str,
223 logs: Arc<Mutex<Vec<String>>>,
224 }
225
226 impl LoggerMiddleware {
227 pub fn new(prefix: &'static str, logs: Arc<Mutex<Vec<String>>>) -> Self {
228 LoggerMiddleware { logs, prefix }
229 }
230
231 pub fn log(&self, message: String) {
232 let mut logs = self.logs.lock().unwrap();
233 logs.push(format!("[{}] {}", self.prefix, message));
234 }
235 }
236
237 #[async_trait]
238 impl<Inner> MiddleWare<LogStore, Log, Inner> for LoggerMiddleware
239 where
240 Inner: StoreApi<LogStore, Log> + Send + Sync,
241 {
242 async fn dispatch(&self, action: Log, inner: &Arc<Inner>) {
243 let log_message = action.0.clone();
244
245 self.log(format!("Before dispatching log message: {:?}", log_message));
247
248 inner.dispatch(action).await;
250
251 self.log(format!("After dispatching log message: {:?}", log_message));
253 }
254 }
255
256 #[tokio::test]
257 async fn logger_middleware() {
258 let logs = Arc::new(Mutex::new(Vec::new()));
259 let log_middleware = LoggerMiddleware::new("log", logs.clone());
260
261 let store = Store::new(log_reducer).wrap(log_middleware).await;
262
263 store.dispatch(Log("Log 1".to_string())).await;
264
265 {
266 let lock = logs.lock().unwrap();
267 let logs: &Vec<String> = lock.as_ref();
268 assert_eq!(
269 logs,
270 &vec![
271 "[log] Before dispatching log message: \"Log 1\"".to_string(),
272 "[log] After dispatching log message: \"Log 1\"".to_string(),
273 ]
274 );
275 }
276
277 store.dispatch(Log("Log 2".to_string())).await;
278
279 {
280 let lock = logs.lock().unwrap();
281 let logs: &Vec<String> = lock.as_ref();
282 assert_eq!(
283 logs,
284 &vec![
285 "[log] Before dispatching log message: \"Log 1\"".to_string(),
286 "[log] After dispatching log message: \"Log 1\"".to_string(),
287 "[log] Before dispatching log message: \"Log 2\"".to_string(),
288 "[log] After dispatching log message: \"Log 2\"".to_string()
289 ]
290 );
291 }
292 }
293
294 #[tokio::test]
295 async fn logger_nested_middlewares() {
296 let logs = Arc::new(Mutex::new(Vec::new()));
297 let log_middleware_1 = LoggerMiddleware::new("middleware_1", logs.clone());
298 let log_middleware_2 = LoggerMiddleware::new("middleware_2", logs.clone());
299
300 let store = Store::new(log_reducer).wrap(log_middleware_1).await.wrap(log_middleware_2).await;
301
302 store.dispatch(Log("Log 1".to_string())).await;
303
304 {
305 let lock = logs.lock().unwrap();
306 let logs: &Vec<String> = lock.as_ref();
307 assert_eq!(
308 logs,
309 &vec![
310 "[middleware_2] Before dispatching log message: \"Log 1\"".to_string(),
311 "[middleware_1] Before dispatching log message: \"Log 1\"".to_string(),
312 "[middleware_1] After dispatching log message: \"Log 1\"".to_string(),
313 "[middleware_2] After dispatching log message: \"Log 1\"".to_string(),
314 ]
315 );
316 }
317
318 store.dispatch(Log("Log 2".to_string())).await;
319
320 {
321 let lock = logs.lock().unwrap();
322 let logs: &Vec<String> = lock.as_ref();
323 assert_eq!(
324 logs,
325 &vec![
326 "[middleware_2] Before dispatching log message: \"Log 1\"".to_string(),
327 "[middleware_1] Before dispatching log message: \"Log 1\"".to_string(),
328 "[middleware_1] After dispatching log message: \"Log 1\"".to_string(),
329 "[middleware_2] After dispatching log message: \"Log 1\"".to_string(),
330 "[middleware_2] Before dispatching log message: \"Log 2\"".to_string(),
331 "[middleware_1] Before dispatching log message: \"Log 2\"".to_string(),
332 "[middleware_1] After dispatching log message: \"Log 2\"".to_string(),
333 "[middleware_2] After dispatching log message: \"Log 2\"".to_string(),
334 ]
335 );
336 }
337 }
338}