1use serde::{Deserialize, Serialize};
2
3use crate::{FibQuantError, FibQuantizer, Result};
4
5use super::{
6 block::{KvBlockEncodingV1, KvEncodedBlockV1},
7 layout::KvCacheLayoutV1,
8 page::KvEncodedPageV1,
9 profile::{KvAxisPolicyV1, KvCompressionProfileV1, KvFallbackModeV1},
10 receipt::{
11 kv_tensor_digest, now_unix_seconds, KvCompressionReceiptV1, KvDecodeReceiptV1,
12 KvOperationKindV1, KV_RECEIPT_SCHEMA,
13 },
14 shape::KvTensorShapeV1,
15};
16
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19pub struct KvEncodedTensorV1 {
20 pub shape: KvTensorShapeV1,
22 pub layout: KvCacheLayoutV1,
24 pub profile: KvCompressionProfileV1,
26 pub pages: Vec<KvEncodedPageV1>,
28 pub receipt: KvCompressionReceiptV1,
30}
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub struct KvDecodedTensorV1 {
35 pub values: Vec<f32>,
37 pub receipt: KvDecodeReceiptV1,
39}
40
41pub fn encode_kv_tensor(
43 shape: KvTensorShapeV1,
44 layout: KvCacheLayoutV1,
45 profile: KvCompressionProfileV1,
46 values: &[f32],
47) -> Result<KvEncodedTensorV1> {
48 shape.validate()?;
49 layout.validate_for_shape(&shape)?;
50 profile.validate_for_shape(&shape)?;
51 if values.len() != shape.element_count()? {
52 return Err(FibQuantError::CorruptPayload(format!(
53 "kv input has {} values, expected {}",
54 values.len(),
55 shape.element_count()?
56 )));
57 }
58 if values.iter().any(|value| !value.is_finite()) {
59 return Err(FibQuantError::CorruptPayload(
60 "kv input contains non-finite value".into(),
61 ));
62 }
63
64 let quantizer = build_quantizer(&profile)?;
65 let source_digest = kv_tensor_digest(values)?;
66 let profile_digest = profile.digest(&shape)?;
67 let mut pages = Vec::new();
68 let mut compressed_blocks = 0u32;
69 let mut raw_fallback_blocks = 0u32;
70 let mut fallback_reasons = Vec::new();
71 let page_count = profile.page_geometry.page_count(&shape)?;
72
73 for page_id in 0..page_count {
74 let token_start = page_id * profile.page_geometry.tokens_per_page;
75 let token_end = (token_start + profile.page_geometry.tokens_per_page).min(shape.tokens);
76 let token_count = token_end - token_start;
77 let mut blocks = Vec::new();
78 for batch in 0..shape.batch {
79 for layer in 0..shape.layers {
80 for head in 0..shape.kv_heads {
81 for token in token_start..token_end {
82 let block_id = blocks.len() as u32;
83 let vector = vector_slice(values, &shape, batch, layer, head, token)?;
84 let protected = profile
85 .protected_policy
86 .is_protected(&shape, layer, head, token);
87 let block = if protected {
88 raw_block(
89 block_id,
90 batch,
91 layer,
92 head,
93 token,
94 vector,
95 profile.page_geometry.encoded_block_bytes,
96 "protected_region",
97 )
98 } else {
99 encode_vector_block(
100 &quantizer, &profile, block_id, batch, layer, head, token, vector,
101 )?
102 };
103 if block.raw_fallback {
104 raw_fallback_blocks += 1;
105 if !fallback_reasons.contains(&block.reason) {
106 fallback_reasons.push(block.reason.clone());
107 }
108 } else {
109 compressed_blocks += 1;
110 }
111 blocks.push(block);
112 }
113 }
114 }
115 }
116 pages.push(KvEncodedPageV1::new(
117 page_id,
118 token_start,
119 token_count,
120 source_digest.clone(),
121 profile_digest.clone(),
122 &shape,
123 profile.page_geometry.clone(),
124 blocks,
125 )?);
126 }
127
128 let page_digests = pages.iter().map(|page| page.page_digest.clone()).collect();
129 let receipt = KvCompressionReceiptV1 {
130 schema_version: KV_RECEIPT_SCHEMA.into(),
131 operation_kind: KvOperationKindV1::Compress,
132 source_digest,
133 profile_digest,
134 shape_digest: shape.digest()?,
135 page_digests,
136 codebook_digest: profile.codebook_digest.clone(),
137 rotation_digest: profile.rotation_digest.clone(),
138 encoded_pages: pages.len() as u32,
139 compressed_blocks,
140 raw_fallback_blocks,
141 fallback_reasons,
142 recorded_unix_seconds: now_unix_seconds(),
143 };
144 Ok(KvEncodedTensorV1 {
145 shape,
146 layout,
147 profile,
148 pages,
149 receipt,
150 })
151}
152
153pub fn decode_kv_pages(encoded: &KvEncodedTensorV1) -> Result<KvDecodedTensorV1> {
155 encoded.shape.validate()?;
156 encoded.layout.validate_for_shape(&encoded.shape)?;
157 encoded.profile.validate_for_shape(&encoded.shape)?;
158 encoded.receipt.validate()?;
159 let profile_digest = encoded.profile.digest(&encoded.shape)?;
160 if encoded.receipt.profile_digest != profile_digest {
161 return Err(FibQuantError::ProfileDigestMismatch {
162 expected: profile_digest,
163 actual: encoded.receipt.profile_digest.clone(),
164 });
165 }
166 let quantizer = build_quantizer(&encoded.profile)?;
167 let mut values = vec![0.0; encoded.shape.element_count()?];
168 let mut page_digests = Vec::with_capacity(encoded.pages.len());
169 let mut raw_fallback_blocks = 0u32;
170 for page in &encoded.pages {
171 page.validate(&encoded.shape)?;
172 if page.profile_digest != encoded.receipt.profile_digest {
173 return Err(FibQuantError::ProfileDigestMismatch {
174 expected: encoded.receipt.profile_digest.clone(),
175 actual: page.profile_digest.clone(),
176 });
177 }
178 page_digests.push(page.page_digest.clone());
179 for block in &page.encoded_blocks {
180 if block.batch >= encoded.shape.batch
181 || block.layer >= encoded.shape.layers
182 || block.kv_head >= encoded.shape.kv_heads
183 || block.token >= encoded.shape.tokens
184 {
185 return Err(FibQuantError::CorruptPayload(
186 "kv block index outside shape".into(),
187 ));
188 }
189 let decoded = match &block.encoding {
190 KvBlockEncodingV1::RawF32 { values } => {
191 raw_fallback_blocks += 1;
192 values.clone()
193 }
194 KvBlockEncodingV1::FibQuant { code } => quantizer.decode(code)?,
195 };
196 if decoded.len() != encoded.shape.head_dim as usize {
197 return Err(FibQuantError::CorruptPayload(
198 "decoded kv vector head_dim mismatch".into(),
199 ));
200 }
201 let out = vector_slice_mut(
202 &mut values,
203 &encoded.shape,
204 block.batch,
205 block.layer,
206 block.kv_head,
207 block.token,
208 )?;
209 out.copy_from_slice(&decoded);
210 }
211 }
212 let decoded_digest = kv_tensor_digest(&values)?;
213 Ok(KvDecodedTensorV1 {
214 values,
215 receipt: KvDecodeReceiptV1 {
216 schema_version: KV_RECEIPT_SCHEMA.into(),
217 operation_kind: KvOperationKindV1::Decode,
218 decoded_digest,
219 profile_digest: encoded.receipt.profile_digest.clone(),
220 shape_digest: encoded.shape.digest()?,
221 page_digests,
222 codebook_digest: encoded.profile.codebook_digest.clone(),
223 rotation_digest: encoded.profile.rotation_digest.clone(),
224 decoded_pages: encoded.pages.len() as u32,
225 raw_fallback_blocks,
226 recorded_unix_seconds: now_unix_seconds(),
227 },
228 })
229}
230
231fn build_quantizer(profile: &KvCompressionProfileV1) -> Result<FibQuantizer> {
232 let quantizer = FibQuantizer::new(profile.fib_profile.clone())?;
233 if quantizer.codebook().codebook_digest != profile.codebook_digest {
234 return Err(FibQuantError::CodebookDigestMismatch {
235 expected: quantizer.codebook().codebook_digest.clone(),
236 actual: profile.codebook_digest.clone(),
237 });
238 }
239 Ok(quantizer)
240}
241
242#[allow(clippy::too_many_arguments)]
243fn encode_vector_block(
244 quantizer: &FibQuantizer,
245 profile: &KvCompressionProfileV1,
246 block_id: u32,
247 batch: u32,
248 layer: u32,
249 head: u32,
250 token: u32,
251 vector: &[f32],
252) -> Result<KvEncodedBlockV1> {
253 match profile.axis_policy {
254 KvAxisPolicyV1::Raw => Ok(raw_block(
255 block_id,
256 batch,
257 layer,
258 head,
259 token,
260 vector,
261 profile.page_geometry.encoded_block_bytes,
262 "raw_axis_policy",
263 )),
264 KvAxisPolicyV1::PerToken => match quantizer.encode(vector) {
265 Ok(code) => Ok(KvEncodedBlockV1::fib_quant(
266 block_id,
267 batch,
268 layer,
269 head,
270 token,
271 code,
272 profile.page_geometry.encoded_block_bytes,
273 "fib_quant_per_token",
274 )),
275 Err(err) if profile.fallback_policy.mode == KvFallbackModeV1::KeepRaw => Ok(raw_block(
276 block_id,
277 batch,
278 layer,
279 head,
280 token,
281 vector,
282 profile.page_geometry.encoded_block_bytes,
283 format!("encode_fallback:{err}"),
284 )),
285 Err(err) => Err(err),
286 },
287 KvAxisPolicyV1::PerChannel | KvAxisPolicyV1::RoleAwareKiviStyle => {
288 if profile.fallback_policy.mode == KvFallbackModeV1::KeepRaw {
289 Ok(raw_block(
290 block_id,
291 batch,
292 layer,
293 head,
294 token,
295 vector,
296 profile.page_geometry.encoded_block_bytes,
297 "unsupported_axis_raw_fallback",
298 ))
299 } else {
300 Err(FibQuantError::DependencyUnsupported(
301 "CPU reference codec supports per-token FibQuant compression only".into(),
302 ))
303 }
304 }
305 }
306}
307
308#[allow(clippy::too_many_arguments)]
309fn raw_block(
310 block_id: u32,
311 batch: u32,
312 layer: u32,
313 head: u32,
314 token: u32,
315 vector: &[f32],
316 fixed_size_bytes: u32,
317 reason: impl Into<String>,
318) -> KvEncodedBlockV1 {
319 KvEncodedBlockV1::raw(
320 block_id,
321 batch,
322 layer,
323 head,
324 token,
325 vector.to_vec(),
326 fixed_size_bytes,
327 reason,
328 )
329}
330
331fn vector_offset(
332 shape: &KvTensorShapeV1,
333 batch: u32,
334 layer: u32,
335 head: u32,
336 token: u32,
337) -> Result<usize> {
338 if batch >= shape.batch
339 || layer >= shape.layers
340 || head >= shape.kv_heads
341 || token >= shape.tokens
342 {
343 return Err(FibQuantError::CorruptPayload(
344 "kv vector index outside shape".into(),
345 ));
346 }
347 let vectors_before = (((batch as usize * shape.layers as usize + layer as usize)
348 * shape.kv_heads as usize
349 + head as usize)
350 * shape.tokens as usize)
351 + token as usize;
352 vectors_before
353 .checked_mul(shape.head_dim as usize)
354 .ok_or_else(|| FibQuantError::ResourceLimitExceeded("kv vector offset overflow".into()))
355}
356
357fn vector_slice<'a>(
358 values: &'a [f32],
359 shape: &KvTensorShapeV1,
360 batch: u32,
361 layer: u32,
362 head: u32,
363 token: u32,
364) -> Result<&'a [f32]> {
365 let start = vector_offset(shape, batch, layer, head, token)?;
366 let end = start + shape.head_dim as usize;
367 values
368 .get(start..end)
369 .ok_or_else(|| FibQuantError::CorruptPayload("kv vector slice out of bounds".into()))
370}
371
372fn vector_slice_mut<'a>(
373 values: &'a mut [f32],
374 shape: &KvTensorShapeV1,
375 batch: u32,
376 layer: u32,
377 head: u32,
378 token: u32,
379) -> Result<&'a mut [f32]> {
380 let start = vector_offset(shape, batch, layer, head, token)?;
381 let end = start + shape.head_dim as usize;
382 values
383 .get_mut(start..end)
384 .ok_or_else(|| FibQuantError::CorruptPayload("kv vector slice out of bounds".into()))
385}