commonware_runtime/tokio/
runtime.rs

1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "iouring-network")]
8use crate::{
9    iouring,
10    network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
11};
12use crate::{
13    network::metered::Network as MeteredNetwork, process::metered::Metrics as MeteredProcess,
14    signal::Signal, storage::metered::Storage as MeteredStorage, telemetry::metrics::task::Label,
15    utils::signal::Stopper, Clock, Error, Handle, SinkOf, StreamOf, METRICS_PREFIX,
16};
17use commonware_macros::select;
18use governor::clock::{Clock as GClock, ReasonablyRealtime};
19use prometheus_client::{
20    encoding::text::encode,
21    metrics::{counter::Counter, family::Family, gauge::Gauge},
22    registry::{Metric, Registry},
23};
24use rand::{rngs::OsRng, CryptoRng, RngCore};
25use std::{
26    env,
27    future::Future,
28    net::SocketAddr,
29    path::PathBuf,
30    sync::{Arc, Mutex},
31    time::{Duration, SystemTime},
32};
33use tokio::runtime::{Builder, Runtime};
34
35#[cfg(feature = "iouring-network")]
36const IOURING_NETWORK_SIZE: u32 = 1024;
37#[cfg(feature = "iouring-network")]
38const IOURING_NETWORK_FORCE_POLL: Option<Duration> = Some(Duration::from_millis(100));
39
40#[derive(Debug)]
41struct Metrics {
42    tasks_spawned: Family<Label, Counter>,
43    tasks_running: Family<Label, Gauge>,
44}
45
46impl Metrics {
47    pub fn init(registry: &mut Registry) -> Self {
48        let metrics = Self {
49            tasks_spawned: Family::default(),
50            tasks_running: Family::default(),
51        };
52        registry.register(
53            "tasks_spawned",
54            "Total number of tasks spawned",
55            metrics.tasks_spawned.clone(),
56        );
57        registry.register(
58            "tasks_running",
59            "Number of tasks currently running",
60            metrics.tasks_running.clone(),
61        );
62        metrics
63    }
64}
65
66#[derive(Clone, Debug)]
67pub struct NetworkConfig {
68    /// If Some, explicitly sets TCP_NODELAY on the socket.
69    /// Otherwise uses system default.
70    tcp_nodelay: Option<bool>,
71
72    /// Read/write timeout for network operations.
73    read_write_timeout: Duration,
74}
75
76impl Default for NetworkConfig {
77    fn default() -> Self {
78        Self {
79            tcp_nodelay: None,
80            read_write_timeout: Duration::from_secs(60),
81        }
82    }
83}
84
85/// Configuration for the `tokio` runtime.
86#[derive(Clone)]
87pub struct Config {
88    /// Number of threads to use for handling async tasks.
89    ///
90    /// Worker threads are always active (waiting for work).
91    ///
92    /// Tokio sets the default value to the number of logical CPUs.
93    worker_threads: usize,
94
95    /// Maximum number of threads to use for blocking tasks.
96    ///
97    /// Unlike worker threads, blocking threads are created as needed and
98    /// exit if left idle for too long.
99    ///
100    /// Tokio sets the default value to 512 to avoid hanging on lower-level
101    /// operations that require blocking (like `fs` and writing to `Stdout`).
102    max_blocking_threads: usize,
103
104    /// Whether or not to catch panics.
105    catch_panics: bool,
106
107    /// Base directory for all storage operations.
108    storage_directory: PathBuf,
109
110    /// Maximum buffer size for operations on blobs.
111    ///
112    /// Tokio sets the default value to 2MB.
113    maximum_buffer_size: usize,
114
115    /// Network configuration.
116    network_cfg: NetworkConfig,
117}
118
119impl Config {
120    /// Returns a new [Config] with default values.
121    pub fn new() -> Self {
122        let rng = OsRng.next_u64();
123        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
124        Self {
125            worker_threads: 2,
126            max_blocking_threads: 512,
127            catch_panics: true,
128            storage_directory,
129            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
130            network_cfg: NetworkConfig::default(),
131        }
132    }
133
134    // Setters
135    /// See [Config]
136    pub fn with_worker_threads(mut self, n: usize) -> Self {
137        self.worker_threads = n;
138        self
139    }
140    /// See [Config]
141    pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
142        self.max_blocking_threads = n;
143        self
144    }
145    /// See [Config]
146    pub fn with_catch_panics(mut self, b: bool) -> Self {
147        self.catch_panics = b;
148        self
149    }
150    /// See [Config]
151    pub fn with_read_write_timeout(mut self, d: Duration) -> Self {
152        self.network_cfg.read_write_timeout = d;
153        self
154    }
155    /// See [Config]
156    pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
157        self.network_cfg.tcp_nodelay = n;
158        self
159    }
160    /// See [Config]
161    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
162        self.storage_directory = p.into();
163        self
164    }
165    /// See [Config]
166    pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
167        self.maximum_buffer_size = n;
168        self
169    }
170
171    // Getters
172    /// See [Config]
173    pub fn worker_threads(&self) -> usize {
174        self.worker_threads
175    }
176    /// See [Config]
177    pub fn max_blocking_threads(&self) -> usize {
178        self.max_blocking_threads
179    }
180    /// See [Config]
181    pub fn catch_panics(&self) -> bool {
182        self.catch_panics
183    }
184    /// See [Config]
185    pub fn read_write_timeout(&self) -> Duration {
186        self.network_cfg.read_write_timeout
187    }
188    /// See [Config]
189    pub fn tcp_nodelay(&self) -> Option<bool> {
190        self.network_cfg.tcp_nodelay
191    }
192    /// See [Config]
193    pub fn storage_directory(&self) -> &PathBuf {
194        &self.storage_directory
195    }
196    /// See [Config]
197    pub fn maximum_buffer_size(&self) -> usize {
198        self.maximum_buffer_size
199    }
200}
201
202impl Default for Config {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208/// Runtime based on [Tokio](https://tokio.rs).
209pub struct Executor {
210    cfg: Config,
211    registry: Mutex<Registry>,
212    metrics: Arc<Metrics>,
213    runtime: Runtime,
214    shutdown: Mutex<Stopper>,
215}
216
217/// Implementation of [crate::Runner] for the `tokio` runtime.
218pub struct Runner {
219    cfg: Config,
220}
221
222impl Default for Runner {
223    fn default() -> Self {
224        Self::new(Config::default())
225    }
226}
227
228impl Runner {
229    /// Initialize a new `tokio` runtime with the given number of threads.
230    pub fn new(cfg: Config) -> Self {
231        Self { cfg }
232    }
233}
234
235impl crate::Runner for Runner {
236    type Context = Context;
237
238    fn start<F, Fut>(self, f: F) -> Fut::Output
239    where
240        F: FnOnce(Self::Context) -> Fut,
241        Fut: Future,
242    {
243        // Create a new registry
244        let mut registry = Registry::default();
245        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
246
247        // Initialize runtime
248        let metrics = Arc::new(Metrics::init(runtime_registry));
249        let runtime = Builder::new_multi_thread()
250            .worker_threads(self.cfg.worker_threads)
251            .max_blocking_threads(self.cfg.max_blocking_threads)
252            .enable_all()
253            .build()
254            .expect("failed to create Tokio runtime");
255
256        // Collect process metrics.
257        //
258        // We prefer to collect process metrics outside of `Context` because
259        // we are using `runtime_registry` rather than the one provided by `Context`.
260        let process = MeteredProcess::init(runtime_registry);
261        runtime.spawn(process.collect(tokio::time::sleep));
262
263        // Initialize storage
264        cfg_if::cfg_if! {
265            if #[cfg(feature = "iouring-storage")] {
266                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
267                let storage = MeteredStorage::new(
268                    IoUringStorage::start(IoUringConfig {
269                        storage_directory: self.cfg.storage_directory.clone(),
270                        ring_config: Default::default(),
271                    }, iouring_registry),
272                    runtime_registry,
273                );
274            } else {
275                let storage = MeteredStorage::new(
276                    TokioStorage::new(TokioStorageConfig::new(
277                        self.cfg.storage_directory.clone(),
278                        self.cfg.maximum_buffer_size,
279                    )),
280                    runtime_registry,
281                );
282            }
283        }
284
285        // Initialize network
286        cfg_if::cfg_if! {
287            if #[cfg(feature = "iouring-network")] {
288                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
289                let config = IoUringNetworkConfig {
290                    tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
291                    iouring_config: iouring::Config {
292                        // TODO (#1045): make `IOURING_NETWORK_SIZE` configurable
293                        size: IOURING_NETWORK_SIZE,
294                        op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
295                        force_poll: IOURING_NETWORK_FORCE_POLL,
296                        shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
297                        ..Default::default()
298                    },
299                };
300                let network = MeteredNetwork::new(
301                    IoUringNetwork::start(config, iouring_registry).unwrap(),
302                runtime_registry,
303            );
304        } else {
305            let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
306                .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
307                .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
308                let network = MeteredNetwork::new(
309                    TokioNetwork::from(config),
310                    runtime_registry,
311                );
312            }
313        }
314
315        // Initialize executor
316        let executor = Arc::new(Executor {
317            cfg: self.cfg,
318            registry: Mutex::new(registry),
319            metrics,
320            runtime,
321            shutdown: Mutex::new(Stopper::default()),
322        });
323
324        // Get metrics
325        let label = Label::root();
326        executor.metrics.tasks_spawned.get_or_create(&label).inc();
327        let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
328
329        // Run the future
330        let context = Context {
331            storage,
332            name: label.name(),
333            spawned: false,
334            executor: executor.clone(),
335            network,
336        };
337        let output = executor.runtime.block_on(f(context));
338        gauge.dec();
339
340        output
341    }
342}
343
344cfg_if::cfg_if! {
345    if #[cfg(feature = "iouring-storage")] {
346        type Storage = MeteredStorage<IoUringStorage>;
347    } else {
348        type Storage = MeteredStorage<TokioStorage>;
349    }
350}
351
352cfg_if::cfg_if! {
353    if #[cfg(feature = "iouring-network")] {
354        type Network = MeteredNetwork<IoUringNetwork>;
355    } else {
356        type Network = MeteredNetwork<TokioNetwork>;
357    }
358}
359
360/// Implementation of [crate::Spawner], [crate::Clock],
361/// [crate::Network], and [crate::Storage] for the `tokio`
362/// runtime.
363pub struct Context {
364    name: String,
365    spawned: bool,
366    executor: Arc<Executor>,
367    storage: Storage,
368    network: Network,
369}
370
371impl Clone for Context {
372    fn clone(&self) -> Self {
373        Self {
374            name: self.name.clone(),
375            spawned: false,
376            executor: self.executor.clone(),
377            storage: self.storage.clone(),
378            network: self.network.clone(),
379        }
380    }
381}
382
383impl crate::Spawner for Context {
384    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
385    where
386        F: FnOnce(Self) -> Fut + Send + 'static,
387        Fut: Future<Output = T> + Send + 'static,
388        T: Send + 'static,
389    {
390        // Ensure a context only spawns one task
391        assert!(!self.spawned, "already spawned");
392
393        // Get metrics
394        let (_, gauge) = spawn_metrics!(self, future);
395
396        // Set up the task
397        let catch_panics = self.executor.cfg.catch_panics;
398        let executor = self.executor.clone();
399        let future = f(self);
400        let (f, handle) = Handle::init_future(future, gauge, catch_panics);
401
402        // Spawn the task
403        executor.runtime.spawn(f);
404        handle
405    }
406
407    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
408    where
409        F: Future<Output = T> + Send + 'static,
410        T: Send + 'static,
411    {
412        // Ensure a context only spawns one task
413        assert!(!self.spawned, "already spawned");
414        self.spawned = true;
415
416        // Get metrics
417        let (_, gauge) = spawn_metrics!(self, future);
418
419        // Set up the task
420        let executor = self.executor.clone();
421        move |f: F| {
422            let (f, handle) = Handle::init_future(f, gauge, executor.cfg.catch_panics);
423
424            // Spawn the task
425            executor.runtime.spawn(f);
426            handle
427        }
428    }
429
430    fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
431    where
432        F: FnOnce(Self) -> T + Send + 'static,
433        T: Send + 'static,
434    {
435        // Ensure a context only spawns one task
436        assert!(!self.spawned, "already spawned");
437
438        // Get metrics
439        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
440
441        // Set up the task
442        let executor = self.executor.clone();
443        let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
444
445        // Spawn the blocking task
446        if dedicated {
447            std::thread::spawn(f);
448        } else {
449            executor.runtime.spawn_blocking(f);
450        }
451        handle
452    }
453
454    fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
455    where
456        F: FnOnce() -> T + Send + 'static,
457        T: Send + 'static,
458    {
459        // Ensure a context only spawns one task
460        assert!(!self.spawned, "already spawned");
461        self.spawned = true;
462
463        // Get metrics
464        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
465
466        // Set up the task
467        let executor = self.executor.clone();
468        move |f: F| {
469            let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
470
471            // Spawn the blocking task
472            if dedicated {
473                std::thread::spawn(f);
474            } else {
475                executor.runtime.spawn_blocking(f);
476            }
477            handle
478        }
479    }
480
481    async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
482        let stop_resolved = {
483            let mut shutdown = self.executor.shutdown.lock().unwrap();
484            shutdown.stop(value)
485        };
486
487        // Wait for all tasks to complete or the timeout to fire
488        let timeout_future = match timeout {
489            Some(duration) => futures::future::Either::Left(self.sleep(duration)),
490            None => futures::future::Either::Right(futures::future::pending()),
491        };
492        select! {
493            result = stop_resolved => {
494                result.map_err(|_| Error::Closed)?;
495                Ok(())
496            },
497            _ = timeout_future => {
498                Err(Error::Timeout)
499            }
500        }
501    }
502
503    fn stopped(&self) -> Signal {
504        self.executor.shutdown.lock().unwrap().stopped()
505    }
506}
507
508impl crate::Metrics for Context {
509    fn with_label(&self, label: &str) -> Self {
510        let name = {
511            let prefix = self.name.clone();
512            if prefix.is_empty() {
513                label.to_string()
514            } else {
515                format!("{prefix}_{label}")
516            }
517        };
518        assert!(
519            !name.starts_with(METRICS_PREFIX),
520            "using runtime label is not allowed"
521        );
522        Self {
523            name,
524            spawned: false,
525            executor: self.executor.clone(),
526            storage: self.storage.clone(),
527            network: self.network.clone(),
528        }
529    }
530
531    fn label(&self) -> String {
532        self.name.clone()
533    }
534
535    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
536        let name = name.into();
537        let prefixed_name = {
538            let prefix = &self.name;
539            if prefix.is_empty() {
540                name
541            } else {
542                format!("{}_{}", *prefix, name)
543            }
544        };
545        self.executor
546            .registry
547            .lock()
548            .unwrap()
549            .register(prefixed_name, help, metric)
550    }
551
552    fn encode(&self) -> String {
553        let mut buffer = String::new();
554        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
555        buffer
556    }
557}
558
559impl Clock for Context {
560    fn current(&self) -> SystemTime {
561        SystemTime::now()
562    }
563
564    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
565        tokio::time::sleep(duration)
566    }
567
568    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
569        let now = SystemTime::now();
570        let duration_until_deadline = match deadline.duration_since(now) {
571            Ok(duration) => duration,
572            Err(_) => Duration::from_secs(0), // Deadline is in the past
573        };
574        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
575        tokio::time::sleep_until(target_instant)
576    }
577}
578
579impl GClock for Context {
580    type Instant = SystemTime;
581
582    fn now(&self) -> Self::Instant {
583        self.current()
584    }
585}
586
587impl ReasonablyRealtime for Context {}
588
589impl crate::Network for Context {
590    type Listener = <Network as crate::Network>::Listener;
591
592    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
593        self.network.bind(socket).await
594    }
595
596    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
597        self.network.dial(socket).await
598    }
599}
600
601impl RngCore for Context {
602    fn next_u32(&mut self) -> u32 {
603        OsRng.next_u32()
604    }
605
606    fn next_u64(&mut self) -> u64 {
607        OsRng.next_u64()
608    }
609
610    fn fill_bytes(&mut self, dest: &mut [u8]) {
611        OsRng.fill_bytes(dest);
612    }
613
614    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
615        OsRng.try_fill_bytes(dest)
616    }
617}
618
619impl CryptoRng for Context {}
620
621impl crate::Storage for Context {
622    type Blob = <Storage as crate::Storage>::Blob;
623
624    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
625        self.storage.open(partition, name).await
626    }
627
628    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
629        self.storage.remove(partition, name).await
630    }
631
632    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
633        self.storage.scan(partition).await
634    }
635}