1use std::{
15 future::Future,
16 net::SocketAddr,
17 time::{Duration, SystemTime},
18};
19use thiserror::Error;
20
21pub mod deterministic;
22pub mod mocks;
23cfg_if::cfg_if! {
24 if #[cfg(not(target_arch = "wasm32"))] {
25 pub mod tokio;
26 }
27}
28
29mod utils;
30pub use utils::{reschedule, Handle, Signal, Signaler};
31
32#[derive(Error, Debug, PartialEq)]
34pub enum Error {
35 #[error("exited")]
36 Exited,
37 #[error("closed")]
38 Closed,
39 #[error("timeout")]
40 Timeout,
41 #[error("bind failed")]
42 BindFailed,
43 #[error("connection failed")]
44 ConnectionFailed,
45 #[error("write failed")]
46 WriteFailed,
47 #[error("read failed")]
48 ReadFailed,
49 #[error("send failed")]
50 SendFailed,
51 #[error("recv failed")]
52 RecvFailed,
53 #[error("partition creation failed: {0}")]
54 PartitionCreationFailed(String),
55 #[error("partition missing: {0}")]
56 PartitionMissing(String),
57 #[error("partition corrupt: {0}")]
58 PartitionCorrupt(String),
59 #[error("blob open failed: {0}/{1}")]
60 BlobOpenFailed(String, String),
61 #[error("blob missing: {0}/{1}")]
62 BlobMissing(String, String),
63 #[error("blob truncate failed: {0}/{1}")]
64 BlobTruncateFailed(String, String),
65 #[error("blob sync failed: {0}/{1}")]
66 BlobSyncFailed(String, String),
67 #[error("blob close failed: {0}/{1}")]
68 BlobCloseFailed(String, String),
69 #[error("blob insufficient length")]
70 BlobInsufficientLength,
71 #[error("offset overflow")]
72 OffsetOverflow,
73}
74
75pub trait Runner {
78 fn start<F>(self, f: F) -> F::Output
80 where
81 F: Future + Send + 'static,
82 F::Output: Send + 'static;
83}
84
85pub trait Spawner: Clone + Send + Sync + 'static {
88 fn spawn<F, T>(&self, label: &str, f: F) -> Handle<T>
98 where
99 F: Future<Output = T> + Send + 'static,
100 T: Send + 'static;
101
102 fn stop(&self, value: i32);
109
110 fn stopped(&self) -> Signal;
115}
116
117pub trait Clock: Clone + Send + Sync + 'static {
123 fn current(&self) -> SystemTime;
125
126 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + 'static;
128
129 fn sleep_until(&self, deadline: SystemTime) -> impl Future<Output = ()> + Send + 'static;
131}
132
133pub trait Network<L, Si, St>: Clone + Send + Sync + 'static
136where
137 L: Listener<Si, St>,
138 Si: Sink,
139 St: Stream,
140{
141 fn bind(&self, socket: SocketAddr) -> impl Future<Output = Result<L, Error>> + Send;
143
144 fn dial(&self, socket: SocketAddr) -> impl Future<Output = Result<(Si, St), Error>> + Send;
146}
147
148pub trait Listener<Si, St>: Sync + Send + 'static
151where
152 Si: Sink,
153 St: Stream,
154{
155 fn accept(&mut self) -> impl Future<Output = Result<(SocketAddr, Si, St), Error>> + Send;
157}
158
159pub trait Sink: Sync + Send + 'static {
162 fn send(&mut self, msg: &[u8]) -> impl Future<Output = Result<(), Error>> + Send;
164}
165
166pub trait Stream: Sync + Send + 'static {
169 fn recv(&mut self, buf: &mut [u8]) -> impl Future<Output = Result<(), Error>> + Send;
172}
173
174pub trait Storage<B>: Clone + Send + Sync + 'static
182where
183 B: Blob,
184{
185 fn open(&self, partition: &str, name: &[u8]) -> impl Future<Output = Result<B, Error>> + Send;
190
191 fn remove(
195 &self,
196 partition: &str,
197 name: Option<&[u8]>,
198 ) -> impl Future<Output = Result<(), Error>> + Send;
199
200 fn scan(&self, partition: &str) -> impl Future<Output = Result<Vec<Vec<u8>>, Error>> + Send;
202}
203
204#[allow(clippy::len_without_is_empty)]
215pub trait Blob: Clone + Send + Sync + 'static {
216 fn len(&self) -> impl Future<Output = Result<u64, Error>> + Send;
218
219 fn read_at(
224 &self,
225 buf: &mut [u8],
226 offset: u64,
227 ) -> impl Future<Output = Result<(), Error>> + Send;
228
229 fn write_at(&self, buf: &[u8], offset: u64) -> impl Future<Output = Result<(), Error>> + Send;
231
232 fn truncate(&self, len: u64) -> impl Future<Output = Result<(), Error>> + Send;
234
235 fn sync(&self) -> impl Future<Output = Result<(), Error>> + Send;
237
238 fn close(self) -> impl Future<Output = Result<(), Error>> + Send;
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use commonware_macros::select;
246 use futures::{channel::mpsc, future::ready, join, SinkExt, StreamExt};
247 use prometheus_client::encoding::text::encode;
248 use prometheus_client::registry::Registry;
249 use std::panic::{catch_unwind, AssertUnwindSafe};
250 use std::sync::{Arc, Mutex};
251 use utils::reschedule;
252
253 fn test_error_future(runner: impl Runner) {
254 async fn error_future() -> Result<&'static str, &'static str> {
255 Err("An error occurred")
256 }
257 let result = runner.start(error_future());
258 assert_eq!(result, Err("An error occurred"));
259 }
260
261 fn test_clock_sleep(runner: impl Runner, context: impl Spawner + Clock) {
262 runner.start(async move {
263 let start = context.current();
265 let sleep_duration = Duration::from_millis(10);
266 context.sleep(sleep_duration).await;
267
268 let end = context.current();
270 assert!(end.duration_since(start).unwrap() >= sleep_duration);
271 });
272 }
273
274 fn test_clock_sleep_until(runner: impl Runner, context: impl Spawner + Clock) {
275 runner.start(async move {
276 let now = context.current();
278 context.sleep_until(now + Duration::from_millis(100)).await;
279
280 let elapsed = now.elapsed().unwrap();
282 assert!(elapsed >= Duration::from_millis(100));
283 });
284 }
285
286 fn test_root_finishes(runner: impl Runner, context: impl Spawner) {
287 runner.start(async move {
288 context.spawn("test", async move {
289 loop {
290 reschedule().await;
291 }
292 });
293 });
294 }
295
296 fn test_spawn_abort(runner: impl Runner, context: impl Spawner) {
297 runner.start(async move {
298 let handle = context.spawn("test", async move {
299 loop {
300 reschedule().await;
301 }
302 });
303 handle.abort();
304 assert_eq!(handle.await, Err(Error::Closed));
305 });
306 }
307
308 fn test_panic_aborts_root(runner: impl Runner) {
309 let result = catch_unwind(AssertUnwindSafe(|| {
310 runner.start(async move {
311 panic!("blah");
312 });
313 }));
314 result.unwrap_err();
315 }
316
317 fn test_panic_aborts_spawn(runner: impl Runner, context: impl Spawner) {
318 let result = runner.start(async move {
319 let result = context.spawn("test", async move {
320 panic!("blah");
321 });
322 assert_eq!(result.await, Err(Error::Exited));
323 Result::<(), Error>::Ok(())
324 });
325
326 result.unwrap();
328 }
329
330 fn test_select(runner: impl Runner) {
331 runner.start(async move {
332 let output = Mutex::new(0);
334 select! {
335 v1 = ready(1) => {
336 *output.lock().unwrap() = v1;
337 },
338 v2 = ready(2) => {
339 *output.lock().unwrap() = v2;
340 },
341 };
342 assert_eq!(*output.lock().unwrap(), 1);
343
344 select! {
346 v1 = std::future::pending::<i32>() => {
347 *output.lock().unwrap() = v1;
348 },
349 v2 = ready(2) => {
350 *output.lock().unwrap() = v2;
351 },
352 };
353 assert_eq!(*output.lock().unwrap(), 2);
354 });
355 }
356
357 fn test_select_loop(runner: impl Runner, context: impl Clock) {
359 runner.start(async move {
360 let (mut sender, mut receiver) = mpsc::unbounded();
362 for _ in 0..2 {
363 select! {
364 v = receiver.next() => {
365 panic!("unexpected value: {:?}", v);
366 },
367 _ = context.sleep(Duration::from_millis(100)) => {
368 continue;
369 },
370 };
371 }
372
373 sender.send(0).await.unwrap();
375 sender.send(1).await.unwrap();
376
377 select! {
379 _ = async {} => {
380 },
382 v = receiver.next() => {
383 panic!("unexpected value: {:?}", v);
384 },
385 };
386
387 for i in 0..2 {
389 select! {
390 _ = context.sleep(Duration::from_millis(100)) => {
391 panic!("timeout");
392 },
393 v = receiver.next() => {
394 assert_eq!(v.unwrap(), i);
395 },
396 };
397 }
398 });
399 }
400
401 fn test_storage_operations<B>(runner: impl Runner, context: impl Spawner + Storage<B>)
402 where
403 B: Blob,
404 {
405 runner.start(async move {
406 let partition = "test_partition";
407 let name = b"test_blob";
408
409 let blob = context
411 .open(partition, name)
412 .await
413 .expect("Failed to open blob");
414
415 let data = b"Hello, Storage!";
417 blob.write_at(data, 0)
418 .await
419 .expect("Failed to write to blob");
420
421 blob.sync().await.expect("Failed to sync blob");
423
424 let mut buffer = vec![0u8; data.len()];
426 blob.read_at(&mut buffer, 0)
427 .await
428 .expect("Failed to read from blob");
429 assert_eq!(&buffer, data);
430
431 let length = blob.len().await.expect("Failed to get blob length");
433 assert_eq!(length, data.len() as u64);
434
435 blob.close().await.expect("Failed to close blob");
437
438 let blobs = context
440 .scan(partition)
441 .await
442 .expect("Failed to scan partition");
443 assert!(blobs.contains(&name.to_vec()));
444
445 let blob = context
447 .open(partition, name)
448 .await
449 .expect("Failed to reopen blob");
450
451 let mut buffer = vec![0u8; 7];
453 blob.read_at(&mut buffer, 7)
454 .await
455 .expect("Failed to read data");
456 assert_eq!(&buffer, b"Storage");
457
458 blob.close().await.expect("Failed to close blob");
460
461 context
463 .remove(partition, Some(name))
464 .await
465 .expect("Failed to remove blob");
466
467 let blobs = context
469 .scan(partition)
470 .await
471 .expect("Failed to scan partition");
472 assert!(!blobs.contains(&name.to_vec()));
473
474 context
476 .remove(partition, None)
477 .await
478 .expect("Failed to remove partition");
479
480 let result = context.scan(partition).await;
482 assert!(matches!(result, Err(Error::PartitionMissing(_))));
483 });
484 }
485
486 fn test_blob_read_write<B>(runner: impl Runner, context: impl Spawner + Storage<B>)
487 where
488 B: Blob,
489 {
490 runner.start(async move {
491 let partition = "test_partition";
492 let name = b"test_blob_rw";
493
494 let blob = context
496 .open(partition, name)
497 .await
498 .expect("Failed to open blob");
499
500 let data1 = b"Hello";
502 let data2 = b"World";
503 blob.write_at(data1, 0)
504 .await
505 .expect("Failed to write data1");
506 blob.write_at(data2, 5)
507 .await
508 .expect("Failed to write data2");
509
510 let length = blob.len().await.expect("Failed to get blob length");
512 assert_eq!(length, 10);
513
514 let mut buffer = vec![0u8; 10];
516 blob.read_at(&mut buffer, 0)
517 .await
518 .expect("Failed to read data");
519 assert_eq!(&buffer[..5], data1);
520 assert_eq!(&buffer[5..], data2);
521
522 let data3 = b"Store";
524 blob.write_at(data3, 5)
525 .await
526 .expect("Failed to write data3");
527 let length = blob.len().await.expect("Failed to get blob length");
528 assert_eq!(length, 10);
529
530 blob.truncate(5).await.expect("Failed to truncate blob");
532 let length = blob.len().await.expect("Failed to get blob length");
533 assert_eq!(length, 5);
534 let mut buffer = vec![0u8; 5];
535 blob.read_at(&mut buffer, 0)
536 .await
537 .expect("Failed to read data");
538 assert_eq!(&buffer[..5], data1);
539
540 let mut buffer = vec![0u8; 10];
542 let result = blob.read_at(&mut buffer, 0).await;
543 assert!(matches!(result, Err(Error::BlobInsufficientLength)));
544
545 blob.close().await.expect("Failed to close blob");
547 });
548 }
549
550 fn test_many_partition_read_write<B>(runner: impl Runner, context: impl Spawner + Storage<B>)
551 where
552 B: Blob,
553 {
554 runner.start(async move {
555 let partitions = ["partition1", "partition2", "partition3"];
556 let name = b"test_blob_rw";
557
558 for (additional, partition) in partitions.iter().enumerate() {
559 let blob = context
561 .open(partition, name)
562 .await
563 .expect("Failed to open blob");
564
565 let data1 = b"Hello";
567 let data2 = b"World";
568 blob.write_at(data1, 0)
569 .await
570 .expect("Failed to write data1");
571 blob.write_at(data2, 5 + additional as u64)
572 .await
573 .expect("Failed to write data2");
574
575 blob.close().await.expect("Failed to close blob");
577 }
578
579 for (additional, partition) in partitions.iter().enumerate() {
580 let blob = context
582 .open(partition, name)
583 .await
584 .expect("Failed to open blob");
585
586 let mut buffer = vec![0u8; 10 + additional];
588 blob.read_at(&mut buffer, 0)
589 .await
590 .expect("Failed to read data");
591 assert_eq!(&buffer[..5], b"Hello");
592 assert_eq!(&buffer[5 + additional..], b"World");
593
594 blob.close().await.expect("Failed to close blob");
596 }
597 });
598 }
599
600 fn test_blob_read_past_length<B>(runner: impl Runner, context: impl Spawner + Storage<B>)
601 where
602 B: Blob,
603 {
604 runner.start(async move {
605 let partition = "test_partition";
606 let name = b"test_blob_rw";
607
608 let blob = context
610 .open(partition, name)
611 .await
612 .expect("Failed to open blob");
613
614 let mut buffer = vec![0u8; 10];
616 let result = blob.read_at(&mut buffer, 0).await;
617 assert!(matches!(result, Err(Error::BlobInsufficientLength)));
618
619 let data = b"Hello, Storage!";
621 blob.write_at(data, 0)
622 .await
623 .expect("Failed to write to blob");
624
625 let mut buffer = vec![0u8; 20];
627 let result = blob.read_at(&mut buffer, 0).await;
628 assert!(matches!(result, Err(Error::BlobInsufficientLength)));
629 })
630 }
631
632 fn test_blob_clone_and_concurrent_read<B>(
633 runner: impl Runner,
634 context: impl Spawner + Storage<B>,
635 ) where
636 B: Blob,
637 {
638 runner.start(async move {
639 let partition = "test_partition";
640 let name = b"test_blob_rw";
641
642 let blob = context
644 .open(partition, name)
645 .await
646 .expect("Failed to open blob");
647
648 let data = b"Hello, Storage!";
650 blob.write_at(data, 0)
651 .await
652 .expect("Failed to write to blob");
653
654 blob.sync().await.expect("Failed to sync blob");
656
657 let check1 = context.spawn("test", {
659 let blob = blob.clone();
660 async move {
661 let mut buffer = vec![0u8; data.len()];
662 blob.read_at(&mut buffer, 0)
663 .await
664 .expect("Failed to read from blob");
665 assert_eq!(&buffer, data);
666 }
667 });
668 let check2 = context.spawn("test", {
669 let blob = blob.clone();
670 async move {
671 let mut buffer = vec![0u8; data.len()];
672 blob.read_at(&mut buffer, 0)
673 .await
674 .expect("Failed to read from blob");
675 assert_eq!(&buffer, data);
676 }
677 });
678
679 let result = join!(check1, check2);
681 assert!(result.0.is_ok());
682 assert!(result.1.is_ok());
683
684 let mut buffer = vec![0u8; data.len()];
686 blob.read_at(&mut buffer, 0)
687 .await
688 .expect("Failed to read from blob");
689 assert_eq!(&buffer, data);
690
691 let length = blob.len().await.expect("Failed to get blob length");
693 assert_eq!(length, data.len() as u64);
694
695 blob.close().await.expect("Failed to close blob");
697 });
698 }
699
700 fn test_shutdown(runner: impl Runner, context: impl Spawner + Clock) {
701 let kill = 9;
702 runner.start(async move {
703 let before = context.spawn("before", {
705 let context = context.clone();
706 async move {
707 let sig = context.stopped().await;
708 assert_eq!(sig.unwrap(), kill);
709 }
710 });
711
712 let after = context.spawn("after", {
714 let context = context.clone();
715 async move {
716 let mut signal = context.stopped();
718 loop {
719 select! {
720 sig = &mut signal => {
721 assert_eq!(sig.unwrap(), kill);
723 break;
724 },
725 _ = context.sleep(Duration::from_millis(10)) => {
726 },
728 }
729 }
730 }
731 });
732
733 context.sleep(Duration::from_millis(50)).await;
735
736 context.stop(kill);
738
739 let result = join!(before, after);
741 assert!(result.0.is_ok());
742 assert!(result.1.is_ok());
743 });
744 }
745
746 #[test]
747 fn test_deterministic_future() {
748 let (runner, _, _) = deterministic::Executor::default();
749 test_error_future(runner);
750 }
751
752 #[test]
753 fn test_deterministic_clock_sleep() {
754 let (executor, runtime, _) = deterministic::Executor::default();
755 assert_eq!(runtime.current(), SystemTime::UNIX_EPOCH);
756 test_clock_sleep(executor, runtime);
757 }
758
759 #[test]
760 fn test_deterministic_clock_sleep_until() {
761 let (executor, runtime, _) = deterministic::Executor::default();
762 test_clock_sleep_until(executor, runtime);
763 }
764
765 #[test]
766 fn test_deterministic_root_finishes() {
767 let (executor, runtime, _) = deterministic::Executor::default();
768 test_root_finishes(executor, runtime);
769 }
770
771 #[test]
772 fn test_deterministic_spawn_abort() {
773 let (executor, runtime, _) = deterministic::Executor::default();
774 test_spawn_abort(executor, runtime);
775 }
776
777 #[test]
778 fn test_deterministic_panic_aborts_root() {
779 let (runner, _, _) = deterministic::Executor::default();
780 test_panic_aborts_root(runner);
781 }
782
783 #[test]
784 #[should_panic(expected = "blah")]
785 fn test_deterministic_panic_aborts_spawn() {
786 let (executor, runtime, _) = deterministic::Executor::default();
787 test_panic_aborts_spawn(executor, runtime);
788 }
789
790 #[test]
791 fn test_deterministic_select() {
792 let (executor, _, _) = deterministic::Executor::default();
793 test_select(executor);
794 }
795
796 #[test]
797 fn test_deterministic_select_loop() {
798 let (executor, runtime, _) = deterministic::Executor::default();
799 test_select_loop(executor, runtime);
800 }
801
802 #[test]
803 fn test_deterministic_storage_operations() {
804 let (executor, runtime, _) = deterministic::Executor::default();
805 test_storage_operations(executor, runtime);
806 }
807
808 #[test]
809 fn test_deterministic_blob_read_write() {
810 let (executor, runtime, _) = deterministic::Executor::default();
811 test_blob_read_write(executor, runtime);
812 }
813
814 #[test]
815 fn test_deterministic_many_partition_read_write() {
816 let (executor, runtime, _) = deterministic::Executor::default();
817 test_many_partition_read_write(executor, runtime);
818 }
819
820 #[test]
821 fn test_deterministic_blob_read_past_length() {
822 let (executor, runtime, _) = deterministic::Executor::default();
823 test_blob_read_past_length(executor, runtime);
824 }
825
826 #[test]
827 fn test_deterministic_blob_clone_and_concurrent_read() {
828 let cfg = deterministic::Config {
830 registry: Arc::new(Mutex::new(Registry::default())),
831 ..Default::default()
832 };
833 let (executor, runtime, _) = deterministic::Executor::init(cfg.clone());
834 test_blob_clone_and_concurrent_read(executor, runtime);
835
836 let mut buffer = String::new();
838 encode(&mut buffer, &cfg.registry.lock().unwrap()).unwrap();
839 assert!(buffer.contains("open_blobs 0"));
840 }
841
842 #[test]
843 fn test_deterministic_shutdown() {
844 let (executor, runtime, _) = deterministic::Executor::default();
845 test_shutdown(executor, runtime);
846 }
847
848 #[test]
849 fn test_tokio_error_future() {
850 let (runner, _) = tokio::Executor::default();
851 test_error_future(runner);
852 }
853
854 #[test]
855 fn test_tokio_clock_sleep() {
856 let (executor, runtime) = tokio::Executor::default();
857 test_clock_sleep(executor, runtime);
858 }
859
860 #[test]
861 fn test_tokio_clock_sleep_until() {
862 let (executor, runtime) = tokio::Executor::default();
863 test_clock_sleep_until(executor, runtime);
864 }
865
866 #[test]
867 fn test_tokio_root_finishes() {
868 let (executor, runtime) = tokio::Executor::default();
869 test_root_finishes(executor, runtime);
870 }
871
872 #[test]
873 fn test_tokio_spawn_abort() {
874 let (executor, runtime) = tokio::Executor::default();
875 test_spawn_abort(executor, runtime);
876 }
877
878 #[test]
879 fn test_tokio_panic_aborts_root() {
880 let (runner, _) = tokio::Executor::default();
881 test_panic_aborts_root(runner);
882 }
883
884 #[test]
885 fn test_tokio_panic_aborts_spawn() {
886 let (executor, runtime) = tokio::Executor::default();
887 test_panic_aborts_spawn(executor, runtime);
888 }
889
890 #[test]
891 fn test_tokio_select() {
892 let (executor, _) = tokio::Executor::default();
893 test_select(executor);
894 }
895
896 #[test]
897 fn test_tokio_select_loop() {
898 let (executor, runtime) = tokio::Executor::default();
899 test_select_loop(executor, runtime);
900 }
901
902 #[test]
903 fn test_tokio_storage_operations() {
904 let (executor, runtime) = tokio::Executor::default();
905 test_storage_operations(executor, runtime);
906 }
907
908 #[test]
909 fn test_tokio_blob_read_write() {
910 let (executor, runtime) = tokio::Executor::default();
911 test_blob_read_write(executor, runtime);
912 }
913
914 #[test]
915 fn test_tokio_many_partition_read_write() {
916 let (executor, runtime) = tokio::Executor::default();
917 test_many_partition_read_write(executor, runtime);
918 }
919
920 #[test]
921 fn test_tokio_blob_read_past_length() {
922 let (executor, runtime) = tokio::Executor::default();
923 test_blob_read_past_length(executor, runtime);
924 }
925
926 #[test]
927 fn test_tokio_blob_clone_and_concurrent_read() {
928 let cfg = tokio::Config {
930 registry: Arc::new(Mutex::new(Registry::default())),
931 ..Default::default()
932 };
933 let (executor, runtime) = tokio::Executor::init(cfg.clone());
934 test_blob_clone_and_concurrent_read(executor, runtime);
935
936 let mut buffer = String::new();
938 encode(&mut buffer, &cfg.registry.lock().unwrap()).unwrap();
939 assert!(buffer.contains("open_blobs 0"));
940 }
941
942 #[test]
943 fn test_tokio_shutdown() {
944 let (executor, runtime) = tokio::Executor::default();
945 test_shutdown(executor, runtime);
946 }
947}