use crate::{utils::Signaler, Clock, Error, Handle, Signal};
use commonware_utils::{from_hex, hex};
use governor::clock::{Clock as GClock, ReasonablyRealtime};
use prometheus_client::{
encoding::EncodeLabelSet,
metrics::{counter::Counter, family::Family, gauge::Gauge},
registry::Registry,
};
use rand::{rngs::OsRng, CryptoRng, RngCore};
use std::{
env,
future::Future,
io::SeekFrom,
net::SocketAddr,
path::PathBuf,
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use tokio::{
fs,
io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
runtime::{Builder, Runtime},
sync::Mutex as AsyncMutex,
task_local,
time::timeout,
};
use tracing::warn;
#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
struct Work {
label: String,
}
#[derive(Debug)]
struct Metrics {
tasks_spawned: Family<Work, Counter>,
tasks_running: Family<Work, Gauge>,
inbound_connections: Counter,
outbound_connections: Counter,
inbound_bandwidth: Counter,
outbound_bandwidth: Counter,
open_blobs: Gauge,
storage_reads: Counter,
storage_read_bytes: Counter,
storage_writes: Counter,
storage_write_bytes: Counter,
}
impl Metrics {
pub fn init(registry: Arc<Mutex<Registry>>) -> Self {
let metrics = Self {
tasks_spawned: Family::default(),
tasks_running: Family::default(),
inbound_connections: Counter::default(),
outbound_connections: Counter::default(),
inbound_bandwidth: Counter::default(),
outbound_bandwidth: Counter::default(),
open_blobs: Gauge::default(),
storage_reads: Counter::default(),
storage_read_bytes: Counter::default(),
storage_writes: Counter::default(),
storage_write_bytes: Counter::default(),
};
{
let mut registry = registry.lock().unwrap();
registry.register(
"tasks_spawned",
"Total number of tasks spawned",
metrics.tasks_spawned.clone(),
);
registry.register(
"tasks_running",
"Number of tasks currently running",
metrics.tasks_running.clone(),
);
registry.register(
"inbound_connections",
"Number of connections created by dialing us",
metrics.inbound_connections.clone(),
);
registry.register(
"outbound_connections",
"Number of connections created by dialing others",
metrics.outbound_connections.clone(),
);
registry.register(
"inbound_bandwidth",
"Bandwidth used by receiving data from others",
metrics.inbound_bandwidth.clone(),
);
registry.register(
"outbound_bandwidth",
"Bandwidth used by sending data to others",
metrics.outbound_bandwidth.clone(),
);
registry.register(
"open_blobs",
"Number of open blobs",
metrics.open_blobs.clone(),
);
registry.register(
"storage_reads",
"Total number of disk reads",
metrics.storage_reads.clone(),
);
registry.register(
"storage_read_bytes",
"Total amount of data read from disk",
metrics.storage_read_bytes.clone(),
);
registry.register(
"storage_writes",
"Total number of disk writes",
metrics.storage_writes.clone(),
);
registry.register(
"storage_write_bytes",
"Total amount of data written to disk",
metrics.storage_write_bytes.clone(),
);
}
metrics
}
}
#[derive(Clone)]
pub struct Config {
pub registry: Arc<Mutex<Registry>>,
pub threads: usize,
pub catch_panics: bool,
pub read_timeout: Duration,
pub write_timeout: Duration,
pub tcp_nodelay: Option<bool>,
pub storage_directory: PathBuf,
pub maximum_buffer_size: usize,
}
impl Default for Config {
fn default() -> Self {
let rng = OsRng.next_u64();
let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
Self {
registry: Arc::new(Mutex::new(Registry::default())),
threads: 2,
catch_panics: true,
read_timeout: Duration::from_secs(60),
write_timeout: Duration::from_secs(30),
tcp_nodelay: None,
storage_directory,
maximum_buffer_size: 2 * 1024 * 1024, }
}
}
pub struct Executor {
cfg: Config,
metrics: Arc<Metrics>,
runtime: Runtime,
fs: AsyncMutex<()>,
signaler: Mutex<Signaler>,
signal: Signal,
}
impl Executor {
pub fn init(cfg: Config) -> (Runner, Context) {
let metrics = Arc::new(Metrics::init(cfg.registry.clone()));
let runtime = Builder::new_multi_thread()
.worker_threads(cfg.threads)
.enable_all()
.build()
.expect("failed to create Tokio runtime");
let (signaler, signal) = Signaler::new();
let executor = Arc::new(Self {
cfg,
metrics,
runtime,
fs: AsyncMutex::new(()),
signaler: Mutex::new(signaler),
signal,
});
(
Runner {
executor: executor.clone(),
},
Context { executor },
)
}
#[allow(clippy::should_implement_trait)]
pub fn default() -> (Runner, Context) {
Self::init(Config::default())
}
}
pub struct Runner {
executor: Arc<Executor>,
}
impl crate::Runner for Runner {
fn start<F>(self, f: F) -> F::Output
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.executor.runtime.block_on(f)
}
}
#[derive(Clone)]
pub struct Context {
executor: Arc<Executor>,
}
task_local! {
static PREFIX: String;
}
impl crate::Spawner for Context {
fn spawn<F, T>(&self, label: &str, f: F) -> Handle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let label = PREFIX
.try_with(|prefix| format!("{}_{}", prefix, label))
.unwrap_or_else(|_| label.to_string());
let f = PREFIX.scope(label.clone(), f);
let work = Work { label };
self.executor
.metrics
.tasks_spawned
.get_or_create(&work)
.inc();
let gauge = self
.executor
.metrics
.tasks_running
.get_or_create(&work)
.clone();
let (f, handle) = Handle::init(f, gauge, self.executor.cfg.catch_panics);
self.executor.runtime.spawn(f);
handle
}
fn stop(&self, value: i32) {
self.executor.signaler.lock().unwrap().signal(value);
}
fn stopped(&self) -> Signal {
self.executor.signal.clone()
}
}
impl Clock for Context {
fn current(&self) -> SystemTime {
SystemTime::now()
}
fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
tokio::time::sleep(duration)
}
fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
let now = SystemTime::now();
let duration_until_deadline = match deadline.duration_since(now) {
Ok(duration) => duration,
Err(_) => Duration::from_secs(0), };
let target_instant = tokio::time::Instant::now() + duration_until_deadline;
tokio::time::sleep_until(target_instant)
}
}
impl GClock for Context {
type Instant = SystemTime;
fn now(&self) -> Self::Instant {
self.current()
}
}
impl ReasonablyRealtime for Context {}
impl crate::Network<Listener, Sink, Stream> for Context {
async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
TcpListener::bind(socket)
.await
.map_err(|_| Error::BindFailed)
.map(|listener| Listener {
context: self.clone(),
listener,
})
}
async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
let stream = TcpStream::connect(socket)
.await
.map_err(|_| Error::ConnectionFailed)?;
self.executor.metrics.outbound_connections.inc();
if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let context = self.clone();
let (stream, sink) = stream.into_split();
Ok((
Sink {
context: context.clone(),
sink,
},
Stream { context, stream },
))
}
}
pub struct Listener {
context: Context,
listener: TcpListener,
}
impl crate::Listener<Sink, Stream> for Listener {
async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
self.context.executor.metrics.inbound_connections.inc();
if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
if let Err(err) = stream.set_nodelay(tcp_nodelay) {
warn!(?err, "failed to set TCP_NODELAY");
}
}
let context = self.context.clone();
let (stream, sink) = stream.into_split();
Ok((
addr,
Sink {
context: context.clone(),
sink,
},
Stream { context, stream },
))
}
}
pub struct Sink {
context: Context,
sink: OwnedWriteHalf,
}
impl crate::Sink for Sink {
async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
let len = msg.len();
timeout(
self.context.executor.cfg.write_timeout,
self.sink.write_all(msg),
)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::SendFailed)?;
self.context
.executor
.metrics
.outbound_bandwidth
.inc_by(len as u64);
Ok(())
}
}
pub struct Stream {
context: Context,
stream: OwnedReadHalf,
}
impl crate::Stream for Stream {
async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
timeout(
self.context.executor.cfg.read_timeout,
self.stream.read_exact(buf),
)
.await
.map_err(|_| Error::Timeout)?
.map_err(|_| Error::RecvFailed)?;
self.context
.executor
.metrics
.inbound_bandwidth
.inc_by(buf.len() as u64);
Ok(())
}
}
impl RngCore for Context {
fn next_u32(&mut self) -> u32 {
OsRng.next_u32()
}
fn next_u64(&mut self) -> u64 {
OsRng.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
OsRng.fill_bytes(dest);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
OsRng.try_fill_bytes(dest)
}
}
impl CryptoRng for Context {}
pub struct Blob {
metrics: Arc<Metrics>,
partition: String,
name: Vec<u8>,
file: Arc<AsyncMutex<(fs::File, u64)>>,
}
impl Blob {
fn new(
metrics: Arc<Metrics>,
partition: String,
name: &[u8],
file: fs::File,
len: u64,
) -> Self {
metrics.open_blobs.inc();
Self {
metrics,
partition,
name: name.into(),
file: Arc::new(AsyncMutex::new((file, len))),
}
}
}
impl Clone for Blob {
fn clone(&self) -> Self {
self.metrics.open_blobs.inc();
Self {
metrics: self.metrics.clone(),
partition: self.partition.clone(),
name: self.name.clone(),
file: self.file.clone(),
}
}
}
impl crate::Storage<Blob> for Context {
async fn open(&self, partition: &str, name: &[u8]) -> Result<Blob, Error> {
let _guard = self.executor.fs.lock().await;
let path = self
.executor
.cfg
.storage_directory
.join(partition)
.join(hex(name));
let parent = match path.parent() {
Some(parent) => parent,
None => return Err(Error::PartitionCreationFailed(partition.into())),
};
fs::create_dir_all(parent)
.await
.map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
let mut file = fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)
.await
.map_err(|_| Error::BlobOpenFailed(partition.into(), hex(name)))?;
file.set_max_buf_size(self.executor.cfg.maximum_buffer_size);
let len = file.metadata().await.map_err(|_| Error::ReadFailed)?.len();
Ok(Blob::new(
self.executor.metrics.clone(),
partition.into(),
name,
file,
len,
))
}
async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
let _guard = self.executor.fs.lock().await;
let path = self.executor.cfg.storage_directory.join(partition);
if let Some(name) = name {
let blob_path = path.join(hex(name));
fs::remove_file(blob_path)
.await
.map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
} else {
fs::remove_dir_all(path)
.await
.map_err(|_| Error::PartitionMissing(partition.into()))?;
}
Ok(())
}
async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
let _guard = self.executor.fs.lock().await;
let path = self.executor.cfg.storage_directory.join(partition);
let mut entries = fs::read_dir(path)
.await
.map_err(|_| Error::PartitionMissing(partition.into()))?;
let mut blobs = Vec::new();
while let Some(entry) = entries.next_entry().await.map_err(|_| Error::ReadFailed)? {
let file_type = entry.file_type().await.map_err(|_| Error::ReadFailed)?;
if !file_type.is_file() {
return Err(Error::PartitionCorrupt(partition.into()));
}
if let Some(name) = entry.file_name().to_str() {
let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
blobs.push(name);
}
}
Ok(blobs)
}
}
impl crate::Blob for Blob {
async fn len(&self) -> Result<u64, Error> {
let (_, len) = *self.file.lock().await;
Ok(len)
}
async fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<(), Error> {
let mut file = self.file.lock().await;
if offset + buf.len() as u64 > file.1 {
return Err(Error::BlobInsufficientLength);
}
file.0
.seek(SeekFrom::Start(offset))
.await
.map_err(|_| Error::ReadFailed)?;
file.0
.read_exact(buf)
.await
.map_err(|_| Error::ReadFailed)?;
self.metrics.storage_reads.inc();
self.metrics.storage_read_bytes.inc_by(buf.len() as u64);
Ok(())
}
async fn write_at(&self, buf: &[u8], offset: u64) -> Result<(), Error> {
let mut file = self.file.lock().await;
file.0
.seek(SeekFrom::Start(offset))
.await
.map_err(|_| Error::WriteFailed)?;
file.0
.write_all(buf)
.await
.map_err(|_| Error::WriteFailed)?;
let max_len = offset + buf.len() as u64;
if max_len > file.1 {
file.1 = max_len;
}
self.metrics.storage_writes.inc();
self.metrics.storage_write_bytes.inc_by(buf.len() as u64);
Ok(())
}
async fn truncate(&self, len: u64) -> Result<(), Error> {
let mut file = self.file.lock().await;
file.0
.set_len(len)
.await
.map_err(|_| Error::BlobTruncateFailed(self.partition.clone(), hex(&self.name)))?;
file.1 = len;
Ok(())
}
async fn sync(&self) -> Result<(), Error> {
let file = self.file.lock().await;
file.0
.sync_all()
.await
.map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))
}
async fn close(self) -> Result<(), Error> {
let mut file = self.file.lock().await;
file.0
.sync_all()
.await
.map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))?;
file.0
.shutdown()
.await
.map_err(|_| Error::BlobCloseFailed(self.partition.clone(), hex(&self.name)))
}
}
impl Drop for Blob {
fn drop(&mut self) {
self.metrics.open_blobs.dec();
}
}