Skip to main content

commonware_runtime/tokio/
runtime.rs

1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "external")]
8use crate::Pacer;
9#[cfg(feature = "iouring-network")]
10use crate::{
11    iouring,
12    network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
13};
14use crate::{
15    network::metered::Network as MeteredNetwork,
16    process::metered::Metrics as MeteredProcess,
17    signal::Signal,
18    storage::metered::Storage as MeteredStorage,
19    telemetry::metrics::task::Label,
20    utils::{add_attribute, signal::Stopper, supervision::Tree, MetricEncoder, Panicker},
21    BufferPool, BufferPoolConfig, Clock, Error, Execution, Handle, Metrics as _, SinkOf,
22    Spawner as _, StreamOf, METRICS_PREFIX,
23};
24use commonware_macros::{select, stability};
25#[stability(BETA)]
26use commonware_parallel::ThreadPool;
27use futures::{future::BoxFuture, FutureExt};
28use governor::clock::{Clock as GClock, ReasonablyRealtime};
29use prometheus_client::{
30    encoding::text::encode,
31    metrics::{counter::Counter, family::Family, gauge::Gauge},
32    registry::{Metric, Registry},
33};
34use rand::{rngs::OsRng, CryptoRng, RngCore};
35#[stability(BETA)]
36use rayon::{ThreadPoolBuildError, ThreadPoolBuilder};
37use std::{
38    borrow::Cow,
39    env,
40    future::Future,
41    net::{IpAddr, SocketAddr},
42    num::NonZeroUsize,
43    path::PathBuf,
44    sync::{Arc, Mutex},
45    thread,
46    time::{Duration, SystemTime},
47};
48use tokio::runtime::{Builder, Runtime};
49use tracing::{info_span, Instrument};
50use tracing_opentelemetry::OpenTelemetrySpanExt;
51
52#[cfg(feature = "iouring-network")]
53const IOURING_NETWORK_SIZE: u32 = 1024;
54#[cfg(feature = "iouring-network")]
55const IOURING_NETWORK_FORCE_POLL: Duration = Duration::from_millis(100);
56
57#[derive(Debug)]
58struct Metrics {
59    tasks_spawned: Family<Label, Counter>,
60    tasks_running: Family<Label, Gauge>,
61}
62
63impl Metrics {
64    pub fn init(registry: &mut Registry) -> Self {
65        let metrics = Self {
66            tasks_spawned: Family::default(),
67            tasks_running: Family::default(),
68        };
69        registry.register(
70            "tasks_spawned",
71            "Total number of tasks spawned",
72            metrics.tasks_spawned.clone(),
73        );
74        registry.register(
75            "tasks_running",
76            "Number of tasks currently running",
77            metrics.tasks_running.clone(),
78        );
79        metrics
80    }
81}
82
83#[derive(Clone, Debug)]
84pub struct NetworkConfig {
85    /// If Some, explicitly sets TCP_NODELAY on the socket.
86    /// Otherwise uses system default.
87    tcp_nodelay: Option<bool>,
88
89    /// Read/write timeout for network operations.
90    read_write_timeout: Duration,
91}
92
93impl Default for NetworkConfig {
94    fn default() -> Self {
95        Self {
96            tcp_nodelay: None,
97            read_write_timeout: Duration::from_secs(60),
98        }
99    }
100}
101
102/// Configuration for the `tokio` runtime.
103#[derive(Clone)]
104pub struct Config {
105    /// Number of threads to use for handling async tasks.
106    ///
107    /// Worker threads are always active (waiting for work).
108    ///
109    /// Tokio sets the default value to the number of logical CPUs.
110    worker_threads: usize,
111
112    /// Maximum number of threads to use for blocking tasks.
113    ///
114    /// Unlike worker threads, blocking threads are created as needed and
115    /// exit if left idle for too long.
116    ///
117    /// Tokio sets the default value to 512 to avoid hanging on lower-level
118    /// operations that require blocking (like `fs` and writing to `Stdout`).
119    max_blocking_threads: usize,
120
121    /// Whether or not to catch panics.
122    catch_panics: bool,
123
124    /// Base directory for all storage operations.
125    storage_directory: PathBuf,
126
127    /// Maximum buffer size for operations on blobs.
128    ///
129    /// Tokio sets the default value to 2MB.
130    maximum_buffer_size: usize,
131
132    /// Network configuration.
133    network_cfg: NetworkConfig,
134}
135
136impl Config {
137    /// Returns a new [Config] with default values.
138    pub fn new() -> Self {
139        let rng = OsRng.next_u64();
140        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
141        Self {
142            worker_threads: 2,
143            max_blocking_threads: 512,
144            catch_panics: false,
145            storage_directory,
146            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
147            network_cfg: NetworkConfig::default(),
148        }
149    }
150
151    // Setters
152    /// See [Config]
153    pub const fn with_worker_threads(mut self, n: usize) -> Self {
154        self.worker_threads = n;
155        self
156    }
157    /// See [Config]
158    pub const fn with_max_blocking_threads(mut self, n: usize) -> Self {
159        self.max_blocking_threads = n;
160        self
161    }
162    /// See [Config]
163    pub const fn with_catch_panics(mut self, b: bool) -> Self {
164        self.catch_panics = b;
165        self
166    }
167    /// See [Config]
168    pub const fn with_read_write_timeout(mut self, d: Duration) -> Self {
169        self.network_cfg.read_write_timeout = d;
170        self
171    }
172    /// See [Config]
173    pub const fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
174        self.network_cfg.tcp_nodelay = n;
175        self
176    }
177    /// See [Config]
178    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
179        self.storage_directory = p.into();
180        self
181    }
182    /// See [Config]
183    pub const fn with_maximum_buffer_size(mut self, n: usize) -> Self {
184        self.maximum_buffer_size = n;
185        self
186    }
187
188    // Getters
189    /// See [Config]
190    pub const fn worker_threads(&self) -> usize {
191        self.worker_threads
192    }
193    /// See [Config]
194    pub const fn max_blocking_threads(&self) -> usize {
195        self.max_blocking_threads
196    }
197    /// See [Config]
198    pub const fn catch_panics(&self) -> bool {
199        self.catch_panics
200    }
201    /// See [Config]
202    pub const fn read_write_timeout(&self) -> Duration {
203        self.network_cfg.read_write_timeout
204    }
205    /// See [Config]
206    pub const fn tcp_nodelay(&self) -> Option<bool> {
207        self.network_cfg.tcp_nodelay
208    }
209    /// See [Config]
210    pub const fn storage_directory(&self) -> &PathBuf {
211        &self.storage_directory
212    }
213    /// See [Config]
214    pub const fn maximum_buffer_size(&self) -> usize {
215        self.maximum_buffer_size
216    }
217}
218
219impl Default for Config {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225/// Runtime based on [Tokio](https://tokio.rs).
226pub struct Executor {
227    registry: Mutex<Registry>,
228    metrics: Arc<Metrics>,
229    runtime: Runtime,
230    shutdown: Mutex<Stopper>,
231    panicker: Panicker,
232}
233
234/// Implementation of [crate::Runner] for the `tokio` runtime.
235pub struct Runner {
236    cfg: Config,
237}
238
239impl Default for Runner {
240    fn default() -> Self {
241        Self::new(Config::default())
242    }
243}
244
245impl Runner {
246    /// Initialize a new `tokio` runtime with the given number of threads.
247    pub const fn new(cfg: Config) -> Self {
248        Self { cfg }
249    }
250}
251
252impl crate::Runner for Runner {
253    type Context = Context;
254
255    fn start<F, Fut>(self, f: F) -> Fut::Output
256    where
257        F: FnOnce(Self::Context) -> Fut,
258        Fut: Future,
259    {
260        // Create a new registry
261        let mut registry = Registry::default();
262        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
263
264        // Initialize runtime
265        let metrics = Arc::new(Metrics::init(runtime_registry));
266        let runtime = Builder::new_multi_thread()
267            .worker_threads(self.cfg.worker_threads)
268            .max_blocking_threads(self.cfg.max_blocking_threads)
269            .enable_all()
270            .build()
271            .expect("failed to create Tokio runtime");
272
273        // Initialize panicker
274        let (panicker, panicked) = Panicker::new(self.cfg.catch_panics);
275
276        // Collect process metrics.
277        //
278        // We prefer to collect process metrics outside of `Context` because
279        // we are using `runtime_registry` rather than the one provided by `Context`.
280        let process = MeteredProcess::init(runtime_registry);
281        runtime.spawn(process.collect(tokio::time::sleep));
282
283        // Initialize storage
284        cfg_if::cfg_if! {
285            if #[cfg(feature = "iouring-storage")] {
286                let iouring_registry =
287                    runtime_registry.sub_registry_with_prefix("iouring_storage");
288                let storage = MeteredStorage::new(
289                    IoUringStorage::start(
290                        IoUringConfig {
291                            storage_directory: self.cfg.storage_directory.clone(),
292                            iouring_config: Default::default(),
293                        },
294                        iouring_registry,
295                    ),
296                    runtime_registry,
297                );
298            } else {
299                let storage = MeteredStorage::new(
300                    TokioStorage::new(TokioStorageConfig::new(
301                        self.cfg.storage_directory.clone(),
302                        self.cfg.maximum_buffer_size,
303                    )),
304                    runtime_registry,
305                );
306            }
307        }
308
309        // Initialize buffer pools
310        let network_buffer_pool = BufferPool::new(
311            BufferPoolConfig::for_network(),
312            runtime_registry.sub_registry_with_prefix("network_buffer_pool"),
313        );
314        let storage_buffer_pool = BufferPool::new(
315            BufferPoolConfig::for_storage(),
316            runtime_registry.sub_registry_with_prefix("storage_buffer_pool"),
317        );
318
319        // Initialize network
320        cfg_if::cfg_if! {
321            if #[cfg(feature = "iouring-network")] {
322                let iouring_registry =
323                    runtime_registry.sub_registry_with_prefix("iouring_network");
324                let config = IoUringNetworkConfig {
325                    tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
326                    iouring_config: iouring::Config {
327                        // TODO (#1045): make `IOURING_NETWORK_SIZE` configurable
328                        size: IOURING_NETWORK_SIZE,
329                        op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
330                        force_poll: IOURING_NETWORK_FORCE_POLL,
331                        shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
332                        ..Default::default()
333                    },
334                    ..Default::default()
335                };
336                let network = MeteredNetwork::new(
337                    IoUringNetwork::start(
338                        config,
339                        iouring_registry,
340                        network_buffer_pool.clone(),
341                    )
342                    .unwrap(),
343                    runtime_registry,
344                );
345            } else {
346                let config = TokioNetworkConfig::default()
347                    .with_read_timeout(self.cfg.network_cfg.read_write_timeout)
348                    .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
349                    .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
350                let network = MeteredNetwork::new(
351                    TokioNetwork::new(config, network_buffer_pool.clone()),
352                    runtime_registry,
353                );
354            }
355        }
356
357        // Initialize executor
358        let executor = Arc::new(Executor {
359            registry: Mutex::new(registry),
360            metrics,
361            runtime,
362            shutdown: Mutex::new(Stopper::default()),
363            panicker,
364        });
365
366        // Get metrics
367        let label = Label::root();
368        executor.metrics.tasks_spawned.get_or_create(&label).inc();
369        let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
370
371        // Run the future
372        let context = Context {
373            storage,
374            name: label.name(),
375            attributes: Vec::new(),
376            executor: executor.clone(),
377            network,
378            network_buffer_pool,
379            storage_buffer_pool,
380            tree: Tree::root(),
381            execution: Execution::default(),
382            instrumented: false,
383        };
384        let output = executor.runtime.block_on(panicked.interrupt(f(context)));
385        gauge.dec();
386
387        output
388    }
389}
390
391cfg_if::cfg_if! {
392    if #[cfg(feature = "iouring-storage")] {
393        type Storage = MeteredStorage<IoUringStorage>;
394    } else {
395        type Storage = MeteredStorage<TokioStorage>;
396    }
397}
398
399cfg_if::cfg_if! {
400    if #[cfg(feature = "iouring-network")] {
401        type Network = MeteredNetwork<IoUringNetwork>;
402    } else {
403        type Network = MeteredNetwork<TokioNetwork>;
404    }
405}
406
407/// Implementation of [crate::Spawner], [crate::Clock],
408/// [crate::Network], and [crate::Storage] for the `tokio`
409/// runtime.
410pub struct Context {
411    name: String,
412    attributes: Vec<(String, String)>,
413    executor: Arc<Executor>,
414    storage: Storage,
415    network: Network,
416    network_buffer_pool: BufferPool,
417    storage_buffer_pool: BufferPool,
418    tree: Arc<Tree>,
419    execution: Execution,
420    instrumented: bool,
421}
422
423impl Clone for Context {
424    fn clone(&self) -> Self {
425        let (child, _) = Tree::child(&self.tree);
426        Self {
427            name: self.name.clone(),
428            attributes: self.attributes.clone(),
429            executor: self.executor.clone(),
430            storage: self.storage.clone(),
431            network: self.network.clone(),
432            network_buffer_pool: self.network_buffer_pool.clone(),
433            storage_buffer_pool: self.storage_buffer_pool.clone(),
434            tree: child,
435            execution: Execution::default(),
436            instrumented: false,
437        }
438    }
439}
440
441impl Context {
442    /// Access the [Metrics] of the runtime.
443    fn metrics(&self) -> &Metrics {
444        &self.executor.metrics
445    }
446}
447
448impl crate::Spawner for Context {
449    fn dedicated(mut self) -> Self {
450        self.execution = Execution::Dedicated;
451        self
452    }
453
454    fn shared(mut self, blocking: bool) -> Self {
455        self.execution = Execution::Shared(blocking);
456        self
457    }
458
459    fn instrumented(mut self) -> Self {
460        self.instrumented = true;
461        self
462    }
463
464    fn spawn<F, Fut, T>(mut self, f: F) -> Handle<T>
465    where
466        F: FnOnce(Self) -> Fut + Send + 'static,
467        Fut: Future<Output = T> + Send + 'static,
468        T: Send + 'static,
469    {
470        // Get metrics
471        let (label, metric) = spawn_metrics!(self);
472
473        // Track supervision before resetting configuration
474        let parent = Arc::clone(&self.tree);
475        let past = self.execution;
476        let is_instrumented = self.instrumented;
477        self.execution = Execution::default();
478        self.instrumented = false;
479        let (child, aborted) = Tree::child(&parent);
480        if aborted {
481            return Handle::closed(metric);
482        }
483        self.tree = child;
484
485        // Spawn the task
486        let executor = self.executor.clone();
487        let future: BoxFuture<'_, T> = if is_instrumented {
488            let span = info_span!("task", name = %label.name());
489            for (key, value) in &self.attributes {
490                span.set_attribute(key.clone(), value.clone());
491            }
492            f(self).instrument(span).boxed()
493        } else {
494            f(self).boxed()
495        };
496        let (f, handle) = Handle::init(
497            future,
498            metric,
499            executor.panicker.clone(),
500            Arc::clone(&parent),
501        );
502
503        if matches!(past, Execution::Dedicated) {
504            thread::spawn({
505                // Ensure the task can access the tokio runtime
506                let handle = executor.runtime.handle().clone();
507                move || {
508                    handle.block_on(f);
509                }
510            });
511        } else if matches!(past, Execution::Shared(true)) {
512            executor.runtime.spawn_blocking({
513                // Ensure the task can access the tokio runtime
514                let handle = executor.runtime.handle().clone();
515                move || {
516                    handle.block_on(f);
517                }
518            });
519        } else {
520            executor.runtime.spawn(f);
521        }
522
523        // Register the task on the parent
524        if let Some(aborter) = handle.aborter() {
525            parent.register(aborter);
526        }
527
528        handle
529    }
530
531    async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
532        let stop_resolved = {
533            let mut shutdown = self.executor.shutdown.lock().unwrap();
534            shutdown.stop(value)
535        };
536
537        // Wait for all tasks to complete or the timeout to fire
538        let timeout_future = timeout.map_or_else(
539            || futures::future::Either::Right(futures::future::pending()),
540            |duration| futures::future::Either::Left(self.sleep(duration)),
541        );
542        select! {
543            result = stop_resolved => {
544                result.map_err(|_| Error::Closed)?;
545                Ok(())
546            },
547            _ = timeout_future => Err(Error::Timeout),
548        }
549    }
550
551    fn stopped(&self) -> Signal {
552        self.executor.shutdown.lock().unwrap().stopped()
553    }
554}
555
556#[stability(BETA)]
557impl crate::ThreadPooler for Context {
558    fn create_thread_pool(
559        &self,
560        concurrency: NonZeroUsize,
561    ) -> Result<ThreadPool, ThreadPoolBuildError> {
562        ThreadPoolBuilder::new()
563            .num_threads(concurrency.get())
564            .spawn_handler(move |thread| {
565                // Tasks spawned in a thread pool are expected to run longer than any single
566                // task and thus should be provisioned as a dedicated thread.
567                self.with_label("rayon_thread")
568                    .dedicated()
569                    .spawn(move |_| async move { thread.run() });
570                Ok(())
571            })
572            .build()
573            .map(Arc::new)
574    }
575}
576
577impl crate::Metrics for Context {
578    fn with_label(&self, label: &str) -> Self {
579        // Construct the full label name
580        let name = {
581            let prefix = self.name.clone();
582            if prefix.is_empty() {
583                label.to_string()
584            } else {
585                format!("{prefix}_{label}")
586            }
587        };
588        Self {
589            name,
590            ..self.clone()
591        }
592    }
593
594    fn label(&self) -> String {
595        self.name.clone()
596    }
597
598    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
599        let name = name.into();
600        let prefixed_name = {
601            let prefix = &self.name;
602            if prefix.is_empty() {
603                name
604            } else {
605                format!("{}_{}", *prefix, name)
606            }
607        };
608
609        // Apply attributes to the registry (in sorted order)
610        let mut registry = self.executor.registry.lock().unwrap();
611        let sub_registry = self.attributes.iter().fold(&mut *registry, |reg, (k, v)| {
612            reg.sub_registry_with_label((Cow::Owned(k.clone()), Cow::Owned(v.clone())))
613        });
614        sub_registry.register(prefixed_name, help, metric);
615    }
616
617    fn encode(&self) -> String {
618        let mut encoder = MetricEncoder::new();
619        encode(&mut encoder, &self.executor.registry.lock().unwrap()).expect("encoding failed");
620        encoder.into_string()
621    }
622
623    fn with_attribute(&self, key: &str, value: impl std::fmt::Display) -> Self {
624        // Add the attribute to the list of attributes
625        let mut attributes = self.attributes.clone();
626        add_attribute(&mut attributes, key, value);
627        Self {
628            attributes,
629            ..self.clone()
630        }
631    }
632}
633
634impl Clock for Context {
635    fn current(&self) -> SystemTime {
636        SystemTime::now()
637    }
638
639    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
640        tokio::time::sleep(duration)
641    }
642
643    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
644        let now = SystemTime::now();
645        let duration_until_deadline = deadline.duration_since(now).unwrap_or_else(|_| {
646            // Deadline is in the past
647            Duration::from_secs(0)
648        });
649        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
650        tokio::time::sleep_until(target_instant)
651    }
652}
653
654#[cfg(feature = "external")]
655impl Pacer for Context {
656    fn pace<'a, F, T>(
657        &'a self,
658        _latency: Duration,
659        future: F,
660    ) -> impl Future<Output = T> + Send + 'a
661    where
662        F: Future<Output = T> + Send + 'a,
663        T: Send + 'a,
664    {
665        // Execute the future immediately
666        future
667    }
668}
669
670impl GClock for Context {
671    type Instant = SystemTime;
672
673    fn now(&self) -> Self::Instant {
674        self.current()
675    }
676}
677
678impl ReasonablyRealtime for Context {}
679
680impl crate::Network for Context {
681    type Listener = <Network as crate::Network>::Listener;
682
683    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
684        self.network.bind(socket).await
685    }
686
687    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
688        self.network.dial(socket).await
689    }
690}
691
692impl crate::Resolver for Context {
693    async fn resolve(&self, host: &str) -> Result<Vec<IpAddr>, Error> {
694        // Uses the host's DNS configuration (e.g. /etc/resolv.conf on Unix,
695        // registry on Windows). This delegates to the system's libc resolver.
696        //
697        // The `:0` port is required by lookup_host's API but is not used
698        // for DNS resolution.
699        let addrs = tokio::net::lookup_host(format!("{host}:0"))
700            .await
701            .map_err(|e| Error::ResolveFailed(e.to_string()))?;
702        Ok(addrs.map(|addr| addr.ip()).collect())
703    }
704}
705
706impl RngCore for Context {
707    fn next_u32(&mut self) -> u32 {
708        OsRng.next_u32()
709    }
710
711    fn next_u64(&mut self) -> u64 {
712        OsRng.next_u64()
713    }
714
715    fn fill_bytes(&mut self, dest: &mut [u8]) {
716        OsRng.fill_bytes(dest);
717    }
718
719    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
720        OsRng.try_fill_bytes(dest)
721    }
722}
723
724impl CryptoRng for Context {}
725
726impl crate::Storage for Context {
727    type Blob = <Storage as crate::Storage>::Blob;
728
729    async fn open_versioned(
730        &self,
731        partition: &str,
732        name: &[u8],
733        versions: std::ops::RangeInclusive<u16>,
734    ) -> Result<(Self::Blob, u64, u16), Error> {
735        self.storage.open_versioned(partition, name, versions).await
736    }
737
738    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
739        self.storage.remove(partition, name).await
740    }
741
742    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
743        self.storage.scan(partition).await
744    }
745}
746
747impl crate::BufferPooler for Context {
748    fn network_buffer_pool(&self) -> &BufferPool {
749        &self.network_buffer_pool
750    }
751
752    fn storage_buffer_pool(&self) -> &BufferPool {
753        &self.storage_buffer_pool
754    }
755}