1use crate::{utils::Signaler, Clock, Error, Handle, Signal, METRICS_PREFIX};
27use commonware_utils::{from_hex, hex};
28use governor::clock::{Clock as GClock, ReasonablyRealtime};
29use prometheus_client::{
30    encoding::{text::encode, EncodeLabelSet},
31    metrics::{counter::Counter, family::Family, gauge::Gauge},
32    registry::{Metric, Registry},
33};
34use rand::{rngs::OsRng, CryptoRng, RngCore};
35use std::{
36    env,
37    future::Future,
38    io::{self, SeekFrom},
39    net::SocketAddr,
40    path::PathBuf,
41    sync::{Arc, Mutex},
42    time::{Duration, SystemTime},
43};
44use tokio::{
45    fs,
46    io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
47    net::{tcp::OwnedReadHalf, tcp::OwnedWriteHalf, TcpListener, TcpStream},
48    runtime::{Builder, Runtime},
49    sync::Mutex as AsyncMutex,
50    time::timeout,
51};
52use tracing::warn;
53
54#[derive(Clone, Debug, Hash, PartialEq, Eq, EncodeLabelSet)]
55struct Work {
56    label: String,
57}
58
59#[derive(Debug)]
60struct Metrics {
61    tasks_spawned: Family<Work, Counter>,
62    tasks_running: Family<Work, Gauge>,
63
64    inbound_connections: Counter,
67    outbound_connections: Counter,
68    inbound_bandwidth: Counter,
69    outbound_bandwidth: Counter,
70
71    open_blobs: Gauge,
72    storage_reads: Counter,
73    storage_read_bytes: Counter,
74    storage_writes: Counter,
75    storage_write_bytes: Counter,
76}
77
78impl Metrics {
79    pub fn init(registry: &mut Registry) -> Self {
80        let metrics = Self {
81            tasks_spawned: Family::default(),
82            tasks_running: Family::default(),
83            inbound_connections: Counter::default(),
84            outbound_connections: Counter::default(),
85            inbound_bandwidth: Counter::default(),
86            outbound_bandwidth: Counter::default(),
87            open_blobs: Gauge::default(),
88            storage_reads: Counter::default(),
89            storage_read_bytes: Counter::default(),
90            storage_writes: Counter::default(),
91            storage_write_bytes: Counter::default(),
92        };
93        registry.register(
94            "tasks_spawned",
95            "Total number of tasks spawned",
96            metrics.tasks_spawned.clone(),
97        );
98        registry.register(
99            "tasks_running",
100            "Number of tasks currently running",
101            metrics.tasks_running.clone(),
102        );
103        registry.register(
104            "inbound_connections",
105            "Number of connections created by dialing us",
106            metrics.inbound_connections.clone(),
107        );
108        registry.register(
109            "outbound_connections",
110            "Number of connections created by dialing others",
111            metrics.outbound_connections.clone(),
112        );
113        registry.register(
114            "inbound_bandwidth",
115            "Bandwidth used by receiving data from others",
116            metrics.inbound_bandwidth.clone(),
117        );
118        registry.register(
119            "outbound_bandwidth",
120            "Bandwidth used by sending data to others",
121            metrics.outbound_bandwidth.clone(),
122        );
123        registry.register(
124            "open_blobs",
125            "Number of open blobs",
126            metrics.open_blobs.clone(),
127        );
128        registry.register(
129            "storage_reads",
130            "Total number of disk reads",
131            metrics.storage_reads.clone(),
132        );
133        registry.register(
134            "storage_read_bytes",
135            "Total amount of data read from disk",
136            metrics.storage_read_bytes.clone(),
137        );
138        registry.register(
139            "storage_writes",
140            "Total number of disk writes",
141            metrics.storage_writes.clone(),
142        );
143        registry.register(
144            "storage_write_bytes",
145            "Total amount of data written to disk",
146            metrics.storage_write_bytes.clone(),
147        );
148        metrics
149    }
150}
151
152#[derive(Clone)]
154pub struct Config {
155    pub threads: usize,
157
158    pub catch_panics: bool,
160
161    pub read_timeout: Duration,
163
164    pub write_timeout: Duration,
166
167    pub tcp_nodelay: Option<bool>,
178
179    pub storage_directory: PathBuf,
181
182    pub maximum_buffer_size: usize,
186}
187
188impl Default for Config {
189    fn default() -> Self {
190        let rng = OsRng.next_u64();
192        let storage_directory = env::temp_dir().join(format!("commonware_tokio_runtime_{}", rng));
193
194        Self {
196            threads: 2,
197            catch_panics: true,
198            read_timeout: Duration::from_secs(60),
199            write_timeout: Duration::from_secs(30),
200            tcp_nodelay: None,
201            storage_directory,
202            maximum_buffer_size: 2 * 1024 * 1024, }
204    }
205}
206
207pub struct Executor {
209    cfg: Config,
210    registry: Mutex<Registry>,
211    metrics: Arc<Metrics>,
212    runtime: Runtime,
213    fs: AsyncMutex<()>,
214    signaler: Mutex<Signaler>,
215    signal: Signal,
216}
217
218impl Executor {
219    pub fn init(cfg: Config) -> (Runner, Context) {
221        let mut registry = Registry::default();
223        let runtime_registry = registry.sub_registry_with_prefix(METRICS_PREFIX);
224
225        let metrics = Arc::new(Metrics::init(runtime_registry));
227        let runtime = Builder::new_multi_thread()
228            .worker_threads(cfg.threads)
229            .enable_all()
230            .build()
231            .expect("failed to create Tokio runtime");
232        let (signaler, signal) = Signaler::new();
233        let executor = Arc::new(Self {
234            cfg,
235            registry: Mutex::new(registry),
236            metrics,
237            runtime,
238            fs: AsyncMutex::new(()),
239            signaler: Mutex::new(signaler),
240            signal,
241        });
242        (
243            Runner {
244                executor: executor.clone(),
245            },
246            Context {
247                label: String::new(),
248                spawned: false,
249                executor,
250            },
251        )
252    }
253
254    #[allow(clippy::should_implement_trait)]
257    pub fn default() -> (Runner, Context) {
258        Self::init(Config::default())
259    }
260}
261
262pub struct Runner {
264    executor: Arc<Executor>,
265}
266
267impl crate::Runner for Runner {
268    fn start<F>(self, f: F) -> F::Output
269    where
270        F: Future + Send + 'static,
271        F::Output: Send + 'static,
272    {
273        self.executor.runtime.block_on(f)
274    }
275}
276
277pub struct Context {
281    label: String,
282    spawned: bool,
283    executor: Arc<Executor>,
284}
285
286impl Clone for Context {
287    fn clone(&self) -> Self {
288        Self {
289            label: self.label.clone(),
290            spawned: false,
291            executor: self.executor.clone(),
292        }
293    }
294}
295
296impl crate::Spawner for Context {
297    fn spawn<F, Fut, T>(self, f: F) -> Handle<T>
298    where
299        F: FnOnce(Self) -> Fut + Send + 'static,
300        Fut: Future<Output = T> + Send + 'static,
301        T: Send + 'static,
302    {
303        assert!(!self.spawned, "already spawned");
305
306        let work = Work {
308            label: self.label.clone(),
309        };
310        self.executor
311            .metrics
312            .tasks_spawned
313            .get_or_create(&work)
314            .inc();
315        let gauge = self
316            .executor
317            .metrics
318            .tasks_running
319            .get_or_create(&work)
320            .clone();
321
322        let catch_panics = self.executor.cfg.catch_panics;
324        let executor = self.executor.clone();
325        let future = f(self);
326        let (f, handle) = Handle::init(future, gauge, catch_panics);
327
328        executor.runtime.spawn(f);
330        handle
331    }
332
333    fn spawn_ref<F, T>(&mut self) -> impl FnOnce(F) -> Handle<T> + 'static
334    where
335        F: Future<Output = T> + Send + 'static,
336        T: Send + 'static,
337    {
338        assert!(!self.spawned, "already spawned");
340        self.spawned = true;
341
342        let work = Work {
344            label: self.label.clone(),
345        };
346        self.executor
347            .metrics
348            .tasks_spawned
349            .get_or_create(&work)
350            .inc();
351        let gauge = self
352            .executor
353            .metrics
354            .tasks_running
355            .get_or_create(&work)
356            .clone();
357
358        let executor = self.executor.clone();
360        move |f: F| {
361            let (f, handle) = Handle::init(f, gauge, executor.cfg.catch_panics);
362
363            executor.runtime.spawn(f);
365            handle
366        }
367    }
368
369    fn stop(&self, value: i32) {
370        self.executor.signaler.lock().unwrap().signal(value);
371    }
372
373    fn stopped(&self) -> Signal {
374        self.executor.signal.clone()
375    }
376}
377
378impl crate::Metrics for Context {
379    fn with_label(&self, label: &str) -> Self {
380        let label = {
381            let prefix = self.label.clone();
382            if prefix.is_empty() {
383                label.to_string()
384            } else {
385                format!("{}_{}", prefix, label)
386            }
387        };
388        assert!(
389            !label.starts_with(METRICS_PREFIX),
390            "using runtime label is not allowed"
391        );
392        Self {
393            label,
394            spawned: false,
395            executor: self.executor.clone(),
396        }
397    }
398
399    fn label(&self) -> String {
400        self.label.clone()
401    }
402
403    fn register<N: Into<String>, H: Into<String>>(&self, name: N, help: H, metric: impl Metric) {
404        let name = name.into();
405        let prefixed_name = {
406            let prefix = &self.label;
407            if prefix.is_empty() {
408                name
409            } else {
410                format!("{}_{}", *prefix, name)
411            }
412        };
413        self.executor
414            .registry
415            .lock()
416            .unwrap()
417            .register(prefixed_name, help, metric)
418    }
419
420    fn encode(&self) -> String {
421        let mut buffer = String::new();
422        encode(&mut buffer, &self.executor.registry.lock().unwrap()).expect("encoding failed");
423        buffer
424    }
425}
426
427impl Clock for Context {
428    fn current(&self) -> SystemTime {
429        SystemTime::now()
430    }
431
432    fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static {
433        tokio::time::sleep(duration)
434    }
435
436    fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static {
437        let now = SystemTime::now();
438        let duration_until_deadline = match deadline.duration_since(now) {
439            Ok(duration) => duration,
440            Err(_) => Duration::from_secs(0), };
442        let target_instant = tokio::time::Instant::now() + duration_until_deadline;
443        tokio::time::sleep_until(target_instant)
444    }
445}
446
447impl GClock for Context {
448    type Instant = SystemTime;
449
450    fn now(&self) -> Self::Instant {
451        self.current()
452    }
453}
454
455impl ReasonablyRealtime for Context {}
456
457impl crate::Network<Listener, Sink, Stream> for Context {
458    async fn bind(&self, socket: SocketAddr) -> Result<Listener, Error> {
459        TcpListener::bind(socket)
460            .await
461            .map_err(|_| Error::BindFailed)
462            .map(|listener| Listener {
463                context: self.clone(),
464                listener,
465            })
466    }
467
468    async fn dial(&self, socket: SocketAddr) -> Result<(Sink, Stream), Error> {
469        let stream = TcpStream::connect(socket)
471            .await
472            .map_err(|_| Error::ConnectionFailed)?;
473        self.executor.metrics.outbound_connections.inc();
474
475        if let Some(tcp_nodelay) = self.executor.cfg.tcp_nodelay {
477            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
478                warn!(?err, "failed to set TCP_NODELAY");
479            }
480        }
481
482        let context = self.clone();
484        let (stream, sink) = stream.into_split();
485        Ok((
486            Sink {
487                context: context.clone(),
488                sink,
489            },
490            Stream { context, stream },
491        ))
492    }
493}
494
495pub struct Listener {
497    context: Context,
498    listener: TcpListener,
499}
500
501impl crate::Listener<Sink, Stream> for Listener {
502    async fn accept(&mut self) -> Result<(SocketAddr, Sink, Stream), Error> {
503        let (stream, addr) = self.listener.accept().await.map_err(|_| Error::Closed)?;
505        self.context.executor.metrics.inbound_connections.inc();
506
507        if let Some(tcp_nodelay) = self.context.executor.cfg.tcp_nodelay {
509            if let Err(err) = stream.set_nodelay(tcp_nodelay) {
510                warn!(?err, "failed to set TCP_NODELAY");
511            }
512        }
513
514        let context = self.context.clone();
516        let (stream, sink) = stream.into_split();
517        Ok((
518            addr,
519            Sink {
520                context: context.clone(),
521                sink,
522            },
523            Stream { context, stream },
524        ))
525    }
526}
527
528impl axum::serve::Listener for Listener {
529    type Io = TcpStream;
530    type Addr = SocketAddr;
531
532    async fn accept(&mut self) -> (Self::Io, Self::Addr) {
533        let (stream, addr) = self.listener.accept().await.unwrap();
534        (stream, addr)
535    }
536
537    fn local_addr(&self) -> io::Result<Self::Addr> {
538        self.listener.local_addr()
539    }
540}
541
542pub struct Sink {
544    context: Context,
545    sink: OwnedWriteHalf,
546}
547
548impl crate::Sink for Sink {
549    async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
550        let len = msg.len();
551        timeout(
552            self.context.executor.cfg.write_timeout,
553            self.sink.write_all(msg),
554        )
555        .await
556        .map_err(|_| Error::Timeout)?
557        .map_err(|_| Error::SendFailed)?;
558        self.context
559            .executor
560            .metrics
561            .outbound_bandwidth
562            .inc_by(len as u64);
563        Ok(())
564    }
565}
566
567pub struct Stream {
569    context: Context,
570    stream: OwnedReadHalf,
571}
572
573impl crate::Stream for Stream {
574    async fn recv(&mut self, buf: &mut [u8]) -> Result<(), Error> {
575        timeout(
577            self.context.executor.cfg.read_timeout,
578            self.stream.read_exact(buf),
579        )
580        .await
581        .map_err(|_| Error::Timeout)?
582        .map_err(|_| Error::RecvFailed)?;
583
584        self.context
586            .executor
587            .metrics
588            .inbound_bandwidth
589            .inc_by(buf.len() as u64);
590
591        Ok(())
592    }
593}
594
595impl RngCore for Context {
596    fn next_u32(&mut self) -> u32 {
597        OsRng.next_u32()
598    }
599
600    fn next_u64(&mut self) -> u64 {
601        OsRng.next_u64()
602    }
603
604    fn fill_bytes(&mut self, dest: &mut [u8]) {
605        OsRng.fill_bytes(dest);
606    }
607
608    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand::Error> {
609        OsRng.try_fill_bytes(dest)
610    }
611}
612
613impl CryptoRng for Context {}
614
615pub struct Blob {
617    metrics: Arc<Metrics>,
618
619    partition: String,
620    name: Vec<u8>,
621
622    file: Arc<AsyncMutex<(fs::File, u64)>>,
629}
630
631impl Blob {
632    fn new(
633        metrics: Arc<Metrics>,
634        partition: String,
635        name: &[u8],
636        file: fs::File,
637        len: u64,
638    ) -> Self {
639        metrics.open_blobs.inc();
640        Self {
641            metrics,
642            partition,
643            name: name.into(),
644            file: Arc::new(AsyncMutex::new((file, len))),
645        }
646    }
647}
648
649impl Clone for Blob {
650    fn clone(&self) -> Self {
651        self.metrics.open_blobs.inc();
653        Self {
654            metrics: self.metrics.clone(),
655            partition: self.partition.clone(),
656            name: self.name.clone(),
657            file: self.file.clone(),
658        }
659    }
660}
661
662impl crate::Storage<Blob> for Context {
663    async fn open(&self, partition: &str, name: &[u8]) -> Result<Blob, Error> {
664        let _guard = self.executor.fs.lock().await;
666
667        let path = self
669            .executor
670            .cfg
671            .storage_directory
672            .join(partition)
673            .join(hex(name));
674        let parent = match path.parent() {
675            Some(parent) => parent,
676            None => return Err(Error::PartitionCreationFailed(partition.into())),
677        };
678
679        fs::create_dir_all(parent)
681            .await
682            .map_err(|_| Error::PartitionCreationFailed(partition.into()))?;
683
684        let mut file = fs::OpenOptions::new()
686            .read(true)
687            .write(true)
688            .create(true)
689            .truncate(false)
690            .open(&path)
691            .await
692            .map_err(|_| Error::BlobOpenFailed(partition.into(), hex(name)))?;
693
694        file.set_max_buf_size(self.executor.cfg.maximum_buffer_size);
696
697        let len = file.metadata().await.map_err(|_| Error::ReadFailed)?.len();
699
700        Ok(Blob::new(
702            self.executor.metrics.clone(),
703            partition.into(),
704            name,
705            file,
706            len,
707        ))
708    }
709
710    async fn remove(&self, partition: &str, name: Option<&[u8]>) -> Result<(), Error> {
711        let _guard = self.executor.fs.lock().await;
713
714        let path = self.executor.cfg.storage_directory.join(partition);
716        if let Some(name) = name {
717            let blob_path = path.join(hex(name));
718            fs::remove_file(blob_path)
719                .await
720                .map_err(|_| Error::BlobMissing(partition.into(), hex(name)))?;
721        } else {
722            fs::remove_dir_all(path)
723                .await
724                .map_err(|_| Error::PartitionMissing(partition.into()))?;
725        }
726        Ok(())
727    }
728
729    async fn scan(&self, partition: &str) -> Result<Vec<Vec<u8>>, Error> {
730        let _guard = self.executor.fs.lock().await;
732
733        let path = self.executor.cfg.storage_directory.join(partition);
735        let mut entries = fs::read_dir(path)
736            .await
737            .map_err(|_| Error::PartitionMissing(partition.into()))?;
738        let mut blobs = Vec::new();
739        while let Some(entry) = entries.next_entry().await.map_err(|_| Error::ReadFailed)? {
740            let file_type = entry.file_type().await.map_err(|_| Error::ReadFailed)?;
741            if !file_type.is_file() {
742                return Err(Error::PartitionCorrupt(partition.into()));
743            }
744            if let Some(name) = entry.file_name().to_str() {
745                let name = from_hex(name).ok_or(Error::PartitionCorrupt(partition.into()))?;
746                blobs.push(name);
747            }
748        }
749        Ok(blobs)
750    }
751}
752
753impl crate::Blob for Blob {
754    async fn len(&self) -> Result<u64, Error> {
755        let (_, len) = *self.file.lock().await;
756        Ok(len)
757    }
758
759    async fn read_at(&self, buf: &mut [u8], offset: u64) -> Result<(), Error> {
760        let mut file = self.file.lock().await;
762        if offset + buf.len() as u64 > file.1 {
763            return Err(Error::BlobInsufficientLength);
764        }
765
766        file.0
768            .seek(SeekFrom::Start(offset))
769            .await
770            .map_err(|_| Error::ReadFailed)?;
771        file.0
772            .read_exact(buf)
773            .await
774            .map_err(|_| Error::ReadFailed)?;
775        self.metrics.storage_reads.inc();
776        self.metrics.storage_read_bytes.inc_by(buf.len() as u64);
777        Ok(())
778    }
779
780    async fn write_at(&self, buf: &[u8], offset: u64) -> Result<(), Error> {
781        let mut file = self.file.lock().await;
783        file.0
784            .seek(SeekFrom::Start(offset))
785            .await
786            .map_err(|_| Error::WriteFailed)?;
787        file.0
788            .write_all(buf)
789            .await
790            .map_err(|_| Error::WriteFailed)?;
791
792        let max_len = offset + buf.len() as u64;
794        if max_len > file.1 {
795            file.1 = max_len;
796        }
797        self.metrics.storage_writes.inc();
798        self.metrics.storage_write_bytes.inc_by(buf.len() as u64);
799        Ok(())
800    }
801
802    async fn truncate(&self, len: u64) -> Result<(), Error> {
803        let mut file = self.file.lock().await;
805        file.0
806            .set_len(len)
807            .await
808            .map_err(|_| Error::BlobTruncateFailed(self.partition.clone(), hex(&self.name)))?;
809
810        file.1 = len;
812        Ok(())
813    }
814
815    async fn sync(&self) -> Result<(), Error> {
816        let file = self.file.lock().await;
817        file.0
818            .sync_all()
819            .await
820            .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))
821    }
822
823    async fn close(self) -> Result<(), Error> {
824        let mut file = self.file.lock().await;
825        file.0
826            .sync_all()
827            .await
828            .map_err(|_| Error::BlobSyncFailed(self.partition.clone(), hex(&self.name)))?;
829        file.0
830            .shutdown()
831            .await
832            .map_err(|_| Error::BlobCloseFailed(self.partition.clone(), hex(&self.name)))
833    }
834}
835
836impl Drop for Blob {
837    fn drop(&mut self) {
838        self.metrics.open_blobs.dec();
839    }
840}