Skip to main content

fib_quant/kv/
codec.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{FibQuantError, FibQuantizer, Result};
4
5use super::{
6    block::{KvBlockEncodingV1, KvEncodedBlockV1},
7    layout::KvCacheLayoutV1,
8    page::KvEncodedPageV1,
9    profile::{KvAxisPolicyV1, KvCompressionProfileV1, KvFallbackModeV1},
10    receipt::{
11        kv_tensor_digest, now_unix_seconds, KvCompressionReceiptV1, KvDecodeReceiptV1,
12        KvOperationKindV1, KV_RECEIPT_SCHEMA,
13    },
14    shape::KvTensorShapeV1,
15};
16
17/// Encoded tensor artifact with pages and compression receipt.
18#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19pub struct KvEncodedTensorV1 {
20    /// Logical shape.
21    pub shape: KvTensorShapeV1,
22    /// Physical layout.
23    pub layout: KvCacheLayoutV1,
24    /// Compression profile.
25    pub profile: KvCompressionProfileV1,
26    /// Encoded pages.
27    pub pages: Vec<KvEncodedPageV1>,
28    /// Compression receipt.
29    pub receipt: KvCompressionReceiptV1,
30}
31
32/// Decoded tensor and receipt.
33#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
34pub struct KvDecodedTensorV1 {
35    /// Canonical contiguous f32 values.
36    pub values: Vec<f32>,
37    /// Decode receipt.
38    pub receipt: KvDecodeReceiptV1,
39}
40
41/// Encode a canonical contiguous f32 KV tensor.
42pub fn encode_kv_tensor(
43    shape: KvTensorShapeV1,
44    layout: KvCacheLayoutV1,
45    profile: KvCompressionProfileV1,
46    values: &[f32],
47) -> Result<KvEncodedTensorV1> {
48    shape.validate()?;
49    layout.validate_for_shape(&shape)?;
50    profile.validate_for_shape(&shape)?;
51    if values.len() != shape.element_count()? {
52        return Err(FibQuantError::CorruptPayload(format!(
53            "kv input has {} values, expected {}",
54            values.len(),
55            shape.element_count()?
56        )));
57    }
58    if values.iter().any(|value| !value.is_finite()) {
59        return Err(FibQuantError::CorruptPayload(
60            "kv input contains non-finite value".into(),
61        ));
62    }
63
64    let quantizer = build_quantizer(&profile)?;
65    let source_digest = kv_tensor_digest(values)?;
66    let profile_digest = profile.digest(&shape)?;
67    let mut pages = Vec::new();
68    let mut compressed_blocks = 0u32;
69    let mut raw_fallback_blocks = 0u32;
70    let mut fallback_reasons = Vec::new();
71    let page_count = profile.page_geometry.page_count(&shape)?;
72
73    for page_id in 0..page_count {
74        let token_start = page_id * profile.page_geometry.tokens_per_page;
75        let token_end = (token_start + profile.page_geometry.tokens_per_page).min(shape.tokens);
76        let token_count = token_end - token_start;
77        let mut blocks = Vec::new();
78        for batch in 0..shape.batch {
79            for layer in 0..shape.layers {
80                for head in 0..shape.kv_heads {
81                    for token in token_start..token_end {
82                        let block_id = blocks.len() as u32;
83                        let vector = vector_slice(values, &shape, batch, layer, head, token)?;
84                        let protected = profile
85                            .protected_policy
86                            .is_protected(&shape, layer, head, token);
87                        let block = if protected {
88                            raw_block(
89                                block_id,
90                                batch,
91                                layer,
92                                head,
93                                token,
94                                vector,
95                                profile.page_geometry.encoded_block_bytes,
96                                "protected_region",
97                            )
98                        } else {
99                            encode_vector_block(
100                                &quantizer, &profile, block_id, batch, layer, head, token, vector,
101                            )?
102                        };
103                        if block.raw_fallback {
104                            raw_fallback_blocks += 1;
105                            if !fallback_reasons.contains(&block.reason) {
106                                fallback_reasons.push(block.reason.clone());
107                            }
108                        } else {
109                            compressed_blocks += 1;
110                        }
111                        blocks.push(block);
112                    }
113                }
114            }
115        }
116        pages.push(KvEncodedPageV1::new(
117            page_id,
118            token_start,
119            token_count,
120            source_digest.clone(),
121            profile_digest.clone(),
122            &shape,
123            profile.page_geometry.clone(),
124            blocks,
125        )?);
126    }
127
128    let page_digests = pages.iter().map(|page| page.page_digest.clone()).collect();
129    let receipt = KvCompressionReceiptV1 {
130        schema_version: KV_RECEIPT_SCHEMA.into(),
131        operation_kind: KvOperationKindV1::Compress,
132        source_digest,
133        profile_digest,
134        shape_digest: shape.digest()?,
135        page_digests,
136        codebook_digest: profile.codebook_digest.clone(),
137        rotation_digest: profile.rotation_digest.clone(),
138        encoded_pages: pages.len() as u32,
139        compressed_blocks,
140        raw_fallback_blocks,
141        fallback_reasons,
142        recorded_unix_seconds: now_unix_seconds(),
143    };
144    Ok(KvEncodedTensorV1 {
145        shape,
146        layout,
147        profile,
148        pages,
149        receipt,
150    })
151}
152
153/// Decode encoded pages into canonical contiguous f32 values.
154pub fn decode_kv_pages(encoded: &KvEncodedTensorV1) -> Result<KvDecodedTensorV1> {
155    encoded.shape.validate()?;
156    encoded.layout.validate_for_shape(&encoded.shape)?;
157    encoded.profile.validate_for_shape(&encoded.shape)?;
158    encoded.receipt.validate()?;
159    let profile_digest = encoded.profile.digest(&encoded.shape)?;
160    if encoded.receipt.profile_digest != profile_digest {
161        return Err(FibQuantError::ProfileDigestMismatch {
162            expected: profile_digest,
163            actual: encoded.receipt.profile_digest.clone(),
164        });
165    }
166    let quantizer = build_quantizer(&encoded.profile)?;
167    let mut values = vec![0.0; encoded.shape.element_count()?];
168    let mut page_digests = Vec::with_capacity(encoded.pages.len());
169    let mut raw_fallback_blocks = 0u32;
170    for page in &encoded.pages {
171        page.validate(&encoded.shape)?;
172        if page.profile_digest != encoded.receipt.profile_digest {
173            return Err(FibQuantError::ProfileDigestMismatch {
174                expected: encoded.receipt.profile_digest.clone(),
175                actual: page.profile_digest.clone(),
176            });
177        }
178        page_digests.push(page.page_digest.clone());
179        for block in &page.encoded_blocks {
180            if block.batch >= encoded.shape.batch
181                || block.layer >= encoded.shape.layers
182                || block.kv_head >= encoded.shape.kv_heads
183                || block.token >= encoded.shape.tokens
184            {
185                return Err(FibQuantError::CorruptPayload(
186                    "kv block index outside shape".into(),
187                ));
188            }
189            let decoded = match &block.encoding {
190                KvBlockEncodingV1::RawF32 { values } => {
191                    raw_fallback_blocks += 1;
192                    values.clone()
193                }
194                KvBlockEncodingV1::FibQuant { code } => quantizer.decode(code)?,
195            };
196            if decoded.len() != encoded.shape.head_dim as usize {
197                return Err(FibQuantError::CorruptPayload(
198                    "decoded kv vector head_dim mismatch".into(),
199                ));
200            }
201            let out = vector_slice_mut(
202                &mut values,
203                &encoded.shape,
204                block.batch,
205                block.layer,
206                block.kv_head,
207                block.token,
208            )?;
209            out.copy_from_slice(&decoded);
210        }
211    }
212    let decoded_digest = kv_tensor_digest(&values)?;
213    Ok(KvDecodedTensorV1 {
214        values,
215        receipt: KvDecodeReceiptV1 {
216            schema_version: KV_RECEIPT_SCHEMA.into(),
217            operation_kind: KvOperationKindV1::Decode,
218            decoded_digest,
219            profile_digest: encoded.receipt.profile_digest.clone(),
220            shape_digest: encoded.shape.digest()?,
221            page_digests,
222            codebook_digest: encoded.profile.codebook_digest.clone(),
223            rotation_digest: encoded.profile.rotation_digest.clone(),
224            decoded_pages: encoded.pages.len() as u32,
225            raw_fallback_blocks,
226            recorded_unix_seconds: now_unix_seconds(),
227        },
228    })
229}
230
231fn build_quantizer(profile: &KvCompressionProfileV1) -> Result<FibQuantizer> {
232    let quantizer = FibQuantizer::new(profile.fib_profile.clone())?;
233    if quantizer.codebook().codebook_digest != profile.codebook_digest {
234        return Err(FibQuantError::CodebookDigestMismatch {
235            expected: quantizer.codebook().codebook_digest.clone(),
236            actual: profile.codebook_digest.clone(),
237        });
238    }
239    Ok(quantizer)
240}
241
242#[allow(clippy::too_many_arguments)]
243fn encode_vector_block(
244    quantizer: &FibQuantizer,
245    profile: &KvCompressionProfileV1,
246    block_id: u32,
247    batch: u32,
248    layer: u32,
249    head: u32,
250    token: u32,
251    vector: &[f32],
252) -> Result<KvEncodedBlockV1> {
253    match profile.axis_policy {
254        KvAxisPolicyV1::Raw => Ok(raw_block(
255            block_id,
256            batch,
257            layer,
258            head,
259            token,
260            vector,
261            profile.page_geometry.encoded_block_bytes,
262            "raw_axis_policy",
263        )),
264        KvAxisPolicyV1::PerToken => match quantizer.encode(vector) {
265            Ok(code) => Ok(KvEncodedBlockV1::fib_quant(
266                block_id,
267                batch,
268                layer,
269                head,
270                token,
271                code,
272                profile.page_geometry.encoded_block_bytes,
273                "fib_quant_per_token",
274            )),
275            Err(err) if profile.fallback_policy.mode == KvFallbackModeV1::KeepRaw => Ok(raw_block(
276                block_id,
277                batch,
278                layer,
279                head,
280                token,
281                vector,
282                profile.page_geometry.encoded_block_bytes,
283                format!("encode_fallback:{err}"),
284            )),
285            Err(err) => Err(err),
286        },
287        KvAxisPolicyV1::PerChannel | KvAxisPolicyV1::RoleAwareKiviStyle => {
288            if profile.fallback_policy.mode == KvFallbackModeV1::KeepRaw {
289                Ok(raw_block(
290                    block_id,
291                    batch,
292                    layer,
293                    head,
294                    token,
295                    vector,
296                    profile.page_geometry.encoded_block_bytes,
297                    "unsupported_axis_raw_fallback",
298                ))
299            } else {
300                Err(FibQuantError::DependencyUnsupported(
301                    "CPU reference codec supports per-token FibQuant compression only".into(),
302                ))
303            }
304        }
305    }
306}
307
308#[allow(clippy::too_many_arguments)]
309fn raw_block(
310    block_id: u32,
311    batch: u32,
312    layer: u32,
313    head: u32,
314    token: u32,
315    vector: &[f32],
316    fixed_size_bytes: u32,
317    reason: impl Into<String>,
318) -> KvEncodedBlockV1 {
319    KvEncodedBlockV1::raw(
320        block_id,
321        batch,
322        layer,
323        head,
324        token,
325        vector.to_vec(),
326        fixed_size_bytes,
327        reason,
328    )
329}
330
331fn vector_offset(
332    shape: &KvTensorShapeV1,
333    batch: u32,
334    layer: u32,
335    head: u32,
336    token: u32,
337) -> Result<usize> {
338    if batch >= shape.batch
339        || layer >= shape.layers
340        || head >= shape.kv_heads
341        || token >= shape.tokens
342    {
343        return Err(FibQuantError::CorruptPayload(
344            "kv vector index outside shape".into(),
345        ));
346    }
347    let vectors_before = (((batch as usize * shape.layers as usize + layer as usize)
348        * shape.kv_heads as usize
349        + head as usize)
350        * shape.tokens as usize)
351        + token as usize;
352    vectors_before
353        .checked_mul(shape.head_dim as usize)
354        .ok_or_else(|| FibQuantError::ResourceLimitExceeded("kv vector offset overflow".into()))
355}
356
357fn vector_slice<'a>(
358    values: &'a [f32],
359    shape: &KvTensorShapeV1,
360    batch: u32,
361    layer: u32,
362    head: u32,
363    token: u32,
364) -> Result<&'a [f32]> {
365    let start = vector_offset(shape, batch, layer, head, token)?;
366    let end = start + shape.head_dim as usize;
367    values
368        .get(start..end)
369        .ok_or_else(|| FibQuantError::CorruptPayload("kv vector slice out of bounds".into()))
370}
371
372fn vector_slice_mut<'a>(
373    values: &'a mut [f32],
374    shape: &KvTensorShapeV1,
375    batch: u32,
376    layer: u32,
377    head: u32,
378    token: u32,
379) -> Result<&'a mut [f32]> {
380    let start = vector_offset(shape, batch, layer, head, token)?;
381    let end = start + shape.head_dim as usize;
382    values
383        .get_mut(start..end)
384        .ok_or_else(|| FibQuantError::CorruptPayload("kv vector slice out of bounds".into()))
385}