use crate::types::{RegionId, TaskId, Time};
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone)]
pub struct TaskLeakViolation {
pub region: RegionId,
pub leaked_tasks: Vec<TaskId>,
pub region_close_time: Time,
}
impl fmt::Display for TaskLeakViolation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Region {:?} closed at {:?} with {} leaked task(s): {:?}",
self.region,
self.region_close_time,
self.leaked_tasks.len(),
self.leaked_tasks
)
}
}
impl std::error::Error for TaskLeakViolation {}
#[derive(Debug, Default)]
pub struct TaskLeakOracle {
tasks_by_region: HashMap<RegionId, HashSet<TaskId>>,
completed_tasks: HashSet<TaskId>,
region_closes: HashMap<RegionId, Time>,
violations: Vec<TaskLeakViolation>,
}
impl TaskLeakOracle {
#[must_use]
pub fn new() -> Self {
Self::default()
}
fn record_leaked_tasks(
&mut self,
region: RegionId,
region_close_time: Time,
leaked_tasks: impl IntoIterator<Item = TaskId>,
) {
if let Some(existing) = self
.violations
.iter_mut()
.find(|violation| violation.region == region)
{
for task in leaked_tasks {
if !existing.leaked_tasks.contains(&task) {
existing.leaked_tasks.push(task);
}
}
existing.leaked_tasks.sort();
return;
}
let mut leaked_tasks: Vec<TaskId> = leaked_tasks.into_iter().collect();
leaked_tasks.sort();
if !leaked_tasks.is_empty() {
self.violations.push(TaskLeakViolation {
region,
leaked_tasks,
region_close_time,
});
}
}
pub fn on_spawn(&mut self, task: TaskId, region: RegionId, _time: Time) {
self.tasks_by_region.entry(region).or_default().insert(task);
if let Some(®ion_close_time) = self.region_closes.get(®ion) {
self.record_leaked_tasks(region, region_close_time, [task]);
}
}
pub fn on_complete(&mut self, task: TaskId, _time: Time) {
self.completed_tasks.insert(task);
}
pub fn on_region_close(&mut self, region: RegionId, time: Time) {
self.region_closes.insert(region, time);
let Some(tasks) = self.tasks_by_region.get(®ion) else {
return;
};
let mut leaked: Vec<TaskId> = tasks
.iter()
.copied()
.filter(|task| !self.completed_tasks.contains(task))
.collect();
leaked.sort();
self.record_leaked_tasks(region, time, leaked);
}
pub fn check(&self, _now: Time) -> Result<(), TaskLeakViolation> {
if let Some(violation) = self
.violations
.iter()
.min_by_key(|violation| (violation.region, violation.region_close_time))
{
return Err(violation.clone());
}
Ok(())
}
pub fn reset(&mut self) {
self.tasks_by_region.clear();
self.completed_tasks.clear();
self.region_closes.clear();
self.violations.clear();
}
#[must_use]
pub fn task_count(&self) -> usize {
self.tasks_by_region.values().map(HashSet::len).sum()
}
#[must_use]
pub fn completed_count(&self) -> usize {
self.completed_tasks.len()
}
#[must_use]
pub fn closed_region_count(&self) -> usize {
self.region_closes.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::util::ArenaIndex;
fn task(n: u32) -> TaskId {
TaskId::from_arena(ArenaIndex::new(n, 0))
}
fn region(n: u32) -> RegionId {
RegionId::from_arena(ArenaIndex::new(n, 0))
}
fn t(nanos: u64) -> Time {
Time::from_nanos(nanos)
}
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
fn scrub_task_leak_display(output: &str) -> String {
output
.replace("RegionId(0:0)", "[REGION_ID]")
.replace("TaskId(1:0)", "[TASK_ID_1]")
.replace("TaskId(2:0)", "[TASK_ID_2]")
}
#[test]
fn no_tasks_passes() {
init_test("no_tasks_passes");
let oracle = TaskLeakOracle::new();
let ok = oracle.check(t(100)).is_ok();
crate::assert_with_log!(ok, "ok", true, ok);
crate::test_complete!("no_tasks_passes");
}
#[test]
fn all_tasks_complete_passes() {
init_test("all_tasks_complete_passes");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_spawn(task(2), region(0), t(20));
oracle.on_complete(task(1), t(50));
oracle.on_complete(task(2), t(60));
oracle.on_region_close(region(0), t(100));
let ok = oracle.check(t(100)).is_ok();
crate::assert_with_log!(ok, "ok", true, ok);
crate::test_complete!("all_tasks_complete_passes");
}
#[test]
fn leaked_task_fails() {
init_test("leaked_task_fails");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_spawn(task(2), region(0), t(20));
oracle.on_complete(task(1), t(50));
oracle.on_region_close(region(0), t(100));
let result = oracle.check(t(100));
let err = result.is_err();
crate::assert_with_log!(err, "err", true, err);
let violation = result.unwrap_err();
crate::assert_with_log!(
violation.region == region(0),
"region",
region(0),
violation.region
);
crate::assert_with_log!(
violation.leaked_tasks == vec![task(2)],
"leaked_tasks",
vec![task(2)],
violation.leaked_tasks
);
crate::assert_with_log!(
violation.region_close_time == t(100),
"close_time",
t(100),
violation.region_close_time
);
crate::test_complete!("leaked_task_fails");
}
#[test]
fn no_tasks_complete_all_leak() {
init_test("no_tasks_complete_all_leak");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_spawn(task(2), region(0), t(20));
oracle.on_region_close(region(0), t(100));
let result = oracle.check(t(100));
let err = result.is_err();
crate::assert_with_log!(err, "err", true, err);
let violation = result.unwrap_err();
let len = violation.leaked_tasks.len();
crate::assert_with_log!(len == 2, "leaked_tasks len", 2, len);
crate::test_complete!("no_tasks_complete_all_leak");
}
#[test]
fn multiple_regions_independent() {
init_test("multiple_regions_independent");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_complete(task(1), t(50));
oracle.on_region_close(region(0), t(100));
oracle.on_spawn(task(2), region(1), t(20));
oracle.on_region_close(region(1), t(100));
let result = oracle.check(t(100));
let err = result.is_err();
crate::assert_with_log!(err, "err", true, err);
let violation = result.unwrap_err();
crate::assert_with_log!(
violation.region == region(1),
"region",
region(1),
violation.region
);
crate::assert_with_log!(
violation.leaked_tasks == vec![task(2)],
"leaked_tasks",
vec![task(2)],
violation.leaked_tasks
);
crate::test_complete!("multiple_regions_independent");
}
#[test]
fn region_without_close_not_checked() {
init_test("region_without_close_not_checked");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
let ok = oracle.check(t(100)).is_ok();
crate::assert_with_log!(ok, "ok", true, ok);
crate::test_complete!("region_without_close_not_checked");
}
#[test]
fn reset_clears_state() {
init_test("reset_clears_state");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_region_close(region(0), t(100));
let err = oracle.check(t(100)).is_err();
crate::assert_with_log!(err, "err", true, err);
oracle.reset();
let ok = oracle.check(t(100)).is_ok();
crate::assert_with_log!(ok, "ok", true, ok);
let task_count = oracle.task_count();
crate::assert_with_log!(task_count == 0, "task_count", 0, task_count);
let completed_count = oracle.completed_count();
crate::assert_with_log!(completed_count == 0, "completed_count", 0, completed_count);
let closed_count = oracle.closed_region_count();
crate::assert_with_log!(closed_count == 0, "closed_count", 0, closed_count);
crate::test_complete!("reset_clears_state");
}
#[test]
fn violation_display() {
init_test("violation_display");
let violation = TaskLeakViolation {
region: region(0),
leaked_tasks: vec![task(1), task(2)],
region_close_time: t(100),
};
let s = violation.to_string();
let has_leaked = s.contains("leaked");
crate::assert_with_log!(has_leaked, "leaked text", true, has_leaked);
let has_two = s.contains('2');
crate::assert_with_log!(has_two, "contains 2", true, has_two);
crate::test_complete!("violation_display");
}
#[test]
fn violation_display_snapshot_scrubbed() {
let violation = TaskLeakViolation {
region: region(0),
leaked_tasks: vec![task(1), task(2)],
region_close_time: t(100),
};
insta::assert_snapshot!(
"task_leak_violation_display_scrubbed",
scrub_task_leak_display(&violation.to_string())
);
}
#[test]
fn task_in_multiple_regions_ok() {
init_test("task_in_multiple_regions_ok");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_spawn(task(1), region(1), t(20));
oracle.on_complete(task(1), t(50));
oracle.on_region_close(region(0), t(100));
oracle.on_region_close(region(1), t(100));
let ok = oracle.check(t(100)).is_ok();
crate::assert_with_log!(ok, "ok", true, ok);
crate::test_complete!("task_in_multiple_regions_ok");
}
#[test]
fn many_tasks_some_leaked() {
init_test("many_tasks_some_leaked");
let mut oracle = TaskLeakOracle::new();
for i in 1..=5 {
oracle.on_spawn(task(i), region(0), t(u64::from(i) * 10));
}
oracle.on_complete(task(1), t(60));
oracle.on_complete(task(3), t(70));
oracle.on_complete(task(5), t(80));
oracle.on_region_close(region(0), t(100));
let result = oracle.check(t(100));
let err = result.is_err();
crate::assert_with_log!(err, "err", true, err);
let violation = result.unwrap_err();
let len = violation.leaked_tasks.len();
crate::assert_with_log!(len == 2, "leaked_tasks len", 2, len);
let has_two = violation.leaked_tasks.contains(&task(2));
crate::assert_with_log!(has_two, "contains task2", true, has_two);
let has_four = violation.leaked_tasks.contains(&task(4));
crate::assert_with_log!(has_four, "contains task4", true, has_four);
crate::test_complete!("many_tasks_some_leaked");
}
#[test]
fn completion_after_close_still_violates() {
init_test("completion_after_close_still_violates");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_region_close(region(0), t(100));
oracle.on_complete(task(1), t(110));
let violation = oracle
.check(t(110))
.expect_err("task completion after close must not erase the violation");
crate::assert_with_log!(
violation.region == region(0),
"region",
region(0),
violation.region
);
crate::assert_with_log!(
violation.leaked_tasks == vec![task(1)],
"leaked_tasks",
vec![task(1)],
violation.leaked_tasks
);
crate::assert_with_log!(
violation.region_close_time == t(100),
"close_time",
t(100),
violation.region_close_time
);
crate::test_complete!("completion_after_close_still_violates");
}
#[test]
fn task_spawned_after_region_close_is_violation_even_if_it_completes_later() {
init_test("task_spawned_after_region_close_is_violation_even_if_it_completes_later");
let mut oracle = TaskLeakOracle::new();
oracle.on_region_close(region(0), t(100));
oracle.on_spawn(task(1), region(0), t(110));
oracle.on_complete(task(1), t(120));
let violation = oracle
.check(t(120))
.expect_err("task spawned after region close must be reported as a leak");
crate::assert_with_log!(
violation.region == region(0),
"region",
region(0),
violation.region
);
crate::assert_with_log!(
violation.leaked_tasks == vec![task(1)],
"leaked_tasks",
vec![task(1)],
violation.leaked_tasks
);
crate::assert_with_log!(
violation.region_close_time == t(100),
"close_time",
t(100),
violation.region_close_time
);
crate::test_complete!(
"task_spawned_after_region_close_is_violation_even_if_it_completes_later"
);
}
#[test]
fn post_close_spawns_merge_into_existing_region_violation() {
init_test("post_close_spawns_merge_into_existing_region_violation");
let mut oracle = TaskLeakOracle::new();
oracle.on_spawn(task(1), region(0), t(10));
oracle.on_region_close(region(0), t(100));
oracle.on_spawn(task(2), region(0), t(110));
oracle.on_complete(task(2), t(120));
let violation = oracle
.check(t(120))
.expect_err("post-close spawns should be merged into the region's leak record");
crate::assert_with_log!(
violation.region == region(0),
"region",
region(0),
violation.region
);
crate::assert_with_log!(
violation.leaked_tasks == vec![task(1), task(2)],
"leaked_tasks",
vec![task(1), task(2)],
violation.leaked_tasks
);
crate::assert_with_log!(
violation.region_close_time == t(100),
"close_time",
t(100),
violation.region_close_time
);
crate::test_complete!("post_close_spawns_merge_into_existing_region_violation");
}
#[test]
fn task_leak_violation_debug_clone() {
let v = TaskLeakViolation {
region: region(5),
leaked_tasks: vec![task(1), task(2)],
region_close_time: t(999),
};
let cloned = v.clone();
assert_eq!(cloned.region, v.region);
assert_eq!(cloned.leaked_tasks.len(), 2);
let dbg = format!("{v:?}");
assert!(dbg.contains("TaskLeakViolation"));
}
#[test]
fn task_leak_oracle_debug_default() {
let oracle = TaskLeakOracle::default();
let dbg = format!("{oracle:?}");
assert!(dbg.contains("TaskLeakOracle"));
let oracle2 = TaskLeakOracle::new();
let dbg2 = format!("{oracle2:?}");
assert_eq!(dbg, dbg2);
}
}