#![cfg(not(feature = "shuttle"))]
use salsa::Database;
use crate::setup::{Knobs, KnobsDatabase};
use crate::sync::thread;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)]
struct CycleValue(u32);
const MIN: CycleValue = CycleValue(0);
const MAX: CycleValue = CycleValue(3);
#[salsa::tracked(cycle_initial=initial)]
fn query_a(db: &dyn KnobsDatabase) -> CycleValue {
query_b(db)
}
#[salsa::tracked(cycle_initial=initial)]
fn query_b(db: &dyn KnobsDatabase) -> CycleValue {
let c_value = query_c(db);
CycleValue(c_value.0 + 1).min(MAX)
}
#[salsa::tracked(cycle_initial=initial)]
fn query_c(db: &dyn KnobsDatabase) -> CycleValue {
let d_value = query_d(db);
let e_value = query_e(db);
let b_value = query_b(db);
let a_value = query_a(db);
CycleValue(d_value.0.max(e_value.0).max(b_value.0).max(a_value.0))
}
#[salsa::tracked(cycle_initial=initial)]
fn query_d(db: &dyn KnobsDatabase) -> CycleValue {
query_c(db)
}
#[salsa::tracked(cycle_initial=initial)]
fn query_e(db: &dyn KnobsDatabase) -> CycleValue {
query_c(db)
}
#[salsa::tracked]
fn query_f(db: &dyn KnobsDatabase) -> CycleValue {
let c = query_c(db);
query_h(db);
c
}
#[salsa::tracked]
fn query_h(db: &dyn KnobsDatabase) {
_ = db;
}
fn initial(db: &dyn KnobsDatabase, _id: salsa::Id) -> CycleValue {
db.signal(1);
db.wait_for(6);
MIN
}
#[test]
fn multi_threaded_cycle_completes_despite_cancellation() {
let db = Knobs::default();
let db_t1 = db.clone();
let db_t2 = db.clone();
let db_t3 = db.clone();
let db_t4 = db.clone();
let db_t5 = db.clone();
let db_signaler = db;
let token_t1 = db_t1.cancellation_token();
let token_t2 = db_t2.cancellation_token();
let token_t3 = db_t3.cancellation_token();
let token_t5 = db_t5.cancellation_token();
let t1 = thread::spawn(move || query_a(&db_t1));
db_signaler.wait_for(1);
db_signaler.signal_on_will_block(2);
let t2 = thread::spawn(move || query_b(&db_t2));
db_signaler.wait_for(2);
db_signaler.signal_on_will_block(3);
let t3 = thread::spawn(move || query_d(&db_t3));
db_signaler.wait_for(3);
db_signaler.signal_on_will_block(4);
let t4 = thread::spawn(move || query_e(&db_t4));
db_signaler.wait_for(4);
db_signaler.signal_on_will_block(5);
let t5 = thread::spawn(move || query_f(&db_t5));
db_signaler.wait_for(5);
token_t1.cancel();
token_t2.cancel();
token_t3.cancel();
token_t5.cancel();
db_signaler.signal(6);
let r_t1 = t1.join().unwrap();
let r_t2 = t2.join().unwrap();
let r_t3 = t3.join().unwrap();
let r_t4 = t4.join().unwrap();
let r_t5 = t5.join().unwrap_err();
assert_eq!(r_t1, MAX, "t1 should get MAX");
assert_eq!(r_t2, MAX, "t2 should get MAX");
assert_eq!(r_t3, MAX, "t3 should get MAX");
assert_eq!(r_t4, MAX, "t4 should get MAX");
assert!(
matches!(
*r_t5.downcast::<salsa::Cancelled>().unwrap(),
salsa::Cancelled::Local
),
"t5 should be cancelled as its blocked on the cycle, not participating in it and calling an uncomputed query after"
);
}