use crate::reducer::Reducer;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
pub type SubscriptionId = usize;
type SharedState<S> = Arc<Mutex<S>>;
type Subscriber<State> = Box<dyn Fn(&State) + Send + Sync>;
type SubscriberMap<State> = Arc<Mutex<HashMap<SubscriptionId, Subscriber<State>>>>;
pub struct Store<State, Action> {
state: SharedState<State>,
reducer: Arc<Mutex<Box<dyn Reducer<State, Action> + Send + Sync>>>,
subscribers: SubscriberMap<State>,
next_subscriber_id: AtomicUsize,
}
impl<State: Clone + Send + 'static, Action: Send + 'static> Store<State, Action> {
pub fn new(
initial_state: State,
reducer: Box<dyn Reducer<State, Action> + Send + Sync>,
) -> Self {
Self {
state: Arc::new(Mutex::new(initial_state)),
reducer: Arc::new(Mutex::new(reducer)),
subscribers: Arc::new(Mutex::new(HashMap::new())),
next_subscriber_id: AtomicUsize::new(0),
}
}
pub fn dispatch(&self, action: Action) {
let new_state = {
let mut state = self.state.lock().unwrap();
let reducer = self.reducer.lock().unwrap();
let new_state = reducer.reduce(&state, &action);
*state = new_state.clone();
new_state
};
self.notify_subscribers(&new_state);
}
pub fn dispatch_batch(&self, actions: Vec<Action>) {
if actions.is_empty() {
return;
}
let new_state = {
let mut state = self.state.lock().unwrap();
let reducer = self.reducer.lock().unwrap();
for action in actions {
let temp_state = reducer.reduce(&state, &action);
*state = temp_state;
}
state.clone()
};
self.notify_subscribers(&new_state);
}
pub fn subscribe<F>(&self, f: F) -> SubscriptionId
where
F: Fn(&State) + Send + Sync + 'static,
{
let id = self.next_subscriber_id.fetch_add(1, Ordering::SeqCst);
self.subscribers.lock().unwrap().insert(id, Box::new(f));
id
}
pub fn unsubscribe(&self, id: SubscriptionId) -> bool {
self.subscribers.lock().unwrap().remove(&id).is_some()
}
pub fn get_state(&self) -> State {
self.state.lock().unwrap().clone()
}
pub fn with_state<R, F>(&self, f: F) -> R
where
F: FnOnce(&State) -> R,
{
let state = self.state.lock().unwrap();
f(&state)
}
pub fn replace_reducer(&self, new_reducer: Box<dyn Reducer<State, Action> + Send + Sync>) {
let mut reducer = self.reducer.lock().unwrap();
*reducer = new_reducer;
}
pub fn subscriber_count(&self) -> usize {
self.subscribers.lock().unwrap().len()
}
fn notify_subscribers(&self, new_state: &State) {
let subscribers = self.subscribers.lock().unwrap();
for subscriber in subscribers.values() {
subscriber(new_state);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::create_reducer;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[derive(Clone, Debug, PartialEq)]
struct TestState {
counter: i32,
}
#[derive(Clone)]
enum TestAction {
Increment,
Decrement,
SetValue(i32),
}
fn create_test_store() -> Store<TestState, TestAction> {
let reducer = create_reducer(|state: &TestState, action: &TestAction| match action {
TestAction::Increment => TestState {
counter: state.counter + 1,
},
TestAction::Decrement => TestState {
counter: state.counter - 1,
},
TestAction::SetValue(val) => TestState { counter: *val },
});
Store::new(TestState { counter: 0 }, Box::new(reducer))
}
#[test]
fn test_basic_operations() {
let store = create_test_store();
assert_eq!(store.get_state().counter, 0);
store.dispatch(TestAction::Increment);
assert_eq!(store.get_state().counter, 1);
store.dispatch(TestAction::Decrement);
assert_eq!(store.get_state().counter, 0);
store.dispatch(TestAction::SetValue(42));
assert_eq!(store.get_state().counter, 42);
}
#[test]
fn test_subscribe_and_unsubscribe() {
let store = create_test_store();
let notifications = Arc::new(Mutex::new(Vec::new()));
let notifications_clone = notifications.clone();
assert_eq!(store.subscriber_count(), 0);
let id = store.subscribe(move |state| {
notifications_clone.lock().unwrap().push(state.counter);
});
assert_eq!(store.subscriber_count(), 1);
store.dispatch(TestAction::Increment);
store.dispatch(TestAction::Increment);
thread::sleep(Duration::from_millis(10));
{
let notifs = notifications.lock().unwrap();
assert_eq!(notifs.len(), 2);
assert_eq!(notifs[0], 1);
assert_eq!(notifs[1], 2);
}
assert!(store.unsubscribe(id));
assert_eq!(store.subscriber_count(), 0);
assert!(!store.unsubscribe(id));
store.dispatch(TestAction::Increment);
thread::sleep(Duration::from_millis(10));
let notifs = notifications.lock().unwrap();
assert_eq!(notifs.len(), 2); }
#[test]
fn test_dispatch_batch() {
let store = create_test_store();
let notifications = Arc::new(Mutex::new(Vec::new()));
let notifications_clone = notifications.clone();
store.subscribe(move |state| {
notifications_clone.lock().unwrap().push(state.counter);
});
store.dispatch_batch(vec![
TestAction::Increment,
TestAction::Increment,
TestAction::Increment,
]);
thread::sleep(Duration::from_millis(10));
let notifs = notifications.lock().unwrap();
assert_eq!(notifs.len(), 1); assert_eq!(notifs[0], 3); assert_eq!(store.get_state().counter, 3);
}
#[test]
fn test_with_state() {
let store = create_test_store();
store.dispatch(TestAction::SetValue(100));
let result = store.with_state(|state| state.counter * 2);
assert_eq!(result, 200);
assert_eq!(store.get_state().counter, 100);
}
#[test]
fn test_concurrent_access() {
let store = Arc::new(create_test_store());
let mut handles = vec![];
for _ in 0..10 {
let store_clone = store.clone();
let handle = thread::spawn(move || {
for _ in 0..100 {
store_clone.dispatch(TestAction::Increment);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(store.get_state().counter, 1000);
}
#[test]
fn test_replace_reducer() {
let store = create_test_store();
store.dispatch(TestAction::Increment);
assert_eq!(store.get_state().counter, 1);
let new_reducer = create_reducer(|state: &TestState, action: &TestAction| match action {
TestAction::Increment => TestState {
counter: state.counter + 10,
},
_ => state.clone(),
});
store.replace_reducer(Box::new(new_reducer));
store.dispatch(TestAction::Increment);
assert_eq!(store.get_state().counter, 11); }
}