#![allow(dead_code)]
use crate::store::{BlockKey, BlockMeta, ReconstructPolicy, StoreError, Tier, TieredStore};
pub trait TensorStore {
fn put(&mut self, key: BlockKey, data: &[f32], tier: Tier, now: u64) -> Result<(), StoreError>;
fn get(&mut self, key: BlockKey, out: &mut [f32], now: u64) -> Result<usize, StoreError>;
fn touch(&mut self, key: BlockKey, now: u64);
fn evict(&mut self, key: BlockKey, policy: ReconstructPolicy) -> Result<(), StoreError>;
fn meta(&self, key: BlockKey) -> Option<&BlockMeta>;
fn block_count(&self) -> usize;
fn tier_count(&self, tier: Tier) -> usize;
fn total_bytes(&self) -> usize;
fn contains(&self, key: BlockKey) -> bool;
fn snapshot(&self) -> TensorStoreSnapshot;
}
impl TensorStore for TieredStore {
fn put(&mut self, key: BlockKey, data: &[f32], tier: Tier, now: u64) -> Result<(), StoreError> {
TieredStore::put(self, key, data, tier, now)
}
fn get(&mut self, key: BlockKey, out: &mut [f32], now: u64) -> Result<usize, StoreError> {
TieredStore::get(self, key, out, now)
}
fn touch(&mut self, key: BlockKey, now: u64) {
TieredStore::touch(self, key, now);
}
fn evict(&mut self, key: BlockKey, policy: ReconstructPolicy) -> Result<(), StoreError> {
TieredStore::evict(self, key, policy)
}
fn meta(&self, key: BlockKey) -> Option<&BlockMeta> {
TieredStore::meta(self, key)
}
fn block_count(&self) -> usize {
TieredStore::block_count(self)
}
fn tier_count(&self, tier: Tier) -> usize {
TieredStore::tier_count(self, tier)
}
fn total_bytes(&self) -> usize {
TieredStore::total_bytes(self)
}
fn contains(&self, key: BlockKey) -> bool {
TieredStore::meta(self, key).is_some()
}
fn snapshot(&self) -> TensorStoreSnapshot {
let tier_counts = [
TieredStore::tier_count(self, Tier::Tier0),
TieredStore::tier_count(self, Tier::Tier1),
TieredStore::tier_count(self, Tier::Tier2),
TieredStore::tier_count(self, Tier::Tier3),
];
let metrics = TieredStore::metrics(self);
let tier_bytes = [
0, metrics.tier1_bytes as usize,
metrics.tier2_bytes as usize,
metrics.tier3_bytes as usize,
];
TensorStoreSnapshot {
block_count: TieredStore::block_count(self),
tier_counts,
total_bytes: TieredStore::total_bytes(self),
tier_bytes,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TensorStoreSnapshot {
pub block_count: usize,
pub tier_counts: [usize; 4],
pub total_bytes: usize,
pub tier_bytes: [usize; 4],
}
impl TensorStoreSnapshot {
pub fn tier_fraction(&self, tier: Tier) -> f64 {
if self.block_count == 0 {
return 0.0;
}
self.tier_counts[tier as usize] as f64 / self.block_count as f64
}
pub fn byte_fraction(&self, tier: Tier) -> f64 {
if self.total_bytes == 0 {
return 0.0;
}
self.tier_bytes[tier as usize] as f64 / self.total_bytes as f64
}
}
pub trait TensorStoreExt: TensorStore {
fn get_vec(&mut self, key: BlockKey, len: usize, now: u64) -> Result<Vec<f32>, StoreError>;
fn put_tier1(&mut self, key: BlockKey, data: &[f32], now: u64) -> Result<(), StoreError>;
fn is_evicted(&self, key: BlockKey) -> bool;
}
impl<T: TensorStore> TensorStoreExt for T {
fn get_vec(&mut self, key: BlockKey, len: usize, now: u64) -> Result<Vec<f32>, StoreError> {
let mut buf = vec![0.0f32; len];
let n = self.get(key, &mut buf, now)?;
buf.truncate(n);
Ok(buf)
}
fn put_tier1(&mut self, key: BlockKey, data: &[f32], now: u64) -> Result<(), StoreError> {
self.put(key, data, Tier::Tier1, now)
}
fn is_evicted(&self, key: BlockKey) -> bool {
self.meta(key)
.map(|m| m.tier == Tier::Tier0)
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::{BlockKey, Tier, TieredStore};
fn make_key(tid: u128, idx: u32) -> BlockKey {
BlockKey {
tensor_id: tid,
block_index: idx,
}
}
#[test]
fn test_trait_put_get_roundtrip() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
let data: Vec<f32> = (0..64).map(|i| i as f32 * 0.25).collect();
TensorStore::put(&mut store, key, &data, Tier::Tier1, 0).unwrap();
assert_eq!(TensorStore::block_count(&store), 1);
assert!(TensorStore::contains(&store, key));
let mut out = vec![0.0f32; 64];
let n = TensorStore::get(&mut store, key, &mut out, 1).unwrap();
assert_eq!(n, 64);
for (i, (&orig, &dec)) in data.iter().zip(out.iter()).enumerate() {
let err = (orig - dec).abs();
let tol = if orig.abs() > 0.01 {
orig.abs() * 0.02
} else {
0.15
};
assert!(err < tol, "i={i} orig={orig} dec={dec} err={err}");
}
}
#[test]
fn test_trait_touch_updates_access() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 16], Tier::Tier1, 0).unwrap();
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.access_count, 1);
TensorStore::touch(&mut store, key, 10);
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.access_count, 2);
assert_eq!(meta.last_access_at, 10);
}
#[test]
fn test_trait_evict() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 32], Tier::Tier1, 0).unwrap();
assert_eq!(TensorStore::tier_count(&store, Tier::Tier1), 1);
TensorStore::evict(&mut store, key, ReconstructPolicy::Delta).unwrap();
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.tier, Tier::Tier0);
assert_eq!(meta.reconstruct, ReconstructPolicy::Delta);
assert_eq!(TensorStore::tier_count(&store, Tier::Tier0), 1);
assert_eq!(TensorStore::tier_count(&store, Tier::Tier1), 0);
}
#[test]
fn test_trait_contains_false_for_missing() {
let store = TieredStore::new(4096);
assert!(!TensorStore::contains(&store, make_key(99, 0)));
}
#[test]
fn test_trait_total_bytes() {
let mut store = TieredStore::new(4096);
assert_eq!(TensorStore::total_bytes(&store), 0);
TensorStore::put(&mut store, make_key(1, 0), &[1.0; 64], Tier::Tier1, 0).unwrap();
assert!(TensorStore::total_bytes(&store) > 0);
}
#[test]
fn test_snapshot_empty_store() {
let store = TieredStore::new(4096);
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 0);
assert_eq!(snap.tier_counts, [0, 0, 0, 0]);
assert_eq!(snap.total_bytes, 0);
assert_eq!(snap.tier_bytes, [0, 0, 0, 0]);
}
#[test]
fn test_snapshot_populated_store() {
let mut store = TieredStore::new(4096);
let data = vec![1.0f32; 32];
TensorStore::put(&mut store, make_key(1, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(2, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(3, 0), &data, Tier::Tier2, 0).unwrap();
TensorStore::put(&mut store, make_key(4, 0), &data, Tier::Tier3, 0).unwrap();
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 4);
assert_eq!(snap.tier_counts[0], 0); assert_eq!(snap.tier_counts[1], 2); assert_eq!(snap.tier_counts[2], 1); assert_eq!(snap.tier_counts[3], 1); assert!(snap.total_bytes > 0);
assert!(snap.tier_bytes[1] > 0); assert!(snap.tier_bytes[2] > 0); assert!(snap.tier_bytes[3] > 0); assert_eq!(snap.tier_bytes[0], 0); }
#[test]
fn test_snapshot_tier_fraction() {
let mut store = TieredStore::new(4096);
let data = vec![1.0f32; 16];
TensorStore::put(&mut store, make_key(1, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(2, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(3, 0), &data, Tier::Tier2, 0).unwrap();
TensorStore::put(&mut store, make_key(4, 0), &data, Tier::Tier3, 0).unwrap();
let snap = TensorStore::snapshot(&store);
assert!((snap.tier_fraction(Tier::Tier1) - 0.5).abs() < 1e-10);
assert!((snap.tier_fraction(Tier::Tier2) - 0.25).abs() < 1e-10);
assert!((snap.tier_fraction(Tier::Tier3) - 0.25).abs() < 1e-10);
assert!((snap.tier_fraction(Tier::Tier0) - 0.0).abs() < 1e-10);
}
#[test]
fn test_snapshot_tier_fraction_empty() {
let snap = TensorStoreSnapshot {
block_count: 0,
tier_counts: [0; 4],
total_bytes: 0,
tier_bytes: [0; 4],
};
assert_eq!(snap.tier_fraction(Tier::Tier1), 0.0);
}
#[test]
fn test_snapshot_byte_fraction_empty() {
let snap = TensorStoreSnapshot {
block_count: 0,
tier_counts: [0; 4],
total_bytes: 0,
tier_bytes: [0; 4],
};
assert_eq!(snap.byte_fraction(Tier::Tier1), 0.0);
}
#[test]
fn test_snapshot_after_eviction() {
let mut store = TieredStore::new(4096);
let data = vec![1.0f32; 32];
TensorStore::put(&mut store, make_key(1, 0), &data, Tier::Tier1, 0).unwrap();
TensorStore::put(&mut store, make_key(2, 0), &data, Tier::Tier2, 0).unwrap();
TensorStore::evict(&mut store, make_key(1, 0), ReconstructPolicy::None).unwrap();
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 2); assert_eq!(snap.tier_counts[0], 1); assert_eq!(snap.tier_counts[1], 0); assert_eq!(snap.tier_counts[2], 1); assert_eq!(snap.tier_bytes[0], 0); assert_eq!(snap.tier_bytes[1], 0); assert!(snap.tier_bytes[2] > 0); }
#[test]
fn test_ext_get_vec() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
let data: Vec<f32> = (0..32).map(|i| i as f32 * 0.5).collect();
TensorStore::put(&mut store, key, &data, Tier::Tier1, 0).unwrap();
let result = TensorStoreExt::get_vec(&mut store, key, 32, 1).unwrap();
assert_eq!(result.len(), 32);
for (i, (&orig, &dec)) in data.iter().zip(result.iter()).enumerate() {
let err = (orig - dec).abs();
let tol = if orig.abs() > 0.01 {
orig.abs() * 0.05
} else {
0.15
};
assert!(err < tol, "i={i} orig={orig} dec={dec} err={err}");
}
}
#[test]
fn test_ext_get_vec_truncates_to_actual() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 16], Tier::Tier1, 0).unwrap();
let result = TensorStoreExt::get_vec(&mut store, key, 64, 1).unwrap();
assert_eq!(result.len(), 16);
}
#[test]
fn test_ext_get_vec_not_found() {
let mut store = TieredStore::new(4096);
let result = TensorStoreExt::get_vec(&mut store, make_key(99, 0), 16, 0);
assert_eq!(result, Err(StoreError::BlockNotFound));
}
#[test]
fn test_ext_put_tier1() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
let data = vec![2.0f32; 16];
TensorStoreExt::put_tier1(&mut store, key, &data, 0).unwrap();
let meta = TensorStore::meta(&store, key).unwrap();
assert_eq!(meta.tier, Tier::Tier1);
assert_eq!(meta.bits, 8);
}
#[test]
fn test_ext_is_evicted_false_when_active() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 8], Tier::Tier1, 0).unwrap();
assert!(!TensorStoreExt::is_evicted(&store, key));
}
#[test]
fn test_ext_is_evicted_true_after_evict() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
TensorStore::put(&mut store, key, &[1.0; 8], Tier::Tier1, 0).unwrap();
TensorStore::evict(&mut store, key, ReconstructPolicy::None).unwrap();
assert!(TensorStoreExt::is_evicted(&store, key));
}
#[test]
fn test_ext_is_evicted_false_when_missing() {
let store = TieredStore::new(4096);
assert!(!TensorStoreExt::is_evicted(&store, make_key(99, 0)));
}
#[test]
fn test_trait_object_usable() {
let mut store = TieredStore::new(4096);
let key = make_key(1, 0);
fn use_store(s: &mut dyn TensorStore) -> usize {
s.block_count()
}
TensorStore::put(&mut store, key, &[1.0; 8], Tier::Tier1, 0).unwrap();
assert_eq!(use_store(&mut store), 1);
}
#[test]
fn test_integration_mixed_usage() {
let mut store = TieredStore::new(4096);
let k1 = make_key(1, 0);
let k2 = make_key(2, 0);
let k3 = make_key(3, 0);
TensorStoreExt::put_tier1(&mut store, k1, &[1.0; 32], 0).unwrap();
TensorStore::put(&mut store, k2, &[2.0; 32], Tier::Tier2, 0).unwrap();
TensorStore::put(&mut store, k3, &[3.0; 32], Tier::Tier3, 0).unwrap();
assert_eq!(TensorStore::block_count(&store), 3);
assert!(TensorStore::contains(&store, k1));
assert!(TensorStore::contains(&store, k2));
assert!(TensorStore::contains(&store, k3));
TensorStore::evict(&mut store, k3, ReconstructPolicy::Delta).unwrap();
assert!(TensorStoreExt::is_evicted(&store, k3));
assert!(!TensorStoreExt::is_evicted(&store, k1));
let v1 = TensorStoreExt::get_vec(&mut store, k1, 32, 10).unwrap();
assert_eq!(v1.len(), 32);
let snap = TensorStore::snapshot(&store);
assert_eq!(snap.block_count, 3);
assert_eq!(snap.tier_counts[0], 1); assert_eq!(snap.tier_counts[1], 1); assert_eq!(snap.tier_counts[2], 1); assert_eq!(snap.tier_counts[3], 0); }
}