use async_trait::async_trait;
#[cfg(feature = "timestamp")]
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use std::{
any::{Any, type_name},
cmp::Eq,
fmt::Debug,
hash::Hash,
pin::Pin,
sync::Arc,
};
use thiserror::Error;
use tokio::{
select,
sync::{MutexGuard, RwLock, broadcast, mpsc},
};
use tokio_util::sync::CancellationToken;
use tracing::instrument;
#[derive(Clone, Debug)]
pub struct StateMachine<G>
where
G: Eq + Hash,
{
sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
}
impl<G> Default for StateMachine<G>
where
G: Eq + Hash,
{
fn default() -> Self {
Self {
sources: Default::default(),
handles: Default::default(),
}
}
}
impl<G> StateMachine<G>
where
G: Clone + Debug + Eq + Hash,
{
pub fn new() -> Self {
Default::default()
}
fn add_source<S>(&self, tag: G, source: Source<S>)
where
S: 'static + Send + Sync,
{
assert!(
!self.sources.contains_key(&tag),
"Source already exist, tag -- {:?}, type -- {:?}",
tag,
type_name::<S>()
);
self.sources.insert(tag, Box::new(source));
}
fn del_source(&self, tag: &G) -> bool {
self.sources.remove(tag).is_some()
}
fn has_source(&self, tag: &G) -> bool {
self.sources.contains_key(tag)
}
async fn source<S>(&self, tag: &G) -> Source<S>
where
S: 'static + Clone,
{
let opt_source_box = self.sources.get(tag);
assert!(
opt_source_box.is_some(),
"source does not exist, tag -- {:?}",
tag
);
let source_box = opt_source_box.unwrap();
let opt_source = source_box.downcast_ref::<Source<S>>();
assert!(
opt_source.is_some(),
"source does not exist, tag -- {:?}, type -- {}",
tag,
type_name::<S>()
);
let source = opt_source.unwrap();
(*source).clone()
}
async fn source_value<S>(&self, tag: &G) -> S
where
S: 'static + Clone + Default + PartialEq + Send,
{
self.source(tag).await.value().await
}
async fn source_value_ex<S>(&self, tag: &G) -> Value<S>
where
S: 'static + Clone + Default + PartialEq + Send,
{
self.source(tag).await.value_ex().await
}
fn add_handle<T>(&self, tag: G, handle: Handle<T>)
where
T: 'static + Send + Sync,
{
assert!(
!self.handles.contains_key(&tag),
"duplicate tag for handle -- {:?}",
tag
);
self.handles.insert(tag, Box::new(handle));
}
fn del_handle(&self, tag: &G) -> bool {
self.handles.remove(tag).is_some()
}
fn has_handle(&self, tag: &G) -> bool {
self.handles.contains_key(tag)
}
async fn handle<T>(&self, tag: &G) -> Handle<T>
where
T: 'static + Clone,
{
let opt_handle_box = self.handles.get(tag);
assert!(
opt_handle_box.is_some(),
"handle does not exist, tag -- {:?}",
tag
);
let handle_box = opt_handle_box.unwrap();
let opt_handle = handle_box.downcast_ref::<Handle<T>>();
assert!(
opt_handle.is_some(),
"handle does not exist, tag -- {:?}, type -- {}",
tag,
type_name::<T>()
);
opt_handle.unwrap().clone()
}
async fn handle_value<T>(&self, tag: &G) -> T
where
T: 'static + Clone + PartialEq,
{
self.handle(tag).await.value().await
}
async fn handle_value_ex<T>(&self, tag: &G) -> Value<T>
where
T: 'static + Clone + PartialEq,
{
self.handle(tag).await.value_ex().await
}
}
#[async_trait]
pub trait HasStateMachine<G>
where
G: Clone + Debug + Eq + Hash,
{
async fn lock(&self) -> MutexGuard<'_, ()>;
async fn state_machine(&self) -> StateMachine<G>;
}
#[async_trait]
pub trait UseStateMachine<G>: HasStateMachine<G>
where
G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
{
async fn add_source<S>(&self, tag: G)
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine()
.await
.add_source(tag, Source::<S>::default());
}
async fn add_source_ex<S>(&self, tag: G, chan_capacity: usize, init_value: S)
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine()
.await
.add_source(tag, Source::create(init_value, chan_capacity));
}
async fn del_source(&self, tag: &G) -> bool {
self.state_machine().await.del_source(tag)
}
async fn has_source(&self, tag: &G) -> bool {
self.state_machine().await.has_source(tag)
}
async fn num_of_subscriptions<S>(&self, tag: &G) -> usize
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine()
.await
.source::<S>(tag)
.await
.num_of_subscriptions()
.await
}
async fn source_value<S>(&self, tag: &G) -> S
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine().await.source_value(tag).await
}
async fn source_value_ex<S>(&self, tag: &G) -> Value<S>
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine().await.source_value_ex(tag).await
}
async fn change<S>(&self, tag: &G, s: S) -> Result<(), SourceChangeError>
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine().await.source(tag).await.change(s).await
}
async fn wait_change<S>(&self, tag: &G, s: S) -> Result<(), SourceChangeError>
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine()
.await
.source(tag)
.await
.wait_change(s)
.await
}
async fn modify<S>(
&self,
tag: &G,
func: impl Fn(S) -> S + Send + Sync + 'static,
) -> Result<(), SourceChangeError>
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine()
.await
.source(tag)
.await
.modify(func)
.await
}
async fn wait_modify<S>(
&self,
tag: &G,
func: impl Fn(S) -> S + Send + Sync + 'static,
) -> Result<(), SourceChangeError>
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine()
.await
.source(tag)
.await
.wait_modify(func)
.await
}
async fn touch<S>(&self, tag: &G) -> Result<(), SourceChangeError>
where
S: 'static + Clone + Default + PartialEq + Send + Sync,
{
self.state_machine()
.await
.source::<S>(tag)
.await
.touch()
.await
}
async fn has_handle(&self, tag: &G) -> bool {
self.state_machine().await.has_handle(tag)
}
async fn handle_value<T>(&self, tag: &G) -> T
where
T: 'static + Clone + PartialEq + Send + Sync,
{
self.state_machine().await.handle_value(&tag).await
}
async fn handle_value_ex<T>(&self, tag: &G) -> Value<T>
where
T: 'static + Clone + PartialEq + Send + Sync,
{
self.state_machine().await.handle_value_ex(&tag).await
}
async fn reader<S>(&self, tag: &G) -> Reader<S>
where
S: 'static + Clone + Default + PartialEq + Send,
{
self.state_machine().await.source::<S>(tag).await.reader()
}
async fn reader_ex<S, T>(
&self,
tag: &G,
func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
) -> ReaderEx<S, T>
where
S: 'static + Clone + Default + PartialEq + Send,
{
self.state_machine()
.await
.source::<S>(tag)
.await
.reader_ex(func)
}
async fn unsubscribe<T>(&self, tag: &G)
where
T: 'static + Clone + PartialEq + Send + Sync,
{
self.state_machine()
.await
.handle::<T>(tag)
.await
.unsubscribe();
}
}
#[async_trait]
impl<T, G> UseStateMachine<G> for T
where
T: HasStateMachine<G>,
G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
{
}
type NotCheckEq = bool;
#[cfg(feature = "timestamp")]
pub type Value<S> = (S, DateTime<Utc>);
#[cfg(not(feature = "timestamp"))]
pub type Value<S> = S;
#[derive(Clone, Debug)]
struct Source<S> {
value: Arc<RwLock<Value<S>>>,
sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
}
impl<S> Default for Source<S>
where
S: 'static + Clone + Default + PartialEq + Send,
{
fn default() -> Self {
Self::new()
}
}
impl<S> Source<S>
where
S: 'static + Clone + Default + PartialEq + Send,
{
fn new() -> Self {
Self::create(Default::default(), 100)
}
fn create(init_value: S, chan_capacity: usize) -> Self {
let (tx, _) = broadcast::channel(chan_capacity);
#[cfg(feature = "timestamp")]
let v = (init_value, Utc::now());
#[cfg(not(feature = "timestamp"))]
let v = init_value;
Self {
value: Arc::new(RwLock::new(v)),
sender: tx,
}
}
fn reader(&self) -> Reader<S> {
Reader {
value: self.value.clone(),
recver: self.sender.subscribe(),
}
}
fn reader_ex<T>(
&self,
func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
) -> ReaderEx<S, T> {
ReaderEx {
value: self.value.clone(),
recver: self.sender.subscribe(),
func: Arc::new(func),
}
}
async fn num_of_subscriptions(&self) -> usize {
self.sender.receiver_count()
}
async fn value(&self) -> S {
#[cfg(feature = "timestamp")]
{
(*self.value.read().await).clone().0
}
#[cfg(not(feature = "timestamp"))]
{
(*self.value.read().await).clone()
}
}
async fn value_ex(&self) -> Value<S> {
(*self.value.read().await).clone()
}
async fn change_ex(
&self,
wait_to_end: bool,
change: Change<S>,
) -> Result<(), SourceChangeError> {
let mut guard = self.value.write().await;
#[cfg(feature = "timestamp")]
let g = (*guard).0.clone();
#[cfg(not(feature = "timestamp"))]
let g = (*guard).clone();
let (s, not_check_eq) = match change {
Change::Value(v) => (v, false),
Change::Func(func) => (func(g.clone()), false),
Change::Touch => (g.clone(), true),
};
if not_check_eq || g != s {
if wait_to_end {
let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
self.sender
.send((s.clone(), not_check_eq, Some(tx_w)))
.map_err(|_| SourceChangeError::SendErr)?;
loop {
select! {
res = rx_w.recv() => {
if res.is_none() {
break;
}
}
}
}
} else {
self.sender
.send((s.clone(), not_check_eq, None))
.map_err(|_| SourceChangeError::SendErr)?;
}
#[cfg(feature = "timestamp")]
{
*guard = (s, Utc::now());
}
#[cfg(not(feature = "timestamp"))]
{
*guard = s;
}
Ok(())
} else {
Err(SourceChangeError::NotChange)
}
}
async fn change(&self, s: S) -> Result<(), SourceChangeError> {
self.change_ex(false, Change::Value(s)).await
}
async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
self.change_ex(true, Change::Value(s)).await
}
async fn modify(
&self,
func: impl Fn(S) -> S + Send + Sync + 'static,
) -> Result<(), SourceChangeError> {
self.change_ex(false, Change::Func(Arc::new(func))).await
}
async fn wait_modify(
&self,
func: impl Fn(S) -> S + Send + Sync + 'static,
) -> Result<(), SourceChangeError> {
self.change_ex(true, Change::Func(Arc::new(func))).await
}
async fn touch(&self) -> Result<(), SourceChangeError> {
self.change_ex(false, Change::Touch).await
}
}
enum Change<S> {
Value(S),
Func(Arc<dyn Fn(S) -> S + Send + Sync>),
Touch,
}
#[derive(Debug, Error)]
pub enum SourceChangeError {
#[error("Change of state failed to broadcast")]
SendErr,
#[error("source not change, no change detected")]
NotChange,
}
pub struct Reader<S> {
value: Arc<RwLock<Value<S>>>,
recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
}
impl<S> Into<ReaderEx<S, S>> for Reader<S>
where
S: 'static + Send,
{
fn into(self) -> ReaderEx<S, S> {
ReaderEx {
value: self.value,
recver: self.recver,
func: Arc::new(|s| Box::pin(async move { s })),
}
}
}
impl<S> Reader<S> {
pub fn extend<T>(
self,
func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
) -> ReaderEx<S, T> {
ReaderEx {
value: self.value,
recver: self.recver,
func: Arc::new(func),
}
}
}
pub struct ReaderEx<S, T> {
value: Arc<RwLock<Value<S>>>,
recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
}
impl<S, T> ReaderEx<S, T>
where
S: 'static + Clone + Send,
T: 'static,
{
async fn value(&self) -> Value<T> {
#[cfg(feature = "timestamp")]
{
let (s, t) = (*self.value.read().await).clone();
(self.func.as_ref()(s).await, t)
}
#[cfg(not(feature = "timestamp"))]
{
self.func.as_ref()((*self.value.read().await).clone()).await
}
}
pub fn extend<U>(
self,
func: impl Fn(T) -> Pin<Box<dyn Future<Output = U> + Send>> + Send + Sync + 'static,
) -> ReaderEx<S, U> {
let func_o = self.func.clone();
let func_n = Arc::new(func);
ReaderEx {
value: self.value,
recver: self.recver,
func: Arc::new(move |s| {
let func_a = func_o.clone();
let func_b = func_n.clone();
Box::pin(async move {
let t = func_a.as_ref()(s).await;
func_b.as_ref()(t).await
})
}),
}
}
}
#[derive(Clone, Debug)]
struct Handle<T> {
cancel_token: CancellationToken,
value: Arc<RwLock<Value<T>>>,
}
impl<T> Handle<T>
where
T: Clone + PartialEq,
{
fn new(init_value: T) -> Self {
#[cfg(feature = "timestamp")]
let t = (init_value, Utc::now());
#[cfg(not(feature = "timestamp"))]
let t = init_value;
Self {
cancel_token: CancellationToken::new(),
value: Arc::new(RwLock::new(t)),
}
}
async fn store(&self, t: T, not_check_eq: bool) -> bool {
#[cfg(feature = "timestamp")]
let v = (t, Utc::now());
#[cfg(not(feature = "timestamp"))]
let v = t;
let changed = *self.value.read().await != v;
if changed {
*self.value.write().await = v;
}
not_check_eq || changed
}
async fn value(&self) -> T {
#[cfg(feature = "timestamp")]
{
(*self.value.read().await).clone().0
}
#[cfg(not(feature = "timestamp"))]
{
(*self.value.read().await).clone()
}
}
async fn value_ex(&self) -> Value<T> {
(*self.value.read().await).clone()
}
fn unsubscribe(&self) {
self.cancel_token.cancel();
}
}
#[async_trait]
pub trait HasStateHandle<T, G>: HasStateMachine<G>
where
T: Clone + Debug + PartialEq,
G: Clone + Debug + Eq + Hash,
{
async fn on_change(
self: Arc<Self>,
tag: G,
new_value: T,
old_value: T,
) -> Result<(), Box<dyn std::error::Error>>;
}
#[async_trait]
pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
where
T: 'static + Clone + Debug + PartialEq + Send + Sync,
G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
{
#[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
async fn subscribe<S>(self: Arc<Self>, reader: impl Into<ReaderEx<S, T>> + Send, tag: G)
where
S: 'static + Clone + Debug + PartialEq + Send + Sync,
{
let reader_ex = reader.into();
#[cfg(feature = "timestamp")]
let init = reader_ex.value().await.0;
#[cfg(not(feature = "timestamp"))]
let init = reader_ex.value().await;
let handle: Handle<T> = Handle::new(init);
self.state_machine()
.await
.add_handle(tag.clone(), handle.clone());
let mut rx_s = reader_ex.recver;
tokio::spawn(async move {
tracing::info!("Subscription start -- {:?}", tag);
loop {
select! {
_ = handle.cancel_token.cancelled() => {
break;
}
res = rx_s.recv() => {
match res {
Ok((s, not_check_eq, opt_feedback)) => {
let v = reader_ex.func.as_ref()(s).await;
let t_old = handle.value().await;
if handle.store(v.clone(), not_check_eq).await {
let _lock = self.lock().await;
let t_new = handle.value().await;
if let Err(e) = self.clone().on_change(tag.clone(), t_new, t_old).await {
tracing::error!("stage [2] | change event proc error -- {}", e);
}
if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
tracing::error!("stage [3] | change event feedback error -- {}", e);
}
}
},
Err(e) => match e {
broadcast::error::RecvError::Closed => {
_ = self.state_machine().await.del_source(&tag);
tracing::info!("source channel closed");
break;
},
broadcast::error::RecvError::Lagged(_) => {
tracing::error!("stage [1] | change event recv lagged");
break;
},
},
}
}
}
}
_ = self.state_machine().await.del_handle(&tag);
tracing::info!("Subscription end -- {:?}", tag);
});
}
}
impl<V, T, G> UseStateHandle<T, G> for V
where
V: 'static + HasStateHandle<T, G>,
T: 'static + Clone + Debug + PartialEq + Send + Sync,
G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
{
}