vyre-foundation 0.4.1

Foundation layer: IR, type system, memory model, wire format. Zero application semantics. Part of the vyre GPU compiler.
Documentation
//! ROADMAP I1 — hot-path recording into optimizer hints.
//!
//! Foundation-side substrate. Backends record per-Region dispatch
//! latency at runtime; the optimizer reads the recorded hints to
//! decide which passes to prioritise on the hot path and which to
//! skip on the cold path.
//!
//! ## Contract
//!
//! `HotPathHints` is the canonical key/value store. Backends append
//! `(region_generator, RegionRecord)` rows after each dispatch.
//! The optimizer queries `is_hot(region_generator)` /
//! `dispatch_count(region_generator)` /
//! `mean_kernel_ns(region_generator)` to decide pass scheduling
//! per region. The default `HotPathHints::default()` is empty —
//! every region is assumed cold until recorded otherwise, so the
//! optimizer falls back to its default schedule when no PGO data
//! exists.
//!
//! Sample concentration: bounded LRU keyed by region generator,
//! capacity tunable via `with_capacity`. New samples accumulate
//! into the existing `RegionRecord` for the matching key (running
//! mean + max), so the table never grows beyond the capacity even
//! across long-running workloads.
//!
//! ## Why on-foundation, not on-driver
//!
//! The hint table is queried by the foundation_optimizer scheduler
//! before any backend is selected. Putting the table here lets
//! every backend feed into the same store without leaking backend-
//! specific types into the optimizer. Backends own *recording*
//! (each backend hooks its own dispatch finalize); the optimizer
//! owns *consuming*.
//!
//! ## Soundness for the optimizer
//!
//! The `is_hot` threshold is a heuristic — passes that consume the
//! hint must remain correct (just not optimal) when the hint is
//! absent or stale. The optimizer must NEVER turn a soundness gate
//! into a hot-path-only gate. The hint is a *prioritisation*
//! signal, not a *correctness* signal.

use rustc_hash::FxHashMap;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};

/// Per-region performance record. Mean over recorded samples is a
/// stable proxy for the steady-state kernel cost; max captures the
/// worst-case dispatch the backend has seen.
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct RegionRecord {
    /// Number of dispatch samples observed for this region
    /// generator across the recording window.
    pub dispatch_count: u64,
    /// Sum of `kernel_execute_ns` across all samples; divide by
    /// `dispatch_count` for the running mean.
    pub kernel_ns_total: u64,
    /// Maximum `kernel_execute_ns` seen so far.
    pub kernel_ns_max: u64,
    /// Sum of `bytes_touched` across all samples.
    pub bytes_total: u64,
}

impl RegionRecord {
    /// Mean kernel-execute time across recorded samples, in
    /// nanoseconds. Returns `0` when no sample has been recorded.
    #[must_use]
    pub fn mean_kernel_ns(&self) -> u64 {
        if self.dispatch_count == 0 {
            0
        } else {
            self.kernel_ns_total / self.dispatch_count
        }
    }

    /// Mean bytes touched per dispatch.
    #[must_use]
    pub fn mean_bytes(&self) -> u64 {
        if self.dispatch_count == 0 {
            0
        } else {
            self.bytes_total / self.dispatch_count
        }
    }
}

/// Bounded LRU store of per-region performance records. Cheap to
/// clone (Arc-shared internal Mutex), Send + Sync.
#[derive(Clone)]
pub struct HotPathHints {
    inner: Arc<Mutex<HintsInner>>,
}

struct HintsInner {
    capacity: usize,
    /// Threshold (mean kernel-execute ns) above which a region is
    /// considered hot.
    hot_ns_threshold: u64,
    records: FxHashMap<String, RegionRecord>,
    recency: VecDeque<String>,
}

