commonware_runtime/tokio/
runtime.rs

1#[cfg(feature = "iouring-storage")]
2use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
3
4#[cfg(feature = "iouring-network")]
5use crate::{
6    iouring,
7    network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
8};
9
10#[cfg(not(feature = "iouring-network"))]
11use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
12
13#[cfg(not(feature = "iouring-storage"))]
14use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
15
16use crate::network::metered::Network as MeteredNetwork;
17use crate::storage::metered::Storage as MeteredStorage;
18use crate::telemetry::metrics::task::Label;
19use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
20use crate::{SinkOf, StreamOf};
21use governor::clock::{Clock as GClock, ReasonablyRealtime};
22use prometheus_client::{
23    encoding::text::encode,
24    metrics::{counter::Counter, family::Family, gauge::Gauge},
25    registry::{Metric, Registry},
26};
27use rand::{rngs::OsRng, CryptoRng, RngCore};
28use std::{
29    env,
30    future::Future,
31    net::SocketAddr,
32    path::PathBuf,
33    sync::{Arc, Mutex},
34    time::{Duration, SystemTime},
35};
36use tokio::runtime::{Builder, Runtime};
37
38#[cfg(feature = "iouring-network")]
39const IOURING_NETWORK_FORCE_POLL: Option<Duration> = Some(Duration::from_millis(100));
40
41#[derive(Debug)]
42struct Metrics {
43    tasks_spawned: Family<Label, Counter>,
44    tasks_running: Family<Label, Gauge>,
45}
46
47impl Metrics {
48    pub fn init(registry: &mut Registry) -> Self {
49        let metrics = Self {
50            tasks_spawned: Family::default(),
51            tasks_running: Family::default(),
52        };
53        registry.register(
54            "tasks_spawned",
55            "Total number of tasks spawned",
56            metrics.tasks_spawned.clone(),
57        );
58        registry.register(
59            "tasks_running",
60            "Number of tasks currently running",
61            metrics.tasks_running.clone(),
62        );
63        metrics
64    }
65}
66
67#[derive(Clone, Debug)]
68pub struct NetworkConfig {
69    /// If Some, explicitly sets TCP_NODELAY on the socket.
70    /// Otherwise uses system default.
71    tcp_nodelay: Option<bool>,
72
73    /// Read/write timeout for network operations.
74    read_write_timeout: Duration,
75}
76
77impl Default for NetworkConfig {
78    fn default() -> Self {
79        Self {
80            tcp_nodelay: None,
81            read_write_timeout: Duration::from_secs(60),
82        }
83    }
84}
85
86/// Configuration for the `tokio` runtime.
87#[derive(Clone)]
88pub struct Config {
89    /// Number of threads to use for handling async tasks.
90    ///
91    /// Worker threads are always active (waiting for work).
92    ///
93    /// Tokio sets the default value to the number of logical CPUs.
94    worker_threads: usize,
95
96    /// Maximum number of threads to use for blocking tasks.
97    ///
98    /// Unlike worker threads, blocking threads are created as needed and
99    /// exit if left idle for too long.
100    ///
101    /// Tokio sets the default value to 512 to avoid hanging on lower-level
102    /// operations that require blocking (like `fs` and writing to `Stdout`).
103    max_blocking_threads: usize,
104
105    /// Whether or not to catch panics.
106    catch_panics: bool,
107
108    /// Base directory for all storage operations.
109    storage_directory: PathBuf,
110
111    /// Maximum buffer size for operations on blobs.
112    ///
113    /// Tokio sets the default value to 2MB.
114    maximum_buffer_size: usize,
115
116    /// Network configuration.
117    network_cfg: NetworkConfig,
118}
119
120impl Config {
121    /// Returns a new [Config] with default values.
122    pub fn new() -> Self {
123        let rng = OsRng.next_u64();
124        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
125        Self {
126            worker_threads: 2,
127            max_blocking_threads: 512,
128            catch_panics: true,
129            storage_directory,
130            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
131            network_cfg: NetworkConfig::default(),
132        }
133    }
134
135    // Setters
136    /// See [Config]
137    pub fn with_worker_threads(mut self, n: usize) -> Self {
138        self.worker_threads = n;
139        self
140    }
141    /// See [Config]
142    pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
143        self.max_blocking_threads = n;
144        self
145    }
146    /// See [Config]
147    pub fn with_catch_panics(mut self, b: bool) -> Self {
148        self.catch_panics = b;
149        self
150    }
151    /// See [Config]
152    pub fn with_read_write_timeout(mut self, d: Duration) -> Self {
153        self.network_cfg.read_write_timeout = d;
154        self
155    }
156    /// See [Config]
157    pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
158        self.network_cfg.tcp_nodelay = n;
159        self
160    }
161    /// See [Config]
162    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
163        self.storage_directory = p.into();
164        self
165    }
166    /// See [Config]
167    pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
168        self.maximum_buffer_size = n;
169        self
170    }
171
172    // Getters
173    /// See [Config]
174    pub fn worker_threads(&self) -> usize {
175        self.worker_threads
176    }
177    /// See [Config]
178    pub fn max_blocking_threads(&self) -> usize {
179        self.max_blocking_threads
180    }
181    /// See [Config]
182    pub fn catch_panics(&self) -> bool {
183        self.catch_panics
184    }
185    /// See [Config]
186    pub fn read_write_timeout(&self) -> Duration {
187        self.network_cfg.read_write_timeout
188    }
189    /// See [Config]
190    pub fn tcp_nodelay(&self) -> Option<bool> {
191        self.network_cfg.tcp_nodelay
192    }
193    /// See [Config]
194    pub fn storage_directory(&self) -> &PathBuf {
195        &self.storage_directory
196    }
197    /// See [Config]
198    pub fn maximum_buffer_size(&self) -> usize {
199        self.maximum_buffer_size
200    }
201}
202
203impl Default for Config {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209/// Runtime based on [Tokio](https://tokio.rs).
210pub struct Executor {
211    cfg: Config,
212    registry: Mutex<Registry>,
213    metrics: Arc<Metrics>,
214    runtime: Runtime,
215    signaler: Mutex<Signaler>,
216    signal: Signal,
217}
218
219/// Implementation of [crate::Runner] for the `tokio` runtime.
220pub struct Runner {
221    cfg: Config,
222}
223
224impl Default for Runner {
225    fn default() -> Self {
226        Self::new(Config::default())
227    }
228}
229
230impl Runner {
231    /// Initialize a new `tokio` runtime with the given number of threads.
232    pub fn new(cfg: Config) -> Self {
233        Self { cfg }
234    }
235}
236
237impl crate::Runner for Runner {
238    type Context = Context;
239
240    fn start<F, Fut>(self, f: F) -> Fut::Output
241    where
242        F: FnOnce(Self::Context) -> Fut,
243        Fut: Future,
244    {
245        // Create a new registry
246        let mut registry = Registry::default();
247        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
248
249        // Initialize runtime
250        let metrics = Arc::new(Metrics::init(runtime_registry));
251        let runtime = Builder::new_multi_thread()
252            .worker_threads(self.cfg.worker_threads)
253            .max_blocking_threads(self.cfg.max_blocking_threads)
254            .enable_all()
255            .build()
256            .expect("failed to create Tokio runtime");
257        let (signaler, signal) = Signaler::new();
258
259        cfg_if::cfg_if! {
260            if #[cfg(feature = "iouring-storage")] {
261                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
262                let storage = MeteredStorage::new(
263                    IoUringStorage::start(IoUringConfig {
264                        storage_directory: self.cfg.storage_directory.clone(),
265                        ring_config: Default::default(),
266                    }, iouring_registry),
267                    runtime_registry,
268                );
269            } else {
270                let storage = MeteredStorage::new(
271                    TokioStorage::new(TokioStorageConfig::new(
272                        self.cfg.storage_directory.clone(),
273                        self.cfg.maximum_buffer_size,
274                    )),
275                    runtime_registry,
276                );
277            }
278        }
279
280        cfg_if::cfg_if! {
281            if #[cfg(feature = "iouring-network")] {
282                let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
283                let config = IoUringNetworkConfig {
284                    tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
285                    iouring_config: iouring::Config {
286                        op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
287                        force_poll: IOURING_NETWORK_FORCE_POLL,
288                        shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
289                        ..Default::default()
290                    },
291                };
292                let network = MeteredNetwork::new(
293                    IoUringNetwork::start(config, iouring_registry).unwrap(),
294                runtime_registry,
295            );
296        } else {
297            let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
298                .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
299                .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
300                let network = MeteredNetwork::new(
301                    TokioNetwork::from(config),
302                    runtime_registry,
303                );
304            }
305        }
306
307        // Initialize executor
308        let executor = Arc::new(Executor {
309            cfg: self.cfg,
310            registry: Mutex::new(registry),
311            metrics,
312            runtime,
313            signaler: Mutex::new(signaler),
314            signal,
315        });
316
317        // Get metrics
318        let label = Label::root();
319        executor.metrics.tasks_spawned.get_or_create(&label).inc();
320        let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
321
322        // Run the future
323        let context = Context {
324            storage,
325            name: label.name(),
326            spawned: false,
327            executor: executor.clone(),
328            network,
329        };
330        let output = executor.runtime.block_on(f(context));
331        gauge.dec();
332
333        output
334    }
335}
336
337cfg_if::cfg_if! {
338    if #[cfg(feature = "iouring-storage")] {
339        type Storage = MeteredStorage<IoUringStorage>;
340    } else {
341        type Storage = MeteredStorage<TokioStorage>;
342    }
343}
344
345cfg_if::cfg_if! {
346    if #[cfg(feature = "iouring-network")] {
347        type Network = MeteredNetwork<IoUringNetwork>;
348    } else {
349        type Network = MeteredNetwork<TokioNetwork>;
350    }
351}
352
353/// Implementation of [crate::Spawner], [crate::Clock],
354/// [crate::Network], and [crate::Storage] for the `tokio`
355/// runtime.
356pub struct Context {
357    name: String,
358    spawned: bool,
359    executor: Arc<Executor>,
360    storage: Storage,
361    network: Network,
362}
363
364impl Clone for Context {
365    fn clone(&self) -> Self {
366        Self {
367            name: self.name.clone(),
368            spawned: false,
369            executor: self.executor.clone(),
370            storage: self.storage.clone(),
371            network: self.network.clone(),
372        }
373    }
374}
375
376impl crate::Spawner for Context {
377    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
378    where
379        F: FnOnce(Self) -> Fut + Send + 'static,
380        Fut: Future<Output = T> + Send + 'static,
381        T: Send + 'static,
382    {
383        // Ensure a context only spawns one task
384        assert!(!self.spawned, "already spawned");
385
386        // Get metrics
387        let (_, gauge) = spawn_metrics!(self, future);
388
389        // Set up the task
390        let catch_panics = self.executor.cfg.catch_panics;
391        let executor = self.executor.clone();
392        let future = f(self);
393        let (f, handle) = Handle::init_future(future, gauge, catch_panics);
394
395        // Spawn the task
396        executor.runtime.spawn(f);
397        handle
398    }
399
400    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
401    where
402        F: Future<Output = T> + Send + 'static,
403        T: Send + 'static,
404    {
405        // Ensure a context only spawns one task
406        assert!(!self.spawned, "already spawned");
407        self.spawned = true;
408
409        // Get metrics
410        let (_, gauge) = spawn_metrics!(self, future);
411
412        // Set up the task
413        let executor = self.executor.clone();
414        move |f: F| {
415            let (f, handle) = Handle::init_future(f, gauge, executor.cfg.catch_panics);
416
417            // Spawn the task
418            executor.runtime.spawn(f);
419            handle
420        }
421    }
422
423    fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
424    where
425        F: FnOnce(Self) -> T + Send + 'static,
426        T: Send + 'static,
427    {
428        // Ensure a context only spawns one task
429        assert!(!self.spawned, "already spawned");
430
431        // Get metrics
432        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
433
434        // Set up the task
435        let executor = self.executor.clone();
436        let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
437
438        // Spawn the blocking task
439        if dedicated {
440            std::thread::spawn(f);
441        } else {
442            executor.runtime.spawn_blocking(f);
443        }
444        handle
445    }
446
447    fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
448    where
449        F: FnOnce() -> T + Send + 'static,
450        T: Send + 'static,
451    {
452        // Ensure a context only spawns one task
453        assert!(!self.spawned, "already spawned");
454        self.spawned = true;
455
456        // Get metrics
457        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
458
459        // Set up the task
460        let executor = self.executor.clone();
461        move |f: F| {
462            let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
463
464            // Spawn the blocking task
465            if dedicated {
466                std::thread::spawn(f);
467            } else {
468                executor.runtime.spawn_blocking(f);
469            }
470            handle
471        }
472    }
473
474    fn stop(&self, value: i32) {
475        self.executor.signaler.lock().unwrap().signal(value);
476    }
477
478    fn stopped(&self) -> Signal {
479        self.executor.signal.clone()
480    }
481}
482
483impl crate::Metrics for Context {
484    fn with_label(&self, label: &str) -> Self {
485        let name = {
486            let prefix = self.name.clone();
487            if prefix.is_empty() {
488                label.to_string()
489            } else {
490                format!("{}_{}", prefix, label)
491            }
492        };
493        assert!(
494            !name.starts_with(METRICS_PREFIX),
495            "using runtime label is not allowed"
496        );
497        Self {
498            name,
499            spawned: false,
500            executor: self.executor.clone(),
501            storage: self.storage.clone(),
502            network: self.network.clone(),
503        }
504    }
505
506    fn label(&self) -> String {
507        self.name.clone()
508    }
509
510    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
511        let name = name.into();
512        let prefixed_name = {
513            let prefix = &self.name;
514            if prefix.is_empty() {
515                name
516            } else {
517                format!("{}_{}", *prefix, name)
518            }
519        };
520        self.executor
521            .registry
522            .lock()
523            .unwrap()
524            .register(prefixed_name, help, metric)
525    }
526
527    fn encode(&self) -> String {
528        let mut buffer = String::new();
529        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
530        buffer
531    }
532}
533
534impl Clock for Context {
535    fn current(&self) -> SystemTime {
536        SystemTime::now()
537    }
538
539    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
540        tokio::time::sleep(duration)
541    }
542
543    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
544        let now = SystemTime::now();
545        let duration_until_deadline = match deadline.duration_since(now) {
546            Ok(duration) => duration,
547            Err(_) => Duration::from_secs(0), // Deadline is in the past
548        };
549        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
550        tokio::time::sleep_until(target_instant)
551    }
552}
553
554impl GClock for Context {
555    type Instant = SystemTime;
556
557    fn now(&self) -> Self::Instant {
558        self.current()
559    }
560}
561
562impl ReasonablyRealtime for Context {}
563
564impl crate::Network for Context {
565    type Listener = <Network as crate::Network>::Listener;
566
567    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
568        self.network.bind(socket).await
569    }
570
571    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
572        self.network.dial(socket).await
573    }
574}
575
576impl RngCore for Context {
577    fn next_u32(&mut self) -> u32 {
578        OsRng.next_u32()
579    }
580
581    fn next_u64(&mut self) -> u64 {
582        OsRng.next_u64()
583    }
584
585    fn fill_bytes(&mut self, dest: &mut [u8]) {
586        OsRng.fill_bytes(dest);
587    }
588
589    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
590        OsRng.try_fill_bytes(dest)
591    }
592}
593
594impl CryptoRng for Context {}
595
596impl crate::Storage for Context {
597    type Blob = <Storage as crate::Storage>::Blob;
598
599    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
600        self.storage.open(partition, name).await
601    }
602
603    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
604        self.storage.remove(partition, name).await
605    }
606
607    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
608        self.storage.scan(partition).await
609    }
610}