use crate::{Blocker, CheckedSender, Receiver, Recipients, Sender};
use commonware_codec::{Codec, Error};
use commonware_cryptography::PublicKey;
use commonware_macros::select_loop;
use commonware_parallel::Strategy;
use commonware_runtime::{iobuf::EncodeExt, spawn_cell, BufferPool, ContextCell, Handle, Spawner};
use commonware_utils::{
channel::{fallible::AsyncFallibleExt, mpsc},
futures::Pool,
};
use std::time::SystemTime;
pub const fn wrap<S: Sender, R: Receiver, V: Codec>(
config: V::Cfg,
pool: BufferPool,
sender: S,
receiver: R,
) -> (WrappedSender<S, V>, WrappedReceiver<R, V>) {
(
WrappedSender::new(pool, sender),
WrappedReceiver::new(config, receiver),
)
}
pub type WrappedMessage<P, V> = (P, Result<V, Error>);
#[derive(Clone)]
pub struct WrappedSender<S: Sender, V: Codec> {
pool: BufferPool,
sender: S,
_phantom_v: std::marker::PhantomData<V>,
}
impl<S: Sender, V: Codec> WrappedSender<S, V> {
pub const fn new(pool: BufferPool, sender: S) -> Self {
Self {
pool,
sender,
_phantom_v: std::marker::PhantomData,
}
}
pub async fn send(
&mut self,
recipients: Recipients<S::PublicKey>,
message: V,
priority: bool,
) -> Result<Vec<S::PublicKey>, <S::Checked<'_> as CheckedSender>::Error> {
let encoded = message.encode_with_pool(&self.pool);
self.sender.send(recipients, encoded, priority).await
}
pub async fn check(
&mut self,
recipients: Recipients<S::PublicKey>,
) -> Result<CheckedWrappedSender<'_, S, V>, SystemTime> {
self.sender
.check(recipients)
.await
.map(|checked| CheckedWrappedSender {
pool: &self.pool,
sender: checked,
_phantom_v: std::marker::PhantomData,
})
}
}
#[derive(Debug)]
pub struct CheckedWrappedSender<'a, S: Sender, V: Codec> {
pool: &'a BufferPool,
sender: S::Checked<'a>,
_phantom_v: std::marker::PhantomData<V>,
}
impl<'a, S: Sender, V: Codec> CheckedWrappedSender<'a, S, V> {
pub async fn send(
self,
message: V,
priority: bool,
) -> Result<Vec<S::PublicKey>, <S::Checked<'a> as CheckedSender>::Error> {
let encoded = message.encode_with_pool(self.pool);
self.sender.send(encoded, priority).await
}
}
pub struct WrappedReceiver<R: Receiver, V: Codec> {
config: V::Cfg,
receiver: R,
}
impl<R: Receiver, V: Codec> WrappedReceiver<R, V> {
pub const fn new(config: V::Cfg, receiver: R) -> Self {
Self { config, receiver }
}
pub async fn recv(&mut self) -> Result<WrappedMessage<R::PublicKey, V>, R::Error> {
let (pk, bytes) = self.receiver.recv().await?;
let decoded = match V::decode_cfg(bytes.as_ref(), &self.config) {
Ok(decoded) => decoded,
Err(e) => {
return Ok((pk, Err(e)));
}
};
Ok((pk, Ok(decoded)))
}
}
pub struct WrappedBackgroundReceiver<E, P, B, R, V>
where
E: Spawner,
P: PublicKey,
B: Blocker<PublicKey = P>,
R: Receiver<PublicKey = P>,
V: Codec + Send,
{
context: ContextCell<E>,
receiver: R,
codec_config: V::Cfg,
blocker: B,
sender: mpsc::Sender<(P, V)>,
max_concurrency: usize,
}
impl<E, P, B, R, V> WrappedBackgroundReceiver<E, P, B, R, V>
where
E: Spawner,
P: PublicKey,
B: Blocker<PublicKey = P>,
R: Receiver<PublicKey = P>,
V: Codec + Send + 'static,
{
pub fn new(
context: E,
receiver: R,
codec_config: V::Cfg,
blocker: B,
channel_capacity: usize,
strategy: &impl Strategy,
) -> (Self, mpsc::Receiver<(P, V)>) {
let (tx, rx) = mpsc::channel(channel_capacity);
(
Self {
context: ContextCell::new(context),
receiver,
codec_config,
blocker,
sender: tx,
max_concurrency: strategy.parallelism_hint().max(1),
},
rx,
)
}
pub fn start(mut self) -> Handle<()> {
spawn_cell!(self.context, self.run().await)
}
async fn run(mut self) {
let mut decode_pool = Pool::default();
let mut receiver_closed = false;
select_loop! {
self.context,
on_start => {
let mut saw_error = false;
while decode_pool.len() >= self.max_concurrency
|| (receiver_closed && !decode_pool.is_empty())
{
let Ok(result) = decode_pool.next_completed().await else {
saw_error = true;
break;
};
Self::handle_decode_result(&mut self.blocker, &mut self.sender, result).await;
}
if saw_error || (receiver_closed && decode_pool.is_empty()) {
break;
}
},
on_stopped => {},
Ok(result) = decode_pool.next_completed() else break => {
Self::handle_decode_result(&mut self.blocker, &mut self.sender, result).await;
},
Ok((peer, bytes)) = self.receiver.recv() else {
receiver_closed = true;
continue;
} => {
let config = self.codec_config.clone();
let handle = self.context.clone().shared(true).spawn(|_| async move {
let result = V::decode_cfg(bytes.as_ref(), &config);
(peer, result)
});
decode_pool.push(handle);
},
}
}
async fn handle_decode_result(
blocker: &mut B,
sender: &mut mpsc::Sender<(P, V)>,
result: (P, Result<V, commonware_codec::Error>),
) {
let (peer, decode_result) = result;
match decode_result {
Ok(value) => {
sender.send_lossy((peer, value)).await;
}
Err(err) => {
crate::block!(blocker, peer, ?err, "received invalid message");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
simulated::{self, Link, Network, Oracle},
Manager as _, Recipients,
};
use commonware_codec::Encode;
use commonware_cryptography::{
ed25519::{PrivateKey, PublicKey},
Signer,
};
use commonware_macros::test_traced;
use commonware_parallel::{Sequential, Strategy};
use commonware_runtime::{deterministic, IoBuf, Metrics, Quota, Runner};
use commonware_utils::{ordered::Set, NZUsize};
use std::{io, num::NonZeroU32, time::Duration};
const LINK: Link = Link {
latency: Duration::from_millis(0),
jitter: Duration::from_millis(0),
success_rate: 1.0,
};
const TEST_QUOTA: Quota = Quota::per_second(NonZeroU32::MAX);
fn start_network(context: deterministic::Context) -> Oracle<PublicKey, deterministic::Context> {
let (network, oracle) = Network::new(
context.with_label("network"),
simulated::Config {
max_size: 1024 * 1024,
disconnect_on_block: true,
tracked_peer_sets: NZUsize!(1),
},
);
network.start();
oracle
}
fn pk(seed: u64) -> PublicKey {
PrivateKey::from_seed(seed).public_key()
}
async fn track_peers<I>(
oracle: &Oracle<PublicKey, deterministic::Context>,
index: u64,
peers: I,
) where
I: IntoIterator<Item = PublicKey>,
{
oracle
.manager()
.track(index, Set::from_iter_dedup(peers))
.await;
}
async fn link_bidirectional(
oracle: &mut Oracle<PublicKey, deterministic::Context>,
a: PublicKey,
b: PublicKey,
) {
oracle.add_link(a.clone(), b.clone(), LINK).await.unwrap();
oracle.add_link(b, a, LINK).await.unwrap();
}
#[derive(Clone, Copy, Debug)]
struct HintStrategy(usize);
impl Strategy for HintStrategy {
fn fold_init<I, INIT, T, R, ID, F, RD>(
&self,
iter: I,
init: INIT,
identity: ID,
fold_op: F,
_reduce_op: RD,
) -> R
where
I: IntoIterator<IntoIter: Send, Item: Send> + Send,
INIT: Fn() -> T + Send + Sync,
T: Send,
R: Send,
ID: Fn() -> R + Send + Sync,
F: Fn(R, &mut T, I::Item) -> R + Send + Sync,
RD: Fn(R, R) -> R + Send + Sync,
{
let mut init_val = init();
iter.into_iter()
.fold(identity(), |acc, item| fold_op(acc, &mut init_val, item))
}
fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
where
A: FnOnce() -> RA + Send,
B: FnOnce() -> RB + Send,
RA: Send,
RB: Send,
{
(a(), b())
}
fn parallelism_hint(&self) -> usize {
self.0
}
}
#[derive(Debug)]
struct MockReceiver<P: commonware_cryptography::PublicKey> {
receiver: mpsc::UnboundedReceiver<crate::Message<P>>,
}
impl<P: commonware_cryptography::PublicKey> crate::Receiver for MockReceiver<P> {
type Error = io::Error;
type PublicKey = P;
async fn recv(&mut self) -> Result<crate::Message<Self::PublicKey>, Self::Error> {
self.receiver
.recv()
.await
.ok_or_else(|| io::Error::from(io::ErrorKind::BrokenPipe))
}
}
#[derive(Clone, Default)]
struct NoopBlocker;
impl crate::Blocker for NoopBlocker {
type PublicKey = PublicKey;
async fn block(&mut self, _peer: Self::PublicKey) {}
}
#[test_traced]
fn test_valid_messages_forwarded() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.clone());
let pk1 = pk(0);
let pk2 = pk(1);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]).await;
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.with_label("bg"),
receiver2,
(),
control2.clone(),
16,
&Sequential,
);
let _handle = bg.start();
let msg: u32 = 42;
let _ = sender1
.send(Recipients::One(pk2.clone()), msg.encode(), true)
.await;
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk1);
assert_eq!(value, 42u32);
});
}
#[test_traced]
fn test_invalid_codec_blocks_peer() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.clone());
let pk1 = pk(0);
let pk2 = pk(1);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]).await;
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.with_label("bg"),
receiver2,
(),
control2.clone(),
16,
&Sequential,
);
let _handle = bg.start();
let invalid = IoBuf::from(vec![0xFFu8]);
let _ = sender1
.send(Recipients::One(pk2.clone()), invalid, true)
.await;
let pk3 = pk(2);
let control3 = oracle.control(pk3.clone());
track_peers(&oracle, 1, [pk2.clone(), pk3.clone()]).await;
link_bidirectional(&mut oracle, pk3.clone(), pk2.clone()).await;
let (mut sender3, _) = control3.register(0, TEST_QUOTA).await.unwrap();
let msg: u32 = 99;
let _ = sender3
.send(Recipients::One(pk2.clone()), msg.encode(), true)
.await;
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk3);
assert_eq!(value, 99u32);
let blocked = oracle.blocked().await.unwrap();
assert!(
blocked.contains(&(pk2.clone(), pk1.clone())),
"expected pk1 to be blocked by pk2, blocked list: {:?}",
blocked
);
});
}
#[test_traced]
fn test_multiple_valid_messages() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.clone());
let pk1 = pk(0);
let pk2 = pk(1);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]).await;
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.with_label("bg"),
receiver2,
(),
control2.clone(),
16,
&Sequential,
);
let _handle = bg.start();
let count = 20;
for i in 0..count {
let msg: u32 = i;
let _ = sender1
.send(Recipients::One(pk2.clone()), msg.encode(), true)
.await;
}
let mut received = Vec::new();
for _ in 0..count {
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk1);
received.push(value);
}
received.sort();
assert_eq!(received, (0..count).collect::<Vec<u32>>());
});
}
#[test_traced]
fn test_concurrency_bounded_by_strategy() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.clone());
let pk1 = pk(0);
let pk2 = pk(1);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone()]).await;
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.with_label("bg"),
receiver2,
(),
control2.clone(),
16,
&Sequential,
);
let _handle = bg.start();
let count = 50u32;
for i in 0..count {
let _ = sender1
.send(Recipients::One(pk2.clone()), i.encode(), true)
.await;
}
let mut received = Vec::new();
for _ in 0..count {
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk1);
received.push(value);
}
received.sort();
assert_eq!(received, (0..count).collect::<Vec<u32>>());
});
}
#[test_traced]
fn test_invalid_among_valid_only_blocks_offender() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let mut oracle = start_network(context.clone());
let pk1 = pk(0);
let pk2 = pk(1);
let pk3 = pk(2);
let control1 = oracle.control(pk1.clone());
let control2 = oracle.control(pk2.clone());
let control3 = oracle.control(pk3.clone());
track_peers(&oracle, 0, [pk1.clone(), pk2.clone(), pk3.clone()]).await;
link_bidirectional(&mut oracle, pk1.clone(), pk2.clone()).await;
link_bidirectional(&mut oracle, pk3.clone(), pk2.clone()).await;
let (mut sender1, _) = control1.register(0, TEST_QUOTA).await.unwrap();
let (_, receiver2) = control2.register(0, TEST_QUOTA).await.unwrap();
let (mut sender3, _) = control3.register(0, TEST_QUOTA).await.unwrap();
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.with_label("bg"),
receiver2,
(),
control2.clone(),
16,
&Sequential,
);
let _handle = bg.start();
let _ = sender3
.send(Recipients::One(pk2.clone()), 10u32.encode(), true)
.await;
let _ = sender1
.send(Recipients::One(pk2.clone()), IoBuf::from(vec![0xFF]), true)
.await;
let _ = sender3
.send(Recipients::One(pk2.clone()), 20u32.encode(), true)
.await;
let mut values = Vec::new();
for _ in 0..2 {
let (from, value) = rx.recv().await.unwrap();
assert_eq!(from, pk3);
values.push(value);
}
values.sort();
assert_eq!(values, vec![10u32, 20]);
let blocked = oracle.blocked().await.unwrap();
assert!(blocked.contains(&(pk2.clone(), pk1.clone())));
assert!(!blocked.contains(&(pk2.clone(), pk3.clone())));
});
}
#[test_traced]
fn test_drain_decode_pool_after_receiver_closure() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let sender = pk(0);
let (tx, receiver) = mpsc::unbounded_channel();
let count = 64u32;
for i in 0..count {
tx.send((sender.clone(), IoBuf::from(i.encode())))
.expect("mock receiver should be open");
}
drop(tx);
let (bg, mut rx) = WrappedBackgroundReceiver::<_, _, _, _, u32>::new(
context.with_label("bg"),
MockReceiver { receiver },
(),
NoopBlocker,
count as usize,
&HintStrategy(8),
);
let _handle = bg.start();
let mut values = Vec::new();
while let Some((from, value)) = rx.recv().await {
assert_eq!(from, sender);
values.push(value);
}
values.sort_unstable();
assert_eq!(values, (0..count).collect::<Vec<u32>>());
});
}
}