1use std::sync::Arc;
5
6use linera_base::identifiers::ChainId;
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8use tracing::trace;
9
10use crate::worker;
11
12pub struct ChannelNotifier<N> {
18 inner: papaya::HashMap<ChainId, Vec<UnboundedSender<N>>>,
19}
20
21impl<N> Default for ChannelNotifier<N> {
22 fn default() -> Self {
23 Self {
24 inner: papaya::HashMap::default(),
25 }
26 }
27}
28
29impl<N> ChannelNotifier<N> {
30 pub fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
32 let pinned = self.inner.pin();
33 for id in chain_ids {
34 pinned.update_or_insert_with(
35 id,
36 |senders| senders.iter().cloned().chain([sender.clone()]).collect(),
37 || vec![sender.clone()],
38 );
39 }
40 }
41
42 pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
44 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
45 self.add_sender(chain_ids, &tx);
46 rx
47 }
48
49 pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
52 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
53 self.add_sender(chain_ids, &tx);
54 tx.send(ack)
55 .expect("pushing to a new channel should succeed");
56 rx
57 }
58}
59
60impl<N> ChannelNotifier<N>
61where
62 N: Clone,
63{
64 pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
66 self.inner.pin().compute(*chain_id, |senders| {
67 let Some((_key, senders)) = senders else {
68 trace!("Chain {chain_id} has no subscribers.");
69 return papaya::Operation::Abort(());
70 };
71 let live_senders = senders
72 .iter()
73 .filter(|sender| sender.send(notification.clone()).is_ok())
74 .cloned()
75 .collect::<Vec<_>>();
76 if live_senders.is_empty() {
77 trace!("No more subscribers for chain {chain_id}. Removing entry.");
78 return papaya::Operation::Remove;
79 }
80 papaya::Operation::Insert(live_senders)
81 });
82 }
83}
84
85pub trait Notifier: Clone + Send + 'static {
86 fn notify(&self, notifications: &[worker::Notification]);
87}
88
89impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
90 fn notify(&self, notifications: &[worker::Notification]) {
91 for notification in notifications {
92 self.notify_chain(¬ification.chain_id, notification);
93 }
94 }
95}
96
97impl Notifier for () {
98 fn notify(&self, _notifications: &[worker::Notification]) {}
99}
100
101#[cfg(with_testing)]
102impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
103 fn notify(&self, notifications: &[worker::Notification]) {
104 let mut guard = self.lock().unwrap();
105 guard.extend(notifications.iter().cloned())
106 }
107}
108
109#[cfg(test)]
110pub mod tests {
111 use std::{
112 sync::{atomic::Ordering, Arc},
113 time::Duration,
114 };
115
116 use linera_execution::test_utils::dummy_chain_description;
117
118 use super::*;
119
120 #[test]
121 fn test_concurrent() {
122 let notifier = ChannelNotifier::default();
123
124 let chain_a = dummy_chain_description(0).id();
125 let chain_b = dummy_chain_description(1).id();
126
127 let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
128 let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
129 let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
130
131 let mut rx_a = notifier.subscribe(vec![chain_a]);
132 let mut rx_b = notifier.subscribe(vec![chain_b]);
133 let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
134
135 let a_rec_clone = a_rec.clone();
136 let b_rec_clone = b_rec.clone();
137 let a_b_rec_clone = a_b_rec.clone();
138
139 let notifier = Arc::new(notifier);
140
141 std::thread::spawn(move || {
142 while rx_a.blocking_recv().is_some() {
143 a_rec_clone.fetch_add(1, Ordering::Relaxed);
144 }
145 });
146
147 std::thread::spawn(move || {
148 while rx_b.blocking_recv().is_some() {
149 b_rec_clone.fetch_add(1, Ordering::Relaxed);
150 }
151 });
152
153 std::thread::spawn(move || {
154 while rx_a_b.blocking_recv().is_some() {
155 a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
156 }
157 });
158
159 const NOTIFICATIONS_A: usize = 500;
160 const NOTIFICATIONS_B: usize = 700;
161
162 let a_notifier = notifier.clone();
163 let handle_a = std::thread::spawn(move || {
164 for _ in 0..NOTIFICATIONS_A {
165 a_notifier.notify_chain(&chain_a, &());
166 }
167 });
168
169 let handle_b = std::thread::spawn(move || {
170 for _ in 0..NOTIFICATIONS_B {
171 notifier.notify_chain(&chain_b, &());
172 }
173 });
174
175 handle_a.join().unwrap();
177 handle_b.join().unwrap();
178
179 std::thread::sleep(Duration::from_millis(100));
181
182 assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
183 assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
184 assert_eq!(
185 a_b_rec.load(Ordering::Relaxed),
186 NOTIFICATIONS_A + NOTIFICATIONS_B
187 );
188 }
189
190 #[test]
191 fn test_eviction() {
192 let notifier = ChannelNotifier::default();
193
194 let chain_a = dummy_chain_description(0).id();
195 let chain_b = dummy_chain_description(1).id();
196 let chain_c = dummy_chain_description(2).id();
197 let chain_d = dummy_chain_description(3).id();
198
199 let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
205 let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
206 let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
207 let mut rx_d = notifier.subscribe(vec![chain_d]);
208
209 assert_eq!(notifier.inner.len(), 4);
210
211 rx_c.close();
212 notifier.notify_chain(&chain_c, &());
213 assert_eq!(notifier.inner.len(), 3);
214
215 rx_a.close();
216 notifier.notify_chain(&chain_a, &());
217 assert_eq!(notifier.inner.len(), 3);
218
219 rx_b.close();
220 notifier.notify_chain(&chain_b, &());
221 assert_eq!(notifier.inner.len(), 2);
222
223 notifier.notify_chain(&chain_a, &());
224 assert_eq!(notifier.inner.len(), 1);
225
226 rx_d.close();
227 notifier.notify_chain(&chain_d, &());
228 assert_eq!(notifier.inner.len(), 0);
229 }
230}