1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, FibQuantError, Result};
4
5use super::{
6 block::KvEncodedBlockV1,
7 layout::KvPageGeometryV1,
8 shape::{KvTensorShapeV1, KV_SHAPE_SCHEMA},
9};
10
11pub const KV_PAGE_SCHEMA: &str = "fib_quant_kv_encoded_page_v1";
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct KvEncodedPageV1 {
16 pub schema_version: String,
18 pub page_id: u32,
20 pub token_start: u32,
22 pub token_count: u32,
24 pub source_tensor_digest: String,
26 pub profile_digest: String,
28 pub shape_digest: String,
30 pub shape_schema_version: String,
32 pub page_geometry: KvPageGeometryV1,
34 pub encoded_blocks: Vec<KvEncodedBlockV1>,
36 pub raw_fallback_blocks: u32,
38 pub page_digest: String,
40}
41
42impl KvEncodedPageV1 {
43 #[allow(clippy::too_many_arguments)]
45 pub fn new(
46 page_id: u32,
47 token_start: u32,
48 token_count: u32,
49 source_tensor_digest: String,
50 profile_digest: String,
51 shape: &KvTensorShapeV1,
52 page_geometry: KvPageGeometryV1,
53 encoded_blocks: Vec<KvEncodedBlockV1>,
54 ) -> Result<Self> {
55 let raw_fallback_blocks = encoded_blocks
56 .iter()
57 .filter(|block| block.raw_fallback)
58 .count() as u32;
59 let mut page = Self {
60 schema_version: KV_PAGE_SCHEMA.into(),
61 page_id,
62 token_start,
63 token_count,
64 source_tensor_digest,
65 profile_digest,
66 shape_digest: shape.digest()?,
67 shape_schema_version: KV_SHAPE_SCHEMA.into(),
68 page_geometry,
69 encoded_blocks,
70 raw_fallback_blocks,
71 page_digest: String::new(),
72 };
73 page.page_digest = page.compute_digest(shape)?;
74 Ok(page)
75 }
76
77 pub fn validate(&self, shape: &KvTensorShapeV1) -> Result<()> {
79 if self.schema_version != KV_PAGE_SCHEMA {
80 return Err(FibQuantError::CorruptPayload(format!(
81 "kv page schema_version {}, expected {KV_PAGE_SCHEMA}",
82 self.schema_version
83 )));
84 }
85 shape.validate()?;
86 if self.shape_schema_version != KV_SHAPE_SCHEMA || self.shape_digest != shape.digest()? {
87 return Err(FibQuantError::CorruptPayload(
88 "kv page shape digest mismatch".into(),
89 ));
90 }
91 self.page_geometry.validate_for_shape(shape)?;
92 if self.token_count == 0 || self.token_start >= shape.tokens {
93 return Err(FibQuantError::CorruptPayload(
94 "invalid kv page token span".into(),
95 ));
96 }
97 if self.token_start + self.token_count > shape.tokens {
98 return Err(FibQuantError::CorruptPayload(
99 "kv page token span exceeds shape tokens".into(),
100 ));
101 }
102 if self.token_count > self.page_geometry.tokens_per_page {
103 return Err(FibQuantError::CorruptPayload(
104 "kv page token_count exceeds geometry".into(),
105 ));
106 }
107 let expected_raw = self
108 .encoded_blocks
109 .iter()
110 .filter(|block| block.raw_fallback)
111 .count() as u32;
112 if self.raw_fallback_blocks != expected_raw {
113 return Err(FibQuantError::CorruptPayload(
114 "kv page raw fallback count mismatch".into(),
115 ));
116 }
117 for (idx, block) in self.encoded_blocks.iter().enumerate() {
118 block.validate(shape.head_dim)?;
119 if block.block_id as usize != idx {
120 return Err(FibQuantError::CorruptPayload(
121 "kv page block ids must be contiguous".into(),
122 ));
123 }
124 if block.token < self.token_start || block.token >= self.token_start + self.token_count
125 {
126 return Err(FibQuantError::CorruptPayload(
127 "kv page block token outside page span".into(),
128 ));
129 }
130 }
131 let expected_digest = self.compute_digest(shape)?;
132 if self.page_digest != expected_digest {
133 return Err(FibQuantError::CorruptPayload(
134 "kv page digest mismatch".into(),
135 ));
136 }
137 Ok(())
138 }
139
140 pub fn compute_digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
142 self.page_geometry.validate_for_shape(shape)?;
143 #[derive(Serialize)]
144 struct DigestView<'a> {
145 schema_version: &'a str,
146 page_id: u32,
147 token_start: u32,
148 token_count: u32,
149 source_tensor_digest: &'a str,
150 profile_digest: &'a str,
151 shape_digest: &'a str,
152 shape_schema_version: &'a str,
153 page_geometry: &'a KvPageGeometryV1,
154 encoded_blocks: &'a [KvEncodedBlockV1],
155 raw_fallback_blocks: u32,
156 }
157 json_digest(
158 KV_PAGE_SCHEMA,
159 &DigestView {
160 schema_version: &self.schema_version,
161 page_id: self.page_id,
162 token_start: self.token_start,
163 token_count: self.token_count,
164 source_tensor_digest: &self.source_tensor_digest,
165 profile_digest: &self.profile_digest,
166 shape_digest: &self.shape_digest,
167 shape_schema_version: &self.shape_schema_version,
168 page_geometry: &self.page_geometry,
169 encoded_blocks: &self.encoded_blocks,
170 raw_fallback_blocks: self.raw_fallback_blocks,
171 },
172 )
173 }
174}