use super::{Config, Error, Mailbox, Message};
use crate::authenticated::{
data::EncodedData,
lookup::{
channels::{self, Channels},
metrics, types,
},
relay::{try_recv, Message as RelayMessage, Prioritized, Relay},
};
use commonware_actor::mailbox;
use commonware_codec::Decode;
use commonware_cryptography::PublicKey;
use commonware_macros::{select, select_loop};
use commonware_runtime::{
iobuf::EncodeExt, telemetry::metrics::CounterFamily, BufferPooler, Clock, Handle, IoBufs,
Metrics, Quota, RateLimiter, Sink, Spawner, Stream,
};
use commonware_stream::encrypted::{Receiver, Sender};
use commonware_utils::{channel::ring, time::SYSTEM_TIME_PRECISION};
use futures::{FutureExt as _, StreamExt as _};
use rand_core::CryptoRngCore;
use std::{collections::HashMap, sync::Arc, time::Duration};
use tracing::debug;
pub struct Actor<E: Spawner + BufferPooler + Clock + Metrics, C: PublicKey> {
context: E,
ping_frequency: Duration,
send_batch_size: usize,
control: ring::Receiver<Message>,
high: mailbox::UnreliableReceiver<RelayMessage<EncodedData>>,
low: mailbox::UnreliableReceiver<RelayMessage<EncodedData>>,
sent_messages: CounterFamily<metrics::Message<C>>,
received_messages: CounterFamily<metrics::Message<C>>,
rate_limited: CounterFamily<metrics::Message<C>>,
_phantom: std::marker::PhantomData<C>,
}
impl<E: Spawner + BufferPooler + Clock + CryptoRngCore + Metrics, C: PublicKey> Actor<E, C> {
pub fn new(context: E, cfg: Config<C>) -> (Self, Mailbox, Relay<EncodedData>) {
let (control_sender, control_receiver) = Mailbox::new(cfg.mailbox_size);
let (relay, receivers) = Relay::new(context.child("relay"), cfg.mailbox_size);
(
Self {
context,
ping_frequency: cfg.ping_frequency,
send_batch_size: cfg.send_batch_size.get(),
control: control_receiver,
high: receivers.high,
low: receivers.low,
sent_messages: cfg.sent_messages,
received_messages: cfg.received_messages,
rate_limited: cfg.rate_limited,
_phantom: std::marker::PhantomData,
},
control_sender,
relay,
)
}
fn prepare_data<V>(
peer: &C,
msg: EncodedData,
rate_limits: &HashMap<u64, V>,
) -> (metrics::Message<C>, IoBufs) {
let encoded = msg.validate_channel(rate_limits);
(
metrics::Message::new_data(peer, encoded.channel),
encoded.payload,
)
}
fn push_batched(
sent_messages: &CounterFamily<metrics::Message<C>>,
batch: &mut Vec<IoBufs>,
metric: metrics::Message<C>,
payload: IoBufs,
) {
sent_messages.get_or_create(&metric).inc();
batch.push(payload);
}
fn try_recv_control(control: &mut ring::Receiver<Message>) -> Option<Message> {
control.next().now_or_never().flatten()
}
#[allow(clippy::too_many_arguments)]
fn extend_send_many<V>(
peer: &C,
batch_size: usize,
batch: &mut Vec<IoBufs>,
control: &mut ring::Receiver<Message>,
high: &mut mailbox::UnreliableReceiver<RelayMessage<EncodedData>>,
low: &mut mailbox::UnreliableReceiver<RelayMessage<EncodedData>>,
rate_limits: &HashMap<u64, V>,
sent_messages: &CounterFamily<metrics::Message<C>>,
) -> Result<(), Error> {
while batch.len() < batch_size {
if let Some(msg) = Self::try_recv_control(control) {
match msg {
Message::Kill => return Err(Error::PeerKilled(peer.to_string())),
}
}
if let Some(msg) = try_recv(high) {
let (metric, payload) = Self::prepare_data(peer, msg, rate_limits);
Self::push_batched(sent_messages, batch, metric, payload);
continue;
}
if let Some(msg) = try_recv(low) {
let (metric, payload) = Self::prepare_data(peer, msg, rate_limits);
Self::push_batched(sent_messages, batch, metric, payload);
continue;
}
break;
}
Ok(())
}
async fn recv_prioritized(
control: &mut ring::Receiver<Message>,
high: &mut mailbox::UnreliableReceiver<RelayMessage<EncodedData>>,
low: &mut mailbox::UnreliableReceiver<RelayMessage<EncodedData>>,
) -> Prioritized<Message, EncodedData> {
select! {
msg = control.next() => msg.map_or(Prioritized::Closed, Prioritized::Control),
msg = high.recv() => msg.map_or(Prioritized::Closed, |msg| Prioritized::Data(
msg.into_inner()
)),
msg = low.recv() => msg.map_or(Prioritized::Closed, |msg| Prioritized::Data(
msg.into_inner()
)),
}
}
pub async fn run<Si: Sink, St: Stream>(
self,
peer: C,
(mut conn_sender, mut conn_receiver): (Sender<Si>, Receiver<St>),
channels: Channels<C>,
) -> Result<(), Error> {
let mut rate_limits = HashMap::new();
let mut senders = HashMap::new();
for (channel, (rate, sender)) in channels.collect() {
let rate_limiter = RateLimiter::direct_with_clock(
rate,
self.context
.child("rate_limiter")
.with_attribute("channel", channel),
);
rate_limits.insert(channel, rate_limiter);
senders.insert(channel, sender);
}
let rate_limits = Arc::new(rate_limits);
let pool = self.context.network_buffer_pool().clone();
let half = (self.ping_frequency / 2).max(SYSTEM_TIME_PRECISION);
let ping_rate = Quota::with_period(half).unwrap();
let ping_rate_limiter =
RateLimiter::direct_with_clock(ping_rate, self.context.child("ping_rate_limiter"));
let mut send_handler: Handle<Result<(), Error>> = self.context.child("sender").spawn({
let peer = peer.clone();
let rate_limits = rate_limits.clone();
move |context| async move {
let mut deadline = context.current() + self.ping_frequency;
let mut batch = Vec::with_capacity(self.send_batch_size);
let (control, high, low) = &mut (self.control, self.high, self.low);
select_loop! {
context,
on_stopped => {},
_ = context.sleep_until(deadline) => {
Self::push_batched(
&self.sent_messages,
&mut batch,
metrics::Message::new_ping(&peer),
types::Message::Ping.encode_with_pool(&pool),
);
Self::extend_send_many(
&peer,
self.send_batch_size,
&mut batch,
control,
high,
low,
&rate_limits,
&self.sent_messages,
)?;
conn_sender
.send_many(batch.drain(..))
.await
.map_err(Error::SendFailed)?;
deadline = context.current() + self.ping_frequency;
},
msg = Self::recv_prioritized(control, high, low) => {
match msg {
Prioritized::Closed => return Err(Error::PeerDisconnected),
Prioritized::Control(msg) => match msg {
Message::Kill => return Err(Error::PeerKilled(peer.to_string())),
},
Prioritized::Data(encoded) => {
let (metric, payload) =
Self::prepare_data(&peer, encoded, &rate_limits);
Self::push_batched(
&self.sent_messages,
&mut batch,
metric,
payload,
);
}
}
Self::extend_send_many(
&peer,
self.send_batch_size,
&mut batch,
control,
high,
low,
&rate_limits,
&self.sent_messages,
)?;
conn_sender
.send_many(batch.drain(..))
.await
.map_err(Error::SendFailed)?;
},
}
Ok(())
}
});
let mut receive_handler: Handle<Result<(), Error>> =
self.context
.child("receiver")
.spawn(move |context| async move {
loop {
let msg = conn_receiver.recv().await.map_err(Error::ReceiveFailed)?;
let max_data_length = msg.len(); let msg = match types::Message::decode_cfg(msg, &max_data_length) {
Ok(msg) => msg,
Err(err) => {
debug!(?err, ?peer, "failed to decode message");
self.received_messages
.get_or_create(&metrics::Message::new_invalid(&peer))
.inc();
return Err(Error::DecodeFailed(err));
}
};
let (metric, rate_limiter) = match &msg {
types::Message::Data(data) => match rate_limits.get(&data.channel) {
Some(rate_limit) => {
(metrics::Message::new_data(&peer, data.channel), rate_limit)
}
None => {
debug!(?peer, channel = data.channel, "invalid channel");
self.received_messages
.get_or_create(&metrics::Message::new_invalid(&peer))
.inc();
return Err(Error::InvalidChannel);
}
},
types::Message::Ping => {
(metrics::Message::new_ping(&peer), &ping_rate_limiter)
}
};
self.received_messages.get_or_create(&metric).inc();
if let Err(wait_until) = rate_limiter.check() {
self.rate_limited.get_or_create(&metric).inc();
let wait_duration = wait_until.wait_time_from(context.now());
context.sleep(wait_duration).await;
}
match msg {
types::Message::Data(data) => {
let sender = senders.get_mut(&data.channel).unwrap();
let _ =
sender.enqueue(channels::Inbound((peer.clone(), data.message)));
}
types::Message::Ping => {
}
}
}
});
let mut shutdown = self.context.stopped();
let result = select! {
_ = &mut shutdown => {
debug!("context shutdown, stopping peer");
Ok(Ok(()))
},
send_result = &mut send_handler => send_result,
receive_result = &mut receive_handler => receive_result,
};
match result {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(e),
Err(e) => Err(Error::UnexpectedFailure(e)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::authenticated::lookup::{actors::router, channels::Channels};
use commonware_codec::Encode;
use commonware_cryptography::{
ed25519::{PrivateKey, PublicKey},
Signer,
};
use commonware_runtime::{
deterministic, mocks, telemetry::metrics::MetricsExt as _, BufferPooler,
Error as RuntimeError, IoBuf, IoBufs, Runner, Spawner, Supervisor as _,
};
use commonware_stream::encrypted::Config as StreamConfig;
use commonware_utils::NZUsize;
use std::{
num::NonZeroU32,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
const STREAM_NAMESPACE: &[u8] = b"test_lookup_peer_actor";
const MAX_MESSAGE_SIZE: u32 = 64 * 1024;
struct CountingSink<S> {
inner: S,
sends: Arc<AtomicUsize>,
}
impl<S> CountingSink<S> {
fn new(inner: S, sends: Arc<AtomicUsize>) -> Self {
Self { inner, sends }
}
}
impl<S: commonware_runtime::Sink> commonware_runtime::Sink for CountingSink<S> {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), RuntimeError> {
self.sends.fetch_add(1, Ordering::Relaxed);
self.inner.send(bufs).await
}
}
fn default_peer_config(context: impl Metrics) -> Config<PublicKey> {
Config {
mailbox_size: NZUsize!(10),
send_batch_size: NZUsize!(8),
ping_frequency: Duration::from_secs(30),
sent_messages: context.family("sent_messages", "test sent messages"),
received_messages: context.family("received_messages", "test received messages"),
rate_limited: context.family("rate_limited", "test rate limited messages"),
}
}
fn stream_config<S: Signer>(key: S) -> StreamConfig<S> {
StreamConfig {
signing_key: key,
namespace: STREAM_NAMESPACE.to_vec(),
max_message_size: MAX_MESSAGE_SIZE,
synchrony_bound: Duration::from_secs(10),
max_handshake_age: Duration::from_secs(10),
handshake_timeout: Duration::from_secs(10),
}
}
fn create_channels(context: impl BufferPooler + Metrics) -> Channels<PublicKey> {
let (router_sender, _router_receiver) = commonware_actor::mailbox::new_unreliable::<
router::Message<PublicKey>,
>(
context.child("router_mailbox"), NZUsize!(10)
);
let messenger = router::Messenger::new(
context.network_buffer_pool().clone(),
router::Mailbox::new(router_sender),
);
Channels::new(messenger, MAX_MESSAGE_SIZE)
}
#[test]
fn test_invalid_channel_no_unbounded_metric_cardinality() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let local_key = PrivateKey::from_seed(1);
let remote_key = PrivateKey::from_seed(2);
let local_pk = local_key.public_key();
let remote_pk = remote_key.public_key();
let (local_sink, remote_stream) = mocks::Channel::init();
let (remote_sink, local_stream) = mocks::Channel::init();
let local_config = stream_config(local_key.clone());
let remote_config = stream_config(remote_key.clone());
let local_pk_clone = local_pk.clone();
let listener_handle = context.child("listener").spawn({
move |ctx| async move {
commonware_stream::encrypted::listen(
ctx,
|_| async { true },
remote_config,
remote_stream,
remote_sink,
)
.await
.map(|(pk, sender, receiver)| {
assert_eq!(pk, local_pk_clone);
(sender, receiver)
})
}
});
let (mut local_sender, _local_receiver) = commonware_stream::encrypted::dial(
context.child("dialer"),
local_config,
remote_pk.clone(),
local_stream,
local_sink,
)
.await
.expect("dial failed");
let (remote_sender, remote_receiver) = listener_handle
.await
.expect("listen failed")
.expect("listen result failed");
let received_messages = context.family(
"received_messages_override",
"test received messages override",
);
let cfg = Config {
received_messages: received_messages.clone(),
..default_peer_config(context.child("config"))
};
let (peer_actor, _mailbox, _relay) =
Actor::<deterministic::Context, PublicKey>::new(context.child("actor"), cfg);
let mut channels = create_channels(context.child("channels"));
let quota =
commonware_runtime::Quota::per_second(std::num::NonZeroU32::new(100).unwrap());
let (_sender, _receiver) = channels.register(0, quota, 10, context.child("channel"));
let invalid_channel = 99999;
let msg = types::Message::Data(crate::authenticated::data::Data {
channel: invalid_channel,
message: commonware_runtime::IoBuf::from(b"attack"),
});
local_sender.send(msg.encode()).await.expect("send failed");
let result = peer_actor
.run(local_pk.clone(), (remote_sender, remote_receiver), channels)
.await;
assert!(
matches!(result, Err(Error::InvalidChannel)),
"Expected InvalidChannel error, got: {result:?}"
);
let attacker_metric = metrics::Message::new_data(&local_pk, invalid_channel);
let attacker_count = received_messages.get_or_create(&attacker_metric).get();
assert_eq!(
attacker_count, 0,
"metric was created for attacker-controlled channel, unbounded cardinality bug"
);
let invalid_metric = metrics::Message::new_invalid(&local_pk);
let invalid_count = received_messages.get_or_create(&invalid_metric).get();
assert_eq!(
invalid_count, 1,
"invalid channel metric should be incremented"
);
});
}
#[test]
fn test_batches_outbound_sends_into_single_runtime_write() {
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let local_key = PrivateKey::from_seed(1);
let remote_key = PrivateKey::from_seed(2);
let local_pk = local_key.public_key();
let remote_pk = remote_key.public_key();
let (local_sink, remote_stream) = mocks::Channel::init();
let (remote_sink, local_stream) = mocks::Channel::init();
let sends = Arc::new(AtomicUsize::new(0));
let local_config = stream_config(local_key.clone());
let remote_config = stream_config(remote_key.clone());
let local_pk_clone = local_pk.clone();
let listener_handle = context.child("listener").spawn({
let sends = sends.clone();
move |ctx| async move {
commonware_stream::encrypted::listen(
ctx,
|_| async { true },
remote_config,
remote_stream,
CountingSink::new(remote_sink, sends),
)
.await
.map(|(pk, sender, receiver)| {
assert_eq!(pk, local_pk_clone);
(sender, receiver)
})
}
});
let (_local_sender, mut local_receiver) = commonware_stream::encrypted::dial(
context.child("dialer"),
local_config,
remote_pk.clone(),
local_stream,
local_sink,
)
.await
.expect("dial failed");
let (remote_sender, remote_receiver) = listener_handle
.await
.expect("listen failed")
.expect("listen result failed");
sends.store(0, Ordering::Relaxed);
let cfg = Config {
send_batch_size: NZUsize!(2),
..default_peer_config(context.child("config"))
};
let (peer_actor, peer_mailbox, relay) =
Actor::<deterministic::Context, PublicKey>::new(context.child("actor"), cfg);
let mut channels = create_channels(context.child("channels"));
let quota = commonware_runtime::Quota::per_second(NonZeroU32::new(100).unwrap());
let (_sender, _receiver) = channels.register(0, quota, 10, context.child("channel"));
let pool = context.network_buffer_pool().clone();
assert!(
relay
.send(
types::Message::encode_data(&pool, 0, IoBufs::from(IoBuf::from(b"first"))),
false,
)
.accepted(),
"first send failed"
);
assert!(
relay
.send(
types::Message::encode_data(&pool, 0, IoBufs::from(IoBuf::from(b"second"))),
false,
)
.accepted(),
"second send failed"
);
let peer_handle = context.child("task").spawn(move |_context| async move {
peer_actor
.run(local_pk.clone(), (remote_sender, remote_receiver), channels)
.await
});
let first = local_receiver.recv().await.expect("recv failed");
let first_len = first.len();
let first = types::Message::decode_cfg(first, &first_len).expect("decode failed");
let types::Message::Data(first) = first else {
panic!("expected data message");
};
assert_eq!(first.message, IoBuf::from(b"first"));
let second = local_receiver.recv().await.expect("recv failed");
let second_len = second.len();
let second = types::Message::decode_cfg(second, &second_len).expect("decode failed");
let types::Message::Data(second) = second else {
panic!("expected data message");
};
assert_eq!(second.message, IoBuf::from(b"second"));
assert_eq!(sends.load(Ordering::Relaxed), 1);
peer_mailbox.kill();
let result = peer_handle.await.expect("peer task failed");
assert!(
matches!(result, Err(Error::PeerKilled(_))),
"unexpected result: {result:?}"
);
});
}
}