1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "external")]
8use crate::Pacer;
9#[cfg(feature = "iouring-network")]
10use crate::{
11 iouring,
12 network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
13};
14use crate::{
15 network::metered::Network as MeteredNetwork,
16 process::metered::Metrics as MeteredProcess,
17 signal::Signal,
18 storage::metered::Storage as MeteredStorage,
19 telemetry::metrics::task::Label,
20 utils::{add_attribute, signal::Stopper, supervision::Tree, MetricEncoder, Panicker},
21 BufferPool, BufferPoolConfig, Clock, Error, Execution, Handle, Metrics as _, SinkOf,
22 Spawner as _, StreamOf, METRICS_PREFIX,
23};
24use commonware_macros::{select, stability};
25#[stability(BETA)]
26use commonware_parallel::ThreadPool;
27use futures::{future::BoxFuture, FutureExt};
28use governor::clock::{Clock as GClock, ReasonablyRealtime};
29use prometheus_client::{
30 encoding::text::encode,
31 metrics::{counter::Counter, family::Family, gauge::Gauge},
32 registry::{Metric, Registry},
33};
34use rand::{rngs::OsRng, CryptoRng, RngCore};
35#[stability(BETA)]
36use rayon::{ThreadPoolBuildError, ThreadPoolBuilder};
37use std::{
38 borrow::Cow,
39 env,
40 future::Future,
41 net::{IpAddr, SocketAddr},
42 num::NonZeroUsize,
43 path::PathBuf,
44 sync::{Arc, Mutex},
45 thread,
46 time::{Duration, SystemTime},
47};
48use tokio::runtime::{Builder, Runtime};
49use tracing::{info_span, Instrument};
50use tracing_opentelemetry::OpenTelemetrySpanExt;
51
52#[cfg(feature = "iouring-network")]
53const IOURING_NETWORK_SIZE: u32 = 1024;
54#[cfg(feature = "iouring-network")]
55const IOURING_NETWORK_FORCE_POLL: Duration = Duration::from_millis(100);
56
57#[derive(Debug)]
58struct Metrics {
59 tasks_spawned: Family<Label, Counter>,
60 tasks_running: Family<Label, Gauge>,
61}
62
63impl Metrics {
64 pub fn init(registry: &mut Registry) -> Self {
65 let metrics = Self {
66 tasks_spawned: Family::default(),
67 tasks_running: Family::default(),
68 };
69 registry.register(
70 "tasks_spawned",
71 "Total number of tasks spawned",
72 metrics.tasks_spawned.clone(),
73 );
74 registry.register(
75 "tasks_running",
76 "Number of tasks currently running",
77 metrics.tasks_running.clone(),
78 );
79 metrics
80 }
81}
82
83#[derive(Clone, Debug)]
84pub struct NetworkConfig {
85 tcp_nodelay: Option<bool>,
88
89 read_write_timeout: Duration,
91}
92
93impl Default for NetworkConfig {
94 fn default() -> Self {
95 Self {
96 tcp_nodelay: None,
97 read_write_timeout: Duration::from_secs(60),
98 }
99 }
100}
101
102#[derive(Clone)]
104pub struct Config {
105 worker_threads: usize,
111
112 max_blocking_threads: usize,
120
121 catch_panics: bool,
123
124 storage_directory: PathBuf,
126
127 maximum_buffer_size: usize,
131
132 network_cfg: NetworkConfig,
134}
135
136impl Config {
137 pub fn new() -> Self {
139 let rng = OsRng.next_u64();
140 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
141 Self {
142 worker_threads: 2,
143 max_blocking_threads: 512,
144 catch_panics: false,
145 storage_directory,
146 maximum_buffer_size: 2 * 1024 * 1024, network_cfg: NetworkConfig::default(),
148 }
149 }
150
151 pub const fn with_worker_threads(mut self, n: usize) -> Self {
154 self.worker_threads = n;
155 self
156 }
157 pub const fn with_max_blocking_threads(mut self, n: usize) -> Self {
159 self.max_blocking_threads = n;
160 self
161 }
162 pub const fn with_catch_panics(mut self, b: bool) -> Self {
164 self.catch_panics = b;
165 self
166 }
167 pub const fn with_read_write_timeout(mut self, d: Duration) -> Self {
169 self.network_cfg.read_write_timeout = d;
170 self
171 }
172 pub const fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
174 self.network_cfg.tcp_nodelay = n;
175 self
176 }
177 pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
179 self.storage_directory = p.into();
180 self
181 }
182 pub const fn with_maximum_buffer_size(mut self, n: usize) -> Self {
184 self.maximum_buffer_size = n;
185 self
186 }
187
188 pub const fn worker_threads(&self) -> usize {
191 self.worker_threads
192 }
193 pub const fn max_blocking_threads(&self) -> usize {
195 self.max_blocking_threads
196 }
197 pub const fn catch_panics(&self) -> bool {
199 self.catch_panics
200 }
201 pub const fn read_write_timeout(&self) -> Duration {
203 self.network_cfg.read_write_timeout
204 }
205 pub const fn tcp_nodelay(&self) -> Option<bool> {
207 self.network_cfg.tcp_nodelay
208 }
209 pub const fn storage_directory(&self) -> &PathBuf {
211 &self.storage_directory
212 }
213 pub const fn maximum_buffer_size(&self) -> usize {
215 self.maximum_buffer_size
216 }
217}
218
219impl Default for Config {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225pub struct Executor {
227 registry: Mutex<Registry>,
228 metrics: Arc<Metrics>,
229 runtime: Runtime,
230 shutdown: Mutex<Stopper>,
231 panicker: Panicker,
232}
233
234pub struct Runner {
236 cfg: Config,
237}
238
239impl Default for Runner {
240 fn default() -> Self {
241 Self::new(Config::default())
242 }
243}
244
245impl Runner {
246 pub const fn new(cfg: Config) -> Self {
248 Self { cfg }
249 }
250}
251
252impl crate::Runner for Runner {
253 type Context = Context;
254
255 fn start<F, Fut>(self, f: F) -> Fut::Output
256 where
257 F: FnOnce(Self::Context) -> Fut,
258 Fut: Future,
259 {
260 let mut registry = Registry::default();
262 let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
263
264 let metrics = Arc::new(Metrics::init(runtime_registry));
266 let runtime = Builder::new_multi_thread()
267 .worker_threads(self.cfg.worker_threads)
268 .max_blocking_threads(self.cfg.max_blocking_threads)
269 .enable_all()
270 .build()
271 .expect("failed to create Tokio runtime");
272
273 let (panicker, panicked) = Panicker::new(self.cfg.catch_panics);
275
276 let process = MeteredProcess::init(runtime_registry);
281 runtime.spawn(process.collect(tokio::time::sleep));
282
283 cfg_if::cfg_if! {
285 if #[cfg(feature = "iouring-storage")] {
286 let iouring_registry =
287 runtime_registry.sub_registry_with_prefix("iouring_storage");
288 let storage = MeteredStorage::new(
289 IoUringStorage::start(
290 IoUringConfig {
291 storage_directory: self.cfg.storage_directory.clone(),
292 iouring_config: Default::default(),
293 },
294 iouring_registry,
295 ),
296 runtime_registry,
297 );
298 } else {
299 let storage = MeteredStorage::new(
300 TokioStorage::new(TokioStorageConfig::new(
301 self.cfg.storage_directory.clone(),
302 self.cfg.maximum_buffer_size,
303 )),
304 runtime_registry,
305 );
306 }
307 }
308
309 let network_buffer_pool = BufferPool::new(
311 BufferPoolConfig::for_network(),
312 runtime_registry.sub_registry_with_prefix("network_buffer_pool"),
313 );
314 let storage_buffer_pool = BufferPool::new(
315 BufferPoolConfig::for_storage(),
316 runtime_registry.sub_registry_with_prefix("storage_buffer_pool"),
317 );
318
319 cfg_if::cfg_if! {
321 if #[cfg(feature = "iouring-network")] {
322 let iouring_registry =
323 runtime_registry.sub_registry_with_prefix("iouring_network");
324 let config = IoUringNetworkConfig {
325 tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
326 iouring_config: iouring::Config {
327 size: IOURING_NETWORK_SIZE,
329 op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
330 force_poll: IOURING_NETWORK_FORCE_POLL,
331 shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
332 ..Default::default()
333 },
334 ..Default::default()
335 };
336 let network = MeteredNetwork::new(
337 IoUringNetwork::start(
338 config,
339 iouring_registry,
340 network_buffer_pool.clone(),
341 )
342 .unwrap(),
343 runtime_registry,
344 );
345 } else {
346 let config = TokioNetworkConfig::default()
347 .with_read_timeout(self.cfg.network_cfg.read_write_timeout)
348 .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
349 .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
350 let network = MeteredNetwork::new(
351 TokioNetwork::new(config, network_buffer_pool.clone()),
352 runtime_registry,
353 );
354 }
355 }
356
357 let executor = Arc::new(Executor {
359 registry: Mutex::new(registry),
360 metrics,
361 runtime,
362 shutdown: Mutex::new(Stopper::default()),
363 panicker,
364 });
365
366 let label = Label::root();
368 executor.metrics.tasks_spawned.get_or_create(&label).inc();
369 let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
370
371 let context = Context {
373 storage,
374 name: label.name(),
375 attributes: Vec::new(),
376 executor: executor.clone(),
377 network,
378 network_buffer_pool,
379 storage_buffer_pool,
380 tree: Tree::root(),
381 execution: Execution::default(),
382 instrumented: false,
383 };
384 let output = executor.runtime.block_on(panicked.interrupt(f(context)));
385 gauge.dec();
386
387 output
388 }
389}
390
391cfg_if::cfg_if! {
392 if #[cfg(feature = "iouring-storage")] {
393 type Storage = MeteredStorage<IoUringStorage>;
394 } else {
395 type Storage = MeteredStorage<TokioStorage>;
396 }
397}
398
399cfg_if::cfg_if! {
400 if #[cfg(feature = "iouring-network")] {
401 type Network = MeteredNetwork<IoUringNetwork>;
402 } else {
403 type Network = MeteredNetwork<TokioNetwork>;
404 }
405}
406
407pub struct Context {
411 name: String,
412 attributes: Vec<(String, String)>,
413 executor: Arc<Executor>,
414 storage: Storage,
415 network: Network,
416 network_buffer_pool: BufferPool,
417 storage_buffer_pool: BufferPool,
418 tree: Arc<Tree>,
419 execution: Execution,
420 instrumented: bool,
421}
422
423impl Clone for Context {
424 fn clone(&self) -> Self {
425 let (child, _) = Tree::child(&self.tree);
426 Self {
427 name: self.name.clone(),
428 attributes: self.attributes.clone(),
429 executor: self.executor.clone(),
430 storage: self.storage.clone(),
431 network: self.network.clone(),
432 network_buffer_pool: self.network_buffer_pool.clone(),
433 storage_buffer_pool: self.storage_buffer_pool.clone(),
434 tree: child,
435 execution: Execution::default(),
436 instrumented: false,
437 }
438 }
439}
440
441impl Context {
442 fn metrics(&self) -> &Metrics {
444 &self.executor.metrics
445 }
446}
447
448impl crate::Spawner for Context {
449 fn dedicated(mut self) -> Self {
450 self.execution = Execution::Dedicated;
451 self
452 }
453
454 fn shared(mut self, blocking: bool) -> Self {
455 self.execution = Execution::Shared(blocking);
456 self
457 }
458
459 fn instrumented(mut self) -> Self {
460 self.instrumented = true;
461 self
462 }
463
464 fn spawn<F, Fut, T>(mut self, f: F) -> Handle<T>
465 where
466 F: FnOnce(Self) -> Fut + Send + 'static,
467 Fut: Future<Output = T> + Send + 'static,
468 T: Send + 'static,
469 {
470 let (label, metric) = spawn_metrics!(self);
472
473 let parent = Arc::clone(&self.tree);
475 let past = self.execution;
476 let is_instrumented = self.instrumented;
477 self.execution = Execution::default();
478 self.instrumented = false;
479 let (child, aborted) = Tree::child(&parent);
480 if aborted {
481 return Handle::closed(metric);
482 }
483 self.tree = child;
484
485 let executor = self.executor.clone();
487 let future: BoxFuture<'_, T> = if is_instrumented {
488 let span = info_span!("task", name = %label.name());
489 for (key, value) in &self.attributes {
490 span.set_attribute(key.clone(), value.clone());
491 }
492 f(self).instrument(span).boxed()
493 } else {
494 f(self).boxed()
495 };
496 let (f, handle) = Handle::init(
497 future,
498 metric,
499 executor.panicker.clone(),
500 Arc::clone(&parent),
501 );
502
503 if matches!(past, Execution::Dedicated) {
504 thread::spawn({
505 let handle = executor.runtime.handle().clone();
507 move || {
508 handle.block_on(f);
509 }
510 });
511 } else if matches!(past, Execution::Shared(true)) {
512 executor.runtime.spawn_blocking({
513 let handle = executor.runtime.handle().clone();
515 move || {
516 handle.block_on(f);
517 }
518 });
519 } else {
520 executor.runtime.spawn(f);
521 }
522
523 if let Some(aborter) = handle.aborter() {
525 parent.register(aborter);
526 }
527
528 handle
529 }
530
531 async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
532 let stop_resolved = {
533 let mut shutdown = self.executor.shutdown.lock().unwrap();
534 shutdown.stop(value)
535 };
536
537 let timeout_future = timeout.map_or_else(
539 || futures::future::Either::Right(futures::future::pending()),
540 |duration| futures::future::Either::Left(self.sleep(duration)),
541 );
542 select! {
543 result = stop_resolved => {
544 result.map_err(|_| Error::Closed)?;
545 Ok(())
546 },
547 _ = timeout_future => Err(Error::Timeout),
548 }
549 }
550
551 fn stopped(&self) -> Signal {
552 self.executor.shutdown.lock().unwrap().stopped()
553 }
554}
555
556#[stability(BETA)]
557impl crate::ThreadPooler for Context {
558 fn create_thread_pool(
559 &self,
560 concurrency: NonZeroUsize,
561 ) -> Result<ThreadPool, ThreadPoolBuildError> {
562 ThreadPoolBuilder::new()
563 .num_threads(concurrency.get())
564 .spawn_handler(move |thread| {
565 self.with_label("rayon_thread")
568 .dedicated()
569 .spawn(move |_| async move { thread.run() });
570 Ok(())
571 })
572 .build()
573 .map(Arc::new)
574 }
575}
576
577impl crate::Metrics for Context {
578 fn with_label(&self, label: &str) -> Self {
579 let name = {
581 let prefix = self.name.clone();
582 if prefix.is_empty() {
583 label.to_string()
584 } else {
585 format!("{prefix}_{label}")
586 }
587 };
588 Self {
589 name,
590 ..self.clone()
591 }
592 }
593
594 fn label(&self) -> String {
595 self.name.clone()
596 }
597
598 fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
599 let name = name.into();
600 let prefixed_name = {
601 let prefix = &self.name;
602 if prefix.is_empty() {
603 name
604 } else {
605 format!("{}_{}", *prefix, name)
606 }
607 };
608
609 let mut registry = self.executor.registry.lock().unwrap();
611 let sub_registry = self.attributes.iter().fold(&mut *registry, |reg, (k, v)| {
612 reg.sub_registry_with_label((Cow::Owned(k.clone()), Cow::Owned(v.clone())))
613 });
614 sub_registry.register(prefixed_name, help, metric);
615 }
616
617 fn encode(&self) -> String {
618 let mut encoder = MetricEncoder::new();
619 encode(&mut encoder, &self.executor.registry.lock().unwrap()).expect("encoding failed");
620 encoder.into_string()
621 }
622
623 fn with_attribute(&self, key: &str, value: impl std::fmt::Display) -> Self {
624 let mut attributes = self.attributes.clone();
626 add_attribute(&mut attributes, key, value);
627 Self {
628 attributes,
629 ..self.clone()
630 }
631 }
632}
633
634impl Clock for Context {
635 fn current(&self) -> SystemTime {
636 SystemTime::now()
637 }
638
639 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
640 tokio::time::sleep(duration)
641 }
642
643 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
644 let now = SystemTime::now();
645 let duration_until_deadline = deadline.duration_since(now).unwrap_or_else(|_| {
646 Duration::from_secs(0)
648 });
649 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
650 tokio::time::sleep_until(target_instant)
651 }
652}
653
654#[cfg(feature = "external")]
655impl Pacer for Context {
656 fn pace<'a, F, T>(
657 &'a self,
658 _latency: Duration,
659 future: F,
660 ) -> impl Future<Output = T> + Send + 'a
661 where
662 F: Future<Output = T> + Send + 'a,
663 T: Send + 'a,
664 {
665 future
667 }
668}
669
670impl GClock for Context {
671 type Instant = SystemTime;
672
673 fn now(&self) -> Self::Instant {
674 self.current()
675 }
676}
677
678impl ReasonablyRealtime for Context {}
679
680impl crate::Network for Context {
681 type Listener = <Network as crate::Network>::Listener;
682
683 async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
684 self.network.bind(socket).await
685 }
686
687 async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
688 self.network.dial(socket).await
689 }
690}
691
692impl crate::Resolver for Context {
693 async fn resolve(&self, host: &str) -> Result<Vec<IpAddr>, Error> {
694 let addrs = tokio::net::lookup_host(format!("{host}:0"))
700 .await
701 .map_err(|e| Error::ResolveFailed(e.to_string()))?;
702 Ok(addrs.map(|addr| addr.ip()).collect())
703 }
704}
705
706impl RngCore for Context {
707 fn next_u32(&mut self) -> u32 {
708 OsRng.next_u32()
709 }
710
711 fn next_u64(&mut self) -> u64 {
712 OsRng.next_u64()
713 }
714
715 fn fill_bytes(&mut self, dest: &mut [u8]) {
716 OsRng.fill_bytes(dest);
717 }
718
719 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
720 OsRng.try_fill_bytes(dest)
721 }
722}
723
724impl CryptoRng for Context {}
725
726impl crate::Storage for Context {
727 type Blob = <Storage as crate::Storage>::Blob;
728
729 async fn open_versioned(
730 &self,
731 partition: &str,
732 name: &[u8],
733 versions: std::ops::RangeInclusive<u16>,
734 ) -> Result<(Self::Blob, u64, u16), Error> {
735 self.storage.open_versioned(partition, name, versions).await
736 }
737
738 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
739 self.storage.remove(partition, name).await
740 }
741
742 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
743 self.storage.scan(partition).await
744 }
745}
746
747impl crate::BufferPooler for Context {
748 fn network_buffer_pool(&self) -> &BufferPool {
749 &self.network_buffer_pool
750 }
751
752 fn storage_buffer_pool(&self) -> &BufferPool {
753 &self.storage_buffer_pool
754 }
755}