use std::sync::{Arc, Mutex, MutexGuard, RwLock};
use futures::future::BoxFuture;
use tokio::sync::Notify;
pub trait StateTrait: Clone + Send + Sync + 'static {}
pub trait EventTrait: Clone + Send + Sync + 'static {}
impl<T> StateTrait for T where T: Clone + Send + Sync + 'static {}
impl<T> EventTrait for T where T: Clone + Send + Sync + 'static {}
pub trait Middleware<T: Slice>:
Fn(&StoreProxy<T, T::Event>, T::Event, &Box<dyn Fn(T::Event)>) + 'static
{
}
impl<T, U> Middleware<T> for U
where
T: Slice,
U: Fn(&StoreProxy<T, T::Event>, T::Event, &Box<dyn Fn(T::Event)>) + 'static,
{
}
pub struct Store<State, Event>
where
State: StateTrait,
Event: EventTrait,
{
state: Arc<RwLock<State>>,
events: Arc<Mutex<Vec<Event>>>,
reducer: fn(&mut State, Event),
dispatch_sync: Box<dyn Fn(Event)>,
dispatch_async: Arc<dyn Fn(Event) -> BoxFuture<'static, ()> + Send + Sync>,
rt: tokio::runtime::Runtime,
notify: Arc<tokio::sync::Notify>,
}
#[derive(Clone)]
pub struct StoreProxy<State, Event>
where
State: StateTrait,
Event: EventTrait,
{
state: Arc<RwLock<State>>,
last_state: Arc<State>,
events: Arc<Mutex<Vec<Event>>>,
dispatch_async: Arc<dyn Fn(Event) -> BoxFuture<'static, ()> + Send + Sync>,
rt: tokio::runtime::Handle,
}
impl<State, Event> StoreProxy<State, Event>
where
State: StateTrait,
Event: EventTrait,
{
pub fn state(&self) -> std::sync::RwLockReadGuard<State> {
self.state.read().unwrap()
}
pub fn get_state(&self) -> State {
self.state().clone()
}
pub fn last(&self) -> &State {
self.last_state.as_ref()
}
pub fn queue(&self, event: impl Into<Event>) {
self.dispatch(event);
}
pub fn dispatch(&self, event: impl Into<Event>) {
self.rt.spawn((self.dispatch_async)(event.into()));
}
}
impl<State, Event> Store<State, Event>
where
State: StateTrait,
Event: EventTrait,
{
pub fn new(s: State, reducer: fn(&mut State, Event)) -> Self {
let state = Arc::new(RwLock::new(s));
let state1 = state.clone();
let state2 = state.clone();
let notify = Arc::new(Notify::new());
let events = Arc::new(Mutex::new(vec![]));
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_time()
.build()
.unwrap();
let notify1 = notify.clone();
let events1 = events.clone();
Self {
events,
reducer,
dispatch_sync: Box::new(move |event| {
let mut mutex = state1.write().unwrap();
let state = &mut *mutex;
reducer(state, event);
}),
dispatch_async: Arc::new(move |event| {
let state = state2.clone();
let events = events1.clone();
let notify = notify1.clone();
Box::pin(async move {
events.lock().unwrap().push(event);
notify.notified().await;
})
}),
notify,
state,
rt,
}
}
pub fn with_state(mut self, state: State) -> Self {
*self.state.write().unwrap() = state;
self
}
pub fn add_middleware_sync(
self,
middleware: impl Fn(&StoreProxy<State, Event>, Event, &Box<dyn Fn(Event)>) + 'static,
) -> Self {
let proxy = self.proxy();
Self {
dispatch_sync: Box::new(move |event| {
middleware(&proxy, event, &self.dispatch_sync);
}),
..self
}
}
pub fn add_middleware_sync_boxed(
self,
middleware: Box<dyn Fn(&StoreProxy<State, Event>, Event, &Box<dyn Fn(Event)>)>,
) -> Self {
let proxy = self.proxy();
Self {
dispatch_sync: Box::new(move |event| {
middleware(&proxy, event, &self.dispatch_sync);
}),
..self
}
}
pub fn add_middleware_async(
self,
middleware: impl Fn(
StoreProxy<State, Event>,
Event,
Arc<dyn Fn(Event) -> BoxFuture<'static, ()> + Send + Sync + 'static>,
) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
) -> Self {
let proxy = self.proxy();
let Self {
reducer,
dispatch_async,
dispatch_sync,
state,
events,
notify,
rt,
} = self;
let dispatch = dispatch_async.clone();
let middleware = Arc::new(middleware);
Self {
dispatch_async: Arc::new(move |event| {
let dispatch2 = dispatch.clone();
let middleware2 = middleware.clone();
let proxy2 = proxy.clone();
Box::pin(async move { middleware2(proxy2, event, dispatch2).await })
}),
dispatch_sync,
state,
notify,
events,
reducer,
rt,
}
}
fn events(&self) -> MutexGuard<Vec<Event>> {
self.events.lock().unwrap()
}
pub fn state(&self) -> std::sync::RwLockReadGuard<State> {
self.state.read().unwrap()
}
pub fn queue(&self, e: Event) {
self.events().push(e);
}
pub fn process(&self) {
let events = self.events.lock().unwrap().drain(..).collect::<Vec<_>>();
for event in events.into_iter() {
(self.dispatch_sync)(event);
}
self.notify.notify_waiters();
}
pub fn proxy(&self) -> StoreProxy<State, Event> {
StoreProxy {
state: self.state.clone(),
last_state: Arc::new(self.state.read().unwrap().clone()),
events: self.events.clone(),
rt: self.rt.handle().clone(),
dispatch_async: self.dispatch_async.clone(),
}
}
pub fn get_state(&self) -> State {
self.state().clone()
}
}
pub fn create_store<T>(state: T) -> Store<T, T::Event>
where
T: Slice + StateTrait,
{
Store::new(state, T::reducer)
}
pub trait Slice: Sync + Send + Clone {
type Event: EventTrait;
fn reducer(state: &mut Self, event: Self::Event);
}
#[cfg(test)]
mod test {
use super::*;
use std::{cell::Cell, rc::Rc};
#[derive(Clone, Debug)]
struct TestState {
count: i32,
}
#[derive(Clone, Debug)]
enum TestEvent {
Increment,
Decrement,
}
impl Slice for TestState {
type Event = TestEvent;
fn reducer(state: &mut Self, event: Self::Event) {
match event {
TestEvent::Increment => state.count += 1,
TestEvent::Decrement => state.count -= 1,
}
}
}
#[test]
fn test_basic() {
let store = create_store(TestState { count: 0 });
store.queue(TestEvent::Increment);
store.queue(TestEvent::Increment);
store.process();
assert_eq!(store.get_state().count, 2);
}
#[test]
fn test_middleware() {
let store = create_store(TestState { count: 0 });
let counter = Rc::new(Cell::new(0));
let _counter = counter.clone();
let store = store.add_middleware_sync(move |proxy, evt, next| {
counter.set(counter.get() + 1);
counter.set(counter.get() + 1);
next(evt)
});
store.queue(TestEvent::Increment);
store.queue(TestEvent::Increment);
store.process();
assert_eq!(store.get_state().count, 2);
let counter = _counter.clone();
let store = store.add_middleware_sync(move |proxy, evt, next| {
counter.set(counter.get() + 3);
next(evt)
});
store.queue(TestEvent::Increment);
store.process();
assert_eq!(_counter.get(), 9);
}
}