use crate::{Blocker, CheckedSender, Receiver, Recipients, Sender};
use commonware_actor::{mailbox, Feedback, Unreliable};
use commonware_codec::{Codec, Error};
use commonware_cryptography::PublicKey;
use commonware_macros::select_loop;
use commonware_parallel::Strategy;
use commonware_runtime::{
iobuf::EncodeExt, spawn_cell, BufferPool, ContextCell, Handle, Metrics, Spawner,
};
use commonware_utils::futures::Pool;
use std::{collections::VecDeque, num::NonZeroUsize, time::SystemTime};
pub const fn wrap<S: Sender, R: Receiver, V: Codec>(
config: V::Cfg,
pool: BufferPool,
sender: S,
receiver: R,
) -> (WrappedSender<S, V>, WrappedReceiver<R, V>) {
(
WrappedSender::new(pool, sender),
WrappedReceiver::new(config, receiver),
)
}
pub type WrappedMessage<P, V> = (P, Result<V, Error>);
#[derive(Clone)]
pub struct WrappedSender<S: Sender, V: Codec> {
pool: BufferPool,
sender: S,
_phantom_v: std::marker::PhantomData<V>,
}
impl<S: Sender, V: Codec> WrappedSender<S, V> {
pub const fn new(pool: BufferPool, sender: S) -> Self {
Self {
pool,
sender,
_phantom_v: std::marker::PhantomData,
}
}
pub fn send(
&mut self,
recipients: Recipients<S::PublicKey>,
message: V,
priority: bool,
) -> Vec<S::PublicKey> {
let encoded = message.encode_with_pool(&self.pool);
self.sender.send(recipients, encoded, priority)
}
pub fn check(
&mut self,
recipients: Recipients<S::PublicKey>,
) -> Result<CheckedWrappedSender<'_, S, V>, SystemTime> {
self.sender
.check(recipients)
.map(|checked| CheckedWrappedSender {
pool: &self.pool,
sender: checked,
_phantom_v: std::marker::PhantomData,
})
}
}
#[derive(Debug)]
pub struct CheckedWrappedSender<'a, S: Sender, V: Codec> {
pool: &'a BufferPool,
sender: S::Checked<'a>,
_phantom_v: std::marker::PhantomData<V>,
}
impl<'a, S: Sender, V: Codec> CheckedWrappedSender<'a, S, V> {
pub fn recipients(&self) -> Vec<S::PublicKey> {
self.sender.recipients()
}
pub fn send(self, message: V, priority: bool) -> Unreliable<Feedback> {
let encoded = message.encode_with_pool(self.pool);
self.sender.send(encoded, priority)
}
}
pub struct WrappedReceiver<R: Receiver, V: Codec> {
config: V::Cfg,
receiver: R,
}
impl<R: Receiver, V: Codec> WrappedReceiver<R, V> {
pub const fn new(config: V::Cfg, receiver: R) -> Self {
Self { config, receiver }
}
pub async fn recv(&mut self) -> Result<WrappedMessage<R::PublicKey, V>, R::Error> {
let (pk, bytes) = self.receiver.recv().await?;
let decoded = match V::decode_cfg(bytes.as_ref(), &self.config) {
Ok(decoded) => decoded,
Err(e) => {
return Ok((pk, Err(e)));
}
};
Ok((pk, Ok(decoded)))
}
}
struct Decoded<P: PublicKey, V>(P, V);
impl<P: PublicKey, V> mailbox::UnreliablePolicy for Decoded<P, V> {
type Overflow = VecDeque<Self>;
fn handle(_overflow: &mut Self::Overflow, _message: Self) -> bool {
false
}
}
pub struct BackgroundReceiver<P: PublicKey, V> {
receiver: mailbox::UnreliableReceiver<Decoded<P, V>>,
}
impl<P: PublicKey, V> BackgroundReceiver<P, V> {
pub async fn recv(&mut self) -> Option<(P, V)> {
self.receiver
.recv()
.await
.map(|Decoded(peer, value)| (peer, value))
}
}
pub struct WrappedBackgroundReceiver<E, P, B, R, V>
where
E: Spawner,
P: PublicKey,
B: Blocker<PublicKey = P>,
R: Receiver<PublicKey = P>,
V: Codec + Send,
{
context: ContextCell<E>,
receiver: R,
codec_config: V::Cfg,
blocker: B,
sender: mailbox::UnreliableSender<Decoded<P, V>>,
max_concurrency: usize,
}
impl<E, P, B, R, V> WrappedBackgroundReceiver<E, P, B, R, V>
where
E: Spawner + Metrics,
P: PublicKey,
B: Blocker<PublicKey = P>,
R: Receiver<PublicKey = P>,
V: Codec + Send + 'static,
{
pub fn new(
context: E,
receiver: R,
codec_config: V::Cfg,
blocker: B,
channel_capacity: NonZeroUsize,
strategy: &impl Strategy,
) -> (Self, BackgroundReceiver<P, V>) {
let (tx, rx) = mailbox::new_unreliable(context.child("mailbox"), channel_capacity);
(
Self {
context: ContextCell::new(context),
receiver,
codec_config,
blocker,
sender: tx,
max_concurrency: strategy.parallelism_hint().max(1),
},
BackgroundReceiver { receiver: rx },
)
}
pub fn start(mut self) -> Handle<()> {
spawn_cell!(self.context, self.run())
}
async fn run(mut self) {
let mut decode_pool = Pool::default();
let mut receiver_closed = false;
select_loop! {
self.context,
on_start => {
let mut saw_error = false;
while decode_pool.len() >= self.max_concurrency
|| (receiver_closed && !decode_pool.is_empty())
{
let Ok(result) = decode_pool.next_completed().await else {
saw_error = true;
break;
};
Self::handle_decode_result(&mut self.blocker, &mut self.sender, result);
}
if saw_error || (receiver_closed && decode_pool.is_empty()) {
break;
}
},
on_stopped => {},
Ok(result) = decode_pool.next_completed() else break => {
Self::handle_decode_result(&mut self.blocker, &mut self.sender, result);
},
Ok((peer, bytes)) = self.receiver.recv() else {
receiver_closed = true;
continue;
} => {
let config = self.codec_config.clone();
let handle = self
.context
.child("decode")
.shared(true)
.spawn(|_| async move {
let result = V::decode_cfg(bytes.as_ref(), &config);
(peer, result)
});
decode_pool.push(handle);
},
}
}
fn handle_decode_result(
blocker: &mut B,
sender: &mut mailbox::UnreliableSender<Decoded<P, V>>,
result: (P, Result<V, commonware_codec::Error>),
) {
let (peer, decode_result) = result;
match decode_result {
Ok(value) => {
let _ = sender.enqueue(Decoded(peer, value));
}
Err(err) => {
crate::block!(blocker, peer, ?err, "received invalid message");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
simulated::{self, Link, Network, Oracle},
Manager as _, Recipients,
};
use commonware_actor::Feedback;
use commonware_codec::Encode;
use commonware_cryptography::{
ed25519::{PrivateKey, PublicKey},
Signer,
};
use commonware_macros::test_traced;
use commonware_parallel::{Sequential, Strategy};
use commonware_runtime::{deterministic, Clock as _, IoBuf, Quota, Runner, Supervisor as _};
use commonware_utils::{channel::mpsc, ordered::Set, NZUsize};
use std::{io, num::NonZeroU32, time::Duration};
const LINK: Link = Link {
latency: Duration::from_millis(0),
jitter: Duration::from_millis(0),
success_rate: 1.0,
};
const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
let (network, oracle) = Network::new(
context.child("network"),
simulated::Config {
max_size: 1024 * 1024,
disconnect_on_block: true,
tracked_peer_sets: NZUsize!(1),
},
);
network.start();
oracle
}
fn pk(seed: u64) -> PublicKey {
PrivateKey::from_seed(seed).public_key()
}
fn track_peers<I>(oracle: &Oracle<PublicKey, deterministic::Context>, index: u64, peers: I)
where
I: IntoIterator<Item = PublicKey>,
{
oracle.manager().track(index, Set::from_iter_dedup(peers));
}
async fn link_bidirectional(
oracle: &mut Oracle<PublicKey, deterministic::Context>,
a: PublicKey,
b: PublicKey,
) {
oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
oracle.add_link(b, a, LINK).await.unwrap();
}
#[derive(Clone, Copy, Debug)]
struct HintStrategy(usize);
impl Strategy for HintStrategy {
fn fold_init<I, INIT, T, R, ID, F, RD>(
&self,
iter: I,
init: INIT,
identity: ID,
fold_op: F,
_reduce_op: RD,
) -> R
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
INIT: Fn() -> T + Send + Sync,
T: Send,
R: Send,
ID: Fn() -> R + Send + Sync,
F: Fn(R, &mut T, I::Item) -> R + Send + Sync,
RD: Fn(R, R) -> R + Send + Sync,
{
let mut init_val = init();
iter.into_iter()
.fold(identity(), |acc, item| fold_op(acc, &mut init_val, item))
}
fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
(a(), b())
}
fn parallelism_hint(&self) -> usize {
self.0
}
}
#[derive(Debug)]
struct MockReceiver<P: commonware_cryptography::PublicKey> {
receiver: mpsc::UnboundedReceiver<crate::Message<P>>,
}
impl<P: commonware_cryptography::PublicKey> crate::Receiver for MockReceiver<P> {
type Error = io::Error;
type PublicKey = P;
async fn recv(&mut self) -> Result<crate::Message<Self::PublicKey>, Self::Error> {
self.receiver
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
}
}
#[derive(Clone, Default)]
struct NoopBlocker;
impl crate::Blocker for NoopBlocker {
type PublicKey = PublicKey;
fn block(&mut self, _peer: Self::PublicKey) -> Feedback {
Feedback::Ok
}
}
#[test_traced]
fn test_valid_messages_forwarded() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.child("network"));
let pk1 = pk(0);
let pk2 = pk(1);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]);
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.child("bg"),
receiver2,
(),
control2.clone(),
NZUsize!(16),
&Sequential,
);
let _handle = bg.start();
let msg: u32 = 42;
let _ = sender1.send(Recipients::One(pk2.clone()), msg.encode(), true);
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk1);
assert_eq!(value, 42u32);
});
}
#[test_traced]
fn test_invalid_codec_blocks_peer() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.child("network"));
let pk1 = pk(0);
let pk2 = pk(1);
let pk3 = pk(2);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone(), pk3.clone()]);
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.child("bg"),
receiver2,
(),
control2.clone(),
NZUsize!(16),
&Sequential,
);
let _handle = bg.start();
let invalid = IoBuf::from(vec![0xFFu8]);
let _ = sender1.send(Recipients::One(pk2.clone()), invalid, true);
let control3 = oracle.control(pk3.clone());
link_bidirectional(&mut oracle, pk3.clone(), pk2.clone()).await;
let (mut sender3, _) = control3.register(0, TEST_QUOTA).await.unwrap();
let msg: u32 = 99;
let _ = sender3.send(Recipients::One(pk2.clone()), msg.encode(), true);
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk3);
assert_eq!(value, 99u32);
loop {
let blocked = oracle.blocked().await.unwrap();
if blocked.contains(&(pk2.clone(), pk1.clone())) {
break;
}
context.sleep(Duration::from_millis(1)).await;
}
});
}
#[test_traced]
fn test_multiple_valid_messages() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.child("network"));
let pk1 = pk(0);
let pk2 = pk(1);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]);
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let count = 20;
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.child("bg"),
receiver2,
(),
control2.clone(),
NZUsize!(20),
&Sequential,
);
let _handle = bg.start();
for i in 0..count {
let msg: u32 = i;
let _ = sender1.send(Recipients::One(pk2.clone()), msg.encode(), true);
}
let mut received = Vec::new();
for _ in 0..count {
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk1);
received.push(value);
}
received.sort();
assert_eq!(received, (0..count).collect::<Vec<u32>>());
});
}
#[test_traced]
fn test_concurrency_bounded_by_strategy() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.child("network"));
let pk1 = pk(0);
let pk2 = pk(1);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]);
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let count = 50u32;
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.child("bg"),
receiver2,
(),
control2.clone(),
NZUsize!(50),
&Sequential,
);
let _handle = bg.start();
for i in 0..count {
let _ = sender1.send(Recipients::One(pk2.clone()), i.encode(), true);
}
let mut received = Vec::new();
for _ in 0..count {
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk1);
received.push(value);
}
received.sort();
assert_eq!(received, (0..count).collect::<Vec<u32>>());
});
}
#[test_traced]
fn test_invalid_among_valid_only_blocks_offender() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.child("network"));
let pk1 = pk(0);
let pk2 = pk(1);
let pk3 = pk(2);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
let control3 = oracle.control(pk3.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone(), pk3.clone()]);
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
link_bidirectional(&mut oracle, pk3.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (mut sender3, _) = control3.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.child("bg"),
receiver2,
(),
control2.clone(),
NZUsize!(16),
&Sequential,
);
let _handle = bg.start();
let _ = sender3.send(Recipients::One(pk2.clone()), 10u32.encode(), true);
let _ = sender1.send(Recipients::One(pk2.clone()), IoBuf::from(vec![0xFF]), true);
let _ = sender3.send(Recipients::One(pk2.clone()), 20u32.encode(), true);
let mut values = Vec::new();
for _ in 0..2 {
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk3);
values.push(value);
}
values.sort();
assert_eq!(values, vec![10u32, 20]);
loop {
let blocked = oracle.blocked().await.unwrap();
assert!(!blocked.contains(&(pk2.clone(), pk3.clone())));
if blocked.contains(&(pk2.clone(), pk1.clone())) {
break;
}
context.sleep(Duration::from_millis(1)).await;
}
});
}
#[test_traced]
fn test_decoded_messages_drop_when_receiver_full() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let sender = pk(0);
let (tx, receiver) = mpsc::unbounded_channel();
for i in 0..2u32 {
tx.send((sender.clone(), IoBuf::from(i.encode())))
.expect("mock receiver should be open");
}
drop(tx);
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.child("bg"),
MockReceiver { receiver },
(),
NoopBlocker,
NZUsize!(1),
&Sequential,
);
let handle = bg.start();
handle.await.expect("background receiver should complete");
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, sender);
assert_eq!(value, 0);
assert!(rx.recv().await.is_none());
});
}
#[test_traced]
fn test_drain_decode_pool_after_receiver_closure() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let sender = pk(0);
let (tx, receiver) = mpsc::unbounded_channel();
let count = 64u32;
for i in 0..count {
tx.send((sender.clone(), IoBuf::from(i.encode())))
.expect("mock receiver should be open");
}
drop(tx);
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.child("bg"),
MockReceiver { receiver },
(),
NoopBlocker,
NZUsize!(64),
&HintStrategy(8),
);
let _handle = bg.start();
let mut values = Vec::new();
while let Some((from, value)) = rx.recv().await {
assert_eq!(from, sender);
values.push(value);
}
values.sort_unstable();
assert_eq!(values, (0..count).collect::<Vec<u32>>());
});
}
}