impl HotPathHints {
    /// Build a hint store with the given LRU capacity. `capacity == 0`
    /// disables recording (all queries return defaults). The default
    /// hot-ns threshold is `100_000` (100µs) — overrideable via
    /// [`with_hot_threshold_ns`](Self::with_hot_threshold_ns).
    #[must_use]
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            inner: Arc::new(Mutex::new(HintsInner {
                capacity,
                hot_ns_threshold: 100_000,
                records: FxHashMap::default(),
                recency: VecDeque::with_capacity(capacity.max(1)),
            })),
        }
    }

    /// Override the kernel-execute-ns threshold above which a
    /// region is considered hot. Default is 100µs.
    #[must_use]
    pub fn with_hot_threshold_ns(self, threshold_ns: u64) -> Self {
        {
            let mut inner = self.inner.lock().expect("hints mutex poisoned");
            inner.hot_ns_threshold = threshold_ns;
        }
        self
    }

    /// Record a dispatch sample for `region_generator`. Existing
    /// record gets accumulated; new region triggers LRU eviction
    /// when at capacity.
    pub fn record(&self, region_generator: &str, kernel_ns: u64, bytes_touched: u64) {
        let mut inner = self.inner.lock().expect("hints mutex poisoned");
        if inner.capacity == 0 {
            return;
        }
        let key = region_generator.to_owned();
        let entry = inner.records.entry(key.clone()).or_insert(RegionRecord {
            dispatch_count: 0,
            kernel_ns_total: 0,
            kernel_ns_max: 0,
            bytes_total: 0,
        });
        entry.dispatch_count = entry.dispatch_count.saturating_add(1);
        entry.kernel_ns_total = entry.kernel_ns_total.saturating_add(kernel_ns);
        if kernel_ns > entry.kernel_ns_max {
            entry.kernel_ns_max = kernel_ns;
        }
        entry.bytes_total = entry.bytes_total.saturating_add(bytes_touched);
        bump_recency(&mut inner.recency, &key);
        let cap = inner.capacity;
        while inner.records.len() > cap {
            if let Some(evicted) = inner.recency.pop_front() {
                inner.records.remove(&evicted);
            } else {
                break;
            }
        }
    }

    /// True iff `region_generator`'s recorded mean kernel-ns
    /// exceeds the hot threshold. Cold (or unrecorded) regions
    /// return false — passes that gate on `is_hot` must remain
    /// correct on the cold path.
    #[must_use]
    pub fn is_hot(&self, region_generator: &str) -> bool {
        let inner = self.inner.lock().expect("hints mutex poisoned");
        inner
            .records
            .get(region_generator)
            .map(|r| r.mean_kernel_ns() >= inner.hot_ns_threshold)
            .unwrap_or(false)
    }

    /// Number of dispatch samples recorded for the region. Returns
    /// `0` for unrecorded regions.
    #[must_use]
    pub fn dispatch_count(&self, region_generator: &str) -> u64 {
        let inner = self.inner.lock().expect("hints mutex poisoned");
        inner
            .records
            .get(region_generator)
            .map(|r| r.dispatch_count)
            .unwrap_or(0)
    }

    /// Mean kernel-execute time in nanoseconds. Returns `0` for
    /// unrecorded regions.
    #[must_use]
    pub fn mean_kernel_ns(&self, region_generator: &str) -> u64 {
        let inner = self.inner.lock().expect("hints mutex poisoned");
        inner
            .records
            .get(region_generator)
            .map(|r| r.mean_kernel_ns())
            .unwrap_or(0)
    }

    /// Snapshot of the full record for `region_generator`, or
    /// `None` if no samples have been recorded.
    #[must_use]
    pub fn record_for(&self, region_generator: &str) -> Option<RegionRecord> {
        let inner = self.inner.lock().expect("hints mutex poisoned");
        inner.records.get(region_generator).copied()
    }

    /// Total number of distinct regions currently recorded.
    #[must_use]
    pub fn len(&self) -> usize {
        self.inner
            .lock()
            .expect("hints mutex poisoned")
            .records
            .len()
    }

    /// `true` iff zero records.
    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }
}

impl Default for HotPathHints {
    fn default() -> Self {
        Self::with_capacity(256)
    }
}

