commonware_runtime/
tokio.rs

1//! A production-focused runtime based on [Tokio](https://tokio.rs) with
2//! secure randomness and storage backed by the local filesystem.
3//!
4//! # Panics
5//!
6//! By default, the runtime will catch any panic and log the error. It is
7//! possible to override this behavior in the configuration.
8//!
9//! # Example
10//!
11//! ```rust
12//! use commonware_runtime::{Spawner, Runner, tokio::Executor};
13//!
14//! let (executor, runtime) = Executor::default();
15//! executor.start(async move {
16//!     println!("Parent started");
17//!     let result = runtime.spawn("child", async move {
18//!         println!("Child started");
19//!         "hello"
20//!     });
21//!     println!("Child result: {:?}", result.await);
22//!     println!("Parent exited");
23//! });
24//! ```
25
26use crate::{utils::Signaler, Clock, Error, Handle, Signal};
27use commonware_utils::{from_hex, hex};
28use governor::clock::{Clock as GClock, ReasonablyRealtime};
29use prometheus_client::{
30    encoding::EncodeLabelSet,
31    metrics::{counter::Counter, family::Family, gauge::Gauge},
32    registry::Registry,
33};
34use rand::{rngs::OsRng, CryptoRng, RngCore};
35use std::{
36    env,
37    future::Future,
38    io::SeekFrom,
39    net::SocketAddr,
40    path::PathBuf,
41    sync::{Arc, Mutex},
42    time::{Duration, SystemTime},
43};
44use tokio::{
45    fs,
46    io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
47    net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
48    runtime::{Builder, Runtime},
49    sync::Mutex as AsyncMutex,
50    task_local,
51    time::timeout,
52};
53use tracing::warn;
54
55#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
56struct Work {
57    label: String,
58}
59
60#[derive(Debug)]
61struct Metrics {
62    tasks_spawned: Family<Work, Counter>,
63    tasks_running: Family<Work, Gauge>,
64
65    // As nice as it would be to track each of these by socket address,
66    // it quickly becomes an OOM attack vector.
67    inbound_connections: Counter,
68    outbound_connections: Counter,
69    inbound_bandwidth: Counter,
70    outbound_bandwidth: Counter,
71
72    open_blobs: Gauge,
73    storage_reads: Counter,
74    storage_read_bytes: Counter,
75    storage_writes: Counter,
76    storage_write_bytes: Counter,
77}
78
79impl Metrics {
80    pub fn init(registry: Arc<Mutex<Registry>>) -> Self {
81        let metrics = Self {
82            tasks_spawned: Family::default(),
83            tasks_running: Family::default(),
84            inbound_connections: Counter::default(),
85            outbound_connections: Counter::default(),
86            inbound_bandwidth: Counter::default(),
87            outbound_bandwidth: Counter::default(),
88            open_blobs: Gauge::default(),
89            storage_reads: Counter::default(),
90            storage_read_bytes: Counter::default(),
91            storage_writes: Counter::default(),
92            storage_write_bytes: Counter::default(),
93        };
94        {
95            let mut registry = registry.lock().unwrap();
96            registry.register(
97                "tasks_spawned",
98                "Total number of tasks spawned",
99                metrics.tasks_spawned.clone(),
100            );
101            registry.register(
102                "tasks_running",
103                "Number of tasks currently running",
104                metrics.tasks_running.clone(),
105            );
106            registry.register(
107                "inbound_connections",
108                "Number of connections created by dialing us",
109                metrics.inbound_connections.clone(),
110            );
111            registry.register(
112                "outbound_connections",
113                "Number of connections created by dialing others",
114                metrics.outbound_connections.clone(),
115            );
116            registry.register(
117                "inbound_bandwidth",
118                "Bandwidth used by receiving data from others",
119                metrics.inbound_bandwidth.clone(),
120            );
121            registry.register(
122                "outbound_bandwidth",
123                "Bandwidth used by sending data to others",
124                metrics.outbound_bandwidth.clone(),
125            );
126            registry.register(
127                "open_blobs",
128                "Number of open blobs",
129                metrics.open_blobs.clone(),
130            );
131            registry.register(
132                "storage_reads",
133                "Total number of disk reads",
134                metrics.storage_reads.clone(),
135            );
136            registry.register(
137                "storage_read_bytes",
138                "Total amount of data read from disk",
139                metrics.storage_read_bytes.clone(),
140            );
141            registry.register(
142                "storage_writes",
143                "Total number of disk writes",
144                metrics.storage_writes.clone(),
145            );
146            registry.register(
147                "storage_write_bytes",
148                "Total amount of data written to disk",
149                metrics.storage_write_bytes.clone(),
150            );
151        }
152        metrics
153    }
154}
155
156/// Configuration for the `tokio` runtime.
157#[derive(Clone)]
158pub struct Config {
159    /// Registry for metrics.
160    pub registry: Arc<Mutex<Registry>>,
161
162    /// Number of threads to use for the runtime.
163    pub threads: usize,
164
165    /// Whether or not to catch panics.
166    pub catch_panics: bool,
167
168    /// Duration after which to close the connection if no message is read.
169    pub read_timeout: Duration,
170
171    /// Duration after which to close the connection if a message cannot be written.
172    pub write_timeout: Duration,
173
174    /// Whether or not to disable Nagle's algorithm.
175    ///
176    /// The algorithm combines a series of small network packets into a single packet
177    /// before sending to reduce overhead of sending multiple small packets which might not
178    /// be efficient on slow, congested networks. However, to do so the algorithm introduces
179    /// a slight delay as it waits to accumulate more data. Latency-sensitive networks should
180    /// consider disabling it to send the packets as soon as possible to reduce latency.
181    ///
182    /// Note: Make sure that your compile target has and allows this configuration otherwise
183    /// panics or unexpected behaviours are possible.
184    pub tcp_nodelay: Option<bool>,
185
186    /// Base directory for all storage operations.
187    pub storage_directory: PathBuf,
188
189    /// Maximum buffer size for operations on blobs.
190    ///
191    /// `tokio` defaults this value to 2MB.
192    pub maximum_buffer_size: usize,
193}
194
195impl Default for Config {
196    fn default() -> Self {
197        // Generate a random directory name to avoid conflicts (used in tests, so we shouldn't need to reload)
198        let rng = OsRng.next_u64();
199        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
200
201        // Return the configuration
202        Self {
203            registry: Arc::new(Mutex::new(Registry::default())),
204            threads: 2,
205            catch_panics: true,
206            read_timeout: Duration::from_secs(60),
207            write_timeout: Duration::from_secs(30),
208            tcp_nodelay: None,
209            storage_directory,
210            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
211        }
212    }
213}
214
215/// Runtime based on [Tokio](https://tokio.rs).
216pub struct Executor {
217    cfg: Config,
218    metrics: Arc<Metrics>,
219    runtime: Runtime,
220    fs: AsyncMutex<()>,
221    signaler: Mutex<Signaler>,
222    signal: Signal,
223}
224
225impl Executor {
226    /// Initialize a new `tokio` runtime with the given number of threads.
227    pub fn init(cfg: Config) -> (Runner, Context) {
228        let metrics = Arc::new(Metrics::init(cfg.registry.clone()));
229        let runtime = Builder::new_multi_thread()
230            .worker_threads(cfg.threads)
231            .enable_all()
232            .build()
233            .expect("failed to create Tokio runtime");
234        let (signaler, signal) = Signaler::new();
235        let executor = Arc::new(Self {
236            cfg,
237            metrics,
238            runtime,
239            fs: AsyncMutex::new(()),
240            signaler: Mutex::new(signaler),
241            signal,
242        });
243        (
244            Runner {
245                executor: executor.clone(),
246            },
247            Context { executor },
248        )
249    }
250
251    /// Initialize a new `tokio` runtime with default configuration.
252    // We'd love to implement the trait but we can't because of the return type.
253    #[allow(clippy::should_implement_trait)]
254    pub fn default() -> (Runner, Context) {
255        Self::init(Config::default())
256    }
257}
258
259/// Implementation of [`crate::Runner`] for the `tokio` runtime.
260pub struct Runner {
261    executor: Arc<Executor>,
262}
263
264impl crate::Runner for Runner {
265    fn start<F>(self, f: F) -> F::Output
266    where
267        F: Future + Send + 'static,
268        F::Output: Send + 'static,
269    {
270        self.executor.runtime.block_on(f)
271    }
272}
273
274/// Implementation of [`crate::Spawner`], [`crate::Clock`],
275/// [`crate::Network`], and [`crate::Storage`] for the `tokio`
276/// runtime.
277#[derive(Clone)]
278pub struct Context {
279    executor: Arc<Executor>,
280}
281
282task_local! {
283    static PREFIX: String;
284}
285
286impl crate::Spawner for Context {
287    fn spawn<F, T>(&self, label: &str, f: F) -> Handle<T>
288    where
289        F: Future<Output = T> + Send + 'static,
290        T: Send + 'static,
291    {
292        let label = PREFIX
293            .try_with(|prefix| format!("{}_{}", prefix, label))
294            .unwrap_or_else(|_| label.to_string());
295        let f = PREFIX.scope(label.clone(), f);
296        let work = Work { label };
297        self.executor
298            .metrics
299            .tasks_spawned
300            .get_or_create(&work)
301            .inc();
302        let gauge = self
303            .executor
304            .metrics
305            .tasks_running
306            .get_or_create(&work)
307            .clone();
308        let (f, handle) = Handle::init(f, gauge, self.executor.cfg.catch_panics);
309        self.executor.runtime.spawn(f);
310        handle
311    }
312
313    fn stop(&self, value: i32) {
314        self.executor.signaler.lock().unwrap().signal(value);
315    }
316
317    fn stopped(&self) -> Signal {
318        self.executor.signal.clone()
319    }
320}
321
322impl Clock for Context {
323    fn current(&self) -> SystemTime {
324        SystemTime::now()
325    }
326
327    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
328        tokio::time::sleep(duration)
329    }
330
331    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
332        let now = SystemTime::now();
333        let duration_until_deadline = match deadline.duration_since(now) {
334            Ok(duration) => duration,
335            Err(_) => Duration::from_secs(0), // Deadline is in the past
336        };
337        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
338        tokio::time::sleep_until(target_instant)
339    }
340}
341
342impl GClock for Context {
343    type Instant = SystemTime;
344
345    fn now(&self) -> Self::Instant {
346        self.current()
347    }
348}
349
350impl ReasonablyRealtime for Context {}
351
352impl crate::Network<Listener, Sink, Stream> for Context {
353    async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
354        TcpListener::bind(socket)
355            .await
356            .map_err(|_| Error::BindFailed)
357            .map(|listener| Listener {
358                context: self.clone(),
359                listener,
360            })
361    }
362
363    async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
364        // Create a new TCP stream
365        let stream = TcpStream::connect(socket)
366            .await
367            .map_err(|_| Error::ConnectionFailed)?;
368        self.executor.metrics.outbound_connections.inc();
369
370        // Set TCP_NODELAY if configured
371        if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
372            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
373                warn!(?err, "failed to set TCP_NODELAY");
374            }
375        }
376
377        // Return the sink and stream
378        let context = self.clone();
379        let (stream, sink) = stream.into_split();
380        Ok((
381            Sink {
382                context: context.clone(),
383                sink,
384            },
385            Stream { context, stream },
386        ))
387    }
388}
389
390/// Implementation of [`crate::Listener`] for the `tokio` runtime.
391pub struct Listener {
392    context: Context,
393    listener: TcpListener,
394}
395
396impl crate::Listener<Sink, Stream> for Listener {
397    async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
398        // Accept a new TCP stream
399        let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
400        self.context.executor.metrics.inbound_connections.inc();
401
402        // Set TCP_NODELAY if configured
403        if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
404            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
405                warn!(?err, "failed to set TCP_NODELAY");
406            }
407        }
408
409        // Return the sink and stream
410        let context = self.context.clone();
411        let (stream, sink) = stream.into_split();
412        Ok((
413            addr,
414            Sink {
415                context: context.clone(),
416                sink,
417            },
418            Stream { context, stream },
419        ))
420    }
421}
422
423/// Implementation of [`crate::Sink`] for the `tokio` runtime.
424pub struct Sink {
425    context: Context,
426    sink: OwnedWriteHalf,
427}
428
429impl crate::Sink for Sink {
430    async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
431        let len = msg.len();
432        timeout(
433            self.context.executor.cfg.write_timeout,
434            self.sink.write_all(msg),
435        )
436        .await
437        .map_err(|_| Error::Timeout)?
438        .map_err(|_| Error::SendFailed)?;
439        self.context
440            .executor
441            .metrics
442            .outbound_bandwidth
443            .inc_by(len as u64);
444        Ok(())
445    }
446}
447
448/// Implementation of [`crate::Stream`] for the `tokio` runtime.
449pub struct Stream {
450    context: Context,
451    stream: OwnedReadHalf,
452}
453
454impl crate::Stream for Stream {
455    async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
456        // Wait for the stream to be readable
457        timeout(
458            self.context.executor.cfg.read_timeout,
459            self.stream.read_exact(buf),
460        )
461        .await
462        .map_err(|_| Error::Timeout)?
463        .map_err(|_| Error::RecvFailed)?;
464
465        // Record metrics
466        self.context
467            .executor
468            .metrics
469            .inbound_bandwidth
470            .inc_by(buf.len() as u64);
471
472        Ok(())
473    }
474}
475
476impl RngCore for Context {
477    fn next_u32(&mut self) -> u32 {
478        OsRng.next_u32()
479    }
480
481    fn next_u64(&mut self) -> u64 {
482        OsRng.next_u64()
483    }
484
485    fn fill_bytes(&mut self, dest: &mut [u8]) {
486        OsRng.fill_bytes(dest);
487    }
488
489    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
490        OsRng.try_fill_bytes(dest)
491    }
492}
493
494impl CryptoRng for Context {}
495
496/// Implementation of [`crate::Blob`] for the `tokio` runtime.
497pub struct Blob {
498    metrics: Arc<Metrics>,
499
500    partition: String,
501    name: Vec<u8>,
502
503    // Files must be seeked prior to any read or write operation and are thus
504    // not safe to concurrently interact with. If we switched to mapping files
505    // we could remove this lock.
506    //
507    // We also track the virtual file size because metadata isn't updated until
508    // the file is synced (not to mention it is a lot less fs calls).
509    file: Arc<AsyncMutex<(fs::File, u64)>>,
510}
511
512impl Blob {
513    fn new(
514        metrics: Arc<Metrics>,
515        partition: String,
516        name: &[u8],
517        file: fs::File,
518        len: u64,
519    ) -> Self {
520        metrics.open_blobs.inc();
521        Self {
522            metrics,
523            partition,
524            name: name.into(),
525            file: Arc::new(AsyncMutex::new((file, len))),
526        }
527    }
528}
529
530impl Clone for Blob {
531    fn clone(&self) -> Self {
532        // We implement `Clone` manually to ensure the `open_blobs` gauge is updated.
533        self.metrics.open_blobs.inc();
534        Self {
535            metrics: self.metrics.clone(),
536            partition: self.partition.clone(),
537            name: self.name.clone(),
538            file: self.file.clone(),
539        }
540    }
541}
542
543impl crate::Storage<Blob> for Context {
544    async fn open(&self, partition: &str, name: &[u8]) -> Result<Blob, Error> {
545        // Acquire the filesystem lock
546        let _guard = self.executor.fs.lock().await;
547
548        // Construct the full path
549        let path = self
550            .executor
551            .cfg
552            .storage_directory
553            .join(partition)
554            .join(hex(name));
555        let parent = match path.parent() {
556            Some(parent) => parent,
557            None => return Err(Error::PartitionCreationFailed(partition.into())),
558        };
559
560        // Create the partition directory if it does not exist
561        fs::create_dir_all(parent)
562            .await
563            .map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
564
565        // Open the file in read-write mode, create if it does not exist
566        let mut file = fs::OpenOptions::new()
567            .read(true)
568            .write(true)
569            .create(true)
570            .truncate(false)
571            .open(&path)
572            .await
573            .map_err(|_| Error::BlobOpenFailed(partition.into(), hex(name)))?;
574
575        // Set the maximum buffer size
576        file.set_max_buf_size(self.executor.cfg.maximum_buffer_size);
577
578        // Get the file length
579        let len = file.metadata().await.map_err(|_| Error::ReadFailed)?.len();
580
581        // Construct the blob
582        Ok(Blob::new(
583            self.executor.metrics.clone(),
584            partition.into(),
585            name,
586            file,
587            len,
588        ))
589    }
590
591    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
592        // Acquire the filesystem lock
593        let _guard = self.executor.fs.lock().await;
594
595        // Remove all related files
596        let path = self.executor.cfg.storage_directory.join(partition);
597        if let Some(name) = name {
598            let blob_path = path.join(hex(name));
599            fs::remove_file(blob_path)
600                .await
601                .map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
602        } else {
603            fs::remove_dir_all(path)
604                .await
605                .map_err(|_| Error::PartitionMissing(partition.into()))?;
606        }
607        Ok(())
608    }
609
610    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
611        // Acquire the filesystem lock
612        let _guard = self.executor.fs.lock().await;
613
614        // Scan the partition directory
615        let path = self.executor.cfg.storage_directory.join(partition);
616        let mut entries = fs::read_dir(path)
617            .await
618            .map_err(|_| Error::PartitionMissing(partition.into()))?;
619        let mut blobs = Vec::new();
620        while let Some(entry) = entries.next_entry().await.map_err(|_| Error::ReadFailed)? {
621            let file_type = entry.file_type().await.map_err(|_| Error::ReadFailed)?;
622            if !file_type.is_file() {
623                return Err(Error::PartitionCorrupt(partition.into()));
624            }
625            if let Some(name) = entry.file_name().to_str() {
626                let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
627                blobs.push(name);
628            }
629        }
630        Ok(blobs)
631    }
632}
633
634impl crate::Blob for Blob {
635    async fn len(&self) -> Result<u64, Error> {
636        let (_, len) = *self.file.lock().await;
637        Ok(len)
638    }
639
640    async fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<(), Error> {
641        // Ensure the read is within bounds
642        let mut file = self.file.lock().await;
643        if offset + buf.len() as u64 > file.1 {
644            return Err(Error::BlobInsufficientLength);
645        }
646
647        // Perform the read
648        file.0
649            .seek(SeekFrom::Start(offset))
650            .await
651            .map_err(|_| Error::ReadFailed)?;
652        file.0
653            .read_exact(buf)
654            .await
655            .map_err(|_| Error::ReadFailed)?;
656        self.metrics.storage_reads.inc();
657        self.metrics.storage_read_bytes.inc_by(buf.len() as u64);
658        Ok(())
659    }
660
661    async fn write_at(&self, buf: &[u8], offset: u64) -> Result<(), Error> {
662        // Perform the write
663        let mut file = self.file.lock().await;
664        file.0
665            .seek(SeekFrom::Start(offset))
666            .await
667            .map_err(|_| Error::WriteFailed)?;
668        file.0
669            .write_all(buf)
670            .await
671            .map_err(|_| Error::WriteFailed)?;
672
673        // Update the virtual file size
674        let max_len = offset + buf.len() as u64;
675        if max_len > file.1 {
676            file.1 = max_len;
677        }
678        self.metrics.storage_writes.inc();
679        self.metrics.storage_write_bytes.inc_by(buf.len() as u64);
680        Ok(())
681    }
682
683    async fn truncate(&self, len: u64) -> Result<(), Error> {
684        // Perform the truncate
685        let mut file = self.file.lock().await;
686        file.0
687            .set_len(len)
688            .await
689            .map_err(|_| Error::BlobTruncateFailed(self.partition.clone(), hex(&self.name)))?;
690
691        // Update the virtual file size
692        file.1 = len;
693        Ok(())
694    }
695
696    async fn sync(&self) -> Result<(), Error> {
697        let file = self.file.lock().await;
698        file.0
699            .sync_all()
700            .await
701            .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))
702    }
703
704    async fn close(self) -> Result<(), Error> {
705        let mut file = self.file.lock().await;
706        file.0
707            .sync_all()
708            .await
709            .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))?;
710        file.0
711            .shutdown()
712            .await
713            .map_err(|_| Error::BlobCloseFailed(self.partition.clone(), hex(&self.name)))
714    }
715}
716
717impl Drop for Blob {
718    fn drop(&mut self) {
719        self.metrics.open_blobs.dec();
720    }
721}