Skip to main content

state_m/
lib.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4    any::{Any, type_name},
5    cmp::Eq,
6    fmt::Debug,
7    hash::Hash,
8    pin::Pin,
9    sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13    select,
14    sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19/// State machine data structure to store state sources and handles.
20/// - G - to distinguish different initiators or responders.
21#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24    G: Eq + Hash,
25{
26    sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27    handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32    G: Eq + Hash,
33{
34    fn default() -> Self {
35        Self {
36            sources: Default::default(),
37            handles: Default::default(),
38        }
39    }
40}
41
42impl<G> StateMachine<G>
43where
44    G: Clone + Debug + Eq + Hash,
45{
46    pub fn new() -> Self {
47        Default::default()
48    }
49
50    /// Add state source to state machine.
51    pub(crate) fn add_source<S>(&self, tag: G, source: Source<S>)
52    where
53        S: 'static + Send + Sync,
54    {
55        assert!(
56            !self.sources.contains_key(&tag),
57            "duplicate tag for source -- {:?}",
58            tag
59        );
60        self.sources.insert(tag, Box::new(source));
61    }
62
63    /// Delete state source from state machine.
64    pub(crate) fn del_source(&self, tag: G) -> bool {
65        self.sources.remove(&tag).is_some()
66    }
67
68    /// Get source from state machine by tag.
69    pub async fn source<S>(&self, tag: G) -> Source<S>
70    where
71        S: 'static + Clone,
72    {
73        let opt_source_box = self.sources.get(&tag);
74        assert!(
75            opt_source_box.is_some(),
76            "state source does not exist, tag -- {:?}",
77            tag
78        );
79        let source_box = opt_source_box.unwrap();
80        let opt_source = source_box.downcast_ref::<Source<S>>();
81        assert!(
82            opt_source.is_some(),
83            "state source does not exist, tag -- {:?}, type -- {}",
84            tag,
85            type_name::<S>()
86        );
87        let source = opt_source.unwrap();
88        (*source).clone()
89    }
90
91    /// Add state handle to state machine.
92    pub(crate) fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93    where
94        T: 'static + Send + Sync,
95    {
96        assert!(
97            !self.handles.contains_key(&tag),
98            "duplicate tag for handle -- {:?}",
99            tag
100        );
101        self.handles.insert(tag, Box::new(handle));
102    }
103
104    /// Delete state handle from state machine.
105    pub(crate) fn del_handle(&self, tag: G) -> bool {
106        self.handles.remove(&tag).is_some()
107    }
108
109    /// Get current value of source from state machine by tag.
110    pub async fn source_value<S>(&self, tag: G) -> S
111    where
112        S: 'static + Clone + Default + PartialEq + Send,
113    {
114        self.source(tag).await.value().await
115    }
116
117    /// Get handle from state machine.
118    pub async fn handle<T>(&self, tag: G) -> Handle<T>
119    where
120        T: 'static + Clone,
121    {
122        let opt_handle_box = self.handles.get(&tag);
123        assert!(
124            opt_handle_box.is_some(),
125            "state handle does not exist, tag -- {:?}",
126            tag
127        );
128        let handle_box = opt_handle_box.unwrap();
129        let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130        assert!(
131            opt_handle.is_some(),
132            "state handle does not exist, tag -- {:?}, type -- {}",
133            tag,
134            type_name::<T>()
135        );
136        opt_handle.unwrap().clone()
137    }
138
139    /// Get current value of handle from state machine.
140    pub async fn handle_value<T>(&self, tag: G) -> Option<T>
141    where
142        T: 'static + Clone + PartialEq,
143    {
144        self.handle(tag).await.value().await
145    }
146}
147
148/// The data structure is locked while responding state change.
149#[async_trait]
150pub trait HasLock {
151    /// The mutex lock to use.
152    async fn lock(&self) -> MutexGuard<'_, ()>;
153}
154
155/// At least you should provide a state machine data structure.
156#[async_trait]
157pub trait HasStateMachine<G>: HasLock
158where
159    G: Clone + Debug + Eq + Hash,
160{
161    /// The state machine data structure.
162    async fn state_machine(&self) -> StateMachine<G>;
163}
164
165/// Some convenient methods to use state machine. The trait is auto implemented for types implemented HasStateMachine.
166#[async_trait]
167pub trait UseStateMachine<G>: HasStateMachine<G>
168where
169    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
170{
171    /// Get state source.
172    async fn source<S>(&self, tag: G) -> Source<S>
173    where
174        S: 'static + Clone,
175    {
176        self.state_machine().await.source(tag).await
177    }
178
179    /// Get current value of state source.
180    async fn source_value<S>(&self, tag: G) -> Option<S>
181    where
182        S: 'static + Clone + PartialEq + Send + Sync,
183    {
184        self.state_machine().await.source_value(tag).await
185    }
186
187    /// Get state handle.
188    async fn handle<T>(&self, tag: G) -> Handle<T>
189    where
190        T: 'static + Clone,
191    {
192        self.state_machine().await.handle(tag).await
193    }
194
195    /// Get current value of state handle.
196    async fn handle_value<T>(&self, tag: G) -> Option<T>
197    where
198        T: 'static + Clone + PartialEq + Send + Sync,
199    {
200        self.state_machine().await.handle_value(tag).await
201    }
202}
203
204#[async_trait]
205impl<T, G> UseStateMachine<G> for T
206where
207    T: HasStateMachine<G>,
208    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
209{
210}
211
212/// Convenient method to add state source to state machine. The trait is auto implemented for types implemented HasStateMachine.
213#[async_trait]
214pub trait UseStateSource<G>: HasStateMachine<G>
215where
216    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
217{
218    /// Add state source to state machine.
219    async fn add_source<S>(&self, tag: G, source: Source<S>)
220    where
221        S: 'static + Send + Sync,
222    {
223        self.state_machine().await.add_source(tag, source);
224    }
225}
226
227impl<T, G> UseStateSource<G> for T
228where
229    T: HasStateMachine<G>,
230    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
231{
232}
233
234/// When initiate state change, compare with current value or not. By default,
235/// a new state is compared with current value, if they are equal, does not trigger a change event.
236type NotCheckEq = bool;
237
238/// State source, the initiator of state change.
239#[derive(Clone, Debug)]
240pub struct Source<S> {
241    value: Arc<RwLock<S>>,
242    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
243}
244
245impl<S> Default for Source<S>
246where
247    S: 'static + Clone + Default + PartialEq + Send,
248{
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl<S> Source<S>
255where
256    S: 'static + Clone + Default + PartialEq + Send,
257{
258    /// Create a state source, with broadcast channel capacity of 100.
259    pub fn new() -> Self {
260        Self::create(Default::default(), 100)
261    }
262
263    /// Create a state source with custom broadcast channel capacity.
264    /// - capacity: broadcast channel capacity
265    pub fn create(init_value: S, capacity: usize) -> Self {
266        let (tx, _) = broadcast::channel(capacity);
267        Self {
268            value: Arc::new(RwLock::new(init_value)),
269            sender: tx,
270        }
271    }
272
273    /// Get reader of state source, can be subscribed by responders.
274    pub fn reader(&self) -> Reader<S> {
275        Reader {
276            value: self.value.clone(),
277            sender: self.sender.clone(),
278        }
279    }
280
281    /// Get reader of state source, can be subscribed by responders.
282    pub fn reader_ex<T>(&self, func: ConvertFunc<S, T>) -> ReaderEx<S, T> {
283        ReaderEx {
284            value: self.value.clone(),
285            sender: self.sender.clone(),
286            func,
287        }
288    }
289
290    /// Num of subscriptions.
291    pub async fn num_of_subs(&self) -> usize {
292        self.sender.receiver_count()
293    }
294
295    /// Get current value of state source.
296    pub async fn value(&self) -> S {
297        (*self.value.read().await).clone()
298    }
299
300    async fn change_ex(
301        &self,
302        wait_to_end: bool,
303        change: Change<S>,
304    ) -> Result<(), SourceChangeError> {
305        let mut guard = self.value.write().await;
306        let (s, not_check_eq) = match change {
307            Change::Value(v) => (v, false),
308            Change::Func(func) => (func((*guard).clone()), false),
309            Change::Touch => ((*guard).clone(), true),
310        };
311        if not_check_eq || *guard != s {
312            if wait_to_end {
313                let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
314                self.sender
315                    .send((s.clone(), not_check_eq, Some(tx_w)))
316                    .map_err(|_| SourceChangeError::SendErr)?;
317                loop {
318                    select! {
319                        res = rx_w.recv()  => {
320                            if res.is_none() {
321                                break;
322                            }
323                        }
324                    }
325                }
326            } else {
327                self.sender
328                    .send((s.clone(), not_check_eq, None))
329                    .map_err(|_| SourceChangeError::SendErr)?;
330            }
331            *guard = s;
332            Ok(())
333        } else {
334            Err(SourceChangeError::NotChange)
335        }
336    }
337
338    /// Change state of source.
339    pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
340        self.change_ex(false, Change::Value(s)).await
341    }
342
343    /// Change state of source, and wait responders to finish actions upon the change event.
344    pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
345        self.change_ex(true, Change::Value(s)).await
346    }
347
348    /// Change state of source by modifying it with a func.
349    pub async fn modify(&self, func: impl Fn(S) -> S + 'static) -> Result<(), SourceChangeError> {
350        self.change_ex(false, Change::Func(Box::new(func))).await
351    }
352
353    /// Change state of source by modifying it with a func, and wait responders to finish actions upon the change event.
354    pub async fn wait_modify(
355        &self,
356        func: impl Fn(S) -> S + 'static,
357    ) -> Result<(), SourceChangeError> {
358        self.change_ex(true, Change::Func(Box::new(func))).await
359    }
360
361    /// Create a change event without changing state of source really.
362    pub async fn touch(&self) -> Result<(), SourceChangeError> {
363        self.change_ex(false, Change::Touch).await
364    }
365}
366
367enum Change<S> {
368    Value(S),
369    Func(Box<dyn Fn(S) -> S>),
370    Touch,
371}
372
373#[derive(Debug, Error)]
374pub enum SourceChangeError {
375    #[error("Change of state failed to broadcast")]
376    SendErr,
377    #[error("State source not change, no change detected")]
378    NotChange,
379}
380
381/// Data structure to be exposed to do subscription by state change responders.
382#[derive(Clone)]
383pub struct Reader<S> {
384    value: Arc<RwLock<S>>,
385    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
386}
387
388impl<S> Into<ReaderEx<S, S>> for Reader<S>
389where
390    S: 'static + Send,
391{
392    fn into(self) -> ReaderEx<S, S> {
393        ReaderEx {
394            value: self.value,
395            sender: self.sender,
396            func: Arc::new(|s| Box::pin(async move { s })),
397        }
398    }
399}
400
401impl<S> Reader<S> {
402    pub fn extend<T>(&self, func: ConvertFunc<S, T>) -> ReaderEx<S, T> {
403        ReaderEx {
404            value: self.value.clone(),
405            sender: self.sender.clone(),
406            func,
407        }
408    }
409}
410
411pub type ConvertFunc<S, T> =
412    Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>;
413
414/// Data structure to be exposed to do subscription by state change responders, with the ability to convert the state to another type.
415#[derive(Clone)]
416pub struct ReaderEx<S, T> {
417    value: Arc<RwLock<S>>,
418    sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
419    func: ConvertFunc<S, T>,
420}
421
422impl<S, T> ReaderEx<S, T>
423where
424    S: Clone,
425{
426    async fn value(&self) -> T {
427        self.func.as_ref()((*self.value.read().await).clone()).await
428    }
429}
430
431/// Data structure to store the latest state in responder's state machine, can be used to do unsubscription.
432#[derive(Clone, Debug)]
433pub struct Handle<T> {
434    cancel_token: CancellationToken,
435    value: Arc<RwLock<T>>,
436}
437
438impl<T> Handle<T>
439where
440    T: Clone + PartialEq,
441{
442    fn new(init_value: T) -> Self {
443        Self {
444            cancel_token: CancellationToken::new(),
445            value: Arc::new(RwLock::new(init_value)),
446        }
447    }
448
449    async fn store(&self, t: T, not_check_eq: bool) -> bool {
450        let changed = *self.value.read().await != t;
451        if changed {
452            *self.value.write().await = t;
453        }
454        not_check_eq || changed
455    }
456
457    async fn value(&self) -> T {
458        (*self.value.read().await).clone()
459    }
460
461    /// Unsubscribe operation, this is optional, after your state machine
462    /// is dropped, subscriptions are auto cleaned.
463    pub fn unsubscribe(&self) {
464        self.cancel_token.cancel();
465    }
466}
467
468/// Define action upon state change event.
469/// - T - type of state in handle,
470/// - G - to distinguish different initiators or responders,
471/// all initiators must use different tag values, all responders,
472/// and all responders do the same, a same tag value can be used
473/// by an initiator and a responder in the same state machine.
474#[async_trait]
475pub trait HasStateHandle<T, G>: HasStateMachine<G>
476where
477    T: Clone + Debug + PartialEq,
478    G: Clone + Debug + Eq + Hash,
479{
480    /// Action upon state change event.
481    /// - tag - the tag value
482    /// - new_value - the new value just received
483    /// - old_value - the value received last time, it should be
484    /// 'None' at the first time.
485    async fn on_change(
486        self: Arc<Self>,
487        tag: G,
488        new_value: T,
489        old_value: T,
490    ) -> Result<(), impl std::error::Error>;
491}
492
493/// Convenient method to do subscription with a state convert function. The trait is auto implemented for types implemented HasStateHandle.
494#[async_trait]
495pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
496where
497    T: 'static + Clone + Debug + PartialEq + Send + Sync,
498    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
499{
500    /// Do subscription with a state convert function.
501    /// - stage [1] -- receive from source's broadcast channel.
502    /// - stage [2] -- convert to target type and send to mpsc channel.
503    /// - stage [3] -- receive from mpsc channel and process it.
504    /// - stage [4] -- (optional) feedback when the change event has been processed.
505    #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
506    async fn subscribe<S>(
507        self: Arc<Self>,
508        reader: impl Into<ReaderEx<S, T>> + Send,
509        tag: G,
510    ) -> Handle<T>
511    where
512        S: 'static + Clone + Debug + PartialEq + Send + Sync,
513    {
514        let reader_ex = reader.into();
515        let handle: Handle<T> = Handle::new(reader_ex.value().await);
516        self.state_machine()
517            .await
518            .add_handle(tag.clone(), handle.clone());
519        let mut rx_s = reader_ex.sender.subscribe();
520        let (tx_t, mut rx_t) =
521            mpsc::unbounded_channel::<(T, T, Option<mpsc::UnboundedSender<()>>)>();
522        let handle_c = handle.clone();
523        tokio::spawn(async move {
524            tracing::info!("Subscription start -- {:?}", tag);
525            loop {
526                select! {
527                    _ = handle_c.cancel_token.cancelled() => {
528                        break;
529                    }
530                    res = rx_s.recv() => {
531                        match res {
532                            Ok((s, not_check_eq, opt_feedback)) => {
533                                let t = reader_ex.func.as_ref()(s).await;
534                                let t_old = handle_c.value().await;
535                                if handle_c.store(t.clone(), not_check_eq).await {
536                                    if let Err(e) = tx_t.send((t, t_old, opt_feedback)) {
537                                        tracing::error!("stage [2] | change event send error -- {}", e);
538                                        break;
539                                    }
540                                }
541                            },
542                            Err(e) => match e {
543                                broadcast::error::RecvError::Closed => {
544                                    _ = self.state_machine().await.del_source(tag.clone());
545                                    tracing::info!("state source channel closed");
546                                    break;
547                                },
548                                broadcast::error::RecvError::Lagged(_) => {
549                                    tracing::error!("stage [1] | change event recv lagged");
550                                    break;
551                                },
552                            },
553                        }
554                    }
555                    res = rx_t.recv() => {
556                        match res {
557                            Some((t, t_old, opt_feedback)) => {
558                                let _lock = self.lock().await;
559                                if let Err(e) = self.clone().on_change(tag.clone(), t, t_old).await {
560                                    tracing::error!("stage [3] | change event proc error -- {}", e);
561                                }
562                                if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
563                                    tracing::error!("stage [4] | change event feedback error -- {}", e);
564                                }
565                            },
566                            None => {
567                                tracing::info!("state target channel closed");
568                                break;
569                            },
570                        }
571                    }
572                }
573            }
574            _ = self.state_machine().await.del_handle(tag.clone());
575            tracing::info!("Subscription end -- {:?}", tag);
576        });
577        handle
578    }
579}
580
581impl<V, T, G> UseStateHandle<T, G> for V
582where
583    V: 'static + HasStateHandle<T, G>,
584    T: 'static + Clone + Debug + PartialEq + Send + Sync,
585    G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
586{
587}