use rustc_hash::FxHashMap;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct RegionRecord {
pub dispatch_count: u64,
pub kernel_ns_total: u64,
pub kernel_ns_max: u64,
pub bytes_total: u64,
}
impl RegionRecord {
#[must_use]
pub fn mean_kernel_ns(&self) -> u64 {
if self.dispatch_count == 0 {
0
} else {
self.kernel_ns_total / self.dispatch_count
}
}
#[must_use]
pub fn mean_bytes(&self) -> u64 {
if self.dispatch_count == 0 {
0
} else {
self.bytes_total / self.dispatch_count
}
}
}
#[derive(Clone)]
pub struct HotPathHints {
inner: Arc<Mutex<HintsInner>>,
}
struct HintsInner {
capacity: usize,
hot_ns_threshold: u64,
records: FxHashMap<String, RegionRecord>,
recency: VecDeque<String>,
}
impl HotPathHints {
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(HintsInner {
capacity,
hot_ns_threshold: 100_000,
records: FxHashMap::default(),
recency: VecDeque::with_capacity(capacity.max(1)),
})),
}
}
#[must_use]
pub fn with_hot_threshold_ns(self, threshold_ns: u64) -> Self {
{
let mut inner = self.inner.lock().expect("hints mutex poisoned");
inner.hot_ns_threshold = threshold_ns;
}
self
}
pub fn record(&self, region_generator: &str, kernel_ns: u64, bytes_touched: u64) {
let mut inner = self.inner.lock().expect("hints mutex poisoned");
if inner.capacity == 0 {
return;
}
let key = region_generator.to_owned();
let entry = inner.records.entry(key.clone()).or_insert(RegionRecord {
dispatch_count: 0,
kernel_ns_total: 0,
kernel_ns_max: 0,
bytes_total: 0,
});
entry.dispatch_count = entry.dispatch_count.saturating_add(1);
entry.kernel_ns_total = entry.kernel_ns_total.saturating_add(kernel_ns);
if kernel_ns > entry.kernel_ns_max {
entry.kernel_ns_max = kernel_ns;
}
entry.bytes_total = entry.bytes_total.saturating_add(bytes_touched);
bump_recency(&mut inner.recency, &key);
let cap = inner.capacity;
while inner.records.len() > cap {
if let Some(evicted) = inner.recency.pop_front() {
inner.records.remove(&evicted);
} else {
break;
}
}
}
#[must_use]
pub fn is_hot(&self, region_generator: &str) -> bool {
let inner = self.inner.lock().expect("hints mutex poisoned");
inner
.records
.get(region_generator)
.map(|r| r.mean_kernel_ns() >= inner.hot_ns_threshold)
.unwrap_or(false)
}
#[must_use]
pub fn dispatch_count(&self, region_generator: &str) -> u64 {
let inner = self.inner.lock().expect("hints mutex poisoned");
inner
.records
.get(region_generator)
.map(|r| r.dispatch_count)
.unwrap_or(0)
}
#[must_use]
pub fn mean_kernel_ns(&self, region_generator: &str) -> u64 {
let inner = self.inner.lock().expect("hints mutex poisoned");
inner
.records
.get(region_generator)
.map(|r| r.mean_kernel_ns())
.unwrap_or(0)
}
#[must_use]
pub fn record_for(&self, region_generator: &str) -> Option<RegionRecord> {
let inner = self.inner.lock().expect("hints mutex poisoned");
inner.records.get(region_generator).copied()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner
.lock()
.expect("hints mutex poisoned")
.records
.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for HotPathHints {
fn default() -> Self {
Self::with_capacity(256)
}
}
impl std::fmt::Debug for HotPathHints {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.lock().map_err(|_| std::fmt::Error)?;
f.debug_struct("HotPathHints")
.field("capacity", &inner.capacity)
.field("hot_ns_threshold", &inner.hot_ns_threshold)
.field("records_len", &inner.records.len())
.finish()
}
}
fn bump_recency(recency: &mut VecDeque<String>, key: &str) {
if let Some(pos) = recency.iter().position(|k| k == key) {
recency.remove(pos);
}
recency.push_back(key.to_owned());
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_hints_returns_defaults() {
let hints = HotPathHints::default();
assert!(hints.is_empty());
assert!(!hints.is_hot("any-region"));
assert_eq!(hints.dispatch_count("any-region"), 0);
assert_eq!(hints.mean_kernel_ns("any-region"), 0);
assert!(hints.record_for("any-region").is_none());
}
#[test]
fn record_accumulates_into_running_mean_and_max() {
let hints = HotPathHints::default();
hints.record("matmul", 1_000, 100);
hints.record("matmul", 3_000, 300);
hints.record("matmul", 2_000, 200);
let rec = hints.record_for("matmul").expect("recorded");
assert_eq!(rec.dispatch_count, 3);
assert_eq!(rec.kernel_ns_total, 6_000);
assert_eq!(rec.kernel_ns_max, 3_000);
assert_eq!(rec.bytes_total, 600);
assert_eq!(rec.mean_kernel_ns(), 2_000);
assert_eq!(rec.mean_bytes(), 200);
}
#[test]
fn is_hot_uses_recorded_mean() {
let hints = HotPathHints::default().with_hot_threshold_ns(1_000);
hints.record("region", 500, 0);
assert!(!hints.is_hot("region"), "below threshold");
hints.record("region", 2_000, 0);
assert!(hints.is_hot("region"));
}
#[test]
fn lru_evicts_oldest_when_capacity_reached() {
let hints = HotPathHints::with_capacity(2);
hints.record("a", 100, 10);
hints.record("b", 200, 20);
hints.record("c", 300, 30);
assert_eq!(hints.len(), 2);
assert!(hints.record_for("a").is_none(), "oldest evicted");
assert!(hints.record_for("b").is_some());
assert!(hints.record_for("c").is_some());
}
#[test]
fn lru_recency_promotes_on_repeat_record() {
let hints = HotPathHints::with_capacity(2);
hints.record("a", 100, 10);
hints.record("b", 200, 20);
hints.record("a", 100, 10); hints.record("c", 300, 30);
assert!(
hints.record_for("a").is_some(),
"a was bumped, must survive"
);
assert!(hints.record_for("b").is_none(), "b was oldest, evicted");
assert!(hints.record_for("c").is_some());
}
#[test]
fn capacity_zero_disables_recording() {
let hints = HotPathHints::with_capacity(0);
hints.record("a", 100, 10);
assert!(hints.is_empty());
assert!(!hints.is_hot("a"));
}
#[test]
fn hints_are_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<HotPathHints>();
assert_sync::<HotPathHints>();
}
}