1use async_trait::async_trait;
2use std::marker::PhantomData;
3use tokio::task::JoinHandle;
4
5use crate::{
6 middleware::{MiddleWare, StoreApi, StoreWithMiddleware},
7 Reducer, Selector, Subscriber,
8};
9
10mod worker;
11use worker::{Address, Dispatch, Select, StateWorker, Subscribe};
12
13pub struct Store<State, Action, RootReducer>
19where
20 State: Send,
21 RootReducer: Send,
22{
23 worker_address: Address<State, Action, RootReducer>,
24 _worker_handle: JoinHandle<()>,
25
26 _types: PhantomData<RootReducer>,
27}
28
29impl<State, Action, RootReducer> Store<State, Action, RootReducer>
30where
31 Action: Send + 'static,
32 RootReducer: Reducer<State, Action> + Send + 'static,
33 State: Send + 'static,
34{
35 pub fn new(root_reducer: RootReducer) -> Self
37 where
38 State: Default,
39 {
40 Self::new_with_state(root_reducer, Default::default())
41 }
42
43 pub fn new_with_state(root_reducer: RootReducer, state: State) -> Self {
45 let mut worker = StateWorker::new(root_reducer, state);
46 let worker_address = worker.address();
47
48 let _worker_handle = tokio::spawn(async move {
49 worker.run().await;
50 });
51
52 Store {
53 worker_address,
54 _worker_handle,
55
56 _types: Default::default(),
57 }
58 }
59
60 pub async fn dispatch(&self, action: Action) {
65 self.worker_address.send(Dispatch::new(action)).await;
66 }
67
68 pub async fn select<S: Selector<State, Result = Result>, Result>(&self, selector: S) -> Result
71 where
72 S: Selector<State, Result = Result> + Send + 'static,
73 Result: Send + 'static,
74 {
75 self.worker_address.send(Select::new(selector)).await
76 }
77
78 pub async fn state_cloned(&self) -> State
81 where
82 State: Clone,
83 {
84 self.select(|state: &State| state.clone()).await
85 }
86
87 pub async fn subscribe<S: Subscriber<State> + Send + 'static>(&self, subscriber: S) {
90 self.worker_address.send(Subscribe::new(Box::new(subscriber))).await
91 }
92
93 pub async fn wrap<M, OuterAction>(self, middleware: M) -> StoreWithMiddleware<Self, M, State, Action, OuterAction>
95 where
96 M: MiddleWare<State, OuterAction, Self, Action> + Send + Sync,
97 OuterAction: Send + Sync + 'static,
98 State: Sync,
99 Action: Sync,
100 RootReducer: Sync,
101 {
102 StoreWithMiddleware::new(self, middleware).await
103 }
104}
105
106#[async_trait]
107impl<State, Action, RootReducer> StoreApi<State, Action> for Store<State, Action, RootReducer>
108where
109 Action: Send + Sync + 'static,
110 RootReducer: Reducer<State, Action> + Send + Sync + 'static,
111 State: Send + Sync + 'static,
112{
113 async fn dispatch(&self, action: Action) {
114 Store::dispatch(self, action).await
115 }
116
117 async fn select<S: Selector<State, Result = Result>, Result>(&self, selector: S) -> Result
118 where
119 S: Selector<State, Result = Result> + Send + 'static,
120 Result: Send + 'static,
121 {
122 Store::select(self, selector).await
123 }
124
125 async fn state_cloned(&self) -> State
126 where
127 State: Clone,
128 {
129 Store::state_cloned(self).await
130 }
131
132 async fn subscribe<S: Subscriber<State> + Send + 'static>(&self, subscriber: S) {
133 Store::subscribe(self, subscriber).await
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use std::sync::atomic::{AtomicI32, Ordering};
141 use std::sync::Arc;
142
143 #[derive(Clone, Debug, PartialEq)]
144 struct Counter {
145 value: i32,
146 }
147
148 impl Counter {
149 pub fn new(value: i32) -> Self {
150 Counter { value }
151 }
152 }
153
154 impl Default for Counter {
155 fn default() -> Self {
156 Self { value: 42 }
157 }
158 }
159
160 struct ValueSelector;
161 impl Selector<Counter> for ValueSelector {
162 type Result = i32;
163
164 fn select(&self, state: &Counter) -> Self::Result {
165 state.value
166 }
167 }
168
169 enum CounterAction {
170 Increment,
171 Decrement,
172 }
173
174 fn counter_reducer(state: Counter, action: CounterAction) -> Counter {
175 match action {
176 CounterAction::Increment => Counter { value: state.value + 1 },
177 CounterAction::Decrement => Counter { value: state.value - 1 },
178 }
179 }
180
181 #[tokio::test]
182 async fn counter_default_state() {
183 let store = Store::new(counter_reducer);
184 assert_eq!(Counter::default(), store.state_cloned().await);
185 }
186
187 #[tokio::test]
188 async fn counter_supplied_state() {
189 let store = Store::new_with_state(counter_reducer, Counter::new(5));
190 assert_eq!(Counter::new(5), store.state_cloned().await);
191 }
192
193 #[tokio::test]
194 async fn counter_actions_cloned_state() {
195 let store = Store::new(counter_reducer);
196 assert_eq!(Counter::new(42), store.state_cloned().await);
197
198 store.dispatch(CounterAction::Increment).await;
199 assert_eq!(Counter::new(43), store.state_cloned().await);
200
201 store.dispatch(CounterAction::Increment).await;
202 assert_eq!(Counter::new(44), store.state_cloned().await);
203
204 store.dispatch(CounterAction::Decrement).await;
205 assert_eq!(Counter::new(43), store.state_cloned().await);
206 }
207
208 #[tokio::test]
209 async fn counter_actions_selector_struct() {
210 let store = Store::new(counter_reducer);
211 assert_eq!(42, store.select(ValueSelector).await);
212
213 store.dispatch(CounterAction::Increment).await;
214 assert_eq!(43, store.select(ValueSelector).await);
215
216 store.dispatch(CounterAction::Increment).await;
217 assert_eq!(44, store.select(ValueSelector).await);
218
219 store.dispatch(CounterAction::Decrement).await;
220 assert_eq!(43, store.select(ValueSelector).await);
221 }
222
223 #[tokio::test]
224 async fn counter_actions_selector_lambda() {
225 let store = Store::new(counter_reducer);
226 assert_eq!(42, store.select(|state: &Counter| state.value).await);
227
228 store.dispatch(CounterAction::Increment).await;
229 assert_eq!(43, store.select(|state: &Counter| state.value).await);
230
231 store.dispatch(CounterAction::Increment).await;
232 assert_eq!(44, store.select(|state: &Counter| state.value).await);
233
234 store.dispatch(CounterAction::Decrement).await;
235 assert_eq!(43, store.select(|state: &Counter| state.value).await);
236 }
237
238 #[tokio::test]
239 async fn counter_subscribe() {
240 let store = Store::new(counter_reducer);
241 assert_eq!(42, store.select(|state: &Counter| state.value).await);
242
243 let sum = Arc::new(AtomicI32::new(0));
244
245 let captured_sum = sum.clone();
247 store
248 .subscribe(move |state: &Counter| {
249 captured_sum.fetch_add(state.value, Ordering::Relaxed);
250 })
251 .await;
252
253 store.dispatch(CounterAction::Increment).await;
254 store.dispatch(CounterAction::Increment).await;
255 store.dispatch(CounterAction::Decrement).await;
256
257 assert_eq!(sum.load(Ordering::Relaxed), 130);
259 }
260}