commonware_runtime/tokio/
runtime.rs

1use crate::storage::metered::Storage;
2use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
3use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
4use crate::{SinkOf, StreamOf};
5use governor::clock::{Clock as GClock, ReasonablyRealtime};
6use prometheus_client::{
7    encoding::{text::encode, EncodeLabelSet},
8    metrics::{counter::Counter, family::Family, gauge::Gauge},
9    registry::{Metric, Registry},
10};
11use rand::{rngs::OsRng, CryptoRng, RngCore};
12use std::{
13    env,
14    future::Future,
15    io,
16    net::SocketAddr,
17    path::PathBuf,
18    sync::{Arc, Mutex},
19    time::{Duration, SystemTime},
20};
21use tokio::{
22    io::{AsyncReadExt, AsyncWriteExt},
23    net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
24    runtime::{Builder, Runtime},
25    time::timeout,
26};
27use tracing::warn;
28
29#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
30struct Work {
31    label: String,
32}
33
34#[derive(Debug)]
35struct Metrics {
36    tasks_spawned: Family<Work, Counter>,
37    tasks_running: Family<Work, Gauge>,
38    blocking_tasks_spawned: Family<Work, Counter>,
39    blocking_tasks_running: Family<Work, Gauge>,
40
41    // As nice as it would be to track each of these by socket address,
42    // it quickly becomes an OOM attack vector.
43    inbound_connections: Counter,
44    outbound_connections: Counter,
45    inbound_bandwidth: Counter,
46    outbound_bandwidth: Counter,
47}
48
49impl Metrics {
50    pub fn init(registry: &mut Registry) -> Self {
51        let metrics = Self {
52            tasks_spawned: Family::default(),
53            tasks_running: Family::default(),
54            blocking_tasks_spawned: Family::default(),
55            blocking_tasks_running: Family::default(),
56            inbound_connections: Counter::default(),
57            outbound_connections: Counter::default(),
58            inbound_bandwidth: Counter::default(),
59            outbound_bandwidth: Counter::default(),
60        };
61        registry.register(
62            "tasks_spawned",
63            "Total number of tasks spawned",
64            metrics.tasks_spawned.clone(),
65        );
66        registry.register(
67            "tasks_running",
68            "Number of tasks currently running",
69            metrics.tasks_running.clone(),
70        );
71        registry.register(
72            "blocking_tasks_spawned",
73            "Total number of blocking tasks spawned",
74            metrics.blocking_tasks_spawned.clone(),
75        );
76        registry.register(
77            "blocking_tasks_running",
78            "Number of blocking tasks currently running",
79            metrics.blocking_tasks_running.clone(),
80        );
81        registry.register(
82            "inbound_connections",
83            "Number of connections created by dialing us",
84            metrics.inbound_connections.clone(),
85        );
86        registry.register(
87            "outbound_connections",
88            "Number of connections created by dialing others",
89            metrics.outbound_connections.clone(),
90        );
91        registry.register(
92            "inbound_bandwidth",
93            "Bandwidth used by receiving data from others",
94            metrics.inbound_bandwidth.clone(),
95        );
96        registry.register(
97            "outbound_bandwidth",
98            "Bandwidth used by sending data to others",
99            metrics.outbound_bandwidth.clone(),
100        );
101        metrics
102    }
103}
104
105/// Configuration for the `tokio` runtime.
106#[derive(Clone)]
107pub struct Config {
108    /// Number of threads to use for handling async tasks.
109    ///
110    /// Worker threads are always active (waiting for work).
111    ///
112    /// Tokio sets the default value to the number of logical CPUs.
113    worker_threads: usize,
114
115    /// Maximum number of threads to use for blocking tasks.
116    ///
117    /// Unlike worker threads, blocking threads are created as needed and
118    /// exit if left idle for too long.
119    ///
120    /// Tokio sets the default value to 512 to avoid hanging on lower-level
121    /// operations that require blocking (like `fs` and writing to `Stdout`).
122    max_blocking_threads: usize,
123
124    /// Whether or not to catch panics.
125    catch_panics: bool,
126
127    /// Duration after which to close the connection if no message is read.
128    read_timeout: Duration,
129
130    /// Duration after which to close the connection if a message cannot be written.
131    write_timeout: Duration,
132
133    /// Whether or not to disable Nagle's algorithm.
134    ///
135    /// The algorithm combines a series of small network packets into a single packet
136    /// before sending to reduce overhead of sending multiple small packets which might not
137    /// be efficient on slow, congested networks. However, to do so the algorithm introduces
138    /// a slight delay as it waits to accumulate more data. Latency-sensitive networks should
139    /// consider disabling it to send the packets as soon as possible to reduce latency.
140    ///
141    /// Note: Make sure that your compile target has and allows this configuration otherwise
142    /// panics or unexpected behaviours are possible.
143    tcp_nodelay: Option<bool>,
144
145    /// Base directory for all storage operations.
146    storage_directory: PathBuf,
147
148    /// Maximum buffer size for operations on blobs.
149    ///
150    /// Tokio sets the default value to 2MB.
151    maximum_buffer_size: usize,
152}
153
154impl Config {
155    /// Returns a new [Config] with default values.
156    pub fn new() -> Self {
157        let rng = OsRng.next_u64();
158        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
159        Self {
160            worker_threads: 2,
161            max_blocking_threads: 512,
162            catch_panics: true,
163            read_timeout: Duration::from_secs(60),
164            write_timeout: Duration::from_secs(30),
165            tcp_nodelay: None,
166            storage_directory,
167            maximum_buffer_size: 2 * 1024 * 1024, // 2 MB
168        }
169    }
170
171    // Setters
172    /// See [Config]
173    pub fn with_worker_threads(mut self, n: usize) -> Self {
174        self.worker_threads = n;
175        self
176    }
177    /// See [Config]
178    pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
179        self.max_blocking_threads = n;
180        self
181    }
182    /// See [Config]
183    pub fn with_catch_panics(mut self, b: bool) -> Self {
184        self.catch_panics = b;
185        self
186    }
187    /// See [Config]
188    pub fn with_read_timeout(mut self, d: Duration) -> Self {
189        self.read_timeout = d;
190        self
191    }
192    /// See [Config]
193    pub fn with_write_timeout(mut self, d: Duration) -> Self {
194        self.write_timeout = d;
195        self
196    }
197    /// See [Config]
198    pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
199        self.tcp_nodelay = n;
200        self
201    }
202    /// See [Config]
203    pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
204        self.storage_directory = p.into();
205        self
206    }
207    /// See [Config]
208    pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
209        self.maximum_buffer_size = n;
210        self
211    }
212
213    // Getters
214    /// See [Config]
215    pub fn worker_threads(&self) -> usize {
216        self.worker_threads
217    }
218    /// See [Config]
219    pub fn max_blocking_threads(&self) -> usize {
220        self.max_blocking_threads
221    }
222    /// See [Config]
223    pub fn catch_panics(&self) -> bool {
224        self.catch_panics
225    }
226    /// See [Config]
227    pub fn read_timeout(&self) -> Duration {
228        self.read_timeout
229    }
230    /// See [Config]
231    pub fn write_timeout(&self) -> Duration {
232        self.write_timeout
233    }
234    /// See [Config]
235    pub fn tcp_nodelay(&self) -> Option<bool> {
236        self.tcp_nodelay
237    }
238    /// See [Config]
239    pub fn storage_directory(&self) -> &PathBuf {
240        &self.storage_directory
241    }
242    /// See [Config]
243    pub fn maximum_buffer_size(&self) -> usize {
244        self.maximum_buffer_size
245    }
246}
247
248impl Default for Config {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254/// Runtime based on [Tokio](https://tokio.rs).
255pub struct Executor {
256    cfg: Config,
257    registry: Mutex<Registry>,
258    metrics: Arc<Metrics>,
259    runtime: Runtime,
260    signaler: Mutex<Signaler>,
261    signal: Signal,
262}
263
264/// Implementation of [crate::Runner] for the `tokio` runtime.
265pub struct Runner {
266    cfg: Config,
267}
268
269impl Default for Runner {
270    fn default() -> Self {
271        Self::new(Config::default())
272    }
273}
274
275impl Runner {
276    /// Initialize a new `tokio` runtime with the given number of threads.
277    pub fn new(cfg: Config) -> Self {
278        Self { cfg }
279    }
280}
281
282impl crate::Runner for Runner {
283    type Context = Context;
284
285    fn start<F, Fut>(self, f: F) -> Fut::Output
286    where
287        F: FnOnce(Self::Context) -> Fut,
288        Fut: Future,
289    {
290        // Create a new registry
291        let mut registry = Registry::default();
292        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
293
294        // Initialize runtime
295        let metrics = Arc::new(Metrics::init(runtime_registry));
296        let runtime = Builder::new_multi_thread()
297            .worker_threads(self.cfg.worker_threads)
298            .max_blocking_threads(self.cfg.max_blocking_threads)
299            .enable_all()
300            .build()
301            .expect("failed to create Tokio runtime");
302        let (signaler, signal) = Signaler::new();
303
304        let storage = Storage::new(
305            TokioStorage::new(TokioStorageConfig::new(
306                self.cfg.storage_directory.clone(),
307                self.cfg.maximum_buffer_size,
308            )),
309            runtime_registry,
310        );
311
312        let executor = Arc::new(Executor {
313            cfg: self.cfg,
314            registry: Mutex::new(registry),
315            metrics,
316            runtime,
317            signaler: Mutex::new(signaler),
318            signal,
319        });
320
321        let context = Context {
322            storage,
323            label: String::new(),
324            spawned: false,
325            executor: executor.clone(),
326        };
327
328        executor.runtime.block_on(f(context))
329    }
330}
331
332/// Implementation of [crate::Spawner], [crate::Clock],
333/// [crate::Network], and [crate::Storage] for the `tokio`
334/// runtime.
335pub struct Context {
336    label: String,
337    spawned: bool,
338    executor: Arc<Executor>,
339    storage: Storage<TokioStorage>,
340}
341
342impl Clone for Context {
343    fn clone(&self) -> Self {
344        Self {
345            label: self.label.clone(),
346            spawned: false,
347            executor: self.executor.clone(),
348            storage: self.storage.clone(),
349        }
350    }
351}
352
353impl crate::Spawner for Context {
354    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
355    where
356        F: FnOnce(Self) -> Fut + Send + 'static,
357        Fut: Future<Output = T> + Send + 'static,
358        T: Send + 'static,
359    {
360        // Ensure a context only spawns one task
361        assert!(!self.spawned, "already spawned");
362
363        // Get metrics
364        let work = Work {
365            label: self.label.clone(),
366        };
367        self.executor
368            .metrics
369            .tasks_spawned
370            .get_or_create(&work)
371            .inc();
372        let gauge = self
373            .executor
374            .metrics
375            .tasks_running
376            .get_or_create(&work)
377            .clone();
378
379        // Set up the task
380        let catch_panics = self.executor.cfg.catch_panics;
381        let executor = self.executor.clone();
382        let future = f(self);
383        let (f, handle) = Handle::init(future, gauge, catch_panics);
384
385        // Spawn the task
386        executor.runtime.spawn(f);
387        handle
388    }
389
390    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
391    where
392        F: Future<Output = T> + Send + 'static,
393        T: Send + 'static,
394    {
395        // Ensure a context only spawns one task
396        assert!(!self.spawned, "already spawned");
397        self.spawned = true;
398
399        // Get metrics
400        let work = Work {
401            label: self.label.clone(),
402        };
403        self.executor
404            .metrics
405            .tasks_spawned
406            .get_or_create(&work)
407            .inc();
408        let gauge = self
409            .executor
410            .metrics
411            .tasks_running
412            .get_or_create(&work)
413            .clone();
414
415        // Set up the task
416        let executor = self.executor.clone();
417        move |f: F| {
418            let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
419
420            // Spawn the task
421            executor.runtime.spawn(f);
422            handle
423        }
424    }
425
426    fn spawn_blocking<F, T>(self, f: F) -> Handle<T>
427    where
428        F: FnOnce() -> T + Send + 'static,
429        T: Send + 'static,
430    {
431        // Ensure a context only spawns one task
432        assert!(!self.spawned, "already spawned");
433
434        // Get metrics
435        let work = Work {
436            label: self.label.clone(),
437        };
438        self.executor
439            .metrics
440            .blocking_tasks_spawned
441            .get_or_create(&work)
442            .inc();
443        let gauge = self
444            .executor
445            .metrics
446            .blocking_tasks_running
447            .get_or_create(&work)
448            .clone();
449
450        // Initialize the blocking task using the new function
451        let (f, handle) = Handle::init_blocking(f, gauge, self.executor.cfg.catch_panics);
452
453        // Spawn the blocking task
454        self.executor.runtime.spawn_blocking(f);
455        handle
456    }
457
458    fn stop(&self, value: i32) {
459        self.executor.signaler.lock().unwrap().signal(value);
460    }
461
462    fn stopped(&self) -> Signal {
463        self.executor.signal.clone()
464    }
465}
466
467impl crate::Metrics for Context {
468    fn with_label(&self, label: &str) -> Self {
469        let label = {
470            let prefix = self.label.clone();
471            if prefix.is_empty() {
472                label.to_string()
473            } else {
474                format!("{}_{}", prefix, label)
475            }
476        };
477        assert!(
478            !label.starts_with(METRICS_PREFIX),
479            "using runtime label is not allowed"
480        );
481        Self {
482            label,
483            spawned: false,
484            executor: self.executor.clone(),
485            storage: self.storage.clone(),
486        }
487    }
488
489    fn label(&self) -> String {
490        self.label.clone()
491    }
492
493    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
494        let name = name.into();
495        let prefixed_name = {
496            let prefix = &self.label;
497            if prefix.is_empty() {
498                name
499            } else {
500                format!("{}_{}", *prefix, name)
501            }
502        };
503        self.executor
504            .registry
505            .lock()
506            .unwrap()
507            .register(prefixed_name, help, metric)
508    }
509
510    fn encode(&self) -> String {
511        let mut buffer = String::new();
512        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
513        buffer
514    }
515}
516
517impl Clock for Context {
518    fn current(&self) -> SystemTime {
519        SystemTime::now()
520    }
521
522    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
523        tokio::time::sleep(duration)
524    }
525
526    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
527        let now = SystemTime::now();
528        let duration_until_deadline = match deadline.duration_since(now) {
529            Ok(duration) => duration,
530            Err(_) => Duration::from_secs(0), // Deadline is in the past
531        };
532        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
533        tokio::time::sleep_until(target_instant)
534    }
535}
536
537impl GClock for Context {
538    type Instant = SystemTime;
539
540    fn now(&self) -> Self::Instant {
541        self.current()
542    }
543}
544
545impl ReasonablyRealtime for Context {}
546
547impl crate::Network for Context {
548    type Listener = Listener;
549
550    async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
551        TcpListener::bind(socket)
552            .await
553            .map_err(|_| Error::BindFailed)
554            .map(|listener| Listener {
555                context: self.clone(),
556                listener,
557            })
558    }
559
560    async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
561        // Create a new TCP stream
562        let stream = TcpStream::connect(socket)
563            .await
564            .map_err(|_| Error::ConnectionFailed)?;
565        self.executor.metrics.outbound_connections.inc();
566
567        // Set TCP_NODELAY if configured
568        if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
569            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
570                warn!(?err, "failed to set TCP_NODELAY");
571            }
572        }
573
574        // Return the sink and stream
575        let context = self.clone();
576        let (stream, sink) = stream.into_split();
577        Ok((
578            Sink {
579                context: context.clone(),
580                sink,
581            },
582            Stream { context, stream },
583        ))
584    }
585}
586
587/// Implementation of [crate::Listener] for the `tokio` runtime.
588pub struct Listener {
589    context: Context,
590    listener: TcpListener,
591}
592
593impl crate::Listener for Listener {
594    type Sink = Sink;
595    type Stream = Stream;
596
597    async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
598        // Accept a new TCP stream
599        let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
600        self.context.executor.metrics.inbound_connections.inc();
601
602        // Set TCP_NODELAY if configured
603        if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
604            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
605                warn!(?err, "failed to set TCP_NODELAY");
606            }
607        }
608
609        // Return the sink and stream
610        let context = self.context.clone();
611        let (stream, sink) = stream.into_split();
612        Ok((
613            addr,
614            Sink {
615                context: context.clone(),
616                sink,
617            },
618            Stream { context, stream },
619        ))
620    }
621}
622
623impl axum::serve::Listener for Listener {
624    type Io = TcpStream;
625    type Addr = SocketAddr;
626
627    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
628        let (stream, addr) = self.listener.accept().await.unwrap();
629        (stream, addr)
630    }
631
632    fn local_addr(&self) -> io::Result<Self::Addr> {
633        self.listener.local_addr()
634    }
635}
636
637/// Implementation of [crate::Sink] for the `tokio` runtime.
638pub struct Sink {
639    context: Context,
640    sink: OwnedWriteHalf,
641}
642
643impl crate::Sink for Sink {
644    async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
645        let len = msg.len();
646        timeout(
647            self.context.executor.cfg.write_timeout,
648            self.sink.write_all(msg),
649        )
650        .await
651        .map_err(|_| Error::Timeout)?
652        .map_err(|_| Error::SendFailed)?;
653        self.context
654            .executor
655            .metrics
656            .outbound_bandwidth
657            .inc_by(len as u64);
658        Ok(())
659    }
660}
661
662/// Implementation of [crate::Stream] for the `tokio` runtime.
663pub struct Stream {
664    context: Context,
665    stream: OwnedReadHalf,
666}
667
668impl crate::Stream for Stream {
669    async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
670        // Wait for the stream to be readable
671        timeout(
672            self.context.executor.cfg.read_timeout,
673            self.stream.read_exact(buf),
674        )
675        .await
676        .map_err(|_| Error::Timeout)?
677        .map_err(|_| Error::RecvFailed)?;
678
679        // Record metrics
680        self.context
681            .executor
682            .metrics
683            .inbound_bandwidth
684            .inc_by(buf.len() as u64);
685
686        Ok(())
687    }
688}
689
690impl RngCore for Context {
691    fn next_u32(&mut self) -> u32 {
692        OsRng.next_u32()
693    }
694
695    fn next_u64(&mut self) -> u64 {
696        OsRng.next_u64()
697    }
698
699    fn fill_bytes(&mut self, dest: &mut [u8]) {
700        OsRng.fill_bytes(dest);
701    }
702
703    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
704        OsRng.try_fill_bytes(dest)
705    }
706}
707
708impl CryptoRng for Context {}
709
710impl crate::Storage for Context {
711    type Blob = <Storage<TokioStorage> as crate::Storage>::Blob;
712
713    async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
714        self.storage.open(partition, name).await
715    }
716
717    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
718        self.storage.remove(partition, name).await
719    }
720
721    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
722        self.storage.scan(partition).await
723    }
724}