use crate::{IoBufs, SinkOf, StreamOf};
use prometheus_client::{metrics::counter::Counter, registry::Registry};
use std::{net::SocketAddr, sync::Arc};
#[derive(Debug)]
struct Metrics {
inbound_connections: Counter,
outbound_connections: Counter,
inbound_bandwidth: Counter,
outbound_bandwidth: Counter,
}
impl Metrics {
fn new(registry: &mut Registry) -> Self {
let metrics = Self {
inbound_connections: Counter::default(),
outbound_connections: Counter::default(),
inbound_bandwidth: Counter::default(),
outbound_bandwidth: Counter::default(),
};
registry.register(
"inbound_connections",
"Number of connections created by dialing us",
metrics.inbound_connections.clone(),
);
registry.register(
"outbound_connections",
"Number of connections created by dialing others",
metrics.outbound_connections.clone(),
);
registry.register(
"inbound_bandwidth",
"Bandwidth used by receiving data from others",
metrics.inbound_bandwidth.clone(),
);
registry.register(
"outbound_bandwidth",
"Bandwidth used by sending data to others",
metrics.outbound_bandwidth.clone(),
);
metrics
}
}
pub struct Sink<S: crate::Sink> {
inner: S,
metrics: Arc<Metrics>,
}
impl<S: crate::Sink> crate::Sink for Sink<S> {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), crate::Error> {
let bufs = bufs.into();
let len = bufs.len();
self.inner.send(bufs).await?;
self.metrics.outbound_bandwidth.inc_by(len as u64);
Ok(())
}
}
pub struct Stream<S: crate::Stream> {
inner: S,
metrics: Arc<Metrics>,
}
impl<S: crate::Stream> crate::Stream for Stream<S> {
async fn recv(&mut self, len: usize) -> Result<IoBufs, crate::Error> {
let bufs = self.inner.recv(len).await?;
self.metrics.inbound_bandwidth.inc_by(len as u64);
Ok(bufs)
}
fn peek(&self, max_len: usize) -> &[u8] {
self.inner.peek(max_len)
}
}
pub struct Listener<L: crate::Listener> {
inner: L,
metrics: Arc<Metrics>,
}
impl<L: crate::Listener> crate::Listener for Listener<L> {
type Sink = Sink<L::Sink>;
type Stream = Stream<L::Stream>;
async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), crate::Error> {
let (addr, sink, stream) = self.inner.accept().await?;
self.metrics.inbound_connections.inc();
Ok((
addr,
Sink {
inner: sink,
metrics: self.metrics.clone(),
},
Stream {
inner: stream,
metrics: self.metrics.clone(),
},
))
}
fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
self.inner.local_addr()
}
}
#[derive(Debug, Clone)]
pub struct Network<N: crate::Network> {
inner: N,
metrics: Arc<Metrics>,
}
impl<N: crate::Network> Network<N> {
pub fn new(inner: N, registry: &mut Registry) -> Self {
let metrics = Metrics::new(registry);
Self {
inner,
metrics: Arc::new(metrics),
}
}
}
impl<N: crate::Network> crate::Network for Network<N> {
type Listener = Listener<N::Listener>;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, crate::Error> {
let inner = self.inner.bind(socket).await?;
Ok(Listener {
inner,
metrics: self.metrics.clone(),
})
}
async fn dial(
&self,
socket: SocketAddr,
) -> Result<(SinkOf<Self>, StreamOf<Self>), crate::Error> {
let (sink, stream) = self.inner.dial(socket).await?;
self.metrics.outbound_connections.inc();
Ok((
Sink {
inner: sink,
metrics: self.metrics.clone(),
},
Stream {
inner: stream,
metrics: self.metrics.clone(),
},
))
}
}
#[cfg(test)]
mod tests {
use crate::{
network::{
deterministic::Network as DeterministicNetwork, metered::Network as MeteredNetwork,
tests,
},
Listener as _, Network as _, Sink as _, Stream as _,
};
use commonware_macros::test_group;
use prometheus_client::registry::Registry;
use std::net::SocketAddr;
#[tokio::test]
async fn test_trait() {
tests::test_network_trait(|| {
MeteredNetwork::new(
DeterministicNetwork::default(),
&mut prometheus_client::registry::Registry::default(),
)
})
.await;
}
#[test_group("slow")]
#[tokio::test]
async fn test_stress_trait() {
tests::stress_test_network_trait(|| {
MeteredNetwork::new(
DeterministicNetwork::default(),
&mut prometheus_client::registry::Registry::default(),
)
})
.await;
}
#[tokio::test]
async fn test_metrics() {
const MSG_SIZE: usize = 100;
let mut registry = Registry::default();
let network = MeteredNetwork::new(DeterministicNetwork::default(), &mut registry);
let addr = SocketAddr::from(([127, 0, 0, 1], 1234));
let mut listener = network.bind(addr).await.unwrap();
let server = tokio::spawn(async move {
let (_, mut sink, mut stream) = listener.accept().await.unwrap();
let received = stream.recv(MSG_SIZE).await.unwrap();
sink.send(received).await.unwrap();
});
let (mut client_sink, mut client_stream) = network.dial(addr).await.unwrap();
let msg = vec![42u8; MSG_SIZE];
client_sink.send(msg.clone()).await.unwrap();
let response = client_stream.recv(MSG_SIZE).await.unwrap().coalesce();
assert_eq!(response.len(), MSG_SIZE);
assert_eq!(response, msg.as_slice());
server.await.unwrap();
assert_eq!(network.metrics.inbound_connections.get(), 1,);
assert_eq!(network.metrics.outbound_connections.get(), 1,);
assert_eq!(
network.metrics.inbound_bandwidth.get(),
2 * MSG_SIZE as u64,
"client and server should both have received MSG_SIZE"
);
assert_eq!(
network.metrics.outbound_bandwidth.get(),
2 * MSG_SIZE as u64,
"client and server should both have sent MSG_SIZE"
);
}
}