use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use super::events::HookEvent;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EventClass {
Write,
Read,
Index,
Transcript,
HotPath,
}
pub const WRITE_CLASS_DEADLINE_MS: u64 = 5_000;
pub const READ_CLASS_DEADLINE_MS: u64 = 2_000;
pub const INDEX_CLASS_DEADLINE_MS: u64 = 1_000;
pub const TRANSCRIPT_CLASS_DEADLINE_MS: u64 = 5_000;
pub const HOT_PATH_CLASS_DEADLINE_MS: u64 = 50;
#[must_use]
pub fn event_class(event: HookEvent) -> EventClass {
match event {
HookEvent::PreStore
| HookEvent::PostStore
| HookEvent::PreDelete
| HookEvent::PostDelete
| HookEvent::PrePromote
| HookEvent::PostPromote
| HookEvent::PreLink
| HookEvent::PostLink
| HookEvent::PreConsolidate
| HookEvent::PostConsolidate
| HookEvent::PreGovernanceDecision
| HookEvent::PostGovernanceDecision
| HookEvent::PreArchive
| HookEvent::PreReflect
| HookEvent::PostReflect
| HookEvent::PreCompaction
| HookEvent::OnCompactionRollback => EventClass::Write,
HookEvent::PreRecall
| HookEvent::PostRecall
| HookEvent::PreSearch
| HookEvent::PostSearch => EventClass::Read,
HookEvent::OnIndexEviction => EventClass::Index,
HookEvent::PreTranscriptStore | HookEvent::PostTranscriptStore => EventClass::Transcript,
HookEvent::PreRecallExpand => EventClass::HotPath,
}
}
#[must_use]
pub fn class_deadline(class: EventClass) -> Duration {
let base_ms = match class {
EventClass::Write => WRITE_CLASS_DEADLINE_MS,
EventClass::Read => READ_CLASS_DEADLINE_MS,
EventClass::Index => INDEX_CLASS_DEADLINE_MS,
EventClass::Transcript => TRANSCRIPT_CLASS_DEADLINE_MS,
EventClass::HotPath => HOT_PATH_CLASS_DEADLINE_MS,
};
Duration::from_millis(base_ms.saturating_mul(test_timing_budget_mult()))
}
#[cfg(any(test, debug_assertions))]
fn test_timing_budget_mult() -> u64 {
std::env::var("AI_MEMORY_TEST_TIMING_BUDGET_MULT")
.ok()
.and_then(|s| s.parse::<u64>().ok())
.filter(|&n| (1..=100).contains(&n))
.unwrap_or(1)
}
#[cfg(not(any(test, debug_assertions)))]
#[inline(always)]
fn test_timing_budget_mult() -> u64 {
1
}
#[must_use]
pub fn class_deadline_for_event(event: HookEvent) -> Duration {
class_deadline(event_class(event))
}
#[must_use]
pub fn per_hook_budget_ms(
chain_deadline: Instant,
now: Instant,
hook_timeout_ms: u32,
) -> Option<u32> {
if now >= chain_deadline {
return None;
}
let remaining = chain_deadline.saturating_duration_since(now);
let remaining_ms = u32::try_from(remaining.as_millis()).unwrap_or(u32::MAX);
Some(remaining_ms.min(hook_timeout_ms))
}
static TIMEOUT_VIOLATIONS: AtomicU64 = AtomicU64::new(0);
pub fn record_timeout_violation() {
TIMEOUT_VIOLATIONS.fetch_add(1, Ordering::Relaxed);
}
#[must_use]
pub fn timeout_violations_total() -> u64 {
TIMEOUT_VIOLATIONS.load(Ordering::Relaxed)
}
#[cfg(test)]
pub fn reset_timeout_violations_for_test() {
TIMEOUT_VIOLATIONS.store(0, Ordering::Relaxed);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn event_class_table_covers_all_25_variants() {
let table = [
(HookEvent::PreStore, EventClass::Write),
(HookEvent::PostStore, EventClass::Write),
(HookEvent::PreDelete, EventClass::Write),
(HookEvent::PostDelete, EventClass::Write),
(HookEvent::PrePromote, EventClass::Write),
(HookEvent::PostPromote, EventClass::Write),
(HookEvent::PreLink, EventClass::Write),
(HookEvent::PostLink, EventClass::Write),
(HookEvent::PreConsolidate, EventClass::Write),
(HookEvent::PostConsolidate, EventClass::Write),
(HookEvent::PreGovernanceDecision, EventClass::Write),
(HookEvent::PostGovernanceDecision, EventClass::Write),
(HookEvent::PreArchive, EventClass::Write),
(HookEvent::PreReflect, EventClass::Write),
(HookEvent::PostReflect, EventClass::Write),
(HookEvent::PreCompaction, EventClass::Write),
(HookEvent::OnCompactionRollback, EventClass::Write),
(HookEvent::PreRecall, EventClass::Read),
(HookEvent::PostRecall, EventClass::Read),
(HookEvent::PreSearch, EventClass::Read),
(HookEvent::PostSearch, EventClass::Read),
(HookEvent::OnIndexEviction, EventClass::Index),
(HookEvent::PreTranscriptStore, EventClass::Transcript),
(HookEvent::PostTranscriptStore, EventClass::Transcript),
(HookEvent::PreRecallExpand, EventClass::HotPath),
];
assert_eq!(
table.len(),
25,
"v0.7.0 L1-7 mapping must cover exactly the 25 HookEvent variants"
);
for (event, expected) in table {
assert_eq!(
event_class(event),
expected,
"event {event:?} mis-classified"
);
}
}
#[test]
fn class_deadlines_match_epic_table() {
assert_eq!(
class_deadline(EventClass::Write),
Duration::from_millis(5_000)
);
assert_eq!(
class_deadline(EventClass::Read),
Duration::from_millis(2_000)
);
assert_eq!(
class_deadline(EventClass::Index),
Duration::from_millis(1_000)
);
assert_eq!(
class_deadline(EventClass::Transcript),
Duration::from_millis(5_000)
);
assert_eq!(
class_deadline(EventClass::HotPath),
Duration::from_millis(50)
);
}
#[test]
fn class_deadline_for_event_round_trips_through_class() {
assert_eq!(
class_deadline_for_event(HookEvent::PreStore),
Duration::from_millis(WRITE_CLASS_DEADLINE_MS)
);
assert_eq!(
class_deadline_for_event(HookEvent::PostRecall),
Duration::from_millis(READ_CLASS_DEADLINE_MS)
);
assert_eq!(
class_deadline_for_event(HookEvent::OnIndexEviction),
Duration::from_millis(INDEX_CLASS_DEADLINE_MS)
);
assert_eq!(
class_deadline_for_event(HookEvent::PostTranscriptStore),
Duration::from_millis(TRANSCRIPT_CLASS_DEADLINE_MS)
);
assert_eq!(
class_deadline_for_event(HookEvent::PreRecallExpand),
Duration::from_millis(HOT_PATH_CLASS_DEADLINE_MS)
);
}
#[test]
fn per_hook_budget_takes_minimum_of_chain_and_hook() {
let now = Instant::now();
let chain_deadline = now + Duration::from_millis(500);
let budget = per_hook_budget_ms(chain_deadline, now, 200).expect("not yet expired");
assert_eq!(budget, 200);
let budget = per_hook_budget_ms(chain_deadline, now, 5_000).expect("not yet expired");
assert!(
(498..=500).contains(&budget),
"expected ~500ms chain budget, got {budget}"
);
}
#[test]
fn per_hook_budget_returns_none_when_chain_deadline_passed() {
let now = Instant::now();
let chain_deadline = now - Duration::from_millis(1);
assert!(per_hook_budget_ms(chain_deadline, now, 1_000).is_none());
}
#[test]
fn per_hook_budget_at_exact_deadline_is_none() {
let now = Instant::now();
assert!(per_hook_budget_ms(now, now, 1_000).is_none());
}
#[test]
fn timeout_violations_counter_is_monotonic_and_resettable() {
reset_timeout_violations_for_test();
assert_eq!(timeout_violations_total(), 0);
record_timeout_violation();
record_timeout_violation();
record_timeout_violation();
assert_eq!(timeout_violations_total(), 3);
reset_timeout_violations_for_test();
assert_eq!(timeout_violations_total(), 0);
}
fn timing_mult_lock() -> std::sync::MutexGuard<'static, ()> {
use std::sync::{Mutex, OnceLock};
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn with_mult<R>(value: Option<&str>, body: impl FnOnce() -> R) -> R {
let _guard = timing_mult_lock();
let prior = std::env::var("AI_MEMORY_TEST_TIMING_BUDGET_MULT").ok();
match value {
Some(v) => unsafe { std::env::set_var("AI_MEMORY_TEST_TIMING_BUDGET_MULT", v) },
None => unsafe { std::env::remove_var("AI_MEMORY_TEST_TIMING_BUDGET_MULT") },
}
let result = body();
match prior {
Some(v) => unsafe { std::env::set_var("AI_MEMORY_TEST_TIMING_BUDGET_MULT", v) },
None => unsafe { std::env::remove_var("AI_MEMORY_TEST_TIMING_BUDGET_MULT") },
}
result
}
#[test]
fn issue_1207_timing_mult_unset_defaults_to_one() {
with_mult(None, || {
assert_eq!(test_timing_budget_mult(), 1);
assert_eq!(
class_deadline(EventClass::Index),
Duration::from_millis(INDEX_CLASS_DEADLINE_MS),
);
});
}
#[test]
fn issue_1207_timing_mult_valid_scales_class_deadline() {
with_mult(Some("5"), || {
assert_eq!(test_timing_budget_mult(), 5);
assert_eq!(
class_deadline(EventClass::Index),
Duration::from_millis(INDEX_CLASS_DEADLINE_MS * 5),
);
assert_eq!(
class_deadline(EventClass::Write),
Duration::from_millis(WRITE_CLASS_DEADLINE_MS * 5),
);
});
}
#[test]
fn issue_1207_timing_mult_unparseable_falls_back_to_one() {
with_mult(Some("bogus-not-a-number"), || {
assert_eq!(test_timing_budget_mult(), 1);
});
}
#[test]
fn issue_1207_timing_mult_below_range_falls_back_to_one() {
with_mult(Some("0"), || {
assert_eq!(test_timing_budget_mult(), 1);
});
}
#[test]
fn issue_1207_timing_mult_above_range_falls_back_to_one() {
with_mult(Some("9999"), || {
assert_eq!(test_timing_budget_mult(), 1);
});
}
#[test]
fn issue_1207_timing_mult_boundary_at_one_and_hundred() {
with_mult(Some("1"), || assert_eq!(test_timing_budget_mult(), 1));
with_mult(Some("100"), || assert_eq!(test_timing_budget_mult(), 100));
}
}