impl std::fmt::Debug for HotPathHints {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let inner = self.inner.lock().map_err(|_| std::fmt::Error)?;
        f.debug_struct("HotPathHints")
            .field("capacity", &inner.capacity)
            .field("hot_ns_threshold", &inner.hot_ns_threshold)
            .field("records_len", &inner.records.len())
            .finish()
    }
}

fn bump_recency(recency: &mut VecDeque<String>, key: &str) {
    if let Some(pos) = recency.iter().position(|k| k == key) {
        recency.remove(pos);
    }
    recency.push_back(key.to_owned());
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Empty hints reject every is_hot query and return zero counts.
    #[test]
    fn empty_hints_returns_defaults() {
        let hints = HotPathHints::default();
        assert!(hints.is_empty());
        assert!(!hints.is_hot("any-region"));
        assert_eq!(hints.dispatch_count("any-region"), 0);
        assert_eq!(hints.mean_kernel_ns("any-region"), 0);
        assert!(hints.record_for("any-region").is_none());
    }

    /// Recording N samples accumulates into a single record with
    /// running mean + max.
    #[test]
    fn record_accumulates_into_running_mean_and_max() {
        let hints = HotPathHints::default();
        hints.record("matmul", 1_000, 100);
        hints.record("matmul", 3_000, 300);
        hints.record("matmul", 2_000, 200);
        let rec = hints.record_for("matmul").expect("recorded");
        assert_eq!(rec.dispatch_count, 3);
        assert_eq!(rec.kernel_ns_total, 6_000);
        assert_eq!(rec.kernel_ns_max, 3_000);
        assert_eq!(rec.bytes_total, 600);
        assert_eq!(rec.mean_kernel_ns(), 2_000);
        assert_eq!(rec.mean_bytes(), 200);
    }

    /// `is_hot` flips true once the recorded mean crosses the
    /// threshold (default 100µs).
    #[test]
    fn is_hot_uses_recorded_mean() {
        let hints = HotPathHints::default().with_hot_threshold_ns(1_000);
        hints.record("region", 500, 0);
        assert!(!hints.is_hot("region"), "below threshold");
        hints.record("region", 2_000, 0);
        // mean = (500 + 2000) / 2 = 1250; >= 1000 → hot.
        assert!(hints.is_hot("region"));
    }

    /// LRU eviction kicks the oldest entry when capacity is hit.
    #[test]
    fn lru_evicts_oldest_when_capacity_reached() {
        let hints = HotPathHints::with_capacity(2);
        hints.record("a", 100, 10);
        hints.record("b", 200, 20);
        hints.record("c", 300, 30);
        assert_eq!(hints.len(), 2);
        assert!(hints.record_for("a").is_none(), "oldest evicted");
        assert!(hints.record_for("b").is_some());
        assert!(hints.record_for("c").is_some());
    }

    /// Recording the same key bumps recency so subsequent eviction
    /// targets a different entry.
    #[test]
    fn lru_recency_promotes_on_repeat_record() {
        let hints = HotPathHints::with_capacity(2);
        hints.record("a", 100, 10);
        hints.record("b", 200, 20);
        hints.record("a", 100, 10); // bumps a to most recent
        hints.record("c", 300, 30);
        assert!(
            hints.record_for("a").is_some(),
            "a was bumped, must survive"
        );
        assert!(hints.record_for("b").is_none(), "b was oldest, evicted");
        assert!(hints.record_for("c").is_some());
    }

    /// Capacity 0 disables recording entirely.
    #[test]
    fn capacity_zero_disables_recording() {
        let hints = HotPathHints::with_capacity(0);
        hints.record("a", 100, 10);
        assert!(hints.is_empty());
        assert!(!hints.is_hot("a"));
    }

    /// Hints are Send + Sync so a backend on one thread can record
    /// while the optimizer on another thread queries.
    #[test]
    fn hints_are_send_sync() {
        fn assert_send<T: Send>() {}
        fn assert_sync<T: Sync>() {}
        assert_send::<HotPathHints>();
        assert_sync::<HotPathHints>();
    }
}