1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, FibQuantError, Result};
4
5use super::shape::{KvTensorShapeV1, KV_SHAPE_SCHEMA};
6
7pub const KV_LAYOUT_SCHEMA: &str = "fib_quant_kv_cache_layout_v1";
8pub const KV_PAGE_GEOMETRY_SCHEMA: &str = "fib_quant_kv_page_geometry_v1";
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[non_exhaustive]
13#[serde(rename_all = "snake_case")]
14pub enum KvLayoutOrder {
15 BatchLayerHeadTokenDim,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct KvCacheLayoutV1 {
22 pub schema_version: String,
24 pub shape_schema_version: String,
26 pub order: KvLayoutOrder,
28 pub batch_stride: u64,
30 pub layer_stride: u64,
32 pub head_stride: u64,
34 pub token_stride: u64,
36 pub dim_stride: u64,
38}
39
40impl KvCacheLayoutV1 {
41 pub fn canonical(shape: &KvTensorShapeV1) -> Result<Self> {
43 shape.validate()?;
44 let dim_stride = 1;
45 let token_stride = u64::from(shape.head_dim);
46 let head_stride = u64::from(shape.tokens) * token_stride;
47 let layer_stride = u64::from(shape.kv_heads) * head_stride;
48 let batch_stride = u64::from(shape.layers) * layer_stride;
49 Ok(Self {
50 schema_version: KV_LAYOUT_SCHEMA.into(),
51 shape_schema_version: KV_SHAPE_SCHEMA.into(),
52 order: KvLayoutOrder::BatchLayerHeadTokenDim,
53 batch_stride,
54 layer_stride,
55 head_stride,
56 token_stride,
57 dim_stride,
58 })
59 }
60
61 pub fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
63 shape.validate()?;
64 if self.schema_version != KV_LAYOUT_SCHEMA {
65 return Err(FibQuantError::CorruptPayload(format!(
66 "kv layout schema_version {}, expected {KV_LAYOUT_SCHEMA}",
67 self.schema_version
68 )));
69 }
70 if self.shape_schema_version != KV_SHAPE_SCHEMA {
71 return Err(FibQuantError::CorruptPayload(
72 "kv layout shape schema mismatch".into(),
73 ));
74 }
75 let expected = Self::canonical(shape)?;
76 if self != &expected {
77 return Err(FibQuantError::CorruptPayload(
78 "only canonical contiguous kv layout is supported by the CPU reference codec"
79 .into(),
80 ));
81 }
82 Ok(())
83 }
84
85 pub fn digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
87 self.validate_for_shape(shape)?;
88 json_digest(KV_LAYOUT_SCHEMA, self)
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub struct KvPageGeometryV1 {
95 pub schema_version: String,
97 pub tokens_per_page: u32,
99 pub vectors_per_block: u32,
101 pub head_dim: u32,
103 pub encoded_block_bytes: u32,
105 pub raw_vector_bytes: u32,
107}
108
109impl KvPageGeometryV1 {
110 pub fn new(tokens_per_page: u32, head_dim: u32, encoded_block_bytes: u32) -> Self {
112 Self {
113 schema_version: KV_PAGE_GEOMETRY_SCHEMA.into(),
114 tokens_per_page,
115 vectors_per_block: 1,
116 head_dim,
117 encoded_block_bytes,
118 raw_vector_bytes: head_dim.saturating_mul(4),
119 }
120 }
121
122 pub fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
124 shape.validate()?;
125 if self.schema_version != KV_PAGE_GEOMETRY_SCHEMA {
126 return Err(FibQuantError::CorruptPayload(format!(
127 "kv page geometry schema_version {}, expected {KV_PAGE_GEOMETRY_SCHEMA}",
128 self.schema_version
129 )));
130 }
131 if self.tokens_per_page == 0 || self.tokens_per_page > shape.tokens {
132 return Err(FibQuantError::CorruptPayload(
133 "tokens_per_page must be in 1..=shape.tokens".into(),
134 ));
135 }
136 if self.vectors_per_block != 1 {
137 return Err(FibQuantError::DependencyUnsupported(
138 "CPU reference codec currently supports one vector per block".into(),
139 ));
140 }
141 if self.head_dim != shape.head_dim {
142 return Err(FibQuantError::CorruptPayload(
143 "page geometry head_dim must match shape".into(),
144 ));
145 }
146 if self.raw_vector_bytes != self.head_dim.saturating_mul(4) {
147 return Err(FibQuantError::CorruptPayload(
148 "raw_vector_bytes must equal head_dim * sizeof(f32)".into(),
149 ));
150 }
151 if self.encoded_block_bytes == 0 {
152 return Err(FibQuantError::CorruptPayload(
153 "encoded_block_bytes must be nonzero".into(),
154 ));
155 }
156 Ok(())
157 }
158
159 pub fn page_count(&self, shape: &KvTensorShapeV1) -> Result<u32> {
161 self.validate_for_shape(shape)?;
162 Ok(shape.tokens.div_ceil(self.tokens_per_page))
163 }
164
165 pub fn digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
167 self.validate_for_shape(shape)?;
168 json_digest(KV_PAGE_GEOMETRY_SCHEMA, self)
169 }
170}