commonware_runtime/tokio/
runtime.rs

1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "external")]
8use crate::Pacer;
9#[cfg(feature = "iouring-network")]
10use crate::{
11    iouring,
12    network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
13};
14use crate::{
15    network::metered::Network as MeteredNetwork,
16    process::metered::Metrics as MeteredProcess,
17    signal::Signal,
18    storage::metered::Storage as MeteredStorage,
19    telemetry::metrics::task::Label,
20    utils::{signal::Stopper, supervision::Tree, Panicker},
21    validate_label, Clock, Error, Execution, Handle, Metrics as _, SinkOf, Spawner as _, StreamOf,
22    METRICS_PREFIX,
23};
24use commonware_macros::select;
25use commonware_parallel::ThreadPool;
26use futures::{future::BoxFuture, FutureExt};
27use governor::clock::{Clock as GClock, ReasonablyRealtime};
28use prometheus_client::{
29    encoding::text::encode,
30    metrics::{counter::Counter, family::Family, gauge::Gauge},
31    registry::{Metric, Registry},
32};
33use rand::{rngs::OsRng, CryptoRng, RngCore};
34use rayon::{ThreadPoolBuildError, ThreadPoolBuilder};
35use std::{
36    env,
37    future::Future,
38    net::{IpAddr, SocketAddr},
39    num::NonZeroUsize,
40    path::PathBuf,
41    sync::{Arc, Mutex},
42    thread,
43    time::{Duration, SystemTime},
44};
45use tokio::runtime::{Builder, Runtime};
46use tracing::{info_span, Instrument};
47
48#[cfg(feature = "iouring-network")]
49const IOURING_NETWORK_SIZE: u32 = 1024;
50#[cfg(feature = "iouring-network")]
51const IOURING_NETWORK_FORCE_POLL: Duration = Duration::from_millis(100);
52
53#[derive(Debug)]
54struct Metrics {
55    tasks_spawned: Family<Label, Counter>,
56    tasks_running: Family<Label, Gauge>,
57}
58
59impl Metrics {
60    pub fn init(registry: &mut Registry) -> Self {
61        let metrics = Self {
62            tasks_spawned: Family::default(),
63            tasks_running: Family::default(),
64        };
65        registry.register(
66            "tasks_spawned",
67            "Total number of tasks spawned",
68            metrics.tasks_spawned.clone(),
69        );
70        registry.register(
71            "tasks_running",
72            "Number of tasks currently running",
73            metrics.tasks_running.clone(),
74        );
75        metrics
76    }
77}
78
79#[derive(Clone, Debug)]
80pub struct NetworkConfig {
81    /// If Some, explicitly sets TCP_NODELAY on the socket.
82    /// Otherwise uses system default.
83    tcp_nodelay: Option<bool>,
84
85    /// Read/write timeout for network operations.
86    read_write_timeout: Duration,
87}
88
89impl Default for NetworkConfig {
90    fn default() -> Self {
91        Self {
92            tcp_nodelay: None,
93            read_write_timeout: Duration::from_secs(60),
94        }
95    }
96}
97
98/// Configuration for the `tokio` runtime.
99#[derive(Clone)]
100pub struct Config {
101    /// Number of threads to use for handling async tasks.
102    ///
103    /// Worker threads are always active (waiting for work).
104    ///
105    /// Tokio sets the default value to the number of logical CPUs.
106    worker_threads: usize,
107
108    /// Maximum number of threads to use for blocking tasks.
109    ///
110    /// Unlike worker threads, blocking threads are created as needed and
111    /// exit if left idle for too long.
112    ///
113    /// Tokio sets the default value to 512 to avoid hanging on lower-level
114    /// operations that require blocking (like `fs` and writing to `Stdout`).
115    max_blocking_threads: usize,
116
117    /// Whether or not to catch panics.
118    catch_panics: bool,
119
120    /// Base directory for all storage operations.
121    storage_directory: PathBuf,
122
123    /// Maximum buffer size for operations on blobs.
124    ///
125    /// Tokio sets the default value to 2MB.
126    maximum_buffer_size: usize,
127
128    /// Network configuration.
129    network_cfg: NetworkConfig,
130}
131
132impl Config {
133    /// Returns a new [Config] with default values.
134    pub fn new() -> Self {
135        let rng = OsRng.next_u64();
136        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
137        Self {
138            worker_threads: 2,
139            max_blocking_threads: 512,
140            catch_panics: false,
141            storage_directory,
142            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
143            network_cfg: NetworkConfig::default(),
144        }
145    }
146
147    // Setters
148    /// See [Config]
149    pub const fn with_worker_threads(mut self, n: usize) -> Self {
150        self.worker_threads = n;
151        self
152    }
153    /// See [Config]
154    pub const fn with_max_blocking_threads(mut self, n: usize) -> Self {
155        self.max_blocking_threads = n;
156        self
157    }
158    /// See [Config]
159    pub const fn with_catch_panics(mut self, b: bool) -> Self {
160        self.catch_panics = b;
161        self
162    }
163    /// See [Config]
164    pub const fn with_read_write_timeout(mut self, d: Duration) -> Self {
165        self.network_cfg.read_write_timeout = d;
166        self
167    }
168    /// See [Config]
169    pub const fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
170        self.network_cfg.tcp_nodelay = n;
171        self
172    }
173    /// See [Config]
174    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
175        self.storage_directory = p.into();
176        self
177    }
178    /// See [Config]
179    pub const fn with_maximum_buffer_size(mut self, n: usize) -> Self {
180        self.maximum_buffer_size = n;
181        self
182    }
183
184    // Getters
185    /// See [Config]
186    pub const fn worker_threads(&self) -> usize {
187        self.worker_threads
188    }
189    /// See [Config]
190    pub const fn max_blocking_threads(&self) -> usize {
191        self.max_blocking_threads
192    }
193    /// See [Config]
194    pub const fn catch_panics(&self) -> bool {
195        self.catch_panics
196    }
197    /// See [Config]
198    pub const fn read_write_timeout(&self) -> Duration {
199        self.network_cfg.read_write_timeout
200    }
201    /// See [Config]
202    pub const fn tcp_nodelay(&self) -> Option<bool> {
203        self.network_cfg.tcp_nodelay
204    }
205    /// See [Config]
206    pub const fn storage_directory(&self) -> &PathBuf {
207        &self.storage_directory
208    }
209    /// See [Config]
210    pub const fn maximum_buffer_size(&self) -> usize {
211        self.maximum_buffer_size
212    }
213}
214
215impl Default for Config {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221/// Runtime based on [Tokio](https://tokio.rs).
222pub struct Executor {
223    registry: Mutex<Registry>,
224    metrics: Arc<Metrics>,
225    runtime: Runtime,
226    shutdown: Mutex<Stopper>,
227    panicker: Panicker,
228}
229
230/// Implementation of [crate::Runner] for the `tokio` runtime.
231pub struct Runner {
232    cfg: Config,
233}
234
235impl Default for Runner {
236    fn default() -> Self {
237        Self::new(Config::default())
238    }
239}
240
241impl Runner {
242    /// Initialize a new `tokio` runtime with the given number of threads.
243    pub const fn new(cfg: Config) -> Self {
244        Self { cfg }
245    }
246}
247
248impl crate::Runner for Runner {
249    type Context = Context;
250
251    fn start<F, Fut>(self, f: F) -> Fut::Output
252    where
253        F: FnOnce(Self::Context) -> Fut,
254        Fut: Future,
255    {
256        // Create a new registry
257        let mut registry = Registry::default();
258        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
259
260        // Initialize runtime
261        let metrics = Arc::new(Metrics::init(runtime_registry));
262        let runtime = Builder::new_multi_thread()
263            .worker_threads(self.cfg.worker_threads)
264            .max_blocking_threads(self.cfg.max_blocking_threads)
265            .enable_all()
266            .build()
267            .expect("failed to create Tokio runtime");
268
269        // Initialize panicker
270        let (panicker, panicked) = Panicker::new(self.cfg.catch_panics);
271
272        // Collect process metrics.
273        //
274        // We prefer to collect process metrics outside of `Context` because
275        // we are using `runtime_registry` rather than the one provided by `Context`.
276        let process = MeteredProcess::init(runtime_registry);
277        runtime.spawn(process.collect(tokio::time::sleep));
278
279        // Initialize storage
280        cfg_if::cfg_if! {
281            if #[cfg(feature = "iouring-storage")] {
282                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
283                let storage = MeteredStorage::new(
284                    IoUringStorage::start(IoUringConfig {
285                        storage_directory: self.cfg.storage_directory.clone(),
286                        iouring_config: Default::default(),
287                    }, iouring_registry),
288                    runtime_registry,
289                );
290            } else {
291                let storage = MeteredStorage::new(
292                    TokioStorage::new(TokioStorageConfig::new(
293                        self.cfg.storage_directory.clone(),
294                        self.cfg.maximum_buffer_size,
295                    )),
296                    runtime_registry,
297                );
298            }
299        }
300
301        // Initialize network
302        cfg_if::cfg_if! {
303            if #[cfg(feature = "iouring-network")] {
304                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
305                let config = IoUringNetworkConfig {
306                    tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
307                    iouring_config: iouring::Config {
308                        // TODO (#1045): make `IOURING_NETWORK_SIZE` configurable
309                        size: IOURING_NETWORK_SIZE,
310                        op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
311                        force_poll: IOURING_NETWORK_FORCE_POLL,
312                        shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
313                        ..Default::default()
314                    },
315                    ..Default::default()
316                };
317                let network = MeteredNetwork::new(
318                    IoUringNetwork::start(config, iouring_registry).unwrap(),
319                runtime_registry,
320            );
321        } else {
322            let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
323                .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
324                .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
325                let network = MeteredNetwork::new(
326                    TokioNetwork::from(config),
327                    runtime_registry,
328                );
329            }
330        }
331
332        // Initialize executor
333        let executor = Arc::new(Executor {
334            registry: Mutex::new(registry),
335            metrics,
336            runtime,
337            shutdown: Mutex::new(Stopper::default()),
338            panicker,
339        });
340
341        // Get metrics
342        let label = Label::root();
343        executor.metrics.tasks_spawned.get_or_create(&label).inc();
344        let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
345
346        // Run the future
347        let context = Context {
348            storage,
349            name: label.name(),
350            executor: executor.clone(),
351            network,
352            tree: Tree::root(),
353            execution: Execution::default(),
354            instrumented: false,
355        };
356        let output = executor.runtime.block_on(panicked.interrupt(f(context)));
357        gauge.dec();
358
359        output
360    }
361}
362
363cfg_if::cfg_if! {
364    if #[cfg(feature = "iouring-storage")] {
365        type Storage = MeteredStorage<IoUringStorage>;
366    } else {
367        type Storage = MeteredStorage<TokioStorage>;
368    }
369}
370
371cfg_if::cfg_if! {
372    if #[cfg(feature = "iouring-network")] {
373        type Network = MeteredNetwork<IoUringNetwork>;
374    } else {
375        type Network = MeteredNetwork<TokioNetwork>;
376    }
377}
378
379/// Implementation of [crate::Spawner], [crate::Clock],
380/// [crate::Network], and [crate::Storage] for the `tokio`
381/// runtime.
382pub struct Context {
383    name: String,
384    executor: Arc<Executor>,
385    storage: Storage,
386    network: Network,
387    tree: Arc<Tree>,
388    execution: Execution,
389    instrumented: bool,
390}
391
392impl Clone for Context {
393    fn clone(&self) -> Self {
394        let (child, _) = Tree::child(&self.tree);
395        Self {
396            name: self.name.clone(),
397            executor: self.executor.clone(),
398            storage: self.storage.clone(),
399            network: self.network.clone(),
400
401            tree: child,
402            execution: Execution::default(),
403            instrumented: false,
404        }
405    }
406}
407
408impl Context {
409    /// Access the [Metrics] of the runtime.
410    fn metrics(&self) -> &Metrics {
411        &self.executor.metrics
412    }
413}
414
415impl crate::Spawner for Context {
416    fn dedicated(mut self) -> Self {
417        self.execution = Execution::Dedicated;
418        self
419    }
420
421    fn shared(mut self, blocking: bool) -> Self {
422        self.execution = Execution::Shared(blocking);
423        self
424    }
425
426    fn instrumented(mut self) -> Self {
427        self.instrumented = true;
428        self
429    }
430
431    fn spawn<F, Fut, T>(mut self, f: F) -> Handle<T>
432    where
433        F: FnOnce(Self) -> Fut + Send + 'static,
434        Fut: Future<Output = T> + Send + 'static,
435        T: Send + 'static,
436    {
437        // Get metrics
438        let (label, metric) = spawn_metrics!(self);
439
440        // Track supervision before resetting configuration
441        let parent = Arc::clone(&self.tree);
442        let past = self.execution;
443        let is_instrumented = self.instrumented;
444        self.execution = Execution::default();
445        self.instrumented = false;
446        let (child, aborted) = Tree::child(&parent);
447        if aborted {
448            return Handle::closed(metric);
449        }
450        self.tree = child;
451
452        // Spawn the task
453        let executor = self.executor.clone();
454        let future: BoxFuture<'_, T> = if is_instrumented {
455            f(self)
456                .instrument(info_span!("task", name = %label.name()))
457                .boxed()
458        } else {
459            f(self).boxed()
460        };
461        let (f, handle) = Handle::init(
462            future,
463            metric,
464            executor.panicker.clone(),
465            Arc::clone(&parent),
466        );
467
468        if matches!(past, Execution::Dedicated) {
469            thread::spawn({
470                // Ensure the task can access the tokio runtime
471                let handle = executor.runtime.handle().clone();
472                move || {
473                    handle.block_on(f);
474                }
475            });
476        } else if matches!(past, Execution::Shared(true)) {
477            executor.runtime.spawn_blocking({
478                // Ensure the task can access the tokio runtime
479                let handle = executor.runtime.handle().clone();
480                move || {
481                    handle.block_on(f);
482                }
483            });
484        } else {
485            executor.runtime.spawn(f);
486        }
487
488        // Register the task on the parent
489        if let Some(aborter) = handle.aborter() {
490            parent.register(aborter);
491        }
492
493        handle
494    }
495
496    async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
497        let stop_resolved = {
498            let mut shutdown = self.executor.shutdown.lock().unwrap();
499            shutdown.stop(value)
500        };
501
502        // Wait for all tasks to complete or the timeout to fire
503        let timeout_future = timeout.map_or_else(
504            || futures::future::Either::Right(futures::future::pending()),
505            |duration| futures::future::Either::Left(self.sleep(duration)),
506        );
507        select! {
508            result = stop_resolved => {
509                result.map_err(|_| Error::Closed)?;
510                Ok(())
511            },
512            _ = timeout_future => {
513                Err(Error::Timeout)
514            }
515        }
516    }
517
518    fn stopped(&self) -> Signal {
519        self.executor.shutdown.lock().unwrap().stopped()
520    }
521}
522
523impl crate::RayonPoolSpawner for Context {
524    fn create_pool(&self, concurrency: NonZeroUsize) -> Result<ThreadPool, ThreadPoolBuildError> {
525        ThreadPoolBuilder::new()
526            .num_threads(concurrency.get())
527            .spawn_handler(move |thread| {
528                // Tasks spawned in a thread pool are expected to run longer than any single
529                // task and thus should be provisioned as a dedicated thread.
530                self.with_label("rayon_thread")
531                    .dedicated()
532                    .spawn(move |_| async move { thread.run() });
533                Ok(())
534            })
535            .build()
536            .map(Arc::new)
537    }
538}
539
540impl crate::Metrics for Context {
541    fn with_label(&self, label: &str) -> Self {
542        // Ensure the label is well-formatted
543        validate_label(label);
544
545        // Construct the full label name
546        let name = {
547            let prefix = self.name.clone();
548            if prefix.is_empty() {
549                label.to_string()
550            } else {
551                format!("{prefix}_{label}")
552            }
553        };
554        assert!(
555            !name.starts_with(METRICS_PREFIX),
556            "using runtime label is not allowed"
557        );
558        Self {
559            name,
560            ..self.clone()
561        }
562    }
563
564    fn label(&self) -> String {
565        self.name.clone()
566    }
567
568    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
569        let name = name.into();
570        let prefixed_name = {
571            let prefix = &self.name;
572            if prefix.is_empty() {
573                name
574            } else {
575                format!("{}_{}", *prefix, name)
576            }
577        };
578        self.executor
579            .registry
580            .lock()
581            .unwrap()
582            .register(prefixed_name, help, metric)
583    }
584
585    fn encode(&self) -> String {
586        let mut buffer = String::new();
587        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
588        buffer
589    }
590}
591
592impl Clock for Context {
593    fn current(&self) -> SystemTime {
594        SystemTime::now()
595    }
596
597    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
598        tokio::time::sleep(duration)
599    }
600
601    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
602        let now = SystemTime::now();
603        let duration_until_deadline = deadline.duration_since(now).unwrap_or_else(|_| {
604            // Deadline is in the past
605            Duration::from_secs(0)
606        });
607        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
608        tokio::time::sleep_until(target_instant)
609    }
610}
611
612#[cfg(feature = "external")]
613impl Pacer for Context {
614    fn pace<'a, F, T>(
615        &'a self,
616        _latency: Duration,
617        future: F,
618    ) -> impl Future<Output = T> + Send + 'a
619    where
620        F: Future<Output = T> + Send + 'a,
621        T: Send + 'a,
622    {
623        // Execute the future immediately
624        future
625    }
626}
627
628impl GClock for Context {
629    type Instant = SystemTime;
630
631    fn now(&self) -> Self::Instant {
632        self.current()
633    }
634}
635
636impl ReasonablyRealtime for Context {}
637
638impl crate::Network for Context {
639    type Listener = <Network as crate::Network>::Listener;
640
641    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
642        self.network.bind(socket).await
643    }
644
645    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
646        self.network.dial(socket).await
647    }
648}
649
650impl crate::Resolver for Context {
651    async fn resolve(&self, host: &str) -> Result<Vec<IpAddr>, Error> {
652        // Uses the host's DNS configuration (e.g. /etc/resolv.conf on Unix,
653        // registry on Windows). This delegates to the system's libc resolver.
654        //
655        // The `:0` port is required by lookup_host's API but is not used
656        // for DNS resolution.
657        let addrs = tokio::net::lookup_host(format!("{host}:0"))
658            .await
659            .map_err(|e| Error::ResolveFailed(e.to_string()))?;
660        Ok(addrs.map(|addr| addr.ip()).collect())
661    }
662}
663
664impl RngCore for Context {
665    fn next_u32(&mut self) -> u32 {
666        OsRng.next_u32()
667    }
668
669    fn next_u64(&mut self) -> u64 {
670        OsRng.next_u64()
671    }
672
673    fn fill_bytes(&mut self, dest: &mut [u8]) {
674        OsRng.fill_bytes(dest);
675    }
676
677    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
678        OsRng.try_fill_bytes(dest)
679    }
680}
681
682impl CryptoRng for Context {}
683
684impl crate::Storage for Context {
685    type Blob = <Storage as crate::Storage>::Blob;
686
687    async fn open_versioned(
688        &self,
689        partition: &str,
690        name: &[u8],
691        versions: std::ops::RangeInclusive<u16>,
692    ) -> Result<(Self::Blob, u64, u16), Error> {
693        self.storage.open_versioned(partition, name, versions).await
694    }
695
696    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
697        self.storage.remove(partition, name).await
698    }
699
700    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
701        self.storage.scan(partition).await
702    }
703}