use crate::{BackpressurePolicy, Subscriber, Subscription};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub const DEFAULT_CAPACITY: usize = 16;
pub const DEFAULT_STORE_NAME: &str = "store";
#[derive(Debug, thiserror::Error)]
pub enum StoreError {
#[error("dispatch error: {0}")]
DispatchError(String),
#[error("reducer error: {0}")]
ReducerError(String),
#[error("subscription error: {0}")]
SubscriptionError(String),
#[error("middleware error: {0}")]
MiddlewareError(String),
#[error("initialization error: {0}")]
InitError(String),
#[error("state update failed: {context}, cause: {source}")]
StateUpdateError {
context: String,
source: Box<dyn std::error::Error + Send + Sync>,
},
}
pub trait Store<State, Action>: Send + Sync
where
State: Send + Sync + Clone + 'static,
Action: Send + Sync + Clone + std::fmt::Debug + 'static,
{
fn get_state(&self) -> State;
fn dispatch(&self, action: Action) -> Result<(), StoreError>;
fn add_subscriber(
&self,
subscriber: Arc<dyn Subscriber<State, Action> + Send + Sync>,
) -> Result<Box<dyn Subscription>, StoreError>;
fn subscribed(
&self,
subscriber: Box<dyn Subscriber<State, Action> + Send + Sync>,
) -> Result<Box<dyn Subscription>, StoreError>;
fn subscribed_with(
&self,
capacity: usize,
policy: BackpressurePolicy<(Instant, State, Action)>,
subscriber: Box<dyn Subscriber<State, Action> + Send + Sync>,
) -> Result<Box<dyn Subscription>, StoreError>;
fn stop(&self) -> Result<(), StoreError>;
fn stop_timeout(&self, timeout: Duration) -> Result<(), StoreError>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::StoreBuilder;
use crate::BackpressurePolicy;
use crate::Reducer;
use crate::StoreImpl;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq)]
struct TestState {
counter: i32,
message: String,
}
impl Default for TestState {
fn default() -> Self {
TestState {
counter: 0,
message: String::new(),
}
}
}
#[derive(Debug, Clone)]
enum TestAction {
Increment,
Decrement,
SetMessage(String),
}
struct TestReducer;
impl Reducer<TestState, TestAction> for TestReducer {
fn reduce(
&self,
state: &TestState,
action: &TestAction,
) -> crate::DispatchOp<TestState, TestAction> {
match action {
TestAction::Increment => {
let mut new_state = state.clone();
new_state.counter += 1;
crate::DispatchOp::Dispatch(new_state, vec![])
}
TestAction::Decrement => {
let mut new_state = state.clone();
new_state.counter -= 1;
crate::DispatchOp::Dispatch(new_state, vec![])
}
TestAction::SetMessage(msg) => {
let mut new_state = state.clone();
new_state.message = msg.clone();
crate::DispatchOp::Dispatch(new_state, vec![])
}
}
}
}
fn create_test_store() -> Arc<StoreImpl<TestState, TestAction>> {
StoreImpl::new_with(
TestState::default(),
vec![Box::new(TestReducer)],
"test-store".into(),
16,
BackpressurePolicy::default(),
vec![],
)
.unwrap()
}
struct TestChannneledReducer;
impl Reducer<i32, i32> for TestChannneledReducer {
fn reduce(&self, state: &i32, action: &i32) -> crate::DispatchOp<i32, i32> {
crate::DispatchOp::Dispatch(state + action, vec![])
}
}
struct TestChannelSubscriber {
received: Arc<Mutex<Vec<(i32, i32)>>>,
}
impl TestChannelSubscriber {
fn new(received: Arc<Mutex<Vec<(i32, i32)>>>) -> Self {
Self { received }
}
}
impl Subscriber<i32, i32> for TestChannelSubscriber {
fn on_notify(&self, state: &i32, action: &i32) {
self.received.lock().unwrap().push((*state, *action));
}
}
struct SlowSubscriber {
received: Arc<Mutex<Vec<(i32, i32)>>>,
delay: Duration,
}
impl SlowSubscriber {
fn new(received: Arc<Mutex<Vec<(i32, i32)>>>, delay: Duration) -> Self {
Self { received, delay }
}
}
impl Subscriber<i32, i32> for SlowSubscriber {
fn on_notify(&self, state: &i32, action: &i32) {
std::thread::sleep(self.delay);
self.received.lock().unwrap().push((*state, *action));
}
}
#[test]
fn test_store_get_state() {
let store = create_test_store();
let initial_state = store.get_state();
assert_eq!(initial_state.counter, 0);
assert_eq!(initial_state.message, "");
}
#[test]
fn test_store_dispatch() {
let store = create_test_store();
store.dispatch(TestAction::Increment).unwrap();
thread::sleep(Duration::from_millis(50));
let state = store.get_state();
assert_eq!(state.counter, 1);
store.dispatch(TestAction::SetMessage("Hello".into())).unwrap();
thread::sleep(Duration::from_millis(50));
let state = store.get_state();
assert_eq!(state.message, "Hello");
match store.stop() {
Ok(_) => println!("store stopped"),
Err(e) => {
panic!("store stop failed : {:?}", e);
}
}
}
#[test]
fn test_store_multiple_actions() {
let store = create_test_store();
store.dispatch(TestAction::Increment).unwrap();
store.dispatch(TestAction::Increment).unwrap();
store.dispatch(TestAction::SetMessage("Test".into())).unwrap();
store.dispatch(TestAction::Decrement).unwrap();
thread::sleep(Duration::from_millis(100));
let final_state = store.get_state();
assert_eq!(final_state.counter, 1);
assert_eq!(final_state.message, "Test");
match store.stop() {
Ok(_) => println!("store stopped"),
Err(e) => {
panic!("store stop failed : {:?}", e);
}
}
}
#[test]
fn test_store_after_stop() {
let store = create_test_store();
match store.stop() {
Ok(_) => println!("store stopped"),
Err(e) => {
panic!("store stop failed : {:?}", e);
}
}
let result = store.dispatch(TestAction::Increment);
assert!(result.is_err());
match result {
Err(StoreError::DispatchError(_)) => (),
_ => panic!("Expected DispatchError"),
}
}
#[test]
fn test_store_concurrent_access() {
let store = Arc::new(create_test_store());
let store_clone = store.clone();
let handle = thread::spawn(move || {
for _ in 0..5 {
store_clone.dispatch(TestAction::Increment).unwrap();
thread::sleep(Duration::from_millis(10));
}
});
for _ in 0..5 {
store.dispatch(TestAction::Decrement).unwrap();
thread::sleep(Duration::from_millis(10));
}
handle.join().unwrap();
thread::sleep(Duration::from_millis(100));
let final_state = store.get_state();
assert_eq!(final_state.counter, 0);
match store.stop() {
Ok(_) => println!("store stopped"),
Err(e) => {
panic!("store stop failed : {:?}", e);
}
}
}
#[test]
fn test_store_builder_configurations() {
#[allow(deprecated)]
let store = StoreBuilder::new(TestState::default())
.with_reducer(Box::new(TestReducer))
.with_name("custom-store".into())
.with_capacity(32)
.with_policy(BackpressurePolicy::DropLatestIf(None))
.build()
.unwrap();
store.dispatch(TestAction::Increment).unwrap();
thread::sleep(Duration::from_millis(50));
let state = store.get_state();
assert_eq!(state.counter, 1);
match store.stop() {
Ok(_) => println!("store stopped"),
Err(e) => {
panic!("store stop failed : {:?}", e);
}
}
}
#[test]
fn test_store_error_handling() {
let store = create_test_store();
match store.stop() {
Ok(_) => println!("store stopped"),
Err(e) => {
panic!("store stop failed : {:?}", e);
}
}
let dispatch_result = store.dispatch(TestAction::Increment);
assert!(matches!(dispatch_result, Err(StoreError::DispatchError(_))));
let state = store.get_state();
assert_eq!(state.counter, 0);
}
#[test]
fn test_subscribed_basic_functionality() {
let store =
StoreBuilder::new_with_reducer(0, Box::new(TestChannneledReducer)).build().unwrap();
let received = Arc::new(Mutex::new(Vec::new()));
let subscriber = Box::new(TestChannelSubscriber::new(received.clone()));
let subscription = store.subscribed(subscriber).unwrap();
store.dispatch(1).unwrap();
store.dispatch(2).unwrap();
store.dispatch(3).unwrap();
thread::sleep(Duration::from_millis(100));
store.stop().unwrap();
subscription.unsubscribe();
let states = received.lock().unwrap();
assert_eq!(states.len(), 3);
assert_eq!(states[0], (1, 1)); assert_eq!(states[1], (3, 2)); assert_eq!(states[2], (6, 3)); }
#[test]
fn test_subscribed_concurrent_subscribers() {
let store =
StoreBuilder::new_with_reducer(0, Box::new(TestChannneledReducer)).build().unwrap();
let received1 = Arc::new(Mutex::new(Vec::new()));
let received2 = Arc::new(Mutex::new(Vec::new()));
let received3 = Arc::new(Mutex::new(Vec::new()));
let subscription1 =
store.subscribed(Box::new(TestChannelSubscriber::new(received1.clone()))).unwrap();
let subscription2 =
store.subscribed(Box::new(TestChannelSubscriber::new(received2.clone()))).unwrap();
let subscription3 =
store.subscribed(Box::new(TestChannelSubscriber::new(received3.clone()))).unwrap();
for i in 1..=10 {
store.dispatch(i).unwrap();
}
thread::sleep(Duration::from_millis(200));
store.stop().unwrap();
subscription1.unsubscribe();
subscription2.unsubscribe();
subscription3.unsubscribe();
let states1 = received1.lock().unwrap();
let states2 = received2.lock().unwrap();
let states3 = received3.lock().unwrap();
assert_eq!(states1.len(), 10);
assert_eq!(states2.len(), 10);
assert_eq!(states3.len(), 10);
for i in 0..10 {
assert_eq!(states1[i].1, states2[i].1); assert_eq!(states2[i].1, states3[i].1);
}
}
#[test]
fn test_subscribed_drop_latest_if_policy() {
let store =
StoreBuilder::new_with_reducer(0, Box::new(TestChannneledReducer)).build().unwrap();
let received = Arc::new(Mutex::new(Vec::new()));
let subscriber = Box::new(SlowSubscriber::new(
received.clone(),
Duration::from_millis(50),
));
let predicate = Box::new(|(_, _, action): &(Instant, i32, i32)| *action < 5);
let policy = BackpressurePolicy::DropLatestIf(Some(predicate));
let subscription = store.subscribed_with(2, policy, subscriber).unwrap();
for i in 1..=10 {
store.dispatch(i).unwrap();
}
thread::sleep(Duration::from_millis(300));
store.stop().unwrap();
subscription.unsubscribe();
let states = received.lock().unwrap();
assert!(
states.len() < 10,
"Expected fewer than 10 messages due to backpressure, got {}",
states.len()
);
assert!(
states.len() > 0,
"Expected at least some messages to be received"
);
}
#[test]
fn test_subscribed_error_handling() {
let store =
StoreBuilder::new_with_reducer(0, Box::new(TestChannneledReducer)).build().unwrap();
let received = Arc::new(Mutex::new(Vec::new()));
let subscriber = Box::new(TestChannelSubscriber::new(received.clone()));
let subscription = store.subscribed(subscriber).unwrap();
subscription.unsubscribe();
store.dispatch(1).unwrap();
thread::sleep(Duration::from_millis(50));
let states = received.lock().unwrap();
assert_eq!(states.len(), 0);
store.stop().unwrap();
}
#[test]
fn test_subscribed_thread_lifecycle() {
let store =
StoreBuilder::new_with_reducer(0, Box::new(TestChannneledReducer)).build().unwrap();
let received = Arc::new(Mutex::new(Vec::new()));
let subscriber = Box::new(TestChannelSubscriber::new(received.clone()));
let subscription = store.subscribed(subscriber).unwrap();
store.dispatch(1).unwrap();
thread::sleep(Duration::from_millis(50));
subscription.unsubscribe();
store.dispatch(2).unwrap();
thread::sleep(Duration::from_millis(50));
let states = received.lock().unwrap();
assert_eq!(states.len(), 1); assert_eq!(states[0], (1, 1));
store.stop().unwrap();
}
#[test]
fn test_subscribed_mixed_with_add_subscriber() {
let store =
StoreBuilder::new_with_reducer(0, Box::new(TestChannneledReducer)).build().unwrap();
let received_main = Arc::new(Mutex::new(Vec::new()));
let subscriber_main = Arc::new(TestChannelSubscriber::new(received_main.clone()));
let _subscription_main = store.add_subscriber(subscriber_main).unwrap();
let received_channeled = Arc::new(Mutex::new(Vec::new()));
let subscriber_channeled = Box::new(TestChannelSubscriber::new(received_channeled.clone()));
let subscription_channeled = store.subscribed(subscriber_channeled).unwrap();
for i in 1..=5 {
store.dispatch(i).unwrap();
}
thread::sleep(Duration::from_millis(100));
store.stop().unwrap();
subscription_channeled.unsubscribe();
let states_main = received_main.lock().unwrap();
let states_channeled = received_channeled.lock().unwrap();
assert_eq!(states_main.len(), 5);
assert_eq!(states_channeled.len(), 5);
for i in 0..5 {
assert_eq!(states_main[i], states_channeled[i]);
}
}
}