Skip to main content

infernum_arbiter/
cache.rs

1//! Fragment cache for HoloTensor weights.
2//!
3//! Unified cache shared between Infernum and Dantalion for efficient
4//! fragment reuse across workloads.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::Instant;
9
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12
13/// Cache configuration.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct CacheConfig {
16    /// VRAM cache capacity in bytes.
17    pub vram_capacity: u64,
18    /// RAM cache capacity in bytes.
19    pub ram_capacity: u64,
20}
21
22impl Default for CacheConfig {
23    fn default() -> Self {
24        Self {
25            vram_capacity: 10 * 1024 * 1024 * 1024, // 10GB
26            ram_capacity: 32 * 1024 * 1024 * 1024,  // 32GB
27        }
28    }
29}
30
31/// Cache tier for fragments.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33pub enum CacheTier {
34    /// GPU VRAM - fastest.
35    Vram,
36    /// System RAM - fast.
37    Ram,
38    /// Not cached.
39    None,
40}
41
42/// Statistics for the fragment cache.
43#[derive(Debug, Clone, Default, Serialize, Deserialize)]
44pub struct CacheStats {
45    /// VRAM bytes used.
46    pub vram_used: u64,
47    /// RAM bytes used.
48    pub ram_used: u64,
49    /// Cache hits.
50    pub hits: u64,
51    /// Cache misses.
52    pub misses: u64,
53    /// Evictions from VRAM.
54    pub vram_evictions: u64,
55    /// Evictions from RAM.
56    pub ram_evictions: u64,
57    /// Total fragments cached.
58    pub fragments_cached: u64,
59}
60
61impl CacheStats {
62    /// Returns cache hit rate (0.0 - 1.0).
63    pub fn hit_rate(&self) -> f64 {
64        let total = self.hits + self.misses;
65        if total == 0 {
66            return 0.0;
67        }
68        self.hits as f64 / total as f64
69    }
70
71    /// Returns VRAM utilization.
72    pub fn vram_utilization(&self, capacity: u64) -> f64 {
73        if capacity == 0 {
74            return 0.0;
75        }
76        self.vram_used as f64 / capacity as f64
77    }
78
79    /// Returns RAM utilization.
80    pub fn ram_utilization(&self, capacity: u64) -> f64 {
81        if capacity == 0 {
82            return 0.0;
83        }
84        self.ram_used as f64 / capacity as f64
85    }
86}
87
88/// A cached fragment entry.
89#[derive(Debug, Clone)]
90struct CacheEntry {
91    /// Fragment identifier (stored for diagnostics; keyed in HashMap).
92    _fragment_id: String,
93    /// Size in bytes.
94    size: u64,
95    /// Current tier.
96    tier: CacheTier,
97    /// Last access time.
98    last_access: Instant,
99    /// Access count.
100    access_count: u64,
101    /// Which systems use this fragment.
102    users: FragmentUsers,
103}
104
105/// Which systems use a fragment.
106#[derive(Debug, Clone, Copy, Default)]
107struct FragmentUsers {
108    infernum: bool,
109    dantalion: bool,
110}
111
112impl FragmentUsers {
113    fn count(&self) -> u32 {
114        self.infernum as u32 + self.dantalion as u32
115    }
116}
117
118/// The unified fragment cache.
119pub struct FragmentCache {
120    config: CacheConfig,
121    entries: RwLock<HashMap<String, CacheEntry>>,
122    vram_used: AtomicU64,
123    ram_used: AtomicU64,
124    hits: AtomicU64,
125    misses: AtomicU64,
126    vram_evictions: AtomicU64,
127    ram_evictions: AtomicU64,
128}
129
130impl FragmentCache {
131    /// Creates a new cache with the given configuration.
132    pub fn new(config: CacheConfig) -> Self {
133        Self {
134            config,
135            entries: RwLock::new(HashMap::new()),
136            vram_used: AtomicU64::new(0),
137            ram_used: AtomicU64::new(0),
138            hits: AtomicU64::new(0),
139            misses: AtomicU64::new(0),
140            vram_evictions: AtomicU64::new(0),
141            ram_evictions: AtomicU64::new(0),
142        }
143    }
144
145    /// Returns the configuration.
146    pub fn config(&self) -> &CacheConfig {
147        &self.config
148    }
149
150    /// Checks if a fragment is cached.
151    pub fn contains(&self, fragment_id: &str) -> bool {
152        self.entries.read().contains_key(fragment_id)
153    }
154
155    /// Gets the tier for a fragment.
156    pub fn get_tier(&self, fragment_id: &str) -> CacheTier {
157        self.entries
158            .read()
159            .get(fragment_id)
160            .map(|e| e.tier)
161            .unwrap_or(CacheTier::None)
162    }
163
164    /// Records a cache access, returning the tier.
165    pub fn access(&self, fragment_id: &str) -> CacheTier {
166        let mut entries = self.entries.write();
167        if let Some(entry) = entries.get_mut(fragment_id) {
168            entry.last_access = Instant::now();
169            entry.access_count += 1;
170            self.hits.fetch_add(1, Ordering::Relaxed);
171            entry.tier
172        } else {
173            self.misses.fetch_add(1, Ordering::Relaxed);
174            CacheTier::None
175        }
176    }
177
178    /// Inserts a fragment into the cache.
179    pub fn insert(
180        &self,
181        fragment_id: impl Into<String>,
182        size: u64,
183        tier: CacheTier,
184        for_infernum: bool,
185    ) {
186        let fragment_id = fragment_id.into();
187
188        // Evict if necessary
189        self.ensure_capacity(size, tier);
190
191        let entry = CacheEntry {
192            _fragment_id: fragment_id.clone(),
193            size,
194            tier,
195            last_access: Instant::now(),
196            access_count: 1,
197            users: FragmentUsers {
198                infernum: for_infernum,
199                dantalion: !for_infernum,
200            },
201        };
202
203        // Update usage tracking
204        match tier {
205            CacheTier::Vram => {
206                self.vram_used.fetch_add(size, Ordering::Relaxed);
207            },
208            CacheTier::Ram => {
209                self.ram_used.fetch_add(size, Ordering::Relaxed);
210            },
211            CacheTier::None => {},
212        }
213
214        self.entries.write().insert(fragment_id, entry);
215    }
216
217    /// Removes a fragment from the cache.
218    pub fn remove(&self, fragment_id: &str) {
219        let mut entries = self.entries.write();
220        if let Some(entry) = entries.remove(fragment_id) {
221            match entry.tier {
222                CacheTier::Vram => {
223                    self.vram_used.fetch_sub(entry.size, Ordering::Relaxed);
224                },
225                CacheTier::Ram => {
226                    self.ram_used.fetch_sub(entry.size, Ordering::Relaxed);
227                },
228                CacheTier::None => {},
229            }
230        }
231    }
232
233    /// Promotes a fragment to a higher tier.
234    pub fn promote(&self, fragment_id: &str, to_tier: CacheTier) {
235        let mut entries = self.entries.write();
236        if let Some(entry) = entries.get_mut(fragment_id) {
237            let from_tier = entry.tier;
238            if to_tier == from_tier {
239                return;
240            }
241
242            // Update usage
243            match from_tier {
244                CacheTier::Vram => {
245                    self.vram_used.fetch_sub(entry.size, Ordering::Relaxed);
246                },
247                CacheTier::Ram => {
248                    self.ram_used.fetch_sub(entry.size, Ordering::Relaxed);
249                },
250                CacheTier::None => {},
251            }
252
253            match to_tier {
254                CacheTier::Vram => {
255                    self.vram_used.fetch_add(entry.size, Ordering::Relaxed);
256                },
257                CacheTier::Ram => {
258                    self.ram_used.fetch_add(entry.size, Ordering::Relaxed);
259                },
260                CacheTier::None => {},
261            }
262
263            entry.tier = to_tier;
264        }
265    }
266
267    /// Demotes a fragment to a lower tier.
268    pub fn demote(&self, fragment_id: &str, to_tier: CacheTier) {
269        self.promote(fragment_id, to_tier);
270    }
271
272    /// Marks a fragment as used by both systems (shared).
273    pub fn mark_shared(&self, fragment_id: &str) {
274        let mut entries = self.entries.write();
275        if let Some(entry) = entries.get_mut(fragment_id) {
276            entry.users.infernum = true;
277            entry.users.dantalion = true;
278        }
279    }
280
281    /// Returns current statistics.
282    pub fn stats(&self) -> CacheStats {
283        CacheStats {
284            vram_used: self.vram_used.load(Ordering::Relaxed),
285            ram_used: self.ram_used.load(Ordering::Relaxed),
286            hits: self.hits.load(Ordering::Relaxed),
287            misses: self.misses.load(Ordering::Relaxed),
288            vram_evictions: self.vram_evictions.load(Ordering::Relaxed),
289            ram_evictions: self.ram_evictions.load(Ordering::Relaxed),
290            fragments_cached: self.entries.read().len() as u64,
291        }
292    }
293
294    /// Returns VRAM used.
295    pub fn vram_used(&self) -> u64 {
296        self.vram_used.load(Ordering::Relaxed)
297    }
298
299    /// Returns RAM used.
300    pub fn ram_used(&self) -> u64 {
301        self.ram_used.load(Ordering::Relaxed)
302    }
303
304    /// Clears all cached fragments.
305    pub fn clear(&self) {
306        self.entries.write().clear();
307        self.vram_used.store(0, Ordering::Relaxed);
308        self.ram_used.store(0, Ordering::Relaxed);
309    }
310
311    /// Ensures capacity for a new entry, evicting if necessary.
312    fn ensure_capacity(&self, size: u64, tier: CacheTier) {
313        let (capacity, used) = match tier {
314            CacheTier::Vram => (
315                self.config.vram_capacity,
316                self.vram_used.load(Ordering::Relaxed),
317            ),
318            CacheTier::Ram => (
319                self.config.ram_capacity,
320                self.ram_used.load(Ordering::Relaxed),
321            ),
322            CacheTier::None => return,
323        };
324
325        if used + size <= capacity {
326            return;
327        }
328
329        // Need to evict - use LRU
330        let needed = used + size - capacity;
331        self.evict_lru(tier, needed);
332    }
333
334    /// Evicts least recently used entries.
335    fn evict_lru(&self, tier: CacheTier, needed: u64) {
336        let mut entries = self.entries.write();
337        let mut candidates: Vec<_> = entries
338            .iter()
339            .filter(|(_, e)| e.tier == tier)
340            .map(|(id, e)| (id.clone(), e.last_access, e.size, e.users.count()))
341            .collect();
342
343        // Sort by: shared count (evict non-shared first), then access time
344        candidates.sort_by(|a, b| a.3.cmp(&b.3).then(a.1.cmp(&b.1)));
345
346        let mut freed = 0u64;
347        for (id, _, _size, _) in candidates {
348            if freed >= needed {
349                break;
350            }
351
352            if let Some(entry) = entries.remove(&id) {
353                freed += entry.size;
354                match tier {
355                    CacheTier::Vram => {
356                        self.vram_used.fetch_sub(entry.size, Ordering::Relaxed);
357                        self.vram_evictions.fetch_add(1, Ordering::Relaxed);
358                    },
359                    CacheTier::Ram => {
360                        self.ram_used.fetch_sub(entry.size, Ordering::Relaxed);
361                        self.ram_evictions.fetch_add(1, Ordering::Relaxed);
362                    },
363                    CacheTier::None => {},
364                }
365            }
366        }
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_cache_insert_and_access() {
376        let cache = FragmentCache::new(CacheConfig {
377            vram_capacity: 1000,
378            ram_capacity: 1000,
379        });
380
381        cache.insert("frag1", 100, CacheTier::Vram, true);
382        assert!(cache.contains("frag1"));
383        assert_eq!(cache.get_tier("frag1"), CacheTier::Vram);
384
385        let tier = cache.access("frag1");
386        assert_eq!(tier, CacheTier::Vram);
387
388        let stats = cache.stats();
389        assert_eq!(stats.hits, 1);
390        assert_eq!(stats.vram_used, 100);
391    }
392
393    #[test]
394    fn test_cache_miss() {
395        let cache = FragmentCache::new(CacheConfig::default());
396
397        let tier = cache.access("nonexistent");
398        assert_eq!(tier, CacheTier::None);
399
400        let stats = cache.stats();
401        assert_eq!(stats.misses, 1);
402    }
403
404    #[test]
405    fn test_cache_eviction() {
406        let cache = FragmentCache::new(CacheConfig {
407            vram_capacity: 200,
408            ram_capacity: 1000,
409        });
410
411        cache.insert("frag1", 100, CacheTier::Vram, true);
412        cache.insert("frag2", 100, CacheTier::Vram, true);
413
414        // This should trigger eviction
415        cache.insert("frag3", 100, CacheTier::Vram, true);
416
417        let stats = cache.stats();
418        assert!(stats.vram_evictions >= 1);
419        assert!(stats.vram_used <= 200);
420    }
421
422    #[test]
423    fn test_cache_promote_demote() {
424        let cache = FragmentCache::new(CacheConfig {
425            vram_capacity: 1000,
426            ram_capacity: 1000,
427        });
428
429        cache.insert("frag1", 100, CacheTier::Ram, true);
430        assert_eq!(cache.ram_used(), 100);
431        assert_eq!(cache.vram_used(), 0);
432
433        cache.promote("frag1", CacheTier::Vram);
434        assert_eq!(cache.ram_used(), 0);
435        assert_eq!(cache.vram_used(), 100);
436        assert_eq!(cache.get_tier("frag1"), CacheTier::Vram);
437    }
438
439    #[test]
440    fn test_hit_rate() {
441        let stats = CacheStats {
442            hits: 80,
443            misses: 20,
444            ..Default::default()
445        };
446
447        assert!((stats.hit_rate() - 0.8).abs() < 0.001);
448    }
449}