use crate::types::{RegionId, TaskId, Time};
use std::collections::BTreeMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct RRefId {
pub owner_region: RegionId,
pub alloc_index: u32,
}
impl RRefId {
#[must_use]
pub const fn new(owner_region: RegionId, alloc_index: u32) -> Self {
Self {
owner_region,
alloc_index,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RRefAccessViolationKind {
CrossRegionAccess {
rref_region: RegionId,
task_region: RegionId,
},
PostCloseAccess {
region: RegionId,
close_time: Time,
access_time: Time,
},
WitnessMismatch {
rref_region: RegionId,
witness_region: RegionId,
},
}
#[derive(Debug, Clone)]
pub struct RRefAccessViolation {
pub rref: RRefId,
pub task: TaskId,
pub time: Time,
pub kind: RRefAccessViolationKind,
}
impl fmt::Display for RRefAccessViolation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
RRefAccessViolationKind::CrossRegionAccess {
rref_region,
task_region,
} => {
write!(
f,
"Cross-region RRef access: task {:?} (region {:?}) \
accessed RRef owned by region {:?} at {:?}",
self.task, task_region, rref_region, self.time
)
}
RRefAccessViolationKind::PostCloseAccess {
region,
close_time,
access_time,
} => {
write!(
f,
"Post-close RRef access: task {:?} accessed RRef in \
closed region {:?} (closed at {:?}) at {:?}",
self.task, region, close_time, access_time
)
}
RRefAccessViolationKind::WitnessMismatch {
rref_region,
witness_region,
} => {
write!(
f,
"Witness mismatch: task {:?} used witness for region {:?} \
to access RRef in region {:?} at {:?}",
self.task, witness_region, rref_region, self.time
)
}
}
}
}
impl std::error::Error for RRefAccessViolation {}
#[derive(Debug, Default)]
pub struct RRefAccessOracle {
rrefs: BTreeMap<RRefId, RegionId>,
closed_regions: BTreeMap<RegionId, Time>,
task_regions: BTreeMap<TaskId, RegionId>,
violations: Vec<RRefAccessViolation>,
}
impl RRefAccessOracle {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn on_rref_create(&mut self, rref: RRefId, owner_region: RegionId) {
self.rrefs.insert(rref, owner_region);
}
pub fn on_task_spawn(&mut self, task: TaskId, region: RegionId) {
self.task_regions.insert(task, region);
}
pub fn on_rref_access(&mut self, rref: RRefId, task: TaskId, time: Time) {
let task_region = self.task_regions.get(&task).copied();
let rref_region = self.rrefs.get(&rref).copied().unwrap_or(rref.owner_region);
if let Some(task_reg) = task_region {
if task_reg != rref_region {
self.violations.push(RRefAccessViolation {
rref,
task,
time,
kind: RRefAccessViolationKind::CrossRegionAccess {
rref_region,
task_region: task_reg,
},
});
}
}
if let Some(&close_time) = self.closed_regions.get(&rref_region) {
if time >= close_time {
self.violations.push(RRefAccessViolation {
rref,
task,
time,
kind: RRefAccessViolationKind::PostCloseAccess {
region: rref_region,
close_time,
access_time: time,
},
});
}
}
}
pub fn on_rref_access_with_witness(
&mut self,
rref: RRefId,
task: TaskId,
witness_region: RegionId,
time: Time,
) {
let rref_region = self.rrefs.get(&rref).copied().unwrap_or(rref.owner_region);
if witness_region != rref_region {
self.violations.push(RRefAccessViolation {
rref,
task,
time,
kind: RRefAccessViolationKind::WitnessMismatch {
rref_region,
witness_region,
},
});
}
self.on_rref_access(rref, task, time);
}
pub fn on_region_close(&mut self, region: RegionId, time: Time) {
self.closed_regions.insert(region, time);
}
pub fn check(&self) -> Result<(), RRefAccessViolation> {
if let Some(v) = self.violations.first() {
return Err(v.clone());
}
Ok(())
}
#[must_use]
pub fn all_violations(&self) -> &[RRefAccessViolation] {
&self.violations
}
pub fn reset(&mut self) {
self.rrefs.clear();
self.closed_regions.clear();
self.task_regions.clear();
self.violations.clear();
}
#[must_use]
pub fn rref_count(&self) -> usize {
self.rrefs.len()
}
#[must_use]
pub fn task_count(&self) -> usize {
self.task_regions.len()
}
#[must_use]
pub fn closed_region_count(&self) -> usize {
self.closed_regions.len()
}
#[must_use]
pub fn violation_count(&self) -> usize {
self.violations.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::ArenaIndex;
fn region(n: u32) -> RegionId {
RegionId::from_arena(ArenaIndex::new(n, 0))
}
fn task(n: u32) -> TaskId {
TaskId::from_arena(ArenaIndex::new(n, 0))
}
fn t(nanos: u64) -> Time {
Time::from_nanos(nanos)
}
fn rref(region_n: u32, alloc: u32) -> RRefId {
RRefId::new(region(region_n), alloc)
}
#[test]
fn same_region_access_no_violation() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_task_spawn(tid, r);
oracle.on_rref_access(rref(0, 0), tid, t(10));
assert!(oracle.check().is_ok());
assert_eq!(oracle.violation_count(), 0);
}
#[test]
fn access_before_close_no_violation() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_task_spawn(tid, r);
oracle.on_rref_access(rref(0, 0), tid, t(10));
oracle.on_region_close(r, t(100));
assert!(oracle.check().is_ok());
}
#[test]
fn witness_matching_no_violation() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_task_spawn(tid, r);
oracle.on_rref_access_with_witness(rref(0, 0), tid, r, t(10));
assert!(oracle.check().is_ok());
}
#[test]
fn multiple_rrefs_same_region_no_violation() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_rref_create(rref(0, 1), r);
oracle.on_rref_create(rref(0, 2), r);
oracle.on_task_spawn(tid, r);
oracle.on_rref_access(rref(0, 0), tid, t(10));
oracle.on_rref_access(rref(0, 1), tid, t(20));
oracle.on_rref_access(rref(0, 2), tid, t(30));
assert!(oracle.check().is_ok());
}
#[test]
fn cross_region_access_detected() {
let mut oracle = RRefAccessOracle::new();
let r_a = region(0);
let r_b = region(1);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r_a);
oracle.on_task_spawn(tid, r_b); oracle.on_rref_access(rref(0, 0), tid, t(10));
let err = oracle.check().unwrap_err();
assert_eq!(
err.kind,
RRefAccessViolationKind::CrossRegionAccess {
rref_region: r_a,
task_region: r_b,
}
);
}
#[test]
fn post_close_access_detected() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_task_spawn(tid, r);
oracle.on_region_close(r, t(50));
oracle.on_rref_access(rref(0, 0), tid, t(100));
let err = oracle.check().unwrap_err();
assert_eq!(
err.kind,
RRefAccessViolationKind::PostCloseAccess {
region: r,
close_time: t(50),
access_time: t(100),
}
);
}
#[test]
fn access_at_close_time_detected() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_task_spawn(tid, r);
oracle.on_region_close(r, t(50));
oracle.on_rref_access(rref(0, 0), tid, t(50));
assert!(oracle.check().is_err());
}
#[test]
fn witness_mismatch_detected() {
let mut oracle = RRefAccessOracle::new();
let r_a = region(0);
let r_b = region(1);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r_a);
oracle.on_task_spawn(tid, r_a);
oracle.on_rref_access_with_witness(rref(0, 0), tid, r_b, t(10));
let violations = oracle.all_violations();
assert!(
violations
.iter()
.any(|v| matches!(v.kind, RRefAccessViolationKind::WitnessMismatch { .. }))
);
}
#[test]
fn multiple_violations_all_recorded() {
let mut oracle = RRefAccessOracle::new();
let r_a = region(0);
let r_b = region(1);
let t1 = task(1);
let t2 = task(2);
oracle.on_rref_create(rref(0, 0), r_a);
oracle.on_rref_create(rref(1, 0), r_b);
oracle.on_task_spawn(t1, r_a);
oracle.on_task_spawn(t2, r_b);
oracle.on_rref_access(rref(0, 0), t2, t(10));
oracle.on_region_close(r_b, t(20));
oracle.on_rref_access(rref(1, 0), t1, t(30));
let violations = oracle.all_violations();
assert_eq!(violations.len(), 3);
}
#[test]
fn reset_clears_all_state() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
let tid = task(1);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_task_spawn(tid, r);
oracle.on_region_close(r, t(10));
oracle.on_rref_access(rref(0, 0), tid, t(20));
assert!(oracle.check().is_err());
oracle.reset();
assert!(oracle.check().is_ok());
assert_eq!(oracle.rref_count(), 0);
assert_eq!(oracle.task_count(), 0);
assert_eq!(oracle.closed_region_count(), 0);
assert_eq!(oracle.violation_count(), 0);
}
#[test]
fn stats_track_entities() {
let mut oracle = RRefAccessOracle::new();
let r = region(0);
oracle.on_rref_create(rref(0, 0), r);
oracle.on_rref_create(rref(0, 1), r);
oracle.on_task_spawn(task(1), r);
oracle.on_task_spawn(task(2), r);
oracle.on_task_spawn(task(3), r);
oracle.on_region_close(r, t(100));
assert_eq!(oracle.rref_count(), 2);
assert_eq!(oracle.task_count(), 3);
assert_eq!(oracle.closed_region_count(), 1);
}
#[test]
fn violation_display_formats() {
let r_a = region(0);
let r_b = region(1);
let tid = task(1);
let violations = vec![
RRefAccessViolation {
rref: rref(0, 0),
task: tid,
time: t(10),
kind: RRefAccessViolationKind::CrossRegionAccess {
rref_region: r_a,
task_region: r_b,
},
},
RRefAccessViolation {
rref: rref(0, 0),
task: tid,
time: t(100),
kind: RRefAccessViolationKind::PostCloseAccess {
region: r_a,
close_time: t(50),
access_time: t(100),
},
},
RRefAccessViolation {
rref: rref(0, 0),
task: tid,
time: t(10),
kind: RRefAccessViolationKind::WitnessMismatch {
rref_region: r_a,
witness_region: r_b,
},
},
];
for v in &violations {
let msg = format!("{v}");
assert!(!msg.is_empty(), "violation display should not be empty");
}
}
#[test]
fn rref_id_debug_clone_copy_eq_ord_hash() {
use std::collections::HashSet;
let id = rref(0, 5);
let id2 = id; let id3 = id; assert_eq!(id, id2);
assert_eq!(id, id3);
assert_ne!(id, rref(0, 6));
assert!(id < rref(0, 6));
let dbg = format!("{id:?}");
assert!(dbg.contains("RRefId"));
let mut set = HashSet::new();
set.insert(id);
assert!(set.contains(&id2));
}
#[test]
fn rref_access_violation_kind_debug_clone_eq() {
let r_a = region(0);
let r_b = region(1);
let k = RRefAccessViolationKind::CrossRegionAccess {
rref_region: r_a,
task_region: r_b,
};
let k2 = k.clone();
assert_eq!(k, k2);
assert_ne!(
k,
RRefAccessViolationKind::WitnessMismatch {
rref_region: r_a,
witness_region: r_b,
}
);
let dbg = format!("{k:?}");
assert!(dbg.contains("CrossRegionAccess"));
}
#[test]
fn rref_access_violation_debug_clone() {
let v = RRefAccessViolation {
rref: rref(0, 1),
task: task(2),
time: t(100),
kind: RRefAccessViolationKind::PostCloseAccess {
region: region(0),
close_time: t(50),
access_time: t(100),
},
};
let v2 = v.clone();
assert_eq!(v.rref, v2.rref);
assert_eq!(v.task, v2.task);
let dbg = format!("{v:?}");
assert!(dbg.contains("RRefAccessViolation"));
}
}