use std::collections::HashMap;
use std::hash::Hasher;
use crate::builder::TaskHandle;
use crate::runner::{DagRunner, PassThroughHasher};
#[cfg(feature = "tracing")]
fn init_tracing() {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| {
tracing_subscriber::fmt()
.with_test_writer()
.with_max_level(tracing::Level::TRACE)
.try_init()
.ok();
});
}
#[cfg(not(feature = "tracing"))]
fn init_tracing() {
}
struct TestTask {
value: i32,
}
#[crate::task]
impl TestTask {
async fn run(&self) -> i32 {
self.value
}
}
#[test]
fn test_dag_runner_new() {
let mut dag = DagRunner::new();
assert_eq!(dag.nodes.len(), 0);
assert_eq!(dag.edges.len(), 0);
assert_eq!(dag.dependents.len(), 0);
dag.add_task(TestTask { value: 42 });
assert_eq!(dag.nodes.len(), 1);
}
#[test]
fn test_dag_runner_default() {
let dag = DagRunner::default();
assert_eq!(dag.nodes.len(), 0);
assert_eq!(dag.edges.len(), 0);
assert_eq!(dag.dependents.len(), 0);
}
#[tokio::test]
#[should_panic]
async fn test_get_wrong_type() {
init_tracing();
let mut dag = DagRunner::new();
let handle = dag.add_task(TestTask { value: 42 });
let mut output = dag
.run(|fut| async move { tokio::spawn(fut).await.unwrap() })
.await
.unwrap();
let fake_handle: TaskHandle<String> = TaskHandle {
id: handle.id,
_phantom: std::marker::PhantomData,
};
let _result = output.get(fake_handle);
}
#[tokio::test]
#[should_panic]
async fn test_invalid_node_id_in_get() {
init_tracing();
let dag = DagRunner::new();
let invalid_handle: TaskHandle<i32> = TaskHandle {
id: crate::builder::NodeId(999),
_phantom: std::marker::PhantomData,
};
let mut output = dag
.run(|fut| async move { tokio::spawn(fut).await.unwrap() })
.await
.unwrap();
let _result = output.get(invalid_handle);
}
#[tokio::test]
async fn test_task_panic_in_multi_task_layer() {
init_tracing();
struct Source;
#[crate::task]
impl Source {
async fn run(&self) -> i32 {
42
}
}
struct PanicTask;
#[crate::task]
impl PanicTask {
async fn run(&self, _input: &i32) -> i32 {
panic!("This task panics!");
}
}
struct GoodTask;
#[crate::task]
impl GoodTask {
async fn run(&self, input: &i32) -> i32 {
input * 2
}
}
let mut dag = DagRunner::new();
let source = dag.add_task(Source);
let _panic_task = dag.add_task(PanicTask).depends_on(&source);
let _good_task = dag.add_task(GoodTask).depends_on(&source);
let result = dag
.run(|fut| async move { tokio::spawn(fut).await.unwrap() })
.await;
assert!(result.is_err());
}
#[test]
fn test_passthrough_hasher() {
let mut passthrough_hashmap = HashMap::with_hasher(PassThroughHasher::default());
passthrough_hashmap.insert(64u32, "test string");
}
#[test]
#[should_panic]
fn test_passthrough_hasher_only_u32() {
let mut passthrough_hasher = PassThroughHasher::default();
passthrough_hasher.write(&[0; 4]);
}