1use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
27use commonware_utils::{from_hex, hex};
28use governor::clock::{Clock as GClock, ReasonablyRealtime};
29use prometheus_client::{
30 encoding::{text::encode, EncodeLabelSet},
31 metrics::{counter::Counter, family::Family, gauge::Gauge},
32 registry::{Metric, Registry},
33};
34use rand::{rngs::OsRng, CryptoRng, RngCore};
35use std::{
36 env,
37 future::Future,
38 io::{self, SeekFrom},
39 net::SocketAddr,
40 path::PathBuf,
41 sync::{Arc, Mutex},
42 time::{Duration, SystemTime},
43};
44use tokio::{
45 fs,
46 io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
47 net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
48 runtime::{Builder, Runtime},
49 sync::Mutex as AsyncMutex,
50 time::timeout,
51};
52use tracing::warn;
53
54#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
55struct Work {
56 label: String,
57}
58
59#[derive(Debug)]
60struct Metrics {
61 tasks_spawned: Family<Work, Counter>,
62 tasks_running: Family<Work, Gauge>,
63
64 inbound_connections: Counter,
67 outbound_connections: Counter,
68 inbound_bandwidth: Counter,
69 outbound_bandwidth: Counter,
70
71 open_blobs: Gauge,
72 storage_reads: Counter,
73 storage_read_bytes: Counter,
74 storage_writes: Counter,
75 storage_write_bytes: Counter,
76}
77
78impl Metrics {
79 pub fn init(registry: &mut Registry) -> Self {
80 let metrics = Self {
81 tasks_spawned: Family::default(),
82 tasks_running: Family::default(),
83 inbound_connections: Counter::default(),
84 outbound_connections: Counter::default(),
85 inbound_bandwidth: Counter::default(),
86 outbound_bandwidth: Counter::default(),
87 open_blobs: Gauge::default(),
88 storage_reads: Counter::default(),
89 storage_read_bytes: Counter::default(),
90 storage_writes: Counter::default(),
91 storage_write_bytes: Counter::default(),
92 };
93 registry.register(
94 "tasks_spawned",
95 "Total number of tasks spawned",
96 metrics.tasks_spawned.clone(),
97 );
98 registry.register(
99 "tasks_running",
100 "Number of tasks currently running",
101 metrics.tasks_running.clone(),
102 );
103 registry.register(
104 "inbound_connections",
105 "Number of connections created by dialing us",
106 metrics.inbound_connections.clone(),
107 );
108 registry.register(
109 "outbound_connections",
110 "Number of connections created by dialing others",
111 metrics.outbound_connections.clone(),
112 );
113 registry.register(
114 "inbound_bandwidth",
115 "Bandwidth used by receiving data from others",
116 metrics.inbound_bandwidth.clone(),
117 );
118 registry.register(
119 "outbound_bandwidth",
120 "Bandwidth used by sending data to others",
121 metrics.outbound_bandwidth.clone(),
122 );
123 registry.register(
124 "open_blobs",
125 "Number of open blobs",
126 metrics.open_blobs.clone(),
127 );
128 registry.register(
129 "storage_reads",
130 "Total number of disk reads",
131 metrics.storage_reads.clone(),
132 );
133 registry.register(
134 "storage_read_bytes",
135 "Total amount of data read from disk",
136 metrics.storage_read_bytes.clone(),
137 );
138 registry.register(
139 "storage_writes",
140 "Total number of disk writes",
141 metrics.storage_writes.clone(),
142 );
143 registry.register(
144 "storage_write_bytes",
145 "Total amount of data written to disk",
146 metrics.storage_write_bytes.clone(),
147 );
148 metrics
149 }
150}
151
152#[derive(Clone)]
154pub struct Config {
155 pub threads: usize,
157
158 pub catch_panics: bool,
160
161 pub read_timeout: Duration,
163
164 pub write_timeout: Duration,
166
167 pub tcp_nodelay: Option<bool>,
178
179 pub storage_directory: PathBuf,
181
182 pub maximum_buffer_size: usize,
186}
187
188impl Default for Config {
189 fn default() -> Self {
190 let rng = OsRng.next_u64();
192 let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
193
194 Self {
196 threads: 2,
197 catch_panics: true,
198 read_timeout: Duration::from_secs(60),
199 write_timeout: Duration::from_secs(30),
200 tcp_nodelay: None,
201 storage_directory,
202 maximum_buffer_size: 2 * 1024 * 1024, }
204 }
205}
206
207pub struct Executor {
209 cfg: Config,
210 registry: Mutex<Registry>,
211 metrics: Arc<Metrics>,
212 runtime: Runtime,
213 fs: AsyncMutex<()>,
214 signaler: Mutex<Signaler>,
215 signal: Signal,
216}
217
218impl Executor {
219 pub fn init(cfg: Config) -> (Runner, Context) {
221 let mut registry = Registry::default();
223 let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
224
225 let metrics = Arc::new(Metrics::init(runtime_registry));
227 let runtime = Builder::new_multi_thread()
228 .worker_threads(cfg.threads)
229 .enable_all()
230 .build()
231 .expect("failed to create Tokio runtime");
232 let (signaler, signal) = Signaler::new();
233 let executor = Arc::new(Self {
234 cfg,
235 registry: Mutex::new(registry),
236 metrics,
237 runtime,
238 fs: AsyncMutex::new(()),
239 signaler: Mutex::new(signaler),
240 signal,
241 });
242 (
243 Runner {
244 executor: executor.clone(),
245 },
246 Context {
247 label: String::new(),
248 spawned: false,
249 executor,
250 },
251 )
252 }
253
254 #[allow(clippy::should_implement_trait)]
257 pub fn default() -> (Runner, Context) {
258 Self::init(Config::default())
259 }
260}
261
262pub struct Runner {
264 executor: Arc<Executor>,
265}
266
267impl crate::Runner for Runner {
268 fn start<F>(self, f: F) -> F::Output
269 where
270 F: Future + Send + 'static,
271 F::Output: Send + 'static,
272 {
273 self.executor.runtime.block_on(f)
274 }
275}
276
277pub struct Context {
281 label: String,
282 spawned: bool,
283 executor: Arc<Executor>,
284}
285
286impl Clone for Context {
287 fn clone(&self) -> Self {
288 Self {
289 label: self.label.clone(),
290 spawned: false,
291 executor: self.executor.clone(),
292 }
293 }
294}
295
296impl crate::Spawner for Context {
297 fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
298 where
299 F: FnOnce(Self) -> Fut + Send + 'static,
300 Fut: Future<Output = T> + Send + 'static,
301 T: Send + 'static,
302 {
303 assert!(!self.spawned, "already spawned");
305
306 let work = Work {
308 label: self.label.clone(),
309 };
310 self.executor
311 .metrics
312 .tasks_spawned
313 .get_or_create(&work)
314 .inc();
315 let gauge = self
316 .executor
317 .metrics
318 .tasks_running
319 .get_or_create(&work)
320 .clone();
321
322 let catch_panics = self.executor.cfg.catch_panics;
324 let executor = self.executor.clone();
325 let future = f(self);
326 let (f, handle) = Handle::init(future, gauge, catch_panics);
327
328 executor.runtime.spawn(f);
330 handle
331 }
332
333 fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
334 where
335 F: Future<Output = T> + Send + 'static,
336 T: Send + 'static,
337 {
338 assert!(!self.spawned, "already spawned");
340 self.spawned = true;
341
342 let work = Work {
344 label: self.label.clone(),
345 };
346 self.executor
347 .metrics
348 .tasks_spawned
349 .get_or_create(&work)
350 .inc();
351 let gauge = self
352 .executor
353 .metrics
354 .tasks_running
355 .get_or_create(&work)
356 .clone();
357
358 let executor = self.executor.clone();
360 move |f: F| {
361 let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
362
363 executor.runtime.spawn(f);
365 handle
366 }
367 }
368
369 fn stop(&self, value: i32) {
370 self.executor.signaler.lock().unwrap().signal(value);
371 }
372
373 fn stopped(&self) -> Signal {
374 self.executor.signal.clone()
375 }
376}
377
378impl crate::Metrics for Context {
379 fn with_label(&self, label: &str) -> Self {
380 let label = {
381 let prefix = self.label.clone();
382 if prefix.is_empty() {
383 label.to_string()
384 } else {
385 format!("{}_{}", prefix, label)
386 }
387 };
388 assert!(
389 !label.starts_with(METRICS_PREFIX),
390 "using runtime label is not allowed"
391 );
392 Self {
393 label,
394 spawned: false,
395 executor: self.executor.clone(),
396 }
397 }
398
399 fn label(&self) -> String {
400 self.label.clone()
401 }
402
403 fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
404 let name = name.into();
405 let prefixed_name = {
406 let prefix = &self.label;
407 if prefix.is_empty() {
408 name
409 } else {
410 format!("{}_{}", *prefix, name)
411 }
412 };
413 self.executor
414 .registry
415 .lock()
416 .unwrap()
417 .register(prefixed_name, help, metric)
418 }
419
420 fn encode(&self) -> String {
421 let mut buffer = String::new();
422 encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
423 buffer
424 }
425}
426
427impl Clock for Context {
428 fn current(&self) -> SystemTime {
429 SystemTime::now()
430 }
431
432 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
433 tokio::time::sleep(duration)
434 }
435
436 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
437 let now = SystemTime::now();
438 let duration_until_deadline = match deadline.duration_since(now) {
439 Ok(duration) => duration,
440 Err(_) => Duration::from_secs(0), };
442 let target_instant = tokio::time::Instant::now() + duration_until_deadline;
443 tokio::time::sleep_until(target_instant)
444 }
445}
446
447impl GClock for Context {
448 type Instant = SystemTime;
449
450 fn now(&self) -> Self::Instant {
451 self.current()
452 }
453}
454
455impl ReasonablyRealtime for Context {}
456
457impl crate::Network<Listener, Sink, Stream> for Context {
458 async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
459 TcpListener::bind(socket)
460 .await
461 .map_err(|_| Error::BindFailed)
462 .map(|listener| Listener {
463 context: self.clone(),
464 listener,
465 })
466 }
467
468 async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
469 let stream = TcpStream::connect(socket)
471 .await
472 .map_err(|_| Error::ConnectionFailed)?;
473 self.executor.metrics.outbound_connections.inc();
474
475 if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
477 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
478 warn!(?err, "failed to set TCP_NODELAY");
479 }
480 }
481
482 let context = self.clone();
484 let (stream, sink) = stream.into_split();
485 Ok((
486 Sink {
487 context: context.clone(),
488 sink,
489 },
490 Stream { context, stream },
491 ))
492 }
493}
494
495pub struct Listener {
497 context: Context,
498 listener: TcpListener,
499}
500
501impl crate::Listener<Sink, Stream> for Listener {
502 async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
503 let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
505 self.context.executor.metrics.inbound_connections.inc();
506
507 if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
509 if let Err(err) = stream.set_nodelay(tcp_nodelay) {
510 warn!(?err, "failed to set TCP_NODELAY");
511 }
512 }
513
514 let context = self.context.clone();
516 let (stream, sink) = stream.into_split();
517 Ok((
518 addr,
519 Sink {
520 context: context.clone(),
521 sink,
522 },
523 Stream { context, stream },
524 ))
525 }
526}
527
528impl axum::serve::Listener for Listener {
529 type Io = TcpStream;
530 type Addr = SocketAddr;
531
532 async fn accept(&mut self) -> (Self::Io, Self::Addr) {
533 let (stream, addr) = self.listener.accept().await.unwrap();
534 (stream, addr)
535 }
536
537 fn local_addr(&self) -> io::Result<Self::Addr> {
538 self.listener.local_addr()
539 }
540}
541
542pub struct Sink {
544 context: Context,
545 sink: OwnedWriteHalf,
546}
547
548impl crate::Sink for Sink {
549 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
550 let len = msg.len();
551 timeout(
552 self.context.executor.cfg.write_timeout,
553 self.sink.write_all(msg),
554 )
555 .await
556 .map_err(|_| Error::Timeout)?
557 .map_err(|_| Error::SendFailed)?;
558 self.context
559 .executor
560 .metrics
561 .outbound_bandwidth
562 .inc_by(len as u64);
563 Ok(())
564 }
565}
566
567pub struct Stream {
569 context: Context,
570 stream: OwnedReadHalf,
571}
572
573impl crate::Stream for Stream {
574 async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
575 timeout(
577 self.context.executor.cfg.read_timeout,
578 self.stream.read_exact(buf),
579 )
580 .await
581 .map_err(|_| Error::Timeout)?
582 .map_err(|_| Error::RecvFailed)?;
583
584 self.context
586 .executor
587 .metrics
588 .inbound_bandwidth
589 .inc_by(buf.len() as u64);
590
591 Ok(())
592 }
593}
594
595impl RngCore for Context {
596 fn next_u32(&mut self) -> u32 {
597 OsRng.next_u32()
598 }
599
600 fn next_u64(&mut self) -> u64 {
601 OsRng.next_u64()
602 }
603
604 fn fill_bytes(&mut self, dest: &mut [u8]) {
605 OsRng.fill_bytes(dest);
606 }
607
608 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
609 OsRng.try_fill_bytes(dest)
610 }
611}
612
613impl CryptoRng for Context {}
614
615pub struct Blob {
617 metrics: Arc<Metrics>,
618
619 partition: String,
620 name: Vec<u8>,
621
622 file: Arc<AsyncMutex<(fs::File, u64)>>,
629}
630
631impl Blob {
632 fn new(
633 metrics: Arc<Metrics>,
634 partition: String,
635 name: &[u8],
636 file: fs::File,
637 len: u64,
638 ) -> Self {
639 metrics.open_blobs.inc();
640 Self {
641 metrics,
642 partition,
643 name: name.into(),
644 file: Arc::new(AsyncMutex::new((file, len))),
645 }
646 }
647}
648
649impl Clone for Blob {
650 fn clone(&self) -> Self {
651 self.metrics.open_blobs.inc();
653 Self {
654 metrics: self.metrics.clone(),
655 partition: self.partition.clone(),
656 name: self.name.clone(),
657 file: self.file.clone(),
658 }
659 }
660}
661
662impl crate::Storage<Blob> for Context {
663 async fn open(&self, partition: &str, name: &[u8]) -> Result<Blob, Error> {
664 let _guard = self.executor.fs.lock().await;
666
667 let path = self
669 .executor
670 .cfg
671 .storage_directory
672 .join(partition)
673 .join(hex(name));
674 let parent = match path.parent() {
675 Some(parent) => parent,
676 None => return Err(Error::PartitionCreationFailed(partition.into())),
677 };
678
679 fs::create_dir_all(parent)
681 .await
682 .map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
683
684 let mut file = fs::OpenOptions::new()
686 .read(true)
687 .write(true)
688 .create(true)
689 .truncate(false)
690 .open(&path)
691 .await
692 .map_err(|_| Error::BlobOpenFailed(partition.into(), hex(name)))?;
693
694 file.set_max_buf_size(self.executor.cfg.maximum_buffer_size);
696
697 let len = file.metadata().await.map_err(|_| Error::ReadFailed)?.len();
699
700 Ok(Blob::new(
702 self.executor.metrics.clone(),
703 partition.into(),
704 name,
705 file,
706 len,
707 ))
708 }
709
710 async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
711 let _guard = self.executor.fs.lock().await;
713
714 let path = self.executor.cfg.storage_directory.join(partition);
716 if let Some(name) = name {
717 let blob_path = path.join(hex(name));
718 fs::remove_file(blob_path)
719 .await
720 .map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
721 } else {
722 fs::remove_dir_all(path)
723 .await
724 .map_err(|_| Error::PartitionMissing(partition.into()))?;
725 }
726 Ok(())
727 }
728
729 async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
730 let _guard = self.executor.fs.lock().await;
732
733 let path = self.executor.cfg.storage_directory.join(partition);
735 let mut entries = fs::read_dir(path)
736 .await
737 .map_err(|_| Error::PartitionMissing(partition.into()))?;
738 let mut blobs = Vec::new();
739 while let Some(entry) = entries.next_entry().await.map_err(|_| Error::ReadFailed)? {
740 let file_type = entry.file_type().await.map_err(|_| Error::ReadFailed)?;
741 if !file_type.is_file() {
742 return Err(Error::PartitionCorrupt(partition.into()));
743 }
744 if let Some(name) = entry.file_name().to_str() {
745 let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
746 blobs.push(name);
747 }
748 }
749 Ok(blobs)
750 }
751}
752
753impl crate::Blob for Blob {
754 async fn len(&self) -> Result<u64, Error> {
755 let (_, len) = *self.file.lock().await;
756 Ok(len)
757 }
758
759 async fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<(), Error> {
760 let mut file = self.file.lock().await;
762 if offset + buf.len() as u64 > file.1 {
763 return Err(Error::BlobInsufficientLength);
764 }
765
766 file.0
768 .seek(SeekFrom::Start(offset))
769 .await
770 .map_err(|_| Error::ReadFailed)?;
771 file.0
772 .read_exact(buf)
773 .await
774 .map_err(|_| Error::ReadFailed)?;
775 self.metrics.storage_reads.inc();
776 self.metrics.storage_read_bytes.inc_by(buf.len() as u64);
777 Ok(())
778 }
779
780 async fn write_at(&self, buf: &[u8], offset: u64) -> Result<(), Error> {
781 let mut file = self.file.lock().await;
783 file.0
784 .seek(SeekFrom::Start(offset))
785 .await
786 .map_err(|_| Error::WriteFailed)?;
787 file.0
788 .write_all(buf)
789 .await
790 .map_err(|_| Error::WriteFailed)?;
791
792 let max_len = offset + buf.len() as u64;
794 if max_len > file.1 {
795 file.1 = max_len;
796 }
797 self.metrics.storage_writes.inc();
798 self.metrics.storage_write_bytes.inc_by(buf.len() as u64);
799 Ok(())
800 }
801
802 async fn truncate(&self, len: u64) -> Result<(), Error> {
803 let mut file = self.file.lock().await;
805 file.0
806 .set_len(len)
807 .await
808 .map_err(|_| Error::BlobTruncateFailed(self.partition.clone(), hex(&self.name)))?;
809
810 file.1 = len;
812 Ok(())
813 }
814
815 async fn sync(&self) -> Result<(), Error> {
816 let file = self.file.lock().await;
817 file.0
818 .sync_all()
819 .await
820 .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))
821 }
822
823 async fn close(self) -> Result<(), Error> {
824 let mut file = self.file.lock().await;
825 file.0
826 .sync_all()
827 .await
828 .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))?;
829 file.0
830 .shutdown()
831 .await
832 .map_err(|_| Error::BlobCloseFailed(self.partition.clone(), hex(&self.name)))
833 }
834}
835
836impl Drop for Blob {
837 fn drop(&mut self) {
838 self.metrics.open_blobs.dec();
839 }
840}