1use crate::{Recipients, UnlimitedSender};
4use bytes::Buf;
5use commonware_cryptography::PublicKey;
6use commonware_runtime::{Clock, KeyedRateLimiter, Quota};
7use commonware_utils::channels::ring;
8use futures::{lock::Mutex, Future, FutureExt, StreamExt};
9use std::{cmp, fmt, sync::Arc, time::SystemTime};
10
11pub trait Connected: Clone + Send + Sync + 'static {
16 type PublicKey: PublicKey;
17
18 fn subscribe(&mut self) -> impl Future<Output = ring::Receiver<Vec<Self::PublicKey>>> + Send;
25}
26
27pub struct LimitedSender<E, S, P>
29where
30 E: Clock,
31 S: UnlimitedSender,
32 P: Connected<PublicKey = S::PublicKey>,
33{
34 sender: S,
35 rate_limit: Arc<Mutex<KeyedRateLimiter<S::PublicKey, E>>>,
36 peers: P,
37 peer_subscription: Option<ring::Receiver<Vec<S::PublicKey>>>,
38 known_peers: Vec<S::PublicKey>,
39}
40
41impl<E, S, P> Clone for LimitedSender<E, S, P>
42where
43 E: Clock,
44 S: UnlimitedSender,
45 P: Connected<PublicKey = S::PublicKey>,
46{
47 fn clone(&self) -> Self {
48 Self {
49 sender: self.sender.clone(),
50 rate_limit: self.rate_limit.clone(),
51 peers: self.peers.clone(),
52 peer_subscription: None,
53 known_peers: Vec::new(),
54 }
55 }
56}
57
58impl<E, S, P> fmt::Debug for LimitedSender<E, S, P>
59where
60 E: Clock,
61 S: UnlimitedSender,
62 P: Connected<PublicKey = S::PublicKey>,
63{
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("LimitedSender")
66 .field("known_peers", &self.known_peers.len())
67 .finish_non_exhaustive()
68 }
69}
70
71impl<E, S, P> LimitedSender<E, S, P>
72where
73 E: Clock,
74 S: UnlimitedSender,
75 P: Connected<PublicKey = S::PublicKey>,
76{
77 pub fn new(sender: S, quota: Quota, clock: E, peers: P) -> Self {
79 let rate_limit = Arc::new(Mutex::new(KeyedRateLimiter::hashmap_with_clock(
80 quota, clock,
81 )));
82 Self {
83 sender,
84 rate_limit,
85 peers,
86 peer_subscription: None,
87 known_peers: Vec::new(),
88 }
89 }
90
91 pub async fn check(
97 &mut self,
98 recipients: Recipients<S::PublicKey>,
99 ) -> Result<CheckedSender<'_, S>, SystemTime> {
100 if self.peer_subscription.is_none() {
102 self.peer_subscription = Some(self.peers.subscribe().await);
103 }
104
105 let rate_limit = self.rate_limit.lock().await;
106
107 if let Some(ref mut subscription) = self.peer_subscription {
109 if let Some(peers) = subscription.next().now_or_never().flatten() {
110 self.known_peers = peers;
111 rate_limit.retain_recent();
112 }
113 }
114
115 let recipients = match recipients {
116 Recipients::One(ref peer) => match rate_limit.check_key(peer) {
117 Ok(()) => recipients,
118 Err(not_until) => return Err(not_until.earliest_possible()),
119 },
120 Recipients::Some(ref peers) => {
121 let (allowed, max_retry) = filter_rate_limited(peers.iter(), &rate_limit);
122 if allowed.is_empty() {
123 match max_retry {
124 Some(retry) => return Err(retry),
125 None => recipients,
126 }
127 } else {
128 Recipients::Some(allowed)
129 }
130 }
131 Recipients::All => {
132 let (allowed, max_retry) =
133 filter_rate_limited(self.known_peers.iter(), &rate_limit);
134 if allowed.is_empty() {
135 match max_retry {
136 Some(retry) => return Err(retry),
137 None => Recipients::Some(Vec::new()),
138 }
139 } else {
140 Recipients::Some(allowed)
141 }
142 }
143 };
144
145 Ok(CheckedSender {
146 recipients,
147 sender: &mut self.sender,
148 })
149 }
150}
151
152pub(crate) fn filter_rate_limited<'a, K, C>(
155 peers: impl Iterator<Item = &'a K>,
156 rate_limit: &KeyedRateLimiter<K, C>,
157) -> (Vec<K>, Option<SystemTime>)
158where
159 K: PublicKey,
160 C: Clock,
161{
162 peers.fold(
163 (Vec::new(), None),
164 |(mut allowed, max_retry), p| match rate_limit.check_key(p) {
165 Ok(()) => {
166 allowed.push(p.clone());
167 (allowed, max_retry)
168 }
169 Err(not_until) => {
170 let earliest = not_until.earliest_possible();
171 let new_max = max_retry.map_or(earliest, |current| cmp::max(current, earliest));
172 (allowed, Some(new_max))
173 }
174 },
175 )
176}
177
178#[derive(Debug)]
183pub struct CheckedSender<'a, S: UnlimitedSender> {
184 sender: &'a mut S,
185 recipients: Recipients<S::PublicKey>,
186}
187
188impl<'a, S: UnlimitedSender> CheckedSender<'a, S> {
189 pub(crate) fn into_inner(self) -> &'a mut S {
196 self.sender
197 }
198}
199
200impl<'a, S: UnlimitedSender> crate::CheckedSender for CheckedSender<'a, S> {
201 type PublicKey = S::PublicKey;
202 type Error = S::Error;
203
204 async fn send(
205 self,
206 message: impl Buf + Send,
207 priority: bool,
208 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
209 self.sender.send(self.recipients, message, priority).await
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::CheckedSender as _;
217 use bytes::Bytes;
218 use commonware_cryptography::{ed25519, Signer as _};
219 use commonware_runtime::{deterministic::Runner, Quota, Runner as _};
220 use commonware_utils::{channels::ring, NZUsize, NZU32};
221 use thiserror::Error;
222
223 type PublicKey = ed25519::PublicKey;
224 type SentMessage = (Recipients<PublicKey>, Bytes, bool);
225
226 #[derive(Debug, Error)]
227 #[error("mock send error")]
228 struct MockError;
229
230 #[derive(Debug, Clone)]
231 struct MockSender {
232 sent: Arc<Mutex<Vec<SentMessage>>>,
233 }
234
235 impl MockSender {
236 fn new() -> Self {
237 Self {
238 sent: Arc::new(Mutex::new(Vec::new())),
239 }
240 }
241
242 async fn sent_messages(&self) -> Vec<SentMessage> {
243 self.sent.lock().await.clone()
244 }
245 }
246
247 impl UnlimitedSender for MockSender {
248 type Error = MockError;
249 type PublicKey = PublicKey;
250
251 async fn send(
252 &mut self,
253 recipients: Recipients<Self::PublicKey>,
254 mut message: impl Buf + Send,
255 priority: bool,
256 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
257 let sent_to = match &recipients {
258 Recipients::One(pk) => vec![pk.clone()],
259 Recipients::Some(pks) => pks.clone(),
260 Recipients::All => Vec::new(),
261 };
262 let message = message.copy_to_bytes(message.remaining());
263 self.sent.lock().await.push((recipients, message, priority));
264 Ok(sent_to)
265 }
266 }
267
268 #[derive(Clone)]
269 struct MockPeers {
270 sender: ring::Sender<Vec<PublicKey>>,
271 }
272
273 impl MockPeers {
274 fn new() -> (Self, ring::Sender<Vec<PublicKey>>) {
275 let (sender, _receiver) = ring::channel(NZUsize!(16));
276 let peers = Self {
277 sender: sender.clone(),
278 };
279 (peers, sender)
280 }
281 }
282
283 impl Connected for MockPeers {
284 type PublicKey = PublicKey;
285
286 async fn subscribe(&mut self) -> ring::Receiver<Vec<Self::PublicKey>> {
287 let (sender, receiver) = ring::channel(NZUsize!(16));
288 self.sender = sender;
290 receiver
291 }
292 }
293
294 fn key(seed: u64) -> PublicKey {
295 ed25519::PrivateKey::from_seed(seed).public_key()
296 }
297
298 fn quota_per_second(n: u32) -> Quota {
299 Quota::per_second(NZU32!(n))
300 }
301
302 #[test]
303 fn check_one_not_rate_limited() {
304 Runner::default().start(|context| async move {
305 let sender = MockSender::new();
306 let (peers, _peer_sender) = MockPeers::new();
307 let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
308
309 let peer = key(1);
310 let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
311 let sent_to = checked.send(Bytes::from("hello"), false).await.unwrap();
312 assert_eq!(sent_to, vec![peer]);
313 });
314 }
315
316 #[test]
317 fn check_one_rate_limited() {
318 Runner::default().start(|context| async move {
319 let sender = MockSender::new();
320 let (peers, _peer_sender) = MockPeers::new();
321 let mut limited =
322 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
323
324 let peer = key(1);
325
326 let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
328 checked.send(Bytes::from("first"), false).await.unwrap();
329
330 let result = limited.check(Recipients::One(peer)).await;
332 assert!(result.is_err());
333 });
334 }
335
336 #[test]
337 fn check_some_all_not_rate_limited() {
338 Runner::default().start(|context| async move {
339 let sender = MockSender::new();
340 let (peers, _peer_sender) = MockPeers::new();
341 let mut limited = LimitedSender::new(sender, quota_per_second(1), context, peers);
342
343 let peers_list = vec![key(1), key(2), key(3)];
344 let checked = limited
345 .check(Recipients::Some(peers_list.clone()))
346 .await
347 .unwrap();
348 let sent_to = checked.send(Bytes::from("hello"), false).await.unwrap();
349 assert_eq!(sent_to.len(), 3);
350 });
351 }
352
353 #[test]
354 fn check_some_filters_rate_limited_peers() {
355 Runner::default().start(|context| async move {
356 let sender = MockSender::new();
357 let (peers, _peer_sender) = MockPeers::new();
358 let mut limited =
359 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
360
361 let peer1 = key(1);
362 let peer2 = key(2);
363 let peer3 = key(3);
364
365 let checked = limited.check(Recipients::One(peer1.clone())).await.unwrap();
367 checked.send(Bytes::from("limit"), false).await.unwrap();
368
369 let checked = limited
371 .check(Recipients::Some(vec![
372 peer1.clone(),
373 peer2.clone(),
374 peer3.clone(),
375 ]))
376 .await
377 .unwrap();
378 let sent_to = checked.send(Bytes::from("filtered"), false).await.unwrap();
379
380 assert_eq!(sent_to.len(), 2);
382 assert!(!sent_to.contains(&peer1));
383 assert!(sent_to.contains(&peer2));
384 assert!(sent_to.contains(&peer3));
385 });
386 }
387
388 #[test]
389 fn check_some_all_rate_limited_returns_error() {
390 Runner::default().start(|context| async move {
391 let sender = MockSender::new();
392 let (peers, _peer_sender) = MockPeers::new();
393 let mut limited =
394 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
395
396 let peer1 = key(1);
397 let peer2 = key(2);
398
399 limited
401 .check(Recipients::One(peer1.clone()))
402 .await
403 .unwrap()
404 .send(Bytes::from("limit1"), false)
405 .await
406 .unwrap();
407
408 limited
409 .check(Recipients::One(peer2.clone()))
410 .await
411 .unwrap()
412 .send(Bytes::from("limit2"), false)
413 .await
414 .unwrap();
415
416 assert!(limited
418 .check(Recipients::Some(vec![peer1, peer2]))
419 .await
420 .is_err());
421 });
422 }
423
424 #[test]
425 fn check_some_empty_returns_as_is() {
426 Runner::default().start(|context| async move {
427 let sender = MockSender::new();
428 let (peers, _peer_sender) = MockPeers::new();
429 let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
430
431 limited.check(Recipients::Some(Vec::new())).await.unwrap();
433 });
434 }
435
436 #[test]
437 fn check_all_uses_known_peers() {
438 Runner::default().start(|context| async move {
439 let sender = MockSender::new();
440 let (peers, _) = MockPeers::new();
441 let mut limited =
442 LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
443
444 let checked = limited.check(Recipients::All).await.unwrap();
446 let sent_to = checked.send(Bytes::from("empty"), false).await.unwrap();
447 assert!(sent_to.is_empty());
448
449 let messages = sender.sent_messages().await;
451 assert_eq!(messages.len(), 1);
452 match &messages[0].0 {
453 Recipients::Some(pks) => assert!(pks.is_empty()),
454 _ => panic!("expected Recipients::Some"),
455 }
456 });
457 }
458
459 #[test]
460 fn check_all_filters_rate_limited_known_peers() {
461 Runner::default().start(|context| async move {
462 let sender = MockSender::new();
463 let (peers, _) = MockPeers::new();
464 let mut limited =
465 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
466
467 let peer1 = key(1);
468 let peer2 = key(2);
469
470 let _ = limited.check(Recipients::All).await;
472
473 limited.known_peers = vec![peer1.clone(), peer2.clone()];
475
476 limited
478 .check(Recipients::One(peer1.clone()))
479 .await
480 .unwrap()
481 .send(Bytes::from("limit"), false)
482 .await
483 .unwrap();
484
485 let checked = limited.check(Recipients::All).await.unwrap();
487 let sent_to = checked.send(Bytes::from("filtered"), false).await.unwrap();
488
489 assert_eq!(sent_to.len(), 1);
490 assert!(!sent_to.contains(&peer1));
491 assert!(sent_to.contains(&peer2));
492 });
493 }
494
495 #[test]
496 fn check_all_returns_error_when_all_known_peers_rate_limited() {
497 Runner::default().start(|context| async move {
498 let sender = MockSender::new();
499 let (peers, _) = MockPeers::new();
500 let mut limited =
501 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
502
503 let peer1 = key(1);
504 let peer2 = key(2);
505
506 let _ = limited.check(Recipients::All).await;
508
509 limited.known_peers = vec![peer1.clone(), peer2.clone()];
511
512 limited
514 .check(Recipients::One(peer1.clone()))
515 .await
516 .unwrap()
517 .send(Bytes::from("limit1"), false)
518 .await
519 .unwrap();
520
521 limited
522 .check(Recipients::One(peer2.clone()))
523 .await
524 .unwrap()
525 .send(Bytes::from("limit2"), false)
526 .await
527 .unwrap();
528
529 assert!(limited.check(Recipients::All).await.is_err());
531 });
532 }
533
534 #[test]
535 fn clone_creates_independent_subscription() {
536 Runner::default().start(|context| async move {
537 let sender = MockSender::new();
538 let (peers, _) = MockPeers::new();
539 let mut limited1 = LimitedSender::new(sender, quota_per_second(10), context, peers);
540
541 let _ = limited1.check(Recipients::All).await;
543 limited1.known_peers = vec![key(1)];
544
545 let limited2 = limited1.clone();
547 assert!(limited2.peer_subscription.is_none());
548 assert!(limited2.known_peers.is_empty());
549 });
550 }
551
552 #[test]
553 fn checked_sender_sends_with_priority() {
554 Runner::default().start(|context| async move {
555 let sender = MockSender::new();
556 let (peers, _peer_sender) = MockPeers::new();
557 let mut limited =
558 LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
559
560 let peer = key(1);
561 limited
562 .check(Recipients::One(peer))
563 .await
564 .unwrap()
565 .send(Bytes::from("priority"), true)
566 .await
567 .unwrap();
568
569 let messages = sender.sent_messages().await;
570 assert_eq!(messages.len(), 1);
571 assert!(messages[0].2); });
573 }
574
575 #[test]
576 fn rate_limit_shared_across_clones() {
577 Runner::default().start(|context| async move {
578 let sender = MockSender::new();
579 let (peers, _) = MockPeers::new();
580 let mut limited1 =
581 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
582 let mut limited2 = limited1.clone();
583
584 let peer = key(1);
585
586 limited1
588 .check(Recipients::One(peer.clone()))
589 .await
590 .unwrap()
591 .send(Bytes::from("limit"), false)
592 .await
593 .unwrap();
594
595 assert!(limited2.check(Recipients::One(peer)).await.is_err());
597 });
598 }
599}