use pf_core::digest::Digest256;
use serde::{Deserialize, Serialize};
pub const LAYOUT_V1: &str = "paged-batchinvariant-v1";
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Dtype {
Bf16,
F16,
F32,
Fp8E4m3,
}
impl Dtype {
#[must_use]
pub const fn bytes(self) -> usize {
match self {
Self::Bf16 | Self::F16 => 2,
Self::F32 => 4,
Self::Fp8E4m3 => 1,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CacheMeta {
pub page_size_tokens: u32,
pub n_layers: u32,
pub n_heads: u32,
pub head_dim: u32,
pub dtype: Dtype,
}
impl CacheMeta {
#[must_use]
pub const fn page_bytes(&self) -> usize {
(self.n_layers as usize)
* (self.page_size_tokens as usize)
* (self.n_heads as usize)
* (self.head_dim as usize)
* self.dtype.bytes()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Page {
pub ix: u32,
pub k: Digest256,
pub v: Digest256,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct LogicalSeq {
pub id: String,
pub page_ixs: Vec<u32>,
pub fill_in_last_page: u32,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PageManifest {
pub layout: String,
#[serde(flatten)]
pub meta: CacheMeta,
pub pages: Vec<Page>,
pub logical_seqs: Vec<LogicalSeq>,
}
impl PageManifest {
#[must_use]
pub fn new(meta: CacheMeta) -> Self {
Self {
layout: LAYOUT_V1.into(),
meta,
pages: Vec::new(),
logical_seqs: Vec::new(),
}
}
pub fn canonicalize(&mut self) {
self.pages.sort_by_key(|p| p.ix);
self.logical_seqs.sort_by(|a, b| a.id.cmp(&b.id));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn meta() -> CacheMeta {
CacheMeta {
page_size_tokens: 16,
n_layers: 80,
n_heads: 64,
head_dim: 128,
dtype: Dtype::Bf16,
}
}
#[test]
fn page_bytes_matches_spec() {
assert_eq!(meta().page_bytes(), 80 * 16 * 64 * 128 * 2);
}
#[test]
fn manifest_round_trips_through_json() {
let mut m = PageManifest::new(meta());
let d = Digest256::of(b"x");
m.pages.push(Page {
ix: 1,
k: d.clone(),
v: d.clone(),
});
m.pages.push(Page {
ix: 0,
k: d.clone(),
v: d.clone(),
});
m.logical_seqs.push(LogicalSeq {
id: "seq-A".into(),
page_ixs: vec![0, 1],
fill_in_last_page: 7,
});
m.canonicalize();
let s = serde_json::to_string(&m).unwrap();
let back: PageManifest = serde_json::from_str(&s).unwrap();
assert_eq!(back.layout, LAYOUT_V1);
assert_eq!(back.pages[0].ix, 0); assert_eq!(back.meta.page_size_tokens, 16);
}
#[test]
fn canonicalize_makes_digest_order_invariant() {
let d = Digest256::of(b"x");
let mut a = PageManifest::new(meta());
a.pages.push(Page {
ix: 0,
k: d.clone(),
v: d.clone(),
});
a.pages.push(Page {
ix: 1,
k: d.clone(),
v: d.clone(),
});
let mut b = PageManifest::new(meta());
b.pages.push(Page {
ix: 1,
k: d.clone(),
v: d.clone(),
});
b.pages.push(Page {
ix: 0,
k: d.clone(),
v: d,
});
a.canonicalize();
b.canonicalize();
assert_eq!(
serde_json::to_vec(&a).unwrap(),
serde_json::to_vec(&b).unwrap()
);
}
#[test]
fn dtype_bytes_correct() {
assert_eq!(Dtype::Bf16.bytes(), 2);
assert_eq!(Dtype::F16.bytes(), 2);
assert_eq!(Dtype::F32.bytes(), 4);
assert_eq!(Dtype::Fp8E4m3.bytes(), 1);
}
}