1#[cfg(not(feature = "iouring-network"))]
2use crate::network::tokio::{Config as TokioNetworkConfig, Network as TokioNetwork};
3#[cfg(feature = "iouring-storage")]
4use crate::storage::iouring::{Config as IoUringConfig, Storage as IoUringStorage};
5#[cfg(not(feature = "iouring-storage"))]
6use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
7#[cfg(feature = "iouring-network")]
8use crate::{
9 iouring,
10 network::iouring::{Config as IoUringNetworkConfig, Network as IoUringNetwork},
11};
12use crate::{
13 network::metered::Network as MeteredNetwork, process::metered::Metrics as MeteredProcess,
14 signal::Signal, storage::metered::Storage as MeteredStorage, telemetry::metrics::task::Label,
15 utils::signal::Stopper, Clock, Error, Handle, SinkOf, StreamOf, METRICS_PREFIX,
16};
17use commonware_macros::select;
18use governor::clock::{Clock as GClock, ReasonablyRealtime};
19use prometheus_client::{
20 encoding::text::encode,
21 metrics::{counter::Counter, family::Family, gauge::Gauge},
22 registry::{Metric, Registry},
23};
24use rand::{rngs::OsRng, CryptoRng, RngCore};
25use std::{
26 env,
27 future::Future,
28 net::SocketAddr,
29 path::PathBuf,
30 sync::{Arc, Mutex},
31 time::{Duration, SystemTime},
32};
33use tokio::runtime::{Builder, Runtime};
34
35#[cfg(feature = "iouring-network")]
36const IOURING_NETWORK_SIZE: u32 = 1024;
37#[cfg(feature = "iouring-network")]
38const IOURING_NETWORK_FORCE_POLL: Option<Duration> = Some(Duration::from_millis(100));
39
40#[derive(Debug)]
41struct Metrics {
42 tasks_spawned: Family<Label, Counter>,
43 tasks_running: Family<Label, Gauge>,
44}
45
46impl Metrics {
47 pub fn init(registry: &mut Registry) -> Self {
48 let metrics = Self {
49 tasks_spawned: Family::default(),
50 tasks_running: Family::default(),
51 };
52 registry.register(
53 "tasks_spawned",
54 "Total number of tasks spawned",
55 metrics.tasks_spawned.clone(),
56 );
57 registry.register(
58 "tasks_running",
59 "Number of tasks currently running",
60 metrics.tasks_running.clone(),
61 );
62 metrics
63 }
64}
65
66#[derive(Clone, Debug)]
67pub struct NetworkConfig {
68 tcp_nodelay: Option<bool>,
71
72 read_write_timeout: Duration,
74}
75
76impl Default for NetworkConfig {
77 fn default() -> Self {
78 Self {
79 tcp_nodelay: None,
80 read_write_timeout: Duration::from_secs(60),
81 }
82 }
83}
84
85#[derive(Clone)]
87pub struct Config {
88 worker_threads: usize,
94
95 max_blocking_threads: usize,
103
104 catch_panics: bool,
106
107 storage_directory: PathBuf,
109
110 maximum_buffer_size: usize,
114
115 network_cfg: NetworkConfig,
117}
118
119impl Config {
120 pub fn new() -> Self {
122 let rng = OsRng.next_u64();
123 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{rng}"));
124 Self {
125 worker_threads: 2,
126 max_blocking_threads: 512,
127 catch_panics: true,
128 storage_directory,
129 maximum_buffer_size: 2 * 1024 * 1024, network_cfg: NetworkConfig::default(),
131 }
132 }
133
134 pub fn with_worker_threads(mut self, n: usize) -> Self {
137 self.worker_threads = n;
138 self
139 }
140 pub fn with_max_blocking_threads(mut self, n: usize) -> Self {
142 self.max_blocking_threads = n;
143 self
144 }
145 pub fn with_catch_panics(mut self, b: bool) -> Self {
147 self.catch_panics = b;
148 self
149 }
150 pub fn with_read_write_timeout(mut self, d: Duration) -> Self {
152 self.network_cfg.read_write_timeout = d;
153 self
154 }
155 pub fn with_tcp_nodelay(mut self, n: Option<bool>) -> Self {
157 self.network_cfg.tcp_nodelay = n;
158 self
159 }
160 pub fn with_storage_directory(mut self, p: impl Into<PathBuf>) -> Self {
162 self.storage_directory = p.into();
163 self
164 }
165 pub fn with_maximum_buffer_size(mut self, n: usize) -> Self {
167 self.maximum_buffer_size = n;
168 self
169 }
170
171 pub fn worker_threads(&self) -> usize {
174 self.worker_threads
175 }
176 pub fn max_blocking_threads(&self) -> usize {
178 self.max_blocking_threads
179 }
180 pub fn catch_panics(&self) -> bool {
182 self.catch_panics
183 }
184 pub fn read_write_timeout(&self) -> Duration {
186 self.network_cfg.read_write_timeout
187 }
188 pub fn tcp_nodelay(&self) -> Option<bool> {
190 self.network_cfg.tcp_nodelay
191 }
192 pub fn storage_directory(&self) -> &PathBuf {
194 &self.storage_directory
195 }
196 pub fn maximum_buffer_size(&self) -> usize {
198 self.maximum_buffer_size
199 }
200}
201
202impl Default for Config {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208pub struct Executor {
210 cfg: Config,
211 registry: Mutex<Registry>,
212 metrics: Arc<Metrics>,
213 runtime: Runtime,
214 shutdown: Mutex<Stopper>,
215}
216
217pub struct Runner {
219 cfg: Config,
220}
221
222impl Default for Runner {
223 fn default() -> Self {
224 Self::new(Config::default())
225 }
226}
227
228impl Runner {
229 pub fn new(cfg: Config) -> Self {
231 Self { cfg }
232 }
233}
234
235impl crate::Runner for Runner {
236 type Context = Context;
237
238 fn start<F, Fut>(self, f: F) -> Fut::Output
239 where
240 F: FnOnce(Self::Context) -> Fut,
241 Fut: Future,
242 {
243 let mut registry = Registry::default();
245 let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
246
247 let metrics = Arc::new(Metrics::init(runtime_registry));
249 let runtime = Builder::new_multi_thread()
250 .worker_threads(self.cfg.worker_threads)
251 .max_blocking_threads(self.cfg.max_blocking_threads)
252 .enable_all()
253 .build()
254 .expect("failed to create Tokio runtime");
255
256 let process = MeteredProcess::init(runtime_registry);
261 runtime.spawn(process.collect(tokio::time::sleep));
262
263 cfg_if::cfg_if! {
265 if #[cfg(feature = "iouring-storage")] {
266 let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_storage");
267 let storage = MeteredStorage::new(
268 IoUringStorage::start(IoUringConfig {
269 storage_directory: self.cfg.storage_directory.clone(),
270 ring_config: Default::default(),
271 }, iouring_registry),
272 runtime_registry,
273 );
274 } else {
275 let storage = MeteredStorage::new(
276 TokioStorage::new(TokioStorageConfig::new(
277 self.cfg.storage_directory.clone(),
278 self.cfg.maximum_buffer_size,
279 )),
280 runtime_registry,
281 );
282 }
283 }
284
285 cfg_if::cfg_if! {
287 if #[cfg(feature = "iouring-network")] {
288 let iouring_registry = runtime_registry.sub_registry_with_prefix("iouring_network");
289 let config = IoUringNetworkConfig {
290 tcp_nodelay: self.cfg.network_cfg.tcp_nodelay,
291 iouring_config: iouring::Config {
292 size: IOURING_NETWORK_SIZE,
294 op_timeout: Some(self.cfg.network_cfg.read_write_timeout),
295 force_poll: IOURING_NETWORK_FORCE_POLL,
296 shutdown_timeout: Some(self.cfg.network_cfg.read_write_timeout),
297 ..Default::default()
298 },
299 };
300 let network = MeteredNetwork::new(
301 IoUringNetwork::start(config, iouring_registry).unwrap(),
302 runtime_registry,
303 );
304 } else {
305 let config = TokioNetworkConfig::default().with_read_timeout(self.cfg.network_cfg.read_write_timeout)
306 .with_write_timeout(self.cfg.network_cfg.read_write_timeout)
307 .with_tcp_nodelay(self.cfg.network_cfg.tcp_nodelay);
308 let network = MeteredNetwork::new(
309 TokioNetwork::from(config),
310 runtime_registry,
311 );
312 }
313 }
314
315 let executor = Arc::new(Executor {
317 cfg: self.cfg,
318 registry: Mutex::new(registry),
319 metrics,
320 runtime,
321 shutdown: Mutex::new(Stopper::default()),
322 });
323
324 let label = Label::root();
326 executor.metrics.tasks_spawned.get_or_create(&label).inc();
327 let gauge = executor.metrics.tasks_running.get_or_create(&label).clone();
328
329 let context = Context {
331 storage,
332 name: label.name(),
333 spawned: false,
334 executor: executor.clone(),
335 network,
336 };
337 let output = executor.runtime.block_on(f(context));
338 gauge.dec();
339
340 output
341 }
342}
343
344cfg_if::cfg_if! {
345 if #[cfg(feature = "iouring-storage")] {
346 type Storage = MeteredStorage<IoUringStorage>;
347 } else {
348 type Storage = MeteredStorage<TokioStorage>;
349 }
350}
351
352cfg_if::cfg_if! {
353 if #[cfg(feature = "iouring-network")] {
354 type Network = MeteredNetwork<IoUringNetwork>;
355 } else {
356 type Network = MeteredNetwork<TokioNetwork>;
357 }
358}
359
360pub struct Context {
364 name: String,
365 spawned: bool,
366 executor: Arc<Executor>,
367 storage: Storage,
368 network: Network,
369}
370
371impl Clone for Context {
372 fn clone(&self) -> Self {
373 Self {
374 name: self.name.clone(),
375 spawned: false,
376 executor: self.executor.clone(),
377 storage: self.storage.clone(),
378 network: self.network.clone(),
379 }
380 }
381}
382
383impl crate::Spawner for Context {
384 fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
385 where
386 F: FnOnce(Self) -> Fut + Send + 'static,
387 Fut: Future<Output = T> + Send + 'static,
388 T: Send + 'static,
389 {
390 assert!(!self.spawned, "already spawned");
392
393 let (_, gauge) = spawn_metrics!(self, future);
395
396 let catch_panics = self.executor.cfg.catch_panics;
398 let executor = self.executor.clone();
399 let future = f(self);
400 let (f, handle) = Handle::init_future(future, gauge, catch_panics);
401
402 executor.runtime.spawn(f);
404 handle
405 }
406
407 fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
408 where
409 F: Future<Output = T> + Send + 'static,
410 T: Send + 'static,
411 {
412 assert!(!self.spawned, "already spawned");
414 self.spawned = true;
415
416 let (_, gauge) = spawn_metrics!(self, future);
418
419 let executor = self.executor.clone();
421 move |f: F| {
422 let (f, handle) = Handle::init_future(f, gauge, executor.cfg.catch_panics);
423
424 executor.runtime.spawn(f);
426 handle
427 }
428 }
429
430 fn spawn_blocking<F, T>(self, dedicated: bool, f: F) -> Handle<T>
431 where
432 F: FnOnce(Self) -> T + Send + 'static,
433 T: Send + 'static,
434 {
435 assert!(!self.spawned, "already spawned");
437
438 let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
440
441 let executor = self.executor.clone();
443 let (f, handle) = Handle::init_blocking(|| f(self), gauge, executor.cfg.catch_panics);
444
445 if dedicated {
447 std::thread::spawn(f);
448 } else {
449 executor.runtime.spawn_blocking(f);
450 }
451 handle
452 }
453
454 fn spawn_blocking_ref<F, T>(&mut self, dedicated: bool) -> impl FnOnce(F) -> Handle<T> + 'static
455 where
456 F: FnOnce() -> T + Send + 'static,
457 T: Send + 'static,
458 {
459 assert!(!self.spawned, "already spawned");
461 self.spawned = true;
462
463 let (_, gauge) = spawn_metrics!(self, blocking, dedicated);
465
466 let executor = self.executor.clone();
468 move |f: F| {
469 let (f, handle) = Handle::init_blocking(f, gauge, executor.cfg.catch_panics);
470
471 if dedicated {
473 std::thread::spawn(f);
474 } else {
475 executor.runtime.spawn_blocking(f);
476 }
477 handle
478 }
479 }
480
481 async fn stop(self, value: i32, timeout: Option<Duration>) -> Result<(), Error> {
482 let stop_resolved = {
483 let mut shutdown = self.executor.shutdown.lock().unwrap();
484 shutdown.stop(value)
485 };
486
487 let timeout_future = match timeout {
489 Some(duration) => futures::future::Either::Left(self.sleep(duration)),
490 None => futures::future::Either::Right(futures::future::pending()),
491 };
492 select! {
493 result = stop_resolved => {
494 result.map_err(|_| Error::Closed)?;
495 Ok(())
496 },
497 _ = timeout_future => {
498 Err(Error::Timeout)
499 }
500 }
501 }
502
503 fn stopped(&self) -> Signal {
504 self.executor.shutdown.lock().unwrap().stopped()
505 }
506}
507
508impl crate::Metrics for Context {
509 fn with_label(&self, label: &str) -> Self {
510 let name = {
511 let prefix = self.name.clone();
512 if prefix.is_empty() {
513 label.to_string()
514 } else {
515 format!("{prefix}_{label}")
516 }
517 };
518 assert!(
519 !name.starts_with(METRICS_PREFIX),
520 "using runtime label is not allowed"
521 );
522 Self {
523 name,
524 spawned: false,
525 executor: self.executor.clone(),
526 storage: self.storage.clone(),
527 network: self.network.clone(),
528 }
529 }
530
531 fn label(&self) -> String {
532 self.name.clone()
533 }
534
535 fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
536 let name = name.into();
537 let prefixed_name = {
538 let prefix = &self.name;
539 if prefix.is_empty() {
540 name
541 } else {
542 format!("{}_{}", *prefix, name)
543 }
544 };
545 self.executor
546 .registry
547 .lock()
548 .unwrap()
549 .register(prefixed_name, help, metric)
550 }
551
552 fn encode(&self) -> String {
553 let mut buffer = String::new();
554 encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
555 buffer
556 }
557}
558
559impl Clock for Context {
560 fn current(&self) -> SystemTime {
561 SystemTime::now()
562 }
563
564 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
565 tokio::time::sleep(duration)
566 }
567
568 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
569 let now = SystemTime::now();
570 let duration_until_deadline = match deadline.duration_since(now) {
571 Ok(duration) => duration,
572 Err(_) => Duration::from_secs(0), };
574 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
575 tokio::time::sleep_until(target_instant)
576 }
577}
578
579impl GClock for Context {
580 type Instant = SystemTime;
581
582 fn now(&self) -> Self::Instant {
583 self.current()
584 }
585}
586
587impl ReasonablyRealtime for Context {}
588
589impl crate::Network for Context {
590 type Listener = <Network as crate::Network>::Listener;
591
592 async fn bind(&self, socket: SocketAddr) -> Result<Self::Listener, Error> {
593 self.network.bind(socket).await
594 }
595
596 async fn dial(&self, socket: SocketAddr) -> Result<(SinkOf<Self>, StreamOf<Self>), Error> {
597 self.network.dial(socket).await
598 }
599}
600
601impl RngCore for Context {
602 fn next_u32(&mut self) -> u32 {
603 OsRng.next_u32()
604 }
605
606 fn next_u64(&mut self) -> u64 {
607 OsRng.next_u64()
608 }
609
610 fn fill_bytes(&mut self, dest: &mut [u8]) {
611 OsRng.fill_bytes(dest);
612 }
613
614 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
615 OsRng.try_fill_bytes(dest)
616 }
617}
618
619impl CryptoRng for Context {}
620
621impl crate::Storage for Context {
622 type Blob = <Storage as crate::Storage>::Blob;
623
624 async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
625 self.storage.open(partition, name).await
626 }
627
628 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
629 self.storage.remove(partition, name).await
630 }
631
632 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
633 self.storage.scan(partition).await
634 }
635}