1use std::collections::{HashMap, HashSet, VecDeque};
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::sync::{Arc, Mutex, RwLock};
17
18use crate::error::RuntimeError;
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
29pub struct TensorId(pub String);
30
31impl TensorId {
32 pub fn new(name: impl Into<String>) -> Self {
34 Self(name.into())
35 }
36}
37
38impl std::fmt::Display for TensorId {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.write_str(&self.0)
41 }
42}
43
44pub struct ResidentTensor {
54 pub data: Arc<[u8]>,
56 pub size_bytes: usize,
58}
59
60impl std::fmt::Debug for ResidentTensor {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("ResidentTensor")
63 .field("size_bytes", &self.size_bytes)
64 .field("data_len", &self.data.len())
65 .finish()
66 }
67}
68
69#[derive(Debug, Clone)]
75pub struct TensorEntry {
76 pub file_offset: u64,
78 pub size_bytes: usize,
80}
81
82pub trait PagerSource: Send + Sync {
92 fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError>;
97
98 fn total_size_bytes(&self) -> u64;
100}
101
102pub struct FilePagerSource {
114 path: std::path::PathBuf,
115 total_bytes: u64,
116}
117
118impl FilePagerSource {
119 pub fn open(path: impl Into<std::path::PathBuf>) -> Result<Self, RuntimeError> {
124 let path = path.into();
125 let meta = std::fs::metadata(&path)?;
126 Ok(Self {
127 total_bytes: meta.len(),
128 path,
129 })
130 }
131}
132
133impl PagerSource for FilePagerSource {
134 fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
135 use std::io::{Read, Seek, SeekFrom};
136 let mut file = std::fs::File::open(&self.path)?;
137 file.seek(SeekFrom::Start(offset))?;
138 file.read_exact(out)?;
139 Ok(())
140 }
141
142 fn total_size_bytes(&self) -> u64 {
143 self.total_bytes
144 }
145}
146
147#[cfg(feature = "mmap")]
158pub struct MmapPagerSource {
159 mmap: Arc<memmap2::Mmap>,
160}
161
162#[cfg(feature = "mmap")]
163impl MmapPagerSource {
164 pub fn open(path: impl AsRef<std::path::Path>) -> Result<Self, RuntimeError> {
166 let file = std::fs::File::open(path)?;
167 let mmap = unsafe { memmap2::Mmap::map(&file)? };
172 Ok(Self {
173 mmap: Arc::new(mmap),
174 })
175 }
176}
177
178#[cfg(feature = "mmap")]
179impl PagerSource for MmapPagerSource {
180 fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
181 let start = offset as usize;
182 let end = start
183 .checked_add(out.len())
184 .ok_or(RuntimeError::OffloadEof {
185 offset,
186 needed: out.len(),
187 available: 0,
188 })?;
189 if end > self.mmap.len() {
190 let available = self.mmap.len().saturating_sub(start);
191 return Err(RuntimeError::OffloadEof {
192 offset,
193 needed: out.len(),
194 available,
195 });
196 }
197 out.copy_from_slice(&self.mmap[start..end]);
198 Ok(())
199 }
200
201 fn total_size_bytes(&self) -> u64 {
202 self.mmap.len() as u64
203 }
204}
205
206pub struct LayerPager {
226 source: Arc<dyn PagerSource>,
227 tensor_map: HashMap<TensorId, TensorEntry>,
228 resident: RwLock<HashMap<TensorId, Arc<ResidentTensor>>>,
229 pinned: HashSet<TensorId>,
230 lru: Mutex<VecDeque<TensorId>>,
231 budget_bytes: u64,
232 resident_bytes: AtomicU64,
233}
234
235impl LayerPager {
236 pub fn new(
245 source: Arc<dyn PagerSource>,
246 tensor_map: HashMap<TensorId, TensorEntry>,
247 budget_bytes: u64,
248 pinned: HashSet<TensorId>,
249 ) -> Self {
250 Self {
251 source,
252 tensor_map,
253 resident: RwLock::new(HashMap::new()),
254 pinned,
255 lru: Mutex::new(VecDeque::new()),
256 budget_bytes,
257 resident_bytes: AtomicU64::new(0),
258 }
259 }
260
261 pub fn acquire(&self, id: &TensorId) -> Result<Arc<ResidentTensor>, RuntimeError> {
273 {
275 let guard = self
276 .resident
277 .read()
278 .map_err(|_| RuntimeError::LockPoisoned)?;
279 if let Some(_tensor) = guard.get(id) {
280 drop(guard);
282 self.bump_lru(id)?;
283 let guard2 = self
286 .resident
287 .read()
288 .map_err(|_| RuntimeError::LockPoisoned)?;
289 if let Some(tensor2) = guard2.get(id) {
292 return Ok(Arc::clone(tensor2));
293 }
294 }
295 }
296
297 let entry = self
299 .tensor_map
300 .get(id)
301 .ok_or_else(|| RuntimeError::TensorNotFound(id.0.clone()))?;
302
303 self.evict_to_fit(entry.size_bytes)?;
305
306 let mut data = vec![0u8; entry.size_bytes];
308 self.source.read_bytes_at(entry.file_offset, &mut data)?;
309
310 let tensor = Arc::new(ResidentTensor {
311 data: data.into(),
312 size_bytes: entry.size_bytes,
313 });
314
315 {
317 let mut guard = self
318 .resident
319 .write()
320 .map_err(|_| RuntimeError::LockPoisoned)?;
321 guard.insert(id.clone(), Arc::clone(&tensor));
322 }
323 {
324 let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
325 lru.push_back(id.clone());
326 }
327 self.resident_bytes
328 .fetch_add(entry.size_bytes as u64, Ordering::Relaxed);
329
330 Ok(tensor)
331 }
332
333 fn evict_to_fit(&self, needed_bytes: usize) -> Result<(), RuntimeError> {
335 loop {
336 let current = self.resident_bytes.load(Ordering::Relaxed);
337 if current.saturating_add(needed_bytes as u64) <= self.budget_bytes {
338 break;
339 }
340
341 let victim = {
343 let lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
344 lru.iter().find(|id| !self.pinned.contains(*id)).cloned()
345 };
346
347 match victim {
348 None => break,
350 Some(victim_id) => {
351 let removed = {
352 let mut guard = self
353 .resident
354 .write()
355 .map_err(|_| RuntimeError::LockPoisoned)?;
356 guard.remove(&victim_id)
357 };
358 if let Some(evicted) = removed {
359 self.resident_bytes
360 .fetch_sub(evicted.size_bytes as u64, Ordering::Relaxed);
361 let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
362 lru.retain(|x| x != &victim_id);
363 }
364 }
365 }
366 }
367 Ok(())
368 }
369
370 fn bump_lru(&self, id: &TensorId) -> Result<(), RuntimeError> {
372 let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
373 if let Some(pos) = lru.iter().position(|x| x == id) {
374 lru.remove(pos);
375 }
376 lru.push_back(id.clone());
377 Ok(())
378 }
379
380 pub fn resident_bytes(&self) -> u64 {
382 self.resident_bytes.load(Ordering::Relaxed)
383 }
384
385 pub fn resident_count(&self) -> usize {
387 self.resident.read().map(|g| g.len()).unwrap_or(0)
388 }
389
390 pub fn budget_bytes(&self) -> u64 {
392 self.budget_bytes
393 }
394
395 pub fn is_resident(&self, id: &TensorId) -> bool {
397 self.resident
398 .read()
399 .map(|g| g.contains_key(id))
400 .unwrap_or(false)
401 }
402
403 pub fn is_pinned(&self, id: &TensorId) -> bool {
405 self.pinned.contains(id)
406 }
407}
408
409#[cfg(test)]
414mod tests {
415 use super::*;
416 use std::collections::{HashMap, HashSet};
417 use std::io::Write;
418 use std::sync::Arc;
419
420 struct VecPagerSource(Vec<u8>);
423
424 impl PagerSource for VecPagerSource {
425 fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
426 let start = offset as usize;
427 let end = start
428 .checked_add(out.len())
429 .ok_or(RuntimeError::OffloadEof {
430 offset,
431 needed: out.len(),
432 available: 0,
433 })?;
434 if end > self.0.len() {
435 let available = self.0.len().saturating_sub(start);
436 return Err(RuntimeError::OffloadEof {
437 offset,
438 needed: out.len(),
439 available,
440 });
441 }
442 out.copy_from_slice(&self.0[start..end]);
443 Ok(())
444 }
445
446 fn total_size_bytes(&self) -> u64 {
447 self.0.len() as u64
448 }
449 }
450
451 fn make_pager(
454 tensors: &[(&str, usize, u64)],
455 budget: u64,
456 pinned: &[&str],
457 ) -> (LayerPager, Vec<u8>) {
458 let total: usize = tensors
459 .iter()
460 .map(|(_, sz, offset)| *offset as usize + *sz)
461 .max()
462 .unwrap_or(0)
463 + 1;
464 let mut data = vec![0u8; total];
465 let mut tensor_map = HashMap::new();
466 for (id, size, offset) in tensors {
467 for i in 0..*size {
468 data[*offset as usize + i] = (i % 256) as u8;
469 }
470 tensor_map.insert(
471 TensorId(id.to_string()),
472 TensorEntry {
473 file_offset: *offset,
474 size_bytes: *size,
475 },
476 );
477 }
478 let pinned_set: HashSet<TensorId> =
479 pinned.iter().map(|s| TensorId(s.to_string())).collect();
480 let pager = LayerPager::new(
481 Arc::new(VecPagerSource(data.clone())),
482 tensor_map,
483 budget,
484 pinned_set,
485 );
486 (pager, data)
487 }
488
489 #[test]
492 fn offload_budget_evicts_coldest() {
493 let (pager, _) = make_pager(
495 &[
496 ("layer_0", 100, 0),
497 ("layer_1", 100, 100),
498 ("layer_2", 100, 200),
499 ],
500 200,
501 &[],
502 );
503
504 let _t0 = pager.acquire(&TensorId("layer_0".into())).expect("t0");
505 let _t1 = pager.acquire(&TensorId("layer_1".into())).expect("t1");
506 assert_eq!(pager.resident_count(), 2);
507
508 drop(_t0);
510 drop(_t1);
511 let _t2 = pager.acquire(&TensorId("layer_2".into())).expect("t2");
512 assert!(
513 pager.resident_bytes() <= 200,
514 "resident_bytes ({}) must be <= budget (200)",
515 pager.resident_bytes()
516 );
517 }
518
519 #[test]
522 fn offload_pinned_never_evicted() {
523 let (pager, _) = make_pager(
525 &[("layer_0", 100, 0), ("layer_1", 100, 100)],
526 100,
527 &["layer_0"],
528 );
529
530 let _t0 = pager.acquire(&TensorId("layer_0".into())).expect("pinned");
531 let _t1 = pager.acquire(&TensorId("layer_1".into())).expect("cold");
533
534 assert!(
535 pager.is_resident(&TensorId("layer_0".into())),
536 "pinned tensor must not be evicted"
537 );
538 }
539
540 #[test]
543 fn offload_acquire_reads_correct_bytes() {
544 let (pager, data) = make_pager(&[("t0", 64, 128)], u64::MAX, &[]);
545 let tensor = pager.acquire(&TensorId("t0".into())).expect("t0");
546 assert_eq!(tensor.data.len(), 64);
547 assert_eq!(&tensor.data[..], &data[128..192]);
548 }
549
550 #[test]
553 fn offload_unknown_tensor_returns_error() {
554 let (pager, _) = make_pager(&[("t0", 10, 0)], u64::MAX, &[]);
555 let res = pager.acquire(&TensorId("nonexistent".into()));
556 assert!(
557 matches!(res, Err(RuntimeError::TensorNotFound(_))),
558 "expected TensorNotFound, got {res:?}"
559 );
560 }
561
562 #[test]
565 fn offload_double_acquire_returns_same_bytes() {
566 let (pager, data) = make_pager(&[("t0", 32, 64)], u64::MAX, &[]);
567 let a = pager.acquire(&TensorId("t0".into())).expect("a");
568 let b = pager.acquire(&TensorId("t0".into())).expect("b");
569 assert_eq!(&a.data[..], &b.data[..]);
570 assert_eq!(&a.data[..], &data[64..96]);
571 }
572
573 #[test]
576 fn offload_file_pager_source_reads_correctly() {
577 let mut tmp = tempfile::NamedTempFile::new().expect("temp file");
578 let payload: Vec<u8> = (0u8..=255u8).collect();
579 tmp.write_all(&payload).expect("write");
580 let source = FilePagerSource::open(tmp.path()).expect("open");
581 let mut buf = vec![0u8; 10];
582 source.read_bytes_at(5, &mut buf).expect("read");
583 assert_eq!(&buf, &payload[5..15]);
584 }
585
586 #[test]
589 fn offload_file_source_eof_errors() {
590 let mut tmp = tempfile::NamedTempFile::new().expect("temp file");
591 tmp.write_all(b"short").expect("write");
592 let source = FilePagerSource::open(tmp.path()).expect("open");
593 let mut buf = vec![0u8; 100];
594 let res = source.read_bytes_at(0, &mut buf);
595 assert!(res.is_err(), "reading past end of file must return Err");
596 }
597
598 #[test]
601 fn offload_resident_count_tracks_evictions() {
602 let (pager, _) = make_pager(&[("a", 50, 0), ("b", 50, 50), ("c", 50, 100)], 100, &[]);
603 assert_eq!(pager.resident_count(), 0);
604 let _a = pager.acquire(&TensorId("a".into())).expect("a");
605 assert_eq!(pager.resident_count(), 1);
606 let _b = pager.acquire(&TensorId("b".into())).expect("b");
607 assert_eq!(pager.resident_count(), 2);
608 let _c = pager.acquire(&TensorId("c".into())).expect("c");
610 assert!(pager.resident_count() <= 2, "budget limits to 2 tensors");
611 assert!(
612 pager.resident_bytes() <= 100,
613 "resident_bytes must not exceed budget"
614 );
615 }
616
617 #[test]
620 fn offload_is_pinned_reflects_set() {
621 let (pager, _) = make_pager(&[("a", 10, 0), ("b", 10, 10)], u64::MAX, &["a"]);
622 assert!(pager.is_pinned(&TensorId("a".into())));
623 assert!(!pager.is_pinned(&TensorId("b".into())));
624 }
625
626 #[test]
629 fn offload_budget_strictly_respected() {
630 let budget = 50u64;
631 let (pager, _) = make_pager(
632 &[
633 ("t0", 50, 0),
634 ("t1", 50, 50),
635 ("t2", 50, 100),
636 ("t3", 50, 150),
637 ],
638 budget,
639 &[],
640 );
641 for name in ["t0", "t1", "t2", "t3"] {
642 let _ = pager.acquire(&TensorId(name.into())).expect(name);
643 assert!(
644 pager.resident_bytes() <= budget,
645 "after acquiring {name}, resident_bytes={} > budget={budget}",
646 pager.resident_bytes()
647 );
648 }
649 }
650
651 #[test]
654 fn tensor_id_display() {
655 let id = TensorId::new("blk.0.attn_q.weight");
656 assert_eq!(id.to_string(), "blk.0.attn_q.weight");
657 }
658}