pf_cache/format.rs
1// SPDX-License-Identifier: MIT
2//! Wire format for the cache layer (`paged-batchinvariant-v1`).
3//!
4//! Mirrors `agent_docs/cache-layer.md` §"On-disk format" exactly. The page
5//! manifest is serialized as JSON for human-debuggability; the per-page K/V
6//! payloads are raw bytes (zstd-compressed at the [`pf_core::cas::FsBlobStore`]
7//! layer, not double-compressed here).
8
9use pf_core::digest::Digest256;
10use serde::{Deserialize, Serialize};
11
12/// Schema discriminator for the v1 layout.
13pub const LAYOUT_V1: &str = "paged-batchinvariant-v1";
14
15/// Numeric dtype of cache entries. Matches the engine-side dtype 1:1 — we
16/// never convert here.
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum Dtype {
20 /// IEEE bfloat16 — vLLM and SGLang default for Llama-class models.
21 Bf16,
22 /// IEEE binary16.
23 F16,
24 /// IEEE binary32 (single-precision; rare in production).
25 F32,
26 /// 8-bit FP, E4M3 layout.
27 Fp8E4m3,
28}
29
30impl Dtype {
31 /// Bytes per element.
32 #[must_use]
33 pub const fn bytes(self) -> usize {
34 match self {
35 Self::Bf16 | Self::F16 => 2,
36 Self::F32 => 4,
37 Self::Fp8E4m3 => 1,
38 }
39 }
40}
41
42/// Static metadata describing a paged KV cache. Identical across pages of
43/// the same engine instance; embedded once in the [`PageManifest`].
44#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
45pub struct CacheMeta {
46 /// Tokens per page (vLLM default 16).
47 pub page_size_tokens: u32,
48 /// Number of transformer layers.
49 pub n_layers: u32,
50 /// Number of attention heads.
51 pub n_heads: u32,
52 /// Per-head dimension.
53 pub head_dim: u32,
54 /// Numeric dtype.
55 pub dtype: Dtype,
56}
57
58impl CacheMeta {
59 /// Bytes per K-page (or per V-page; they're the same shape).
60 #[must_use]
61 pub const fn page_bytes(&self) -> usize {
62 (self.n_layers as usize)
63 * (self.page_size_tokens as usize)
64 * (self.n_heads as usize)
65 * (self.head_dim as usize)
66 * self.dtype.bytes()
67 }
68}
69
70/// One physical page in the cache. K and V are content-addressed
71/// independently so a fork that only mutates V (e.g. via a single-token
72/// generation step) shares its K page with siblings.
73#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
74pub struct Page {
75 /// Physical-page index inside the engine's page table.
76 pub ix: u32,
77 /// Digest of the K-tensor bytes for this page.
78 pub k: Digest256,
79 /// Digest of the V-tensor bytes for this page.
80 pub v: Digest256,
81}
82
83/// One logical request (sequence) in the cache, mapping its token positions
84/// onto a list of physical pages. Preserved across snapshot/restore so
85/// prefix-sharing (vLLM PagedAttention, SGLang RadixAttention) survives.
86#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
87pub struct LogicalSeq {
88 /// Stable identifier for this sequence.
89 pub id: String,
90 /// Ordered list of physical-page indices the sequence occupies.
91 pub page_ixs: Vec<u32>,
92 /// How many of `page_size_tokens` slots in the LAST page are occupied.
93 /// `0` means the last page is full and the next token starts a new page.
94 pub fill_in_last_page: u32,
95}
96
97/// Top-level page manifest. Serialized as JSON; persisted as a single CAS
98/// blob whose digest goes into the `.pfimg` manifest's `cache.manifest`
99/// field.
100#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
101pub struct PageManifest {
102 /// Always [`LAYOUT_V1`] in this version.
103 pub layout: String,
104 /// Static metadata (page size, n_layers, dtype, etc.).
105 #[serde(flatten)]
106 pub meta: CacheMeta,
107 /// Pages sorted by `ix` for deterministic manifest digests.
108 pub pages: Vec<Page>,
109 /// Logical sequences sorted by `id`.
110 pub logical_seqs: Vec<LogicalSeq>,
111}
112
113impl PageManifest {
114 /// Construct a fresh manifest with the v1 layout tag pre-set.
115 #[must_use]
116 pub fn new(meta: CacheMeta) -> Self {
117 Self {
118 layout: LAYOUT_V1.into(),
119 meta,
120 pages: Vec::new(),
121 logical_seqs: Vec::new(),
122 }
123 }
124
125 /// Sort pages by `ix` and seqs by `id` so the JSON serialization (and
126 /// therefore the digest) is invariant w.r.t. iteration order.
127 pub fn canonicalize(&mut self) {
128 self.pages.sort_by_key(|p| p.ix);
129 self.logical_seqs.sort_by(|a, b| a.id.cmp(&b.id));
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 fn meta() -> CacheMeta {
138 CacheMeta {
139 page_size_tokens: 16,
140 n_layers: 80,
141 n_heads: 64,
142 head_dim: 128,
143 dtype: Dtype::Bf16,
144 }
145 }
146
147 #[test]
148 fn page_bytes_matches_spec() {
149 // 80 layers × 16 tokens × 64 heads × 128 head_dim × 2 B (bf16) per page
150 // = 20,971,520 bytes ≈ 20 MiB per K-page (and per V-page).
151 assert_eq!(meta().page_bytes(), 80 * 16 * 64 * 128 * 2);
152 }
153
154 #[test]
155 fn manifest_round_trips_through_json() {
156 let mut m = PageManifest::new(meta());
157 let d = Digest256::of(b"x");
158 m.pages.push(Page {
159 ix: 1,
160 k: d.clone(),
161 v: d.clone(),
162 });
163 m.pages.push(Page {
164 ix: 0,
165 k: d.clone(),
166 v: d.clone(),
167 });
168 m.logical_seqs.push(LogicalSeq {
169 id: "seq-A".into(),
170 page_ixs: vec![0, 1],
171 fill_in_last_page: 7,
172 });
173 m.canonicalize();
174 let s = serde_json::to_string(&m).unwrap();
175 let back: PageManifest = serde_json::from_str(&s).unwrap();
176 assert_eq!(back.layout, LAYOUT_V1);
177 assert_eq!(back.pages[0].ix, 0); // canonicalized order
178 assert_eq!(back.meta.page_size_tokens, 16);
179 }
180
181 #[test]
182 fn canonicalize_makes_digest_order_invariant() {
183 let d = Digest256::of(b"x");
184 let mut a = PageManifest::new(meta());
185 a.pages.push(Page {
186 ix: 0,
187 k: d.clone(),
188 v: d.clone(),
189 });
190 a.pages.push(Page {
191 ix: 1,
192 k: d.clone(),
193 v: d.clone(),
194 });
195 let mut b = PageManifest::new(meta());
196 b.pages.push(Page {
197 ix: 1,
198 k: d.clone(),
199 v: d.clone(),
200 });
201 b.pages.push(Page {
202 ix: 0,
203 k: d.clone(),
204 v: d,
205 });
206 a.canonicalize();
207 b.canonicalize();
208 assert_eq!(
209 serde_json::to_vec(&a).unwrap(),
210 serde_json::to_vec(&b).unwrap()
211 );
212 }
213
214 #[test]
215 fn dtype_bytes_correct() {
216 assert_eq!(Dtype::Bf16.bytes(), 2);
217 assert_eq!(Dtype::F16.bytes(), 2);
218 assert_eq!(Dtype::F32.bytes(), 4);
219 assert_eq!(Dtype::Fp8E4m3.bytes(), 1);
220 }
221}