1use crate::{utils::Signaler, Clock, Error, Handle, Signal};
27use commonware_utils::{from_hex, hex};
28use governor::clock::{Clock as GClock, ReasonablyRealtime};
29use prometheus_client::{
30 encoding::EncodeLabelSet,
31 metrics::{counter::Counter, family::Family, gauge::Gauge},
32 registry::Registry,
33};
34use rand::{rngs::OsRng, CryptoRng, RngCore};
35use std::{
36 env,
37 future::Future,
38 io::SeekFrom,
39 net::SocketAddr,
40 path::PathBuf,
41 sync::{Arc, Mutex},
42 time::{Duration, SystemTime},
43};
44use tokio::{
45 fs,
46 io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
47 net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
48 runtime::{Builder, Runtime},
49 sync::Mutex as AsyncMutex,
50 task_local,
51 time::timeout,
52};
53use tracing::warn;
54
55#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
56struct Work {
57 label: String,
58}
59
60#[derive(Debug)]
61struct Metrics {
62 tasks_spawned: Family<Work, Counter>,
63 tasks_running: Family<Work, Gauge>,
64
65 inbound_connections: Counter,
68 outbound_connections: Counter,
69 inbound_bandwidth: Counter,
70 outbound_bandwidth: Counter,
71
72 open_blobs: Gauge,
73 storage_reads: Counter,
74 storage_read_bytes: Counter,
75 storage_writes: Counter,
76 storage_write_bytes: Counter,
77}
78
79impl Metrics {
80 pub fn init(registry: Arc<Mutex<Registry>>) -> Self {
81 let metrics = Self {
82 tasks_spawned: Family::default(),
83 tasks_running: Family::default(),
84 inbound_connections: Counter::default(),
85 outbound_connections: Counter::default(),
86 inbound_bandwidth: Counter::default(),
87 outbound_bandwidth: Counter::default(),
88 open_blobs: Gauge::default(),
89 storage_reads: Counter::default(),
90 storage_read_bytes: Counter::default(),
91 storage_writes: Counter::default(),
92 storage_write_bytes: Counter::default(),
93 };
94 {
95 let mut registry = registry.lock().unwrap();
96 registry.register(
97 "tasks_spawned",
98 "Total number of tasks spawned",
99 metrics.tasks_spawned.clone(),
100 );
101 registry.register(
102 "tasks_running",
103 "Number of tasks currently running",
104 metrics.tasks_running.clone(),
105 );
106 registry.register(
107 "inbound_connections",
108 "Number of connections created by dialing us",
109 metrics.inbound_connections.clone(),
110 );
111 registry.register(
112 "outbound_connections",
113 "Number of connections created by dialing others",
114 metrics.outbound_connections.clone(),
115 );
116 registry.register(
117 "inbound_bandwidth",
118 "Bandwidth used by receiving data from others",
119 metrics.inbound_bandwidth.clone(),
120 );
121 registry.register(
122 "outbound_bandwidth",
123 "Bandwidth used by sending data to others",
124 metrics.outbound_bandwidth.clone(),
125 );
126 registry.register(
127 "open_blobs",
128 "Number of open blobs",
129 metrics.open_blobs.clone(),
130 );
131 registry.register(
132 "storage_reads",
133 "Total number of disk reads",
134 metrics.storage_reads.clone(),
135 );
136 registry.register(
137 "storage_read_bytes",
138 "Total amount of data read from disk",
139 metrics.storage_read_bytes.clone(),
140 );
141 registry.register(
142 "storage_writes",
143 "Total number of disk writes",
144 metrics.storage_writes.clone(),
145 );
146 registry.register(
147 "storage_write_bytes",
148 "Total amount of data written to disk",
149 metrics.storage_write_bytes.clone(),
150 );
151 }
152 metrics
153 }
154}
155
156#[derive(Clone)]
158pub struct Config {
159 pub registry: Arc<Mutex<Registry>>,
161
162 pub threads: usize,
164
165 pub catch_panics: bool,
167
168 pub read_timeout: Duration,
170
171 pub write_timeout: Duration,
173
174 pub tcp_nodelay: Option<bool>,
185
186 pub storage_directory: PathBuf,
188
189 pub maximum_buffer_size: usize,
193}
194
195impl Default for Config {
196 fn default() -> Self {
197 let rng = OsRng.next_u64();
199 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
200
201 Self {
203 registry: Arc::new(Mutex::new(Registry::default())),
204 threads: 2,
205 catch_panics: true,
206 read_timeout: Duration::from_secs(60),
207 write_timeout: Duration::from_secs(30),
208 tcp_nodelay: None,
209 storage_directory,
210 maximum_buffer_size: 2 * 1024 * 1024, }
212 }
213}
214
215pub struct Executor {
217 cfg: Config,
218 metrics: Arc<Metrics>,
219 runtime: Runtime,
220 fs: AsyncMutex<()>,
221 signaler: Mutex<Signaler>,
222 signal: Signal,
223}
224
225impl Executor {
226 pub fn init(cfg: Config) -> (Runner, Context) {
228 let metrics = Arc::new(Metrics::init(cfg.registry.clone()));
229 let runtime = Builder::new_multi_thread()
230 .worker_threads(cfg.threads)
231 .enable_all()
232 .build()
233 .expect("failed to create Tokio runtime");
234 let (signaler, signal) = Signaler::new();
235 let executor = Arc::new(Self {
236 cfg,
237 metrics,
238 runtime,
239 fs: AsyncMutex::new(()),
240 signaler: Mutex::new(signaler),
241 signal,
242 });
243 (
244 Runner {
245 executor: executor.clone(),
246 },
247 Context { executor },
248 )
249 }
250
251 #[allow(clippy::should_implement_trait)]
254 pub fn default() -> (Runner, Context) {
255 Self::init(Config::default())
256 }
257}
258
259pub struct Runner {
261 executor: Arc<Executor>,
262}
263
264impl crate::Runner for Runner {
265 fn start<F>(self, f: F) -> F::Output
266 where
267 F: Future + Send + 'static,
268 F::Output: Send + 'static,
269 {
270 self.executor.runtime.block_on(f)
271 }
272}
273
274#[derive(Clone)]
278pub struct Context {
279 executor: Arc<Executor>,
280}
281
282task_local! {
283 static PREFIX: String;
284}
285
286impl crate::Spawner for Context {
287 fn spawn<F, T>(&self, label: &str, f: F) -> Handle<T>
288 where
289 F: Future<Output = T> + Send + 'static,
290 T: Send + 'static,
291 {
292 let label = PREFIX
293 .try_with(|prefix| format!("{}_{}", prefix, label))
294 .unwrap_or_else(|_| label.to_string());
295 let f = PREFIX.scope(label.clone(), f);
296 let work = Work { label };
297 self.executor
298 .metrics
299 .tasks_spawned
300 .get_or_create(&work)
301 .inc();
302 let gauge = self
303 .executor
304 .metrics
305 .tasks_running
306 .get_or_create(&work)
307 .clone();
308 let (f, handle) = Handle::init(f, gauge, self.executor.cfg.catch_panics);
309 self.executor.runtime.spawn(f);
310 handle
311 }
312
313 fn stop(&self, value: i32) {
314 self.executor.signaler.lock().unwrap().signal(value);
315 }
316
317 fn stopped(&self) -> Signal {
318 self.executor.signal.clone()
319 }
320}
321
322impl Clock for Context {
323 fn current(&self) -> SystemTime {
324 SystemTime::now()
325 }
326
327 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
328 tokio::time::sleep(duration)
329 }
330
331 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
332 let now = SystemTime::now();
333 let duration_until_deadline = match deadline.duration_since(now) {
334 Ok(duration) => duration,
335 Err(_) => Duration::from_secs(0), };
337 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
338 tokio::time::sleep_until(target_instant)
339 }
340}
341
342impl GClock for Context {
343 type Instant = SystemTime;
344
345 fn now(&self) -> Self::Instant {
346 self.current()
347 }
348}
349
350impl ReasonablyRealtime for Context {}
351
352impl crate::Network<Listener, Sink, Stream> for Context {
353 async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
354 TcpListener::bind(socket)
355 .await
356 .map_err(|_| Error::BindFailed)
357 .map(|listener| Listener {
358 context: self.clone(),
359 listener,
360 })
361 }
362
363 async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
364 let stream = TcpStream::connect(socket)
366 .await
367 .map_err(|_| Error::ConnectionFailed)?;
368 self.executor.metrics.outbound_connections.inc();
369
370 if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
372 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
373 warn!(?err, "failed to set TCP_NODELAY");
374 }
375 }
376
377 let context = self.clone();
379 let (stream, sink) = stream.into_split();
380 Ok((
381 Sink {
382 context: context.clone(),
383 sink,
384 },
385 Stream { context, stream },
386 ))
387 }
388}
389
390pub struct Listener {
392 context: Context,
393 listener: TcpListener,
394}
395
396impl crate::Listener<Sink, Stream> for Listener {
397 async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
398 let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
400 self.context.executor.metrics.inbound_connections.inc();
401
402 if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
404 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
405 warn!(?err, "failed to set TCP_NODELAY");
406 }
407 }
408
409 let context = self.context.clone();
411 let (stream, sink) = stream.into_split();
412 Ok((
413 addr,
414 Sink {
415 context: context.clone(),
416 sink,
417 },
418 Stream { context, stream },
419 ))
420 }
421}
422
423pub struct Sink {
425 context: Context,
426 sink: OwnedWriteHalf,
427}
428
429impl crate::Sink for Sink {
430 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
431 let len = msg.len();
432 timeout(
433 self.context.executor.cfg.write_timeout,
434 self.sink.write_all(msg),
435 )
436 .await
437 .map_err(|_| Error::Timeout)?
438 .map_err(|_| Error::SendFailed)?;
439 self.context
440 .executor
441 .metrics
442 .outbound_bandwidth
443 .inc_by(len as u64);
444 Ok(())
445 }
446}
447
448pub struct Stream {
450 context: Context,
451 stream: OwnedReadHalf,
452}
453
454impl crate::Stream for Stream {
455 async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
456 timeout(
458 self.context.executor.cfg.read_timeout,
459 self.stream.read_exact(buf),
460 )
461 .await
462 .map_err(|_| Error::Timeout)?
463 .map_err(|_| Error::RecvFailed)?;
464
465 self.context
467 .executor
468 .metrics
469 .inbound_bandwidth
470 .inc_by(buf.len() as u64);
471
472 Ok(())
473 }
474}
475
476impl RngCore for Context {
477 fn next_u32(&mut self) -> u32 {
478 OsRng.next_u32()
479 }
480
481 fn next_u64(&mut self) -> u64 {
482 OsRng.next_u64()
483 }
484
485 fn fill_bytes(&mut self, dest: &mut [u8]) {
486 OsRng.fill_bytes(dest);
487 }
488
489 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
490 OsRng.try_fill_bytes(dest)
491 }
492}
493
494impl CryptoRng for Context {}
495
496pub struct Blob {
498 metrics: Arc<Metrics>,
499
500 partition: String,
501 name: Vec<u8>,
502
503 file: Arc<AsyncMutex<(fs::File, u64)>>,
510}
511
512impl Blob {
513 fn new(
514 metrics: Arc<Metrics>,
515 partition: String,
516 name: &[u8],
517 file: fs::File,
518 len: u64,
519 ) -> Self {
520 metrics.open_blobs.inc();
521 Self {
522 metrics,
523 partition,
524 name: name.into(),
525 file: Arc::new(AsyncMutex::new((file, len))),
526 }
527 }
528}
529
530impl Clone for Blob {
531 fn clone(&self) -> Self {
532 self.metrics.open_blobs.inc();
534 Self {
535 metrics: self.metrics.clone(),
536 partition: self.partition.clone(),
537 name: self.name.clone(),
538 file: self.file.clone(),
539 }
540 }
541}
542
543impl crate::Storage<Blob> for Context {
544 async fn open(&self, partition: &str, name: &[u8]) -> Result<Blob, Error> {
545 let _guard = self.executor.fs.lock().await;
547
548 let path = self
550 .executor
551 .cfg
552 .storage_directory
553 .join(partition)
554 .join(hex(name));
555 let parent = match path.parent() {
556 Some(parent) => parent,
557 None => return Err(Error::PartitionCreationFailed(partition.into())),
558 };
559
560 fs::create_dir_all(parent)
562 .await
563 .map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
564
565 let mut file = fs::OpenOptions::new()
567 .read(true)
568 .write(true)
569 .create(true)
570 .truncate(false)
571 .open(&path)
572 .await
573 .map_err(|_| Error::BlobOpenFailed(partition.into(), hex(name)))?;
574
575 file.set_max_buf_size(self.executor.cfg.maximum_buffer_size);
577
578 let len = file.metadata().await.map_err(|_| Error::ReadFailed)?.len();
580
581 Ok(Blob::new(
583 self.executor.metrics.clone(),
584 partition.into(),
585 name,
586 file,
587 len,
588 ))
589 }
590
591 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
592 let _guard = self.executor.fs.lock().await;
594
595 let path = self.executor.cfg.storage_directory.join(partition);
597 if let Some(name) = name {
598 let blob_path = path.join(hex(name));
599 fs::remove_file(blob_path)
600 .await
601 .map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
602 } else {
603 fs::remove_dir_all(path)
604 .await
605 .map_err(|_| Error::PartitionMissing(partition.into()))?;
606 }
607 Ok(())
608 }
609
610 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
611 let _guard = self.executor.fs.lock().await;
613
614 let path = self.executor.cfg.storage_directory.join(partition);
616 let mut entries = fs::read_dir(path)
617 .await
618 .map_err(|_| Error::PartitionMissing(partition.into()))?;
619 let mut blobs = Vec::new();
620 while let Some(entry) = entries.next_entry().await.map_err(|_| Error::ReadFailed)? {
621 let file_type = entry.file_type().await.map_err(|_| Error::ReadFailed)?;
622 if !file_type.is_file() {
623 return Err(Error::PartitionCorrupt(partition.into()));
624 }
625 if let Some(name) = entry.file_name().to_str() {
626 let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
627 blobs.push(name);
628 }
629 }
630 Ok(blobs)
631 }
632}
633
634impl crate::Blob for Blob {
635 async fn len(&self) -> Result<u64, Error> {
636 let (_, len) = *self.file.lock().await;
637 Ok(len)
638 }
639
640 async fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<(), Error> {
641 let mut file = self.file.lock().await;
643 if offset + buf.len() as u64 > file.1 {
644 return Err(Error::BlobInsufficientLength);
645 }
646
647 file.0
649 .seek(SeekFrom::Start(offset))
650 .await
651 .map_err(|_| Error::ReadFailed)?;
652 file.0
653 .read_exact(buf)
654 .await
655 .map_err(|_| Error::ReadFailed)?;
656 self.metrics.storage_reads.inc();
657 self.metrics.storage_read_bytes.inc_by(buf.len() as u64);
658 Ok(())
659 }
660
661 async fn write_at(&self, buf: &[u8], offset: u64) -> Result<(), Error> {
662 let mut file = self.file.lock().await;
664 file.0
665 .seek(SeekFrom::Start(offset))
666 .await
667 .map_err(|_| Error::WriteFailed)?;
668 file.0
669 .write_all(buf)
670 .await
671 .map_err(|_| Error::WriteFailed)?;
672
673 let max_len = offset + buf.len() as u64;
675 if max_len > file.1 {
676 file.1 = max_len;
677 }
678 self.metrics.storage_writes.inc();
679 self.metrics.storage_write_bytes.inc_by(buf.len() as u64);
680 Ok(())
681 }
682
683 async fn truncate(&self, len: u64) -> Result<(), Error> {
684 let mut file = self.file.lock().await;
686 file.0
687 .set_len(len)
688 .await
689 .map_err(|_| Error::BlobTruncateFailed(self.partition.clone(), hex(&self.name)))?;
690
691 file.1 = len;
693 Ok(())
694 }
695
696 async fn sync(&self) -> Result<(), Error> {
697 let file = self.file.lock().await;
698 file.0
699 .sync_all()
700 .await
701 .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))
702 }
703
704 async fn close(self) -> Result<(), Error> {
705 let mut file = self.file.lock().await;
706 file.0
707 .sync_all()
708 .await
709 .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))?;
710 file.0
711 .shutdown()
712 .await
713 .map_err(|_| Error::BlobCloseFailed(self.partition.clone(), hex(&self.name)))
714 }
715}
716
717impl Drop for Blob {
718 fn drop(&mut self) {
719 self.metrics.open_blobs.dec();
720 }
721}