use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use crate::error::RuntimeError;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TensorId(pub String);
impl TensorId {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
}
impl std::fmt::Display for TensorId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
pub struct ResidentTensor {
pub data: Arc<[u8]>,
pub size_bytes: usize,
}
impl std::fmt::Debug for ResidentTensor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResidentTensor")
.field("size_bytes", &self.size_bytes)
.field("data_len", &self.data.len())
.finish()
}
}
#[derive(Debug, Clone)]
pub struct TensorEntry {
pub file_offset: u64,
pub size_bytes: usize,
}
pub trait PagerSource: Send + Sync {
fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError>;
fn total_size_bytes(&self) -> u64;
}
pub struct FilePagerSource {
path: std::path::PathBuf,
total_bytes: u64,
}
impl FilePagerSource {
pub fn open(path: impl Into<std::path::PathBuf>) -> Result<Self, RuntimeError> {
let path = path.into();
let meta = std::fs::metadata(&path)?;
Ok(Self {
total_bytes: meta.len(),
path,
})
}
}
impl PagerSource for FilePagerSource {
fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
use std::io::{Read, Seek, SeekFrom};
let mut file = std::fs::File::open(&self.path)?;
file.seek(SeekFrom::Start(offset))?;
file.read_exact(out)?;
Ok(())
}
fn total_size_bytes(&self) -> u64 {
self.total_bytes
}
}
#[cfg(feature = "mmap")]
pub struct MmapPagerSource {
mmap: Arc<memmap2::Mmap>,
}
#[cfg(feature = "mmap")]
impl MmapPagerSource {
pub fn open(path: impl AsRef<std::path::Path>) -> Result<Self, RuntimeError> {
let file = std::fs::File::open(path)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
Ok(Self {
mmap: Arc::new(mmap),
})
}
}
#[cfg(feature = "mmap")]
impl PagerSource for MmapPagerSource {
fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
let start = offset as usize;
let end = start
.checked_add(out.len())
.ok_or(RuntimeError::OffloadEof {
offset,
needed: out.len(),
available: 0,
})?;
if end > self.mmap.len() {
let available = self.mmap.len().saturating_sub(start);
return Err(RuntimeError::OffloadEof {
offset,
needed: out.len(),
available,
});
}
out.copy_from_slice(&self.mmap[start..end]);
Ok(())
}
fn total_size_bytes(&self) -> u64 {
self.mmap.len() as u64
}
}
pub struct LayerPager {
source: Arc<dyn PagerSource>,
tensor_map: HashMap<TensorId, TensorEntry>,
resident: RwLock<HashMap<TensorId, Arc<ResidentTensor>>>,
pinned: HashSet<TensorId>,
lru: Mutex<VecDeque<TensorId>>,
budget_bytes: u64,
resident_bytes: AtomicU64,
}
impl LayerPager {
pub fn new(
source: Arc<dyn PagerSource>,
tensor_map: HashMap<TensorId, TensorEntry>,
budget_bytes: u64,
pinned: HashSet<TensorId>,
) -> Self {
Self {
source,
tensor_map,
resident: RwLock::new(HashMap::new()),
pinned,
lru: Mutex::new(VecDeque::new()),
budget_bytes,
resident_bytes: AtomicU64::new(0),
}
}
pub fn acquire(&self, id: &TensorId) -> Result<Arc<ResidentTensor>, RuntimeError> {
{
let guard = self
.resident
.read()
.map_err(|_| RuntimeError::LockPoisoned)?;
if let Some(_tensor) = guard.get(id) {
drop(guard);
self.bump_lru(id)?;
let guard2 = self
.resident
.read()
.map_err(|_| RuntimeError::LockPoisoned)?;
if let Some(tensor2) = guard2.get(id) {
return Ok(Arc::clone(tensor2));
}
}
}
let entry = self
.tensor_map
.get(id)
.ok_or_else(|| RuntimeError::TensorNotFound(id.0.clone()))?;
self.evict_to_fit(entry.size_bytes)?;
let mut data = vec![0u8; entry.size_bytes];
self.source.read_bytes_at(entry.file_offset, &mut data)?;
let tensor = Arc::new(ResidentTensor {
data: data.into(),
size_bytes: entry.size_bytes,
});
{
let mut guard = self
.resident
.write()
.map_err(|_| RuntimeError::LockPoisoned)?;
guard.insert(id.clone(), Arc::clone(&tensor));
}
{
let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
lru.push_back(id.clone());
}
self.resident_bytes
.fetch_add(entry.size_bytes as u64, Ordering::Relaxed);
Ok(tensor)
}
fn evict_to_fit(&self, needed_bytes: usize) -> Result<(), RuntimeError> {
loop {
let current = self.resident_bytes.load(Ordering::Relaxed);
if current.saturating_add(needed_bytes as u64) <= self.budget_bytes {
break;
}
let victim = {
let lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
lru.iter().find(|id| !self.pinned.contains(*id)).cloned()
};
match victim {
None => break,
Some(victim_id) => {
let removed = {
let mut guard = self
.resident
.write()
.map_err(|_| RuntimeError::LockPoisoned)?;
guard.remove(&victim_id)
};
if let Some(evicted) = removed {
self.resident_bytes
.fetch_sub(evicted.size_bytes as u64, Ordering::Relaxed);
let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
lru.retain(|x| x != &victim_id);
}
}
}
}
Ok(())
}
fn bump_lru(&self, id: &TensorId) -> Result<(), RuntimeError> {
let mut lru = self.lru.lock().map_err(|_| RuntimeError::LockPoisoned)?;
if let Some(pos) = lru.iter().position(|x| x == id) {
lru.remove(pos);
}
lru.push_back(id.clone());
Ok(())
}
pub fn resident_bytes(&self) -> u64 {
self.resident_bytes.load(Ordering::Relaxed)
}
pub fn resident_count(&self) -> usize {
self.resident.read().map(|g| g.len()).unwrap_or(0)
}
pub fn budget_bytes(&self) -> u64 {
self.budget_bytes
}
pub fn is_resident(&self, id: &TensorId) -> bool {
self.resident
.read()
.map(|g| g.contains_key(id))
.unwrap_or(false)
}
pub fn is_pinned(&self, id: &TensorId) -> bool {
self.pinned.contains(id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::{HashMap, HashSet};
use std::io::Write;
use std::sync::Arc;
struct VecPagerSource(Vec<u8>);
impl PagerSource for VecPagerSource {
fn read_bytes_at(&self, offset: u64, out: &mut [u8]) -> Result<(), RuntimeError> {
let start = offset as usize;
let end = start
.checked_add(out.len())
.ok_or(RuntimeError::OffloadEof {
offset,
needed: out.len(),
available: 0,
})?;
if end > self.0.len() {
let available = self.0.len().saturating_sub(start);
return Err(RuntimeError::OffloadEof {
offset,
needed: out.len(),
available,
});
}
out.copy_from_slice(&self.0[start..end]);
Ok(())
}
fn total_size_bytes(&self) -> u64 {
self.0.len() as u64
}
}
fn make_pager(
tensors: &[(&str, usize, u64)],
budget: u64,
pinned: &[&str],
) -> (LayerPager, Vec<u8>) {
let total: usize = tensors
.iter()
.map(|(_, sz, offset)| *offset as usize + *sz)
.max()
.unwrap_or(0)
+ 1;
let mut data = vec![0u8; total];
let mut tensor_map = HashMap::new();
for (id, size, offset) in tensors {
for i in 0..*size {
data[*offset as usize + i] = (i % 256) as u8;
}
tensor_map.insert(
TensorId(id.to_string()),
TensorEntry {
file_offset: *offset,
size_bytes: *size,
},
);
}
let pinned_set: HashSet<TensorId> =
pinned.iter().map(|s| TensorId(s.to_string())).collect();
let pager = LayerPager::new(
Arc::new(VecPagerSource(data.clone())),
tensor_map,
budget,
pinned_set,
);
(pager, data)
}
#[test]
fn offload_budget_evicts_coldest() {
let (pager, _) = make_pager(
&[
("layer_0", 100, 0),
("layer_1", 100, 100),
("layer_2", 100, 200),
],
200,
&[],
);
let _t0 = pager.acquire(&TensorId("layer_0".into())).expect("t0");
let _t1 = pager.acquire(&TensorId("layer_1".into())).expect("t1");
assert_eq!(pager.resident_count(), 2);
drop(_t0);
drop(_t1);
let _t2 = pager.acquire(&TensorId("layer_2".into())).expect("t2");
assert!(
pager.resident_bytes() <= 200,
"resident_bytes ({}) must be <= budget (200)",
pager.resident_bytes()
);
}
#[test]
fn offload_pinned_never_evicted() {
let (pager, _) = make_pager(
&[("layer_0", 100, 0), ("layer_1", 100, 100)],
100,
&["layer_0"],
);
let _t0 = pager.acquire(&TensorId("layer_0".into())).expect("pinned");
let _t1 = pager.acquire(&TensorId("layer_1".into())).expect("cold");
assert!(
pager.is_resident(&TensorId("layer_0".into())),
"pinned tensor must not be evicted"
);
}
#[test]
fn offload_acquire_reads_correct_bytes() {
let (pager, data) = make_pager(&[("t0", 64, 128)], u64::MAX, &[]);
let tensor = pager.acquire(&TensorId("t0".into())).expect("t0");
assert_eq!(tensor.data.len(), 64);
assert_eq!(&tensor.data[..], &data[128..192]);
}
#[test]
fn offload_unknown_tensor_returns_error() {
let (pager, _) = make_pager(&[("t0", 10, 0)], u64::MAX, &[]);
let res = pager.acquire(&TensorId("nonexistent".into()));
assert!(
matches!(res, Err(RuntimeError::TensorNotFound(_))),
"expected TensorNotFound, got {res:?}"
);
}
#[test]
fn offload_double_acquire_returns_same_bytes() {
let (pager, data) = make_pager(&[("t0", 32, 64)], u64::MAX, &[]);
let a = pager.acquire(&TensorId("t0".into())).expect("a");
let b = pager.acquire(&TensorId("t0".into())).expect("b");
assert_eq!(&a.data[..], &b.data[..]);
assert_eq!(&a.data[..], &data[64..96]);
}
#[test]
fn offload_file_pager_source_reads_correctly() {
let mut tmp = tempfile::NamedTempFile::new().expect("temp file");
let payload: Vec<u8> = (0u8..=255u8).collect();
tmp.write_all(&payload).expect("write");
let source = FilePagerSource::open(tmp.path()).expect("open");
let mut buf = vec![0u8; 10];
source.read_bytes_at(5, &mut buf).expect("read");
assert_eq!(&buf, &payload[5..15]);
}
#[test]
fn offload_file_source_eof_errors() {
let mut tmp = tempfile::NamedTempFile::new().expect("temp file");
tmp.write_all(b"short").expect("write");
let source = FilePagerSource::open(tmp.path()).expect("open");
let mut buf = vec![0u8; 100];
let res = source.read_bytes_at(0, &mut buf);
assert!(res.is_err(), "reading past end of file must return Err");
}
#[test]
fn offload_resident_count_tracks_evictions() {
let (pager, _) = make_pager(&[("a", 50, 0), ("b", 50, 50), ("c", 50, 100)], 100, &[]);
assert_eq!(pager.resident_count(), 0);
let _a = pager.acquire(&TensorId("a".into())).expect("a");
assert_eq!(pager.resident_count(), 1);
let _b = pager.acquire(&TensorId("b".into())).expect("b");
assert_eq!(pager.resident_count(), 2);
let _c = pager.acquire(&TensorId("c".into())).expect("c");
assert!(pager.resident_count() <= 2, "budget limits to 2 tensors");
assert!(
pager.resident_bytes() <= 100,
"resident_bytes must not exceed budget"
);
}
#[test]
fn offload_is_pinned_reflects_set() {
let (pager, _) = make_pager(&[("a", 10, 0), ("b", 10, 10)], u64::MAX, &["a"]);
assert!(pager.is_pinned(&TensorId("a".into())));
assert!(!pager.is_pinned(&TensorId("b".into())));
}
#[test]
fn offload_budget_strictly_respected() {
let budget = 50u64;
let (pager, _) = make_pager(
&[
("t0", 50, 0),
("t1", 50, 50),
("t2", 50, 100),
("t3", 50, 150),
],
budget,
&[],
);
for name in ["t0", "t1", "t2", "t3"] {
let _ = pager.acquire(&TensorId(name.into())).expect(name);
assert!(
pager.resident_bytes() <= budget,
"after acquiring {name}, resident_bytes={} > budget={budget}",
pager.resident_bytes()
);
}
}
#[test]
fn tensor_id_display() {
let id = TensorId::new("blk.0.attn_q.weight");
assert_eq!(id.to_string(), "blk.0.attn_q.weight");
}
}