1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4 any::{Any, type_name},
5 cmp::Eq,
6 fmt::Debug,
7 hash::Hash,
8 pin::Pin,
9 sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13 select,
14 sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19#[derive(Clone, Debug)]
22pub struct StateMachine<G>
23where
24 G: Eq + Hash,
25{
26 sources: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
27 handles: Arc<DashMap<G, Box<dyn Any + Send + Sync>>>,
28}
29
30impl<G> Default for StateMachine<G>
31where
32 G: Eq + Hash,
33{
34 fn default() -> Self {
35 Self {
36 sources: Default::default(),
37 handles: Default::default(),
38 }
39 }
40}
41
42impl<G> StateMachine<G>
43where
44 G: Clone + Debug + Eq + Hash,
45{
46 pub fn new() -> Self {
47 Default::default()
48 }
49
50 pub(crate) fn add_source<S>(&self, tag: G, source: Source<S>)
52 where
53 S: 'static + Send + Sync,
54 {
55 assert!(
56 !self.sources.contains_key(&tag),
57 "duplicate tag for source -- {:?}",
58 tag
59 );
60 self.sources.insert(tag, Box::new(source));
61 }
62
63 pub(crate) fn del_source(&self, tag: G) -> bool {
65 self.sources.remove(&tag).is_some()
66 }
67
68 pub async fn source<S>(&self, tag: G) -> Source<S>
70 where
71 S: 'static + Clone,
72 {
73 let opt_source_box = self.sources.get(&tag);
74 assert!(
75 opt_source_box.is_some(),
76 "state source does not exist, tag -- {:?}",
77 tag
78 );
79 let source_box = opt_source_box.unwrap();
80 let opt_source = source_box.downcast_ref::<Source<S>>();
81 assert!(
82 opt_source.is_some(),
83 "state source does not exist, tag -- {:?}, type -- {}",
84 tag,
85 type_name::<S>()
86 );
87 let source = opt_source.unwrap();
88 (*source).clone()
89 }
90
91 pub(crate) fn add_handle<T>(&self, tag: G, handle: Handle<T>)
93 where
94 T: 'static + Send + Sync,
95 {
96 assert!(
97 !self.handles.contains_key(&tag),
98 "duplicate tag for handle -- {:?}",
99 tag
100 );
101 self.handles.insert(tag, Box::new(handle));
102 }
103
104 pub(crate) fn del_handle(&self, tag: G) -> bool {
106 self.handles.remove(&tag).is_some()
107 }
108
109 pub async fn source_value<S>(&self, tag: G) -> S
111 where
112 S: 'static + Clone + Default + PartialEq + Send,
113 {
114 self.source(tag).await.value().await
115 }
116
117 pub async fn handle<T>(&self, tag: G) -> Handle<T>
119 where
120 T: 'static + Clone,
121 {
122 let opt_handle_box = self.handles.get(&tag);
123 assert!(
124 opt_handle_box.is_some(),
125 "state handle does not exist, tag -- {:?}",
126 tag
127 );
128 let handle_box = opt_handle_box.unwrap();
129 let opt_handle = handle_box.downcast_ref::<Handle<T>>();
130 assert!(
131 opt_handle.is_some(),
132 "state handle does not exist, tag -- {:?}, type -- {}",
133 tag,
134 type_name::<T>()
135 );
136 opt_handle.unwrap().clone()
137 }
138
139 pub async fn handle_value<T>(&self, tag: G) -> Option<T>
141 where
142 T: 'static + Clone + PartialEq,
143 {
144 self.handle(tag).await.value().await
145 }
146}
147
148#[async_trait]
150pub trait HasLock {
151 async fn lock(&self) -> MutexGuard<'_, ()>;
153}
154
155#[async_trait]
157pub trait HasStateMachine<G>: HasLock
158where
159 G: Clone + Debug + Eq + Hash,
160{
161 async fn state_machine(&self) -> StateMachine<G>;
163}
164
165#[async_trait]
167pub trait UseStateMachine<G>: HasStateMachine<G>
168where
169 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
170{
171 async fn source<S>(&self, tag: G) -> Source<S>
173 where
174 S: 'static + Clone,
175 {
176 self.state_machine().await.source(tag).await
177 }
178
179 async fn source_value<S>(&self, tag: G) -> Option<S>
181 where
182 S: 'static + Clone + PartialEq + Send + Sync,
183 {
184 self.state_machine().await.source_value(tag).await
185 }
186
187 async fn handle<T>(&self, tag: G) -> Handle<T>
189 where
190 T: 'static + Clone,
191 {
192 self.state_machine().await.handle(tag).await
193 }
194
195 async fn handle_value<T>(&self, tag: G) -> Option<T>
197 where
198 T: 'static + Clone + PartialEq + Send + Sync,
199 {
200 self.state_machine().await.handle_value(tag).await
201 }
202}
203
204#[async_trait]
205impl<T, G> UseStateMachine<G> for T
206where
207 T: HasStateMachine<G>,
208 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
209{
210}
211
212#[async_trait]
214pub trait UseStateSource<G>: HasStateMachine<G>
215where
216 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
217{
218 async fn add_source<S>(&self, tag: G, source: Source<S>)
220 where
221 S: 'static + Send + Sync,
222 {
223 self.state_machine().await.add_source(tag, source);
224 }
225}
226
227impl<T, G> UseStateSource<G> for T
228where
229 T: HasStateMachine<G>,
230 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
231{
232}
233
234type NotCheckEq = bool;
237
238#[derive(Clone, Debug)]
240pub struct Source<S> {
241 value: Arc<RwLock<S>>,
242 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
243}
244
245impl<S> Default for Source<S>
246where
247 S: 'static + Clone + Default + PartialEq + Send,
248{
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl<S> Source<S>
255where
256 S: 'static + Clone + Default + PartialEq + Send,
257{
258 pub fn new() -> Self {
260 Self::create(Default::default(), 100)
261 }
262
263 pub fn create(init_value: S, capacity: usize) -> Self {
266 let (tx, _) = broadcast::channel(capacity);
267 Self {
268 value: Arc::new(RwLock::new(init_value)),
269 sender: tx,
270 }
271 }
272
273 pub fn reader(&self) -> Reader<S> {
275 Reader {
276 value: self.value.clone(),
277 sender: self.sender.clone(),
278 }
279 }
280
281 pub fn reader_ex<T>(&self, func: ConvertFunc<S, T>) -> ReaderEx<S, T> {
283 ReaderEx {
284 value: self.value.clone(),
285 sender: self.sender.clone(),
286 func,
287 }
288 }
289
290 pub async fn num_of_subs(&self) -> usize {
292 self.sender.receiver_count()
293 }
294
295 pub async fn value(&self) -> S {
297 (*self.value.read().await).clone()
298 }
299
300 async fn change_ex(
301 &self,
302 wait_to_end: bool,
303 change: Change<S>,
304 ) -> Result<(), SourceChangeError> {
305 let mut guard = self.value.write().await;
306 let (s, not_check_eq) = match change {
307 Change::Value(v) => (v, false),
308 Change::Func(func) => (func((*guard).clone()), false),
309 Change::Touch => ((*guard).clone(), true),
310 };
311 if not_check_eq || *guard != s {
312 if wait_to_end {
313 let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
314 self.sender
315 .send((s.clone(), not_check_eq, Some(tx_w)))
316 .map_err(|_| SourceChangeError::SendErr)?;
317 loop {
318 select! {
319 res = rx_w.recv() => {
320 if res.is_none() {
321 break;
322 }
323 }
324 }
325 }
326 } else {
327 self.sender
328 .send((s.clone(), not_check_eq, None))
329 .map_err(|_| SourceChangeError::SendErr)?;
330 }
331 *guard = s;
332 Ok(())
333 } else {
334 Err(SourceChangeError::NotChange)
335 }
336 }
337
338 pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
340 self.change_ex(false, Change::Value(s)).await
341 }
342
343 pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
345 self.change_ex(true, Change::Value(s)).await
346 }
347
348 pub async fn modify(&self, func: impl Fn(S) -> S + 'static) -> Result<(), SourceChangeError> {
350 self.change_ex(false, Change::Func(Box::new(func))).await
351 }
352
353 pub async fn wait_modify(
355 &self,
356 func: impl Fn(S) -> S + 'static,
357 ) -> Result<(), SourceChangeError> {
358 self.change_ex(true, Change::Func(Box::new(func))).await
359 }
360
361 pub async fn touch(&self) -> Result<(), SourceChangeError> {
363 self.change_ex(false, Change::Touch).await
364 }
365}
366
367enum Change<S> {
368 Value(S),
369 Func(Box<dyn Fn(S) -> S>),
370 Touch,
371}
372
373#[derive(Debug, Error)]
374pub enum SourceChangeError {
375 #[error("Change of state failed to broadcast")]
376 SendErr,
377 #[error("State source not change, no change detected")]
378 NotChange,
379}
380
381#[derive(Clone)]
383pub struct Reader<S> {
384 value: Arc<RwLock<S>>,
385 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
386}
387
388impl<S> Into<ReaderEx<S, S>> for Reader<S>
389where
390 S: 'static + Send,
391{
392 fn into(self) -> ReaderEx<S, S> {
393 ReaderEx {
394 value: self.value,
395 sender: self.sender,
396 func: Arc::new(|s| Box::pin(async move { s })),
397 }
398 }
399}
400
401impl<S> Reader<S> {
402 pub fn extend<T>(&self, func: ConvertFunc<S, T>) -> ReaderEx<S, T> {
403 ReaderEx {
404 value: self.value.clone(),
405 sender: self.sender.clone(),
406 func,
407 }
408 }
409}
410
411pub type ConvertFunc<S, T> =
412 Arc<dyn Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + Sync>;
413
414#[derive(Clone)]
416pub struct ReaderEx<S, T> {
417 value: Arc<RwLock<S>>,
418 sender: broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>,
419 func: ConvertFunc<S, T>,
420}
421
422impl<S, T> ReaderEx<S, T>
423where
424 S: Clone,
425{
426 async fn value(&self) -> T {
427 self.func.as_ref()((*self.value.read().await).clone()).await
428 }
429}
430
431#[derive(Clone, Debug)]
433pub struct Handle<T> {
434 cancel_token: CancellationToken,
435 value: Arc<RwLock<T>>,
436}
437
438impl<T> Handle<T>
439where
440 T: Clone + PartialEq,
441{
442 fn new(init_value: T) -> Self {
443 Self {
444 cancel_token: CancellationToken::new(),
445 value: Arc::new(RwLock::new(init_value)),
446 }
447 }
448
449 async fn store(&self, t: T, not_check_eq: bool) -> bool {
450 let changed = *self.value.read().await != t;
451 if changed {
452 *self.value.write().await = t;
453 }
454 not_check_eq || changed
455 }
456
457 async fn value(&self) -> T {
458 (*self.value.read().await).clone()
459 }
460
461 pub fn unsubscribe(&self) {
464 self.cancel_token.cancel();
465 }
466}
467
468#[async_trait]
475pub trait HasStateHandle<T, G>: HasStateMachine<G>
476where
477 T: Clone + Debug + PartialEq,
478 G: Clone + Debug + Eq + Hash,
479{
480 async fn on_change(
486 self: Arc<Self>,
487 tag: G,
488 new_value: T,
489 old_value: T,
490 ) -> Result<(), impl std::error::Error>;
491}
492
493#[async_trait]
495pub trait UseStateHandle<T, G>: HasStateHandle<T, G> + 'static
496where
497 T: 'static + Clone + Debug + PartialEq + Send + Sync,
498 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
499{
500 #[instrument(name = "UseStateHandle::subscribe", skip_all, fields(tag))]
506 async fn subscribe<S>(
507 self: Arc<Self>,
508 reader: impl Into<ReaderEx<S, T>> + Send,
509 tag: G,
510 ) -> Handle<T>
511 where
512 S: 'static + Clone + Debug + PartialEq + Send + Sync,
513 {
514 let reader_ex = reader.into();
515 let handle: Handle<T> = Handle::new(reader_ex.value().await);
516 self.state_machine()
517 .await
518 .add_handle(tag.clone(), handle.clone());
519 let mut rx_s = reader_ex.sender.subscribe();
520 let (tx_t, mut rx_t) =
521 mpsc::unbounded_channel::<(T, T, Option<mpsc::UnboundedSender<()>>)>();
522 let handle_c = handle.clone();
523 tokio::spawn(async move {
524 tracing::info!("Subscription start -- {:?}", tag);
525 loop {
526 select! {
527 _ = handle_c.cancel_token.cancelled() => {
528 break;
529 }
530 res = rx_s.recv() => {
531 match res {
532 Ok((s, not_check_eq, opt_feedback)) => {
533 let t = reader_ex.func.as_ref()(s).await;
534 let t_old = handle_c.value().await;
535 if handle_c.store(t.clone(), not_check_eq).await {
536 if let Err(e) = tx_t.send((t, t_old, opt_feedback)) {
537 tracing::error!("stage [2] | change event send error -- {}", e);
538 break;
539 }
540 }
541 },
542 Err(e) => match e {
543 broadcast::error::RecvError::Closed => {
544 _ = self.state_machine().await.del_source(tag.clone());
545 tracing::info!("state source channel closed");
546 break;
547 },
548 broadcast::error::RecvError::Lagged(_) => {
549 tracing::error!("stage [1] | change event recv lagged");
550 break;
551 },
552 },
553 }
554 }
555 res = rx_t.recv() => {
556 match res {
557 Some((t, t_old, opt_feedback)) => {
558 let _lock = self.lock().await;
559 if let Err(e) = self.clone().on_change(tag.clone(), t, t_old).await {
560 tracing::error!("stage [3] | change event proc error -- {}", e);
561 }
562 if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
563 tracing::error!("stage [4] | change event feedback error -- {}", e);
564 }
565 },
566 None => {
567 tracing::info!("state target channel closed");
568 break;
569 },
570 }
571 }
572 }
573 }
574 _ = self.state_machine().await.del_handle(tag.clone());
575 tracing::info!("Subscription end -- {:?}", tag);
576 });
577 handle
578 }
579}
580
581impl<V, T, G> UseStateHandle<T, G> for V
582where
583 V: 'static + HasStateHandle<T, G>,
584 T: 'static + Clone + Debug + PartialEq + Send + Sync,
585 G: 'static + Clone + Debug + Eq + Hash + Send + Sync,
586{
587}