1use crate::{Recipients, UnlimitedSender};
4use commonware_cryptography::PublicKey;
5use commonware_runtime::{Clock, IoBufMut, KeyedRateLimiter, Quota};
6use commonware_utils::channel::ring;
7use futures::{lock::Mutex, Future, FutureExt, StreamExt};
8use std::{cmp, fmt, sync::Arc, time::SystemTime};
9
10pub trait Connected: Clone + Send + Sync + 'static {
15 type PublicKey: PublicKey;
16
17 fn subscribe(&mut self) -> impl Future<Output = ring::Receiver<Vec<Self::PublicKey>>> + Send;
24}
25
26pub struct LimitedSender<E, S, P>
28where
29 E: Clock,
30 S: UnlimitedSender,
31 P: Connected<PublicKey = S::PublicKey>,
32{
33 sender: S,
34 rate_limit: Arc<Mutex<KeyedRateLimiter<S::PublicKey, E>>>,
35 peers: P,
36 peer_subscription: Option<ring::Receiver<Vec<S::PublicKey>>>,
37 known_peers: Vec<S::PublicKey>,
38}
39
40impl<E, S, P> Clone for LimitedSender<E, S, P>
41where
42 E: Clock,
43 S: UnlimitedSender,
44 P: Connected<PublicKey = S::PublicKey>,
45{
46 fn clone(&self) -> Self {
47 Self {
48 sender: self.sender.clone(),
49 rate_limit: self.rate_limit.clone(),
50 peers: self.peers.clone(),
51 peer_subscription: None,
52 known_peers: Vec::new(),
53 }
54 }
55}
56
57impl<E, S, P> fmt::Debug for LimitedSender<E, S, P>
58where
59 E: Clock,
60 S: UnlimitedSender,
61 P: Connected<PublicKey = S::PublicKey>,
62{
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 f.debug_struct("LimitedSender")
65 .field("known_peers", &self.known_peers.len())
66 .finish_non_exhaustive()
67 }
68}
69
70impl<E, S, P> LimitedSender<E, S, P>
71where
72 E: Clock,
73 S: UnlimitedSender,
74 P: Connected<PublicKey = S::PublicKey>,
75{
76 pub fn new(sender: S, quota: Quota, clock: E, peers: P) -> Self {
78 let rate_limit = Arc::new(Mutex::new(KeyedRateLimiter::hashmap_with_clock(
79 quota, clock,
80 )));
81 Self {
82 sender,
83 rate_limit,
84 peers,
85 peer_subscription: None,
86 known_peers: Vec::new(),
87 }
88 }
89
90 pub async fn check(
96 &mut self,
97 recipients: Recipients<S::PublicKey>,
98 ) -> Result<CheckedSender<'_, S>, SystemTime> {
99 if self.peer_subscription.is_none() {
101 self.peer_subscription = Some(self.peers.subscribe().await);
102 }
103
104 let rate_limit = self.rate_limit.lock().await;
105
106 if let Some(ref mut subscription) = self.peer_subscription {
108 if let Some(peers) = subscription.next().now_or_never().flatten() {
109 self.known_peers = peers;
110 rate_limit.retain_recent();
111 }
112 }
113
114 let recipients = match recipients {
115 Recipients::One(ref peer) => match rate_limit.check_key(peer) {
116 Ok(()) => recipients,
117 Err(not_until) => return Err(not_until.earliest_possible()),
118 },
119 Recipients::Some(ref peers) => {
120 let (allowed, max_retry) = filter_rate_limited(peers.iter(), &rate_limit);
121 if allowed.is_empty() {
122 match max_retry {
123 Some(retry) => return Err(retry),
124 None => recipients,
125 }
126 } else {
127 Recipients::Some(allowed)
128 }
129 }
130 Recipients::All => {
131 let (allowed, max_retry) =
132 filter_rate_limited(self.known_peers.iter(), &rate_limit);
133 if allowed.is_empty() {
134 match max_retry {
135 Some(retry) => return Err(retry),
136 None => Recipients::Some(Vec::new()),
137 }
138 } else {
139 Recipients::Some(allowed)
140 }
141 }
142 };
143
144 Ok(CheckedSender {
145 recipients,
146 sender: &mut self.sender,
147 })
148 }
149}
150
151pub(crate) fn filter_rate_limited<'a, K, C>(
154 peers: impl Iterator<Item = &'a K>,
155 rate_limit: &KeyedRateLimiter<K, C>,
156) -> (Vec<K>, Option<SystemTime>)
157where
158 K: PublicKey,
159 C: Clock,
160{
161 peers.fold(
162 (Vec::new(), None),
163 |(mut allowed, max_retry), p| match rate_limit.check_key(p) {
164 Ok(()) => {
165 allowed.push(p.clone());
166 (allowed, max_retry)
167 }
168 Err(not_until) => {
169 let earliest = not_until.earliest_possible();
170 let new_max = max_retry.map_or(earliest, |current| cmp::max(current, earliest));
171 (allowed, Some(new_max))
172 }
173 },
174 )
175}
176
177#[derive(Debug)]
182pub struct CheckedSender<'a, S: UnlimitedSender> {
183 sender: &'a mut S,
184 recipients: Recipients<S::PublicKey>,
185}
186
187impl<'a, S: UnlimitedSender> CheckedSender<'a, S> {
188 #[commonware_macros::stability(ALPHA)]
195 pub(crate) fn into_inner(self) -> &'a mut S {
196 self.sender
197 }
198}
199
200impl<'a, S: UnlimitedSender> crate::CheckedSender for CheckedSender<'a, S> {
201 type PublicKey = S::PublicKey;
202 type Error = S::Error;
203
204 async fn send(
205 self,
206 message: impl Into<IoBufMut> + Send,
207 priority: bool,
208 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
209 self.sender.send(self.recipients, message, priority).await
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::CheckedSender as _;
217 use commonware_cryptography::{ed25519, Signer as _};
218 use commonware_runtime::{deterministic::Runner, IoBuf, Quota, Runner as _};
219 use commonware_utils::{channel::ring, NZUsize, NZU32};
220 use thiserror::Error;
221
222 type PublicKey = ed25519::PublicKey;
223 type SentMessage = (Recipients<PublicKey>, IoBuf, bool);
224
225 #[derive(Debug, Error)]
226 #[error("mock send error")]
227 struct MockError;
228
229 #[derive(Debug, Clone)]
230 struct MockSender {
231 sent: Arc<Mutex<Vec<SentMessage>>>,
232 }
233
234 impl MockSender {
235 fn new() -> Self {
236 Self {
237 sent: Arc::new(Mutex::new(Vec::new())),
238 }
239 }
240
241 async fn sent_messages(&self) -> Vec<SentMessage> {
242 self.sent.lock().await.clone()
243 }
244 }
245
246 impl UnlimitedSender for MockSender {
247 type Error = MockError;
248 type PublicKey = PublicKey;
249
250 async fn send(
251 &mut self,
252 recipients: Recipients<Self::PublicKey>,
253 message: impl Into<IoBufMut> + Send,
254 priority: bool,
255 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
256 let sent_to = match &recipients {
257 Recipients::One(pk) => vec![pk.clone()],
258 Recipients::Some(pks) => pks.clone(),
259 Recipients::All => Vec::new(),
260 };
261 let message = message.into().freeze();
262 self.sent.lock().await.push((recipients, message, priority));
263 Ok(sent_to)
264 }
265 }
266
267 #[derive(Clone)]
268 struct MockPeers {
269 sender: ring::Sender<Vec<PublicKey>>,
270 }
271
272 impl MockPeers {
273 fn new() -> (Self, ring::Sender<Vec<PublicKey>>) {
274 let (sender, _receiver) = ring::channel(NZUsize!(16));
275 let peers = Self {
276 sender: sender.clone(),
277 };
278 (peers, sender)
279 }
280 }
281
282 impl Connected for MockPeers {
283 type PublicKey = PublicKey;
284
285 async fn subscribe(&mut self) -> ring::Receiver<Vec<Self::PublicKey>> {
286 let (sender, receiver) = ring::channel(NZUsize!(16));
287 self.sender = sender;
289 receiver
290 }
291 }
292
293 fn key(seed: u64) -> PublicKey {
294 ed25519::PrivateKey::from_seed(seed).public_key()
295 }
296
297 fn quota_per_second(n: u32) -> Quota {
298 Quota::per_second(NZU32!(n))
299 }
300
301 #[test]
302 fn check_one_not_rate_limited() {
303 Runner::default().start(|context| async move {
304 let sender = MockSender::new();
305 let (peers, _peer_sender) = MockPeers::new();
306 let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
307
308 let peer = key(1);
309 let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
310 let sent_to = checked.send(IoBuf::from(b"hello"), false).await.unwrap();
311 assert_eq!(sent_to, vec![peer]);
312 });
313 }
314
315 #[test]
316 fn check_one_rate_limited() {
317 Runner::default().start(|context| async move {
318 let sender = MockSender::new();
319 let (peers, _peer_sender) = MockPeers::new();
320 let mut limited =
321 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
322
323 let peer = key(1);
324
325 let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
327 checked.send(IoBuf::from(b"first"), false).await.unwrap();
328
329 let result = limited.check(Recipients::One(peer)).await;
331 assert!(result.is_err());
332 });
333 }
334
335 #[test]
336 fn check_some_all_not_rate_limited() {
337 Runner::default().start(|context| async move {
338 let sender = MockSender::new();
339 let (peers, _peer_sender) = MockPeers::new();
340 let mut limited = LimitedSender::new(sender, quota_per_second(1), context, peers);
341
342 let peers_list = vec![key(1), key(2), key(3)];
343 let checked = limited
344 .check(Recipients::Some(peers_list.clone()))
345 .await
346 .unwrap();
347 let sent_to = checked.send(IoBuf::from(b"hello"), false).await.unwrap();
348 assert_eq!(sent_to.len(), 3);
349 });
350 }
351
352 #[test]
353 fn check_some_filters_rate_limited_peers() {
354 Runner::default().start(|context| async move {
355 let sender = MockSender::new();
356 let (peers, _peer_sender) = MockPeers::new();
357 let mut limited =
358 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
359
360 let peer1 = key(1);
361 let peer2 = key(2);
362 let peer3 = key(3);
363
364 let checked = limited.check(Recipients::One(peer1.clone())).await.unwrap();
366 checked.send(IoBuf::from(b"limit"), false).await.unwrap();
367
368 let checked = limited
370 .check(Recipients::Some(vec![
371 peer1.clone(),
372 peer2.clone(),
373 peer3.clone(),
374 ]))
375 .await
376 .unwrap();
377 let sent_to = checked.send(IoBuf::from(b"filtered"), false).await.unwrap();
378
379 assert_eq!(sent_to.len(), 2);
381 assert!(!sent_to.contains(&peer1));
382 assert!(sent_to.contains(&peer2));
383 assert!(sent_to.contains(&peer3));
384 });
385 }
386
387 #[test]
388 fn check_some_all_rate_limited_returns_error() {
389 Runner::default().start(|context| async move {
390 let sender = MockSender::new();
391 let (peers, _peer_sender) = MockPeers::new();
392 let mut limited =
393 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
394
395 let peer1 = key(1);
396 let peer2 = key(2);
397
398 limited
400 .check(Recipients::One(peer1.clone()))
401 .await
402 .unwrap()
403 .send(IoBuf::from(b"limit1"), false)
404 .await
405 .unwrap();
406
407 limited
408 .check(Recipients::One(peer2.clone()))
409 .await
410 .unwrap()
411 .send(IoBuf::from(b"limit2"), false)
412 .await
413 .unwrap();
414
415 assert!(limited
417 .check(Recipients::Some(vec![peer1, peer2]))
418 .await
419 .is_err());
420 });
421 }
422
423 #[test]
424 fn check_some_empty_returns_as_is() {
425 Runner::default().start(|context| async move {
426 let sender = MockSender::new();
427 let (peers, _peer_sender) = MockPeers::new();
428 let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
429
430 limited.check(Recipients::Some(Vec::new())).await.unwrap();
432 });
433 }
434
435 #[test]
436 fn check_all_uses_known_peers() {
437 Runner::default().start(|context| async move {
438 let sender = MockSender::new();
439 let (peers, _) = MockPeers::new();
440 let mut limited =
441 LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
442
443 let checked = limited.check(Recipients::All).await.unwrap();
445 let sent_to = checked.send(IoBuf::from(b"empty"), false).await.unwrap();
446 assert!(sent_to.is_empty());
447
448 let messages = sender.sent_messages().await;
450 assert_eq!(messages.len(), 1);
451 match &messages[0].0 {
452 Recipients::Some(pks) => assert!(pks.is_empty()),
453 _ => panic!("expected Recipients::Some"),
454 }
455 });
456 }
457
458 #[test]
459 fn check_all_filters_rate_limited_known_peers() {
460 Runner::default().start(|context| async move {
461 let sender = MockSender::new();
462 let (peers, _) = MockPeers::new();
463 let mut limited =
464 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
465
466 let peer1 = key(1);
467 let peer2 = key(2);
468
469 let _ = limited.check(Recipients::All).await;
471
472 limited.known_peers = vec![peer1.clone(), peer2.clone()];
474
475 limited
477 .check(Recipients::One(peer1.clone()))
478 .await
479 .unwrap()
480 .send(IoBuf::from(b"limit"), false)
481 .await
482 .unwrap();
483
484 let checked = limited.check(Recipients::All).await.unwrap();
486 let sent_to = checked.send(IoBuf::from(b"filtered"), false).await.unwrap();
487
488 assert_eq!(sent_to.len(), 1);
489 assert!(!sent_to.contains(&peer1));
490 assert!(sent_to.contains(&peer2));
491 });
492 }
493
494 #[test]
495 fn check_all_returns_error_when_all_known_peers_rate_limited() {
496 Runner::default().start(|context| async move {
497 let sender = MockSender::new();
498 let (peers, _) = MockPeers::new();
499 let mut limited =
500 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
501
502 let peer1 = key(1);
503 let peer2 = key(2);
504
505 let _ = limited.check(Recipients::All).await;
507
508 limited.known_peers = vec![peer1.clone(), peer2.clone()];
510
511 limited
513 .check(Recipients::One(peer1.clone()))
514 .await
515 .unwrap()
516 .send(IoBuf::from(b"limit1"), false)
517 .await
518 .unwrap();
519
520 limited
521 .check(Recipients::One(peer2.clone()))
522 .await
523 .unwrap()
524 .send(IoBuf::from(b"limit2"), false)
525 .await
526 .unwrap();
527
528 assert!(limited.check(Recipients::All).await.is_err());
530 });
531 }
532
533 #[test]
534 fn clone_creates_independent_subscription() {
535 Runner::default().start(|context| async move {
536 let sender = MockSender::new();
537 let (peers, _) = MockPeers::new();
538 let mut limited1 = LimitedSender::new(sender, quota_per_second(10), context, peers);
539
540 let _ = limited1.check(Recipients::All).await;
542 limited1.known_peers = vec![key(1)];
543
544 let limited2 = limited1.clone();
546 assert!(limited2.peer_subscription.is_none());
547 assert!(limited2.known_peers.is_empty());
548 });
549 }
550
551 #[test]
552 fn checked_sender_sends_with_priority() {
553 Runner::default().start(|context| async move {
554 let sender = MockSender::new();
555 let (peers, _peer_sender) = MockPeers::new();
556 let mut limited =
557 LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
558
559 let peer = key(1);
560 limited
561 .check(Recipients::One(peer))
562 .await
563 .unwrap()
564 .send(IoBuf::from(b"priority"), true)
565 .await
566 .unwrap();
567
568 let messages = sender.sent_messages().await;
569 assert_eq!(messages.len(), 1);
570 assert!(messages[0].2); });
572 }
573
574 #[test]
575 fn rate_limit_shared_across_clones() {
576 Runner::default().start(|context| async move {
577 let sender = MockSender::new();
578 let (peers, _) = MockPeers::new();
579 let mut limited1 =
580 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
581 let mut limited2 = limited1.clone();
582
583 let peer = key(1);
584
585 limited1
587 .check(Recipients::One(peer.clone()))
588 .await
589 .unwrap()
590 .send(IoBuf::from(b"limit"), false)
591 .await
592 .unwrap();
593
594 assert!(limited2.check(Recipients::One(peer)).await.is_err());
596 });
597 }
598}