1use crate::storage::metered::Storage;
2use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
3use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
4use governor::clock::{Clock as GClock, ReasonablyRealtime};
5use prometheus_client::{
6 encoding::{text::encode, EncodeLabelSet},
7 metrics::{counter::Counter, family::Family, gauge::Gauge},
8 registry::{Metric, Registry},
9};
10use rand::{rngs::OsRng, CryptoRng, RngCore};
11use std::{
12 env,
13 future::Future,
14 io,
15 net::SocketAddr,
16 path::PathBuf,
17 sync::{Arc, Mutex},
18 time::{Duration, SystemTime},
19};
20use tokio::{
21 io::{AsyncReadExt, AsyncWriteExt},
22 net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
23 runtime::{Builder, Runtime},
24 time::timeout,
25};
26use tracing::warn;
27
28#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
29struct Work {
30 label: String,
31}
32
33#[derive(Debug)]
34struct Metrics {
35 tasks_spawned: Family<Work, Counter>,
36 tasks_running: Family<Work, Gauge>,
37 blocking_tasks_spawned: Family<Work, Counter>,
38 blocking_tasks_running: Family<Work, Gauge>,
39
40 inbound_connections: Counter,
43 outbound_connections: Counter,
44 inbound_bandwidth: Counter,
45 outbound_bandwidth: Counter,
46}
47
48impl Metrics {
49 pub fn init(registry: &mut Registry) -> Self {
50 let metrics = Self {
51 tasks_spawned: Family::default(),
52 tasks_running: Family::default(),
53 blocking_tasks_spawned: Family::default(),
54 blocking_tasks_running: Family::default(),
55 inbound_connections: Counter::default(),
56 outbound_connections: Counter::default(),
57 inbound_bandwidth: Counter::default(),
58 outbound_bandwidth: Counter::default(),
59 };
60 registry.register(
61 "tasks_spawned",
62 "Total number of tasks spawned",
63 metrics.tasks_spawned.clone(),
64 );
65 registry.register(
66 "tasks_running",
67 "Number of tasks currently running",
68 metrics.tasks_running.clone(),
69 );
70 registry.register(
71 "blocking_tasks_spawned",
72 "Total number of blocking tasks spawned",
73 metrics.blocking_tasks_spawned.clone(),
74 );
75 registry.register(
76 "blocking_tasks_running",
77 "Number of blocking tasks currently running",
78 metrics.blocking_tasks_running.clone(),
79 );
80 registry.register(
81 "inbound_connections",
82 "Number of connections created by dialing us",
83 metrics.inbound_connections.clone(),
84 );
85 registry.register(
86 "outbound_connections",
87 "Number of connections created by dialing others",
88 metrics.outbound_connections.clone(),
89 );
90 registry.register(
91 "inbound_bandwidth",
92 "Bandwidth used by receiving data from others",
93 metrics.inbound_bandwidth.clone(),
94 );
95 registry.register(
96 "outbound_bandwidth",
97 "Bandwidth used by sending data to others",
98 metrics.outbound_bandwidth.clone(),
99 );
100 metrics
101 }
102}
103
104#[derive(Clone)]
106pub struct Config {
107 pub worker_threads: usize,
113
114 pub max_blocking_threads: usize,
122
123 pub catch_panics: bool,
125
126 pub read_timeout: Duration,
128
129 pub write_timeout: Duration,
131
132 pub tcp_nodelay: Option<bool>,
143
144 pub storage_directory: PathBuf,
146
147 pub maximum_buffer_size: usize,
151}
152
153impl Default for Config {
154 fn default() -> Self {
155 let rng = OsRng.next_u64();
157 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
158
159 Self {
161 worker_threads: 2,
162 max_blocking_threads: 512,
163 catch_panics: true,
164 read_timeout: Duration::from_secs(60),
165 write_timeout: Duration::from_secs(30),
166 tcp_nodelay: None,
167 storage_directory,
168 maximum_buffer_size: 2 * 1024 * 1024, }
170 }
171}
172
173pub struct Executor {
175 cfg: Config,
176 registry: Mutex<Registry>,
177 metrics: Arc<Metrics>,
178 runtime: Runtime,
179 signaler: Mutex<Signaler>,
180 signal: Signal,
181}
182
183impl Executor {
184 pub fn init(cfg: Config) -> (Runner, Context) {
186 let mut registry = Registry::default();
188 let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
189
190 let metrics = Arc::new(Metrics::init(runtime_registry));
192 let runtime = Builder::new_multi_thread()
193 .worker_threads(cfg.worker_threads)
194 .max_blocking_threads(cfg.max_blocking_threads)
195 .enable_all()
196 .build()
197 .expect("failed to create Tokio runtime");
198 let (signaler, signal) = Signaler::new();
199
200 let storage = Storage::new(
201 TokioStorage::new(TokioStorageConfig::new(
202 cfg.storage_directory.clone(),
203 cfg.maximum_buffer_size,
204 )),
205 runtime_registry,
206 );
207
208 let executor = Arc::new(Self {
209 cfg,
210 registry: Mutex::new(registry),
211 metrics,
212 runtime,
213 signaler: Mutex::new(signaler),
214 signal,
215 });
216 (
217 Runner {
218 executor: executor.clone(),
219 },
220 Context {
221 storage,
222 label: String::new(),
223 spawned: false,
224 executor,
225 },
226 )
227 }
228
229 #[allow(clippy::should_implement_trait)]
232 pub fn default() -> (Runner, Context) {
233 Self::init(Config::default())
234 }
235}
236
237pub struct Runner {
239 executor: Arc<Executor>,
240}
241
242impl crate::Runner for Runner {
243 fn start<F>(self, f: F) -> F::Output
244 where
245 F: Future,
246 {
247 self.executor.runtime.block_on(f)
248 }
249}
250
251pub struct Context {
255 label: String,
256 spawned: bool,
257 executor: Arc<Executor>,
258 storage: Storage<TokioStorage>,
259}
260
261impl Clone for Context {
262 fn clone(&self) -> Self {
263 Self {
264 label: self.label.clone(),
265 spawned: false,
266 executor: self.executor.clone(),
267 storage: self.storage.clone(),
268 }
269 }
270}
271
272impl crate::Spawner for Context {
273 fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
274 where
275 F: FnOnce(Self) -> Fut + Send + 'static,
276 Fut: Future<Output = T> + Send + 'static,
277 T: Send + 'static,
278 {
279 assert!(!self.spawned, "already spawned");
281
282 let work = Work {
284 label: self.label.clone(),
285 };
286 self.executor
287 .metrics
288 .tasks_spawned
289 .get_or_create(&work)
290 .inc();
291 let gauge = self
292 .executor
293 .metrics
294 .tasks_running
295 .get_or_create(&work)
296 .clone();
297
298 let catch_panics = self.executor.cfg.catch_panics;
300 let executor = self.executor.clone();
301 let future = f(self);
302 let (f, handle) = Handle::init(future, gauge, catch_panics);
303
304 executor.runtime.spawn(f);
306 handle
307 }
308
309 fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
310 where
311 F: Future<Output = T> + Send + 'static,
312 T: Send + 'static,
313 {
314 assert!(!self.spawned, "already spawned");
316 self.spawned = true;
317
318 let work = Work {
320 label: self.label.clone(),
321 };
322 self.executor
323 .metrics
324 .tasks_spawned
325 .get_or_create(&work)
326 .inc();
327 let gauge = self
328 .executor
329 .metrics
330 .tasks_running
331 .get_or_create(&work)
332 .clone();
333
334 let executor = self.executor.clone();
336 move |f: F| {
337 let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
338
339 executor.runtime.spawn(f);
341 handle
342 }
343 }
344
345 fn spawn_blocking<F, T>(self, f: F) -> Handle<T>
346 where
347 F: FnOnce() -> T + Send + 'static,
348 T: Send + 'static,
349 {
350 assert!(!self.spawned, "already spawned");
352
353 let work = Work {
355 label: self.label.clone(),
356 };
357 self.executor
358 .metrics
359 .blocking_tasks_spawned
360 .get_or_create(&work)
361 .inc();
362 let gauge = self
363 .executor
364 .metrics
365 .blocking_tasks_running
366 .get_or_create(&work)
367 .clone();
368
369 let (f, handle) = Handle::init_blocking(f, gauge, self.executor.cfg.catch_panics);
371
372 self.executor.runtime.spawn_blocking(f);
374 handle
375 }
376
377 fn stop(&self, value: i32) {
378 self.executor.signaler.lock().unwrap().signal(value);
379 }
380
381 fn stopped(&self) -> Signal {
382 self.executor.signal.clone()
383 }
384}
385
386impl crate::Metrics for Context {
387 fn with_label(&self, label: &str) -> Self {
388 let label = {
389 let prefix = self.label.clone();
390 if prefix.is_empty() {
391 label.to_string()
392 } else {
393 format!("{}_{}", prefix, label)
394 }
395 };
396 assert!(
397 !label.starts_with(METRICS_PREFIX),
398 "using runtime label is not allowed"
399 );
400 Self {
401 label,
402 spawned: false,
403 executor: self.executor.clone(),
404 storage: self.storage.clone(),
405 }
406 }
407
408 fn label(&self) -> String {
409 self.label.clone()
410 }
411
412 fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
413 let name = name.into();
414 let prefixed_name = {
415 let prefix = &self.label;
416 if prefix.is_empty() {
417 name
418 } else {
419 format!("{}_{}", *prefix, name)
420 }
421 };
422 self.executor
423 .registry
424 .lock()
425 .unwrap()
426 .register(prefixed_name, help, metric)
427 }
428
429 fn encode(&self) -> String {
430 let mut buffer = String::new();
431 encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
432 buffer
433 }
434}
435
436impl Clock for Context {
437 fn current(&self) -> SystemTime {
438 SystemTime::now()
439 }
440
441 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
442 tokio::time::sleep(duration)
443 }
444
445 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
446 let now = SystemTime::now();
447 let duration_until_deadline = match deadline.duration_since(now) {
448 Ok(duration) => duration,
449 Err(_) => Duration::from_secs(0), };
451 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
452 tokio::time::sleep_until(target_instant)
453 }
454}
455
456impl GClock for Context {
457 type Instant = SystemTime;
458
459 fn now(&self) -> Self::Instant {
460 self.current()
461 }
462}
463
464impl ReasonablyRealtime for Context {}
465
466impl crate::Network<Listener, Sink, Stream> for Context {
467 async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
468 TcpListener::bind(socket)
469 .await
470 .map_err(|_| Error::BindFailed)
471 .map(|listener| Listener {
472 context: self.clone(),
473 listener,
474 })
475 }
476
477 async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
478 let stream = TcpStream::connect(socket)
480 .await
481 .map_err(|_| Error::ConnectionFailed)?;
482 self.executor.metrics.outbound_connections.inc();
483
484 if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
486 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
487 warn!(?err, "failed to set TCP_NODELAY");
488 }
489 }
490
491 let context = self.clone();
493 let (stream, sink) = stream.into_split();
494 Ok((
495 Sink {
496 context: context.clone(),
497 sink,
498 },
499 Stream { context, stream },
500 ))
501 }
502}
503
504pub struct Listener {
506 context: Context,
507 listener: TcpListener,
508}
509
510impl crate::Listener<Sink, Stream> for Listener {
511 async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
512 let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
514 self.context.executor.metrics.inbound_connections.inc();
515
516 if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
518 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
519 warn!(?err, "failed to set TCP_NODELAY");
520 }
521 }
522
523 let context = self.context.clone();
525 let (stream, sink) = stream.into_split();
526 Ok((
527 addr,
528 Sink {
529 context: context.clone(),
530 sink,
531 },
532 Stream { context, stream },
533 ))
534 }
535}
536
537impl axum::serve::Listener for Listener {
538 type Io = TcpStream;
539 type Addr = SocketAddr;
540
541 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
542 let (stream, addr) = self.listener.accept().await.unwrap();
543 (stream, addr)
544 }
545
546 fn local_addr(&self) -> io::Result<Self::Addr> {
547 self.listener.local_addr()
548 }
549}
550
551pub struct Sink {
553 context: Context,
554 sink: OwnedWriteHalf,
555}
556
557impl crate::Sink for Sink {
558 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
559 let len = msg.len();
560 timeout(
561 self.context.executor.cfg.write_timeout,
562 self.sink.write_all(msg),
563 )
564 .await
565 .map_err(|_| Error::Timeout)?
566 .map_err(|_| Error::SendFailed)?;
567 self.context
568 .executor
569 .metrics
570 .outbound_bandwidth
571 .inc_by(len as u64);
572 Ok(())
573 }
574}
575
576pub struct Stream {
578 context: Context,
579 stream: OwnedReadHalf,
580}
581
582impl crate::Stream for Stream {
583 async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
584 timeout(
586 self.context.executor.cfg.read_timeout,
587 self.stream.read_exact(buf),
588 )
589 .await
590 .map_err(|_| Error::Timeout)?
591 .map_err(|_| Error::RecvFailed)?;
592
593 self.context
595 .executor
596 .metrics
597 .inbound_bandwidth
598 .inc_by(buf.len() as u64);
599
600 Ok(())
601 }
602}
603
604impl RngCore for Context {
605 fn next_u32(&mut self) -> u32 {
606 OsRng.next_u32()
607 }
608
609 fn next_u64(&mut self) -> u64 {
610 OsRng.next_u64()
611 }
612
613 fn fill_bytes(&mut self, dest: &mut [u8]) {
614 OsRng.fill_bytes(dest);
615 }
616
617 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
618 OsRng.try_fill_bytes(dest)
619 }
620}
621
622impl CryptoRng for Context {}
623
624impl crate::Storage for Context {
625 type Blob = <Storage<TokioStorage> as crate::Storage>::Blob;
626
627 async fn open(&self, partition: &str, name: &[u8]) -> Result<Self::Blob, Error> {
628 self.storage.open(partition, name).await
629 }
630
631 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
632 self.storage.remove(partition, name).await
633 }
634
635 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
636 self.storage.scan(partition).await
637 }
638}