Skip to main content

fib_quant/kv/
layout.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, FibQuantError, Result};
4
5use super::shape::{KvTensorShapeV1, KV_SHAPE_SCHEMA};
6
7pub const KV_LAYOUT_SCHEMA: &str = "fib_quant_kv_cache_layout_v1";
8pub const KV_PAGE_GEOMETRY_SCHEMA: &str = "fib_quant_kv_page_geometry_v1";
9
10/// Canonical physical order for flat tensors.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[non_exhaustive]
13#[serde(rename_all = "snake_case")]
14pub enum KvLayoutOrder {
15    /// `[batch][layer][kv_head][token][head_dim]`.
16    BatchLayerHeadTokenDim,
17}
18
19/// KV cache layout declaration.
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21pub struct KvCacheLayoutV1 {
22    /// Stable schema marker.
23    pub schema_version: String,
24    /// Shape schema bound into this layout.
25    pub shape_schema_version: String,
26    /// Physical order.
27    pub order: KvLayoutOrder,
28    /// Scalar stride for batch.
29    pub batch_stride: u64,
30    /// Scalar stride for layer.
31    pub layer_stride: u64,
32    /// Scalar stride for KV head.
33    pub head_stride: u64,
34    /// Scalar stride for token.
35    pub token_stride: u64,
36    /// Scalar stride for head dimension.
37    pub dim_stride: u64,
38}
39
40impl KvCacheLayoutV1 {
41    /// Build the canonical contiguous layout for a shape.
42    pub fn canonical(shape: &KvTensorShapeV1) -> Result<Self> {
43        shape.validate()?;
44        let dim_stride = 1;
45        let token_stride = u64::from(shape.head_dim);
46        let head_stride = u64::from(shape.tokens) * token_stride;
47        let layer_stride = u64::from(shape.kv_heads) * head_stride;
48        let batch_stride = u64::from(shape.layers) * layer_stride;
49        Ok(Self {
50            schema_version: KV_LAYOUT_SCHEMA.into(),
51            shape_schema_version: KV_SHAPE_SCHEMA.into(),
52            order: KvLayoutOrder::BatchLayerHeadTokenDim,
53            batch_stride,
54            layer_stride,
55            head_stride,
56            token_stride,
57            dim_stride,
58        })
59    }
60
61    /// Validate this layout against a logical shape.
62    pub fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
63        shape.validate()?;
64        if self.schema_version != KV_LAYOUT_SCHEMA {
65            return Err(FibQuantError::CorruptPayload(format!(
66                "kv layout schema_version {}, expected {KV_LAYOUT_SCHEMA}",
67                self.schema_version
68            )));
69        }
70        if self.shape_schema_version != KV_SHAPE_SCHEMA {
71            return Err(FibQuantError::CorruptPayload(
72                "kv layout shape schema mismatch".into(),
73            ));
74        }
75        let expected = Self::canonical(shape)?;
76        if self != &expected {
77            return Err(FibQuantError::CorruptPayload(
78                "only canonical contiguous kv layout is supported by the CPU reference codec"
79                    .into(),
80            ));
81        }
82        Ok(())
83    }
84
85    /// Stable layout digest.
86    pub fn digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
87        self.validate_for_shape(shape)?;
88        json_digest(KV_LAYOUT_SCHEMA, self)
89    }
90}
91
92/// Fixed-size page geometry.
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
94pub struct KvPageGeometryV1 {
95    /// Stable schema marker.
96    pub schema_version: String,
97    /// Number of tokens per encoded page.
98    pub tokens_per_page: u32,
99    /// Number of logical vectors per encoded block.
100    pub vectors_per_block: u32,
101    /// Number of channels in each logical vector.
102    pub head_dim: u32,
103    /// Fixed encoded bytes reserved for each block.
104    pub encoded_block_bytes: u32,
105    /// Fixed raw f32 bytes per logical vector.
106    pub raw_vector_bytes: u32,
107}
108
109impl KvPageGeometryV1 {
110    /// Build a page geometry for one-vector blocks.
111    pub fn new(tokens_per_page: u32, head_dim: u32, encoded_block_bytes: u32) -> Self {
112        Self {
113            schema_version: KV_PAGE_GEOMETRY_SCHEMA.into(),
114            tokens_per_page,
115            vectors_per_block: 1,
116            head_dim,
117            encoded_block_bytes,
118            raw_vector_bytes: head_dim.saturating_mul(4),
119        }
120    }
121
122    /// Validate page geometry for a shape.
123    pub fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
124        shape.validate()?;
125        if self.schema_version != KV_PAGE_GEOMETRY_SCHEMA {
126            return Err(FibQuantError::CorruptPayload(format!(
127                "kv page geometry schema_version {}, expected {KV_PAGE_GEOMETRY_SCHEMA}",
128                self.schema_version
129            )));
130        }
131        if self.tokens_per_page == 0 || self.tokens_per_page > shape.tokens {
132            return Err(FibQuantError::CorruptPayload(
133                "tokens_per_page must be in 1..=shape.tokens".into(),
134            ));
135        }
136        if self.vectors_per_block != 1 {
137            return Err(FibQuantError::DependencyUnsupported(
138                "CPU reference codec currently supports one vector per block".into(),
139            ));
140        }
141        if self.head_dim != shape.head_dim {
142            return Err(FibQuantError::CorruptPayload(
143                "page geometry head_dim must match shape".into(),
144            ));
145        }
146        if self.raw_vector_bytes != self.head_dim.saturating_mul(4) {
147            return Err(FibQuantError::CorruptPayload(
148                "raw_vector_bytes must equal head_dim * sizeof(f32)".into(),
149            ));
150        }
151        if self.encoded_block_bytes == 0 {
152            return Err(FibQuantError::CorruptPayload(
153                "encoded_block_bytes must be nonzero".into(),
154            ));
155        }
156        Ok(())
157    }
158
159    /// Number of token pages in a shape.
160    pub fn page_count(&self, shape: &KvTensorShapeV1) -> Result<u32> {
161        self.validate_for_shape(shape)?;
162        Ok(shape.tokens.div_ceil(self.tokens_per_page))
163    }
164
165    /// Stable digest.
166    pub fn digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
167        self.validate_for_shape(shape)?;
168        json_digest(KV_PAGE_GEOMETRY_SCHEMA, self)
169    }
170}