1use crate::storage::metered::Storage;
2use crate::storage::tokio::{Config as TokioStorageConfig, Storage as TokioStorage};
3use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
4use governor::clock::{Clock as GClock, ReasonablyRealtime};
5use prometheus_client::{
6 encoding::{text::encode, EncodeLabelSet},
7 metrics::{counter::Counter, family::Family, gauge::Gauge},
8 registry::{Metric, Registry},
9};
10use rand::{rngs::OsRng, CryptoRng, RngCore};
11use std::{
12 env,
13 future::Future,
14 io,
15 net::SocketAddr,
16 path::PathBuf,
17 sync::{Arc, Mutex},
18 time::{Duration, SystemTime},
19};
20use tokio::{
21 io::{AsyncReadExt, AsyncWriteExt},
22 net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
23 runtime::{Builder, Runtime},
24 time::timeout,
25};
26use tracing::warn;
27
28#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
29struct Work {
30 label: String,
31}
32
33#[derive(Debug)]
34struct Metrics {
35 tasks_spawned: Family<Work, Counter>,
36 tasks_running: Family<Work, Gauge>,
37 blocking_tasks_spawned: Family<Work, Counter>,
38 blocking_tasks_running: Family<Work, Gauge>,
39
40 inbound_connections: Counter,
43 outbound_connections: Counter,
44 inbound_bandwidth: Counter,
45 outbound_bandwidth: Counter,
46}
47
48impl Metrics {
49 pub fn init(registry: &mut Registry) -> Self {
50 let metrics = Self {
51 tasks_spawned: Family::default(),
52 tasks_running: Family::default(),
53 blocking_tasks_spawned: Family::default(),
54 blocking_tasks_running: Family::default(),
55 inbound_connections: Counter::default(),
56 outbound_connections: Counter::default(),
57 inbound_bandwidth: Counter::default(),
58 outbound_bandwidth: Counter::default(),
59 };
60 registry.register(
61 "tasks_spawned",
62 "Total number of tasks spawned",
63 metrics.tasks_spawned.clone(),
64 );
65 registry.register(
66 "tasks_running",
67 "Number of tasks currently running",
68 metrics.tasks_running.clone(),
69 );
70 registry.register(
71 "blocking_tasks_spawned",
72 "Total number of blocking tasks spawned",
73 metrics.blocking_tasks_spawned.clone(),
74 );
75 registry.register(
76 "blocking_tasks_running",
77 "Number of blocking tasks currently running",
78 metrics.blocking_tasks_running.clone(),
79 );
80 registry.register(
81 "inbound_connections",
82 "Number of connections created by dialing us",
83 metrics.inbound_connections.clone(),
84 );
85 registry.register(
86 "outbound_connections",
87 "Number of connections created by dialing others",
88 metrics.outbound_connections.clone(),
89 );
90 registry.register(
91 "inbound_bandwidth",
92 "Bandwidth used by receiving data from others",
93 metrics.inbound_bandwidth.clone(),
94 );
95 registry.register(
96 "outbound_bandwidth",
97 "Bandwidth used by sending data to others",
98 metrics.outbound_bandwidth.clone(),
99 );
100 metrics
101 }
102}
103
104#[derive(Clone)]
106pub struct Config {
107 pub worker_threads: usize,
113
114 pub max_blocking_threads: usize,
122
123 pub catch_panics: bool,
125
126 pub read_timeout: Duration,
128
129 pub write_timeout: Duration,
131
132 pub tcp_nodelay: Option<bool>,
143
144 pub storage_directory: PathBuf,
146
147 pub maximum_buffer_size: usize,
151}
152
153impl Default for Config {
154 fn default() -> Self {
155 let rng = OsRng.next_u64();
157 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
158
159 Self {
161 worker_threads: 2,
162 max_blocking_threads: 512,
163 catch_panics: true,
164 read_timeout: Duration::from_secs(60),
165 write_timeout: Duration::from_secs(30),
166 tcp_nodelay: None,
167 storage_directory,
168 maximum_buffer_size: 2 * 1024 * 1024, }
170 }
171}
172
173pub struct Executor {
175 cfg: Config,
176 registry: Mutex<Registry>,
177 metrics: Arc<Metrics>,
178 runtime: Runtime,
179 signaler: Mutex<Signaler>,
180 signal: Signal,
181}
182
183pub struct Runner {
185 cfg: Config,
186}
187
188impl Default for Runner {
189 fn default() -> Self {
190 Self::new(Config::default())
191 }
192}
193
194impl Runner {
195 pub fn new(cfg: Config) -> Self {
197 Self { cfg }
198 }
199}
200
201impl crate::Runner for Runner {
202 type Context = Context;
203
204 fn start<F, Fut>(self, f: F) -> Fut::Output
205 where
206 F: FnOnce(Self::Context) -> Fut,
207 Fut: Future,
208 {
209 let mut registry = Registry::default();
211 let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
212
213 let metrics = Arc::new(Metrics::init(runtime_registry));
215 let runtime = Builder::new_multi_thread()
216 .worker_threads(self.cfg.worker_threads)
217 .max_blocking_threads(self.cfg.max_blocking_threads)
218 .enable_all()
219 .build()
220 .expect("failed to create Tokio runtime");
221 let (signaler, signal) = Signaler::new();
222
223 let storage = Storage::new(
224 TokioStorage::new(TokioStorageConfig::new(
225 self.cfg.storage_directory.clone(),
226 self.cfg.maximum_buffer_size,
227 )),
228 runtime_registry,
229 );
230
231 let executor = Arc::new(Executor {
232 cfg: self.cfg,
233 registry: Mutex::new(registry),
234 metrics,
235 runtime,
236 signaler: Mutex::new(signaler),
237 signal,
238 });
239
240 let context = Context {
241 storage,
242 label: String::new(),
243 spawned: false,
244 executor: executor.clone(),
245 };
246
247 executor.runtime.block_on(f(context))
248 }
249}
250
251pub struct Context {
255 label: String,
256 spawned: bool,
257 executor: Arc<Executor>,
258 storage: Storage<TokioStorage>,
259}
260
261impl Clone for Context {
262 fn clone(&self) -> Self {
263 Self {
264 label: self.label.clone(),
265 spawned: false,
266 executor: self.executor.clone(),
267 storage: self.storage.clone(),
268 }
269 }
270}
271
272impl crate::Spawner for Context {
273 fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
274 where
275 F: FnOnce(Self) -> Fut + Send + 'static,
276 Fut: Future<Output = T> + Send + 'static,
277 T: Send + 'static,
278 {
279 assert!(!self.spawned, "already spawned");
281
282 let work = Work {
284 label: self.label.clone(),
285 };
286 self.executor
287 .metrics
288 .tasks_spawned
289 .get_or_create(&work)
290 .inc();
291 let gauge = self
292 .executor
293 .metrics
294 .tasks_running
295 .get_or_create(&work)
296 .clone();
297
298 let catch_panics = self.executor.cfg.catch_panics;
300 let executor = self.executor.clone();
301 let future = f(self);
302 let (f, handle) = Handle::init(future, gauge, catch_panics);
303
304 executor.runtime.spawn(f);
306 handle
307 }
308
309 fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
310 where
311 F: Future<Output = T> + Send + 'static,
312 T: Send + 'static,
313 {
314 assert!(!self.spawned, "already spawned");
316 self.spawned = true;
317
318 let work = Work {
320 label: self.label.clone(),
321 };
322 self.executor
323 .metrics
324 .tasks_spawned
325 .get_or_create(&work)
326 .inc();
327 let gauge = self
328 .executor
329 .metrics
330 .tasks_running
331 .get_or_create(&work)
332 .clone();
333
334 let executor = self.executor.clone();
336 move |f: F| {
337 let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
338
339 executor.runtime.spawn(f);
341 handle
342 }
343 }
344
345 fn spawn_blocking<F, T>(self, f: F) -> Handle<T>
346 where
347 F: FnOnce() -> T + Send + 'static,
348 T: Send + 'static,
349 {
350 assert!(!self.spawned, "already spawned");
352
353 let work = Work {
355 label: self.label.clone(),
356 };
357 self.executor
358 .metrics
359 .blocking_tasks_spawned
360 .get_or_create(&work)
361 .inc();
362 let gauge = self
363 .executor
364 .metrics
365 .blocking_tasks_running
366 .get_or_create(&work)
367 .clone();
368
369 let (f, handle) = Handle::init_blocking(f, gauge, self.executor.cfg.catch_panics);
371
372 self.executor.runtime.spawn_blocking(f);
374 handle
375 }
376
377 fn stop(&self, value: i32) {
378 self.executor.signaler.lock().unwrap().signal(value);
379 }
380
381 fn stopped(&self) -> Signal {
382 self.executor.signal.clone()
383 }
384}
385
386impl crate::Metrics for Context {
387 fn with_label(&self, label: &str) -> Self {
388 let label = {
389 let prefix = self.label.clone();
390 if prefix.is_empty() {
391 label.to_string()
392 } else {
393 format!("{}_{}", prefix, label)
394 }
395 };
396 assert!(
397 !label.starts_with(METRICS_PREFIX),
398 "using runtime label is not allowed"
399 );
400 Self {
401 label,
402 spawned: false,
403 executor: self.executor.clone(),
404 storage: self.storage.clone(),
405 }
406 }
407
408 fn label(&self) -> String {
409 self.label.clone()
410 }
411
412 fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
413 let name = name.into();
414 let prefixed_name = {
415 let prefix = &self.label;
416 if prefix.is_empty() {
417 name
418 } else {
419 format!("{}_{}", *prefix, name)
420 }
421 };
422 self.executor
423 .registry
424 .lock()
425 .unwrap()
426 .register(prefixed_name, help, metric)
427 }
428
429 fn encode(&self) -> String {
430 let mut buffer = String::new();
431 encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
432 buffer
433 }
434}
435
436impl Clock for Context {
437 fn current(&self) -> SystemTime {
438 SystemTime::now()
439 }
440
441 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
442 tokio::time::sleep(duration)
443 }
444
445 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
446 let now = SystemTime::now();
447 let duration_until_deadline = match deadline.duration_since(now) {
448 Ok(duration) => duration,
449 Err(_) => Duration::from_secs(0), };
451 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
452 tokio::time::sleep_until(target_instant)
453 }
454}
455
456impl GClock for Context {
457 type Instant = SystemTime;
458
459 fn now(&self) -> Self::Instant {
460 self.current()
461 }
462}
463
464impl ReasonablyRealtime for Context {}
465
466impl crate::Network<Listener, Sink, Stream> for Context {
467 async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
468 TcpListener::bind(socket)
469 .await
470 .map_err(|_| Error::BindFailed)
471 .map(|listener| Listener {
472 context: self.clone(),
473 listener,
474 })
475 }
476
477 async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
478 let stream = TcpStream::connect(socket)
480 .await
481 .map_err(|_| Error::ConnectionFailed)?;
482 self.executor.metrics.outbound_connections.inc();
483
484 if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
486 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
487 warn!(?err, "failed to set TCP_NODELAY");
488 }
489 }
490
491 let context = self.clone();
493 let (stream, sink) = stream.into_split();
494 Ok((
495 Sink {
496 context: context.clone(),
497 sink,
498 },
499 Stream { context, stream },
500 ))
501 }
502}
503
504pub struct Listener {
506 context: Context,
507 listener: TcpListener,
508}
509
510impl crate::Listener<Sink, Stream> for Listener {
511 async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
512 let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
514 self.context.executor.metrics.inbound_connections.inc();
515
516 if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
518 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
519 warn!(?err, "failed to set TCP_NODELAY");
520 }
521 }
522
523 let context = self.context.clone();
525 let (stream, sink) = stream.into_split();
526 Ok((
527 addr,
528 Sink {
529 context: context.clone(),
530 sink,
531 },
532 Stream { context, stream },
533 ))
534 }
535}
536
537impl axum::serve::Listener for Listener {
538 type Io = TcpStream;
539 type Addr = SocketAddr;
540
541 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
542 let (stream, addr) = self.listener.accept().await.unwrap();
543 (stream, addr)
544 }
545
546 fn local_addr(&self) -> io::Result<Self::Addr> {
547 self.listener.local_addr()
548 }
549}
550
551pub struct Sink {
553 context: Context,
554 sink: OwnedWriteHalf,
555}
556
557impl crate::Sink for Sink {
558 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
559 let len = msg.len();
560 timeout(
561 self.context.executor.cfg.write_timeout,
562 self.sink.write_all(msg),
563 )
564 .await
565 .map_err(|_| Error::Timeout)?
566 .map_err(|_| Error::SendFailed)?;
567 self.context
568 .executor
569 .metrics
570 .outbound_bandwidth
571 .inc_by(len as u64);
572 Ok(())
573 }
574}
575
576pub struct Stream {
578 context: Context,
579 stream: OwnedReadHalf,
580}
581
582impl crate::Stream for Stream {
583 async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
584 timeout(
586 self.context.executor.cfg.read_timeout,
587 self.stream.read_exact(buf),
588 )
589 .await
590 .map_err(|_| Error::Timeout)?
591 .map_err(|_| Error::RecvFailed)?;
592
593 self.context
595 .executor
596 .metrics
597 .inbound_bandwidth
598 .inc_by(buf.len() as u64);
599
600 Ok(())
601 }
602}
603
604impl RngCore for Context {
605 fn next_u32(&mut self) -> u32 {
606 OsRng.next_u32()
607 }
608
609 fn next_u64(&mut self) -> u64 {
610 OsRng.next_u64()
611 }
612
613 fn fill_bytes(&mut self, dest: &mut [u8]) {
614 OsRng.fill_bytes(dest);
615 }
616
617 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
618 OsRng.try_fill_bytes(dest)
619 }
620}
621
622impl CryptoRng for Context {}
623
624impl crate::Storage for Context {
625 type Blob = <Storage<TokioStorage> as crate::Storage>::Blob;
626
627 async fn open(&self, partition: &str, name: &[u8]) -> Result<(Self::Blob, u64), Error> {
628 self.storage.open(partition, name).await
629 }
630
631 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
632 self.storage.remove(partition, name).await
633 }
634
635 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
636 self.storage.scan(partition).await
637 }
638}