use crate::{Clock, Error, Handle};
use bytes::Bytes;
use futures::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use governor::clock::{Clock as GClock, ReasonablyRealtime};
use prometheus_client::{
encoding::EncodeLabelSet,
metrics::{counter::Counter, family::Family, gauge::Gauge},
registry::Registry,
};
use rand::{rngs::OsRng, CryptoRng, RngCore};
use std::{
future::Future,
net::SocketAddr,
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use tokio::{
net::{TcpListener, TcpStream},
runtime::{Builder, Runtime},
task_local,
time::timeout,
};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::warn;
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct Work {
label: String,
}
#[derive(Debug)]
struct Metrics {
tasks_spawned: Family<Work, Counter>,
tasks_running: Family<Work, Gauge>,
inbound_connections: Counter,
outbound_connections: Counter,
inbound_bandwidth: Counter,
outbound_bandwidth: Counter,
}
impl Metrics {
pub fn init(registry: Arc<Mutex<Registry>>) -> Self {
let metrics = Self {
tasks_spawned: Family::default(),
tasks_running: Family::default(),
inbound_connections: Counter::default(),
outbound_connections: Counter::default(),
inbound_bandwidth: Counter::default(),
outbound_bandwidth: Counter::default(),
};
{
let mut registry = registry.lock().unwrap();
registry.register(
"tasks_spawned",
"Total number of tasks spawned",
metrics.tasks_spawned.clone(),
);
registry.register(
"tasks_running",
"Number of tasks currently running",
metrics.tasks_running.clone(),
);
registry.register(
"inbound_connections",
"Number of connections created by dialing us",
metrics.inbound_connections.clone(),
);
registry.register(
"outbound_connections",
"Number of connections created by dialing others",
metrics.outbound_connections.clone(),
);
registry.register(
"inbound_bandwidth",
"Bandwidth used by receiving data from others",
metrics.inbound_bandwidth.clone(),
);
registry.register(
"outbound_bandwidth",
"Bandwidth used by sending data to others",
metrics.outbound_bandwidth.clone(),
);
}
metrics
}
}
#[derive(Clone)]
pub struct Config {
pub registry: Arc<Mutex<Registry>>,
pub threads: usize,
pub catch_panics: bool,
pub max_message_size: usize,
pub read_timeout: Duration,
pub write_timeout: Duration,
pub tcp_nodelay: Option<bool>,
}
impl Default for Config {
fn default() -> Self {
Self {
registry: Arc::new(Mutex::new(Registry::default())),
threads: 2,
catch_panics: true,
max_message_size: 1024 * 1024, read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(30),
tcp_nodelay: None,
}
}
}
pub struct Executor {
cfg: Config,
metrics: Arc<Metrics>,
runtime: Runtime,
}
impl Executor {
pub fn init(cfg: Config) -> (Runner, Context) {
let metrics = Arc::new(Metrics::init(cfg.registry.clone()));
let runtime = Builder::new_multi_thread()
.worker_threads(cfg.threads)
.enable_all()
.build()
.expect("failed to create Tokio runtime");
let executor = Arc::new(Self {
cfg,
metrics,
runtime,
});
(
Runner {
executor: executor.clone(),
},
Context { executor },
)
}
#[allow(clippy::should_implement_trait)]
pub fn default() -> (Runner, Context) {
Self::init(Config::default())
}
}
pub struct Runner {
executor: Arc<Executor>,
}
impl crate::Runner for Runner {
fn start<F>(self, f: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.executor.runtime.block_on(f)
}
}
#[derive(Clone)]
pub struct Context {
executor: Arc<Executor>,
}
task_local! {
static PREFIX: String;
}
impl crate::Spawner for Context {
fn spawn<F, T>(&self, label: &str, f: F) -> Handle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let label = PREFIX
.try_with(|prefix| format!("{}_{}", prefix, label))
.unwrap_or_else(|_| label.to_string());
let f = PREFIX.scope(label.clone(), f);
let work = Work { label };
self.executor
.metrics
.tasks_spawned
.get_or_create(&work)
.inc();
let gauge = self
.executor
.metrics
.tasks_running
.get_or_create(&work)
.clone();
let (f, handle) = Handle::init(f, gauge, self.executor.cfg.catch_panics);
self.executor.runtime.spawn(f);
handle
}
}
impl Clock for Context {
fn current(&self) -> SystemTime {
SystemTime::now()
}
fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
tokio::time::sleep(duration)
}
fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
let now = SystemTime::now();
let duration_until_deadline = match deadline.duration_since(now) {
Ok(duration) => duration,
Err(_) => Duration::from_secs(0), };
let target_instant = tokio::time::Instant::now() + duration_until_deadline;
tokio::time::sleep_until(target_instant)
}
}
impl GClock for Context {
type Instant = SystemTime;
fn now(&self) -> Self::Instant {
self.current()
}
}
impl ReasonablyRealtime for Context {}
pub fn codec(max_frame_len: usize) -> LengthDelimitedCodec {
LengthDelimitedCodec::builder()
.length_field_type::<u32>()
.max_frame_length(max_frame_len)
.new_codec()
}
impl crate::Network<Listener, Sink, Stream> for Context {
async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
TcpListener::bind(socket)
.await
.map_err(|_| Error::BindFailed)
.map(|listener| Listener {
context: self.clone(),
listener,
})
}
async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
let stream = TcpStream::connect(socket)
.await
.map_err(|_| Error::ConnectionFailed)?;
self.executor.metrics.outbound_connections.inc();
if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let context = self.clone();
let framed = Framed::new(stream, codec(self.executor.cfg.max_message_size));
let (sink, stream) = framed.split();
Ok((
Sink {
context: context.clone(),
sink,
},
Stream { context, stream },
))
}
}
pub struct Listener {
context: Context,
listener: TcpListener,
}
impl crate::Listener<Sink, Stream> for Listener {
async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
self.context.executor.metrics.inbound_connections.inc();
if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let framed = Framed::new(stream, codec(self.context.executor.cfg.max_message_size));
let (sink, stream) = framed.split();
let context = self.context.clone();
Ok((
addr,
Sink {
context: context.clone(),
sink,
},
Stream { context, stream },
))
}
}
pub struct Sink {
context: Context,
sink: SplitSink<Framed<TcpStream, LengthDelimitedCodec>, Bytes>,
}
impl crate::Sink for Sink {
async fn send(&mut self, msg: Bytes) -> Result<(), Error> {
let len = msg.len();
timeout(self.context.executor.cfg.write_timeout, self.sink.send(msg))
.await
.map_err(|_| Error::WriteFailed)?
.map_err(|_| Error::WriteFailed)?;
self.context
.executor
.metrics
.outbound_bandwidth
.inc_by(len as u64);
Ok(())
}
}
pub struct Stream {
context: Context,
stream: SplitStream<Framed<TcpStream, LengthDelimitedCodec>>,
}
impl crate::Stream for Stream {
async fn recv(&mut self) -> Result<Bytes, Error> {
let result = timeout(self.context.executor.cfg.read_timeout, self.stream.next())
.await
.map_err(|_| Error::ReadFailed)?
.ok_or(Error::Closed)?
.map_err(|_| Error::ReadFailed)?;
self.context
.executor
.metrics
.inbound_bandwidth
.inc_by(result.len() as u64);
Ok(result.freeze())
}
}
impl RngCore for Context {
fn next_u32(&mut self) -> u32 {
OsRng.next_u32()
}
fn next_u64(&mut self) -> u64 {
OsRng.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
OsRng.fill_bytes(dest);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
OsRng.try_fill_bytes(dest)
}
}
impl CryptoRng for Context {}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::run_tasks;
use crate::Runner;
use std::io::Cursor;
#[test]
fn test_runs_tasks() {
let (executor, runtime) = Executor::default();
run_tasks(10, executor, runtime);
}
#[test]
fn test_codec_invalid_frame_len() {
let (runner, _) = Executor::default();
runner.start(async move {
let max_frame_len = 10;
let codec = codec(max_frame_len);
let mut framed = Framed::new(Cursor::new(Vec::new()), codec);
let message = vec![0; max_frame_len + 1];
let message = Bytes::from(message);
let result = framed.send(message).await;
assert!(result.is_err());
});
}
#[test]
fn test_codec_valid_frame_len() {
let (runner, _) = Executor::default();
runner.start(async move {
let max_frame_len = 10;
let codec = codec(max_frame_len);
let mut framed = Framed::new(Cursor::new(Vec::new()), codec);
let message = vec![0; max_frame_len];
let message = Bytes::from(message);
let result = framed.send(message).await;
assert!(result.is_ok());
});
}
}