use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
static SUBSCRIBER_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
pub fn generate_subscriber_id() -> u64 {
SUBSCRIBER_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
}
pub trait Subscriber<State, Action>
where
State: Send + Sync + Clone,
Action: Send + Sync + Clone + 'static,
{
#[allow(unused_variables)]
fn on_subscribe(&self, state: &State) {}
fn on_notify(&self, state: &State, action: &Action);
fn on_unsubscribe(&self) {}
}
#[derive(Clone)]
pub(crate) struct SubscriberWithId<State, Action>
where
State: Send + Sync + Clone,
Action: Send + Sync + Clone + 'static,
{
pub id: u64, pub subscriber: Arc<dyn Subscriber<State, Action> + Send + Sync>,
}
impl<State, Action> SubscriberWithId<State, Action>
where
State: Send + Sync + Clone,
Action: Send + Sync + Clone + 'static,
{
pub fn new(subscriber: Arc<dyn Subscriber<State, Action> + Send + Sync>) -> Self {
Self {
id: generate_subscriber_id(),
subscriber,
}
}
#[allow(dead_code)]
pub fn with_id(id: u64, subscriber: Arc<dyn Subscriber<State, Action> + Send + Sync>) -> Self {
Self { id, subscriber }
}
}
impl<State, Action> Subscriber<State, Action> for SubscriberWithId<State, Action>
where
State: Send + Sync + Clone,
Action: Send + Sync + Clone + 'static,
{
fn on_subscribe(&self, state: &State) {
self.subscriber.on_subscribe(state)
}
fn on_notify(&self, state: &State, action: &Action) {
self.subscriber.on_notify(state, action)
}
fn on_unsubscribe(&self) {
self.subscriber.on_unsubscribe()
}
}
pub trait Subscription: Send {
fn unsubscribe(&self);
}
pub struct FnSubscriber<F, State, Action>
where
F: Fn(&State, &Action),
State: Send + Sync + Clone,
Action: Send + Sync + Clone + 'static,
{
func: F,
_marker: std::marker::PhantomData<(State, Action)>,
}
impl<F, State, Action> Subscriber<State, Action> for FnSubscriber<F, State, Action>
where
F: Fn(&State, &Action),
State: Send + Sync + Clone,
Action: Send + Sync + Clone + 'static,
{
fn on_notify(&self, state: &State, action: &Action) {
(self.func)(state, action)
}
}
impl<F, State, Action> From<F> for FnSubscriber<F, State, Action>
where
F: Fn(&State, &Action),
State: Send + Sync + Clone,
Action: Send + Sync + Clone + 'static,
{
fn from(func: F) -> Self {
Self {
func,
_marker: std::marker::PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[allow(dead_code)]
#[derive(Default, Clone)]
struct TestState {
counter: i32,
name: String,
}
#[allow(dead_code)]
#[derive(Clone, Debug)]
enum TestAction {
IncrementCounter,
#[allow(dead_code)]
SetName(String),
}
struct TestSubscriberWithSubscribe {
received_states: Arc<Mutex<Vec<TestState>>>,
received_actions: Arc<Mutex<Vec<TestAction>>>,
subscribe_called: Arc<Mutex<bool>>,
}
impl TestSubscriberWithSubscribe {
fn new() -> Self {
Self {
received_states: Arc::new(Mutex::new(Vec::new())),
received_actions: Arc::new(Mutex::new(Vec::new())),
subscribe_called: Arc::new(Mutex::new(false)),
}
}
fn get_received_states(&self) -> Vec<TestState> {
self.received_states.lock().unwrap().clone()
}
fn get_received_actions(&self) -> Vec<TestAction> {
self.received_actions.lock().unwrap().clone()
}
fn was_subscribe_called(&self) -> bool {
*self.subscribe_called.lock().unwrap()
}
}
impl Subscriber<TestState, TestAction> for TestSubscriberWithSubscribe {
fn on_subscribe(&self, state: &TestState) {
self.received_states.lock().unwrap().push(state.clone());
*self.subscribe_called.lock().unwrap() = true;
}
fn on_notify(&self, state: &TestState, action: &TestAction) {
self.received_states.lock().unwrap().push(state.clone());
self.received_actions.lock().unwrap().push(action.clone());
}
}
#[test]
fn test_subscriber_on_subscribe() {
let subscriber = TestSubscriberWithSubscribe::new();
let state = TestState {
counter: 42,
name: "test".to_string(),
};
subscriber.on_subscribe(&state);
assert!(subscriber.was_subscribe_called());
let received_states = subscriber.get_received_states();
assert_eq!(received_states.len(), 1);
assert_eq!(received_states[0].counter, 42);
assert_eq!(received_states[0].name, "test");
}
#[test]
fn test_subscriber_on_notify() {
let subscriber = TestSubscriberWithSubscribe::new();
let state = TestState {
counter: 100,
name: "notify".to_string(),
};
let action = TestAction::IncrementCounter;
subscriber.on_notify(&state, &action);
let received_states = subscriber.get_received_states();
let received_actions = subscriber.get_received_actions();
assert_eq!(received_states.len(), 1);
assert_eq!(received_actions.len(), 1);
assert_eq!(received_states[0].counter, 100);
assert_eq!(received_states[0].counter, 100);
assert_eq!(received_states[0].name, "notify");
}
#[test]
fn test_fn_subscriber() {
let received_states = Arc::new(Mutex::new(Vec::new()));
let received_actions = Arc::new(Mutex::new(Vec::new()));
let fn_subscriber = FnSubscriber::from(|state: &TestState, action: &TestAction| {
received_states.lock().unwrap().push(state.clone());
received_actions.lock().unwrap().push(action.clone());
});
let state = TestState {
counter: 42,
name: "test".to_string(),
};
let action = TestAction::IncrementCounter;
fn_subscriber.on_notify(&state, &action);
let states = received_states.lock().unwrap();
let actions = received_actions.lock().unwrap();
assert_eq!(states.len(), 1);
assert_eq!(actions.len(), 1);
assert_eq!(states[0].counter, 42);
assert_eq!(states[0].name, "test");
}
#[test]
fn test_multiple_on_subscribe_calls() {
let subscriber = TestSubscriberWithSubscribe::new();
let state1 = TestState {
counter: 10,
name: "first".to_string(),
};
let state2 = TestState {
counter: 20,
name: "second".to_string(),
};
subscriber.on_subscribe(&state1);
assert!(subscriber.was_subscribe_called());
assert_eq!(subscriber.get_received_states().len(), 1);
assert_eq!(subscriber.get_received_states()[0].counter, 10);
subscriber.on_subscribe(&state2);
assert!(subscriber.was_subscribe_called());
assert_eq!(subscriber.get_received_states().len(), 2);
assert_eq!(subscriber.get_received_states()[1].counter, 20);
}
#[test]
fn test_default_on_unsubscribe() {
let subscriber = TestSubscriberWithSubscribe::new();
subscriber.on_unsubscribe();
assert!(!subscriber.was_subscribe_called());
assert_eq!(subscriber.get_received_states().len(), 0);
assert_eq!(subscriber.get_received_actions().len(), 0);
}
#[test]
fn test_complex_state_and_action() {
let subscriber = TestSubscriberWithSubscribe::new();
let complex_state = TestState {
counter: 999,
name: "complex_test_state".to_string(),
};
let complex_action = TestAction::SetName("new_name".to_string());
subscriber.on_subscribe(&complex_state);
subscriber.on_notify(&complex_state, &complex_action);
assert!(subscriber.was_subscribe_called());
let states = subscriber.get_received_states();
let actions = subscriber.get_received_actions();
assert_eq!(states.len(), 2); assert_eq!(actions.len(), 1);
assert_eq!(states[0].counter, 999);
assert_eq!(states[0].name, "complex_test_state");
assert_eq!(states[1].counter, 999);
assert_eq!(states[1].name, "complex_test_state");
match &actions[0] {
TestAction::SetName(name) => assert_eq!(name, "new_name"),
_ => panic!("Expected SetName action"),
}
}
}