Skip to main content

linera_core/
notifier.rs

1// Copyright (c) Zefchain Labs, Inc.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use linera_base::identifiers::ChainId;
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8use tracing::trace;
9
10use crate::worker;
11
12// TODO(#2171): replace this with a Tokio broadcast channel
13
14/// A `Notifier` holds references to clients waiting to receive notifications
15/// from the validator.
16/// Clients will be evicted if their connections are terminated.
17pub struct ChannelNotifier<N> {
18    inner: papaya::HashMap<ChainId, Vec<UnboundedSender<N>>>,
19}
20
21impl<N> Default for ChannelNotifier<N> {
22    fn default() -> Self {
23        Self {
24            inner: papaya::HashMap::default(),
25        }
26    }
27}
28
29impl<N> ChannelNotifier<N> {
30    /// Registers a sender for notifications on the given chain IDs.
31    pub fn add_sender(&self, chain_ids: Vec<ChainId>, sender: &UnboundedSender<N>) {
32        let pinned = self.inner.pin();
33        for id in chain_ids {
34            pinned.update_or_insert_with(
35                id,
36                |senders| senders.iter().cloned().chain([sender.clone()]).collect(),
37                || vec![sender.clone()],
38            );
39        }
40    }
41
42    /// Creates a subscription given a collection of chain IDs and a sender to the client.
43    pub fn subscribe(&self, chain_ids: Vec<ChainId>) -> UnboundedReceiver<N> {
44        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
45        self.add_sender(chain_ids, &tx);
46        rx
47    }
48
49    /// Creates a subscription given a collection of chain IDs and a sender to the client.
50    /// Immediately posts a first notification as an ACK.
51    pub fn subscribe_with_ack(&self, chain_ids: Vec<ChainId>, ack: N) -> UnboundedReceiver<N> {
52        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
53        self.add_sender(chain_ids, &tx);
54        tx.send(ack)
55            .expect("pushing to a new channel should succeed");
56        rx
57    }
58}
59
60impl<N> ChannelNotifier<N>
61where
62    N: Clone,
63{
64    /// Notifies all the clients waiting for a notification from a given chain.
65    pub fn notify_chain(&self, chain_id: &ChainId, notification: &N) {
66        self.inner.pin().compute(*chain_id, |senders| {
67            let Some((_key, senders)) = senders else {
68                trace!("Chain {chain_id} has no subscribers.");
69                return papaya::Operation::Abort(());
70            };
71            let live_senders = senders
72                .iter()
73                .filter(|sender| sender.send(notification.clone()).is_ok())
74                .cloned()
75                .collect::<Vec<_>>();
76            if live_senders.is_empty() {
77                trace!("No more subscribers for chain {chain_id}. Removing entry.");
78                return papaya::Operation::Remove;
79            }
80            papaya::Operation::Insert(live_senders)
81        });
82    }
83}
84
85pub trait Notifier: Clone + Send + 'static {
86    fn notify(&self, notifications: &[worker::Notification]);
87}
88
89impl Notifier for Arc<ChannelNotifier<worker::Notification>> {
90    fn notify(&self, notifications: &[worker::Notification]) {
91        for notification in notifications {
92            self.notify_chain(&notification.chain_id, notification);
93        }
94    }
95}
96
97impl Notifier for () {
98    fn notify(&self, _notifications: &[worker::Notification]) {}
99}
100
101#[cfg(with_testing)]
102impl Notifier for Arc<std::sync::Mutex<Vec<worker::Notification>>> {
103    fn notify(&self, notifications: &[worker::Notification]) {
104        let mut guard = self.lock().unwrap();
105        guard.extend(notifications.iter().cloned())
106    }
107}
108
109#[cfg(test)]
110pub mod tests {
111    use std::{
112        sync::{atomic::Ordering, Arc},
113        time::Duration,
114    };
115
116    use linera_execution::test_utils::dummy_chain_description;
117
118    use super::*;
119
120    #[test]
121    fn test_concurrent() {
122        let notifier = ChannelNotifier::default();
123
124        let chain_a = dummy_chain_description(0).id();
125        let chain_b = dummy_chain_description(1).id();
126
127        let a_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
128        let b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
129        let a_b_rec = Arc::new(std::sync::atomic::AtomicUsize::new(0));
130
131        let mut rx_a = notifier.subscribe(vec![chain_a]);
132        let mut rx_b = notifier.subscribe(vec![chain_b]);
133        let mut rx_a_b = notifier.subscribe(vec![chain_a, chain_b]);
134
135        let a_rec_clone = a_rec.clone();
136        let b_rec_clone = b_rec.clone();
137        let a_b_rec_clone = a_b_rec.clone();
138
139        let notifier = Arc::new(notifier);
140
141        std::thread::spawn(move || {
142            while rx_a.blocking_recv().is_some() {
143                a_rec_clone.fetch_add(1, Ordering::Relaxed);
144            }
145        });
146
147        std::thread::spawn(move || {
148            while rx_b.blocking_recv().is_some() {
149                b_rec_clone.fetch_add(1, Ordering::Relaxed);
150            }
151        });
152
153        std::thread::spawn(move || {
154            while rx_a_b.blocking_recv().is_some() {
155                a_b_rec_clone.fetch_add(1, Ordering::Relaxed);
156            }
157        });
158
159        const NOTIFICATIONS_A: usize = 500;
160        const NOTIFICATIONS_B: usize = 700;
161
162        let a_notifier = notifier.clone();
163        let handle_a = std::thread::spawn(move || {
164            for _ in 0..NOTIFICATIONS_A {
165                a_notifier.notify_chain(&chain_a, &());
166            }
167        });
168
169        let handle_b = std::thread::spawn(move || {
170            for _ in 0..NOTIFICATIONS_B {
171                notifier.notify_chain(&chain_b, &());
172            }
173        });
174
175        // finish sending all the messages
176        handle_a.join().unwrap();
177        handle_b.join().unwrap();
178
179        // give some time for the messages to be received.
180        std::thread::sleep(Duration::from_millis(100));
181
182        assert_eq!(a_rec.load(Ordering::Relaxed), NOTIFICATIONS_A);
183        assert_eq!(b_rec.load(Ordering::Relaxed), NOTIFICATIONS_B);
184        assert_eq!(
185            a_b_rec.load(Ordering::Relaxed),
186            NOTIFICATIONS_A + NOTIFICATIONS_B
187        );
188    }
189
190    #[test]
191    fn test_eviction() {
192        let notifier = ChannelNotifier::default();
193
194        let chain_a = dummy_chain_description(0).id();
195        let chain_b = dummy_chain_description(1).id();
196        let chain_c = dummy_chain_description(2).id();
197        let chain_d = dummy_chain_description(3).id();
198
199        // Chain A -> Notify A, Notify B
200        // Chain B -> Notify A, Notify B
201        // Chain C -> Notify C
202        // Chain D -> Notify A, Notify B, Notify C, Notify D
203
204        let mut rx_a = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
205        let mut rx_b = notifier.subscribe(vec![chain_a, chain_b, chain_d]);
206        let mut rx_c = notifier.subscribe(vec![chain_c, chain_d]);
207        let mut rx_d = notifier.subscribe(vec![chain_d]);
208
209        assert_eq!(notifier.inner.len(), 4);
210
211        rx_c.close();
212        notifier.notify_chain(&chain_c, &());
213        assert_eq!(notifier.inner.len(), 3);
214
215        rx_a.close();
216        notifier.notify_chain(&chain_a, &());
217        assert_eq!(notifier.inner.len(), 3);
218
219        rx_b.close();
220        notifier.notify_chain(&chain_b, &());
221        assert_eq!(notifier.inner.len(), 2);
222
223        notifier.notify_chain(&chain_a, &());
224        assert_eq!(notifier.inner.len(), 1);
225
226        rx_d.close();
227        notifier.notify_chain(&chain_d, &());
228        assert_eq!(notifier.inner.len(), 0);
229    }
230}