vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
use crate::runtime::cache::lru::AccessTracker;
use rustc_hash::FxHashMap;

/// Metadata for a cached entry.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct CacheEntry {
    /// Unique identifier for the entry.
    pub key: u64,
    /// Size of the entry in bytes.
    pub size: u64,
    /// Index of the tier the entry currently resides in.
    pub tier: usize,
}

/// A single cache tier with a fixed capacity.
#[non_exhaustive]
pub struct CacheTier {
    /// Human-readable name for the tier.
    pub name: String,
    /// Total capacity of the tier in bytes.
    pub capacity: u64,
    /// Currently used bytes in the tier.
    pub used: u64,
    pub(crate) entries: FxHashMap<u64, CacheEntry>,
}

impl CacheTier {
    /// Create a new empty tier.
    #[inline]
    pub fn new(name: impl Into<String>, capacity: u64) -> Self {
        Self {
            name: name.into(),
            capacity,
            used: 0,
            entries: FxHashMap::default(),
        }
    }
}

/// Access statistics used by [`TierPolicy`] implementations.
#[non_exhaustive]
pub struct AccessStats {
    /// Number of recorded accesses.
    pub frequency: u32,
    /// Position in the recency queue (0 = most recent).
    pub recency_rank: usize,
    /// Size of the entry in bytes.
    pub size: u64,
}

/// Policy that decides promotion and eviction behavior.
pub trait TierPolicy: Send + Sync {
    /// Return `true` if the entry should be promoted to a faster tier.
    fn should_promote(&self, key: u64, stats: &AccessStats) -> bool;

    /// Select a candidate for eviction from the given tier.
    fn eviction_candidate(
        &self,
        tier: usize,
        entries: &FxHashMap<u64, CacheEntry>,
        tracker: &AccessTracker,
    ) -> Option<u64>;
}

/// LRU eviction policy with frequency-based promotion.
#[non_exhaustive]
pub struct LruPolicy {
    /// Minimum access frequency required for promotion.
    pub promote_threshold: u32,
}

impl LruPolicy {
    /// Default access threshold for promotion.
    pub const DEFAULT_THRESHOLD: u32 = 3;

    /// Create a new policy with the given promotion threshold.
    #[inline]
    pub fn new(promote_threshold: u32) -> Self {
        Self { promote_threshold }
    }
}

impl Default for LruPolicy {
    fn default() -> Self {
        Self::new(Self::DEFAULT_THRESHOLD)
    }
}

impl TierPolicy for LruPolicy {
    fn should_promote(&self, _key: u64, stats: &AccessStats) -> bool {
        stats.frequency >= self.promote_threshold
    }

    fn eviction_candidate(
        &self,
        _tier: usize,
        entries: &FxHashMap<u64, CacheEntry>,
        tracker: &AccessTracker,
    ) -> Option<u64> {
        for (key, _meta) in tracker.iter_coldest() {
            if entries.contains_key(key) {
                return Some(*key);
            }
        }
        entries.keys().next().copied()
    }
}

/// Errors that can occur during cache operations.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum CacheError {
    /// The requested key does not exist in the cache.
    KeyNotFound,
    /// The entry is too large to fit in any tier.
    EntryTooLarge,
}

impl std::fmt::Display for CacheError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Self::KeyNotFound => write!(
                f,
                "Key not found in cache. Fix: verify the key was inserted before operating on it."
            ),
            Self::EntryTooLarge => write!(
                f,
                "Entry size exceeds the capacity of the largest tier. Fix: reduce the buffer size or increase the tier capacity."
            ),
        }
    }
}

impl std::error::Error for CacheError {}

/// Generic tiered cache for GPU buffers.
///
/// Tracks hot/cold buffers. [`TierPolicy`] decides promotion and eviction.
/// This is the vyre primitive that helix builds inference intelligence on top of.
#[non_exhaustive]
pub struct TieredCache<P: TierPolicy = LruPolicy> {
    pub(crate) tiers: Vec<CacheTier>,
    pub(crate) tracker: AccessTracker,
    pub(crate) policy: P,
}

impl TieredCache<LruPolicy> {
    /// Create a new cache with the given tiers and a default [`LruPolicy`].
    #[inline]
    pub fn new(tiers: Vec<CacheTier>) -> Self {
        Self::with_policy(tiers, LruPolicy::default())
    }
}

impl<P: TierPolicy> TieredCache<P> {
    /// Create a new cache with a custom policy implementation.
    #[inline]
    pub fn with_policy(tiers: Vec<CacheTier>, policy: P) -> Self {
        Self {
            tiers,
            tracker: AccessTracker::new(),
            policy,
        }
    }

