1use async_trait::async_trait;
2use dashmap::DashMap;
3use std::{
4 any::{Any, type_name},
5 cmp::Eq,
6 fmt::Debug,
7 hash::Hash,
8 pin::Pin,
9 sync::Arc,
10};
11use thiserror::Error;
12use tokio::{
13 select,
14 sync::{MutexGuard, RwLock, broadcast, mpsc},
15};
16use tokio_util::sync::CancellationToken;
17use tracing::instrument;
18
19#[derive(Clone, Debug)]
21pub struct StateMachine<Tag>
22where
23 Tag: Eq + Hash,
24{
25 sources: Arc<DashMap<Tag, Box<dyn Any + Send + Sync>>>,
26 handles: Arc<DashMap<Tag, Box<dyn Any + Send + Sync>>>,
27}
28
29impl<Tag> Default for StateMachine<Tag>
30where
31 Tag: Eq + Hash,
32{
33 fn default() -> Self {
34 Self {
35 sources: Default::default(),
36 handles: Default::default(),
37 }
38 }
39}
40
41impl<Tag> StateMachine<Tag>
42where
43 Tag: Clone + Debug + Eq + Hash,
44{
45 pub fn new() -> Self {
46 Default::default()
47 }
48
49 pub(crate) fn add_source<S>(&self, tag: Tag, source: Source<S>)
51 where
52 S: 'static + Send + Sync,
53 {
54 assert!(
55 !self.sources.contains_key(&tag),
56 "duplicate tag for source -- {:?}",
57 tag
58 );
59 self.sources.insert(tag, Box::new(source));
60 }
61
62 pub(crate) fn del_source(&self, tag: Tag) -> bool {
64 self.sources.remove(&tag).is_some()
65 }
66
67 pub async fn source<S>(&self, tag: Tag) -> Source<S>
69 where
70 S: 'static + Clone,
71 {
72 let opt_source_box = self.sources.get(&tag);
73 assert!(
74 opt_source_box.is_some(),
75 "state source does not exist, tag -- {:?}",
76 tag
77 );
78 let source_box = opt_source_box.unwrap();
79 let opt_source = source_box.downcast_ref::<Source<S>>();
80 assert!(
81 opt_source.is_some(),
82 "state source does not exist, tag -- {:?}, type -- {}",
83 tag,
84 type_name::<S>()
85 );
86 let source = opt_source.unwrap();
87 (*source).clone()
88 }
89
90 pub(crate) fn add_handle<T>(&self, tag: Tag, handle: Handle<T>)
92 where
93 T: 'static + Send + Sync,
94 {
95 assert!(
96 !self.handles.contains_key(&tag),
97 "duplicate tag for handle -- {:?}",
98 tag
99 );
100 self.handles.insert(tag, Box::new(handle));
101 }
102
103 pub(crate) fn del_handle(&self, tag: Tag) -> bool {
105 self.handles.remove(&tag).is_some()
106 }
107
108 pub async fn source_value<S>(&self, tag: Tag) -> Option<S>
110 where
111 S: 'static + Clone + PartialEq,
112 {
113 self.source(tag).await.value().await
114 }
115
116 pub async fn handle<T>(&self, tag: Tag) -> Handle<T>
118 where
119 T: 'static + Clone,
120 {
121 let opt_handle_box = self.handles.get(&tag);
122 assert!(
123 opt_handle_box.is_some(),
124 "state handle does not exist, tag -- {:?}",
125 tag
126 );
127 let handle_box = opt_handle_box.unwrap();
128 let opt_handle = handle_box.downcast_ref::<Handle<T>>();
129 assert!(
130 opt_handle.is_some(),
131 "state handle does not exist, tag -- {:?}, type -- {}",
132 tag,
133 type_name::<T>()
134 );
135 opt_handle.unwrap().clone()
136 }
137
138 pub async fn handle_value<T>(&self, tag: Tag) -> Option<T>
140 where
141 T: 'static + Clone + PartialEq,
142 {
143 self.handle(tag).await.value().await
144 }
145}
146
147#[async_trait]
149pub trait HasStateMachine<Tag>
150where
151 Tag: Eq + Hash,
152{
153 async fn lock(&self) -> MutexGuard<'_, ()>;
155
156 async fn state_machine(&self) -> StateMachine<Tag>;
158}
159
160#[async_trait]
162pub trait UseStateMachine<Tag>: HasStateMachine<Tag>
163where
164 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
165{
166 async fn source<S>(&self, tag: Tag) -> Source<S>
168 where
169 S: 'static + Clone,
170 {
171 self.state_machine().await.source(tag).await
172 }
173
174 async fn source_value<S>(&self, tag: Tag) -> Option<S>
176 where
177 S: 'static + Clone + PartialEq + Send + Sync,
178 {
179 self.state_machine().await.source_value(tag).await
180 }
181
182 async fn handle<T>(&self, tag: Tag) -> Handle<T>
184 where
185 T: 'static + Clone,
186 {
187 self.state_machine().await.handle(tag).await
188 }
189
190 async fn handle_value<T>(&self, tag: Tag) -> Option<T>
192 where
193 T: 'static + Clone + PartialEq + Send + Sync,
194 {
195 self.state_machine().await.handle_value(tag).await
196 }
197}
198
199#[async_trait]
200impl<T, Tag> UseStateMachine<Tag> for T
201where
202 T: HasStateMachine<Tag>,
203 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
204{
205}
206
207#[async_trait]
209pub trait UseStateSource<Tag>: HasStateMachine<Tag>
210where
211 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
212{
213 async fn add_source<S>(&self, tag: Tag, source: Source<S>)
215 where
216 S: 'static + Send + Sync,
217 {
218 self.state_machine().await.add_source(tag, source);
219 }
220}
221
222impl<T, Tag> UseStateSource<Tag> for T
223where
224 T: HasStateMachine<Tag>,
225 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
226{
227}
228
229type NotCheckEq = bool;
232
233#[derive(Clone, Debug)]
235pub struct Source<S> {
236 value: Arc<RwLock<Option<S>>>,
237 sender: Arc<broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>>,
238}
239
240impl<S> Source<S>
241where
242 S: Clone + PartialEq,
243{
244 pub fn new() -> Self {
246 Self::create(100)
247 }
248
249 pub fn create(capacity: usize) -> Self {
252 let (tx, _) = broadcast::channel(capacity);
253 Self {
254 value: Arc::new(RwLock::new(None)),
255 sender: Arc::new(tx),
256 }
257 }
258
259 pub fn reader(&self) -> Reader<S> {
261 Reader {
262 sender: self.sender.clone(),
263 }
264 }
265
266 pub async fn num_of_subs(&self) -> usize {
268 self.sender.receiver_count()
269 }
270
271 pub async fn value(&self) -> Option<S> {
273 (*self.value.read().await).clone()
274 }
275
276 async fn change_ex(
277 &self,
278 wait_to_end: bool,
279 change: Change<S>,
280 ) -> Result<(), SourceChangeError> {
281 let mut guard = self.value.write().await;
282 let (opt_s, not_check_eq) = match change {
283 Change::Value(v) => (Some(v), false),
284 Change::Func(func) => ((*guard).clone().map(|v| func(v)), false),
285 Change::Touch => ((*guard).clone(), true),
286 };
287 if not_check_eq || *guard != opt_s {
288 if let Some(s) = opt_s {
289 if wait_to_end {
290 let (tx_w, mut rx_w) = mpsc::unbounded_channel::<()>();
291 self.sender
292 .send((s.clone(), not_check_eq, Some(tx_w)))
293 .map_err(|_| SourceChangeError::SendErr)?;
294 loop {
295 select! {
296 res = rx_w.recv() => {
297 if res.is_none() {
298 break;
299 }
300 }
301 }
302 }
303 } else {
304 self.sender
305 .send((s.clone(), not_check_eq, None))
306 .map_err(|_| SourceChangeError::SendErr)?;
307 }
308 *guard = Some(s);
309 }
310 Ok(())
311 } else {
312 Err(SourceChangeError::NotChange)
313 }
314 }
315
316 pub async fn change(&self, s: S) -> Result<(), SourceChangeError> {
318 self.change_ex(false, Change::Value(s)).await
319 }
320
321 pub async fn wait_change(&self, s: S) -> Result<(), SourceChangeError> {
323 self.change_ex(true, Change::Value(s)).await
324 }
325
326 pub async fn modify(&self, func: impl Fn(S) -> S + 'static) -> Result<(), SourceChangeError> {
328 self.change_ex(false, Change::Func(Box::new(func))).await
329 }
330
331 pub async fn wait_modify(
333 &self,
334 func: impl Fn(S) -> S + 'static,
335 ) -> Result<(), SourceChangeError> {
336 self.change_ex(true, Change::Func(Box::new(func))).await
337 }
338
339 pub async fn touch(&self) -> Result<(), SourceChangeError> {
341 self.change_ex(false, Change::Touch).await
342 }
343}
344
345enum Change<S> {
346 Value(S),
347 Func(Box<dyn Fn(S) -> S>),
348 Touch,
349}
350
351#[derive(Debug, Error)]
352pub enum SourceChangeError {
353 #[error("Change of state failed to broadcast")]
354 SendErr,
355 #[error("State source not change, no change detected")]
356 NotChange,
357}
358
359#[derive(Clone, Debug)]
361pub struct Reader<S> {
362 sender: Arc<broadcast::Sender<(S, NotCheckEq, Option<mpsc::UnboundedSender<()>>)>>,
363}
364
365#[derive(Clone, Debug)]
367pub struct Handle<T> {
368 cancel_token: CancellationToken,
369 value: Arc<RwLock<Option<T>>>,
370}
371
372impl<T> Handle<T>
373where
374 T: Clone + PartialEq,
375{
376 fn new() -> Self {
377 Self {
378 cancel_token: CancellationToken::new(),
379 value: Arc::new(RwLock::new(None)),
380 }
381 }
382
383 async fn store(&self, val: T, not_check_eq: bool) -> bool {
384 let opt_t = Some(val);
385 let res = *self.value.read().await != opt_t;
386 if res {
387 *self.value.write().await = opt_t;
388 }
389 not_check_eq || res
390 }
391
392 async fn value(&self) -> Option<T> {
393 (*self.value.read().await).clone()
394 }
395
396 pub fn unsubscribe(&self) {
399 self.cancel_token.cancel();
400 }
401}
402
403#[async_trait]
411pub trait HasStateHandle<S, T, Tag>: HasStateMachine<Tag>
412where
413 Tag: Eq + Hash,
414{
415 async fn on_change(
421 self: Arc<Self>,
422 tag: Tag,
423 new_value: T,
424 old_value: Option<T>,
425 ) -> anyhow::Result<()>;
426}
427
428#[async_trait]
430pub trait UseStateConvTarget<S, T, Tag>: HasStateHandle<S, T, Tag>
431where
432 Self: 'static,
433 S: 'static + Clone + Debug + Send,
434 T: 'static + Clone + Debug + PartialEq + Send + Sync,
435 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
436{
437 #[instrument(
443 name = "UseStateConvTarget::convert_subscribe",
444 skip_all,
445 fields(tag, chan_cap)
446 )]
447 async fn convert_subscribe(
448 self: Arc<Self>,
449 reader: Reader<S>,
450 tag: Tag,
451 convert: impl Fn(S) -> Pin<Box<dyn Future<Output = T> + Send>> + Send + 'static,
452 ) -> Handle<T> {
453 let handle: Handle<T> = Handle::new();
454 self.state_machine()
455 .await
456 .add_handle(tag.clone(), handle.clone());
457 let mut rx_s = reader.sender.subscribe();
458 let (tx_t, mut rx_t) =
459 mpsc::unbounded_channel::<(T, Option<T>, Option<mpsc::UnboundedSender<()>>)>();
460 let handle_c = handle.clone();
461 tokio::spawn(async move {
462 tracing::info!("Subscription start -- {:?}", tag);
463 loop {
464 select! {
465 _ = handle_c.cancel_token.cancelled() => {
466 break;
467 }
468 res = rx_s.recv() => {
469 match res {
470 Ok((s, not_check_eq, opt_feedback)) => {
471 let t = convert(s).await;
472 let opt_t_old = handle_c.value().await;
473 if handle_c.store(t.clone(), not_check_eq).await {
474 if let Err(e) = tx_t.send((t, opt_t_old, opt_feedback)) {
475 tracing::error!("stage [2] | change event send error -- {}", e);
476 break;
477 }
478 }
479 },
480 Err(e) => match e {
481 broadcast::error::RecvError::Closed => {
482 _ = self.state_machine().await.del_source(tag.clone());
483 tracing::info!("state source channel closed");
484 break;
485 },
486 broadcast::error::RecvError::Lagged(_) => {
487 tracing::error!("stage [1] | change event recv lagged");
488 break;
489 },
490 },
491 }
492 }
493 res = rx_t.recv() => {
494 match res {
495 Some((t, opt_t_old, opt_feedback)) => {
496 let _lock = self.lock().await;
497 if let Err(e) = self.clone().on_change(tag.clone(), t, opt_t_old).await {
498 tracing::error!("stage [3] | change event proc error -- {}", e);
499 }
500 if let Some(feedback) = opt_feedback && let Err(e) = feedback.send(()) {
501 tracing::error!("stage [4] | change event feedback error -- {}", e);
502 }
503 },
504 None => {
505 tracing::info!("state target channel closed");
506 break;
507 },
508 }
509 }
510 }
511 }
512 _ = self.state_machine().await.del_handle(tag.clone());
513 tracing::info!("Subscription end -- {:?}", tag);
514 });
515 handle
516 }
517}
518
519impl<V, S, T, Tag> UseStateConvTarget<S, T, Tag> for V
520where
521 V: 'static + HasStateHandle<S, T, Tag>,
522 S: 'static + Clone + Debug + Send,
523 T: 'static + Clone + Debug + PartialEq + Send + Sync,
524 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
525{
526}
527
528#[async_trait]
530pub trait UseStateTarget<T, Tag>: UseStateConvTarget<T, T, Tag>
531where
532 Self: 'static,
533 T: 'static + Clone + Debug + PartialEq + Send + Sync,
534 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
535{
536 #[instrument(name = "UseStateTarget::subscribe", skip_all, fields(tag, chan_cap))]
538 async fn subscribe(self: Arc<Self>, reader: Reader<T>, tag: Tag) -> Handle<T> {
539 UseStateConvTarget::convert_subscribe(self, reader, tag, |t| Box::pin(async move { t }))
540 .await
541 }
542}
543
544impl<V, T, Tag> UseStateTarget<T, Tag> for V
545where
546 V: 'static + UseStateConvTarget<T, T, Tag>,
547 T: 'static + Clone + Debug + PartialEq + Send + Sync,
548 Tag: 'static + Clone + Debug + Eq + Hash + Send + Sync,
549{
550}