1use std::{collections::HashMap, fmt, fmt::Debug, hash::Hash, sync::Arc};
2
3use bigerror::LogError;
4use tokio::{sync::mpsc::UnboundedSender, task::JoinSet};
5use tokio_stream::StreamExt;
6use tracing::{debug, trace, warn, Instrument};
7
8use crate::{
9 queue::{StreamReceiver, StreamableDeque},
10 HashKind, Rex, StateId,
11};
12
13pub trait RexMessage: GetTopic<Self::Topic> + Clone + fmt::Debug + Send + Sync + 'static
16where
17 Self: Send + Sync,
18{
19 type Topic: RexTopic;
20}
21
22pub trait GetTopic<T: RexTopic>: fmt::Debug {
25 fn get_topic(&self) -> T;
26}
27
28#[derive(Debug, Clone)]
32pub struct Notification<M: RexMessage>(pub M);
33
34impl<M, T> GetTopic<T> for Notification<M>
35where
36 T: RexTopic,
37 M: RexMessage + GetTopic<T>,
38{
39 fn get_topic(&self) -> T {
40 self.0.get_topic()
41 }
42}
43
44pub trait RexTopic: fmt::Debug + Hash + Eq + PartialEq + Copy + Send + Sync + 'static {}
45impl<T> RexTopic for T where T: fmt::Debug + Hash + Eq + PartialEq + Copy + Send + Sync + 'static {}
46
47pub type Subscriber<M> = UnboundedSender<Notification<M>>;
50pub struct NotificationManager<M>
53where
54 M: RexMessage,
55{
56 processors: Arc<HashMap<M::Topic, Vec<Subscriber<M>>>>,
57 notification_queue: NotificationQueue<M>,
58}
59
60#[derive(Default, Clone, Debug)]
61pub struct NotificationQueue<M: RexMessage>(pub(crate) Arc<StreamableDeque<Notification<M>>>);
62
63impl<M: RexMessage> NotificationQueue<M> {
64 #[must_use]
65 pub fn new() -> Self {
66 Self(Arc::new(StreamableDeque::new()))
67 }
68 pub fn send(&self, notif: Notification<M>) {
69 self.0.push_back(notif);
70 }
71
72 pub fn priority_send(&self, notif: Notification<M>) {
73 self.0.push_front(notif);
74 }
75
76 #[must_use]
77 pub fn stream(&self) -> StreamReceiver<Notification<M>> {
78 self.0.stream()
79 }
80}
81
82impl<M> NotificationManager<M>
83where
84 M: RexMessage,
85{
86 pub fn new(
87 processors: Vec<Box<dyn NotificationProcessor<M>>>,
88 join_set: &mut JoinSet<()>,
89 notification_queue: NotificationQueue<M>,
90 ) -> Self {
91 let processors: HashMap<M::Topic, Vec<UnboundedSender<Notification<M>>>> = processors
92 .into_iter()
93 .fold(HashMap::new(), |mut subscribers, mut processor| {
94 let subscriber_tx = processor.init(join_set);
95 for topic in processor.get_topics() {
96 subscribers
97 .entry(*topic)
98 .or_default()
99 .push(subscriber_tx.clone());
100 }
101 subscribers
102 });
103 Self {
104 processors: Arc::new(processors),
105 notification_queue,
106 }
107 }
108
109 pub fn init(&self, join_set: &mut JoinSet<()>) -> NotificationQueue<M> {
110 let stream_queue = self.notification_queue.clone();
111 let processors = self.processors.clone();
112 join_set.spawn(async move {
113 debug!(spawning = "NotificationManager.processors");
114 let mut stream = stream_queue.stream();
115 while let Some(notification) = stream.next().await {
116 trace!(?notification);
117 let topic = notification.get_topic();
118 if let Some(subscribers) = processors.get(&topic) {
119 let Some((last, rest)) = subscribers.split_last() else {
120 continue;
121 };
122 for tx in rest {
123 tx.send(notification.clone()).log_attached_err(format!(
124 "nm::processors send failed for topic {topic:?}"
125 ));
126 }
127 last.send(notification).log_attached_err(format!(
128 "nm::processors send last failed for topic {topic:?}"
129 ));
130 } else {
131 warn!(topic = ?notification.get_topic(), ?notification, "NotificationProcessor not found");
132 }
133 }
134 }.in_current_span());
135 self.notification_queue.clone()
136 }
137}
138
139pub trait NotificationProcessor<M>: Send + Sync
140where
141 M: RexMessage,
142{
143 fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<M>>;
144 fn get_topics(&self) -> &[M::Topic];
145}
146
147#[derive(Debug, Clone)]
150pub struct UnaryRequest<K, O>
151where
152 K: HashKind,
153 O: Operation,
154{
155 pub id: StateId<K>,
156 pub op: O,
157}
158
159impl<K, O> UnaryRequest<K, O>
160where
161 K: HashKind,
162 O: Operation,
163{
164 pub const fn new(id: StateId<K>, op: O) -> Self {
165 Self { id, op }
166 }
167}
168
169impl<K: HashKind, O: Operation + Copy> Copy for UnaryRequest<K, O> {}
170
171pub trait Operation: std::fmt::Display + Clone {}
173impl<Op> Operation for Op where Op: std::fmt::Display + Clone {}
174
175pub trait Request<K>
176where
177 K: HashKind,
178 Self: Operation,
179{
180 fn request(self, id: StateId<K>) -> UnaryRequest<K, Self>;
181}
182
183impl<K: Rex, Op: Operation> Request<K> for Op
184where
185 K::Message: From<UnaryRequest<K, Op>>,
186{
187 fn request(self, id: StateId<K>) -> UnaryRequest<K, Op> {
188 UnaryRequest { id, op: self }
189 }
190}
191
192pub trait RequestInner<K>
193where
194 K: HashKind,
195 Self: Sized,
196{
197 fn request_inner<Op>(self, id: StateId<K>) -> UnaryRequest<K, Op>
198 where
199 Op: Operation + From<Self>;
200}
201
202impl<K, T> RequestInner<K> for T
203where
204 K: HashKind,
205{
206 fn request_inner<Op>(self, id: StateId<K>) -> UnaryRequest<K, Op>
207 where
208 Op: Operation + From<T>,
209 {
210 UnaryRequest {
211 id,
212 op: self.into(),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use std::time::Duration;
220
221 use super::*;
222 use crate::{test_support::*, StateId};
223
224 #[tokio::test]
225 async fn route_to_timeout_manager() {
226 use crate::timeout::*;
227
228 let timeout_manager = TimeoutManager::test_default();
229 let sq1 = timeout_manager.signal_queue.clone();
230 let timeout_manager_two = TimeoutManager::test_default();
231 let sq2 = timeout_manager_two.signal_queue.clone();
232 let mut join_set = JoinSet::new();
233 let notification_manager = NotificationManager::new(
234 vec![Box::new(timeout_manager), Box::new(timeout_manager_two)],
235 &mut join_set,
236 NotificationQueue::new(),
237 );
238 let notification_tx = notification_manager.init(&mut join_set);
239
240 let test_id = StateId::new_with_u128(TestKind, 1);
241 let timeout_duration = Duration::from_millis(1);
243
244 let set_timeout = Notification(TimeoutInput::set_timeout(test_id, timeout_duration).into());
245 notification_tx.send(set_timeout);
246
247 tokio::time::sleep(Duration::from_millis(10)).await;
248
249 let timeout_one = sq1.pop_front().expect("timeout one");
250 let timeout_two = sq2.pop_front().expect("timeout two");
251 assert_eq!(timeout_one.id, timeout_two.id);
252 }
253}