Skip to main content

fib_quant/kv/
page.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, FibQuantError, Result};
4
5use super::{
6    block::KvEncodedBlockV1,
7    layout::KvPageGeometryV1,
8    shape::{KvTensorShapeV1, KV_SHAPE_SCHEMA},
9};
10
11pub const KV_PAGE_SCHEMA: &str = "fib_quant_kv_encoded_page_v1";
12
13/// Fixed-size random-access encoded KV page.
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct KvEncodedPageV1 {
16    /// Stable schema marker.
17    pub schema_version: String,
18    /// Page id in token-page order.
19    pub page_id: u32,
20    /// First token covered by this page.
21    pub token_start: u32,
22    /// Number of tokens represented.
23    pub token_count: u32,
24    /// Source tensor digest.
25    pub source_tensor_digest: String,
26    /// KV compression profile digest.
27    pub profile_digest: String,
28    /// Shape digest.
29    pub shape_digest: String,
30    /// Shape schema marker.
31    pub shape_schema_version: String,
32    /// Page geometry.
33    pub page_geometry: KvPageGeometryV1,
34    /// Encoded blocks.
35    pub encoded_blocks: Vec<KvEncodedBlockV1>,
36    /// Count of raw fallback blocks.
37    pub raw_fallback_blocks: u32,
38    /// Stable digest/checksum for this page.
39    pub page_digest: String,
40}
41
42impl KvEncodedPageV1 {
43    /// Build a page and compute its digest.
44    #[allow(clippy::too_many_arguments)]
45    pub fn new(
46        page_id: u32,
47        token_start: u32,
48        token_count: u32,
49        source_tensor_digest: String,
50        profile_digest: String,
51        shape: &KvTensorShapeV1,
52        page_geometry: KvPageGeometryV1,
53        encoded_blocks: Vec<KvEncodedBlockV1>,
54    ) -> Result<Self> {
55        let raw_fallback_blocks = encoded_blocks
56            .iter()
57            .filter(|block| block.raw_fallback)
58            .count() as u32;
59        let mut page = Self {
60            schema_version: KV_PAGE_SCHEMA.into(),
61            page_id,
62            token_start,
63            token_count,
64            source_tensor_digest,
65            profile_digest,
66            shape_digest: shape.digest()?,
67            shape_schema_version: KV_SHAPE_SCHEMA.into(),
68            page_geometry,
69            encoded_blocks,
70            raw_fallback_blocks,
71            page_digest: String::new(),
72        };
73        page.page_digest = page.compute_digest(shape)?;
74        Ok(page)
75    }
76
77    /// Validate page fields and digest.
78    pub fn validate(&self, shape: &KvTensorShapeV1) -> Result<()> {
79        if self.schema_version != KV_PAGE_SCHEMA {
80            return Err(FibQuantError::CorruptPayload(format!(
81                "kv page schema_version {}, expected {KV_PAGE_SCHEMA}",
82                self.schema_version
83            )));
84        }
85        shape.validate()?;
86        if self.shape_schema_version != KV_SHAPE_SCHEMA || self.shape_digest != shape.digest()? {
87            return Err(FibQuantError::CorruptPayload(
88                "kv page shape digest mismatch".into(),
89            ));
90        }
91        self.page_geometry.validate_for_shape(shape)?;
92        if self.token_count == 0 || self.token_start >= shape.tokens {
93            return Err(FibQuantError::CorruptPayload(
94                "invalid kv page token span".into(),
95            ));
96        }
97        if self.token_start + self.token_count > shape.tokens {
98            return Err(FibQuantError::CorruptPayload(
99                "kv page token span exceeds shape tokens".into(),
100            ));
101        }
102        if self.token_count > self.page_geometry.tokens_per_page {
103            return Err(FibQuantError::CorruptPayload(
104                "kv page token_count exceeds geometry".into(),
105            ));
106        }
107        let expected_raw = self
108            .encoded_blocks
109            .iter()
110            .filter(|block| block.raw_fallback)
111            .count() as u32;
112        if self.raw_fallback_blocks != expected_raw {
113            return Err(FibQuantError::CorruptPayload(
114                "kv page raw fallback count mismatch".into(),
115            ));
116        }
117        for (idx, block) in self.encoded_blocks.iter().enumerate() {
118            block.validate(shape.head_dim)?;
119            if block.block_id as usize != idx {
120                return Err(FibQuantError::CorruptPayload(
121                    "kv page block ids must be contiguous".into(),
122                ));
123            }
124            if block.token < self.token_start || block.token >= self.token_start + self.token_count
125            {
126                return Err(FibQuantError::CorruptPayload(
127                    "kv page block token outside page span".into(),
128                ));
129            }
130        }
131        let expected_digest = self.compute_digest(shape)?;
132        if self.page_digest != expected_digest {
133            return Err(FibQuantError::CorruptPayload(
134                "kv page digest mismatch".into(),
135            ));
136        }
137        Ok(())
138    }
139
140    /// Compute page digest excluding the digest field itself.
141    pub fn compute_digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
142        self.page_geometry.validate_for_shape(shape)?;
143        #[derive(Serialize)]
144        struct DigestView<'a> {
145            schema_version: &'a str,
146            page_id: u32,
147            token_start: u32,
148            token_count: u32,
149            source_tensor_digest: &'a str,
150            profile_digest: &'a str,
151            shape_digest: &'a str,
152            shape_schema_version: &'a str,
153            page_geometry: &'a KvPageGeometryV1,
154            encoded_blocks: &'a [KvEncodedBlockV1],
155            raw_fallback_blocks: u32,
156        }
157        json_digest(
158            KV_PAGE_SCHEMA,
159            &DigestView {
160                schema_version: &self.schema_version,
161                page_id: self.page_id,
162                token_start: self.token_start,
163                token_count: self.token_count,
164                source_tensor_digest: &self.source_tensor_digest,
165                profile_digest: &self.profile_digest,
166                shape_digest: &self.shape_digest,
167                shape_schema_version: &self.shape_schema_version,
168                page_geometry: &self.page_geometry,
169                encoded_blocks: &self.encoded_blocks,
170                raw_fallback_blocks: self.raw_fallback_blocks,
171            },
172        )
173    }
174}