1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4 any::{Any, type_name},
5 cmp::Eq,
6 fmt::Debug,
7 hash::Hash,
8 pin::Pin,
9 sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13 select,
14 sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24 G: Eq + Hash,
25{
26 sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27 handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32 G: Eq + Hash,
33{
34 fn default() -> Self {
35 Self {
36 sources: Default::default(),
37 handles: Default::default(),
38 }
39 }
40}
41
42impl<G> StateMachine<G>
43where
44 G: Clone + Debug + Eq + Hash,
45{
46 pub fn new() -> Self {
47 Default::default()
48 }
49
50 fn add_source<S>(&self, tag: G, source: Source<S>)
52 where
53 S: 'static + Send + Sync,
54 {
55 assert!(
56 !self.sources.contains_key(&tag),
57 "duplicate tag for source -- {:?}",
58 tag
59 );
60 self.sources.insert(tag, Box::new(source));
61 }
62
63 fn del_source(&self, tag: G) -> bool {
65 self.sources.remove(&tag).is_some()
66 }
67
68 async fn source<S>(&self, tag: G) -> Source<S>
70 where
71 S: 'static + Clone,
72 {
73 let opt_source_box = self.sources.get(&tag);
74 assert!(
75 opt_source_box.is_some(),
76 "state source does not exist, tag -- {:?}",
77 tag
78 );
79 let source_box = opt_source_box.unwrap();
80 let opt_source = source_box.downcast_ref::<Source<S>>();
81 assert!(
82 opt_source.is_some(),
83 "state source does not exist, tag -- {:?}, type -- {}",
84 tag,
85 type_name::<S>()
86 );
87 let source = opt_source.unwrap();
88 (*source).clone()
89 }
90
91 fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93 where
94 T: 'static + Send + Sync,
95 {
96 assert!(
97 !self.handles.contains_key(&tag),
98 "duplicate tag for handle -- {:?}",
99 tag
100 );
101 self.handles.insert(tag, Box::new(handle));
102 }
103
104 fn del_handle(&self, tag: G) -> bool {
106 self.handles.remove(&tag).is_some()
107 }
108
109 pub async fn source_value<S>(&self, tag: G) -> S
111 where
112 S: 'static + Clone + Default + PartialEq + Send,
113 {
114 self.source(tag).await.value().await
115 }
116
117 async fn handle<T>(&self, tag: G) -> Handle<T>
119 where
120 T: 'static + Clone,
121 {
122 let opt_handle_box = self.handles.get(&tag);
123 assert!(
124 opt_handle_box.is_some(),
125 "state handle does not exist, tag -- {:?}",
126 tag
127 );
128 let handle_box = opt_handle_box.unwrap();
129 let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130 assert!(
131 opt_handle.is_some(),
132 "state handle does not exist, tag -- {:?}, type -- {}",
133 tag,
134 type_name::<T>()
135 );
136 opt_handle.unwrap().clone()
137 }
138
139 async fn handle_value<T>(&self, tag: G) -> T
141 where
142 T: 'static + Clone + PartialEq,
143 {
144 self.handle(tag).await.value().await
145 }
146}
147
148#[async_trait]
150pub trait HasStateMachine<G>
151where
152 G: Clone + Debug + Eq + Hash,
153{
154 async fn lock(&self) -> MutexGuard<'_, ()>;
156
157 async fn state_machine(&self) -> StateMachine<G>;
159}
160
161#[async_trait]
163pub trait UseStateMachine<G>: HasStateMachine<G>
164where
165 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
166{
167 async fn add_source<S>(&self, tag: G)
169 where
170 S: 'static + Clone + Default + PartialEq + Send + Sync,
171 {
172 self.state_machine()
173 .await
174 .add_source(tag, Source::<S>::default());
175 }
176
177 async fn add_source_ex<S>(&self, tag: G, chan_capacity: usize, init_value: S)
179 where
180 S: 'static + Clone + Default + PartialEq + Send + Sync,
181 {
182 self.state_machine()
183 .await
184 .add_source(tag, Source::create(init_value, chan_capacity));
185 }
186
187 async fn del_source(&self, tag: G) -> bool {
189 self.state_machine().await.del_source(tag)
190 }
191
192 async fn num_of_subscriptions<S>(&self, tag: G) -> usize
194 where
195 S: 'static + Clone + Default + PartialEq + Send + Sync,
196 {
197 self.state_machine()
198 .await
199 .source::<S>(tag)
200 .await
201 .num_of_subscriptions()
202 .await
203 }
204
205 async fn source_value<S>(&self, tag: G) -> S
207 where
208 S: 'static + Clone + Default + PartialEq + Send + Sync,
209 {
210 self.state_machine().await.source_value(tag).await
211 }
212
213 async fn change<S>(&self, tag: G, s: S) -> Result<(), SourceChangeError>
215 where
216 S: 'static + Clone + Default + PartialEq + Send + Sync,
217 {
218 self.state_machine().await.source(tag).await.change(s).await
219 }
220
221 async fn wait_change<S>(&self, tag: G, s: S) -> Result<(), SourceChangeError>
223 where
224 S: 'static + Clone + Default + PartialEq + Send + Sync,
225 {
226 self.state_machine()
227 .await
228 .source(tag)
229 .await
230 .wait_change(s)
231 .await
232 }
233
234 async fn modify<S>(
236 &self,
237 tag: G,
238 func: impl Fn(S) -> S + Send + Sync + 'static,
239 ) -> Result<(), SourceChangeError>
240 where
241 S: 'static + Clone + Default + PartialEq + Send + Sync,
242 {
243 self.state_machine()
244 .await
245 .source(tag)
246 .await
247 .modify(func)
248 .await
249 }
250
251 async fn wait_modify<S>(
253 &self,
254 tag: G,
255 func: impl Fn(S) -> S + Send + Sync + 'static,
256 ) -> Result<(), SourceChangeError>
257 where
258 S: 'static + Clone + Default + PartialEq + Send + Sync,
259 {
260 self.state_machine()
261 .await
262 .source(tag)
263 .await
264 .wait_modify(func)
265 .await
266 }
267
268 async fn touch<S>(&self, tag: G) -> Result<(), SourceChangeError>
270 where
271 S: 'static + Clone + Default + PartialEq + Send + Sync,
272 {
273 self.state_machine()
274 .await
275 .source::<S>(tag)
276 .await
277 .touch()
278 .await
279 }
280
281 async fn handle_value<T>(&self, tag: G) -> T
283 where
284 T: 'static + Clone + PartialEq + Send + Sync,
285 {
286 self.state_machine().await.handle_value(tag).await
287 }
288
289 async fn reader<S>(&self, tag: G) -> Reader<S>
291 where
292 S: 'static + Clone + Default + PartialEq + Send,
293 {
294 self.state_machine().await.source::<S>(tag).await.reader()
295 }
296
297 async fn reader_ex<S, T>(
299 &self,
300 tag: G,
301 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
302 ) -> ReaderEx<S, T>
303 where
304 S: 'static + Clone + Default + PartialEq + Send,
305 {
306 self.state_machine()
307 .await
308 .source::<S>(tag)
309 .await
310 .reader_ex(func)
311 }
312
313 async fn unsubscribe<T>(&self, tag: G)
315 where
316 T: 'static + Clone + PartialEq + Send + Sync,
317 {
318 self.state_machine()
319 .await
320 .handle::<T>(tag)
321 .await
322 .unsubscribe();
323 }
324}
325
326#[async_trait]
327impl<T, G> UseStateMachine<G> for T
328where
329 T: HasStateMachine<G>,
330 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
331{
332}
333
334type NotCheckEq = bool;
337
338#[derive(Clone, Debug)]
340struct Source<S> {
341 value: Arc<RwLock<S>>,
342 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
343}
344
345impl<S> Default for Source<S>
346where
347 S: 'static + Clone + Default + PartialEq + Send,
348{
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354impl<S> Source<S>
355where
356 S: 'static + Clone + Default + PartialEq + Send,
357{
358 fn new() -> Self {
360 Self::create(Default::default(), 100)
361 }
362
363 fn create(init_value: S, chan_capacity: usize) -> Self {
366 let (tx, _) = broadcast::channel(chan_capacity);
367 Self {
368 value: Arc::new(RwLock::new(init_value)),
369 sender: tx,
370 }
371 }
372
373 fn reader(&self) -> Reader<S> {
375 Reader {
376 value: self.value.clone(),
377 recver: self.sender.subscribe(),
378 }
379 }
380
381 fn reader_ex<T>(
383 &self,
384 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
385 ) -> ReaderEx<S, T> {
386 ReaderEx {
387 value: self.value.clone(),
388 recver: self.sender.subscribe(),
389 func: Arc::new(func),
390 }
391 }
392
393 async fn num_of_subscriptions(&self) -> usize {
395 self.sender.receiver_count()
396 }
397
398 async fn value(&self) -> S {
400 (*self.value.read().await).clone()
401 }
402
403 async fn change_ex(
404 &self,
405 wait_to_end: bool,
406 change: Change<S>,
407 ) -> Result<(), SourceChangeError> {
408 let mut guard = self.value.write().await;
409 let (s, not_check_eq) = match change {
410 Change::Value(v) => (v, false),
411 Change::Func(func) => (func((*guard).clone()), false),
412 Change::Touch => ((*guard).clone(), true),
413 };
414 if not_check_eq || *guard != s {
415 if wait_to_end {
416 let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
417 self.sender
418 .send((s.clone(), not_check_eq, Some(tx_w)))
419 .map_err(|_| SourceChangeError::SendErr)?;
420 loop {
421 select! {
422 res = rx_w.recv() => {
423 if res.is_none() {
424 break;
425 }
426 }
427 }
428 }
429 } else {
430 self.sender
431 .send((s.clone(), not_check_eq, None))
432 .map_err(|_| SourceChangeError::SendErr)?;
433 }
434 *guard = s;
435 Ok(())
436 } else {
437 Err(SourceChangeError::NotChange)
438 }
439 }
440
441 async fn change(&self, s: S) -> Result<(), SourceChangeError> {
443 self.change_ex(false, Change::Value(s)).await
444 }
445
446 async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
448 self.change_ex(true, Change::Value(s)).await
449 }
450
451 async fn modify(
453 &self,
454 func: impl Fn(S) -> S + Send + Sync + 'static,
455 ) -> Result<(), SourceChangeError> {
456 self.change_ex(false, Change::Func(Arc::new(func))).await
457 }
458
459 async fn wait_modify(
461 &self,
462 func: impl Fn(S) -> S + Send + Sync + 'static,
463 ) -> Result<(), SourceChangeError> {
464 self.change_ex(true, Change::Func(Arc::new(func))).await
465 }
466
467 async fn touch(&self) -> Result<(), SourceChangeError> {
469 self.change_ex(false, Change::Touch).await
470 }
471}
472
473enum Change<S> {
474 Value(S),
475 Func(Arc<dyn Fn(S) -> S + Send + Sync>),
476 Touch,
477}
478
479#[derive(Debug, Error)]
480pub enum SourceChangeError {
481 #[error("Change of state failed to broadcast")]
482 SendErr,
483 #[error("State source not change, no change detected")]
484 NotChange,
485}
486
487pub struct Reader<S> {
489 value: Arc<RwLock<S>>,
490 recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
491}
492
493impl<S> Into<ReaderEx<S, S>> for Reader<S>
494where
495 S: 'static + Send,
496{
497 fn into(self) -> ReaderEx<S, S> {
498 ReaderEx {
499 value: self.value,
500 recver: self.recver,
501 func: Arc::new(|s| Box::pin(async move { s })),
502 }
503 }
504}
505
506impl<S> Reader<S> {
507 pub fn extend<T>(
508 self,
509 func: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync + 'static,
510 ) -> ReaderEx<S, T> {
511 ReaderEx {
512 value: self.value,
513 recver: self.recver,
514 func: Arc::new(func),
515 }
516 }
517}
518
519pub struct ReaderEx<S, T> {
521 value: Arc<RwLock<S>>,
522 recver: broadcast::Receiver<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
523 func: Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>,
524}
525
526impl<S, T> ReaderEx<S, T>
527where
528 S: Clone,
529{
530 async fn value(&self) -> T {
531 self.func.as_ref()((*self.value.read().await).clone()).await
532 }
533}
534
535#[derive(Clone, Debug)]
537struct Handle<T> {
538 cancel_token: CancellationToken,
539 value: Arc<RwLock<T>>,
540}
541
542impl<T> Handle<T>
543where
544 T: Clone + PartialEq,
545{
546 fn new(init_value: T) -> Self {
547 Self {
548 cancel_token: CancellationToken::new(),
549 value: Arc::new(RwLock::new(init_value)),
550 }
551 }
552
553 async fn store(&self, t: T, not_check_eq: bool) -> bool {
554 let changed = *self.value.read().await != t;
555 if changed {
556 *self.value.write().await = t;
557 }
558 not_check_eq || changed
559 }
560
561 async fn value(&self) -> T {
562 (*self.value.read().await).clone()
563 }
564
565 fn unsubscribe(&self) {
568 self.cancel_token.cancel();
569 }
570}
571
572#[async_trait]
579pub trait HasStateHandle<T, G>: HasStateMachine<G>
580where
581 T: Clone + Debug + PartialEq,
582 G: Clone + Debug + Eq + Hash,
583{
584 async fn on_change(
590 self: Arc<Self>,
591 tag: G,
592 new_value: T,
593 old_value: T,
594 ) -> Result<(), Box<dyn std::error::Error>>;
595}
596
597#[async_trait]
599pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
600where
601 T: 'static + Clone + Debug + PartialEq + Send + Sync,
602 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
603{
604 #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
610 async fn subscribe<S>(self: Arc<Self>, reader: impl Into<ReaderEx<S, T>> + Send, tag: G)
611 where
612 S: 'static + Clone + Debug + PartialEq + Send + Sync,
613 {
614 let reader_ex = reader.into();
615 let handle: Handle<T> = Handle::new(reader_ex.value().await);
616 self.state_machine()
617 .await
618 .add_handle(tag.clone(), handle.clone());
619 let mut rx_s = reader_ex.recver;
620 let (tx_t, mut rx_t) =
621 mpsc::unbounded_channel::<(T, T, Option<mpsc::UnboundedSender<()>>)>();
622 tokio::spawn(async move {
623 tracing::info!("Subscription start -- {:?}", tag);
624 loop {
625 select! {
626 _ = handle.cancel_token.cancelled() => {
627 break;
628 }
629 res = rx_s.recv() => {
630 match res {
631 Ok((s, not_check_eq, opt_feedback)) => {
632 let t = reader_ex.func.as_ref()(s).await;
633 let t_old = handle.value().await;
634 if handle.store(t.clone(), not_check_eq).await {
635 if let Err(e) = tx_t.send((t, t_old, opt_feedback)) {
636 tracing::error!("stage [2] | change event send error -- {}", e);
637 break;
638 }
639 }
640 },
641 Err(e) => match e {
642 broadcast::error::RecvError::Closed => {
643 _ = self.state_machine().await.del_source(tag.clone());
644 tracing::info!("state source channel closed");
645 break;
646 },
647 broadcast::error::RecvError::Lagged(_) => {
648 tracing::error!("stage [1] | change event recv lagged");
649 break;
650 },
651 },
652 }
653 }
654 res = rx_t.recv() => {
655 match res {
656 Some((t, t_old, opt_feedback)) => {
657 let _lock = self.lock().await;
658 if let Err(e) = self.clone().on_change(tag.clone(), t, t_old).await {
659 tracing::error!("stage [3] | change event proc error -- {}", e);
660 }
661 if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
662 tracing::error!("stage [4] | change event feedback error -- {}", e);
663 }
664 },
665 None => {
666 tracing::info!("state target channel closed");
667 break;
668 },
669 }
670 }
671 }
672 }
673 _ = self.state_machine().await.del_handle(tag.clone());
674 tracing::info!("Subscription end -- {:?}", tag);
675 });
676 }
677}
678
679impl<V, T, G> UseStateHandle<T, G> for V
680where
681 V: 'static + HasStateHandle<T, G>,
682 T: 'static + Clone + Debug + PartialEq + Send + Sync,
683 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
684{
685}