1use crate::{Recipients, UnlimitedSender};
4use bytes::Bytes;
5use commonware_cryptography::PublicKey;
6use commonware_runtime::{Clock, KeyedRateLimiter, Quota};
7use commonware_utils::channels::ring;
8use futures::{lock::Mutex, Future, FutureExt, StreamExt};
9use std::{cmp, fmt, sync::Arc, time::SystemTime};
10
11pub trait Connected: Clone + Send + Sync + 'static {
16 type PublicKey: PublicKey;
17
18 fn subscribe(&mut self) -> impl Future<Output = ring::Receiver<Vec<Self::PublicKey>>> + Send;
25}
26
27pub struct LimitedSender<E, S, P>
29where
30 E: Clock,
31 S: UnlimitedSender,
32 P: Connected<PublicKey = S::PublicKey>,
33{
34 sender: S,
35 rate_limit: Arc<Mutex<KeyedRateLimiter<S::PublicKey, E>>>,
36 peers: P,
37 peer_subscription: Option<ring::Receiver<Vec<S::PublicKey>>>,
38 known_peers: Vec<S::PublicKey>,
39}
40
41impl<E, S, P> Clone for LimitedSender<E, S, P>
42where
43 E: Clock,
44 S: UnlimitedSender,
45 P: Connected<PublicKey = S::PublicKey>,
46{
47 fn clone(&self) -> Self {
48 Self {
49 sender: self.sender.clone(),
50 rate_limit: self.rate_limit.clone(),
51 peers: self.peers.clone(),
52 peer_subscription: None,
53 known_peers: Vec::new(),
54 }
55 }
56}
57
58impl<E, S, P> fmt::Debug for LimitedSender<E, S, P>
59where
60 E: Clock,
61 S: UnlimitedSender,
62 P: Connected<PublicKey = S::PublicKey>,
63{
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("LimitedSender")
66 .field("known_peers", &self.known_peers.len())
67 .finish_non_exhaustive()
68 }
69}
70
71impl<E, S, P> LimitedSender<E, S, P>
72where
73 E: Clock,
74 S: UnlimitedSender,
75 P: Connected<PublicKey = S::PublicKey>,
76{
77 pub fn new(sender: S, quota: Quota, clock: E, peers: P) -> Self {
79 let rate_limit = Arc::new(Mutex::new(KeyedRateLimiter::hashmap_with_clock(
80 quota, clock,
81 )));
82 Self {
83 sender,
84 rate_limit,
85 peers,
86 peer_subscription: None,
87 known_peers: Vec::new(),
88 }
89 }
90
91 pub async fn check(
97 &mut self,
98 recipients: Recipients<S::PublicKey>,
99 ) -> Result<CheckedSender<'_, S>, SystemTime> {
100 if self.peer_subscription.is_none() {
102 self.peer_subscription = Some(self.peers.subscribe().await);
103 }
104
105 let rate_limit = self.rate_limit.lock().await;
106
107 if let Some(ref mut subscription) = self.peer_subscription {
109 if let Some(peers) = subscription.next().now_or_never().flatten() {
110 self.known_peers = peers;
111 rate_limit.retain_recent();
112 }
113 }
114
115 let recipients = match recipients {
116 Recipients::One(ref peer) => match rate_limit.check_key(peer) {
117 Ok(()) => recipients,
118 Err(not_until) => return Err(not_until.earliest_possible()),
119 },
120 Recipients::Some(ref peers) => {
121 let (allowed, max_retry) = filter_rate_limited(peers.iter(), &rate_limit);
122 if allowed.is_empty() {
123 match max_retry {
124 Some(retry) => return Err(retry),
125 None => recipients,
126 }
127 } else {
128 Recipients::Some(allowed)
129 }
130 }
131 Recipients::All => {
132 let (allowed, max_retry) =
133 filter_rate_limited(self.known_peers.iter(), &rate_limit);
134 if allowed.is_empty() {
135 match max_retry {
136 Some(retry) => return Err(retry),
137 None => Recipients::Some(Vec::new()),
138 }
139 } else {
140 Recipients::Some(allowed)
141 }
142 }
143 };
144
145 Ok(CheckedSender {
146 recipients,
147 sender: &mut self.sender,
148 })
149 }
150}
151
152pub(crate) fn filter_rate_limited<'a, K, C>(
155 peers: impl Iterator<Item = &'a K>,
156 rate_limit: &KeyedRateLimiter<K, C>,
157) -> (Vec<K>, Option<SystemTime>)
158where
159 K: PublicKey,
160 C: Clock,
161{
162 peers.fold(
163 (Vec::new(), None),
164 |(mut allowed, max_retry), p| match rate_limit.check_key(p) {
165 Ok(()) => {
166 allowed.push(p.clone());
167 (allowed, max_retry)
168 }
169 Err(not_until) => {
170 let earliest = not_until.earliest_possible();
171 let new_max = max_retry.map_or(earliest, |current| cmp::max(current, earliest));
172 (allowed, Some(new_max))
173 }
174 },
175 )
176}
177
178#[derive(Debug)]
183pub struct CheckedSender<'a, S: UnlimitedSender> {
184 sender: &'a mut S,
185 recipients: Recipients<S::PublicKey>,
186}
187
188impl<'a, S: UnlimitedSender> CheckedSender<'a, S> {
189 pub(crate) fn into_inner(self) -> &'a mut S {
196 self.sender
197 }
198}
199
200impl<'a, S: UnlimitedSender> crate::CheckedSender for CheckedSender<'a, S> {
201 type PublicKey = S::PublicKey;
202 type Error = S::Error;
203
204 async fn send(
205 self,
206 message: Bytes,
207 priority: bool,
208 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
209 self.sender.send(self.recipients, message, priority).await
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::CheckedSender as _;
217 use bytes::Bytes;
218 use commonware_cryptography::{ed25519, Signer as _};
219 use commonware_runtime::{deterministic::Runner, Quota, Runner as _};
220 use commonware_utils::{channels::ring, NZUsize, NZU32};
221 use thiserror::Error;
222
223 type PublicKey = ed25519::PublicKey;
224 type SentMessage = (Recipients<PublicKey>, Bytes, bool);
225
226 #[derive(Debug, Error)]
227 #[error("mock send error")]
228 struct MockError;
229
230 #[derive(Debug, Clone)]
231 struct MockSender {
232 sent: Arc<Mutex<Vec<SentMessage>>>,
233 }
234
235 impl MockSender {
236 fn new() -> Self {
237 Self {
238 sent: Arc::new(Mutex::new(Vec::new())),
239 }
240 }
241
242 async fn sent_messages(&self) -> Vec<SentMessage> {
243 self.sent.lock().await.clone()
244 }
245 }
246
247 impl UnlimitedSender for MockSender {
248 type Error = MockError;
249 type PublicKey = PublicKey;
250
251 async fn send(
252 &mut self,
253 recipients: Recipients<Self::PublicKey>,
254 message: Bytes,
255 priority: bool,
256 ) -> Result<Vec<Self::PublicKey>, Self::Error> {
257 let sent_to = match &recipients {
258 Recipients::One(pk) => vec![pk.clone()],
259 Recipients::Some(pks) => pks.clone(),
260 Recipients::All => Vec::new(),
261 };
262 self.sent.lock().await.push((recipients, message, priority));
263 Ok(sent_to)
264 }
265 }
266
267 #[derive(Clone)]
268 struct MockPeers {
269 sender: ring::Sender<Vec<PublicKey>>,
270 }
271
272 impl MockPeers {
273 fn new() -> (Self, ring::Sender<Vec<PublicKey>>) {
274 let (sender, _receiver) = ring::channel(NZUsize!(16));
275 let peers = Self {
276 sender: sender.clone(),
277 };
278 (peers, sender)
279 }
280 }
281
282 impl Connected for MockPeers {
283 type PublicKey = PublicKey;
284
285 async fn subscribe(&mut self) -> ring::Receiver<Vec<Self::PublicKey>> {
286 let (sender, receiver) = ring::channel(NZUsize!(16));
287 self.sender = sender;
289 receiver
290 }
291 }
292
293 fn key(seed: u64) -> PublicKey {
294 ed25519::PrivateKey::from_seed(seed).public_key()
295 }
296
297 fn quota_per_second(n: u32) -> Quota {
298 Quota::per_second(NZU32!(n))
299 }
300
301 #[test]
302 fn check_one_not_rate_limited() {
303 Runner::default().start(|context| async move {
304 let sender = MockSender::new();
305 let (peers, _peer_sender) = MockPeers::new();
306 let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
307
308 let peer = key(1);
309 let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
310 let sent_to = checked.send(Bytes::from("hello"), false).await.unwrap();
311 assert_eq!(sent_to, vec![peer]);
312 });
313 }
314
315 #[test]
316 fn check_one_rate_limited() {
317 Runner::default().start(|context| async move {
318 let sender = MockSender::new();
319 let (peers, _peer_sender) = MockPeers::new();
320 let mut limited =
321 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
322
323 let peer = key(1);
324
325 let checked = limited.check(Recipients::One(peer.clone())).await.unwrap();
327 checked.send(Bytes::from("first"), false).await.unwrap();
328
329 let result = limited.check(Recipients::One(peer)).await;
331 assert!(result.is_err());
332 });
333 }
334
335 #[test]
336 fn check_some_all_not_rate_limited() {
337 Runner::default().start(|context| async move {
338 let sender = MockSender::new();
339 let (peers, _peer_sender) = MockPeers::new();
340 let mut limited = LimitedSender::new(sender, quota_per_second(1), context, peers);
341
342 let peers_list = vec![key(1), key(2), key(3)];
343 let checked = limited
344 .check(Recipients::Some(peers_list.clone()))
345 .await
346 .unwrap();
347 let sent_to = checked.send(Bytes::from("hello"), false).await.unwrap();
348 assert_eq!(sent_to.len(), 3);
349 });
350 }
351
352 #[test]
353 fn check_some_filters_rate_limited_peers() {
354 Runner::default().start(|context| async move {
355 let sender = MockSender::new();
356 let (peers, _peer_sender) = MockPeers::new();
357 let mut limited =
358 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
359
360 let peer1 = key(1);
361 let peer2 = key(2);
362 let peer3 = key(3);
363
364 let checked = limited.check(Recipients::One(peer1.clone())).await.unwrap();
366 checked.send(Bytes::from("limit"), false).await.unwrap();
367
368 let checked = limited
370 .check(Recipients::Some(vec![
371 peer1.clone(),
372 peer2.clone(),
373 peer3.clone(),
374 ]))
375 .await
376 .unwrap();
377 let sent_to = checked.send(Bytes::from("filtered"), false).await.unwrap();
378
379 assert_eq!(sent_to.len(), 2);
381 assert!(!sent_to.contains(&peer1));
382 assert!(sent_to.contains(&peer2));
383 assert!(sent_to.contains(&peer3));
384 });
385 }
386
387 #[test]
388 fn check_some_all_rate_limited_returns_error() {
389 Runner::default().start(|context| async move {
390 let sender = MockSender::new();
391 let (peers, _peer_sender) = MockPeers::new();
392 let mut limited =
393 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
394
395 let peer1 = key(1);
396 let peer2 = key(2);
397
398 limited
400 .check(Recipients::One(peer1.clone()))
401 .await
402 .unwrap()
403 .send(Bytes::from("limit1"), false)
404 .await
405 .unwrap();
406
407 limited
408 .check(Recipients::One(peer2.clone()))
409 .await
410 .unwrap()
411 .send(Bytes::from("limit2"), false)
412 .await
413 .unwrap();
414
415 assert!(limited
417 .check(Recipients::Some(vec![peer1, peer2]))
418 .await
419 .is_err());
420 });
421 }
422
423 #[test]
424 fn check_some_empty_returns_as_is() {
425 Runner::default().start(|context| async move {
426 let sender = MockSender::new();
427 let (peers, _peer_sender) = MockPeers::new();
428 let mut limited = LimitedSender::new(sender, quota_per_second(10), context, peers);
429
430 limited.check(Recipients::Some(Vec::new())).await.unwrap();
432 });
433 }
434
435 #[test]
436 fn check_all_uses_known_peers() {
437 Runner::default().start(|context| async move {
438 let sender = MockSender::new();
439 let (peers, _) = MockPeers::new();
440 let mut limited =
441 LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
442
443 let checked = limited.check(Recipients::All).await.unwrap();
445 let sent_to = checked.send(Bytes::from("empty"), false).await.unwrap();
446 assert!(sent_to.is_empty());
447
448 let messages = sender.sent_messages().await;
450 assert_eq!(messages.len(), 1);
451 match &messages[0].0 {
452 Recipients::Some(pks) => assert!(pks.is_empty()),
453 _ => panic!("expected Recipients::Some"),
454 }
455 });
456 }
457
458 #[test]
459 fn check_all_filters_rate_limited_known_peers() {
460 Runner::default().start(|context| async move {
461 let sender = MockSender::new();
462 let (peers, _) = MockPeers::new();
463 let mut limited =
464 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
465
466 let peer1 = key(1);
467 let peer2 = key(2);
468
469 let _ = limited.check(Recipients::All).await;
471
472 limited.known_peers = vec![peer1.clone(), peer2.clone()];
474
475 limited
477 .check(Recipients::One(peer1.clone()))
478 .await
479 .unwrap()
480 .send(Bytes::from("limit"), false)
481 .await
482 .unwrap();
483
484 let checked = limited.check(Recipients::All).await.unwrap();
486 let sent_to = checked.send(Bytes::from("filtered"), false).await.unwrap();
487
488 assert_eq!(sent_to.len(), 1);
489 assert!(!sent_to.contains(&peer1));
490 assert!(sent_to.contains(&peer2));
491 });
492 }
493
494 #[test]
495 fn check_all_returns_error_when_all_known_peers_rate_limited() {
496 Runner::default().start(|context| async move {
497 let sender = MockSender::new();
498 let (peers, _) = MockPeers::new();
499 let mut limited =
500 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
501
502 let peer1 = key(1);
503 let peer2 = key(2);
504
505 let _ = limited.check(Recipients::All).await;
507
508 limited.known_peers = vec![peer1.clone(), peer2.clone()];
510
511 limited
513 .check(Recipients::One(peer1.clone()))
514 .await
515 .unwrap()
516 .send(Bytes::from("limit1"), false)
517 .await
518 .unwrap();
519
520 limited
521 .check(Recipients::One(peer2.clone()))
522 .await
523 .unwrap()
524 .send(Bytes::from("limit2"), false)
525 .await
526 .unwrap();
527
528 assert!(limited.check(Recipients::All).await.is_err());
530 });
531 }
532
533 #[test]
534 fn clone_creates_independent_subscription() {
535 Runner::default().start(|context| async move {
536 let sender = MockSender::new();
537 let (peers, _) = MockPeers::new();
538 let mut limited1 = LimitedSender::new(sender, quota_per_second(10), context, peers);
539
540 let _ = limited1.check(Recipients::All).await;
542 limited1.known_peers = vec![key(1)];
543
544 let limited2 = limited1.clone();
546 assert!(limited2.peer_subscription.is_none());
547 assert!(limited2.known_peers.is_empty());
548 });
549 }
550
551 #[test]
552 fn checked_sender_sends_with_priority() {
553 Runner::default().start(|context| async move {
554 let sender = MockSender::new();
555 let (peers, _peer_sender) = MockPeers::new();
556 let mut limited =
557 LimitedSender::new(sender.clone(), quota_per_second(10), context, peers);
558
559 let peer = key(1);
560 limited
561 .check(Recipients::One(peer))
562 .await
563 .unwrap()
564 .send(Bytes::from("priority"), true)
565 .await
566 .unwrap();
567
568 let messages = sender.sent_messages().await;
569 assert_eq!(messages.len(), 1);
570 assert!(messages[0].2); });
572 }
573
574 #[test]
575 fn rate_limit_shared_across_clones() {
576 Runner::default().start(|context| async move {
577 let sender = MockSender::new();
578 let (peers, _) = MockPeers::new();
579 let mut limited1 =
580 LimitedSender::new(sender.clone(), quota_per_second(1), context, peers);
581 let mut limited2 = limited1.clone();
582
583 let peer = key(1);
584
585 limited1
587 .check(Recipients::One(peer.clone()))
588 .await
589 .unwrap()
590 .send(Bytes::from("limit"), false)
591 .await
592 .unwrap();
593
594 assert!(limited2.check(Recipients::One(peer)).await.is_err());
596 });
597 }
598}