use serde::{Deserialize, Serialize};
use std::future::Future;
use std::pin::Pin;
pub trait State: Clone + Serialize + for<'de> Deserialize<'de> + Send + Sync {
type Message: Send;
fn update(&mut self, msg: Self::Message) -> Command<Self::Message>;
}
#[derive(Default)]
pub enum Command<M> {
#[default]
None,
Batch(Vec<Self>),
Task(Pin<Box<dyn Future<Output = M> + Send>>),
Navigate {
route: String,
},
SaveState {
key: String,
},
LoadState {
key: String,
on_load: fn(Option<Vec<u8>>) -> M,
},
}
impl<M> Command<M> {
pub fn task<F>(future: F) -> Self
where
F: Future<Output = M> + Send + 'static,
{
Self::Task(Box::pin(future))
}
pub fn batch(commands: impl IntoIterator<Item = Self>) -> Self {
Self::Batch(commands.into_iter().collect())
}
#[must_use]
pub const fn is_none(&self) -> bool {
matches!(self, Self::None)
}
pub fn map<N, F>(self, f: F) -> Command<N>
where
F: Fn(M) -> N + Send + Sync + 'static,
M: Send + 'static,
N: Send + 'static,
{
let f: std::sync::Arc<dyn Fn(M) -> N + Send + Sync> = std::sync::Arc::new(f);
self.map_inner(&f)
}
fn map_inner<N>(self, f: &std::sync::Arc<dyn Fn(M) -> N + Send + Sync>) -> Command<N>
where
M: Send + 'static,
N: Send + 'static,
{
match self {
Self::None => Command::None,
Self::Batch(cmds) => Command::Batch(cmds.into_iter().map(|c| c.map_inner(f)).collect()),
Self::Task(fut) => {
let f = f.clone();
Command::Task(Box::pin(async move { f(fut.await) }))
}
Self::Navigate { route } => Command::Navigate { route },
Self::SaveState { key } => Command::SaveState { key },
Self::LoadState { .. } => {
Command::None
}
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CounterState {
pub count: i32,
}
#[derive(Debug, Clone)]
pub enum CounterMessage {
Increment,
Decrement,
Set(i32),
Reset,
}
impl State for CounterState {
type Message = CounterMessage;
fn update(&mut self, msg: Self::Message) -> Command<Self::Message> {
match msg {
CounterMessage::Increment => self.count += 1,
CounterMessage::Decrement => self.count -= 1,
CounterMessage::Set(value) => self.count = value,
CounterMessage::Reset => self.count = 0,
}
Command::None
}
}
type Subscriber<S> = Box<dyn Fn(&S) + Send + Sync>;
pub struct Store<S: State> {
state: S,
history: Vec<S>,
history_index: usize,
max_history: usize,
subscribers: Vec<Subscriber<S>>,
}
impl<S: State> Store<S> {
pub fn new(initial: S) -> Self {
Self {
state: initial,
history: Vec::new(),
history_index: 0,
max_history: 100,
subscribers: Vec::new(),
}
}
pub fn with_history_limit(initial: S, max_history: usize) -> Self {
Self {
state: initial,
history: Vec::new(),
history_index: 0,
max_history,
subscribers: Vec::new(),
}
}
pub const fn state(&self) -> &S {
&self.state
}
pub fn dispatch(&mut self, msg: S::Message) -> Command<S::Message> {
if self.max_history > 0 {
if self.history_index < self.history.len() {
self.history.truncate(self.history_index);
}
self.history.push(self.state.clone());
if self.history.len() > self.max_history {
self.history.remove(0);
} else {
self.history_index = self.history.len();
}
}
let cmd = self.state.update(msg);
self.notify_subscribers();
cmd
}
pub fn subscribe<F>(&mut self, callback: F)
where
F: Fn(&S) + Send + Sync + 'static,
{
self.subscribers.push(Box::new(callback));
}
pub fn history_len(&self) -> usize {
self.history.len()
}
pub const fn can_undo(&self) -> bool {
self.history_index > 0
}
pub fn can_redo(&self) -> bool {
self.history_index < self.history.len()
}
pub fn undo(&mut self) -> bool {
if self.can_undo() {
if self.history_index == self.history.len() {
self.history.push(self.state.clone());
}
self.history_index -= 1;
self.state = self.history[self.history_index].clone();
self.notify_subscribers();
true
} else {
false
}
}
pub fn redo(&mut self) -> bool {
if self.history_index < self.history.len().saturating_sub(1) {
self.history_index += 1;
self.state = self.history[self.history_index].clone();
self.notify_subscribers();
true
} else {
false
}
}
pub fn jump_to(&mut self, index: usize) -> bool {
if index < self.history.len() {
self.history_index = index;
self.state = self.history[index].clone();
self.notify_subscribers();
true
} else {
false
}
}
pub fn clear_history(&mut self) {
self.history.clear();
self.history_index = 0;
}
fn notify_subscribers(&self) {
for subscriber in &self.subscribers {
subscriber(&self.state);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_counter_increment() {
let mut state = CounterState::default();
state.update(CounterMessage::Increment);
assert_eq!(state.count, 1);
}
#[test]
fn test_counter_decrement() {
let mut state = CounterState { count: 5 };
state.update(CounterMessage::Decrement);
assert_eq!(state.count, 4);
}
#[test]
fn test_counter_set() {
let mut state = CounterState::default();
state.update(CounterMessage::Set(42));
assert_eq!(state.count, 42);
}
#[test]
fn test_counter_reset() {
let mut state = CounterState { count: 100 };
state.update(CounterMessage::Reset);
assert_eq!(state.count, 0);
}
#[test]
fn test_command_none() {
let cmd: Command<()> = Command::None;
assert!(cmd.is_none());
}
#[test]
fn test_command_default() {
let cmd: Command<()> = Command::default();
assert!(cmd.is_none());
}
#[test]
fn test_command_batch() {
let cmd: Command<i32> = Command::batch([
Command::Navigate {
route: "/a".to_string(),
},
Command::Navigate {
route: "/b".to_string(),
},
]);
assert!(!cmd.is_none());
if let Command::Batch(cmds) = cmd {
assert_eq!(cmds.len(), 2);
} else {
panic!("Expected Batch command");
}
}
#[test]
fn test_command_navigate() {
let cmd: Command<()> = Command::Navigate {
route: "/home".to_string(),
};
if let Command::Navigate { route } = cmd {
assert_eq!(route, "/home");
} else {
panic!("Expected Navigate command");
}
}
#[test]
fn test_command_save_state() {
let cmd: Command<()> = Command::SaveState {
key: "app_state".to_string(),
};
if let Command::SaveState { key } = cmd {
assert_eq!(key, "app_state");
} else {
panic!("Expected SaveState command");
}
}
#[test]
fn test_counter_serialization() {
let state = CounterState { count: 42 };
let json = serde_json::to_string(&state).unwrap();
let loaded: CounterState = serde_json::from_str(&json).unwrap();
assert_eq!(loaded.count, 42);
}
#[test]
fn test_command_map() {
let cmd: Command<i32> = Command::Navigate {
route: "/test".to_string(),
};
let mapped: Command<String> = cmd.map(|_i| "mapped".to_string());
if let Command::Navigate { route } = mapped {
assert_eq!(route, "/test");
} else {
panic!("Expected Navigate command after map");
}
}
#[test]
fn test_command_map_none() {
let cmd: Command<i32> = Command::None;
let mapped: Command<String> = cmd.map(|i| i.to_string());
assert!(mapped.is_none());
}
#[test]
fn test_command_batch_map() {
let cmd: Command<i32> = Command::batch([
Command::SaveState {
key: "key1".to_string(),
},
Command::SaveState {
key: "key2".to_string(),
},
]);
let mapped: Command<String> = cmd.map(|i| format!("val_{i}"));
if let Command::Batch(cmds) = mapped {
assert_eq!(cmds.len(), 2);
} else {
panic!("Expected Batch command after map");
}
}
#[test]
fn test_store_new() {
let store = Store::new(CounterState::default());
assert_eq!(store.state().count, 0);
}
#[test]
fn test_store_dispatch() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Increment);
assert_eq!(store.state().count, 1);
}
#[test]
fn test_store_history() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Increment);
store.dispatch(CounterMessage::Increment);
store.dispatch(CounterMessage::Increment);
assert_eq!(store.state().count, 3);
assert_eq!(store.history_len(), 3);
}
#[test]
fn test_store_undo() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Increment);
store.dispatch(CounterMessage::Increment);
assert_eq!(store.state().count, 2);
assert!(store.can_undo());
assert!(store.undo());
assert_eq!(store.state().count, 1);
assert!(store.undo());
assert_eq!(store.state().count, 0);
}
#[test]
fn test_store_redo() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Increment);
store.dispatch(CounterMessage::Increment);
store.undo();
store.undo();
assert_eq!(store.state().count, 0);
assert!(store.can_redo());
assert!(store.redo());
assert_eq!(store.state().count, 1);
assert!(store.redo());
assert_eq!(store.state().count, 2);
}
#[test]
fn test_store_undo_at_start() {
let mut store = Store::new(CounterState::default());
assert!(!store.can_undo());
assert!(!store.undo());
}
#[test]
fn test_store_redo_at_end() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Increment);
assert!(!store.can_redo());
assert!(!store.redo());
}
#[test]
fn test_store_history_truncation() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Set(1));
store.dispatch(CounterMessage::Set(2));
store.dispatch(CounterMessage::Set(3));
store.undo();
store.undo();
assert_eq!(store.state().count, 1);
store.dispatch(CounterMessage::Set(10));
assert_eq!(store.state().count, 10);
assert!(!store.redo());
}
#[test]
fn test_store_jump_to() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Set(10));
store.dispatch(CounterMessage::Set(20));
store.dispatch(CounterMessage::Set(30));
assert!(store.jump_to(0));
assert_eq!(store.state().count, 0);
assert!(store.jump_to(2));
assert_eq!(store.state().count, 20);
}
#[test]
fn test_store_jump_invalid() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Increment);
assert!(!store.jump_to(100));
}
#[test]
fn test_store_clear_history() {
let mut store = Store::new(CounterState::default());
store.dispatch(CounterMessage::Increment);
store.dispatch(CounterMessage::Increment);
assert!(store.history_len() > 0);
store.clear_history();
assert_eq!(store.history_len(), 0);
assert!(!store.can_undo());
}
#[test]
fn test_store_with_history_limit() {
let mut store = Store::with_history_limit(CounterState::default(), 3);
for i in 1..=10 {
store.dispatch(CounterMessage::Set(i));
}
assert!(store.history_len() <= 3);
}
#[test]
fn test_store_subscribe() {
use std::sync::atomic::{AtomicI32, Ordering};
use std::sync::Arc;
let call_count = Arc::new(AtomicI32::new(0));
let call_count_clone = call_count.clone();
let mut store = Store::new(CounterState::default());
store.subscribe(move |_| {
call_count_clone.fetch_add(1, Ordering::SeqCst);
});
store.dispatch(CounterMessage::Increment);
store.dispatch(CounterMessage::Increment);
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[test]
fn test_store_no_history() {
let mut store = Store::with_history_limit(CounterState::default(), 0);
store.dispatch(CounterMessage::Increment);
store.dispatch(CounterMessage::Increment);
assert_eq!(store.history_len(), 0);
assert!(!store.can_undo());
}
}