Skip to main content

fib_quant/kv/
profile.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{
4    digest::json_digest, profile::FibQuantProfileV1, rotation::StoredRotation, FibQuantError,
5    Result,
6};
7
8use super::{
9    layout::KvPageGeometryV1,
10    shape::{KvRole, KvTensorShapeV1},
11};
12
13pub const KV_PROFILE_SCHEMA: &str = "fib_quant_kv_compression_profile_v1";
14
15/// Compression axis policy.
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[non_exhaustive]
18#[serde(rename_all = "snake_case")]
19pub enum KvAxisPolicyV1 {
20    /// Keep vectors raw.
21    Raw,
22    /// Compress each token/head vector independently.
23    PerToken,
24    /// Compress channel vectors across tokens. Planned for backend-specific paths.
25    PerChannel,
26    /// Key per-channel, value per-token baseline.
27    RoleAwareKiviStyle,
28}
29
30/// Fallback mode for unsupported or rejected regions.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[non_exhaustive]
33#[serde(rename_all = "snake_case")]
34pub enum KvFallbackModeV1 {
35    /// Store raw f32 blocks.
36    KeepRaw,
37    /// Fail the operation.
38    FailClosed,
39}
40
41/// Protected raw regions.
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub struct KvProtectedPolicyV1 {
44    /// First N tokens are raw.
45    pub first_tokens_raw: u32,
46    /// Last N tokens are raw.
47    pub last_tokens_raw: u32,
48    /// Layers kept raw.
49    pub raw_layers: Vec<u32>,
50    /// KV heads kept raw.
51    pub raw_heads: Vec<u32>,
52}
53
54impl KvProtectedPolicyV1 {
55    /// No protected regions.
56    pub fn none() -> Self {
57        Self {
58            first_tokens_raw: 0,
59            last_tokens_raw: 0,
60            raw_layers: Vec::new(),
61            raw_heads: Vec::new(),
62        }
63    }
64
65    /// Whether a vector falls in a protected raw region.
66    pub fn is_protected(&self, shape: &KvTensorShapeV1, layer: u32, head: u32, token: u32) -> bool {
67        token < self.first_tokens_raw
68            || token.saturating_add(self.last_tokens_raw) >= shape.tokens
69            || self.raw_layers.contains(&layer)
70            || self.raw_heads.contains(&head)
71    }
72
73    pub(crate) fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
74        if self.first_tokens_raw > shape.tokens || self.last_tokens_raw > shape.tokens {
75            return Err(FibQuantError::CorruptPayload(
76                "protected token count exceeds shape tokens".into(),
77            ));
78        }
79        if self.raw_layers.iter().any(|layer| *layer >= shape.layers) {
80            return Err(FibQuantError::CorruptPayload(
81                "protected layer outside shape".into(),
82            ));
83        }
84        if self.raw_heads.iter().any(|head| *head >= shape.kv_heads) {
85            return Err(FibQuantError::CorruptPayload(
86                "protected head outside shape".into(),
87            ));
88        }
89        Ok(())
90    }
91}
92
93/// Fallback declaration.
94#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub struct KvFallbackPolicyV1 {
96    /// Fallback mode.
97    pub mode: KvFallbackModeV1,
98    /// Whether raw fallback blocks are always allowed.
99    pub raw_fallback_available: bool,
100}
101
102impl KvFallbackPolicyV1 {
103    /// Conservative raw fallback.
104    pub fn keep_raw() -> Self {
105        Self {
106            mode: KvFallbackModeV1::KeepRaw,
107            raw_fallback_available: true,
108        }
109    }
110}
111
112/// Quality budget used by policy and receipts.
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub struct KvQualityBudgetV1 {
115    /// Maximum key logit MSE.
116    pub max_logit_mse: Option<f64>,
117    /// Maximum attention total variation distance.
118    pub max_attention_tv: Option<f64>,
119    /// Maximum top-k disagreement rate.
120    pub max_topk_disagreement: Option<f64>,
121    /// Maximum value aggregation MSE.
122    pub max_value_aggregation_mse: Option<f64>,
123    /// Fallback mode on violation.
124    pub fallback_on_violation: KvFallbackModeV1,
125}
126
127impl KvQualityBudgetV1 {
128    /// Unknown budget; policies should prefer calibration or raw fallback.
129    pub fn unavailable() -> Self {
130        Self {
131            max_logit_mse: None,
132            max_attention_tv: None,
133            max_topk_disagreement: None,
134            max_value_aggregation_mse: None,
135            fallback_on_violation: KvFallbackModeV1::KeepRaw,
136        }
137    }
138
139    /// Whether any quantitative budget is present.
140    pub fn has_any_metric(&self) -> bool {
141        self.max_logit_mse.is_some()
142            || self.max_attention_tv.is_some()
143            || self.max_topk_disagreement.is_some()
144            || self.max_value_aggregation_mse.is_some()
145    }
146
147    pub(crate) fn validate(&self) -> Result<()> {
148        for (name, value) in [
149            ("max_logit_mse", self.max_logit_mse),
150            ("max_attention_tv", self.max_attention_tv),
151            ("max_topk_disagreement", self.max_topk_disagreement),
152            ("max_value_aggregation_mse", self.max_value_aggregation_mse),
153        ] {
154            if let Some(value) = value {
155                if !value.is_finite() || value < 0.0 {
156                    return Err(FibQuantError::CorruptPayload(format!(
157                        "{name} must be finite and nonnegative"
158                    )));
159                }
160            }
161        }
162        Ok(())
163    }
164}
165
166/// KV compression profile binding shape, FibQuant artifacts, policy, and budgets.
167#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
168pub struct KvCompressionProfileV1 {
169    /// Stable schema marker.
170    pub schema_version: String,
171    /// Operator-chosen profile identifier.
172    pub profile_id: String,
173    /// Digest of the logical KV shape.
174    pub shape_digest: String,
175    /// Embedded FibQuant vector profile.
176    pub fib_profile: FibQuantProfileV1,
177    /// Digest of `fib_profile`.
178    pub fib_profile_digest: String,
179    /// Digest of the matching codebook.
180    pub codebook_digest: String,
181    /// Digest of the matching rotation.
182    pub rotation_digest: String,
183    /// Role this profile targets.
184    pub role_policy: KvRole,
185    /// Axis/policy declaration.
186    pub axis_policy: KvAxisPolicyV1,
187    /// Fixed-size page geometry.
188    pub page_geometry: KvPageGeometryV1,
189    /// Protected raw regions.
190    pub protected_policy: KvProtectedPolicyV1,
191    /// Raw/fail fallback policy.
192    pub fallback_policy: KvFallbackPolicyV1,
193    /// Quality budget.
194    pub quality_budget: KvQualityBudgetV1,
195    /// Calibration artifact digest or a stable missing marker.
196    pub calibration_digest: String,
197}
198
199impl KvCompressionProfileV1 {
200    /// Build a profile from an already built quantizer identity.
201    pub fn from_parts(
202        profile_id: impl Into<String>,
203        shape: &KvTensorShapeV1,
204        fib_profile: FibQuantProfileV1,
205        codebook_digest: impl Into<String>,
206        axis_policy: KvAxisPolicyV1,
207        page_geometry: KvPageGeometryV1,
208    ) -> Result<Self> {
209        shape.validate_block_dim(fib_profile.block_dim)?;
210        fib_profile.validate()?;
211        if fib_profile.ambient_dim != shape.head_dim {
212            return Err(FibQuantError::CorruptPayload(
213                "fib profile ambient_dim must equal kv head_dim for CPU reference codec".into(),
214            ));
215        }
216        let rotation_digest =
217            StoredRotation::new(fib_profile.ambient_dim as usize, fib_profile.rotation_seed)?
218                .digest()?;
219        let profile = Self {
220            schema_version: KV_PROFILE_SCHEMA.into(),
221            profile_id: profile_id.into(),
222            shape_digest: shape.digest()?,
223            fib_profile_digest: fib_profile.digest()?,
224            role_policy: shape.role,
225            fib_profile,
226            codebook_digest: codebook_digest.into(),
227            rotation_digest,
228            axis_policy,
229            page_geometry,
230            protected_policy: KvProtectedPolicyV1::none(),
231            fallback_policy: KvFallbackPolicyV1::keep_raw(),
232            quality_budget: KvQualityBudgetV1::unavailable(),
233            calibration_digest: "missing:calibration".into(),
234        };
235        profile.validate_for_shape(shape)?;
236        Ok(profile)
237    }
238
239    /// Validate profile against the expected shape.
240    pub fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
241        if self.schema_version != KV_PROFILE_SCHEMA {
242            return Err(FibQuantError::CorruptPayload(format!(
243                "kv profile schema_version {}, expected {KV_PROFILE_SCHEMA}",
244                self.schema_version
245            )));
246        }
247        shape.validate_block_dim(self.fib_profile.block_dim)?;
248        if self.shape_digest != shape.digest()? {
249            return Err(FibQuantError::ProfileDigestMismatch {
250                expected: shape.digest()?,
251                actual: self.shape_digest.clone(),
252            });
253        }
254        if self.role_policy != shape.role {
255            return Err(FibQuantError::CorruptPayload(
256                "kv profile role does not match shape role".into(),
257            ));
258        }
259        self.fib_profile.validate()?;
260        if self.fib_profile.ambient_dim != shape.head_dim {
261            return Err(FibQuantError::CorruptPayload(
262                "fib profile ambient_dim must equal kv head_dim".into(),
263            ));
264        }
265        let expected_fib = self.fib_profile.digest()?;
266        if self.fib_profile_digest != expected_fib {
267            return Err(FibQuantError::ProfileDigestMismatch {
268                expected: expected_fib,
269                actual: self.fib_profile_digest.clone(),
270            });
271        }
272        let expected_rotation = StoredRotation::new(
273            self.fib_profile.ambient_dim as usize,
274            self.fib_profile.rotation_seed,
275        )?
276        .digest()?;
277        if self.rotation_digest != expected_rotation {
278            return Err(FibQuantError::RotationDigestMismatch {
279                expected: expected_rotation,
280                actual: self.rotation_digest.clone(),
281            });
282        }
283        if self.codebook_digest.is_empty() {
284            return Err(FibQuantError::CorruptPayload(
285                "kv profile codebook_digest must be nonempty".into(),
286            ));
287        }
288        self.page_geometry.validate_for_shape(shape)?;
289        self.protected_policy.validate_for_shape(shape)?;
290        self.quality_budget.validate()?;
291        Ok(())
292    }
293
294    /// Stable digest for the KV profile.
295    pub fn digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
296        self.validate_for_shape(shape)?;
297        json_digest(KV_PROFILE_SCHEMA, self)
298    }
299}