commonware_runtime/tokio/
runtime.rs

1use crate::network::metered::{Listener as MeteredListener, Network as MeteredNetwork};
2use crate::network::tokio::{
3    Config as TokioNetworkConfig, Listener as TokioListener, Network as TokioNetwork,
4};
5use crate::storage::metered::Storage;
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
8use crate::{SinkOf, StreamOf};
9use governor::clock::{Clock as GClock, ReasonablyRealtime};
10use prometheus_client::{
11    encoding::{text::encode, EncodeLabelSet},
12    metrics::{counter::Counter, family::Family, gauge::Gauge},
13    registry::{Metric, Registry},
14};
15use rand::{rngs::OsRng, CryptoRng, RngCore};
16use std::{
17    env,
18    future::Future,
19    net::SocketAddr,
20    path::PathBuf,
21    sync::{Arc, Mutex},
22    time::{Duration, SystemTime},
23};
24use tokio::runtime::{Builder, Runtime};
25
26#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
27struct Work {
28    label: String,
29}
30
31#[derive(Debug)]
32struct Metrics {
33    tasks_spawned: Family<Work, Counter>,
34    tasks_running: Family<Work, Gauge>,
35    blocking_tasks_spawned: Family<Work, Counter>,
36    blocking_tasks_running: Family<Work, Gauge>,
37}
38
39impl Metrics {
40    pub fn init(registry: &mut Registry) -> Self {
41        let metrics = Self {
42            tasks_spawned: Family::default(),
43            tasks_running: Family::default(),
44            blocking_tasks_spawned: Family::default(),
45            blocking_tasks_running: Family::default(),
46        };
47        registry.register(
48            "tasks_spawned",
49            "Total number of tasks spawned",
50            metrics.tasks_spawned.clone(),
51        );
52        registry.register(
53            "tasks_running",
54            "Number of tasks currently running",
55            metrics.tasks_running.clone(),
56        );
57        registry.register(
58            "blocking_tasks_spawned",
59            "Total number of blocking tasks spawned",
60            metrics.blocking_tasks_spawned.clone(),
61        );
62        registry.register(
63            "blocking_tasks_running",
64            "Number of blocking tasks currently running",
65            metrics.blocking_tasks_running.clone(),
66        );
67        metrics
68    }
69}
70
71/// Configuration for the `tokio` runtime.
72#[derive(Clone)]
73pub struct Config {
74    /// Number of threads to use for handling async tasks.
75    ///
76    /// Worker threads are always active (waiting for work).
77    ///
78    /// Tokio sets the default value to the number of logical CPUs.
79    worker_threads: usize,
80
81    /// Maximum number of threads to use for blocking tasks.
82    ///
83    /// Unlike worker threads, blocking threads are created as needed and
84    /// exit if left idle for too long.
85    ///
86    /// Tokio sets the default value to 512 to avoid hanging on lower-level
87    /// operations that require blocking (like `fs` and writing to `Stdout`).
88    max_blocking_threads: usize,
89
90    /// Whether or not to catch panics.
91    catch_panics: bool,
92
93    /// Base directory for all storage operations.
94    storage_directory: PathBuf,
95
96    /// Maximum buffer size for operations on blobs.
97    ///
98    /// Tokio sets the default value to 2MB.
99    maximum_buffer_size: usize,
100
101    network_cfg: TokioNetworkConfig,
102}
103
104impl Config {
105    /// Returns a new [Config] with default values.
106    pub fn new() -> Self {
107        let rng = OsRng.next_u64();
108        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
109        Self {
110            worker_threads: 2,
111            max_blocking_threads: 512,
112            catch_panics: true,
113            storage_directory,
114            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
115            network_cfg: TokioNetworkConfig::default(),
116        }
117    }
118
119    // Setters
120    /// See [Config]
121    pub fn with_worker_threads(mut self, n: usize) -> Self {
122        self.worker_threads = n;
123        self
124    }
125    /// See [Config]
126    pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
127        self.max_blocking_threads = n;
128        self
129    }
130    /// See [Config]
131    pub fn with_catch_panics(mut self, b: bool) -> Self {
132        self.catch_panics = b;
133        self
134    }
135    /// See [Config]
136    pub fn with_read_timeout(mut self, d: Duration) -> Self {
137        self.network_cfg = self.network_cfg.with_read_timeout(d);
138        self
139    }
140    /// See [Config]
141    pub fn with_write_timeout(mut self, d: Duration) -> Self {
142        self.network_cfg = self.network_cfg.with_write_timeout(d);
143        self
144    }
145    /// See [Config]
146    pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
147        self.network_cfg = self.network_cfg.with_tcp_nodelay(n);
148        self
149    }
150    /// See [Config]
151    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
152        self.storage_directory = p.into();
153        self
154    }
155    /// See [Config]
156    pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
157        self.maximum_buffer_size = n;
158        self
159    }
160
161    // Getters
162    /// See [Config]
163    pub fn worker_threads(&self) -> usize {
164        self.worker_threads
165    }
166    /// See [Config]
167    pub fn max_blocking_threads(&self) -> usize {
168        self.max_blocking_threads
169    }
170    /// See [Config]
171    pub fn catch_panics(&self) -> bool {
172        self.catch_panics
173    }
174    /// See [Config]
175    pub fn read_timeout(&self) -> Duration {
176        self.network_cfg.read_timeout()
177    }
178    /// See [Config]
179    pub fn write_timeout(&self) -> Duration {
180        self.network_cfg.write_timeout()
181    }
182    /// See [Config]
183    pub fn tcp_nodelay(&self) -> Option<bool> {
184        self.network_cfg.tcp_nodelay()
185    }
186    /// See [Config]
187    pub fn storage_directory(&self) -> &PathBuf {
188        &self.storage_directory
189    }
190    /// See [Config]
191    pub fn maximum_buffer_size(&self) -> usize {
192        self.maximum_buffer_size
193    }
194}
195
196impl Default for Config {
197    fn default() -> Self {
198        Self::new()
199    }
200}
201
202/// Runtime based on [Tokio](https://tokio.rs).
203pub struct Executor {
204    cfg: Config,
205    registry: Mutex<Registry>,
206    metrics: Arc<Metrics>,
207    runtime: Runtime,
208    signaler: Mutex<Signaler>,
209    signal: Signal,
210}
211
212/// Implementation of [crate::Runner] for the `tokio` runtime.
213pub struct Runner {
214    cfg: Config,
215}
216
217impl Default for Runner {
218    fn default() -> Self {
219        Self::new(Config::default())
220    }
221}
222
223impl Runner {
224    /// Initialize a new `tokio` runtime with the given number of threads.
225    pub fn new(cfg: Config) -> Self {
226        Self { cfg }
227    }
228}
229
230impl crate::Runner for Runner {
231    type Context = Context;
232
233    fn start<F, Fut>(self, f: F) -> Fut::Output
234    where
235        F: FnOnce(Self::Context) -> Fut,
236        Fut: Future,
237    {
238        // Create a new registry
239        let mut registry = Registry::default();
240        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
241
242        // Initialize runtime
243        let metrics = Arc::new(Metrics::init(runtime_registry));
244        let runtime = Builder::new_multi_thread()
245            .worker_threads(self.cfg.worker_threads)
246            .max_blocking_threads(self.cfg.max_blocking_threads)
247            .enable_all()
248            .build()
249            .expect("failed to create Tokio runtime");
250        let (signaler, signal) = Signaler::new();
251
252        let storage = Storage::new(
253            TokioStorage::new(TokioStorageConfig::new(
254                self.cfg.storage_directory.clone(),
255                self.cfg.maximum_buffer_size,
256            )),
257            runtime_registry,
258        );
259
260        let network = TokioNetwork::from(self.cfg.network_cfg.clone());
261        let network = MeteredNetwork::new(network, runtime_registry);
262
263        let executor = Arc::new(Executor {
264            cfg: self.cfg,
265            registry: Mutex::new(registry),
266            metrics,
267            runtime,
268            signaler: Mutex::new(signaler),
269            signal,
270        });
271
272        let context = Context {
273            storage,
274            label: String::new(),
275            spawned: false,
276            executor: executor.clone(),
277            network,
278        };
279
280        executor.runtime.block_on(f(context))
281    }
282}
283
284/// Implementation of [crate::Spawner], [crate::Clock],
285/// [crate::Network], and [crate::Storage] for the `tokio`
286/// runtime.
287pub struct Context {
288    label: String,
289    spawned: bool,
290    executor: Arc<Executor>,
291    storage: Storage<TokioStorage>,
292    network: MeteredNetwork<TokioNetwork>,
293}
294
295impl Clone for Context {
296    fn clone(&self) -> Self {
297        Self {
298            label: self.label.clone(),
299            spawned: false,
300            executor: self.executor.clone(),
301            storage: self.storage.clone(),
302            network: self.network.clone(),
303        }
304    }
305}
306
307impl crate::Spawner for Context {
308    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
309    where
310        F: FnOnce(Self) -> Fut + Send + 'static,
311        Fut: Future<Output = T> + Send + 'static,
312        T: Send + 'static,
313    {
314        // Ensure a context only spawns one task
315        assert!(!self.spawned, "already spawned");
316
317        // Get metrics
318        let work = Work {
319            label: self.label.clone(),
320        };
321        self.executor
322            .metrics
323            .tasks_spawned
324            .get_or_create(&work)
325            .inc();
326        let gauge = self
327            .executor
328            .metrics
329            .tasks_running
330            .get_or_create(&work)
331            .clone();
332
333        // Set up the task
334        let catch_panics = self.executor.cfg.catch_panics;
335        let executor = self.executor.clone();
336        let future = f(self);
337        let (f, handle) = Handle::init(future, gauge, catch_panics);
338
339        // Spawn the task
340        executor.runtime.spawn(f);
341        handle
342    }
343
344    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
345    where
346        F: Future<Output = T> + Send + 'static,
347        T: Send + 'static,
348    {
349        // Ensure a context only spawns one task
350        assert!(!self.spawned, "already spawned");
351        self.spawned = true;
352
353        // Get metrics
354        let work = Work {
355            label: self.label.clone(),
356        };
357        self.executor
358            .metrics
359            .tasks_spawned
360            .get_or_create(&work)
361            .inc();
362        let gauge = self
363            .executor
364            .metrics
365            .tasks_running
366            .get_or_create(&work)
367            .clone();
368
369        // Set up the task
370        let executor = self.executor.clone();
371        move |f: F| {
372            let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
373
374            // Spawn the task
375            executor.runtime.spawn(f);
376            handle
377        }
378    }
379
380    fn spawn_blocking<F, T>(self, f: F) -> Handle<T>
381    where
382        F: FnOnce() -> T + Send + 'static,
383        T: Send + 'static,
384    {
385        // Ensure a context only spawns one task
386        assert!(!self.spawned, "already spawned");
387
388        // Get metrics
389        let work = Work {
390            label: self.label.clone(),
391        };
392        self.executor
393            .metrics
394            .blocking_tasks_spawned
395            .get_or_create(&work)
396            .inc();
397        let gauge = self
398            .executor
399            .metrics
400            .blocking_tasks_running
401            .get_or_create(&work)
402            .clone();
403
404        // Initialize the blocking task using the new function
405        let (f, handle) = Handle::init_blocking(f, gauge, self.executor.cfg.catch_panics);
406
407        // Spawn the blocking task
408        self.executor.runtime.spawn_blocking(f);
409        handle
410    }
411
412    fn stop(&self, value: i32) {
413        self.executor.signaler.lock().unwrap().signal(value);
414    }
415
416    fn stopped(&self) -> Signal {
417        self.executor.signal.clone()
418    }
419}
420
421impl crate::Metrics for Context {
422    fn with_label(&self, label: &str) -> Self {
423        let label = {
424            let prefix = self.label.clone();
425            if prefix.is_empty() {
426                label.to_string()
427            } else {
428                format!("{}_{}", prefix, label)
429            }
430        };
431        assert!(
432            !label.starts_with(METRICS_PREFIX),
433            "using runtime label is not allowed"
434        );
435        Self {
436            label,
437            spawned: false,
438            executor: self.executor.clone(),
439            storage: self.storage.clone(),
440            network: self.network.clone(),
441        }
442    }
443
444    fn label(&self) -> String {
445        self.label.clone()
446    }
447
448    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
449        let name = name.into();
450        let prefixed_name = {
451            let prefix = &self.label;
452            if prefix.is_empty() {
453                name
454            } else {
455                format!("{}_{}", *prefix, name)
456            }
457        };
458        self.executor
459            .registry
460            .lock()
461            .unwrap()
462            .register(prefixed_name, help, metric)
463    }
464
465    fn encode(&self) -> String {
466        let mut buffer = String::new();
467        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
468        buffer
469    }
470}
471
472impl Clock for Context {
473    fn current(&self) -> SystemTime {
474        SystemTime::now()
475    }
476
477    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
478        tokio::time::sleep(duration)
479    }
480
481    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
482        let now = SystemTime::now();
483        let duration_until_deadline = match deadline.duration_since(now) {
484            Ok(duration) => duration,
485            Err(_) => Duration::from_secs(0), // Deadline is in the past
486        };
487        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
488        tokio::time::sleep_until(target_instant)
489    }
490}
491
492impl GClock for Context {
493    type Instant = SystemTime;
494
495    fn now(&self) -> Self::Instant {
496        self.current()
497    }
498}
499
500impl ReasonablyRealtime for Context {}
501
502impl crate::Network for Context {
503    type Listener = MeteredListener<TokioListener>;
504
505    async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
506        self.network.bind(socket).await
507    }
508
509    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
510        self.network.dial(socket).await
511    }
512}
513
514impl RngCore for Context {
515    fn next_u32(&mut self) -> u32 {
516        OsRng.next_u32()
517    }
518
519    fn next_u64(&mut self) -> u64 {
520        OsRng.next_u64()
521    }
522
523    fn fill_bytes(&mut self, dest: &mut [u8]) {
524        OsRng.fill_bytes(dest);
525    }
526
527    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
528        OsRng.try_fill_bytes(dest)
529    }
530}
531
532impl CryptoRng for Context {}
533
534impl crate::Storage for Context {
535    type Blob = <Storage<TokioStorage> as crate::Storage>::Blob;
536
537    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
538        self.storage.open(partition, name).await
539    }
540
541    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
542        self.storage.remove(partition, name).await
543    }
544
545    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
546        self.storage.scan(partition).await
547    }
548}