shuttle 0.9.0

A library for testing concurrent Rust code
Documentation
use shuttle::{future, scheduler::RandomScheduler, thread, Runner};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use test_log::test;
use tracing::field::{Field, Visit};
use tracing::{trace, Event, Id, Metadata, Subscriber};

type SigCounterMap = Arc<Mutex<HashMap<u64, usize>>>;
#[derive(Clone)]
struct SignatureSubscriber {
    signatures: SigCounterMap,
    static_create_locations: SigCounterMap,
}

impl SignatureSubscriber {
    pub fn new() -> Self {
        Self {
            signatures: Arc::new(Mutex::new(HashMap::new())),
            static_create_locations: Arc::new(Mutex::new(HashMap::new())),
        }
    }
}

impl Subscriber for SignatureSubscriber {
    fn enabled(&self, _metadata: &Metadata<'_>) -> bool {
        true
    }

    fn new_span(&self, _span: &tracing::span::Attributes<'_>) -> Id {
        Id::from_u64(1)
    }

    fn record(&self, _span: &Id, _values: &tracing::span::Record<'_>) {}

    fn record_follows_from(&self, _span: &Id, _follows: &Id) {}

    fn event(&self, event: &Event<'_>) {
        let metadata = event.metadata();
        if metadata.target() == "shuttle::runtime::task" && metadata.level() == &tracing::Level::DEBUG {
            struct SignatureVisitor {
                task_id: Option<String>,
                signature: Option<u64>,
                static_create_location: Option<u64>,
            }
            impl Visit for SignatureVisitor {
                fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) {
                    if field.name() == "task_id" {
                        self.task_id = Some(format!("{value:?}"));
                    }
                }
                fn record_u64(&mut self, field: &Field, value: u64) {
                    if field.name() == "signature" {
                        self.signature = Some(value);
                    }
                    if field.name() == "static_create_location" {
                        self.static_create_location = Some(value);
                    }
                }
            }
            let mut visitor = SignatureVisitor {
                task_id: None,
                signature: None,
                static_create_location: None,
            };
            event.record(&mut visitor);

            if let Some(sig) = visitor.signature {
                self.signatures
                    .lock()
                    .unwrap()
                    .entry(sig)
                    .and_modify(|counter| *counter += 1)
                    .or_insert(1);
            }
            if let Some(loc) = visitor.static_create_location {
                self.static_create_locations
                    .lock()
                    .unwrap()
                    .entry(loc)
                    .and_modify(|counter| *counter += 1)
                    .or_insert(1);
            }
        }
    }

    fn enter(&self, _span: &Id) {}
    fn exit(&self, _span: &Id) {}
}

pub fn check_any_n_same_signatures(signatures: &SigCounterMap, expected_count: usize) {
    let signatures = signatures.lock().unwrap();
    trace!("Total signatures captured: {}", signatures.len());

    trace!("Signature counts: {:?}", signatures);

    let worker_signatures: Vec<u64> = signatures
        .iter()
        .filter(|(_, count)| **count == expected_count)
        .map(|(sig, _)| *sig)
        .collect();

    assert_eq!(
        worker_signatures.len(),
        1,
        "Should have exactly one signature appearing {expected_count} times"
    );
    trace!(
        "{} tasks have the same signature: {}",
        expected_count,
        worker_signatures[0]
    );
}

pub fn check_exactly_n_different_signatures(signatures: &SigCounterMap, expected_count: usize) {
    let signatures = signatures.lock().unwrap();
    trace!("Total signatures captured: {}", signatures.len());

    trace!("Signature counts: {:?}", signatures);

    let unique_signatures: Vec<u64> = signatures.keys().cloned().collect();
    assert_eq!(
        unique_signatures.len(),
        expected_count,
        "Should have {expected_count} different signatures"
    );
    trace!(
        "All {} tasks have different signatures: {:?}",
        expected_count,
        unique_signatures
    );
}

pub fn run_test_n_iterations_with_subscriber<F>(test_fn: F, iterations: usize) -> (SigCounterMap, SigCounterMap)
where
    F: Fn() + Send + Sync + 'static,
{
    let subscriber = SignatureSubscriber::new();
    let signatures = Arc::clone(&subscriber.signatures);
    let static_create_locations = Arc::clone(&subscriber.static_create_locations);
    let _guard = tracing::subscriber::set_default(subscriber);

    let scheduler = RandomScheduler::new(iterations);
    let runner = Runner::new(scheduler, Default::default());
    runner.run(test_fn);

    (signatures, static_create_locations)
}

#[test]
fn sync_tasks_created_at_same_src_location_have_different_signatures() {
    fn worker_function() {}

    let (signatures, static_create_locations) = run_test_n_iterations_with_subscriber(
        || {
            let mut handles = Vec::new();
            for _ in 0..10 {
                handles.push(thread::spawn(worker_function));
            }
            for handle in handles {
                handle.join().unwrap();
            }
        },
        1,
    );

    check_any_n_same_signatures(&static_create_locations, 10);

    check_exactly_n_different_signatures(&signatures, 11);
}

#[test]
fn async_tasks_created_at_same_src_location_have_different_signatures() {
    async fn async_worker_function() {}

    let (signatures, static_create_locations) = run_test_n_iterations_with_subscriber(
        || {
            let mut handles = Vec::new();
            for _ in 0..10 {
                handles.push(future::spawn(async_worker_function()));
            }
            future::block_on(async {
                for handle in handles {
                    handle.await.unwrap();
                }
            });
        },
        1,
    );

    check_any_n_same_signatures(&static_create_locations, 10);
    check_exactly_n_different_signatures(&signatures, 11);
}

#[test]
fn sync_tasks_created_in_different_src_locations_have_different_signatures() {
    fn worker_function() {}

    let (signatures, _) = run_test_n_iterations_with_subscriber(
        || {
            let handle1 = thread::spawn(worker_function);
            let handle2 = thread::spawn(worker_function);
            handle1.join().unwrap();
            handle2.join().unwrap();
        },
        1,
    );

    check_exactly_n_different_signatures(&signatures, 3);
}

#[test]
fn async_tasks_created_in_different_src_locations_have_different_signatures() {
    async fn async_worker_function() {}

    let (signatures, _) = run_test_n_iterations_with_subscriber(
        || {
            let handle1 = future::spawn(async_worker_function());
            let handle2 = future::spawn(async_worker_function());
            future::block_on(async {
                handle1.await.unwrap();
                handle2.await.unwrap();
            });
        },
        1,
    );

    check_exactly_n_different_signatures(&signatures, 3);
}

#[test]
fn task_signatures_consistent_across_shuttle_iterations() {
    fn worker_with_nested_spawn() {
        let handle = thread::spawn(|| {});
        handle.join().unwrap();
    }

    let (signatures, _) = run_test_n_iterations_with_subscriber(
        || {
            let handle1 = thread::spawn(worker_with_nested_spawn);
            let handle2 = thread::spawn(worker_with_nested_spawn);
            handle1.join().unwrap();
            handle2.join().unwrap();
        },
        100,
    );

    check_exactly_n_different_signatures(&signatures, 5);
}