commonware_runtime/tokio/
runtime.rs

1use crate::storage::metered::Storage;
2use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
3use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
4use governor::clock::{Clock as GClock, ReasonablyRealtime};
5use prometheus_client::{
6    encoding::{text::encode, EncodeLabelSet},
7    metrics::{counter::Counter, family::Family, gauge::Gauge},
8    registry::{Metric, Registry},
9};
10use rand::{rngs::OsRng, CryptoRng, RngCore};
11use std::{
12    env,
13    future::Future,
14    io,
15    net::SocketAddr,
16    path::PathBuf,
17    sync::{Arc, Mutex},
18    time::{Duration, SystemTime},
19};
20use tokio::{
21    io::{AsyncReadExt, AsyncWriteExt},
22    net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
23    runtime::{Builder, Runtime},
24    time::timeout,
25};
26use tracing::warn;
27
28#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
29struct Work {
30    label: String,
31}
32
33#[derive(Debug)]
34struct Metrics {
35    tasks_spawned: Family<Work, Counter>,
36    tasks_running: Family<Work, Gauge>,
37    blocking_tasks_spawned: Family<Work, Counter>,
38    blocking_tasks_running: Family<Work, Gauge>,
39
40    // As nice as it would be to track each of these by socket address,
41    // it quickly becomes an OOM attack vector.
42    inbound_connections: Counter,
43    outbound_connections: Counter,
44    inbound_bandwidth: Counter,
45    outbound_bandwidth: Counter,
46}
47
48impl Metrics {
49    pub fn init(registry: &mut Registry) -> Self {
50        let metrics = Self {
51            tasks_spawned: Family::default(),
52            tasks_running: Family::default(),
53            blocking_tasks_spawned: Family::default(),
54            blocking_tasks_running: Family::default(),
55            inbound_connections: Counter::default(),
56            outbound_connections: Counter::default(),
57            inbound_bandwidth: Counter::default(),
58            outbound_bandwidth: Counter::default(),
59        };
60        registry.register(
61            "tasks_spawned",
62            "Total number of tasks spawned",
63            metrics.tasks_spawned.clone(),
64        );
65        registry.register(
66            "tasks_running",
67            "Number of tasks currently running",
68            metrics.tasks_running.clone(),
69        );
70        registry.register(
71            "blocking_tasks_spawned",
72            "Total number of blocking tasks spawned",
73            metrics.blocking_tasks_spawned.clone(),
74        );
75        registry.register(
76            "blocking_tasks_running",
77            "Number of blocking tasks currently running",
78            metrics.blocking_tasks_running.clone(),
79        );
80        registry.register(
81            "inbound_connections",
82            "Number of connections created by dialing us",
83            metrics.inbound_connections.clone(),
84        );
85        registry.register(
86            "outbound_connections",
87            "Number of connections created by dialing others",
88            metrics.outbound_connections.clone(),
89        );
90        registry.register(
91            "inbound_bandwidth",
92            "Bandwidth used by receiving data from others",
93            metrics.inbound_bandwidth.clone(),
94        );
95        registry.register(
96            "outbound_bandwidth",
97            "Bandwidth used by sending data to others",
98            metrics.outbound_bandwidth.clone(),
99        );
100        metrics
101    }
102}
103
104/// Configuration for the `tokio` runtime.
105#[derive(Clone)]
106pub struct Config {
107    /// Number of threads to use for handling async tasks.
108    ///
109    /// Worker threads are always active (waiting for work).
110    ///
111    /// Tokio sets the default value to the number of logical CPUs.
112    pub worker_threads: usize,
113
114    /// Maximum number of threads to use for blocking tasks.
115    ///
116    /// Unlike worker threads, blocking threads are created as needed and
117    /// exit if left idle for too long.
118    ///
119    /// Tokio sets the default value to 512 to avoid hanging on lower-level
120    /// operations that require blocking (like `fs` and writing to `Stdout`).
121    pub max_blocking_threads: usize,
122
123    /// Whether or not to catch panics.
124    pub catch_panics: bool,
125
126    /// Duration after which to close the connection if no message is read.
127    pub read_timeout: Duration,
128
129    /// Duration after which to close the connection if a message cannot be written.
130    pub write_timeout: Duration,
131
132    /// Whether or not to disable Nagle's algorithm.
133    ///
134    /// The algorithm combines a series of small network packets into a single packet
135    /// before sending to reduce overhead of sending multiple small packets which might not
136    /// be efficient on slow, congested networks. However, to do so the algorithm introduces
137    /// a slight delay as it waits to accumulate more data. Latency-sensitive networks should
138    /// consider disabling it to send the packets as soon as possible to reduce latency.
139    ///
140    /// Note: Make sure that your compile target has and allows this configuration otherwise
141    /// panics or unexpected behaviours are possible.
142    pub tcp_nodelay: Option<bool>,
143
144    /// Base directory for all storage operations.
145    pub storage_directory: PathBuf,
146
147    /// Maximum buffer size for operations on blobs.
148    ///
149    /// Tokio sets the default value to 2MB.
150    pub maximum_buffer_size: usize,
151}
152
153impl Default for Config {
154    fn default() -> Self {
155        // Generate a random directory name to avoid conflicts (used in tests, so we shouldn't need to reload)
156        let rng = OsRng.next_u64();
157        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
158
159        // Return the configuration
160        Self {
161            worker_threads: 2,
162            max_blocking_threads: 512,
163            catch_panics: true,
164            read_timeout: Duration::from_secs(60),
165            write_timeout: Duration::from_secs(30),
166            tcp_nodelay: None,
167            storage_directory,
168            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
169        }
170    }
171}
172
173/// Runtime based on [Tokio](https://tokio.rs).
174pub struct Executor {
175    cfg: Config,
176    registry: Mutex<Registry>,
177    metrics: Arc<Metrics>,
178    runtime: Runtime,
179    signaler: Mutex<Signaler>,
180    signal: Signal,
181}
182
183/// Implementation of [`crate::Runner`] for the `tokio` runtime.
184pub struct Runner {
185    cfg: Config,
186}
187
188impl Default for Runner {
189    fn default() -> Self {
190        Self::new(Config::default())
191    }
192}
193
194impl Runner {
195    /// Initialize a new `tokio` runtime with the given number of threads.
196    pub fn new(cfg: Config) -> Self {
197        Self { cfg }
198    }
199}
200
201impl crate::Runner for Runner {
202    type Context = Context;
203
204    fn start<F, Fut>(self, f: F) -> Fut::Output
205    where
206        F: FnOnce(Self::Context) -> Fut,
207        Fut: Future,
208    {
209        // Create a new registry
210        let mut registry = Registry::default();
211        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
212
213        // Initialize runtime
214        let metrics = Arc::new(Metrics::init(runtime_registry));
215        let runtime = Builder::new_multi_thread()
216            .worker_threads(self.cfg.worker_threads)
217            .max_blocking_threads(self.cfg.max_blocking_threads)
218            .enable_all()
219            .build()
220            .expect("failed to create Tokio runtime");
221        let (signaler, signal) = Signaler::new();
222
223        let storage = Storage::new(
224            TokioStorage::new(TokioStorageConfig::new(
225                self.cfg.storage_directory.clone(),
226                self.cfg.maximum_buffer_size,
227            )),
228            runtime_registry,
229        );
230
231        let executor = Arc::new(Executor {
232            cfg: self.cfg,
233            registry: Mutex::new(registry),
234            metrics,
235            runtime,
236            signaler: Mutex::new(signaler),
237            signal,
238        });
239
240        let context = Context {
241            storage,
242            label: String::new(),
243            spawned: false,
244            executor: executor.clone(),
245        };
246
247        executor.runtime.block_on(f(context))
248    }
249}
250
251/// Implementation of [`crate::Spawner`], [`crate::Clock`],
252/// [`crate::Network`], and [`crate::Storage`] for the `tokio`
253/// runtime.
254pub struct Context {
255    label: String,
256    spawned: bool,
257    executor: Arc<Executor>,
258    storage: Storage<TokioStorage>,
259}
260
261impl Clone for Context {
262    fn clone(&self) -> Self {
263        Self {
264            label: self.label.clone(),
265            spawned: false,
266            executor: self.executor.clone(),
267            storage: self.storage.clone(),
268        }
269    }
270}
271
272impl crate::Spawner for Context {
273    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
274    where
275        F: FnOnce(Self) -> Fut + Send + 'static,
276        Fut: Future<Output = T> + Send + 'static,
277        T: Send + 'static,
278    {
279        // Ensure a context only spawns one task
280        assert!(!self.spawned, "already spawned");
281
282        // Get metrics
283        let work = Work {
284            label: self.label.clone(),
285        };
286        self.executor
287            .metrics
288            .tasks_spawned
289            .get_or_create(&work)
290            .inc();
291        let gauge = self
292            .executor
293            .metrics
294            .tasks_running
295            .get_or_create(&work)
296            .clone();
297
298        // Set up the task
299        let catch_panics = self.executor.cfg.catch_panics;
300        let executor = self.executor.clone();
301        let future = f(self);
302        let (f, handle) = Handle::init(future, gauge, catch_panics);
303
304        // Spawn the task
305        executor.runtime.spawn(f);
306        handle
307    }
308
309    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
310    where
311        F: Future<Output = T> + Send + 'static,
312        T: Send + 'static,
313    {
314        // Ensure a context only spawns one task
315        assert!(!self.spawned, "already spawned");
316        self.spawned = true;
317
318        // Get metrics
319        let work = Work {
320            label: self.label.clone(),
321        };
322        self.executor
323            .metrics
324            .tasks_spawned
325            .get_or_create(&work)
326            .inc();
327        let gauge = self
328            .executor
329            .metrics
330            .tasks_running
331            .get_or_create(&work)
332            .clone();
333
334        // Set up the task
335        let executor = self.executor.clone();
336        move |f: F| {
337            let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
338
339            // Spawn the task
340            executor.runtime.spawn(f);
341            handle
342        }
343    }
344
345    fn spawn_blocking<F, T>(self, f: F) -> Handle<T>
346    where
347        F: FnOnce() -> T + Send + 'static,
348        T: Send + 'static,
349    {
350        // Ensure a context only spawns one task
351        assert!(!self.spawned, "already spawned");
352
353        // Get metrics
354        let work = Work {
355            label: self.label.clone(),
356        };
357        self.executor
358            .metrics
359            .blocking_tasks_spawned
360            .get_or_create(&work)
361            .inc();
362        let gauge = self
363            .executor
364            .metrics
365            .blocking_tasks_running
366            .get_or_create(&work)
367            .clone();
368
369        // Initialize the blocking task using the new function
370        let (f, handle) = Handle::init_blocking(f, gauge, self.executor.cfg.catch_panics);
371
372        // Spawn the blocking task
373        self.executor.runtime.spawn_blocking(f);
374        handle
375    }
376
377    fn stop(&self, value: i32) {
378        self.executor.signaler.lock().unwrap().signal(value);
379    }
380
381    fn stopped(&self) -> Signal {
382        self.executor.signal.clone()
383    }
384}
385
386impl crate::Metrics for Context {
387    fn with_label(&self, label: &str) -> Self {
388        let label = {
389            let prefix = self.label.clone();
390            if prefix.is_empty() {
391                label.to_string()
392            } else {
393                format!("{}_{}", prefix, label)
394            }
395        };
396        assert!(
397            !label.starts_with(METRICS_PREFIX),
398            "using runtime label is not allowed"
399        );
400        Self {
401            label,
402            spawned: false,
403            executor: self.executor.clone(),
404            storage: self.storage.clone(),
405        }
406    }
407
408    fn label(&self) -> String {
409        self.label.clone()
410    }
411
412    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
413        let name = name.into();
414        let prefixed_name = {
415            let prefix = &self.label;
416            if prefix.is_empty() {
417                name
418            } else {
419                format!("{}_{}", *prefix, name)
420            }
421        };
422        self.executor
423            .registry
424            .lock()
425            .unwrap()
426            .register(prefixed_name, help, metric)
427    }
428
429    fn encode(&self) -> String {
430        let mut buffer = String::new();
431        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
432        buffer
433    }
434}
435
436impl Clock for Context {
437    fn current(&self) -> SystemTime {
438        SystemTime::now()
439    }
440
441    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
442        tokio::time::sleep(duration)
443    }
444
445    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
446        let now = SystemTime::now();
447        let duration_until_deadline = match deadline.duration_since(now) {
448            Ok(duration) => duration,
449            Err(_) => Duration::from_secs(0), // Deadline is in the past
450        };
451        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
452        tokio::time::sleep_until(target_instant)
453    }
454}
455
456impl GClock for Context {
457    type Instant = SystemTime;
458
459    fn now(&self) -> Self::Instant {
460        self.current()
461    }
462}
463
464impl ReasonablyRealtime for Context {}
465
466impl crate::Network<Listener, Sink, Stream> for Context {
467    async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
468        TcpListener::bind(socket)
469            .await
470            .map_err(|_| Error::BindFailed)
471            .map(|listener| Listener {
472                context: self.clone(),
473                listener,
474            })
475    }
476
477    async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
478        // Create a new TCP stream
479        let stream = TcpStream::connect(socket)
480            .await
481            .map_err(|_| Error::ConnectionFailed)?;
482        self.executor.metrics.outbound_connections.inc();
483
484        // Set TCP_NODELAY if configured
485        if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
486            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
487                warn!(?err, "failed to set TCP_NODELAY");
488            }
489        }
490
491        // Return the sink and stream
492        let context = self.clone();
493        let (stream, sink) = stream.into_split();
494        Ok((
495            Sink {
496                context: context.clone(),
497                sink,
498            },
499            Stream { context, stream },
500        ))
501    }
502}
503
504/// Implementation of [`crate::Listener`] for the `tokio` runtime.
505pub struct Listener {
506    context: Context,
507    listener: TcpListener,
508}
509
510impl crate::Listener<Sink, Stream> for Listener {
511    async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
512        // Accept a new TCP stream
513        let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
514        self.context.executor.metrics.inbound_connections.inc();
515
516        // Set TCP_NODELAY if configured
517        if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
518            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
519                warn!(?err, "failed to set TCP_NODELAY");
520            }
521        }
522
523        // Return the sink and stream
524        let context = self.context.clone();
525        let (stream, sink) = stream.into_split();
526        Ok((
527            addr,
528            Sink {
529                context: context.clone(),
530                sink,
531            },
532            Stream { context, stream },
533        ))
534    }
535}
536
537impl axum::serve::Listener for Listener {
538    type Io = TcpStream;
539    type Addr = SocketAddr;
540
541    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
542        let (stream, addr) = self.listener.accept().await.unwrap();
543        (stream, addr)
544    }
545
546    fn local_addr(&self) -> io::Result<Self::Addr> {
547        self.listener.local_addr()
548    }
549}
550
551/// Implementation of [`crate::Sink`] for the `tokio` runtime.
552pub struct Sink {
553    context: Context,
554    sink: OwnedWriteHalf,
555}
556
557impl crate::Sink for Sink {
558    async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
559        let len = msg.len();
560        timeout(
561            self.context.executor.cfg.write_timeout,
562            self.sink.write_all(msg),
563        )
564        .await
565        .map_err(|_| Error::Timeout)?
566        .map_err(|_| Error::SendFailed)?;
567        self.context
568            .executor
569            .metrics
570            .outbound_bandwidth
571            .inc_by(len as u64);
572        Ok(())
573    }
574}
575
576/// Implementation of [`crate::Stream`] for the `tokio` runtime.
577pub struct Stream {
578    context: Context,
579    stream: OwnedReadHalf,
580}
581
582impl crate::Stream for Stream {
583    async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
584        // Wait for the stream to be readable
585        timeout(
586            self.context.executor.cfg.read_timeout,
587            self.stream.read_exact(buf),
588        )
589        .await
590        .map_err(|_| Error::Timeout)?
591        .map_err(|_| Error::RecvFailed)?;
592
593        // Record metrics
594        self.context
595            .executor
596            .metrics
597            .inbound_bandwidth
598            .inc_by(buf.len() as u64);
599
600        Ok(())
601    }
602}
603
604impl RngCore for Context {
605    fn next_u32(&mut self) -> u32 {
606        OsRng.next_u32()
607    }
608
609    fn next_u64(&mut self) -> u64 {
610        OsRng.next_u64()
611    }
612
613    fn fill_bytes(&mut self, dest: &mut [u8]) {
614        OsRng.fill_bytes(dest);
615    }
616
617    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
618        OsRng.try_fill_bytes(dest)
619    }
620}
621
622impl CryptoRng for Context {}
623
624impl crate::Storage for Context {
625    type Blob = <Storage<TokioStorage> as crate::Storage>::Blob;
626
627    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
628        self.storage.open(partition, name).await
629    }
630
631    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
632        self.storage.remove(partition, name).await
633    }
634
635    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
636        self.storage.scan(partition).await
637    }
638}