use super::{Config, Error, Message};
use crate::authenticated::{
data::EncodedData,
discovery::{
actors::tracker,
channels::Channels,
metrics,
types::{self, InfoVerifier},
},
mailbox::UnboundedMailbox,
relay::{recv_prioritized, Prioritized, Relay},
Mailbox,
};
use commonware_codec::Decode;
use commonware_cryptography::PublicKey;
use commonware_macros::{select, select_loop};
use commonware_runtime::{
iobuf::EncodeExt, BufferPooler, Clock, Handle, IoBufs, Metrics, Quota, RateLimiter, Sink,
Spawner, Stream,
};
use commonware_stream::encrypted::{Receiver, Sender};
use commonware_utils::{
channel::mpsc::{self, error::TrySendError},
time::SYSTEM_TIME_PRECISION,
};
use prometheus_client::metrics::{counter::Counter, family::Family};
use rand_core::CryptoRngCore;
use std::{collections::HashMap, sync::Arc, time::Duration};
use tracing::debug;
pub struct Actor<E: Spawner + BufferPooler + Clock + Metrics, C: PublicKey> {
context: E,
gossip_bit_vec_frequency: Duration,
send_batch_size: usize,
info_verifier: InfoVerifier<C>,
max_bit_vec: u64,
max_peers: usize,
mailbox: Mailbox<Message<C>>,
control: mpsc::Receiver<Message<C>>,
high: mpsc::Receiver<EncodedData>,
low: mpsc::Receiver<EncodedData>,
sent_messages: Family<metrics::Message, Counter>,
received_messages: Family<metrics::Message, Counter>,
dropped_messages: Family<metrics::Message, Counter>,
rate_limited: Family<metrics::Message, Counter>,
}
impl<E: Spawner + BufferPooler + Clock + CryptoRngCore + Metrics, C: PublicKey> Actor<E, C> {
pub fn new(context: E, cfg: Config<C>) -> (Self, Relay<EncodedData>) {
let (control_sender, control_receiver) = Mailbox::new(cfg.mailbox_size);
let (high_sender, high_receiver) = mpsc::channel(cfg.mailbox_size);
let (low_sender, low_receiver) = mpsc::channel(cfg.mailbox_size);
(
Self {
context,
mailbox: control_sender,
gossip_bit_vec_frequency: cfg.gossip_bit_vec_frequency,
send_batch_size: cfg.send_batch_size.get(),
info_verifier: cfg.info_verifier,
max_bit_vec: cfg.max_peer_set_size,
max_peers: cfg.peer_gossip_max_count,
control: control_receiver,
high: high_receiver,
low: low_receiver,
sent_messages: cfg.sent_messages,
received_messages: cfg.received_messages,
dropped_messages: cfg.dropped_messages,
rate_limited: cfg.rate_limited,
},
Relay::new(low_sender, high_sender),
)
}
fn prepare_control(
peer: &C,
msg: Message<C>,
pool: &commonware_runtime::BufferPool,
) -> Result<(metrics::Message, IoBufs), Error> {
let (metric, payload) = match msg {
Message::BitVec(bit_vec) => (
metrics::Message::new_bit_vec(peer),
types::Payload::BitVec(bit_vec),
),
Message::Peers(peers) => (
metrics::Message::new_peers(peer),
types::Payload::Peers(peers),
),
Message::Kill => return Err(Error::PeerKilled(peer.to_string())),
};
Ok((metric, payload.encode_with_pool(pool)))
}
fn prepare_data<V>(
peer: &C,
msg: EncodedData,
rate_limits: &HashMap<u64, V>,
) -> (metrics::Message, IoBufs) {
let encoded = msg.validate_channel(rate_limits);
(
metrics::Message::new_data(peer, encoded.channel),
encoded.payload,
)
}
fn push_batched(
sent_messages: &Family<metrics::Message, Counter>,
batch: &mut Vec<IoBufs>,
metric: metrics::Message,
payload: IoBufs,
) {
sent_messages.get_or_create(&metric).inc();
batch.push(payload);
}
#[allow(clippy::too_many_arguments)]
fn extend_send_many<V>(
peer: &C,
batch_size: usize,
batch: &mut Vec<IoBufs>,
control: &mut mpsc::Receiver<Message<C>>,
pool: &commonware_runtime::BufferPool,
high: &mut mpsc::Receiver<EncodedData>,
low: &mut mpsc::Receiver<EncodedData>,
rate_limits: &HashMap<u64, V>,
sent_messages: &Family<metrics::Message, Counter>,
) -> Result<(), Error> {
while batch.len() < batch_size {
if let Ok(msg) = control.try_recv() {
let (metric, payload) = Self::prepare_control(peer, msg, pool)?;
Self::push_batched(sent_messages, batch, metric, payload);
continue;
}
if let Ok(msg) = high.try_recv() {
let (metric, payload) = Self::prepare_data(peer, msg, rate_limits);
Self::push_batched(sent_messages, batch, metric, payload);
continue;
}
if let Ok(msg) = low.try_recv() {
let (metric, payload) = Self::prepare_data(peer, msg, rate_limits);
Self::push_batched(sent_messages, batch, metric, payload);
continue;
}
break;
}
Ok(())
}
pub async fn run<O: Sink, I: Stream>(
self,
peer: C,
greeting: types::Info<C>,
(mut conn_sender, mut conn_receiver): (Sender<O>, Receiver<I>),
mut tracker: UnboundedMailbox<tracker::Message<C>>,
channels: Channels<C>,
) -> Result<(), Error> {
let mut rate_limits = HashMap::new();
let mut senders = HashMap::new();
for (channel, (rate, sender)) in channels.collect() {
let rate_limiter = RateLimiter::direct_with_clock(rate, self.context.clone());
rate_limits.insert(channel, rate_limiter);
senders.insert(channel, sender);
}
let rate_limits = Arc::new(rate_limits);
let pool = self.context.network_buffer_pool().clone();
self.sent_messages
.get_or_create(&metrics::Message::new_greeting(&peer))
.inc();
conn_sender
.send(types::Payload::Greeting(greeting).encode_with_pool(&pool))
.await
.map_err(Error::SendFailed)?;
let mut send_handler: Handle<Result<(), Error>> =
self.context.with_label("sender").spawn({
let peer = peer.clone();
let mut tracker = tracker.clone();
let mailbox = self.mailbox.clone();
let rate_limits = rate_limits.clone();
move |context| async move {
let mut deadline = context.current();
let mut batch = Vec::with_capacity(self.send_batch_size);
let (control, high, low) = &mut (self.control, self.high, self.low);
select_loop! {
context,
on_stopped => {},
_ = context.sleep_until(deadline) => {
tracker.construct(peer.clone(), mailbox.clone());
deadline = context.current() + self.gossip_bit_vec_frequency;
},
msg = recv_prioritized(control, high, low) => {
let (metric, payload) = match msg {
Prioritized::Closed => return Err(Error::PeerDisconnected),
Prioritized::Control(msg) => {
Self::prepare_control(&peer, msg, &pool)?
}
Prioritized::Data(encoded) => {
Self::prepare_data(&peer, encoded, &rate_limits)
}
};
Self::push_batched(&self.sent_messages, &mut batch, metric, payload);
Self::extend_send_many(
&peer,
self.send_batch_size,
&mut batch,
control,
&pool,
high,
low,
&rate_limits,
&self.sent_messages,
)?;
conn_sender
.send_many(batch.drain(..))
.await
.map_err(Error::SendFailed)?;
},
}
Ok(())
}
});
let mut receive_handler: Handle<Result<(), Error>> = self
.context
.with_label("receiver")
.spawn(move |context| async move {
let half = (self.gossip_bit_vec_frequency / 2).max(SYSTEM_TIME_PRECISION);
let rate = Quota::with_period(half).unwrap();
let bit_vec_rate_limiter =
RateLimiter::direct_with_clock(rate, context.clone());
let peers_rate_limiter =
RateLimiter::direct_with_clock(rate, context.clone());
let mut greeting_received = false;
let mut first_bit_vec_received = false;
let mut first_peers_received = false;
loop {
let msg = conn_receiver.recv().await.map_err(Error::ReceiveFailed)?;
let cfg = types::PayloadConfig {
max_bit_vec: self.max_bit_vec,
max_peers: self.max_peers,
max_data_length: msg.len(), };
let msg = match types::Payload::decode_cfg(msg, &cfg) {
Ok(msg) => msg,
Err(err) => {
debug!(?err, ?peer, "failed to decode message");
self.received_messages
.get_or_create(&metrics::Message::new_invalid(&peer))
.inc();
return Err(Error::DecodeFailed(err));
}
};
if let types::Payload::Greeting(info) = msg {
self.received_messages
.get_or_create(&metrics::Message::new_greeting(&peer))
.inc();
if greeting_received {
debug!(?peer, "received duplicate greeting");
return Err(Error::DuplicateGreeting);
}
greeting_received = true;
if info.public_key != peer {
debug!(?peer, greeting_pk = ?info.public_key, "greeting public key mismatch");
return Err(Error::GreetingMismatch);
}
self.info_verifier.validate(&context, std::slice::from_ref(&info)).map_err(Error::Types)?;
tracker.peers(vec![info]);
continue;
} else if !greeting_received {
debug!(?peer, "expected greeting as first message");
return Err(Error::MissingGreeting);
}
let (metric, rate_limiter) = match &msg {
types::Payload::Data(data) => match rate_limits.get(&data.channel) {
Some(rate_limit) => {
(metrics::Message::new_data(&peer, data.channel), Some(rate_limit))
}
None => {
debug!(?peer, channel = data.channel, "invalid channel");
self.received_messages
.get_or_create(&metrics::Message::new_invalid(&peer))
.inc();
return Err(Error::InvalidChannel);
}
},
types::Payload::Greeting(_) => unreachable!(),
types::Payload::BitVec(_) => {
let rate_limiter = if first_bit_vec_received {
Some(&bit_vec_rate_limiter)
} else {
first_bit_vec_received = true;
None
};
(metrics::Message::new_bit_vec(&peer), rate_limiter)
}
types::Payload::Peers(_) => {
let rate_limiter = if first_peers_received {
Some(&peers_rate_limiter)
} else {
first_peers_received = true;
None
};
(metrics::Message::new_peers(&peer), rate_limiter)
}
};
self.received_messages.get_or_create(&metric).inc();
if let Some(rate_limiter) = rate_limiter {
if let Err(wait_until) = rate_limiter.check() {
self.rate_limited.get_or_create(&metric).inc();
let wait_duration = wait_until.wait_time_from(context.now());
context.sleep(wait_duration).await;
}
}
match msg {
types::Payload::Data(data) => {
let sender = senders.get_mut(&data.channel).unwrap();
if let Err(e) = sender.try_send((peer.clone(), data.message)) {
if matches!(e, TrySendError::Full(_)) {
self.dropped_messages
.get_or_create(&metrics::Message::new_data(&peer, data.channel))
.inc();
}
debug!(err=?e, channel=data.channel, "failed to send message to client");
}
}
types::Payload::Greeting(_) => unreachable!(),
types::Payload::BitVec(bit_vec) => {
tracker.bit_vec(bit_vec, self.mailbox.clone());
}
types::Payload::Peers(peers) => {
self.info_verifier.validate(&context, &peers).map_err(Error::Types)?;
tracker.peers(peers);
}
}
}
});
let mut shutdown = self.context.stopped();
let result = select! {
_ = &mut shutdown => {
debug!("context shutdown, stopping peer");
Ok(Ok(()))
},
send_result = &mut send_handler => send_result,
receive_result = &mut receive_handler => receive_result,
};
match result {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(e),
Err(e) => Err(Error::UnexpectedFailure(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::authenticated::{
discovery::{
actors::{router, tracker},
channels::Channels,
},
mailbox::UnboundedMailbox,
Mailbox,
};
use commonware_codec::Encode;
use commonware_cryptography::{
ed25519::{PrivateKey, PublicKey},
Signer,
};
use commonware_runtime::{deterministic, mocks, BufferPooler, IoBuf, Runner, Spawner};
use commonware_stream::encrypted::Config as StreamConfig;
use commonware_utils::{bitmap::BitMap, NZUsize, SystemTimeExt};
use prometheus_client::metrics::{counter::Counter, family::Family};
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
time::Duration,
};
const STREAM_NAMESPACE: &[u8] = b"test_peer_actor";
const IP_NAMESPACE: &[u8] = b"test_peer_actor_IP";
const MAX_MESSAGE_SIZE: u32 = 64 * 1024;
fn default_peer_config(me: PublicKey) -> Config<PublicKey> {
Config {
mailbox_size: 10,
send_batch_size: NZUsize!(8),
gossip_bit_vec_frequency: Duration::from_secs(30),
max_peer_set_size: 128,
peer_gossip_max_count: 10,
info_verifier: types::Info::verifier(
me,
10,
Duration::from_secs(60),
IP_NAMESPACE.to_vec(),
),
sent_messages: Family::<metrics::Message, Counter>::default(),
received_messages: Family::<metrics::Message, Counter>::default(),
dropped_messages: Family::<metrics::Message, Counter>::default(),
rate_limited: Family::<metrics::Message, Counter>::default(),
}
}
fn stream_config<S: Signer>(key: S) -> StreamConfig<S> {
StreamConfig {
signing_key: key,
namespace: STREAM_NAMESPACE.to_vec(),
max_message_size: MAX_MESSAGE_SIZE,
synchrony_bound: Duration::from_secs(10),
max_handshake_age: Duration::from_secs(10),
handshake_timeout: Duration::from_secs(10),
}
}
fn create_channels(context: &impl BufferPooler) -> Channels<PublicKey> {
let (router_mailbox, _router_receiver) = Mailbox::<router::Message<PublicKey>>::new(10);
let messenger =
router::Messenger::new(context.network_buffer_pool().clone(), router_mailbox);
Channels::new(messenger, MAX_MESSAGE_SIZE)
}
#[test]
fn test_missing_greeting_returns_error() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let local_key = PrivateKey::from_seed(1);
let remote_key = PrivateKey::from_seed(2);
let local_pk = local_key.public_key();
let remote_pk = remote_key.public_key();
let (local_sink, remote_stream) = mocks::Channel::init();
let (remote_sink, local_stream) = mocks::Channel::init();
let local_config = stream_config(local_key.clone());
let remote_config = stream_config(remote_key.clone());
let local_pk_clone = local_pk.clone();
let listener_handle = context.clone().spawn({
move |ctx| async move {
commonware_stream::encrypted::listen(
ctx,
|_| async { true },
remote_config,
remote_stream,
remote_sink,
)
.await
.map(|(pk, sender, receiver)| {
assert_eq!(pk, local_pk_clone);
(sender, receiver)
})
}
});
let (mut local_sender, _local_receiver) = commonware_stream::encrypted::dial(
context.clone(),
local_config,
remote_pk.clone(),
local_stream,
local_sink,
)
.await
.expect("dial failed");
let (remote_sender, remote_receiver) = listener_handle
.await
.expect("listen failed")
.expect("listen result failed");
let (peer_actor, _messenger) = Actor::<deterministic::Context, PublicKey>::new(
context.clone(),
default_peer_config(remote_pk),
);
let greeting = types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
context.current().epoch().as_millis() as u64,
);
let (tracker_mailbox, _tracker_receiver) =
UnboundedMailbox::<tracker::Message<PublicKey>>::new();
let channels = create_channels(&context);
let bit_vec = types::Payload::<PublicKey>::BitVec(types::BitVec {
index: 0,
bits: BitMap::ones(10),
});
local_sender
.send(bit_vec.encode())
.await
.expect("send failed");
let result = peer_actor
.run(
local_pk,
greeting,
(remote_sender, remote_receiver),
tracker_mailbox,
channels,
)
.await;
assert!(
matches!(result, Err(Error::MissingGreeting)),
"Expected MissingGreeting error, got: {result:?}"
);
});
}
#[test]
fn test_duplicate_greeting_returns_error() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let local_key = PrivateKey::from_seed(1);
let remote_key = PrivateKey::from_seed(2);
let local_pk = local_key.public_key();
let remote_pk = remote_key.public_key();
let (local_sink, remote_stream) = mocks::Channel::init();
let (remote_sink, local_stream) = mocks::Channel::init();
let local_config = stream_config(local_key.clone());
let remote_config = stream_config(remote_key.clone());
let local_pk_clone = local_pk.clone();
let listener_handle = context.clone().spawn({
move |ctx| async move {
commonware_stream::encrypted::listen(
ctx,
|_| async { true },
remote_config,
remote_stream,
remote_sink,
)
.await
.map(|(pk, sender, receiver)| {
assert_eq!(pk, local_pk_clone);
(sender, receiver)
})
}
});
let (mut local_sender, _local_receiver) = commonware_stream::encrypted::dial(
context.clone(),
local_config,
remote_pk.clone(),
local_stream,
local_sink,
)
.await
.expect("dial failed");
let (remote_sender, remote_receiver) = listener_handle
.await
.expect("listen failed")
.expect("listen result failed");
let (peer_actor, _messenger) = Actor::<deterministic::Context, PublicKey>::new(
context.clone(),
default_peer_config(remote_pk),
);
let greeting = types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
context.current().epoch().as_millis() as u64,
);
let (tracker_mailbox, _tracker_receiver) =
UnboundedMailbox::<tracker::Message<PublicKey>>::new();
let channels = create_channels(&context);
let first_greeting = types::Payload::<PublicKey>::Greeting(greeting.clone());
local_sender
.send(first_greeting.encode())
.await
.expect("send failed");
let second_greeting = types::Payload::<PublicKey>::Greeting(greeting.clone());
local_sender
.send(second_greeting.encode())
.await
.expect("send failed");
let result = peer_actor
.run(
local_pk,
greeting,
(remote_sender, remote_receiver),
tracker_mailbox,
channels,
)
.await;
assert!(
matches!(result, Err(Error::DuplicateGreeting)),
"Expected DuplicateGreeting error, got: {result:?}"
);
});
}
#[test]
fn test_greeting_public_key_mismatch_returns_error() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let local_key = PrivateKey::from_seed(1);
let remote_key = PrivateKey::from_seed(2);
let wrong_key = PrivateKey::from_seed(3);
let local_pk = local_key.public_key();
let remote_pk = remote_key.public_key();
let wrong_pk = wrong_key.public_key();
let (local_sink, remote_stream) = mocks::Channel::init();
let (remote_sink, local_stream) = mocks::Channel::init();
let local_config = stream_config(local_key.clone());
let remote_config = stream_config(remote_key.clone());
let local_pk_clone = local_pk.clone();
let listener_handle = context.clone().spawn({
move |ctx| async move {
commonware_stream::encrypted::listen(
ctx,
|_| async { true },
remote_config,
remote_stream,
remote_sink,
)
.await
.map(|(pk, sender, receiver)| {
assert_eq!(pk, local_pk_clone);
(sender, receiver)
})
}
});
let (mut local_sender, _local_receiver) = commonware_stream::encrypted::dial(
context.clone(),
local_config,
remote_pk.clone(),
local_stream,
local_sink,
)
.await
.expect("dial failed");
let (remote_sender, remote_receiver) = listener_handle
.await
.expect("listen failed")
.expect("listen result failed");
let (peer_actor, _messenger) = Actor::<deterministic::Context, PublicKey>::new(
context.clone(),
default_peer_config(remote_pk),
);
let greeting = types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
context.current().epoch().as_millis() as u64,
);
let (tracker_mailbox, _tracker_receiver) =
UnboundedMailbox::<tracker::Message<PublicKey>>::new();
let channels = create_channels(&context);
let mut wrong_greeting = types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
context.current().epoch().as_millis() as u64,
);
wrong_greeting.public_key = wrong_pk;
let greeting_payload = types::Payload::<PublicKey>::Greeting(wrong_greeting);
local_sender
.send(greeting_payload.encode())
.await
.expect("send failed");
let result = peer_actor
.run(
local_pk,
greeting,
(remote_sender, remote_receiver),
tracker_mailbox,
channels,
)
.await;
assert!(
matches!(result, Err(Error::GreetingMismatch)),
"Expected GreetingMismatch error, got: {result:?}"
);
});
}
#[test]
fn test_dropped_messages_metric_on_full_buffer() {
let executor = deterministic::Runner::timed(Duration::from_secs(10));
executor.start(|context| async move {
let local_key = PrivateKey::from_seed(1);
let remote_key = PrivateKey::from_seed(2);
let local_pk = local_key.public_key();
let remote_pk = remote_key.public_key();
let (local_sink, remote_stream) = mocks::Channel::init();
let (remote_sink, local_stream) = mocks::Channel::init();
let local_config = stream_config(local_key.clone());
let remote_config = stream_config(remote_key.clone());
let local_pk_clone = local_pk.clone();
let listener_handle = context.clone().spawn({
move |ctx| async move {
commonware_stream::encrypted::listen(
ctx,
|_| async { true },
remote_config,
remote_stream,
remote_sink,
)
.await
.map(|(pk, sender, receiver)| {
assert_eq!(pk, local_pk_clone);
(sender, receiver)
})
}
});
let (mut local_sender, _local_receiver) = commonware_stream::encrypted::dial(
context.clone(),
local_config,
remote_pk.clone(),
local_stream,
local_sink,
)
.await
.expect("dial failed");
let (remote_sender, remote_receiver) = listener_handle
.await
.expect("listen failed")
.expect("listen result failed");
let dropped_messages = Family::<metrics::Message, Counter>::default();
let config = Config {
mailbox_size: 10,
send_batch_size: NZUsize!(8),
gossip_bit_vec_frequency: Duration::from_secs(30),
max_peer_set_size: 128,
peer_gossip_max_count: 10,
info_verifier: types::Info::verifier(
remote_pk.clone(),
10,
Duration::from_secs(60),
IP_NAMESPACE.to_vec(),
),
sent_messages: Family::<metrics::Message, Counter>::default(),
received_messages: Family::<metrics::Message, Counter>::default(),
dropped_messages: dropped_messages.clone(),
rate_limited: Family::<metrics::Message, Counter>::default(),
};
let (peer_actor, _messenger) =
Actor::<deterministic::Context, PublicKey>::new(context.clone(), config);
let greeting = types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
context.current().epoch().as_millis() as u64,
);
let (tracker_mailbox, _tracker_receiver) =
UnboundedMailbox::<tracker::Message<PublicKey>>::new();
let (router_mailbox, _router_receiver) = Mailbox::<router::Message<PublicKey>>::new(10);
let messenger =
router::Messenger::new(context.network_buffer_pool().clone(), router_mailbox);
let mut channels = Channels::new(messenger, MAX_MESSAGE_SIZE);
let channel_id = 0u64;
let (_sender, _receiver) = channels.register(
channel_id,
Quota::per_second(std::num::NonZeroU32::new(100).unwrap()),
1, context.clone(),
);
let local_pk_clone = local_pk.clone();
context.clone().spawn(move |_| async move {
let greeting_payload = types::Payload::<PublicKey>::Greeting(types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
0,
));
local_sender
.send(greeting_payload.encode())
.await
.expect("send greeting failed");
for i in 0..5 {
let data =
types::Payload::<PublicKey>::Data(crate::authenticated::data::Data {
channel: channel_id,
message: IoBuf::from(vec![i as u8; 100]),
});
let _ = local_sender.send(data.encode()).await;
}
});
let _ = peer_actor
.run(
local_pk_clone.clone(),
greeting,
(remote_sender, remote_receiver),
tracker_mailbox,
channels,
)
.await;
let metric_label = metrics::Message::new_data(&local_pk_clone, channel_id);
let dropped_count = dropped_messages.get_or_create(&metric_label).get();
assert!(
dropped_count > 0,
"Expected dropped_messages to be incremented when buffer is full, got {dropped_count}"
);
});
}
#[test]
fn test_invalid_channel_no_unbounded_metric_cardinality() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let local_key = PrivateKey::from_seed(1);
let remote_key = PrivateKey::from_seed(2);
let local_pk = local_key.public_key();
let remote_pk = remote_key.public_key();
let (local_sink, remote_stream) = mocks::Channel::init();
let (remote_sink, local_stream) = mocks::Channel::init();
let local_config = stream_config(local_key.clone());
let remote_config = stream_config(remote_key.clone());
let local_pk_clone = local_pk.clone();
let listener_handle = context.clone().spawn({
move |ctx| async move {
commonware_stream::encrypted::listen(
ctx,
|_| async { true },
remote_config,
remote_stream,
remote_sink,
)
.await
.map(|(pk, sender, receiver)| {
assert_eq!(pk, local_pk_clone);
(sender, receiver)
})
}
});
let (mut local_sender, _local_receiver) = commonware_stream::encrypted::dial(
context.clone(),
local_config,
remote_pk.clone(),
local_stream,
local_sink,
)
.await
.expect("dial failed");
let (remote_sender, remote_receiver) = listener_handle
.await
.expect("listen failed")
.expect("listen result failed");
let received_messages = Family::<metrics::Message, Counter>::default();
let cfg = Config {
received_messages: received_messages.clone(),
..default_peer_config(remote_pk)
};
let (peer_actor, _messenger) =
Actor::<deterministic::Context, PublicKey>::new(context.clone(), cfg);
let greeting = types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
context.current().epoch().as_millis() as u64,
);
let (tracker_mailbox, _tracker_receiver) =
UnboundedMailbox::<tracker::Message<PublicKey>>::new();
let mut channels = create_channels(&context);
let quota =
commonware_runtime::Quota::per_second(std::num::NonZeroU32::new(100).unwrap());
let (_sender, _receiver) = channels.register(0, quota, 10, context.clone());
let local_pk_clone = local_pk.clone();
context.clone().spawn(move |_ctx| async move {
let greeting_payload = types::Payload::<PublicKey>::Greeting(types::Info::sign(
&local_key,
IP_NAMESPACE,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080),
0,
));
local_sender
.send(greeting_payload.encode())
.await
.expect("send greeting failed");
let data = types::Payload::<PublicKey>::Data(crate::authenticated::data::Data {
channel: 99999,
message: IoBuf::from(b"attack"),
});
local_sender.send(data.encode()).await.expect("send failed");
});
let result = peer_actor
.run(
local_pk_clone.clone(),
greeting,
(remote_sender, remote_receiver),
tracker_mailbox,
channels,
)
.await;
assert!(
matches!(result, Err(Error::InvalidChannel)),
"Expected InvalidChannel error, got: {result:?}"
);
let attacker_metric = metrics::Message::new_data(&local_pk_clone, 99999);
let attacker_count = received_messages.get_or_create(&attacker_metric).get();
assert_eq!(
attacker_count, 0,
"metric was created for attacker-controlled channel, unbounded cardinality bug"
);
let invalid_metric = metrics::Message::new_invalid(&local_pk_clone);
let invalid_count = received_messages.get_or_create(&invalid_metric).get();
assert_eq!(invalid_count, 1);
});
}
}