Skip to main content

state_m/
lib.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4    any::{Any, type_name},
5    cmp::Eq,
6    fmt::Debug,
7    hash::Hash,
8    pin::Pin,
9    sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13    select,
14    sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19/// State machine data structure to store state sources and handles.
20/// - G - to distinguish different initiators or responders.
21#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24    G: Eq + Hash,
25{
26    sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27    handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32    G: Eq + Hash,
33{
34    fn default() -> Self {
35        Self {
36            sources: Default::default(),
37            handles: Default::default(),
38        }
39    }
40}
41
42impl<G> StateMachine<G>
43where
44    G: Clone + Debug + Eq + Hash,
45{
46    pub fn new() -> Self {
47        Default::default()
48    }
49
50    /// Add state source to state machine.
51    pub(crate) fn add_source<S>(&self, tag: G, source: Source<S>)
52    where
53        S: 'static + Send + Sync,
54    {
55        assert!(
56            !self.sources.contains_key(&tag),
57            "duplicate tag for source -- {:?}",
58            tag
59        );
60        self.sources.insert(tag, Box::new(source));
61    }
62
63    /// Delete state source from state machine.
64    pub(crate) fn del_source(&self, tag: G) -> bool {
65        self.sources.remove(&tag).is_some()
66    }
67
68    /// Get source from state machine by tag.
69    pub async fn source<S>(&self, tag: G) -> Source<S>
70    where
71        S: 'static + Clone,
72    {
73        let opt_source_box = self.sources.get(&tag);
74        assert!(
75            opt_source_box.is_some(),
76            "state source does not exist, tag -- {:?}",
77            tag
78        );
79        let source_box = opt_source_box.unwrap();
80        let opt_source = source_box.downcast_ref::<Source<S>>();
81        assert!(
82            opt_source.is_some(),
83            "state source does not exist, tag -- {:?}, type -- {}",
84            tag,
85            type_name::<S>()
86        );
87        let source = opt_source.unwrap();
88        (*source).clone()
89    }
90
91    /// Add state handle to state machine.
92    pub(crate) fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93    where
94        T: 'static + Send + Sync,
95    {
96        assert!(
97            !self.handles.contains_key(&tag),
98            "duplicate tag for handle -- {:?}",
99            tag
100        );
101        self.handles.insert(tag, Box::new(handle));
102    }
103
104    /// Delete state handle from state machine.
105    pub(crate) fn del_handle(&self, tag: G) -> bool {
106        self.handles.remove(&tag).is_some()
107    }
108
109    /// Get current value of source from state machine by tag.
110    pub async fn source_value<S>(&self, tag: G) -> S
111    where
112        S: 'static + Clone + Default + PartialEq + Send,
113    {
114        self.source(tag).await.value().await
115    }
116
117    /// Get handle from state machine.
118    pub async fn handle<T>(&self, tag: G) -> Handle<T>
119    where
120        T: 'static + Clone,
121    {
122        let opt_handle_box = self.handles.get(&tag);
123        assert!(
124            opt_handle_box.is_some(),
125            "state handle does not exist, tag -- {:?}",
126            tag
127        );
128        let handle_box = opt_handle_box.unwrap();
129        let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130        assert!(
131            opt_handle.is_some(),
132            "state handle does not exist, tag -- {:?}, type -- {}",
133            tag,
134            type_name::<T>()
135        );
136        opt_handle.unwrap().clone()
137    }
138
139    /// Get current value of handle from state machine.
140    pub async fn handle_value<T>(&self, tag: G) -> T
141    where
142        T: 'static + Clone + PartialEq,
143    {
144        self.handle(tag).await.value().await
145    }
146}
147
148/// The data structure is locked while responding state change.
149#[async_trait]
150pub trait HasLock {
151    /// The mutex lock to use.
152    async fn lock(&self) -> MutexGuard<'_, ()>;
153}
154
155/// At least you should provide a state machine data structure.
156#[async_trait]
157pub trait HasStateMachine<G>: HasLock
158where
159    G: Clone + Debug + Eq + Hash,
160{
161    /// The state machine data structure.
162    async fn state_machine(&self) -> StateMachine<G>;
163}
164
165/// Some convenient methods to use state machine. The trait is auto implemented for types implemented HasStateMachine.
166#[async_trait]
167pub trait UseStateMachine<G>: HasStateMachine<G>
168where
169    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
170{
171    /// Get state source.
172    async fn source<S>(&self, tag: G) -> Source<S>
173    where
174        S: 'static + Clone,
175    {
176        self.state_machine().await.source(tag).await
177    }
178
179    /// Get current value of state source.
180    async fn source_value<S>(&self, tag: G) -> S
181    where
182        S: 'static + Clone + Default + PartialEq + Send + Sync,
183    {
184        self.state_machine().await.source_value(tag).await
185    }
186
187    /// Get state handle.
188    async fn handle<T>(&self, tag: G) -> Handle<T>
189    where
190        T: 'static + Clone,
191    {
192        self.state_machine().await.handle(tag).await
193    }
194
195    /// Get current value of state handle.
196    async fn handle_value<T>(&self, tag: G) -> T
197    where
198        T: 'static + Clone + PartialEq + Send + Sync,
199    {
200        self.state_machine().await.handle_value(tag).await
201    }
202}
203
204#[async_trait]
205impl<T, G> UseStateMachine<G> for T
206where
207    T: HasStateMachine<G>,
208    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
209{
210}
211
212/// Convenient method to add state source to state machine. The trait is auto implemented for types implemented HasStateMachine.
213#[async_trait]
214pub trait UseStateSource<G>: HasStateMachine<G>
215where
216    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
217{
218    /// Add state source to state machine, the state source is created by default.
219    async fn add_source<S>(&self, tag: G) -> Source<S>
220    where
221        S: 'static + Clone + Default + PartialEq + Send + Sync,
222    {
223        let source = Source::<S>::default();
224        self.state_machine().await.add_source(tag, source.clone());
225        source
226    }
227
228    /// Add state source to state machine.
229    async fn add_source_ex<S>(&self, tag: G, source: Source<S>) -> Source<S>
230    where
231        S: 'static + Clone + Send + Sync,
232    {
233        self.state_machine().await.add_source(tag, source.clone());
234        source
235    }
236}
237
238impl<T, G> UseStateSource<G> for T
239where
240    T: HasStateMachine<G>,
241    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
242{
243}
244
245/// When initiate state change, compare with current value or not. By default,
246/// a new state is compared with current value, if they are equal, does not trigger a change event.
247type NotCheckEq = bool;
248
249/// State source, the initiator of state change.
250#[derive(Clone, Debug)]
251pub struct Source<S> {
252    value: Arc<RwLock<S>>,
253    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
254}
255
256impl<S> Default for Source<S>
257where
258    S: 'static + Clone + Default + PartialEq + Send,
259{
260    fn default() -> Self {
261        Self::new()
262    }
263}
264
265impl<S> Source<S>
266where
267    S: 'static + Clone + Default + PartialEq + Send,
268{
269    /// Create a state source, with broadcast channel capacity of 100.
270    pub fn new() -> Self {
271        Self::create(Default::default(), 100)
272    }
273
274    /// Create a state source with custom broadcast channel capacity.
275    /// - capacity: broadcast channel capacity
276    pub fn create(init_value: S, capacity: usize) -> Self {
277        let (tx, _) = broadcast::channel(capacity);
278        Self {
279            value: Arc::new(RwLock::new(init_value)),
280            sender: tx,
281        }
282    }
283
284    /// Get reader of state source, can be subscribed by responders.
285    pub fn reader(&self) -> Reader<S> {
286        Reader {
287            value: self.value.clone(),
288            sender: self.sender.clone(),
289        }
290    }
291
292    /// Get reader of state source, can be subscribed by responders.
293    pub fn reader_ex<T>(
294        &self,
295        func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
296    ) -> ReaderEx<S, T> {
297        ReaderEx {
298            value: self.value.clone(),
299            sender: self.sender.clone(),
300            func: Arc::new(func),
301        }
302    }
303
304    /// Num of subscriptions.
305    pub async fn num_of_subs(&self) -> usize {
306        self.sender.receiver_count()
307    }
308
309    /// Get current value of state source.
310    pub async fn value(&self) -> S {
311        (*self.value.read().await).clone()
312    }
313
314    async fn change_ex(
315        &self,
316        wait_to_end: bool,
317        change: Change<S>,
318    ) -> Result<(), SourceChangeError> {
319        let mut guard = self.value.write().await;
320        let (s, not_check_eq) = match change {
321            Change::Value(v) => (v, false),
322            Change::Func(func) => (func((*guard).clone()), false),
323            Change::Touch => ((*guard).clone(), true),
324        };
325        if not_check_eq || *guard != s {
326            if wait_to_end {
327                let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
328                self.sender
329                    .send((s.clone(), not_check_eq, Some(tx_w)))
330                    .map_err(|_| SourceChangeError::SendErr)?;
331                loop {
332                    select! {
333                        res = rx_w.recv()  => {
334                            if res.is_none() {
335                                break;
336                            }
337                        }
338                    }
339                }
340            } else {
341                self.sender
342                    .send((s.clone(), not_check_eq, None))
343                    .map_err(|_| SourceChangeError::SendErr)?;
344            }
345            *guard = s;
346            Ok(())
347        } else {
348            Err(SourceChangeError::NotChange)
349        }
350    }
351
352    /// Change state of source.
353    pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
354        self.change_ex(false, Change::Value(s)).await
355    }
356
357    /// Change state of source, and wait responders to finish actions upon the change event.
358    pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
359        self.change_ex(true, Change::Value(s)).await
360    }
361
362    /// Change state of source by modifying it with a func.
363    pub async fn modify(
364        &self,
365        func: impl Fn(S) -> S + Send + Sync + 'static,
366    ) -> Result<(), SourceChangeError> {
367        self.change_ex(false, Change::Func(Arc::new(func))).await
368    }
369
370    /// Change state of source by modifying it with a func, and wait responders to finish actions upon the change event.
371    pub async fn wait_modify(
372        &self,
373        func: impl Fn(S) -> S + Send + Sync + 'static,
374    ) -> Result<(), SourceChangeError> {
375        self.change_ex(true, Change::Func(Arc::new(func))).await
376    }
377
378    /// Create a change event without changing state of source really.
379    pub async fn touch(&self) -> Result<(), SourceChangeError> {
380        self.change_ex(false, Change::Touch).await
381    }
382}
383
384enum Change<S> {
385    Value(S),
386    Func(Arc<dyn Fn(S) -> S + Send + Sync>),
387    Touch,
388}
389
390#[derive(Debug, Error)]
391pub enum SourceChangeError {
392    #[error("Change of state failed to broadcast")]
393    SendErr,
394    #[error("State source not change, no change detected")]
395    NotChange,
396}
397
398/// Data structure to be exposed to do subscription by state change responders.
399#[derive(Clone)]
400pub struct Reader<S> {
401    value: Arc<RwLock<S>>,
402    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
403}
404
405impl<S> Into<ReaderEx<S, S>> for Reader<S>
406where
407    S: 'static + Send,
408{
409    fn into(self) -> ReaderEx<S, S> {
410        ReaderEx {
411            value: self.value,
412            sender: self.sender,
413            func: Arc::new(|s| Box::pin(async move { s })),
414        }
415    }
416}
417
418impl<S> Reader<S> {
419    pub fn extend<T>(
420        &self,
421        func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
422    ) -> ReaderEx<S, T> {
423        ReaderEx {
424            value: self.value.clone(),
425            sender: self.sender.clone(),
426            func: Arc::new(func),
427        }
428    }
429}
430
431/// Data structure to be exposed to do subscription by state change responders, with the ability to convert the state to another type.
432#[derive(Clone)]
433pub struct ReaderEx<S, T> {
434    value: Arc<RwLock<S>>,
435    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
436    func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
437}
438
439impl<S, T> ReaderEx<S, T>
440where
441    S: Clone,
442{
443    async fn value(&self) -> T {
444        self.func.as_ref()((*self.value.read().await).clone()).await
445    }
446}
447
448/// Data structure to store the latest state in responder's state machine, can be used to do unsubscription.
449#[derive(Clone, Debug)]
450pub struct Handle<T> {
451    cancel_token: CancellationToken,
452    value: Arc<RwLock<T>>,
453}
454
455impl<T> Handle<T>
456where
457    T: Clone + PartialEq,
458{
459    fn new(init_value: T) -> Self {
460        Self {
461            cancel_token: CancellationToken::new(),
462            value: Arc::new(RwLock::new(init_value)),
463        }
464    }
465
466    async fn store(&self, t: T, not_check_eq: bool) -> bool {
467        let changed = *self.value.read().await != t;
468        if changed {
469            *self.value.write().await = t;
470        }
471        not_check_eq || changed
472    }
473
474    async fn value(&self) -> T {
475        (*self.value.read().await).clone()
476    }
477
478    /// Unsubscribe operation, this is optional, after your state machine
479    /// is dropped, subscriptions are auto cleaned.
480    pub fn unsubscribe(&self) {
481        self.cancel_token.cancel();
482    }
483}
484
485/// Define action upon state change event.
486/// - T - type of state in handle,
487/// - G - to distinguish different initiators or responders,
488/// all initiators must use different tag values, all responders,
489/// and all responders do the same, a same tag value can be used
490/// by an initiator and a responder in the same state machine.
491#[async_trait]
492pub trait HasStateHandle<T, G>: HasStateMachine<G>
493where
494    T: Clone + Debug + PartialEq,
495    G: Clone + Debug + Eq + Hash,
496{
497    /// Action upon state change event.
498    /// - tag - the tag value
499    /// - new_value - the new value just received
500    /// - old_value - the value received last time, it should be
501    /// 'None' at the first time.
502    async fn on_change(
503        self: Arc<Self>,
504        tag: G,
505        new_value: T,
506        old_value: T,
507    ) -> Result<(), Box<dyn std::error::Error>>;
508}
509
510/// Convenient method to do subscription with a state convert function. The trait is auto implemented for types implemented HasStateHandle.
511#[async_trait]
512pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
513where
514    T: 'static + Clone + Debug + PartialEq + Send + Sync,
515    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
516{
517    /// Do subscription with a state convert function.
518    /// - stage [1] -- receive from source's broadcast channel.
519    /// - stage [2] -- convert to target type and send to mpsc channel.
520    /// - stage [3] -- receive from mpsc channel and process it.
521    /// - stage [4] -- (optional) feedback when the change event has been processed.
522    #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
523    async fn subscribe<S>(
524        self: Arc<Self>,
525        reader: impl Into<ReaderEx<S, T>> + Send,
526        tag: G,
527    ) -> Handle<T>
528    where
529        S: 'static + Clone + Debug + PartialEq + Send + Sync,
530    {
531        let reader_ex = reader.into();
532        let handle: Handle<T> = Handle::new(reader_ex.value().await);
533        self.state_machine()
534            .await
535            .add_handle(tag.clone(), handle.clone());
536        let mut rx_s = reader_ex.sender.subscribe();
537        let (tx_t, mut rx_t) =
538            mpsc::unbounded_channel::<(T, T, Option<mpsc::UnboundedSender<()>>)>();
539        let handle_c = handle.clone();
540        tokio::spawn(async move {
541            tracing::info!("Subscription start -- {:?}", tag);
542            loop {
543                select! {
544                    _ = handle_c.cancel_token.cancelled() => {
545                        break;
546                    }
547                    res = rx_s.recv() => {
548                        match res {
549                            Ok((s, not_check_eq, opt_feedback)) => {
550                                let t = reader_ex.func.as_ref()(s).await;
551                                let t_old = handle_c.value().await;
552                                if handle_c.store(t.clone(), not_check_eq).await {
553                                    if let Err(e) = tx_t.send((t, t_old, opt_feedback)) {
554                                        tracing::error!("stage [2] | change event send error -- {}", e);
555                                        break;
556                                    }
557                                }
558                            },
559                            Err(e) => match e {
560                                broadcast::error::RecvError::Closed => {
561                                    _ = self.state_machine().await.del_source(tag.clone());
562                                    tracing::info!("state source channel closed");
563                                    break;
564                                },
565                                broadcast::error::RecvError::Lagged(_) => {
566                                    tracing::error!("stage [1] | change event recv lagged");
567                                    break;
568                                },
569                            },
570                        }
571                    }
572                    res = rx_t.recv() => {
573                        match res {
574                            Some((t, t_old, opt_feedback)) => {
575                                let _lock = self.lock().await;
576                                if let Err(e) = self.clone().on_change(tag.clone(), t, t_old).await {
577                                    tracing::error!("stage [3] | change event proc error -- {}", e);
578                                }
579                                if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
580                                    tracing::error!("stage [4] | change event feedback error -- {}", e);
581                                }
582                            },
583                            None => {
584                                tracing::info!("state target channel closed");
585                                break;
586                            },
587                        }
588                    }
589                }
590            }
591            _ = self.state_machine().await.del_handle(tag.clone());
592            tracing::info!("Subscription end -- {:?}", tag);
593        });
594        handle
595    }
596}
597
598impl<V, T, G> UseStateHandle<T, G> for V
599where
600    V: 'static + HasStateHandle<T, G>,
601    T: 'static + Clone + Debug + PartialEq + Send + Sync,
602    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
603{
604}