commonware_runtime/tokio/
runtime.rs

1#[cfg(feature = "iouring-storage")]
2use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
3
4#[cfg(feature = "iouring-network")]
5use crate::network::iouring::Network as IoUringNetwork;
6
7#[cfg(not(feature = "iouring-network"))]
8use crate::network::tokio::Network as TokioNetwork;
9
10#[cfg(not(feature = "iouring-storage"))]
11use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
12
13use crate::network::metered::Network as MeteredNetwork;
14use crate::network::tokio::Config as TokioNetworkConfig;
15use crate::storage::metered::Storage as MeteredStorage;
16use crate::telemetry::metrics::task::Label;
17use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
18use crate::{SinkOf, StreamOf};
19use governor::clock::{Clock as GClock, ReasonablyRealtime};
20use prometheus_client::{
21    encoding::text::encode,
22    metrics::{counter::Counter, family::Family, gauge::Gauge},
23    registry::{Metric, Registry},
24};
25use rand::{rngs::OsRng, CryptoRng, RngCore};
26use std::{
27    env,
28    future::Future,
29    net::SocketAddr,
30    path::PathBuf,
31    sync::{Arc, Mutex},
32    time::{Duration, SystemTime},
33};
34use tokio::runtime::{Builder, Runtime};
35
36#[derive(Debug)]
37struct Metrics {
38    tasks_spawned: Family<Label, Counter>,
39    tasks_running: Family<Label, Gauge>,
40}
41
42impl Metrics {
43    pub fn init(registry: &mut Registry) -> Self {
44        let metrics = Self {
45            tasks_spawned: Family::default(),
46            tasks_running: Family::default(),
47        };
48        registry.register(
49            "tasks_spawned",
50            "Total number of tasks spawned",
51            metrics.tasks_spawned.clone(),
52        );
53        registry.register(
54            "tasks_running",
55            "Number of tasks currently running",
56            metrics.tasks_running.clone(),
57        );
58        metrics
59    }
60}
61
62/// Configuration for the `tokio` runtime.
63#[derive(Clone)]
64pub struct Config {
65    /// Number of threads to use for handling async tasks.
66    ///
67    /// Worker threads are always active (waiting for work).
68    ///
69    /// Tokio sets the default value to the number of logical CPUs.
70    worker_threads: usize,
71
72    /// Maximum number of threads to use for blocking tasks.
73    ///
74    /// Unlike worker threads, blocking threads are created as needed and
75    /// exit if left idle for too long.
76    ///
77    /// Tokio sets the default value to 512 to avoid hanging on lower-level
78    /// operations that require blocking (like `fs` and writing to `Stdout`).
79    max_blocking_threads: usize,
80
81    /// Whether or not to catch panics.
82    catch_panics: bool,
83
84    /// Base directory for all storage operations.
85    storage_directory: PathBuf,
86
87    /// Maximum buffer size for operations on blobs.
88    ///
89    /// Tokio sets the default value to 2MB.
90    maximum_buffer_size: usize,
91
92    network_cfg: TokioNetworkConfig,
93}
94
95impl Config {
96    /// Returns a new [Config] with default values.
97    pub fn new() -> Self {
98        let rng = OsRng.next_u64();
99        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
100        Self {
101            worker_threads: 2,
102            max_blocking_threads: 512,
103            catch_panics: true,
104            storage_directory,
105            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
106            network_cfg: TokioNetworkConfig::default(),
107        }
108    }
109
110    // Setters
111    /// See [Config]
112    pub fn with_worker_threads(mut self, n: usize) -> Self {
113        self.worker_threads = n;
114        self
115    }
116    /// See [Config]
117    pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
118        self.max_blocking_threads = n;
119        self
120    }
121    /// See [Config]
122    pub fn with_catch_panics(mut self, b: bool) -> Self {
123        self.catch_panics = b;
124        self
125    }
126    /// See [Config]
127    pub fn with_read_timeout(mut self, d: Duration) -> Self {
128        self.network_cfg = self.network_cfg.with_read_timeout(d);
129        self
130    }
131    /// See [Config]
132    pub fn with_write_timeout(mut self, d: Duration) -> Self {
133        self.network_cfg = self.network_cfg.with_write_timeout(d);
134        self
135    }
136    /// See [Config]
137    pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
138        self.network_cfg = self.network_cfg.with_tcp_nodelay(n);
139        self
140    }
141    /// See [Config]
142    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
143        self.storage_directory = p.into();
144        self
145    }
146    /// See [Config]
147    pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
148        self.maximum_buffer_size = n;
149        self
150    }
151
152    // Getters
153    /// See [Config]
154    pub fn worker_threads(&self) -> usize {
155        self.worker_threads
156    }
157    /// See [Config]
158    pub fn max_blocking_threads(&self) -> usize {
159        self.max_blocking_threads
160    }
161    /// See [Config]
162    pub fn catch_panics(&self) -> bool {
163        self.catch_panics
164    }
165    /// See [Config]
166    pub fn read_timeout(&self) -> Duration {
167        self.network_cfg.read_timeout()
168    }
169    /// See [Config]
170    pub fn write_timeout(&self) -> Duration {
171        self.network_cfg.write_timeout()
172    }
173    /// See [Config]
174    pub fn tcp_nodelay(&self) -> Option<bool> {
175        self.network_cfg.tcp_nodelay()
176    }
177    /// See [Config]
178    pub fn storage_directory(&self) -> &PathBuf {
179        &self.storage_directory
180    }
181    /// See [Config]
182    pub fn maximum_buffer_size(&self) -> usize {
183        self.maximum_buffer_size
184    }
185}
186
187impl Default for Config {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193/// Runtime based on [Tokio](https://tokio.rs).
194pub struct Executor {
195    cfg: Config,
196    registry: Mutex<Registry>,
197    metrics: Arc<Metrics>,
198    runtime: Runtime,
199    signaler: Mutex<Signaler>,
200    signal: Signal,
201}
202
203/// Implementation of [crate::Runner] for the `tokio` runtime.
204pub struct Runner {
205    cfg: Config,
206}
207
208impl Default for Runner {
209    fn default() -> Self {
210        Self::new(Config::default())
211    }
212}
213
214impl Runner {
215    /// Initialize a new `tokio` runtime with the given number of threads.
216    pub fn new(cfg: Config) -> Self {
217        Self { cfg }
218    }
219}
220
221impl crate::Runner for Runner {
222    type Context = Context;
223
224    fn start<F, Fut>(self, f: F) -> Fut::Output
225    where
226        F: FnOnce(Self::Context) -> Fut,
227        Fut: Future,
228    {
229        // Create a new registry
230        let mut registry = Registry::default();
231        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
232
233        // Initialize runtime
234        let metrics = Arc::new(Metrics::init(runtime_registry));
235        let runtime = Builder::new_multi_thread()
236            .worker_threads(self.cfg.worker_threads)
237            .max_blocking_threads(self.cfg.max_blocking_threads)
238            .enable_all()
239            .build()
240            .expect("failed to create Tokio runtime");
241        let (signaler, signal) = Signaler::new();
242
243        cfg_if::cfg_if! {
244            if #[cfg(feature = "iouring-storage")] {
245                let storage = MeteredStorage::new(
246                    IoUringStorage::start(IoUringConfig {
247                        storage_directory: self.cfg.storage_directory.clone(),
248                        ring_config: Default::default(),
249                    }),
250                    runtime_registry,
251                );
252            } else {
253                let storage = MeteredStorage::new(
254                    TokioStorage::new(TokioStorageConfig::new(
255                        self.cfg.storage_directory.clone(),
256                        self.cfg.maximum_buffer_size,
257                    )),
258                    runtime_registry,
259                );
260            }
261        }
262
263        cfg_if::cfg_if! {
264            if #[cfg(feature = "iouring-network")] {
265                let network = MeteredNetwork::new(
266                    IoUringNetwork::start(crate::iouring::Config::default()).unwrap(),
267                    runtime_registry,
268                );
269            } else {
270                let network = MeteredNetwork::new(
271                    TokioNetwork::from(self.cfg.network_cfg.clone()),
272                    runtime_registry,
273                );
274            }
275        }
276
277        // Initialize executor
278        let executor = Arc::new(Executor {
279            cfg: self.cfg,
280            registry: Mutex::new(registry),
281            metrics,
282            runtime,
283            signaler: Mutex::new(signaler),
284            signal,
285        });
286
287        // Get metrics
288        let label = Label::root();
289        executor.metrics.tasks_spawned.get_or_create(&label).inc();
290        let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
291
292        // Run the future
293        let context = Context {
294            storage,
295            name: label.name(),
296            spawned: false,
297            executor: executor.clone(),
298            network,
299        };
300        let output = executor.runtime.block_on(f(context));
301        gauge.dec();
302
303        output
304    }
305}
306
307cfg_if::cfg_if! {
308    if #[cfg(feature = "iouring-storage")] {
309        type Storage = MeteredStorage<IoUringStorage>;
310    } else {
311        type Storage = MeteredStorage<TokioStorage>;
312    }
313}
314
315cfg_if::cfg_if! {
316    if #[cfg(feature = "iouring-network")] {
317        type Network = MeteredNetwork<IoUringNetwork>;
318    } else {
319        type Network = MeteredNetwork<TokioNetwork>;
320    }
321}
322
323/// Implementation of [crate::Spawner], [crate::Clock],
324/// [crate::Network], and [crate::Storage] for the `tokio`
325/// runtime.
326pub struct Context {
327    name: String,
328    spawned: bool,
329    executor: Arc<Executor>,
330    storage: Storage,
331    network: Network,
332}
333
334impl Clone for Context {
335    fn clone(&self) -> Self {
336        Self {
337            name: self.name.clone(),
338            spawned: false,
339            executor: self.executor.clone(),
340            storage: self.storage.clone(),
341            network: self.network.clone(),
342        }
343    }
344}
345
346impl crate::Spawner for Context {
347    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
348    where
349        F: FnOnce(Self) -> Fut + Send + 'static,
350        Fut: Future<Output = T> + Send + 'static,
351        T: Send + 'static,
352    {
353        // Ensure a context only spawns one task
354        assert!(!self.spawned, "already spawned");
355
356        // Get metrics
357        let (_, gauge) = spawn_metrics!(self, future);
358
359        // Set up the task
360        let catch_panics = self.executor.cfg.catch_panics;
361        let executor = self.executor.clone();
362        let future = f(self);
363        let (f, handle) = Handle::init_future(future, gauge, catch_panics);
364
365        // Spawn the task
366        executor.runtime.spawn(f);
367        handle
368    }
369
370    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
371    where
372        F: Future<Output = T> + Send + 'static,
373        T: Send + 'static,
374    {
375        // Ensure a context only spawns one task
376        assert!(!self.spawned, "already spawned");
377        self.spawned = true;
378
379        // Get metrics
380        let (_, gauge) = spawn_metrics!(self, future);
381
382        // Set up the task
383        let executor = self.executor.clone();
384        move |f: F| {
385            let (f, handle) = Handle::init_future(f, gauge, executor.cfg.catch_panics);
386
387            // Spawn the task
388            executor.runtime.spawn(f);
389            handle
390        }
391    }
392
393    fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
394    where
395        F: FnOnce(Self) -> T + Send + 'static,
396        T: Send + 'static,
397    {
398        // Ensure a context only spawns one task
399        assert!(!self.spawned, "already spawned");
400
401        // Get metrics
402        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
403
404        // Set up the task
405        let executor = self.executor.clone();
406        let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
407
408        // Spawn the blocking task
409        if dedicated {
410            std::thread::spawn(f);
411        } else {
412            executor.runtime.spawn_blocking(f);
413        }
414        handle
415    }
416
417    fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
418    where
419        F: FnOnce() -> T + Send + 'static,
420        T: Send + 'static,
421    {
422        // Ensure a context only spawns one task
423        assert!(!self.spawned, "already spawned");
424        self.spawned = true;
425
426        // Get metrics
427        let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
428
429        // Set up the task
430        let executor = self.executor.clone();
431        move |f: F| {
432            let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
433
434            // Spawn the blocking task
435            if dedicated {
436                std::thread::spawn(f);
437            } else {
438                executor.runtime.spawn_blocking(f);
439            }
440            handle
441        }
442    }
443
444    fn stop(&self, value: i32) {
445        self.executor.signaler.lock().unwrap().signal(value);
446    }
447
448    fn stopped(&self) -> Signal {
449        self.executor.signal.clone()
450    }
451}
452
453impl crate::Metrics for Context {
454    fn with_label(&self, label: &str) -> Self {
455        let name = {
456            let prefix = self.name.clone();
457            if prefix.is_empty() {
458                label.to_string()
459            } else {
460                format!("{}_{}", prefix, label)
461            }
462        };
463        assert!(
464            !name.starts_with(METRICS_PREFIX),
465            "using runtime label is not allowed"
466        );
467        Self {
468            name,
469            spawned: false,
470            executor: self.executor.clone(),
471            storage: self.storage.clone(),
472            network: self.network.clone(),
473        }
474    }
475
476    fn label(&self) -> String {
477        self.name.clone()
478    }
479
480    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
481        let name = name.into();
482        let prefixed_name = {
483            let prefix = &self.name;
484            if prefix.is_empty() {
485                name
486            } else {
487                format!("{}_{}", *prefix, name)
488            }
489        };
490        self.executor
491            .registry
492            .lock()
493            .unwrap()
494            .register(prefixed_name, help, metric)
495    }
496
497    fn encode(&self) -> String {
498        let mut buffer = String::new();
499        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
500        buffer
501    }
502}
503
504impl Clock for Context {
505    fn current(&self) -> SystemTime {
506        SystemTime::now()
507    }
508
509    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
510        tokio::time::sleep(duration)
511    }
512
513    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
514        let now = SystemTime::now();
515        let duration_until_deadline = match deadline.duration_since(now) {
516            Ok(duration) => duration,
517            Err(_) => Duration::from_secs(0), // Deadline is in the past
518        };
519        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
520        tokio::time::sleep_until(target_instant)
521    }
522}
523
524impl GClock for Context {
525    type Instant = SystemTime;
526
527    fn now(&self) -> Self::Instant {
528        self.current()
529    }
530}
531
532impl ReasonablyRealtime for Context {}
533
534impl crate::Network for Context {
535    type Listener = <Network as crate::Network>::Listener;
536
537    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
538        self.network.bind(socket).await
539    }
540
541    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
542        self.network.dial(socket).await
543    }
544}
545
546impl RngCore for Context {
547    fn next_u32(&mut self) -> u32 {
548        OsRng.next_u32()
549    }
550
551    fn next_u64(&mut self) -> u64 {
552        OsRng.next_u64()
553    }
554
555    fn fill_bytes(&mut self, dest: &mut [u8]) {
556        OsRng.fill_bytes(dest);
557    }
558
559    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
560        OsRng.try_fill_bytes(dest)
561    }
562}
563
564impl CryptoRng for Context {}
565
566impl crate::Storage for Context {
567    type Blob = <Storage as crate::Storage>::Blob;
568
569    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
570        self.storage.open(partition, name).await
571    }
572
573    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
574        self.storage.remove(partition, name).await
575    }
576
577    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
578        self.storage.scan(partition).await
579    }
580}