1use crate::storage::metered::Storage;
2use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
3use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
4use crate::{SinkOf, StreamOf};
5use governor::clock::{Clock as GClock, ReasonablyRealtime};
6use prometheus_client::{
7 encoding::{text::encode, EncodeLabelSet},
8 metrics::{counter::Counter, family::Family, gauge::Gauge},
9 registry::{Metric, Registry},
10};
11use rand::{rngs::OsRng, CryptoRng, RngCore};
12use std::{
13 env,
14 future::Future,
15 io,
16 net::SocketAddr,
17 path::PathBuf,
18 sync::{Arc, Mutex},
19 time::{Duration, SystemTime},
20};
21use tokio::{
22 io::{AsyncReadExt, AsyncWriteExt},
23 net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
24 runtime::{Builder, Runtime},
25 time::timeout,
26};
27use tracing::warn;
28
29#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
30struct Work {
31 label: String,
32}
33
34#[derive(Debug)]
35struct Metrics {
36 tasks_spawned: Family<Work, Counter>,
37 tasks_running: Family<Work, Gauge>,
38 blocking_tasks_spawned: Family<Work, Counter>,
39 blocking_tasks_running: Family<Work, Gauge>,
40
41 inbound_connections: Counter,
44 outbound_connections: Counter,
45 inbound_bandwidth: Counter,
46 outbound_bandwidth: Counter,
47}
48
49impl Metrics {
50 pub fn init(registry: &mut Registry) -> Self {
51 let metrics = Self {
52 tasks_spawned: Family::default(),
53 tasks_running: Family::default(),
54 blocking_tasks_spawned: Family::default(),
55 blocking_tasks_running: Family::default(),
56 inbound_connections: Counter::default(),
57 outbound_connections: Counter::default(),
58 inbound_bandwidth: Counter::default(),
59 outbound_bandwidth: Counter::default(),
60 };
61 registry.register(
62 "tasks_spawned",
63 "Total number of tasks spawned",
64 metrics.tasks_spawned.clone(),
65 );
66 registry.register(
67 "tasks_running",
68 "Number of tasks currently running",
69 metrics.tasks_running.clone(),
70 );
71 registry.register(
72 "blocking_tasks_spawned",
73 "Total number of blocking tasks spawned",
74 metrics.blocking_tasks_spawned.clone(),
75 );
76 registry.register(
77 "blocking_tasks_running",
78 "Number of blocking tasks currently running",
79 metrics.blocking_tasks_running.clone(),
80 );
81 registry.register(
82 "inbound_connections",
83 "Number of connections created by dialing us",
84 metrics.inbound_connections.clone(),
85 );
86 registry.register(
87 "outbound_connections",
88 "Number of connections created by dialing others",
89 metrics.outbound_connections.clone(),
90 );
91 registry.register(
92 "inbound_bandwidth",
93 "Bandwidth used by receiving data from others",
94 metrics.inbound_bandwidth.clone(),
95 );
96 registry.register(
97 "outbound_bandwidth",
98 "Bandwidth used by sending data to others",
99 metrics.outbound_bandwidth.clone(),
100 );
101 metrics
102 }
103}
104
105#[derive(Clone)]
107pub struct Config {
108 worker_threads: usize,
114
115 max_blocking_threads: usize,
123
124 catch_panics: bool,
126
127 read_timeout: Duration,
129
130 write_timeout: Duration,
132
133 tcp_nodelay: Option<bool>,
144
145 storage_directory: PathBuf,
147
148 maximum_buffer_size: usize,
152}
153
154impl Config {
155 pub fn new() -> Self {
157 let rng = OsRng.next_u64();
158 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
159 Self {
160 worker_threads: 2,
161 max_blocking_threads: 512,
162 catch_panics: true,
163 read_timeout: Duration::from_secs(60),
164 write_timeout: Duration::from_secs(30),
165 tcp_nodelay: None,
166 storage_directory,
167 maximum_buffer_size: 2 * 1024 * 1024, }
169 }
170
171 pub fn with_worker_threads(mut self, n: usize) -> Self {
174 self.worker_threads = n;
175 self
176 }
177 pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
179 self.max_blocking_threads = n;
180 self
181 }
182 pub fn with_catch_panics(mut self, b: bool) -> Self {
184 self.catch_panics = b;
185 self
186 }
187 pub fn with_read_timeout(mut self, d: Duration) -> Self {
189 self.read_timeout = d;
190 self
191 }
192 pub fn with_write_timeout(mut self, d: Duration) -> Self {
194 self.write_timeout = d;
195 self
196 }
197 pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
199 self.tcp_nodelay = n;
200 self
201 }
202 pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
204 self.storage_directory = p.into();
205 self
206 }
207 pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
209 self.maximum_buffer_size = n;
210 self
211 }
212
213 pub fn worker_threads(&self) -> usize {
216 self.worker_threads
217 }
218 pub fn max_blocking_threads(&self) -> usize {
220 self.max_blocking_threads
221 }
222 pub fn catch_panics(&self) -> bool {
224 self.catch_panics
225 }
226 pub fn read_timeout(&self) -> Duration {
228 self.read_timeout
229 }
230 pub fn write_timeout(&self) -> Duration {
232 self.write_timeout
233 }
234 pub fn tcp_nodelay(&self) -> Option<bool> {
236 self.tcp_nodelay
237 }
238 pub fn storage_directory(&self) -> &PathBuf {
240 &self.storage_directory
241 }
242 pub fn maximum_buffer_size(&self) -> usize {
244 self.maximum_buffer_size
245 }
246}
247
248impl Default for Config {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254pub struct Executor {
256 cfg: Config,
257 registry: Mutex<Registry>,
258 metrics: Arc<Metrics>,
259 runtime: Runtime,
260 signaler: Mutex<Signaler>,
261 signal: Signal,
262}
263
264pub struct Runner {
266 cfg: Config,
267}
268
269impl Default for Runner {
270 fn default() -> Self {
271 Self::new(Config::default())
272 }
273}
274
275impl Runner {
276 pub fn new(cfg: Config) -> Self {
278 Self { cfg }
279 }
280}
281
282impl crate::Runner for Runner {
283 type Context = Context;
284
285 fn start<F, Fut>(self, f: F) -> Fut::Output
286 where
287 F: FnOnce(Self::Context) -> Fut,
288 Fut: Future,
289 {
290 let mut registry = Registry::default();
292 let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
293
294 let metrics = Arc::new(Metrics::init(runtime_registry));
296 let runtime = Builder::new_multi_thread()
297 .worker_threads(self.cfg.worker_threads)
298 .max_blocking_threads(self.cfg.max_blocking_threads)
299 .enable_all()
300 .build()
301 .expect("failed to create Tokio runtime");
302 let (signaler, signal) = Signaler::new();
303
304 let storage = Storage::new(
305 TokioStorage::new(TokioStorageConfig::new(
306 self.cfg.storage_directory.clone(),
307 self.cfg.maximum_buffer_size,
308 )),
309 runtime_registry,
310 );
311
312 let executor = Arc::new(Executor {
313 cfg: self.cfg,
314 registry: Mutex::new(registry),
315 metrics,
316 runtime,
317 signaler: Mutex::new(signaler),
318 signal,
319 });
320
321 let context = Context {
322 storage,
323 label: String::new(),
324 spawned: false,
325 executor: executor.clone(),
326 };
327
328 executor.runtime.block_on(f(context))
329 }
330}
331
332pub struct Context {
336 label: String,
337 spawned: bool,
338 executor: Arc<Executor>,
339 storage: Storage<TokioStorage>,
340}
341
342impl Clone for Context {
343 fn clone(&self) -> Self {
344 Self {
345 label: self.label.clone(),
346 spawned: false,
347 executor: self.executor.clone(),
348 storage: self.storage.clone(),
349 }
350 }
351}
352
353impl crate::Spawner for Context {
354 fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
355 where
356 F: FnOnce(Self) -> Fut + Send + 'static,
357 Fut: Future<Output = T> + Send + 'static,
358 T: Send + 'static,
359 {
360 assert!(!self.spawned, "already spawned");
362
363 let work = Work {
365 label: self.label.clone(),
366 };
367 self.executor
368 .metrics
369 .tasks_spawned
370 .get_or_create(&work)
371 .inc();
372 let gauge = self
373 .executor
374 .metrics
375 .tasks_running
376 .get_or_create(&work)
377 .clone();
378
379 let catch_panics = self.executor.cfg.catch_panics;
381 let executor = self.executor.clone();
382 let future = f(self);
383 let (f, handle) = Handle::init(future, gauge, catch_panics);
384
385 executor.runtime.spawn(f);
387 handle
388 }
389
390 fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
391 where
392 F: Future<Output = T> + Send + 'static,
393 T: Send + 'static,
394 {
395 assert!(!self.spawned, "already spawned");
397 self.spawned = true;
398
399 let work = Work {
401 label: self.label.clone(),
402 };
403 self.executor
404 .metrics
405 .tasks_spawned
406 .get_or_create(&work)
407 .inc();
408 let gauge = self
409 .executor
410 .metrics
411 .tasks_running
412 .get_or_create(&work)
413 .clone();
414
415 let executor = self.executor.clone();
417 move |f: F| {
418 let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
419
420 executor.runtime.spawn(f);
422 handle
423 }
424 }
425
426 fn spawn_blocking<F, T>(self, f: F) -> Handle<T>
427 where
428 F: FnOnce() -> T + Send + 'static,
429 T: Send + 'static,
430 {
431 assert!(!self.spawned, "already spawned");
433
434 let work = Work {
436 label: self.label.clone(),
437 };
438 self.executor
439 .metrics
440 .blocking_tasks_spawned
441 .get_or_create(&work)
442 .inc();
443 let gauge = self
444 .executor
445 .metrics
446 .blocking_tasks_running
447 .get_or_create(&work)
448 .clone();
449
450 let (f, handle) = Handle::init_blocking(f, gauge, self.executor.cfg.catch_panics);
452
453 self.executor.runtime.spawn_blocking(f);
455 handle
456 }
457
458 fn stop(&self, value: i32) {
459 self.executor.signaler.lock().unwrap().signal(value);
460 }
461
462 fn stopped(&self) -> Signal {
463 self.executor.signal.clone()
464 }
465}
466
467impl crate::Metrics for Context {
468 fn with_label(&self, label: &str) -> Self {
469 let label = {
470 let prefix = self.label.clone();
471 if prefix.is_empty() {
472 label.to_string()
473 } else {
474 format!("{}_{}", prefix, label)
475 }
476 };
477 assert!(
478 !label.starts_with(METRICS_PREFIX),
479 "using runtime label is not allowed"
480 );
481 Self {
482 label,
483 spawned: false,
484 executor: self.executor.clone(),
485 storage: self.storage.clone(),
486 }
487 }
488
489 fn label(&self) -> String {
490 self.label.clone()
491 }
492
493 fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
494 let name = name.into();
495 let prefixed_name = {
496 let prefix = &self.label;
497 if prefix.is_empty() {
498 name
499 } else {
500 format!("{}_{}", *prefix, name)
501 }
502 };
503 self.executor
504 .registry
505 .lock()
506 .unwrap()
507 .register(prefixed_name, help, metric)
508 }
509
510 fn encode(&self) -> String {
511 let mut buffer = String::new();
512 encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
513 buffer
514 }
515}
516
517impl Clock for Context {
518 fn current(&self) -> SystemTime {
519 SystemTime::now()
520 }
521
522 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
523 tokio::time::sleep(duration)
524 }
525
526 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
527 let now = SystemTime::now();
528 let duration_until_deadline = match deadline.duration_since(now) {
529 Ok(duration) => duration,
530 Err(_) => Duration::from_secs(0), };
532 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
533 tokio::time::sleep_until(target_instant)
534 }
535}
536
537impl GClock for Context {
538 type Instant = SystemTime;
539
540 fn now(&self) -> Self::Instant {
541 self.current()
542 }
543}
544
545impl ReasonablyRealtime for Context {}
546
547impl crate::Network for Context {
548 type Listener = Listener;
549
550 async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
551 TcpListener::bind(socket)
552 .await
553 .map_err(|_| Error::BindFailed)
554 .map(|listener| Listener {
555 context: self.clone(),
556 listener,
557 })
558 }
559
560 async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
561 let stream = TcpStream::connect(socket)
563 .await
564 .map_err(|_| Error::ConnectionFailed)?;
565 self.executor.metrics.outbound_connections.inc();
566
567 if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
569 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
570 warn!(?err, "failed to set TCP_NODELAY");
571 }
572 }
573
574 let context = self.clone();
576 let (stream, sink) = stream.into_split();
577 Ok((
578 Sink {
579 context: context.clone(),
580 sink,
581 },
582 Stream { context, stream },
583 ))
584 }
585}
586
587pub struct Listener {
589 context: Context,
590 listener: TcpListener,
591}
592
593impl crate::Listener for Listener {
594 type Sink = Sink;
595 type Stream = Stream;
596
597 async fn accept(&mut self) -> Result<(SocketAddr, Self::Sink, Self::Stream), Error> {
598 let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
600 self.context.executor.metrics.inbound_connections.inc();
601
602 if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
604 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
605 warn!(?err, "failed to set TCP_NODELAY");
606 }
607 }
608
609 let context = self.context.clone();
611 let (stream, sink) = stream.into_split();
612 Ok((
613 addr,
614 Sink {
615 context: context.clone(),
616 sink,
617 },
618 Stream { context, stream },
619 ))
620 }
621}
622
623impl axum::serve::Listener for Listener {
624 type Io = TcpStream;
625 type Addr = SocketAddr;
626
627 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
628 let (stream, addr) = self.listener.accept().await.unwrap();
629 (stream, addr)
630 }
631
632 fn local_addr(&self) -> io::Result<Self::Addr> {
633 self.listener.local_addr()
634 }
635}
636
637pub struct Sink {
639 context: Context,
640 sink: OwnedWriteHalf,
641}
642
643impl crate::Sink for Sink {
644 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
645 let len = msg.len();
646 timeout(
647 self.context.executor.cfg.write_timeout,
648 self.sink.write_all(msg),
649 )
650 .await
651 .map_err(|_| Error::Timeout)?
652 .map_err(|_| Error::SendFailed)?;
653 self.context
654 .executor
655 .metrics
656 .outbound_bandwidth
657 .inc_by(len as u64);
658 Ok(())
659 }
660}
661
662pub struct Stream {
664 context: Context,
665 stream: OwnedReadHalf,
666}
667
668impl crate::Stream for Stream {
669 async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
670 timeout(
672 self.context.executor.cfg.read_timeout,
673 self.stream.read_exact(buf),
674 )
675 .await
676 .map_err(|_| Error::Timeout)?
677 .map_err(|_| Error::RecvFailed)?;
678
679 self.context
681 .executor
682 .metrics
683 .inbound_bandwidth
684 .inc_by(buf.len() as u64);
685
686 Ok(())
687 }
688}
689
690impl RngCore for Context {
691 fn next_u32(&mut self) -> u32 {
692 OsRng.next_u32()
693 }
694
695 fn next_u64(&mut self) -> u64 {
696 OsRng.next_u64()
697 }
698
699 fn fill_bytes(&mut self, dest: &mut [u8]) {
700 OsRng.fill_bytes(dest);
701 }
702
703 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
704 OsRng.try_fill_bytes(dest)
705 }
706}
707
708impl CryptoRng for Context {}
709
710impl crate::Storage for Context {
711 type Blob = <Storage<TokioStorage> as crate::Storage>::Blob;
712
713 async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
714 self.storage.open(partition, name).await
715 }
716
717 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
718 self.storage.remove(partition, name).await
719 }
720
721 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
722 self.storage.scan(partition).await
723 }
724}