    /// Return a reference to the entry with the given key, if it exists.
    #[inline]
    pub fn get(&self, key: u64) -> Option<&CacheEntry> {
        self.tiers.iter().find_map(|tier| tier.entries.get(&key))
    }

    /// Insert a new entry into the lowest tier that can fit it.
    ///
    /// # Errors
    ///
    /// Returns [`CacheError::EntryTooLarge`] when no tier can hold the entry.
    #[inline]
    pub fn insert(&mut self, key: u64, size: u64) -> Result<(), CacheError> {
        if self.get(key).is_some() {
            self.evict(key);
        }
        self.tracker.set_size(key, size);
        self.insert_into_tier(key, size, 0)
    }

    /// Record an access for the given key.
    #[inline]
    pub fn record_access(&mut self, key: u64) {
        if self.get(key).is_some() {
            self.tracker.record(key);
        }
    }

    /// Promote the entry to the next faster tier if the policy allows it.
    ///
    /// # Errors
    ///
    /// Returns [`CacheError::KeyNotFound`] when the key does not exist.
    #[inline]
    pub fn promote(&mut self, key: u64) -> Result<(), CacheError> {
        let entry = self.get(key).copied().ok_or(CacheError::KeyNotFound)?;
        let stats = self.tracker.stats(key).ok_or(CacheError::KeyNotFound)?;
        if !self.policy.should_promote(key, &stats) {
            return Ok(());
        }
        let target = entry.tier.saturating_add(1);
        if target >= self.tiers.len() {
            return Ok(());
        }
        let size = entry.size;
        self.remove_entry(key);
        self.move_into_tier(key, size, target, entry.tier)
    }

    /// Demote the entry to the next slower tier.
    ///
    /// # Errors
    ///
    /// Returns [`CacheError::KeyNotFound`] when the key does not exist.
    #[inline]
    pub fn demote(&mut self, key: u64) -> Result<(), CacheError> {
        let entry = self.get(key).copied().ok_or(CacheError::KeyNotFound)?;
        if entry.tier == 0 {
            return Ok(());
        }
        let target = entry.tier - 1;
        let size = entry.size;
        self.remove_entry(key);
        self.move_into_tier(key, size, target, entry.tier)
    }

    fn insert_into_tier(
        &mut self,
        key: u64,
        size: u64,
        mut start: usize,
    ) -> Result<(), CacheError> {
        while start < self.tiers.len() {
            if size > self.tiers[start].capacity {
                start += 1;
                continue;
            }
            if self.make_room(start, size) {
                self.tiers[start].used = self.tiers[start].used.saturating_add(size);
                self.tiers[start].entries.insert(
                    key,
                    CacheEntry {
                        key,
                        size,
                        tier: start,
                    },
                );
                return Ok(());
            }
            start += 1;
        }
        Err(CacheError::EntryTooLarge)
    }

    fn move_into_tier(
        &mut self,
        key: u64,
        size: u64,
        target: usize,
        fallback: usize,
    ) -> Result<(), CacheError> {
        if self.make_room(target, size) {
            self.tiers[target].used = self.tiers[target].used.saturating_add(size);
            self.tiers[target].entries.insert(
                key,
                CacheEntry {
                    key,
                    size,
                    tier: target,
                },
            );
            Ok(())
        } else {
            self.insert_into_tier(key, size, fallback)
        }
    }

    fn make_room(&mut self, tier: usize, size: u64) -> bool {
        loop {
            let used = self.tiers[tier].used;
            let cap = self.tiers[tier].capacity;
            if used.saturating_add(size) <= cap {
                return true;
            }
            let candidate = {
                let entries = &self.tiers[tier].entries;
                self.policy.eviction_candidate(tier, entries, &self.tracker)
            };
            if let Some(key) = candidate {
                self.evict_from_tier(key, tier);
            } else {
                return false;
            }
        }
    }

    fn remove_entry(&mut self, key: u64) -> Option<CacheEntry> {
        for tier in &mut self.tiers {
            if let Some(entry) = tier.entries.remove(&key) {
                tier.used = tier.used.saturating_sub(entry.size);
                return Some(entry);
            }
        }
        None
    }

    fn evict(&mut self, key: u64) -> Option<CacheEntry> {
        for tier in &mut self.tiers {
            if let Some(entry) = tier.entries.remove(&key) {
                tier.used = tier.used.saturating_sub(entry.size);
                self.tracker.remove(key);
                return Some(entry);
            }
        }
        None
    }

    fn evict_from_tier(&mut self, key: u64, tier: usize) -> Option<CacheEntry> {
        let tier = &mut self.tiers[tier];
        let entry = tier.entries.remove(&key)?;
        tier.used = tier.used.saturating_sub(entry.size);
        self.tracker.remove(key);
        Some(entry)
    }
}