1use crate::{Consumer, Delivery};
10use commonware_utils::futures::{AbortablePool, Aborter};
11use futures::future::Aborted;
12use std::collections::{hash_map::Entry as HashMapEntry, HashMap};
13
14#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct Completion<K, S, Context = ()> {
17 pub context: Context,
19
20 pub delivery: Delivery<K, S>,
22
23 pub valid: bool,
25}
26
27struct Response<Context, V> {
29 context: Context,
30 value: V,
31 accepted: bool,
32}
33
34struct ActiveDelivery {
36 generation: u64,
37 _aborter: Aborter,
38}
39
40struct PooledCompletion<K, S, Context> {
42 generation: u64,
43 completion: Completion<K, S, Context>,
44}
45
46struct Entry<Context, V, State> {
48 delivery: Option<ActiveDelivery>,
49 response: Option<Response<Context, V>>,
50 state: Option<State>,
51}
52
53impl<Context, V, State> Entry<Context, V, State> {
54 const fn new(state: State) -> Self {
55 Self {
56 delivery: None,
57 response: None,
58 state: Some(state),
59 }
60 }
61}
62
63pub struct Tracker<Con, Context = (), State = ()>
71where
72 Con: Consumer,
73 Con::Value: Clone + Send + 'static,
74 Context: Clone + Send + 'static,
75{
76 entries: HashMap<Con::Key, Entry<Context, Con::Value, State>>,
77 deliveries: AbortablePool<PooledCompletion<Con::Key, Con::Subscriber, Context>>,
78 next_generation: u64,
79 consumer: Con,
80}
81
82impl<Con, Context, State> Tracker<Con, Context, State>
83where
84 Con: Consumer,
85 Con::Value: Clone + Send + 'static,
86 Context: Clone + Send + 'static,
87{
88 pub fn new(consumer: Con) -> Self {
90 Self {
91 entries: HashMap::new(),
92 deliveries: AbortablePool::default(),
93 next_generation: 0,
94 consumer,
95 }
96 }
97
98 pub fn contains(&self, key: &Con::Key) -> bool {
100 self.entries.contains_key(key)
101 }
102
103 pub(crate) fn insert_with_state(&mut self, key: Con::Key, state: State) -> bool {
108 match self.entries.entry(key) {
109 HashMapEntry::Vacant(entry) => {
110 entry.insert(Entry::new(state));
111 true
112 }
113 HashMapEntry::Occupied(_) => false,
114 }
115 }
116
117 pub fn remove(&mut self, key: &Con::Key) -> bool {
122 self.entries.remove(key).is_some()
123 }
124
125 pub(crate) fn remove_with_state(&mut self, key: &Con::Key) -> Option<Option<State>> {
130 self.entries.remove(key).map(|entry| entry.state)
131 }
132
133 pub(crate) fn take_state(&mut self, key: &Con::Key) -> Option<State> {
137 self.entries
138 .get_mut(key)
139 .and_then(|entry| entry.state.take())
140 }
141
142 pub fn retain<F: FnMut(&Con::Key) -> bool>(&mut self, mut predicate: F) -> usize {
147 let removed: Vec<_> = self.entries.extract_if(|key, _| !predicate(key)).collect();
148 removed.len()
149 }
150
151 pub fn drain(&mut self) -> usize {
155 let count = self.entries.len();
156 self.entries.clear();
157 count
158 }
159
160 pub fn deliver(
166 &mut self,
167 delivery: Delivery<Con::Key, Con::Subscriber>,
168 context: Context,
169 value: Con::Value,
170 ) {
171 let key = delivery.key.clone();
172 let entry = self.entries.get_mut(&key).expect("delivery entry");
173 entry.response = Some(Response {
174 context: context.clone(),
175 value: value.clone(),
176 accepted: false,
177 });
178 self.push_delivery(delivery, context, value);
179 }
180
181 pub fn redeliver(&mut self, delivery: Delivery<Con::Key, Con::Subscriber>) {
187 let key = delivery.key.clone();
188 let (context, value) = {
189 let entry = self.entries.get(&key).expect("delivery entry");
190 let response = entry.response.as_ref().expect("response");
191 assert!(response.accepted, "accepted response");
192 (response.context.clone(), response.value.clone())
193 };
194 self.push_delivery(delivery, context, value);
195 }
196
197 pub fn response_accepted(&self, key: &Con::Key) -> bool {
199 self.entries
200 .get(key)
201 .and_then(|entry| entry.response.as_ref())
202 .is_some_and(|response| response.accepted)
203 }
204
205 pub fn accept_response(&mut self, key: &Con::Key) {
209 let entry = self.entries.get_mut(key).expect("delivery entry");
210 let response = entry.response.as_mut().expect("response");
211 response.accepted = true;
212 }
213
214 pub fn discard_response(&mut self, key: &Con::Key) {
219 if let Some(entry) = self.entries.get_mut(key) {
220 entry.response = None;
221 }
222 }
223
224 pub async fn next_completion(
231 &mut self,
232 ) -> Result<Completion<Con::Key, Con::Subscriber, Context>, Aborted> {
233 let completed = self.deliveries.next_completed().await?;
234 let Some(entry) = self.entries.get_mut(&completed.completion.delivery.key) else {
235 return Err(Aborted);
236 };
237 if entry
238 .delivery
239 .as_ref()
240 .is_none_or(|delivery| delivery.generation != completed.generation)
241 {
242 return Err(Aborted);
243 }
244 entry.delivery = None;
245 Ok(completed.completion)
246 }
247
248 fn push_delivery(
250 &mut self,
251 delivery: Delivery<Con::Key, Con::Subscriber>,
252 context: Context,
253 value: Con::Value,
254 ) {
255 let generation = self.next_generation;
256 self.next_generation = self
257 .next_generation
258 .checked_add(1)
259 .expect("delivery generation overflow");
260 let key = delivery.key.clone();
261 let completed = delivery.clone();
262 let mut consumer = self.consumer.clone();
263 let receiver = consumer.deliver(delivery, value);
264 let aborter = self.deliveries.push(async move {
265 PooledCompletion {
266 generation,
267 completion: Completion {
268 context,
269 delivery: completed,
270 valid: receiver.await.unwrap_or(false),
271 },
272 }
273 });
274 let entry = self.entries.get_mut(&key).expect("delivery entry");
275 assert!(entry
276 .delivery
277 .replace(ActiveDelivery {
278 generation,
279 _aborter: aborter,
280 })
281 .is_none());
282 }
283}
284
285impl<Con, Context> Tracker<Con, Context>
286where
287 Con: Consumer,
288 Con::Value: Clone + Send + 'static,
289 Context: Clone + Send + 'static,
290{
291 pub fn insert(&mut self, key: Con::Key) -> bool {
296 self.insert_with_state(key, ())
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::p2p::mocks::{Consumer as MockConsumer, Key as MockKey};
304 use bytes::Bytes;
305 use commonware_runtime::{deterministic::Runner, Runner as _};
306 use commonware_utils::{
307 channel::{fallible::FallibleExt, mpsc, oneshot},
308 non_empty_vec,
309 };
310
311 type TestTracker = Tracker<MockConsumer<MockKey, Bytes>, u8>;
312
313 fn delivery(key: MockKey) -> Delivery<MockKey, ()> {
314 Delivery {
315 key,
316 subscribers: non_empty_vec![()],
317 }
318 }
319
320 #[derive(Clone)]
321 struct PendingConsumer {
322 sender: mpsc::UnboundedSender<oneshot::Sender<bool>>,
323 }
324
325 impl PendingConsumer {
326 fn new() -> (Self, mpsc::UnboundedReceiver<oneshot::Sender<bool>>) {
327 let (sender, receiver) = mpsc::unbounded_channel();
328 (Self { sender }, receiver)
329 }
330 }
331
332 impl Consumer for PendingConsumer {
333 type Key = MockKey;
334 type Value = Bytes;
335 type Subscriber = ();
336
337 fn deliver(
338 &mut self,
339 _delivery: Delivery<Self::Key, Self::Subscriber>,
340 _value: Self::Value,
341 ) -> oneshot::Receiver<bool> {
342 let (sender, receiver) = oneshot::channel();
343 self.sender.send_lossy(sender);
344 receiver
345 }
346 }
347
348 #[test]
349 fn test_insert_contains_remove_round_trip() {
350 let runner = Runner::default();
351 runner.start(|_| async move {
352 let mut tracker = TestTracker::new(MockConsumer::dummy());
353
354 assert!(!tracker.contains(&MockKey(1)));
355 assert!(tracker.insert(MockKey(1)));
356 assert!(tracker.contains(&MockKey(1)));
357
358 assert!(!tracker.insert(MockKey(1)));
359 assert!(tracker.remove(&MockKey(1)));
360 assert!(!tracker.contains(&MockKey(1)));
361 assert!(!tracker.remove(&MockKey(1)));
362 });
363 }
364
365 #[test]
366 fn test_deliver_completes_with_context_and_consumer_result() {
367 let runner = Runner::default();
368 runner.start(|_| async move {
369 let (consumer, mut events) = MockConsumer::<MockKey, Bytes>::new();
370 let mut tracker = TestTracker::new(consumer);
371 let key = MockKey(7);
372 let value = Bytes::from("data");
373
374 tracker.insert(key.clone());
375 tracker.deliver(delivery(key.clone()), 9, value.clone());
376
377 let completed = tracker
378 .next_completion()
379 .await
380 .expect("delivery should complete");
381 assert_eq!(completed.context, 9);
382 assert_eq!(completed.delivery.key, key);
383 assert!(completed.valid);
384
385 let (delivered_key, delivered_value) = events.recv().await.unwrap();
386 assert_eq!(delivered_key, key);
387 assert_eq!(delivered_value, value);
388 });
389 }
390
391 #[test]
392 fn test_remove_aborts_in_flight_delivery() {
393 let runner = Runner::default();
394 runner.start(|_| async move {
395 let (consumer, _events) = MockConsumer::<MockKey, Bytes>::new();
396 let mut tracker = TestTracker::new(consumer);
397 let key = MockKey(1);
398
399 tracker.insert(key.clone());
400 tracker.deliver(delivery(key.clone()), 2, Bytes::from("v"));
401 assert!(tracker.remove(&key));
402
403 assert!(matches!(tracker.next_completion().await, Err(Aborted)));
404 });
405 }
406
407 #[test]
408 fn test_stale_same_key_completion_does_not_clear_new_delivery() {
409 let runner = Runner::default();
410 runner.start(|_| async move {
411 let (consumer, mut senders) = PendingConsumer::new();
412 let mut tracker = Tracker::<PendingConsumer, u8>::new(consumer);
413 let key = MockKey(1);
414
415 tracker.insert(key.clone());
416 tracker.deliver(delivery(key.clone()), 1, Bytes::from("old"));
417 let old_sender = senders.recv().await.unwrap();
418 old_sender.send(true).unwrap();
419 let stale = tracker.deliveries.next_completed().await.unwrap();
420
421 assert!(tracker.remove(&key));
422 tracker.insert(key.clone());
423 tracker.deliver(delivery(key.clone()), 2, Bytes::from("new"));
424 let new_sender = senders.recv().await.unwrap();
425
426 let _stale_aborter = tracker.deliveries.push(async move { stale });
427 assert!(matches!(tracker.next_completion().await, Err(Aborted)));
428
429 new_sender.send(true).unwrap();
430 let completed = tracker
431 .next_completion()
432 .await
433 .expect("new delivery should complete");
434 assert_eq!(completed.context, 2);
435 assert_eq!(completed.delivery.key, key);
436 assert!(completed.valid);
437 });
438 }
439
440 #[test]
441 fn test_redeliver_reuses_accepted_response_for_new_subscribers() {
442 let runner = Runner::default();
443 runner.start(|_| async move {
444 let (consumer, mut events) = MockConsumer::<MockKey, Bytes>::new();
445 let mut tracker = TestTracker::new(consumer);
446 let key = MockKey(5);
447 let value = Bytes::from("first");
448
449 tracker.insert(key.clone());
450 tracker.deliver(delivery(key.clone()), 3, value.clone());
451
452 let completed = tracker
453 .next_completion()
454 .await
455 .expect("first delivery should complete");
456 assert!(completed.valid);
457 tracker.accept_response(&key);
458 assert!(tracker.response_accepted(&key));
459
460 tracker.redeliver(delivery(key.clone()));
461 let redelivered = tracker
462 .next_completion()
463 .await
464 .expect("redelivery should complete");
465 assert_eq!(redelivered.context, 3);
466 assert_eq!(redelivered.delivery.key, key);
467 assert!(redelivered.valid);
468
469 let first = events.recv().await.unwrap();
470 let second = events.recv().await.unwrap();
471 assert_eq!(first, (key.clone(), value.clone()));
472 assert_eq!(second, (key, value));
473 });
474 }
475
476 #[test]
477 #[should_panic(expected = "accepted response")]
478 fn test_redeliver_requires_accepted_response() {
479 let runner = Runner::default();
480 runner.start(|_| async move {
481 let (consumer, _events) = MockConsumer::<MockKey, Bytes>::new();
482 let mut tracker = TestTracker::new(consumer);
483 let key = MockKey(7);
484
485 tracker.insert(key.clone());
486 tracker.deliver(delivery(key.clone()), 3, Bytes::from("first"));
487 let completed = tracker
488 .next_completion()
489 .await
490 .expect("first delivery should complete");
491 assert!(completed.valid);
492
493 tracker.redeliver(delivery(key));
494 });
495 }
496
497 #[test]
498 fn test_rejected_response_can_be_discarded_and_replaced() {
499 let runner = Runner::default();
500 runner.start(|_| async move {
501 let (mut consumer, _events) = MockConsumer::<MockKey, Bytes>::new();
502 let key = MockKey(8);
503 consumer.add_expected(key.clone(), Bytes::from("good"));
504 let mut tracker = TestTracker::new(consumer);
505
506 tracker.insert(key.clone());
507 tracker.deliver(delivery(key.clone()), 1, Bytes::from("bad"));
508 let rejected = tracker
509 .next_completion()
510 .await
511 .expect("rejected delivery should complete");
512 assert!(!rejected.valid);
513
514 tracker.discard_response(&key);
515 assert!(!tracker.response_accepted(&key));
516 tracker.deliver(delivery(key.clone()), 2, Bytes::from("good"));
517
518 let accepted = tracker
519 .next_completion()
520 .await
521 .expect("accepted delivery should complete");
522 assert_eq!(accepted.context, 2);
523 assert!(accepted.valid);
524 });
525 }
526}