use shuttle::{
check_dfs, check_random,
current::{get_label_for_task, me, set_label_for_task, set_name_for_task, ChildLabelFn, TaskName},
future, thread,
};
use std::collections::HashSet;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use test_log::test;
use tracing::field::{Field, Visit};
use tracing::span::{Attributes, Record};
use tracing::{Event, Id, Metadata, Subscriber};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct Ident(usize);
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct Parent(usize);
async fn spawn_tasks(counter: Arc<AtomicUsize>) -> HashSet<(usize, usize)> {
set_label_for_task(me(), Ident(1));
let handles = (0..2)
.map(|i| {
let counter2 = counter.clone();
future::spawn(async move {
let Ident(parent) = get_label_for_task(me()).unwrap();
set_label_for_task(me(), Parent(parent));
set_label_for_task(me(), Ident(2 * parent + i));
let handles = (0..2)
.map(|j| {
let counter3 = counter2.clone();
future::spawn(async move {
let Ident(parent) = get_label_for_task(me()).unwrap();
set_label_for_task(me(), Parent(parent));
set_label_for_task(me(), Ident(2 * parent + j));
counter3.fetch_add(1usize, Ordering::SeqCst);
let Parent(p) = get_label_for_task::<Parent>(me()).unwrap();
let Ident(c) = get_label_for_task::<Ident>(me()).unwrap();
(p, c)
})
})
.collect::<Vec<_>>();
let Parent(p) = get_label_for_task::<Parent>(me()).unwrap();
let Ident(c) = get_label_for_task::<Ident>(me()).unwrap();
(p, c, handles)
})
})
.collect::<Vec<_>>();
let mut values = HashSet::new();
for h in handles.into_iter() {
let (a, b, handles) = h.await.unwrap();
for h2 in handles.into_iter() {
let v2 = h2.await.unwrap();
assert!(values.insert(v2));
}
assert!(values.insert((a, b)));
}
let Ident(c) = get_label_for_task(me()).unwrap();
assert_eq!(c, 1);
assert_eq!(get_label_for_task::<Parent>(me()), None);
values
}
#[test]
fn task_inheritance() {
check_random(
|| {
let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone();
let seen_values = future::block_on(async move { spawn_tasks(counter2).await });
assert_eq!(counter.load(Ordering::SeqCst), 2usize.pow(2));
let expected_values = HashSet::from([(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)]);
assert_eq!(seen_values, expected_values);
},
10_000,
);
}
fn spawn_threads(counter: Arc<AtomicUsize>) -> HashSet<(usize, usize)> {
set_label_for_task(me(), Ident(1));
let handles = (0..2)
.map(|i| {
let counter2 = counter.clone();
thread::spawn(move || {
let Ident(parent) = get_label_for_task(me()).unwrap();
set_label_for_task(me(), Parent(parent));
set_label_for_task(me(), Ident(2 * parent + i));
let handles: Vec<thread::JoinHandle<(usize, usize)>> = (0..2)
.map(|j| {
let counter3 = counter2.clone();
thread::spawn(move || {
let Ident(parent) = get_label_for_task(me()).unwrap();
set_label_for_task(me(), Parent(parent));
set_label_for_task(me(), Ident(2 * parent + j));
counter3.fetch_add(1usize, Ordering::SeqCst);
let Parent(p) = get_label_for_task::<Parent>(me()).unwrap();
let Ident(c) = get_label_for_task::<Ident>(me()).unwrap();
(p, c)
})
})
.collect::<Vec<_>>();
let Parent(p) = get_label_for_task::<Parent>(me()).unwrap();
let Ident(c) = get_label_for_task::<Ident>(me()).unwrap();
(p, c, handles)
})
})
.collect::<Vec<_>>();
let mut values = HashSet::new();
for h in handles.into_iter() {
let (a, b, handles) = h.join().unwrap();
for h2 in handles.into_iter() {
let (c, d) = h2.join().unwrap();
assert!(values.insert((c, d)));
}
assert!(values.insert((a, b)));
}
let Ident(c) = get_label_for_task(me()).unwrap();
assert_eq!(c, 1);
assert_eq!(get_label_for_task::<Parent>(me()), None);
values
}
#[test]
fn thread_inheritance() {
check_random(
|| {
let counter = Arc::new(AtomicUsize::new(0));
let counter2 = counter.clone();
let seen_values = spawn_threads(counter2);
assert_eq!(counter.load(Ordering::SeqCst), 2usize.pow(2));
let expected_values = HashSet::from([(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)]);
assert_eq!(seen_values, expected_values);
},
10_000,
);
}
#[test]
fn label_modify() {
check_dfs(
|| {
set_label_for_task(me(), Ident(0));
let parent_id = me();
let child = thread::spawn(move || {
set_label_for_task(parent_id, Ident(1));
assert_ne!(me(), parent_id);
get_label_for_task::<Ident>(me()).unwrap()
});
let child_id = child.join().unwrap();
let my_label = get_label_for_task::<Ident>(me()).unwrap();
assert_eq!(my_label, Ident(1)); assert_eq!(child_id, Ident(0)); },
None,
);
}
async fn label_fn_inner(set_name_before_spawn: bool) {
let handles = (0..3).map(|_| {
if set_name_before_spawn {
set_label_for_task(
me(),
ChildLabelFn(Arc::new(|_task_id, labels| {
labels.insert(TaskName::from("Child"));
})),
);
}
future::spawn(async move {
if !set_name_before_spawn {
set_name_for_task(me(), TaskName::from("Child"));
}
shuttle::future::yield_now().await;
})
});
for h in handles {
h.await.unwrap();
}
}
#[test]
fn test_tracing_with_label_fn() {
let metrics = RunnableSubscriber {};
let _guard = tracing::subscriber::set_default(metrics);
check_random(
|| {
future::block_on(async { label_fn_inner(true).await });
},
10,
);
}
#[test]
#[should_panic(expected = "assertion failed")]
fn test_tracing_without_label_fn() {
let metrics = RunnableSubscriber {};
let _guard = tracing::subscriber::set_default(metrics);
check_random(
|| {
future::block_on(async { label_fn_inner(false).await });
},
1, );
}
struct RunnableSubscriber;
impl Subscriber for RunnableSubscriber {
fn enabled(&self, _metadata: &Metadata<'_>) -> bool {
true
}
fn new_span(&self, _span: &Attributes<'_>) -> Id {
Id::from_u64(1)
}
fn record(&self, _span: &Id, _values: &Record<'_>) {}
fn record_follows_from(&self, _span: &Id, _follows: &Id) {}
fn event(&self, event: &Event<'_>) {
let metadata = event.metadata();
let target = metadata.target();
if target.contains("shuttle") && target.ends_with("::runtime::execution") {
let fields: &tracing::field::FieldSet = metadata.fields();
if fields.iter().any(|f| f.name() == "runnable") {
struct CheckRunnableSubscriber;
impl Visit for CheckRunnableSubscriber {
fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
if field.name() == "runnable" {
let value = format!("{:?}", value).replace('[', "").replace(']', "");
let v1 = value.split(',');
assert!(v1
.map(|s| s.trim())
.all(|s| (s == "main-thread(0)") || s.starts_with("Child(")));
}
}
}
let mut visitor = CheckRunnableSubscriber {};
event.record(&mut visitor);
}
}
}
fn enter(&self, _span: &Id) {}
fn exit(&self, _span: &Id) {}
}