commonware_runtime/tokio/
runtime.rs

1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "iouring-network")]
8use crate::{
9    iouring,
10    network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
11};
12use crate::{
13    network::metered::Network as MeteredNetwork, storage::metered::Storage as MeteredStorage,
14    telemetry::metrics::task::Label, utils::Signaler, Clock, Error, Handle, Signal, SinkOf,
15    StreamOf, METRICS_PREFIX,
16};
17use governor::clock::{Clock as GClock, ReasonablyRealtime};
18use prometheus_client::{
19    encoding::text::encode,
20    metrics::{counter::Counter, family::Family, gauge::Gauge},
21    registry::{Metric, Registry},
22};
23use rand::{rngs::OsRng, CryptoRng, RngCore};
24use std::{
25    env,
26    future::Future,
27    net::SocketAddr,
28    path::PathBuf,
29    sync::{Arc, Mutex},
30    time::{Duration, SystemTime},
31};
32use tokio::runtime::{Builder, Runtime};
33
34#[cfg(feature = "iouring-network")]
35const IOURING_NETWORK_SIZE: u32 = 1024;
36#[cfg(feature = "iouring-network")]
37const IOURING_NETWORK_FORCE_POLL: Option<Duration> = Some(Duration::from_millis(100));
38
39#[derive(Debug)]
40struct Metrics {
41    tasks_spawned: Family<Label, Counter>,
42    tasks_running: Family<Label, Gauge>,
43}
44
45impl Metrics {
46    pub fn init(registry: &mut Registry) -> Self {
47        let metrics = Self {
48            tasks_spawned: Family::default(),
49            tasks_running: Family::default(),
50        };
51        registry.register(
52            "tasks_spawned",
53            "Total number of tasks spawned",
54            metrics.tasks_spawned.clone(),
55        );
56        registry.register(
57            "tasks_running",
58            "Number of tasks currently running",
59            metrics.tasks_running.clone(),
60        );
61        metrics
62    }
63}
64
65#[derive(Clone, Debug)]
66pub struct NetworkConfig {
67    /// If Some, explicitly sets TCP_NODELAY on the socket.
68    /// Otherwise uses system default.
69    tcp_nodelay: Option<bool>,
70
71    /// Read/write timeout for network operations.
72    read_write_timeout: Duration,
73}
74
75impl Default for NetworkConfig {
76    fn default() -> Self {
77        Self {
78            tcp_nodelay: None,
79            read_write_timeout: Duration::from_secs(60),
80        }
81    }
82}
83
84/// Configuration for the `tokio` runtime.
85#[derive(Clone)]
86pub struct Config {
87    /// Number of threads to use for handling async tasks.
88    ///
89    /// Worker threads are always active (waiting for work).
90    ///
91    /// Tokio sets the default value to the number of logical CPUs.
92    worker_threads: usize,
93
94    /// Maximum number of threads to use for blocking tasks.
95    ///
96    /// Unlike worker threads, blocking threads are created as needed and
97    /// exit if left idle for too long.
98    ///
99    /// Tokio sets the default value to 512 to avoid hanging on lower-level
100    /// operations that require blocking (like `fs` and writing to `Stdout`).
101    max_blocking_threads: usize,
102
103    /// Whether or not to catch panics.
104    catch_panics: bool,
105
106    /// Base directory for all storage operations.
107    storage_directory: PathBuf,
108
109    /// Maximum buffer size for operations on blobs.
110    ///
111    /// Tokio sets the default value to 2MB.
112    maximum_buffer_size: usize,
113
114    /// Network configuration.
115    network_cfg: NetworkConfig,
116}
117
118impl Config {
119    /// Returns a new [Config] with default values.
120    pub fn new() -> Self {
121        let rng = OsRng.next_u64();
122        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
123        Self {
124            worker_threads: 2,
125            max_blocking_threads: 512,
126            catch_panics: true,
127            storage_directory,
128            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
129            network_cfg: NetworkConfig::default(),
130        }
131    }
132
133    // Setters
134    /// See [Config]
135    pub fn with_worker_threads(mut self, n: usize) -> Self {
136        self.worker_threads = n;
137        self
138    }
139    /// See [Config]
140    pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
141        self.max_blocking_threads = n;
142        self
143    }
144    /// See [Config]
145    pub fn with_catch_panics(mut self, b: bool) -> Self {
146        self.catch_panics = b;
147        self
148    }
149    /// See [Config]
150    pub fn with_read_write_timeout(mut self, d: Duration) -> Self {
151        self.network_cfg.read_write_timeout = d;
152        self
153    }
154    /// See [Config]
155    pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
156        self.network_cfg.tcp_nodelay = n;
157        self
158    }
159    /// See [Config]
160    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
161        self.storage_directory = p.into();
162        self
163    }
164    /// See [Config]
165    pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
166        self.maximum_buffer_size = n;
167        self
168    }
169
170    // Getters
171    /// See [Config]
172    pub fn worker_threads(&self) -> usize {
173        self.worker_threads
174    }
175    /// See [Config]
176    pub fn max_blocking_threads(&self) -> usize {
177        self.max_blocking_threads
178    }
179    /// See [Config]
180    pub fn catch_panics(&self) -> bool {
181        self.catch_panics
182    }
183    /// See [Config]
184    pub fn read_write_timeout(&self) -> Duration {
185        self.network_cfg.read_write_timeout
186    }
187    /// See [Config]
188    pub fn tcp_nodelay(&self) -> Option<bool> {
189        self.network_cfg.tcp_nodelay
190    }
191    /// See [Config]
192    pub fn storage_directory(&self) -> &PathBuf {
193        &self.storage_directory
194    }
195    /// See [Config]
196    pub fn maximum_buffer_size(&self) -> usize {
197        self.maximum_buffer_size
198    }
199}
200
201impl Default for Config {
202    fn default() -> Self {
203        Self::new()
204    }
205}
206
207/// Runtime based on [Tokio](https://tokio.rs).
208pub struct Executor {
209    cfg: Config,
210    registry: Mutex<Registry>,
211    metrics: Arc<Metrics>,
212    runtime: Runtime,
213    signaler: Mutex<Signaler>,
214    signal: Signal,
215}
216
217/// Implementation of [crate::Runner] for the `tokio` runtime.
218pub struct Runner {
219    cfg: Config,
220}
221
222impl Default for Runner {
223    fn default() -> Self {
224        Self::new(Config::default())
225    }
226}
227
228impl Runner {
229    /// Initialize a new `tokio` runtime with the given number of threads.
230    pub fn new(cfg: Config) -> Self {
231        Self { cfg }
232    }
233}
234
235impl crate::Runner for Runner {
236    type Context = Context;
237
238    fn start<F, Fut>(self, f: F) -> Fut::Output
239    where
240        F: FnOnce(Self::Context) -> Fut,
241        Fut: Future,
242    {
243        // Create a new registry
244        let mut registry = Registry::default();
245        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
246
247        // Initialize runtime
248        let metrics = Arc::new(Metrics::init(runtime_registry));
249        let runtime = Builder::new_multi_thread()
250            .worker_threads(self.cfg.worker_threads)
251            .max_blocking_threads(self.cfg.max_blocking_threads)
252            .enable_all()
253            .build()
254            .expect("failed to create Tokio runtime");
255        let (signaler, signal) = Signaler::new();
256
257        cfg_if::cfg_if! {
258            if #[cfg(feature = "iouring-storage")] {
259                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
260                let storage = MeteredStorage::new(
261                    IoUringStorage::start(IoUringConfig {
262                        storage_directory: self.cfg.storage_directory.clone(),
263                        ring_config: Default::default(),
264                    }, iouring_registry),
265                    runtime_registry,
266                );
267            } else {
268                let storage = MeteredStorage::new(
269                    TokioStorage::new(TokioStorageConfig::new(
270                        self.cfg.storage_directory.clone(),
271                        self.cfg.maximum_buffer_size,
272                    )),
273                    runtime_registry,
274                );
275            }
276        }
277
278        cfg_if::cfg_if! {
279            if #[cfg(feature = "iouring-network")] {
280                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
281                let config = IoUringNetworkConfig {
282                    tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
283                    iouring_config: iouring::Config {
284                        // TODO (#1045): make `IOURING_NETWORK_SIZE` configurable
285                        size: IOURING_NETWORK_SIZE,
286                        op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
287                        force_poll: IOURING_NETWORK_FORCE_POLL,
288                        shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
289                        ..Default::default()
290                    },
291                };
292                let network = MeteredNetwork::new(
293                    IoUringNetwork::start(config, iouring_registry).unwrap(),
294                runtime_registry,
295            );
296        } else {
297            let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
298                .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
299                .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
300                let network = MeteredNetwork::new(
301                    TokioNetwork::from(config),
302                    runtime_registry,
303                );
304            }
305        }
306
307        // Initialize executor
308        let executor = Arc::new(Executor {
309            cfg: self.cfg,
310            registry: Mutex::new(registry),
311            metrics,
312            runtime,
313            signaler: Mutex::new(signaler),
314            signal,
315        });
316
317        // Get metrics
318        let label = Label::root();
319        executor.metrics.tasks_spawned.get_or_create(&label).inc();
320        let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
321
322        // Run the future
323        let context = Context {
324            storage,
325            name: label.name(),
326            spawned: false,
327            executor: executor.clone(),
328            network,
329        };
330        let output = executor.runtime.block_on(f(context));
331        gauge.dec();
332
333        output
334    }
335}
336
337cfg_if::cfg_if! {
338    if #[cfg(feature = "iouring-storage")] {
339        type Storage = MeteredStorage<IoUringStorage>;
340    } else {
341        type Storage = MeteredStorage<TokioStorage>;
342    }
343}
344
345cfg_if::cfg_if! {
346    if #[cfg(feature = "iouring-network")] {
347        type Network = MeteredNetwork<IoUringNetwork>;
348    } else {
349        type Network = MeteredNetwork<TokioNetwork>;
350    }
351}
352
353/// Implementation of [crate::Spawner], [crate::Clock],
354/// [crate::Network], and [crate::Storage] for the `tokio`
355/// runtime.
356pub struct Context {
357    name: String,
358    spawned: bool,
359    executor: Arc<Executor>,
360    storage: Storage,
361    network: Network,
362}
363
364impl Clone for Context {
365    fn clone(&self) -> Self {
366        Self {
367            name: self.name.clone(),
368            spawned: false,
369            executor: self.executor.clone(),
370            storage: self.storage.clone(),
371            network: self.network.clone(),
372        }
373    }
374}
375
376impl crate::Spawner for Context {
377    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
378    where
379        F: FnOnce(Self) -> Fut + Send + 'static,
380        Fut: Future<Output = T> + Send + 'static,
381        T: Send + 'static,
382    {
383        // Ensure a context only spawns one task
384        assert!(!self.spawned, "already spawned");
385
386        // Get metrics
387        let (_, gauge) = spawn_metrics!(self, future);
388
389        // Set up the task
390        let catch_panics = self.executor.cfg.catch_panics;
391        let executor = self.executor.clone();
392        let future = f(self);
393        let (f, handle) = Handle::init_future(future, gauge, catch_panics);
394
395        // Spawn the task
396        executor.runtime.spawn(f);
397        handle
398    }
399
400    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
401    where
402        F: Future<Output = T> + Send + 'static,
403        T: Send + 'static,
404    {
405        // Ensure a context only spawns one task
406        assert!(!self.spawned, "already spawned");
407        self.spawned = true;
408
409        // Get metrics
410        let (_, gauge) = spawn_metrics!(self, future);
411
412        // Set up the task
413        let executor = self.executor.clone();
414        move |f: F| {
415            let (f, handle) = Handle::init_future(f, gauge, executor.cfg.catch_panics);
416
417            // Spawn the task
418            executor.runtime.spawn(f);
419            handle
420        }
421    }
422
423    fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
424    where
425        F: FnOnce(Self) -> T + Send + 'static,
426        T: Send + 'static,
427    {
428        // Ensure a context only spawns one task
429        assert!(!self.spawned, "already spawned");
430
431        // Get metrics
432        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
433
434        // Set up the task
435        let executor = self.executor.clone();
436        let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
437
438        // Spawn the blocking task
439        if dedicated {
440            std::thread::spawn(f);
441        } else {
442            executor.runtime.spawn_blocking(f);
443        }
444        handle
445    }
446
447    fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
448    where
449        F: FnOnce() -> T + Send + 'static,
450        T: Send + 'static,
451    {
452        // Ensure a context only spawns one task
453        assert!(!self.spawned, "already spawned");
454        self.spawned = true;
455
456        // Get metrics
457        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
458
459        // Set up the task
460        let executor = self.executor.clone();
461        move |f: F| {
462            let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
463
464            // Spawn the blocking task
465            if dedicated {
466                std::thread::spawn(f);
467            } else {
468                executor.runtime.spawn_blocking(f);
469            }
470            handle
471        }
472    }
473
474    fn stop(&self, value: i32) {
475        self.executor.signaler.lock().unwrap().signal(value);
476    }
477
478    fn stopped(&self) -> Signal {
479        self.executor.signal.clone()
480    }
481}
482
483impl crate::Metrics for Context {
484    fn with_label(&self, label: &str) -> Self {
485        let name = {
486            let prefix = self.name.clone();
487            if prefix.is_empty() {
488                label.to_string()
489            } else {
490                format!("{prefix}_{label}")
491            }
492        };
493        assert!(
494            !name.starts_with(METRICS_PREFIX),
495            "using runtime label is not allowed"
496        );
497        Self {
498            name,
499            spawned: false,
500            executor: self.executor.clone(),
501            storage: self.storage.clone(),
502            network: self.network.clone(),
503        }
504    }
505
506    fn label(&self) -> String {
507        self.name.clone()
508    }
509
510    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
511        let name = name.into();
512        let prefixed_name = {
513            let prefix = &self.name;
514            if prefix.is_empty() {
515                name
516            } else {
517                format!("{}_{}", *prefix, name)
518            }
519        };
520        self.executor
521            .registry
522            .lock()
523            .unwrap()
524            .register(prefixed_name, help, metric)
525    }
526
527    fn encode(&self) -> String {
528        let mut buffer = String::new();
529        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
530        buffer
531    }
532}
533
534impl Clock for Context {
535    fn current(&self) -> SystemTime {
536        SystemTime::now()
537    }
538
539    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
540        tokio::time::sleep(duration)
541    }
542
543    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
544        let now = SystemTime::now();
545        let duration_until_deadline = match deadline.duration_since(now) {
546            Ok(duration) => duration,
547            Err(_) => Duration::from_secs(0), // Deadline is in the past
548        };
549        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
550        tokio::time::sleep_until(target_instant)
551    }
552}
553
554impl GClock for Context {
555    type Instant = SystemTime;
556
557    fn now(&self) -> Self::Instant {
558        self.current()
559    }
560}
561
562impl ReasonablyRealtime for Context {}
563
564impl crate::Network for Context {
565    type Listener = <Network as crate::Network>::Listener;
566
567    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
568        self.network.bind(socket).await
569    }
570
571    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
572        self.network.dial(socket).await
573    }
574}
575
576impl RngCore for Context {
577    fn next_u32(&mut self) -> u32 {
578        OsRng.next_u32()
579    }
580
581    fn next_u64(&mut self) -> u64 {
582        OsRng.next_u64()
583    }
584
585    fn fill_bytes(&mut self, dest: &mut [u8]) {
586        OsRng.fill_bytes(dest);
587    }
588
589    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
590        OsRng.try_fill_bytes(dest)
591    }
592}
593
594impl CryptoRng for Context {}
595
596impl crate::Storage for Context {
597    type Blob = <Storage as crate::Storage>::Blob;
598
599    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
600        self.storage.open(partition, name).await
601    }
602
603    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
604        self.storage.remove(partition, name).await
605    }
606
607    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
608        self.storage.scan(partition).await
609    }
610}