commonware_p2p/utils/
limited.rs

1//! Rate-limited [`UnlimitedSender`] wrapper.
2
3use crate::{Recipients, UnlimitedSender};
4use bytes::Bytes;
5use commonware_cryptography::PublicKey;
6use commonware_runtime::{Clock, KeyedRateLimiter, Quota};
7use commonware_utils::channels::ring;
8use futures::{lock::Mutex, Future, FutureExt, StreamExt};
9use std::{cmp, fmt, sync::Arc, time::SystemTime};
10
11/// Provides peer subscriptions for resolving [`Recipients::All`].
12///
13/// Implementations must be clonable so that each clone of [`LimitedSender`]
14/// can establish its own peer subscription.
15pub trait Connected: Clone + Send + Sync + 'static {
16    type PublicKey: PublicKey;
17
18    /// Subscribe to peer updates.
19    ///
20    /// Returns a receiver that yields the current set of known peers whenever it changes.
21    ///
22    /// It is assumed that when a new subscription is created, the current set of known peers
23    /// is sent immediately.
24    fn subscribe(&mut self) -> impl Future<Output = ring::Receiver<Vec<Self::PublicKey>>> + Send;
25}
26
27/// A wrapper around a [`UnlimitedSender`] that provides rate limiting with retry-time feedback.
28pub struct LimitedSender<E, S, P>
29where
30    E: Clock,
31    S: UnlimitedSender,
32    P: Connected<PublicKey = S::PublicKey>,
33{
34    sender: S,
35    rate_limit: Arc<Mutex<KeyedRateLimiter<S::PublicKey, E>>>,
36    peers: P,
37    peer_subscription: Option<ring::Receiver<Vec<S::PublicKey>>>,
38    known_peers: Vec<S::PublicKey>,
39}
40
41impl<E, S, P> Clone for LimitedSender<E, S, P>
42where
43    E: Clock,
44    S: UnlimitedSender,
45    P: Connected<PublicKey = S::PublicKey>,
46{
47    fn clone(&self) -> Self {
48        Self {
49            sender: self.sender.clone(),
50            rate_limit: self.rate_limit.clone(),
51            peers: self.peers.clone(),
52            peer_subscription: None,
53            known_peers: Vec::new(),
54        }
55    }
56}
57
58impl<E, S, P> fmt::Debug for LimitedSender<E, S, P>
59where
60    E: Clock,
61    S: UnlimitedSender,
62    P: Connected<PublicKey = S::PublicKey>,
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("LimitedSender")
66            .field("known_peers", &self.known_peers.len())
67            .finish_non_exhaustive()
68    }
69}
70
71impl<E, S, P> LimitedSender<E, S, P>
72where
73    E: Clock,
74    S: UnlimitedSender,
75    P: Connected<PublicKey = S::PublicKey>,
76{
77    /// Create a new [`LimitedSender`] with the given sender, [`Quota`], and peer source.
78    pub fn new(sender: S, quota: Quota, clock: E, peers: P) -> Self {
79        let rate_limit = Arc::new(Mutex::new(KeyedRateLimiter::hashmap_with_clock(
80            quota, clock,
81        )));
82        Self {
83            sender,
84            rate_limit,
85            peers,
86            peer_subscription: None,
87            known_peers: Vec::new(),
88        }
89    }
90
91    /// Check that a given set of [`Recipients`] are within the rate limit.
92    ///
93    /// Returns a [`CheckedSender`] with only the recipients that are not
94    /// currently rate-limited. If _all_ recipients are rate-limited, returns
95    /// the earliest instant at which all recipients will be available.
96    pub async fn check(
97        &mut self,
98        recipients: Recipients<S::PublicKey>,
99    ) -> Result<CheckedSender<'_, S>, SystemTime> {
100        // Lazily establish peer subscription on first use
101        if self.peer_subscription.is_none() {
102            self.peer_subscription = Some(self.peers.subscribe().await);
103        }
104
105        let rate_limit = self.rate_limit.lock().await;
106
107        // Update known peers from subscription if available (non-blocking)
108        if let Some(ref mut subscription) = self.peer_subscription {
109            if let Some(peers) = subscription.next().now_or_never().flatten() {
110                self.known_peers = peers;
111                rate_limit.retain_recent();
112            }
113        }
114
115        let recipients = match recipients {
116            Recipients::One(ref peer) => match rate_limit.check_key(peer) {
117                Ok(()) => recipients,
118                Err(not_until) => return Err(not_until.earliest_possible()),
119            },
120            Recipients::Some(ref peers) => {
121                let (allowed, max_retry) = filter_rate_limited(peers.iter(), &rate_limit);
122                if allowed.is_empty() {
123                    match max_retry {
124                        Some(retry) => return Err(retry),
125                        None => recipients,
126                    }
127                } else {
128                    Recipients::Some(allowed)
129                }
130            }
131            Recipients::All => {
132                let (allowed, max_retry) =
133                    filter_rate_limited(self.known_peers.iter(), &rate_limit);
134                if allowed.is_empty() {
135                    match max_retry {
136                        Some(retry) => return Err(retry),
137                        None => Recipients::Some(Vec::new()),
138                    }
139                } else {
140                    Recipients::Some(allowed)
141                }
142            }
143        };
144
145        Ok(CheckedSender {
146            recipients,
147            sender: &mut self.sender,
148        })
149    }
150}
151
152/// Filters peers by rate limit, returning those that pass and the latest retry
153/// time among those that don't.
154pub(crate) fn filter_rate_limited<'a, K, C>(
155    peers: impl Iterator<Item = &'a K>,
156    rate_limit: &KeyedRateLimiter<K, C>,
157) -> (Vec<K>, Option<SystemTime>)
158where
159    K: PublicKey,
160    C: Clock,
161{
162    peers.fold(
163        (Vec::new(), None),
164        |(mut allowed, max_retry), p| match rate_limit.check_key(p) {
165            Ok(()) => {
166                allowed.push(p.clone());
167                (allowed, max_retry)
168            }
169            Err(not_until) => {
170                let earliest = not_until.earliest_possible();
171                let new_max = max_retry.map_or(earliest, |current| cmp::max(current, earliest));
172                (allowed, Some(new_max))
173            }
174        },
175    )
176}
177
178/// An exclusive reference to an [`UnlimitedSender`] with a pre-checked list of
179/// recipients that are not currently rate-limited.
180///
181/// A [`CheckedSender`] can only be acquired via [`LimitedSender::check`].
182#[derive(Debug)]
183pub struct CheckedSender<'a, S: UnlimitedSender> {
184    sender: &'a mut S,
185    recipients: Recipients<S::PublicKey>,
186}
187
188impl<'a, S: UnlimitedSender> CheckedSender<'a, S> {
189    /// Extracts the inner [`UnlimitedSender`] reference.
190    ///
191    /// # Warning
192    ///
193    /// Rate limiting has already been applied to the original recipients. Any
194    /// messages sent via the extracted sender will bypass the rate limiter.
195    pub(crate) fn into_inner(self) -> &'a mut S {
196        self.sender
197    }
198}
199
200impl<'a, S: UnlimitedSender> crate::CheckedSender for CheckedSender<'a, S> {
201    type PublicKey = S::PublicKey;
202    type Error = S::Error;
203
204    async fn send(
205        self,
206        message: Bytes,
207        priority: bool,
208    ) -> Result<Vec<Self::PublicKey>, Self::Error> {
209        self.sender.send(self.recipients, message, priority).await
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::CheckedSender as _;
217    use bytes::Bytes;
218    use commonware_cryptography::{ed25519, Signer as _};
219    use commonware_runtime::{deterministic::Runner, Quota, Runner as _};
220    use commonware_utils::{channels::ring, NZUsize, NZU32};
221    use thiserror::Error;
222
223    type PublicKey = ed25519::PublicKey;
224    type SentMessage = (Recipients<PublicKey>, Bytes, bool);
225
226    #[derive(Debug, Error)]
227    #[error("mock send error")]
228    struct MockError;
229
230    #[derive(Debug, Clone)]
231    struct MockSender {
232        sent: Arc<Mutex<Vec<SentMessage>>>,
233    }
234
235    impl MockSender {
236        fn new() -> Self {
237            Self {
238                sent: Arc::new(Mutex::new(Vec::new())),
239            }
240        }
241
242        async fn sent_messages(&self) -> Vec<SentMessage> {
243            self.sent.lock().await.clone()
244        }
245    }
246
247    impl UnlimitedSender for MockSender {
248        type Error = MockError;
249        type PublicKey = PublicKey;
250
251        async fn send(
252            &mut self,
253            recipients: Recipients<Self::PublicKey>,
254            message: Bytes,
255            priority: bool,
256        ) -> Result<Vec<Self::PublicKey>, Self::Error> {
257            let sent_to = match &recipients {
258                Recipients::One(pk) => vec![pk.clone()],
259                Recipients::Some(pks) => pks.clone(),
260                Recipients::All => Vec::new(),
261            };
262            self.sent.lock().await.push((recipients, message, priority));
263            Ok(sent_to)
264        }
265    }
266
267    #[derive(Clone)]
268    struct MockPeers {
269        sender: ring::Sender<Vec<PublicKey>>,
270    }
271
272    impl MockPeers {
273        fn new() -> (Self, ring::Sender<Vec<PublicKey>>) {
274            let (sender, _receiver) = ring::channel(NZUsize!(16));
275            let peers = Self {
276                sender: sender.clone(),
277            };
278            (peers, sender)
279        }
280    }
281
282    impl Connected for MockPeers {
283        type PublicKey = PublicKey;
284
285        async fn subscribe(&mut self) -> ring::Receiver<Vec<Self::PublicKey>> {
286            let (sender, receiver) = ring::channel(NZUsize!(16));
287            // Replace our sender with a new one connected to the returned receiver
288            self.sender = sender;
289            receiver
290        }
291    }
292
293    fn key(seed: u64) -> PublicKey {
294        ed25519::PrivateKey::from_seed(seed).public_key()
295    }
296
297    fn quota_per_second(n: u32) -> Quota {
298        Quota::per_second(NZU32!(n))
299    }
300
301    #[test]
302    fn check_one_not_rate_limited() {
303        Runner::default().start(|context| async move {
304            let sender = MockSender::new();
305            let (peers, _peer_sender) = MockPeers::new();
306            let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
307
308            let peer = key(1);
309            let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
310            let sent_to = checked.send(Bytes::from("hello"), false).await.unwrap();
311            assert_eq!(sent_to, vec![peer]);
312        });
313    }
314
315    #[test]
316    fn check_one_rate_limited() {
317        Runner::default().start(|context| async move {
318            let sender = MockSender::new();
319            let (peers, _peer_sender) = MockPeers::new();
320            let mut limited =
321                LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
322
323            let peer = key(1);
324
325            // First check should succeed and consume the quota
326            let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
327            checked.send(Bytes::from("first"), false).await.unwrap();
328
329            // Second check should fail (rate limited)
330            let result = limited.check(Recipients::One(peer)).await;
331            assert!(result.is_err());
332        });
333    }
334
335    #[test]
336    fn check_some_all_not_rate_limited() {
337        Runner::default().start(|context| async move {
338            let sender = MockSender::new();
339            let (peers, _peer_sender) = MockPeers::new();
340            let mut limited = LimitedSender::new(sender, quota_per_second(1), context, peers);
341
342            let peers_list = vec![key(1), key(2), key(3)];
343            let checked = limited
344                .check(Recipients::Some(peers_list.clone()))
345                .await
346                .unwrap();
347            let sent_to = checked.send(Bytes::from("hello"), false).await.unwrap();
348            assert_eq!(sent_to.len(), 3);
349        });
350    }
351
352    #[test]
353    fn check_some_filters_rate_limited_peers() {
354        Runner::default().start(|context| async move {
355            let sender = MockSender::new();
356            let (peers, _peer_sender) = MockPeers::new();
357            let mut limited =
358                LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
359
360            let peer1 = key(1);
361            let peer2 = key(2);
362            let peer3 = key(3);
363
364            // Rate limit peer1 by sending to it first
365            let checked = limited.check(Recipients::One(peer1.clone())).await.unwrap();
366            checked.send(Bytes::from("limit"), false).await.unwrap();
367
368            // Now check with all three peers - peer1 should be filtered out
369            let checked = limited
370                .check(Recipients::Some(vec![
371                    peer1.clone(),
372                    peer2.clone(),
373                    peer3.clone(),
374                ]))
375                .await
376                .unwrap();
377            let sent_to = checked.send(Bytes::from("filtered"), false).await.unwrap();
378
379            // peer1 should be filtered out since it's rate limited
380            assert_eq!(sent_to.len(), 2);
381            assert!(!sent_to.contains(&peer1));
382            assert!(sent_to.contains(&peer2));
383            assert!(sent_to.contains(&peer3));
384        });
385    }
386
387    #[test]
388    fn check_some_all_rate_limited_returns_error() {
389        Runner::default().start(|context| async move {
390            let sender = MockSender::new();
391            let (peers, _peer_sender) = MockPeers::new();
392            let mut limited =
393                LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
394
395            let peer1 = key(1);
396            let peer2 = key(2);
397
398            // Rate limit both peers
399            limited
400                .check(Recipients::One(peer1.clone()))
401                .await
402                .unwrap()
403                .send(Bytes::from("limit1"), false)
404                .await
405                .unwrap();
406
407            limited
408                .check(Recipients::One(peer2.clone()))
409                .await
410                .unwrap()
411                .send(Bytes::from("limit2"), false)
412                .await
413                .unwrap();
414
415            // Now both are rate limited - should return error with retry time
416            assert!(limited
417                .check(Recipients::Some(vec![peer1, peer2]))
418                .await
419                .is_err());
420        });
421    }
422
423    #[test]
424    fn check_some_empty_returns_as_is() {
425        Runner::default().start(|context| async move {
426            let sender = MockSender::new();
427            let (peers, _peer_sender) = MockPeers::new();
428            let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
429
430            // Empty recipients should pass through
431            limited.check(Recipients::Some(Vec::new())).await.unwrap();
432        });
433    }
434
435    #[test]
436    fn check_all_uses_known_peers() {
437        Runner::default().start(|context| async move {
438            let sender = MockSender::new();
439            let (peers, _) = MockPeers::new();
440            let mut limited =
441                LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
442
443            // First call establishes subscription - no known peers yet
444            let checked = limited.check(Recipients::All).await.unwrap();
445            let sent_to = checked.send(Bytes::from("empty"), false).await.unwrap();
446            assert!(sent_to.is_empty());
447
448            // Verify that the sender received the message with empty Recipients::Some
449            let messages = sender.sent_messages().await;
450            assert_eq!(messages.len(), 1);
451            match &messages[0].0 {
452                Recipients::Some(pks) => assert!(pks.is_empty()),
453                _ => panic!("expected Recipients::Some"),
454            }
455        });
456    }
457
458    #[test]
459    fn check_all_filters_rate_limited_known_peers() {
460        Runner::default().start(|context| async move {
461            let sender = MockSender::new();
462            let (peers, _) = MockPeers::new();
463            let mut limited =
464                LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
465
466            let peer1 = key(1);
467            let peer2 = key(2);
468
469            // First call to establish subscription
470            let _ = limited.check(Recipients::All).await;
471
472            // Manually set known peers (simulating peer updates)
473            limited.known_peers = vec![peer1.clone(), peer2.clone()];
474
475            // Rate limit peer1
476            limited
477                .check(Recipients::One(peer1.clone()))
478                .await
479                .unwrap()
480                .send(Bytes::from("limit"), false)
481                .await
482                .unwrap();
483
484            // Check All should filter out peer1
485            let checked = limited.check(Recipients::All).await.unwrap();
486            let sent_to = checked.send(Bytes::from("filtered"), false).await.unwrap();
487
488            assert_eq!(sent_to.len(), 1);
489            assert!(!sent_to.contains(&peer1));
490            assert!(sent_to.contains(&peer2));
491        });
492    }
493
494    #[test]
495    fn check_all_returns_error_when_all_known_peers_rate_limited() {
496        Runner::default().start(|context| async move {
497            let sender = MockSender::new();
498            let (peers, _) = MockPeers::new();
499            let mut limited =
500                LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
501
502            let peer1 = key(1);
503            let peer2 = key(2);
504
505            // First call to establish subscription
506            let _ = limited.check(Recipients::All).await;
507
508            // Set known peers
509            limited.known_peers = vec![peer1.clone(), peer2.clone()];
510
511            // Rate limit both peers
512            limited
513                .check(Recipients::One(peer1.clone()))
514                .await
515                .unwrap()
516                .send(Bytes::from("limit1"), false)
517                .await
518                .unwrap();
519
520            limited
521                .check(Recipients::One(peer2.clone()))
522                .await
523                .unwrap()
524                .send(Bytes::from("limit2"), false)
525                .await
526                .unwrap();
527
528            // Check All should fail since all known peers are rate limited
529            assert!(limited.check(Recipients::All).await.is_err());
530        });
531    }
532
533    #[test]
534    fn clone_creates_independent_subscription() {
535        Runner::default().start(|context| async move {
536            let sender = MockSender::new();
537            let (peers, _) = MockPeers::new();
538            let mut limited1 = LimitedSender::new(sender, quota_per_second(10), context, peers);
539
540            // Establish subscription on first instance
541            let _ = limited1.check(Recipients::All).await;
542            limited1.known_peers = vec![key(1)];
543
544            // Clone should not have a subscription or known peers
545            let limited2 = limited1.clone();
546            assert!(limited2.peer_subscription.is_none());
547            assert!(limited2.known_peers.is_empty());
548        });
549    }
550
551    #[test]
552    fn checked_sender_sends_with_priority() {
553        Runner::default().start(|context| async move {
554            let sender = MockSender::new();
555            let (peers, _peer_sender) = MockPeers::new();
556            let mut limited =
557                LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
558
559            let peer = key(1);
560            limited
561                .check(Recipients::One(peer))
562                .await
563                .unwrap()
564                .send(Bytes::from("priority"), true)
565                .await
566                .unwrap();
567
568            let messages = sender.sent_messages().await;
569            assert_eq!(messages.len(), 1);
570            assert!(messages[0].2); // priority flag
571        });
572    }
573
574    #[test]
575    fn rate_limit_shared_across_clones() {
576        Runner::default().start(|context| async move {
577            let sender = MockSender::new();
578            let (peers, _) = MockPeers::new();
579            let mut limited1 =
580                LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
581            let mut limited2 = limited1.clone();
582
583            let peer = key(1);
584
585            // Rate limit peer via first instance
586            limited1
587                .check(Recipients::One(peer.clone()))
588                .await
589                .unwrap()
590                .send(Bytes::from("limit"), false)
591                .await
592                .unwrap();
593
594            // Second instance should see the rate limit
595            assert!(limited2.check(Recipients::One(peer)).await.is_err());
596        });
597    }
598}