rex/
notification.rs

1use std::{collections::HashMap, fmt, fmt::Debug, hash::Hash, sync::Arc};
2
3use bigerror::LogError;
4use tokio::{sync::mpsc::UnboundedSender, task::JoinSet};
5use tokio_stream::StreamExt;
6use tracing::{debug, trace, warn, Instrument};
7
8use crate::{
9    queue::{StreamReceiver, StreamableDeque},
10    HashKind, Rex, StateId,
11};
12
13// a PubSub message that is able to be sent to [`NotificationProcessor`]s that subscribe to one
14// or more [`RexTopic`]s
15pub trait RexMessage: GetTopic<Self::Topic> + Clone + fmt::Debug + Send + Sync + 'static
16where
17    Self: Send + Sync,
18{
19    type Topic: RexTopic;
20}
21
22/// Used to derive a marker used to route [`Notification`]s
23/// to [`NotificationProcessor`]s
24pub trait GetTopic<T: RexTopic>: fmt::Debug {
25    fn get_topic(&self) -> T;
26}
27
28/// This is the analogue to [`super::node_state_machine::Signal`]
29/// that is meant to send messages to anything that is _not_ a
30/// state machine
31#[derive(Debug, Clone)]
32pub struct Notification<M: RexMessage>(pub M);
33
34impl<M, T> GetTopic<T> for Notification<M>
35where
36    T: RexTopic,
37    M: RexMessage + GetTopic<T>,
38{
39    fn get_topic(&self) -> T {
40        self.0.get_topic()
41    }
42}
43
44pub trait RexTopic: fmt::Debug + Hash + Eq + PartialEq + Copy + Send + Sync + 'static {}
45impl<T> RexTopic for T where T: fmt::Debug + Hash + Eq + PartialEq + Copy + Send + Sync + 'static {}
46
47// --------------------------------------
48
49pub type Subscriber<M> = UnboundedSender<Notification<M>>;
50/// [`NotificationManager`] routes [`Notifications`] to their desired
51/// destination
52pub struct NotificationManager<M>
53where
54    M: RexMessage,
55{
56    processors: Arc<HashMap<M::Topic, Vec<Subscriber<M>>>>,
57    notification_queue: NotificationQueue<M>,
58}
59
60#[derive(Default, Clone, Debug)]
61pub struct NotificationQueue<M: RexMessage>(pub(crate) Arc<StreamableDeque<Notification<M>>>);
62
63impl<M: RexMessage> NotificationQueue<M> {
64    #[must_use]
65    pub fn new() -> Self {
66        Self(Arc::new(StreamableDeque::new()))
67    }
68    pub fn send(&self, notif: Notification<M>) {
69        self.0.push_back(notif);
70    }
71
72    pub fn priority_send(&self, notif: Notification<M>) {
73        self.0.push_front(notif);
74    }
75
76    #[must_use]
77    pub fn stream(&self) -> StreamReceiver<Notification<M>> {
78        self.0.stream()
79    }
80}
81
82impl<M> NotificationManager<M>
83where
84    M: RexMessage,
85{
86    pub fn new(
87        processors: Vec<Box<dyn NotificationProcessor<M>>>,
88        join_set: &mut JoinSet<()>,
89        notification_queue: NotificationQueue<M>,
90    ) -> Self {
91        let processors: HashMap<M::Topic, Vec<UnboundedSender<Notification<M>>>> = processors
92            .into_iter()
93            .fold(HashMap::new(), |mut subscribers, mut processor| {
94                let subscriber_tx = processor.init(join_set);
95                for topic in processor.get_topics() {
96                    subscribers
97                        .entry(*topic)
98                        .or_default()
99                        .push(subscriber_tx.clone());
100                }
101                subscribers
102            });
103        Self {
104            processors: Arc::new(processors),
105            notification_queue,
106        }
107    }
108
109    pub fn init(&self, join_set: &mut JoinSet<()>) -> NotificationQueue<M> {
110        let stream_queue = self.notification_queue.clone();
111        let processors = self.processors.clone();
112        join_set.spawn(async move {
113            debug!(spawning = "NotificationManager.processors");
114            let mut stream = stream_queue.stream();
115            while let Some(notification) = stream.next().await {
116                trace!(?notification);
117                let topic = notification.get_topic();
118                if let Some(subscribers) = processors.get(&topic) {
119                    let Some((last, rest)) = subscribers.split_last() else {
120                        continue;
121                    };
122                    for tx in rest {
123                        tx.send(notification.clone()).log_attached_err(format!(
124                            "nm::processors send failed for topic {topic:?}"
125                        ));
126                    }
127                    last.send(notification).log_attached_err(format!(
128                        "nm::processors send last failed for topic {topic:?}"
129                    ));
130                } else {
131                    warn!(topic = ?notification.get_topic(), ?notification, "NotificationProcessor not found");
132                }
133            }
134        }.in_current_span());
135        self.notification_queue.clone()
136    }
137}
138
139pub trait NotificationProcessor<M>: Send + Sync
140where
141    M: RexMessage,
142{
143    fn init(&mut self, join_set: &mut JoinSet<()>) -> UnboundedSender<Notification<M>>;
144    fn get_topics(&self) -> &[M::Topic];
145}
146
147/// A message that is expected to return a result
148/// to the associated [`StateId`] that that did the initial request
149#[derive(Debug, Clone)]
150pub struct UnaryRequest<K, O>
151where
152    K: HashKind,
153    O: Operation,
154{
155    pub id: StateId<K>,
156    pub op: O,
157}
158
159impl<K, O> UnaryRequest<K, O>
160where
161    K: HashKind,
162    O: Operation,
163{
164    pub const fn new(id: StateId<K>, op: O) -> Self {
165        Self { id, op }
166    }
167}
168
169impl<K: HashKind, O: Operation + Copy> Copy for UnaryRequest<K, O> {}
170
171/// Defines the unit of work held by a [`UnaryRequest`]
172pub trait Operation: std::fmt::Display + Clone {}
173impl<Op> Operation for Op where Op: std::fmt::Display + Clone {}
174
175pub trait Request<K>
176where
177    K: HashKind,
178    Self: Operation,
179{
180    fn request(self, id: StateId<K>) -> UnaryRequest<K, Self>;
181}
182
183impl<K: Rex, Op: Operation> Request<K> for Op
184where
185    K::Message: From<UnaryRequest<K, Op>>,
186{
187    fn request(self, id: StateId<K>) -> UnaryRequest<K, Op> {
188        UnaryRequest { id, op: self }
189    }
190}
191
192pub trait RequestInner<K>
193where
194    K: HashKind,
195    Self: Sized,
196{
197    fn request_inner<Op>(self, id: StateId<K>) -> UnaryRequest<K, Op>
198    where
199        Op: Operation + From<Self>;
200}
201
202impl<K, T> RequestInner<K> for T
203where
204    K: HashKind,
205{
206    fn request_inner<Op>(self, id: StateId<K>) -> UnaryRequest<K, Op>
207    where
208        Op: Operation + From<T>,
209    {
210        UnaryRequest {
211            id,
212            op: self.into(),
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use std::time::Duration;
220
221    use super::*;
222    use crate::{test_support::*, StateId};
223
224    #[tokio::test]
225    async fn route_to_timeout_manager() {
226        use crate::timeout::*;
227
228        let timeout_manager = TimeoutManager::test_default();
229        let sq1 = timeout_manager.signal_queue.clone();
230        let timeout_manager_two = TimeoutManager::test_default();
231        let sq2 = timeout_manager_two.signal_queue.clone();
232        let mut join_set = JoinSet::new();
233        let notification_manager = NotificationManager::new(
234            vec![Box::new(timeout_manager), Box::new(timeout_manager_two)],
235            &mut join_set,
236            NotificationQueue::new(),
237        );
238        let notification_tx = notification_manager.init(&mut join_set);
239
240        let test_id = StateId::new_with_u128(TestKind, 1);
241        // this should timeout instantly
242        let timeout_duration = Duration::from_millis(1);
243
244        let set_timeout = Notification(TimeoutInput::set_timeout(test_id, timeout_duration).into());
245        notification_tx.send(set_timeout);
246
247        tokio::time::sleep(Duration::from_millis(10)).await;
248
249        let timeout_one = sq1.pop_front().expect("timeout one");
250        let timeout_two = sq2.pop_front().expect("timeout two");
251        assert_eq!(timeout_one.id, timeout_two.id);
252    }
253}