1use serde::{Deserialize, Serialize};
44use std::future::Future;
45use std::pin::Pin;
46
47pub trait State: Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync {
51 type Message: Send;
53
54 fn update(&mut self, msg: Self::Message) -> Command<Self::Message>;
58}
59
60#[derive(Default)]
67pub enum Command<M> {
68 #[default]
70 None,
71 Batch(Vec<Self>),
73 Task(Pin<Box<dyn Future<Output = M> + Send>>),
75 Navigate {
77 route: String,
79 },
80 SaveState {
82 key: String,
84 },
85 LoadState {
87 key: String,
89 on_load: fn(Option<Vec<u8>>) -> M,
91 },
92}
93
94impl<M> std::fmt::Debug for Command<M> {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 match self {
97 Self::None => write!(f, "Command::None"),
98 Self::Batch(cmds) => f.debug_tuple("Command::Batch").field(&cmds.len()).finish(),
99 Self::Task(_) => write!(f, "Command::Task(..)"),
100 Self::Navigate { route } => f
101 .debug_struct("Command::Navigate")
102 .field("route", route)
103 .finish(),
104 Self::SaveState { key } => f
105 .debug_struct("Command::SaveState")
106 .field("key", key)
107 .finish(),
108 Self::LoadState { key, .. } => f
109 .debug_struct("Command::LoadState")
110 .field("key", key)
111 .finish_non_exhaustive(),
112 }
113 }
114}
115
116impl<M> Command<M> {
117 pub fn task<F>(future: F) -> Self
119 where
120 F: Future<Output = M> + Send + 'static,
121 {
122 Self::Task(Box::pin(future))
123 }
124
125 pub fn batch(commands: impl IntoIterator<Item = Self>) -> Self {
127 Self::Batch(commands.into_iter().collect())
128 }
129
130 #[must_use]
132 pub const fn is_none(&self) -> bool {
133 matches!(self, Self::None)
134 }
135
136 pub fn map<N, F>(self, f: F) -> Command<N>
138 where
139 F: Fn(M) -> N + Send + Sync + 'static,
140 M: Send + 'static,
141 N: Send + 'static,
142 {
143 let f: std::sync::Arc<dyn Fn(M) -> N + Send + Sync> = std::sync::Arc::new(f);
144 self.map_inner(&f)
145 }
146
147 fn map_inner<N>(self, f: &std::sync::Arc<dyn Fn(M) -> N + Send + Sync>) -> Command<N>
148 where
149 M: Send + 'static,
150 N: Send + 'static,
151 {
152 match self {
153 Self::None => Command::None,
154 Self::Batch(cmds) => Command::Batch(cmds.into_iter().map(|c| c.map_inner(f)).collect()),
155 Self::Task(fut) => {
156 let f = f.clone();
157 Command::Task(Box::pin(async move { f(fut.await) }))
158 }
159 Self::Navigate { route } => Command::Navigate { route },
160 Self::SaveState { key } => Command::SaveState { key },
161 Self::LoadState { .. } => {
162 Command::None
165 }
166 }
167 }
168}
169
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct CounterState {
173 pub count: i32,
175}
176
177#[derive(Debug, Clone)]
179pub enum CounterMessage {
180 Increment,
182 Decrement,
184 Set(i32),
186 Reset,
188}
189
190impl State for CounterState {
191 type Message = CounterMessage;
192
193 fn update(&mut self, msg: Self::Message) -> Command<Self::Message> {
194 match msg {
195 CounterMessage::Increment => self.count += 1,
196 CounterMessage::Decrement => self.count -= 1,
197 CounterMessage::Set(value) => self.count = value,
198 CounterMessage::Reset => self.count = 0,
199 }
200 Command::None
201 }
202}
203
204type Subscriber<S> = Box<dyn Fn(&S) + Send + Sync>;
206
207pub struct Store<S: State> {
209 state: S,
210 history: Vec<S>,
211 history_index: usize,
212 max_history: usize,
213 subscribers: Vec<Subscriber<S>>,
214}
215
216impl<S: State + std::fmt::Debug> std::fmt::Debug for Store<S> {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 f.debug_struct("Store")
219 .field("state", &self.state)
220 .field("history_len", &self.history.len())
221 .field("history_index", &self.history_index)
222 .field("max_history", &self.max_history)
223 .field("subscribers_count", &self.subscribers.len())
224 .finish()
225 }
226}
227
228impl<S: State> Store<S> {
229 pub fn new(initial: S) -> Self {
231 Self {
232 state: initial,
233 history: Vec::new(),
234 history_index: 0,
235 max_history: 100,
236 subscribers: Vec::new(),
237 }
238 }
239
240 pub fn with_history_limit(initial: S, max_history: usize) -> Self {
242 Self {
243 state: initial,
244 history: Vec::new(),
245 history_index: 0,
246 max_history,
247 subscribers: Vec::new(),
248 }
249 }
250
251 pub const fn state(&self) -> &S {
253 &self.state
254 }
255
256 pub fn dispatch(&mut self, msg: S::Message) -> Command<S::Message> {
258 if self.max_history > 0 {
260 if self.history_index < self.history.len() {
262 self.history.truncate(self.history_index);
263 }
264
265 self.history.push(self.state.clone());
266
267 if self.history.len() > self.max_history {
269 self.history.remove(0);
270 } else {
271 self.history_index = self.history.len();
272 }
273 }
274
275 let cmd = self.state.update(msg);
277
278 self.notify_subscribers();
280
281 cmd
282 }
283
284 pub fn subscribe<F>(&mut self, callback: F)
286 where
287 F: Fn(&S) + Send + Sync + 'static,
288 {
289 self.subscribers.push(Box::new(callback));
290 }
291
292 pub fn history_len(&self) -> usize {
294 self.history.len()
295 }
296
297 pub const fn can_undo(&self) -> bool {
299 self.history_index > 0
300 }
301
302 pub fn can_redo(&self) -> bool {
304 self.history_index < self.history.len()
305 }
306
307 pub fn undo(&mut self) -> bool {
309 if self.can_undo() {
310 if self.history_index == self.history.len() {
312 self.history.push(self.state.clone());
313 }
314
315 self.history_index -= 1;
316 self.state = self.history[self.history_index].clone();
317 self.notify_subscribers();
318 true
319 } else {
320 false
321 }
322 }
323
324 pub fn redo(&mut self) -> bool {
326 if self.history_index < self.history.len().saturating_sub(1) {
327 self.history_index += 1;
328 self.state = self.history[self.history_index].clone();
329 self.notify_subscribers();
330 true
331 } else {
332 false
333 }
334 }
335
336 pub fn jump_to(&mut self, index: usize) -> bool {
338 if index < self.history.len() {
339 self.history_index = index;
340 self.state = self.history[index].clone();
341 self.notify_subscribers();
342 true
343 } else {
344 false
345 }
346 }
347
348 pub fn clear_history(&mut self) {
350 self.history.clear();
351 self.history_index = 0;
352 }
353
354 fn notify_subscribers(&self) {
355 for subscriber in &self.subscribers {
356 subscriber(&self.state);
357 }
358 }
359}
360
361#[cfg(test)]
362#[allow(clippy::unwrap_used, clippy::disallowed_methods)]
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_counter_increment() {
368 let mut state = CounterState::default();
369 state.update(CounterMessage::Increment);
370 assert_eq!(state.count, 1);
371 }
372
373 #[test]
374 fn test_counter_decrement() {
375 let mut state = CounterState { count: 5 };
376 state.update(CounterMessage::Decrement);
377 assert_eq!(state.count, 4);
378 }
379
380 #[test]
381 fn test_counter_set() {
382 let mut state = CounterState::default();
383 state.update(CounterMessage::Set(42));
384 assert_eq!(state.count, 42);
385 }
386
387 #[test]
388 fn test_counter_reset() {
389 let mut state = CounterState { count: 100 };
390 state.update(CounterMessage::Reset);
391 assert_eq!(state.count, 0);
392 }
393
394 #[test]
395 fn test_command_none() {
396 let cmd: Command<()> = Command::None;
397 assert!(cmd.is_none());
398 }
399
400 #[test]
401 fn test_command_default() {
402 let cmd: Command<()> = Command::default();
403 assert!(cmd.is_none());
404 }
405
406 #[test]
407 fn test_command_batch() {
408 let cmd: Command<i32> = Command::batch([
409 Command::Navigate {
410 route: "/a".to_string(),
411 },
412 Command::Navigate {
413 route: "/b".to_string(),
414 },
415 ]);
416 assert!(!cmd.is_none());
417 if let Command::Batch(cmds) = cmd {
418 assert_eq!(cmds.len(), 2);
419 } else {
420 panic!("Expected Batch command");
421 }
422 }
423
424 #[test]
425 fn test_command_navigate() {
426 let cmd: Command<()> = Command::Navigate {
427 route: "/home".to_string(),
428 };
429 if let Command::Navigate { route } = cmd {
430 assert_eq!(route, "/home");
431 } else {
432 panic!("Expected Navigate command");
433 }
434 }
435
436 #[test]
437 fn test_command_save_state() {
438 let cmd: Command<()> = Command::SaveState {
439 key: "app_state".to_string(),
440 };
441 if let Command::SaveState { key } = cmd {
442 assert_eq!(key, "app_state");
443 } else {
444 panic!("Expected SaveState command");
445 }
446 }
447
448 #[test]
449 fn test_counter_serialization() {
450 let state = CounterState { count: 42 };
451 let json = serde_json::to_string(&state).unwrap();
452 let loaded: CounterState = serde_json::from_str(&json).unwrap();
453 assert_eq!(loaded.count, 42);
454 }
455
456 #[test]
457 fn test_command_map() {
458 let cmd: Command<i32> = Command::Navigate {
459 route: "/test".to_string(),
460 };
461 let mapped: Command<String> = cmd.map(|_i| "mapped".to_string());
462
463 if let Command::Navigate { route } = mapped {
464 assert_eq!(route, "/test");
465 } else {
466 panic!("Expected Navigate command after map");
467 }
468 }
469
470 #[test]
471 fn test_command_map_none() {
472 let cmd: Command<i32> = Command::None;
473 let mapped: Command<String> = cmd.map(|i| i.to_string());
474 assert!(mapped.is_none());
475 }
476
477 #[test]
478 fn test_command_batch_map() {
479 let cmd: Command<i32> = Command::batch([
480 Command::SaveState {
481 key: "key1".to_string(),
482 },
483 Command::SaveState {
484 key: "key2".to_string(),
485 },
486 ]);
487
488 let mapped: Command<String> = cmd.map(|i| format!("val_{i}"));
489
490 if let Command::Batch(cmds) = mapped {
491 assert_eq!(cmds.len(), 2);
492 } else {
493 panic!("Expected Batch command after map");
494 }
495 }
496
497 #[test]
502 fn test_store_new() {
503 let store = Store::new(CounterState::default());
504 assert_eq!(store.state().count, 0);
505 }
506
507 #[test]
508 fn test_store_dispatch() {
509 let mut store = Store::new(CounterState::default());
510 store.dispatch(CounterMessage::Increment);
511 assert_eq!(store.state().count, 1);
512 }
513
514 #[test]
515 fn test_store_history() {
516 let mut store = Store::new(CounterState::default());
517
518 store.dispatch(CounterMessage::Increment);
519 store.dispatch(CounterMessage::Increment);
520 store.dispatch(CounterMessage::Increment);
521
522 assert_eq!(store.state().count, 3);
523 assert_eq!(store.history_len(), 3);
524 }
525
526 #[test]
527 fn test_store_undo() {
528 let mut store = Store::new(CounterState::default());
529
530 store.dispatch(CounterMessage::Increment);
531 store.dispatch(CounterMessage::Increment);
532 assert_eq!(store.state().count, 2);
533
534 assert!(store.can_undo());
535 assert!(store.undo());
536 assert_eq!(store.state().count, 1);
537
538 assert!(store.undo());
539 assert_eq!(store.state().count, 0);
540 }
541
542 #[test]
543 fn test_store_redo() {
544 let mut store = Store::new(CounterState::default());
545
546 store.dispatch(CounterMessage::Increment);
547 store.dispatch(CounterMessage::Increment);
548 store.undo();
549 store.undo();
550
551 assert_eq!(store.state().count, 0);
552 assert!(store.can_redo());
553
554 assert!(store.redo());
555 assert_eq!(store.state().count, 1);
556
557 assert!(store.redo());
558 assert_eq!(store.state().count, 2);
559 }
560
561 #[test]
562 fn test_store_undo_at_start() {
563 let mut store = Store::new(CounterState::default());
564 assert!(!store.can_undo());
565 assert!(!store.undo());
566 }
567
568 #[test]
569 fn test_store_redo_at_end() {
570 let mut store = Store::new(CounterState::default());
571 store.dispatch(CounterMessage::Increment);
572 assert!(!store.can_redo());
573 assert!(!store.redo());
574 }
575
576 #[test]
577 fn test_store_history_truncation() {
578 let mut store = Store::new(CounterState::default());
579
580 store.dispatch(CounterMessage::Set(1));
581 store.dispatch(CounterMessage::Set(2));
582 store.dispatch(CounterMessage::Set(3));
583
584 store.undo();
586 store.undo();
587 assert_eq!(store.state().count, 1);
588
589 store.dispatch(CounterMessage::Set(10));
591 assert_eq!(store.state().count, 10);
592
593 assert!(!store.redo());
595 }
596
597 #[test]
598 fn test_store_jump_to() {
599 let mut store = Store::new(CounterState::default());
600
601 store.dispatch(CounterMessage::Set(10));
602 store.dispatch(CounterMessage::Set(20));
603 store.dispatch(CounterMessage::Set(30));
604
605 assert!(store.jump_to(0));
606 assert_eq!(store.state().count, 0);
607
608 assert!(store.jump_to(2));
609 assert_eq!(store.state().count, 20);
610 }
611
612 #[test]
613 fn test_store_jump_invalid() {
614 let mut store = Store::new(CounterState::default());
615 store.dispatch(CounterMessage::Increment);
616
617 assert!(!store.jump_to(100));
618 }
619
620 #[test]
621 fn test_store_clear_history() {
622 let mut store = Store::new(CounterState::default());
623
624 store.dispatch(CounterMessage::Increment);
625 store.dispatch(CounterMessage::Increment);
626 assert!(store.history_len() > 0);
627
628 store.clear_history();
629 assert_eq!(store.history_len(), 0);
630 assert!(!store.can_undo());
631 }
632
633 #[test]
634 fn test_store_with_history_limit() {
635 let mut store = Store::with_history_limit(CounterState::default(), 3);
636
637 for i in 1..=10 {
638 store.dispatch(CounterMessage::Set(i));
639 }
640
641 assert!(store.history_len() <= 3);
643 }
644
645 #[test]
646 fn test_store_subscribe() {
647 use std::sync::atomic::{AtomicI32, Ordering};
648 use std::sync::Arc;
649
650 let call_count = Arc::new(AtomicI32::new(0));
651 let call_count_clone = call_count.clone();
652
653 let mut store = Store::new(CounterState::default());
654 store.subscribe(move |_| {
655 call_count_clone.fetch_add(1, Ordering::SeqCst);
656 });
657
658 store.dispatch(CounterMessage::Increment);
659 store.dispatch(CounterMessage::Increment);
660
661 assert_eq!(call_count.load(Ordering::SeqCst), 2);
662 }
663
664 #[test]
665 fn test_store_no_history() {
666 let mut store = Store::with_history_limit(CounterState::default(), 0);
667
668 store.dispatch(CounterMessage::Increment);
669 store.dispatch(CounterMessage::Increment);
670
671 assert_eq!(store.history_len(), 0);
672 assert!(!store.can_undo());
673 }
674}