commonware_runtime/tokio/
runtime.rs

1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "iouring-network")]
8use crate::{
9    iouring,
10    network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
11};
12use crate::{
13    network::metered::Network as MeteredNetwork, process::metered::Metrics as MeteredProcess,
14    signal::Signal, storage::metered::Storage as MeteredStorage, telemetry::metrics::task::Label,
15    utils::signal::Stopper, Clock, Error, Handle, SinkOf, StreamOf, METRICS_PREFIX,
16};
17use commonware_macros::select;
18use futures::future::AbortHandle;
19use governor::clock::{Clock as GClock, ReasonablyRealtime};
20use prometheus_client::{
21    encoding::text::encode,
22    metrics::{counter::Counter, family::Family, gauge::Gauge},
23    registry::{Metric, Registry},
24};
25use rand::{rngs::OsRng, CryptoRng, RngCore};
26use std::{
27    env,
28    future::Future,
29    net::SocketAddr,
30    path::PathBuf,
31    sync::{Arc, Mutex},
32    time::{Duration, SystemTime},
33};
34use tokio::runtime::{Builder, Runtime};
35
36#[cfg(feature = "iouring-network")]
37const IOURING_NETWORK_SIZE: u32 = 1024;
38#[cfg(feature = "iouring-network")]
39const IOURING_NETWORK_FORCE_POLL: Option<Duration> = Some(Duration::from_millis(100));
40
41#[derive(Debug)]
42struct Metrics {
43    tasks_spawned: Family<Label, Counter>,
44    tasks_running: Family<Label, Gauge>,
45}
46
47impl Metrics {
48    pub fn init(registry: &mut Registry) -> Self {
49        let metrics = Self {
50            tasks_spawned: Family::default(),
51            tasks_running: Family::default(),
52        };
53        registry.register(
54            "tasks_spawned",
55            "Total number of tasks spawned",
56            metrics.tasks_spawned.clone(),
57        );
58        registry.register(
59            "tasks_running",
60            "Number of tasks currently running",
61            metrics.tasks_running.clone(),
62        );
63        metrics
64    }
65}
66
67#[derive(Clone, Debug)]
68pub struct NetworkConfig {
69    /// If Some, explicitly sets TCP_NODELAY on the socket.
70    /// Otherwise uses system default.
71    tcp_nodelay: Option<bool>,
72
73    /// Read/write timeout for network operations.
74    read_write_timeout: Duration,
75}
76
77impl Default for NetworkConfig {
78    fn default() -> Self {
79        Self {
80            tcp_nodelay: None,
81            read_write_timeout: Duration::from_secs(60),
82        }
83    }
84}
85
86/// Configuration for the `tokio` runtime.
87#[derive(Clone)]
88pub struct Config {
89    /// Number of threads to use for handling async tasks.
90    ///
91    /// Worker threads are always active (waiting for work).
92    ///
93    /// Tokio sets the default value to the number of logical CPUs.
94    worker_threads: usize,
95
96    /// Maximum number of threads to use for blocking tasks.
97    ///
98    /// Unlike worker threads, blocking threads are created as needed and
99    /// exit if left idle for too long.
100    ///
101    /// Tokio sets the default value to 512 to avoid hanging on lower-level
102    /// operations that require blocking (like `fs` and writing to `Stdout`).
103    max_blocking_threads: usize,
104
105    /// Whether or not to catch panics.
106    catch_panics: bool,
107
108    /// Base directory for all storage operations.
109    storage_directory: PathBuf,
110
111    /// Maximum buffer size for operations on blobs.
112    ///
113    /// Tokio sets the default value to 2MB.
114    maximum_buffer_size: usize,
115
116    /// Network configuration.
117    network_cfg: NetworkConfig,
118}
119
120impl Config {
121    /// Returns a new [Config] with default values.
122    pub fn new() -> Self {
123        let rng = OsRng.next_u64();
124        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
125        Self {
126            worker_threads: 2,
127            max_blocking_threads: 512,
128            catch_panics: true,
129            storage_directory,
130            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
131            network_cfg: NetworkConfig::default(),
132        }
133    }
134
135    // Setters
136    /// See [Config]
137    pub fn with_worker_threads(mut self, n: usize) -> Self {
138        self.worker_threads = n;
139        self
140    }
141    /// See [Config]
142    pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
143        self.max_blocking_threads = n;
144        self
145    }
146    /// See [Config]
147    pub fn with_catch_panics(mut self, b: bool) -> Self {
148        self.catch_panics = b;
149        self
150    }
151    /// See [Config]
152    pub fn with_read_write_timeout(mut self, d: Duration) -> Self {
153        self.network_cfg.read_write_timeout = d;
154        self
155    }
156    /// See [Config]
157    pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
158        self.network_cfg.tcp_nodelay = n;
159        self
160    }
161    /// See [Config]
162    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
163        self.storage_directory = p.into();
164        self
165    }
166    /// See [Config]
167    pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
168        self.maximum_buffer_size = n;
169        self
170    }
171
172    // Getters
173    /// See [Config]
174    pub fn worker_threads(&self) -> usize {
175        self.worker_threads
176    }
177    /// See [Config]
178    pub fn max_blocking_threads(&self) -> usize {
179        self.max_blocking_threads
180    }
181    /// See [Config]
182    pub fn catch_panics(&self) -> bool {
183        self.catch_panics
184    }
185    /// See [Config]
186    pub fn read_write_timeout(&self) -> Duration {
187        self.network_cfg.read_write_timeout
188    }
189    /// See [Config]
190    pub fn tcp_nodelay(&self) -> Option<bool> {
191        self.network_cfg.tcp_nodelay
192    }
193    /// See [Config]
194    pub fn storage_directory(&self) -> &PathBuf {
195        &self.storage_directory
196    }
197    /// See [Config]
198    pub fn maximum_buffer_size(&self) -> usize {
199        self.maximum_buffer_size
200    }
201}
202
203impl Default for Config {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209/// Runtime based on [Tokio](https://tokio.rs).
210pub struct Executor {
211    cfg: Config,
212    registry: Mutex<Registry>,
213    metrics: Arc<Metrics>,
214    runtime: Runtime,
215    shutdown: Mutex<Stopper>,
216}
217
218/// Implementation of [crate::Runner] for the `tokio` runtime.
219pub struct Runner {
220    cfg: Config,
221}
222
223impl Default for Runner {
224    fn default() -> Self {
225        Self::new(Config::default())
226    }
227}
228
229impl Runner {
230    /// Initialize a new `tokio` runtime with the given number of threads.
231    pub fn new(cfg: Config) -> Self {
232        Self { cfg }
233    }
234}
235
236impl crate::Runner for Runner {
237    type Context = Context;
238
239    fn start<F, Fut>(self, f: F) -> Fut::Output
240    where
241        F: FnOnce(Self::Context) -> Fut,
242        Fut: Future,
243    {
244        // Create a new registry
245        let mut registry = Registry::default();
246        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
247
248        // Initialize runtime
249        let metrics = Arc::new(Metrics::init(runtime_registry));
250        let runtime = Builder::new_multi_thread()
251            .worker_threads(self.cfg.worker_threads)
252            .max_blocking_threads(self.cfg.max_blocking_threads)
253            .enable_all()
254            .build()
255            .expect("failed to create Tokio runtime");
256
257        // Collect process metrics.
258        //
259        // We prefer to collect process metrics outside of `Context` because
260        // we are using `runtime_registry` rather than the one provided by `Context`.
261        let process = MeteredProcess::init(runtime_registry);
262        runtime.spawn(process.collect(tokio::time::sleep));
263
264        // Initialize storage
265        cfg_if::cfg_if! {
266            if #[cfg(feature = "iouring-storage")] {
267                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
268                let storage = MeteredStorage::new(
269                    IoUringStorage::start(IoUringConfig {
270                        storage_directory: self.cfg.storage_directory.clone(),
271                        iouring_config: Default::default(),
272                    }, iouring_registry),
273                    runtime_registry,
274                );
275            } else {
276                let storage = MeteredStorage::new(
277                    TokioStorage::new(TokioStorageConfig::new(
278                        self.cfg.storage_directory.clone(),
279                        self.cfg.maximum_buffer_size,
280                    )),
281                    runtime_registry,
282                );
283            }
284        }
285
286        // Initialize network
287        cfg_if::cfg_if! {
288            if #[cfg(feature = "iouring-network")] {
289                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
290                let config = IoUringNetworkConfig {
291                    tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
292                    iouring_config: iouring::Config {
293                        // TODO (#1045): make `IOURING_NETWORK_SIZE` configurable
294                        size: IOURING_NETWORK_SIZE,
295                        op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
296                        force_poll: IOURING_NETWORK_FORCE_POLL,
297                        shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
298                        ..Default::default()
299                    },
300                };
301                let network = MeteredNetwork::new(
302                    IoUringNetwork::start(config, iouring_registry).unwrap(),
303                runtime_registry,
304            );
305        } else {
306            let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
307                .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
308                .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
309                let network = MeteredNetwork::new(
310                    TokioNetwork::from(config),
311                    runtime_registry,
312                );
313            }
314        }
315
316        // Initialize executor
317        let executor = Arc::new(Executor {
318            cfg: self.cfg,
319            registry: Mutex::new(registry),
320            metrics,
321            runtime,
322            shutdown: Mutex::new(Stopper::default()),
323        });
324
325        // Get metrics
326        let label = Label::root();
327        executor.metrics.tasks_spawned.get_or_create(&label).inc();
328        let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
329
330        // Run the future
331        let context = Context {
332            storage,
333            name: label.name(),
334            spawned: false,
335            executor: executor.clone(),
336            network,
337            children: Arc::new(Mutex::new(Vec::new())),
338        };
339        let output = executor.runtime.block_on(f(context));
340        gauge.dec();
341
342        output
343    }
344}
345
346cfg_if::cfg_if! {
347    if #[cfg(feature = "iouring-storage")] {
348        type Storage = MeteredStorage<IoUringStorage>;
349    } else {
350        type Storage = MeteredStorage<TokioStorage>;
351    }
352}
353
354cfg_if::cfg_if! {
355    if #[cfg(feature = "iouring-network")] {
356        type Network = MeteredNetwork<IoUringNetwork>;
357    } else {
358        type Network = MeteredNetwork<TokioNetwork>;
359    }
360}
361
362/// Implementation of [crate::Spawner], [crate::Clock],
363/// [crate::Network], and [crate::Storage] for the `tokio`
364/// runtime.
365pub struct Context {
366    name: String,
367    spawned: bool,
368    executor: Arc<Executor>,
369    storage: Storage,
370    network: Network,
371    children: Arc<Mutex<Vec<AbortHandle>>>,
372}
373
374impl Clone for Context {
375    fn clone(&self) -> Self {
376        Self {
377            name: self.name.clone(),
378            spawned: false,
379            executor: self.executor.clone(),
380            storage: self.storage.clone(),
381            network: self.network.clone(),
382            children: self.children.clone(),
383        }
384    }
385}
386
387impl crate::Spawner for Context {
388    fn spawn<F, Fut, T>(mut self, f: F) -> Handle<T>
389    where
390        F: FnOnce(Self) -> Fut + Send + 'static,
391        Fut: Future<Output = T> + Send + 'static,
392        T: Send + 'static,
393    {
394        // Ensure a context only spawns one task
395        assert!(!self.spawned, "already spawned");
396
397        // Get metrics
398        let (_, gauge) = spawn_metrics!(self, future);
399
400        // Set up the task
401        let catch_panics = self.executor.cfg.catch_panics;
402        let executor = self.executor.clone();
403
404        // Give spawned task its own empty children list
405        let children = Arc::new(Mutex::new(Vec::new()));
406        self.children = children.clone();
407
408        let future = f(self);
409        let (f, handle) = Handle::init_future(future, gauge, catch_panics, children);
410
411        // Spawn the task
412        executor.runtime.spawn(f);
413        handle
414    }
415
416    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
417    where
418        F: Future<Output = T> + Send + 'static,
419        T: Send + 'static,
420    {
421        // Ensure a context only spawns one task
422        assert!(!self.spawned, "already spawned");
423        self.spawned = true;
424
425        // Get metrics
426        let (_, gauge) = spawn_metrics!(self, future);
427
428        // Set up the task
429        let executor = self.executor.clone();
430
431        move |f: F| {
432            let (f, handle) = Handle::init_future(
433                f,
434                gauge,
435                executor.cfg.catch_panics,
436                // Give spawned task its own empty children list
437                Arc::new(Mutex::new(Vec::new())),
438            );
439
440            // Spawn the task
441            executor.runtime.spawn(f);
442            handle
443        }
444    }
445
446    fn spawn_child<F, Fut, T>(self, f: F) -> Handle<T>
447    where
448        F: FnOnce(Self) -> Fut + Send + 'static,
449        Fut: Future<Output = T> + Send + 'static,
450        T: Send + 'static,
451    {
452        // Store parent's children list
453        let parent_children = self.children.clone();
454
455        // Spawn the child
456        let child_handle = self.spawn(f);
457
458        // Register this child with the parent
459        if let Some(abort_handle) = child_handle.abort_handle() {
460            parent_children.lock().unwrap().push(abort_handle);
461        }
462
463        child_handle
464    }
465
466    fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
467    where
468        F: FnOnce(Self) -> T + Send + 'static,
469        T: Send + 'static,
470    {
471        // Ensure a context only spawns one task
472        assert!(!self.spawned, "already spawned");
473
474        // Get metrics
475        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
476
477        // Set up the task
478        let executor = self.executor.clone();
479        let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
480
481        // Spawn the blocking task
482        if dedicated {
483            std::thread::spawn(f);
484        } else {
485            executor.runtime.spawn_blocking(f);
486        }
487        handle
488    }
489
490    fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
491    where
492        F: FnOnce() -> T + Send + 'static,
493        T: Send + 'static,
494    {
495        // Ensure a context only spawns one task
496        assert!(!self.spawned, "already spawned");
497        self.spawned = true;
498
499        // Get metrics
500        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
501
502        // Set up the task
503        let executor = self.executor.clone();
504        move |f: F| {
505            let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
506
507            // Spawn the blocking task
508            if dedicated {
509                std::thread::spawn(f);
510            } else {
511                executor.runtime.spawn_blocking(f);
512            }
513            handle
514        }
515    }
516
517    async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
518        let stop_resolved = {
519            let mut shutdown = self.executor.shutdown.lock().unwrap();
520            shutdown.stop(value)
521        };
522
523        // Wait for all tasks to complete or the timeout to fire
524        let timeout_future = match timeout {
525            Some(duration) => futures::future::Either::Left(self.sleep(duration)),
526            None => futures::future::Either::Right(futures::future::pending()),
527        };
528        select! {
529            result = stop_resolved => {
530                result.map_err(|_| Error::Closed)?;
531                Ok(())
532            },
533            _ = timeout_future => {
534                Err(Error::Timeout)
535            }
536        }
537    }
538
539    fn stopped(&self) -> Signal {
540        self.executor.shutdown.lock().unwrap().stopped()
541    }
542}
543
544impl crate::Metrics for Context {
545    fn with_label(&self, label: &str) -> Self {
546        let name = {
547            let prefix = self.name.clone();
548            if prefix.is_empty() {
549                label.to_string()
550            } else {
551                format!("{prefix}_{label}")
552            }
553        };
554        assert!(
555            !name.starts_with(METRICS_PREFIX),
556            "using runtime label is not allowed"
557        );
558        Self {
559            name,
560            spawned: false,
561            executor: self.executor.clone(),
562            storage: self.storage.clone(),
563            network: self.network.clone(),
564            children: self.children.clone(),
565        }
566    }
567
568    fn label(&self) -> String {
569        self.name.clone()
570    }
571
572    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
573        let name = name.into();
574        let prefixed_name = {
575            let prefix = &self.name;
576            if prefix.is_empty() {
577                name
578            } else {
579                format!("{}_{}", *prefix, name)
580            }
581        };
582        self.executor
583            .registry
584            .lock()
585            .unwrap()
586            .register(prefixed_name, help, metric)
587    }
588
589    fn encode(&self) -> String {
590        let mut buffer = String::new();
591        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
592        buffer
593    }
594}
595
596impl Clock for Context {
597    fn current(&self) -> SystemTime {
598        SystemTime::now()
599    }
600
601    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
602        tokio::time::sleep(duration)
603    }
604
605    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
606        let now = SystemTime::now();
607        let duration_until_deadline = match deadline.duration_since(now) {
608            Ok(duration) => duration,
609            Err(_) => Duration::from_secs(0), // Deadline is in the past
610        };
611        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
612        tokio::time::sleep_until(target_instant)
613    }
614}
615
616impl GClock for Context {
617    type Instant = SystemTime;
618
619    fn now(&self) -> Self::Instant {
620        self.current()
621    }
622}
623
624impl ReasonablyRealtime for Context {}
625
626impl crate::Network for Context {
627    type Listener = <Network as crate::Network>::Listener;
628
629    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
630        self.network.bind(socket).await
631    }
632
633    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
634        self.network.dial(socket).await
635    }
636}
637
638impl RngCore for Context {
639    fn next_u32(&mut self) -> u32 {
640        OsRng.next_u32()
641    }
642
643    fn next_u64(&mut self) -> u64 {
644        OsRng.next_u64()
645    }
646
647    fn fill_bytes(&mut self, dest: &mut [u8]) {
648        OsRng.fill_bytes(dest);
649    }
650
651    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
652        OsRng.try_fill_bytes(dest)
653    }
654}
655
656impl CryptoRng for Context {}
657
658impl crate::Storage for Context {
659    type Blob = <Storage as crate::Storage>::Blob;
660
661    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
662        self.storage.open(partition, name).await
663    }
664
665    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
666        self.storage.remove(partition, name).await
667    }
668
669    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
670        self.storage.scan(partition).await
671    }
672}