#[cfg(not(feature = "iouring-network"))]
use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
#[cfg(feature = "iouring-storage")]
use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
#[cfg(not(feature = "iouring-storage"))]
use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
#[cfg(feature = "iouring-network")]
use crate::{
iouring,
network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
};
use crate::{
network::metered::Network as MeteredNetwork, process::metered::Metrics as MeteredProcess,
signal::Signal, storage::metered::Storage as MeteredStorage, telemetry::metrics::task::Label,
utils::signal::Stopper, Clock, Error, Handle, SinkOf, StreamOf, METRICS_PREFIX,
};
use commonware_macros::select;
use futures::future::AbortHandle;
use governor::clock::{Clock as GClock, ReasonablyRealtime};
use prometheus_client::{
encoding::text::encode,
metrics::{counter::Counter, family::Family, gauge::Gauge},
registry::{Metric, Registry},
};
use rand::{rngs::OsRng, CryptoRng, RngCore};
use std::{
env,
future::Future,
net::SocketAddr,
path::PathBuf,
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use tokio::runtime::{Builder, Runtime};
#[cfg(feature = "iouring-network")]
const IOURING_NETWORK_SIZE: u32 = 1024;
#[cfg(feature = "iouring-network")]
const IOURING_NETWORK_FORCE_POLL: Option<Duration> = Some(Duration::from_millis(100));
#[derive(Debug)]
struct Metrics {
tasks_spawned: Family<Label, Counter>,
tasks_running: Family<Label, Gauge>,
}
impl Metrics {
pub fn init(registry: &mut Registry) -> Self {
let metrics = Self {
tasks_spawned: Family::default(),
tasks_running: Family::default(),
};
registry.register(
"tasks_spawned",
"Total number of tasks spawned",
metrics.tasks_spawned.clone(),
);
registry.register(
"tasks_running",
"Number of tasks currently running",
metrics.tasks_running.clone(),
);
metrics
}
}
#[derive(Clone, Debug)]
pub struct NetworkConfig {
tcp_nodelay: Option<bool>,
read_write_timeout: Duration,
}
impl Default for NetworkConfig {
fn default() -> Self {
Self {
tcp_nodelay: None,
read_write_timeout: Duration::from_secs(60),
}
}
}
#[derive(Clone)]
pub struct Config {
worker_threads: usize,
max_blocking_threads: usize,
catch_panics: bool,
storage_directory: PathBuf,
maximum_buffer_size: usize,
network_cfg: NetworkConfig,
}
impl Config {
pub fn new() -> Self {
let rng = OsRng.next_u64();
let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
Self {
worker_threads: 2,
max_blocking_threads: 512,
catch_panics: true,
storage_directory,
maximum_buffer_size: 2 * 1024 * 1024, network_cfg: NetworkConfig::default(),
}
}
pub fn with_worker_threads(mut self, n: usize) -> Self {
self.worker_threads = n;
self
}
pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
self.max_blocking_threads = n;
self
}
pub fn with_catch_panics(mut self, b: bool) -> Self {
self.catch_panics = b;
self
}
pub fn with_read_write_timeout(mut self, d: Duration) -> Self {
self.network_cfg.read_write_timeout = d;
self
}
pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
self.network_cfg.tcp_nodelay = n;
self
}
pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
self.storage_directory = p.into();
self
}
pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
self.maximum_buffer_size = n;
self
}
pub fn worker_threads(&self) -> usize {
self.worker_threads
}
pub fn max_blocking_threads(&self) -> usize {
self.max_blocking_threads
}
pub fn catch_panics(&self) -> bool {
self.catch_panics
}
pub fn read_write_timeout(&self) -> Duration {
self.network_cfg.read_write_timeout
}
pub fn tcp_nodelay(&self) -> Option<bool> {
self.network_cfg.tcp_nodelay
}
pub fn storage_directory(&self) -> &PathBuf {
&self.storage_directory
}
pub fn maximum_buffer_size(&self) -> usize {
self.maximum_buffer_size
}
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
pub struct Executor {
cfg: Config,
registry: Mutex<Registry>,
metrics: Arc<Metrics>,
runtime: Runtime,
shutdown: Mutex<Stopper>,
}
pub struct Runner {
cfg: Config,
}
impl Default for Runner {
fn default() -> Self {
Self::new(Config::default())
}
}
impl Runner {
pub fn new(cfg: Config) -> Self {
Self { cfg }
}
}
impl crate::Runner for Runner {
type Context = Context;
fn start<F, Fut>(self, f: F) -> Fut::Output
where
F: FnOnce(Self::Context) -> Fut,
Fut: Future,
{
let mut registry = Registry::default();
let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
let metrics = Arc::new(Metrics::init(runtime_registry));
let runtime = Builder::new_multi_thread()
.worker_threads(self.cfg.worker_threads)
.max_blocking_threads(self.cfg.max_blocking_threads)
.enable_all()
.build()
.expect("failed to create Tokio runtime");
let process = MeteredProcess::init(runtime_registry);
runtime.spawn(process.collect(tokio::time::sleep));
cfg_if::cfg_if! {
if #[cfg(feature = "iouring-storage")] {
let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
let storage = MeteredStorage::new(
IoUringStorage::start(IoUringConfig {
storage_directory: self.cfg.storage_directory.clone(),
iouring_config: Default::default(),
}, iouring_registry),
runtime_registry,
);
} else {
let storage = MeteredStorage::new(
TokioStorage::new(TokioStorageConfig::new(
self.cfg.storage_directory.clone(),
self.cfg.maximum_buffer_size,
)),
runtime_registry,
);
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "iouring-network")] {
let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
let config = IoUringNetworkConfig {
tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
iouring_config: iouring::Config {
size: IOURING_NETWORK_SIZE,
op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
force_poll: IOURING_NETWORK_FORCE_POLL,
shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
..Default::default()
},
};
let network = MeteredNetwork::new(
IoUringNetwork::start(config, iouring_registry).unwrap(),
runtime_registry,
);
} else {
let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
.with_write_timeout(self.cfg.network_cfg.read_write_timeout)
.with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
let network = MeteredNetwork::new(
TokioNetwork::from(config),
runtime_registry,
);
}
}
let executor = Arc::new(Executor {
cfg: self.cfg,
registry: Mutex::new(registry),
metrics,
runtime,
shutdown: Mutex::new(Stopper::default()),
});
let label = Label::root();
executor.metrics.tasks_spawned.get_or_create(&label).inc();
let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
let context = Context {
storage,
name: label.name(),
spawned: false,
executor: executor.clone(),
network,
children: Arc::new(Mutex::new(Vec::new())),
};
let output = executor.runtime.block_on(f(context));
gauge.dec();
output
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "iouring-storage")] {
type Storage = MeteredStorage<IoUringStorage>;
} else {
type Storage = MeteredStorage<TokioStorage>;
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "iouring-network")] {
type Network = MeteredNetwork<IoUringNetwork>;
} else {
type Network = MeteredNetwork<TokioNetwork>;
}
}
pub struct Context {
name: String,
spawned: bool,
executor: Arc<Executor>,
storage: Storage,
network: Network,
children: Arc<Mutex<Vec<AbortHandle>>>,
}
impl Clone for Context {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
spawned: false,
executor: self.executor.clone(),
storage: self.storage.clone(),
network: self.network.clone(),
children: self.children.clone(),
}
}
}
impl crate::Spawner for Context {
fn spawn<F, Fut, T>(mut self, f: F) -> Handle<T>
where
F: FnOnce(Self) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
assert!(!self.spawned, "already spawned");
let (_, gauge) = spawn_metrics!(self, future);
let catch_panics = self.executor.cfg.catch_panics;
let executor = self.executor.clone();
let children = Arc::new(Mutex::new(Vec::new()));
self.children = children.clone();
let future = f(self);
let (f, handle) = Handle::init_future(future, gauge, catch_panics, children);
executor.runtime.spawn(f);
handle
}
fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
assert!(!self.spawned, "already spawned");
self.spawned = true;
let (_, gauge) = spawn_metrics!(self, future);
let executor = self.executor.clone();
move |f: F| {
let (f, handle) = Handle::init_future(
f,
gauge,
executor.cfg.catch_panics,
Arc::new(Mutex::new(Vec::new())),
);
executor.runtime.spawn(f);
handle
}
}
fn spawn_child<F, Fut, T>(self, f: F) -> Handle<T>
where
F: FnOnce(Self) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let parent_children = self.children.clone();
let child_handle = self.spawn(f);
if let Some(abort_handle) = child_handle.abort_handle() {
parent_children.lock().unwrap().push(abort_handle);
}
child_handle
}
fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
where
F: FnOnce(Self) -> T + Send + 'static,
T: Send + 'static,
{
assert!(!self.spawned, "already spawned");
let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
let executor = self.executor.clone();
let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
if dedicated {
std::thread::spawn(f);
} else {
executor.runtime.spawn_blocking(f);
}
handle
}
fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
assert!(!self.spawned, "already spawned");
self.spawned = true;
let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
let executor = self.executor.clone();
move |f: F| {
let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
if dedicated {
std::thread::spawn(f);
} else {
executor.runtime.spawn_blocking(f);
}
handle
}
}
async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
let stop_resolved = {
let mut shutdown = self.executor.shutdown.lock().unwrap();
shutdown.stop(value)
};
let timeout_future = match timeout {
Some(duration) => futures::future::Either::Left(self.sleep(duration)),
None => futures::future::Either::Right(futures::future::pending()),
};
select! {
result = stop_resolved => {
result.map_err(|_| Error::Closed)?;
Ok(())
},
_ = timeout_future => {
Err(Error::Timeout)
}
}
}
fn stopped(&self) -> Signal {
self.executor.shutdown.lock().unwrap().stopped()
}
}
impl crate::Metrics for Context {
fn with_label(&self, label: &str) -> Self {
let name = {
let prefix = self.name.clone();
if prefix.is_empty() {
label.to_string()
} else {
format!("{prefix}_{label}")
}
};
assert!(
!name.starts_with(METRICS_PREFIX),
"using runtime label is not allowed"
);
Self {
name,
spawned: false,
executor: self.executor.clone(),
storage: self.storage.clone(),
network: self.network.clone(),
children: self.children.clone(),
}
}
fn label(&self) -> String {
self.name.clone()
}
fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
let name = name.into();
let prefixed_name = {
let prefix = &self.name;
if prefix.is_empty() {
name
} else {
format!("{}_{}", *prefix, name)
}
};
self.executor
.registry
.lock()
.unwrap()
.register(prefixed_name, help, metric)
}
fn encode(&self) -> String {
let mut buffer = String::new();
encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
buffer
}
}
impl Clock for Context {
fn current(&self) -> SystemTime {
SystemTime::now()
}
fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
tokio::time::sleep(duration)
}
fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
let now = SystemTime::now();
let duration_until_deadline = match deadline.duration_since(now) {
Ok(duration) => duration,
Err(_) => Duration::from_secs(0), };
let target_instant = tokio::time::Instant::now() + duration_until_deadline;
tokio::time::sleep_until(target_instant)
}
}
impl GClock for Context {
type Instant = SystemTime;
fn now(&self) -> Self::Instant {
self.current()
}
}
impl ReasonablyRealtime for Context {}
impl crate::Network for Context {
type Listener = <Network as crate::Network>::Listener;
async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
self.network.bind(socket).await
}
async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
self.network.dial(socket).await
}
}
impl RngCore for Context {
fn next_u32(&mut self) -> u32 {
OsRng.next_u32()
}
fn next_u64(&mut self) -> u64 {
OsRng.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
OsRng.fill_bytes(dest);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
OsRng.try_fill_bytes(dest)
}
}
impl CryptoRng for Context {}
impl crate::Storage for Context {
type Blob = <Storage as crate::Storage>::Blob;
async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
self.storage.open(partition, name).await
}
async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
self.storage.remove(partition, name).await
}
async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
self.storage.scan(partition).await
}
}