use crate::{Consumer, Delivery};
use commonware_utils::futures::{AbortablePool, Aborter};
use futures::future::Aborted;
use std::collections::{hash_map::Entry as HashMapEntry, HashMap};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Completion<K, S, Context = ()> {
pub context: Context,
pub delivery: Delivery<K, S>,
pub valid: bool,
}
struct Response<Context, V> {
context: Context,
value: V,
accepted: bool,
}
struct ActiveDelivery {
generation: u64,
_aborter: Aborter,
}
struct PooledCompletion<K, S, Context> {
generation: u64,
completion: Completion<K, S, Context>,
}
struct Entry<Context, V, State> {
delivery: Option<ActiveDelivery>,
response: Option<Response<Context, V>>,
state: Option<State>,
}
impl<Context, V, State> Entry<Context, V, State> {
const fn new(state: State) -> Self {
Self {
delivery: None,
response: None,
state: Some(state),
}
}
}
pub struct Tracker<Con, Context = (), State = ()>
where
Con: Consumer,
Con::Value: Clone + Send + 'static,
Context: Clone + Send + 'static,
{
entries: HashMap<Con::Key, Entry<Context, Con::Value, State>>,
deliveries: AbortablePool<PooledCompletion<Con::Key, Con::Subscriber, Context>>,
next_generation: u64,
consumer: Con,
}
impl<Con, Context, State> Tracker<Con, Context, State>
where
Con: Consumer,
Con::Value: Clone + Send + 'static,
Context: Clone + Send + 'static,
{
pub fn new(consumer: Con) -> Self {
Self {
entries: HashMap::new(),
deliveries: AbortablePool::default(),
next_generation: 0,
consumer,
}
}
pub fn contains(&self, key: &Con::Key) -> bool {
self.entries.contains_key(key)
}
pub(crate) fn insert_with_state(&mut self, key: Con::Key, state: State) -> bool {
match self.entries.entry(key) {
HashMapEntry::Vacant(entry) => {
entry.insert(Entry::new(state));
true
}
HashMapEntry::Occupied(_) => false,
}
}
pub fn remove(&mut self, key: &Con::Key) -> bool {
self.entries.remove(key).is_some()
}
pub(crate) fn remove_with_state(&mut self, key: &Con::Key) -> Option<Option<State>> {
self.entries.remove(key).map(|entry| entry.state)
}
pub(crate) fn take_state(&mut self, key: &Con::Key) -> Option<State> {
self.entries
.get_mut(key)
.and_then(|entry| entry.state.take())
}
pub fn retain<F: FnMut(&Con::Key) -> bool>(&mut self, mut predicate: F) -> usize {
let removed: Vec<_> = self.entries.extract_if(|key, _| !predicate(key)).collect();
removed.len()
}
pub fn drain(&mut self) -> usize {
let count = self.entries.len();
self.entries.clear();
count
}
pub fn deliver(
&mut self,
delivery: Delivery<Con::Key, Con::Subscriber>,
context: Context,
value: Con::Value,
) {
let key = delivery.key.clone();
let entry = self.entries.get_mut(&key).expect("delivery entry");
entry.response = Some(Response {
context: context.clone(),
value: value.clone(),
accepted: false,
});
self.push_delivery(delivery, context, value);
}
pub fn redeliver(&mut self, delivery: Delivery<Con::Key, Con::Subscriber>) {
let key = delivery.key.clone();
let (context, value) = {
let entry = self.entries.get(&key).expect("delivery entry");
let response = entry.response.as_ref().expect("response");
assert!(response.accepted, "accepted response");
(response.context.clone(), response.value.clone())
};
self.push_delivery(delivery, context, value);
}
pub fn response_accepted(&self, key: &Con::Key) -> bool {
self.entries
.get(key)
.and_then(|entry| entry.response.as_ref())
.is_some_and(|response| response.accepted)
}
pub fn accept_response(&mut self, key: &Con::Key) {
let entry = self.entries.get_mut(key).expect("delivery entry");
let response = entry.response.as_mut().expect("response");
response.accepted = true;
}
pub fn discard_response(&mut self, key: &Con::Key) {
if let Some(entry) = self.entries.get_mut(key) {
entry.response = None;
}
}
pub async fn next_completion(
&mut self,
) -> Result<Completion<Con::Key, Con::Subscriber, Context>, Aborted> {
let completed = self.deliveries.next_completed().await?;
let Some(entry) = self.entries.get_mut(&completed.completion.delivery.key) else {
return Err(Aborted);
};
if entry
.delivery
.as_ref()
.is_none_or(|delivery| delivery.generation != completed.generation)
{
return Err(Aborted);
}
entry.delivery = None;
Ok(completed.completion)
}
fn push_delivery(
&mut self,
delivery: Delivery<Con::Key, Con::Subscriber>,
context: Context,
value: Con::Value,
) {
let generation = self.next_generation;
self.next_generation = self
.next_generation
.checked_add(1)
.expect("delivery generation overflow");
let key = delivery.key.clone();
let completed = delivery.clone();
let mut consumer = self.consumer.clone();
let receiver = consumer.deliver(delivery, value);
let aborter = self.deliveries.push(async move {
PooledCompletion {
generation,
completion: Completion {
context,
delivery: completed,
valid: receiver.await.unwrap_or(false),
},
}
});
let entry = self.entries.get_mut(&key).expect("delivery entry");
assert!(entry
.delivery
.replace(ActiveDelivery {
generation,
_aborter: aborter,
})
.is_none());
}
}
impl<Con, Context> Tracker<Con, Context>
where
Con: Consumer,
Con::Value: Clone + Send + 'static,
Context: Clone + Send + 'static,
{
pub fn insert(&mut self, key: Con::Key) -> bool {
self.insert_with_state(key, ())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::p2p::mocks::{Consumer as MockConsumer, Key as MockKey};
use bytes::Bytes;
use commonware_runtime::{deterministic::Runner, Runner as _};
use commonware_utils::{
channel::{fallible::FallibleExt, mpsc, oneshot},
non_empty_vec,
};
type TestTracker = Tracker<MockConsumer<MockKey, Bytes>, u8>;
fn delivery(key: MockKey) -> Delivery<MockKey, ()> {
Delivery {
key,
subscribers: non_empty_vec![()],
}
}
#[derive(Clone)]
struct PendingConsumer {
sender: mpsc::UnboundedSender<oneshot::Sender<bool>>,
}
impl PendingConsumer {
fn new() -> (Self, mpsc::UnboundedReceiver<oneshot::Sender<bool>>) {
let (sender, receiver) = mpsc::unbounded_channel();
(Self { sender }, receiver)
}
}
impl Consumer for PendingConsumer {
type Key = MockKey;
type Value = Bytes;
type Subscriber = ();
fn deliver(
&mut self,
_delivery: Delivery<Self::Key, Self::Subscriber>,
_value: Self::Value,
) -> oneshot::Receiver<bool> {
let (sender, receiver) = oneshot::channel();
self.sender.send_lossy(sender);
receiver
}
}
#[test]
fn test_insert_contains_remove_round_trip() {
let runner = Runner::default();
runner.start(|_| async move {
let mut tracker = TestTracker::new(MockConsumer::dummy());
assert!(!tracker.contains(&MockKey(1)));
assert!(tracker.insert(MockKey(1)));
assert!(tracker.contains(&MockKey(1)));
assert!(!tracker.insert(MockKey(1)));
assert!(tracker.remove(&MockKey(1)));
assert!(!tracker.contains(&MockKey(1)));
assert!(!tracker.remove(&MockKey(1)));
});
}
#[test]
fn test_deliver_completes_with_context_and_consumer_result() {
let runner = Runner::default();
runner.start(|_| async move {
let (consumer, mut events) = MockConsumer::<MockKey, Bytes>::new();
let mut tracker = TestTracker::new(consumer);
let key = MockKey(7);
let value = Bytes::from("data");
tracker.insert(key.clone());
tracker.deliver(delivery(key.clone()), 9, value.clone());
let completed = tracker
.next_completion()
.await
.expect("delivery should complete");
assert_eq!(completed.context, 9);
assert_eq!(completed.delivery.key, key);
assert!(completed.valid);
let (delivered_key, delivered_value) = events.recv().await.unwrap();
assert_eq!(delivered_key, key);
assert_eq!(delivered_value, value);
});
}
#[test]
fn test_remove_aborts_in_flight_delivery() {
let runner = Runner::default();
runner.start(|_| async move {
let (consumer, _events) = MockConsumer::<MockKey, Bytes>::new();
let mut tracker = TestTracker::new(consumer);
let key = MockKey(1);
tracker.insert(key.clone());
tracker.deliver(delivery(key.clone()), 2, Bytes::from("v"));
assert!(tracker.remove(&key));
assert!(matches!(tracker.next_completion().await, Err(Aborted)));
});
}
#[test]
fn test_stale_same_key_completion_does_not_clear_new_delivery() {
let runner = Runner::default();
runner.start(|_| async move {
let (consumer, mut senders) = PendingConsumer::new();
let mut tracker = Tracker::<PendingConsumer, u8>::new(consumer);
let key = MockKey(1);
tracker.insert(key.clone());
tracker.deliver(delivery(key.clone()), 1, Bytes::from("old"));
let old_sender = senders.recv().await.unwrap();
old_sender.send(true).unwrap();
let stale = tracker.deliveries.next_completed().await.unwrap();
assert!(tracker.remove(&key));
tracker.insert(key.clone());
tracker.deliver(delivery(key.clone()), 2, Bytes::from("new"));
let new_sender = senders.recv().await.unwrap();
let _stale_aborter = tracker.deliveries.push(async move { stale });
assert!(matches!(tracker.next_completion().await, Err(Aborted)));
new_sender.send(true).unwrap();
let completed = tracker
.next_completion()
.await
.expect("new delivery should complete");
assert_eq!(completed.context, 2);
assert_eq!(completed.delivery.key, key);
assert!(completed.valid);
});
}
#[test]
fn test_redeliver_reuses_accepted_response_for_new_subscribers() {
let runner = Runner::default();
runner.start(|_| async move {
let (consumer, mut events) = MockConsumer::<MockKey, Bytes>::new();
let mut tracker = TestTracker::new(consumer);
let key = MockKey(5);
let value = Bytes::from("first");
tracker.insert(key.clone());
tracker.deliver(delivery(key.clone()), 3, value.clone());
let completed = tracker
.next_completion()
.await
.expect("first delivery should complete");
assert!(completed.valid);
tracker.accept_response(&key);
assert!(tracker.response_accepted(&key));
tracker.redeliver(delivery(key.clone()));
let redelivered = tracker
.next_completion()
.await
.expect("redelivery should complete");
assert_eq!(redelivered.context, 3);
assert_eq!(redelivered.delivery.key, key);
assert!(redelivered.valid);
let first = events.recv().await.unwrap();
let second = events.recv().await.unwrap();
assert_eq!(first, (key.clone(), value.clone()));
assert_eq!(second, (key, value));
});
}
#[test]
#[should_panic(expected = "accepted response")]
fn test_redeliver_requires_accepted_response() {
let runner = Runner::default();
runner.start(|_| async move {
let (consumer, _events) = MockConsumer::<MockKey, Bytes>::new();
let mut tracker = TestTracker::new(consumer);
let key = MockKey(7);
tracker.insert(key.clone());
tracker.deliver(delivery(key.clone()), 3, Bytes::from("first"));
let completed = tracker
.next_completion()
.await
.expect("first delivery should complete");
assert!(completed.valid);
tracker.redeliver(delivery(key));
});
}
#[test]
fn test_rejected_response_can_be_discarded_and_replaced() {
let runner = Runner::default();
runner.start(|_| async move {
let (mut consumer, _events) = MockConsumer::<MockKey, Bytes>::new();
let key = MockKey(8);
consumer.add_expected(key.clone(), Bytes::from("good"));
let mut tracker = TestTracker::new(consumer);
tracker.insert(key.clone());
tracker.deliver(delivery(key.clone()), 1, Bytes::from("bad"));
let rejected = tracker
.next_completion()
.await
.expect("rejected delivery should complete");
assert!(!rejected.valid);
tracker.discard_response(&key);
assert!(!tracker.response_accepted(&key));
tracker.deliver(delivery(key.clone()), 2, Bytes::from("good"));
let accepted = tracker
.next_completion()
.await
.expect("accepted delivery should complete");
assert_eq!(accepted.context, 2);
assert!(accepted.valid);
});
}
}