1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "iouring-network")]
8use crate::{
9 iouring,
10 network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
11};
12use crate::{
13 network::metered::Network as MeteredNetwork, process::metered::Metrics as MeteredProcess,
14 signal::Signal, storage::metered::Storage as MeteredStorage, telemetry::metrics::task::Label,
15 utils::signal::Stopper, Clock, Error, Handle, SinkOf, StreamOf, METRICS_PREFIX,
16};
17use commonware_macros::select;
18use futures::future::AbortHandle;
19use governor::clock::{Clock as GClock, ReasonablyRealtime};
20use prometheus_client::{
21 encoding::text::encode,
22 metrics::{counter::Counter, family::Family, gauge::Gauge},
23 registry::{Metric, Registry},
24};
25use rand::{rngs::OsRng, CryptoRng, RngCore};
26use std::{
27 env,
28 future::Future,
29 net::SocketAddr,
30 path::PathBuf,
31 sync::{Arc, Mutex},
32 time::{Duration, SystemTime},
33};
34use tokio::runtime::{Builder, Runtime};
35
36#[cfg(feature = "iouring-network")]
37const IOURING_NETWORK_SIZE: u32 = 1024;
38#[cfg(feature = "iouring-network")]
39const IOURING_NETWORK_FORCE_POLL: Option<Duration> = Some(Duration::from_millis(100));
40
41#[derive(Debug)]
42struct Metrics {
43 tasks_spawned: Family<Label, Counter>,
44 tasks_running: Family<Label, Gauge>,
45}
46
47impl Metrics {
48 pub fn init(registry: &mut Registry) -> Self {
49 let metrics = Self {
50 tasks_spawned: Family::default(),
51 tasks_running: Family::default(),
52 };
53 registry.register(
54 "tasks_spawned",
55 "Total number of tasks spawned",
56 metrics.tasks_spawned.clone(),
57 );
58 registry.register(
59 "tasks_running",
60 "Number of tasks currently running",
61 metrics.tasks_running.clone(),
62 );
63 metrics
64 }
65}
66
67#[derive(Clone, Debug)]
68pub struct NetworkConfig {
69 tcp_nodelay: Option<bool>,
72
73 read_write_timeout: Duration,
75}
76
77impl Default for NetworkConfig {
78 fn default() -> Self {
79 Self {
80 tcp_nodelay: None,
81 read_write_timeout: Duration::from_secs(60),
82 }
83 }
84}
85
86#[derive(Clone)]
88pub struct Config {
89 worker_threads: usize,
95
96 max_blocking_threads: usize,
104
105 catch_panics: bool,
107
108 storage_directory: PathBuf,
110
111 maximum_buffer_size: usize,
115
116 network_cfg: NetworkConfig,
118}
119
120impl Config {
121 pub fn new() -> Self {
123 let rng = OsRng.next_u64();
124 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
125 Self {
126 worker_threads: 2,
127 max_blocking_threads: 512,
128 catch_panics: true,
129 storage_directory,
130 maximum_buffer_size: 2 * 1024 * 1024, network_cfg: NetworkConfig::default(),
132 }
133 }
134
135 pub fn with_worker_threads(mut self, n: usize) -> Self {
138 self.worker_threads = n;
139 self
140 }
141 pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
143 self.max_blocking_threads = n;
144 self
145 }
146 pub fn with_catch_panics(mut self, b: bool) -> Self {
148 self.catch_panics = b;
149 self
150 }
151 pub fn with_read_write_timeout(mut self, d: Duration) -> Self {
153 self.network_cfg.read_write_timeout = d;
154 self
155 }
156 pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
158 self.network_cfg.tcp_nodelay = n;
159 self
160 }
161 pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
163 self.storage_directory = p.into();
164 self
165 }
166 pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
168 self.maximum_buffer_size = n;
169 self
170 }
171
172 pub fn worker_threads(&self) -> usize {
175 self.worker_threads
176 }
177 pub fn max_blocking_threads(&self) -> usize {
179 self.max_blocking_threads
180 }
181 pub fn catch_panics(&self) -> bool {
183 self.catch_panics
184 }
185 pub fn read_write_timeout(&self) -> Duration {
187 self.network_cfg.read_write_timeout
188 }
189 pub fn tcp_nodelay(&self) -> Option<bool> {
191 self.network_cfg.tcp_nodelay
192 }
193 pub fn storage_directory(&self) -> &PathBuf {
195 &self.storage_directory
196 }
197 pub fn maximum_buffer_size(&self) -> usize {
199 self.maximum_buffer_size
200 }
201}
202
203impl Default for Config {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209pub struct Executor {
211 cfg: Config,
212 registry: Mutex<Registry>,
213 metrics: Arc<Metrics>,
214 runtime: Runtime,
215 shutdown: Mutex<Stopper>,
216}
217
218pub struct Runner {
220 cfg: Config,
221}
222
223impl Default for Runner {
224 fn default() -> Self {
225 Self::new(Config::default())
226 }
227}
228
229impl Runner {
230 pub fn new(cfg: Config) -> Self {
232 Self { cfg }
233 }
234}
235
236impl crate::Runner for Runner {
237 type Context = Context;
238
239 fn start<F, Fut>(self, f: F) -> Fut::Output
240 where
241 F: FnOnce(Self::Context) -> Fut,
242 Fut: Future,
243 {
244 let mut registry = Registry::default();
246 let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
247
248 let metrics = Arc::new(Metrics::init(runtime_registry));
250 let runtime = Builder::new_multi_thread()
251 .worker_threads(self.cfg.worker_threads)
252 .max_blocking_threads(self.cfg.max_blocking_threads)
253 .enable_all()
254 .build()
255 .expect("failed to create Tokio runtime");
256
257 let process = MeteredProcess::init(runtime_registry);
262 runtime.spawn(process.collect(tokio::time::sleep));
263
264 cfg_if::cfg_if! {
266 if #[cfg(feature = "iouring-storage")] {
267 let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
268 let storage = MeteredStorage::new(
269 IoUringStorage::start(IoUringConfig {
270 storage_directory: self.cfg.storage_directory.clone(),
271 iouring_config: Default::default(),
272 }, iouring_registry),
273 runtime_registry,
274 );
275 } else {
276 let storage = MeteredStorage::new(
277 TokioStorage::new(TokioStorageConfig::new(
278 self.cfg.storage_directory.clone(),
279 self.cfg.maximum_buffer_size,
280 )),
281 runtime_registry,
282 );
283 }
284 }
285
286 cfg_if::cfg_if! {
288 if #[cfg(feature = "iouring-network")] {
289 let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
290 let config = IoUringNetworkConfig {
291 tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
292 iouring_config: iouring::Config {
293 size: IOURING_NETWORK_SIZE,
295 op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
296 force_poll: IOURING_NETWORK_FORCE_POLL,
297 shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
298 ..Default::default()
299 },
300 };
301 let network = MeteredNetwork::new(
302 IoUringNetwork::start(config, iouring_registry).unwrap(),
303 runtime_registry,
304 );
305 } else {
306 let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
307 .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
308 .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
309 let network = MeteredNetwork::new(
310 TokioNetwork::from(config),
311 runtime_registry,
312 );
313 }
314 }
315
316 let executor = Arc::new(Executor {
318 cfg: self.cfg,
319 registry: Mutex::new(registry),
320 metrics,
321 runtime,
322 shutdown: Mutex::new(Stopper::default()),
323 });
324
325 let label = Label::root();
327 executor.metrics.tasks_spawned.get_or_create(&label).inc();
328 let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
329
330 let context = Context {
332 storage,
333 name: label.name(),
334 spawned: false,
335 executor: executor.clone(),
336 network,
337 children: Arc::new(Mutex::new(Vec::new())),
338 };
339 let output = executor.runtime.block_on(f(context));
340 gauge.dec();
341
342 output
343 }
344}
345
346cfg_if::cfg_if! {
347 if #[cfg(feature = "iouring-storage")] {
348 type Storage = MeteredStorage<IoUringStorage>;
349 } else {
350 type Storage = MeteredStorage<TokioStorage>;
351 }
352}
353
354cfg_if::cfg_if! {
355 if #[cfg(feature = "iouring-network")] {
356 type Network = MeteredNetwork<IoUringNetwork>;
357 } else {
358 type Network = MeteredNetwork<TokioNetwork>;
359 }
360}
361
362pub struct Context {
366 name: String,
367 spawned: bool,
368 executor: Arc<Executor>,
369 storage: Storage,
370 network: Network,
371 children: Arc<Mutex<Vec<AbortHandle>>>,
372}
373
374impl Clone for Context {
375 fn clone(&self) -> Self {
376 Self {
377 name: self.name.clone(),
378 spawned: false,
379 executor: self.executor.clone(),
380 storage: self.storage.clone(),
381 network: self.network.clone(),
382 children: self.children.clone(),
383 }
384 }
385}
386
387impl crate::Spawner for Context {
388 fn spawn<F, Fut, T>(mut self, f: F) -> Handle<T>
389 where
390 F: FnOnce(Self) -> Fut + Send + 'static,
391 Fut: Future<Output = T> + Send + 'static,
392 T: Send + 'static,
393 {
394 assert!(!self.spawned, "already spawned");
396
397 let (_, gauge) = spawn_metrics!(self, future);
399
400 let catch_panics = self.executor.cfg.catch_panics;
402 let executor = self.executor.clone();
403
404 let children = Arc::new(Mutex::new(Vec::new()));
406 self.children = children.clone();
407
408 let future = f(self);
409 let (f, handle) = Handle::init_future(future, gauge, catch_panics, children);
410
411 executor.runtime.spawn(f);
413 handle
414 }
415
416 fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
417 where
418 F: Future<Output = T> + Send + 'static,
419 T: Send + 'static,
420 {
421 assert!(!self.spawned, "already spawned");
423 self.spawned = true;
424
425 let (_, gauge) = spawn_metrics!(self, future);
427
428 let executor = self.executor.clone();
430
431 move |f: F| {
432 let (f, handle) = Handle::init_future(
433 f,
434 gauge,
435 executor.cfg.catch_panics,
436 Arc::new(Mutex::new(Vec::new())),
438 );
439
440 executor.runtime.spawn(f);
442 handle
443 }
444 }
445
446 fn spawn_child<F, Fut, T>(self, f: F) -> Handle<T>
447 where
448 F: FnOnce(Self) -> Fut + Send + 'static,
449 Fut: Future<Output = T> + Send + 'static,
450 T: Send + 'static,
451 {
452 let parent_children = self.children.clone();
454
455 let child_handle = self.spawn(f);
457
458 if let Some(abort_handle) = child_handle.abort_handle() {
460 parent_children.lock().unwrap().push(abort_handle);
461 }
462
463 child_handle
464 }
465
466 fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
467 where
468 F: FnOnce(Self) -> T + Send + 'static,
469 T: Send + 'static,
470 {
471 assert!(!self.spawned, "already spawned");
473
474 let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
476
477 let executor = self.executor.clone();
479 let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
480
481 if dedicated {
483 std::thread::spawn(f);
484 } else {
485 executor.runtime.spawn_blocking(f);
486 }
487 handle
488 }
489
490 fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
491 where
492 F: FnOnce() -> T + Send + 'static,
493 T: Send + 'static,
494 {
495 assert!(!self.spawned, "already spawned");
497 self.spawned = true;
498
499 let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
501
502 let executor = self.executor.clone();
504 move |f: F| {
505 let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
506
507 if dedicated {
509 std::thread::spawn(f);
510 } else {
511 executor.runtime.spawn_blocking(f);
512 }
513 handle
514 }
515 }
516
517 async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
518 let stop_resolved = {
519 let mut shutdown = self.executor.shutdown.lock().unwrap();
520 shutdown.stop(value)
521 };
522
523 let timeout_future = match timeout {
525 Some(duration) => futures::future::Either::Left(self.sleep(duration)),
526 None => futures::future::Either::Right(futures::future::pending()),
527 };
528 select! {
529 result = stop_resolved => {
530 result.map_err(|_| Error::Closed)?;
531 Ok(())
532 },
533 _ = timeout_future => {
534 Err(Error::Timeout)
535 }
536 }
537 }
538
539 fn stopped(&self) -> Signal {
540 self.executor.shutdown.lock().unwrap().stopped()
541 }
542}
543
544impl crate::Metrics for Context {
545 fn with_label(&self, label: &str) -> Self {
546 let name = {
547 let prefix = self.name.clone();
548 if prefix.is_empty() {
549 label.to_string()
550 } else {
551 format!("{prefix}_{label}")
552 }
553 };
554 assert!(
555 !name.starts_with(METRICS_PREFIX),
556 "using runtime label is not allowed"
557 );
558 Self {
559 name,
560 spawned: false,
561 executor: self.executor.clone(),
562 storage: self.storage.clone(),
563 network: self.network.clone(),
564 children: self.children.clone(),
565 }
566 }
567
568 fn label(&self) -> String {
569 self.name.clone()
570 }
571
572 fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
573 let name = name.into();
574 let prefixed_name = {
575 let prefix = &self.name;
576 if prefix.is_empty() {
577 name
578 } else {
579 format!("{}_{}", *prefix, name)
580 }
581 };
582 self.executor
583 .registry
584 .lock()
585 .unwrap()
586 .register(prefixed_name, help, metric)
587 }
588
589 fn encode(&self) -> String {
590 let mut buffer = String::new();
591 encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
592 buffer
593 }
594}
595
596impl Clock for Context {
597 fn current(&self) -> SystemTime {
598 SystemTime::now()
599 }
600
601 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
602 tokio::time::sleep(duration)
603 }
604
605 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
606 let now = SystemTime::now();
607 let duration_until_deadline = match deadline.duration_since(now) {
608 Ok(duration) => duration,
609 Err(_) => Duration::from_secs(0), };
611 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
612 tokio::time::sleep_until(target_instant)
613 }
614}
615
616impl GClock for Context {
617 type Instant = SystemTime;
618
619 fn now(&self) -> Self::Instant {
620 self.current()
621 }
622}
623
624impl ReasonablyRealtime for Context {}
625
626impl crate::Network for Context {
627 type Listener = <Network as crate::Network>::Listener;
628
629 async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
630 self.network.bind(socket).await
631 }
632
633 async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
634 self.network.dial(socket).await
635 }
636}
637
638impl RngCore for Context {
639 fn next_u32(&mut self) -> u32 {
640 OsRng.next_u32()
641 }
642
643 fn next_u64(&mut self) -> u64 {
644 OsRng.next_u64()
645 }
646
647 fn fill_bytes(&mut self, dest: &mut [u8]) {
648 OsRng.fill_bytes(dest);
649 }
650
651 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
652 OsRng.try_fill_bytes(dest)
653 }
654}
655
656impl CryptoRng for Context {}
657
658impl crate::Storage for Context {
659 type Blob = <Storage as crate::Storage>::Blob;
660
661 async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
662 self.storage.open(partition, name).await
663 }
664
665 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
666 self.storage.remove(partition, name).await
667 }
668
669 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
670 self.storage.scan(partition).await
671 }
672}