1use serde::{Deserialize, Serialize};
2
3use crate::{
4 digest::json_digest, profile::FibQuantProfileV1, rotation::StoredRotation, FibQuantError,
5 Result,
6};
7
8use super::{
9 layout::KvPageGeometryV1,
10 shape::{KvRole, KvTensorShapeV1},
11};
12
13pub const KV_PROFILE_SCHEMA: &str = "fib_quant_kv_compression_profile_v1";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[non_exhaustive]
18#[serde(rename_all = "snake_case")]
19pub enum KvAxisPolicyV1 {
20 Raw,
22 PerToken,
24 PerChannel,
26 RoleAwareKiviStyle,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[non_exhaustive]
33#[serde(rename_all = "snake_case")]
34pub enum KvFallbackModeV1 {
35 KeepRaw,
37 FailClosed,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
43pub struct KvProtectedPolicyV1 {
44 pub first_tokens_raw: u32,
46 pub last_tokens_raw: u32,
48 pub raw_layers: Vec<u32>,
50 pub raw_heads: Vec<u32>,
52}
53
54impl KvProtectedPolicyV1 {
55 pub fn none() -> Self {
57 Self {
58 first_tokens_raw: 0,
59 last_tokens_raw: 0,
60 raw_layers: Vec::new(),
61 raw_heads: Vec::new(),
62 }
63 }
64
65 pub fn is_protected(&self, shape: &KvTensorShapeV1, layer: u32, head: u32, token: u32) -> bool {
67 token < self.first_tokens_raw
68 || token.saturating_add(self.last_tokens_raw) >= shape.tokens
69 || self.raw_layers.contains(&layer)
70 || self.raw_heads.contains(&head)
71 }
72
73 pub(crate) fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
74 if self.first_tokens_raw > shape.tokens || self.last_tokens_raw > shape.tokens {
75 return Err(FibQuantError::CorruptPayload(
76 "protected token count exceeds shape tokens".into(),
77 ));
78 }
79 if self.raw_layers.iter().any(|layer| *layer >= shape.layers) {
80 return Err(FibQuantError::CorruptPayload(
81 "protected layer outside shape".into(),
82 ));
83 }
84 if self.raw_heads.iter().any(|head| *head >= shape.kv_heads) {
85 return Err(FibQuantError::CorruptPayload(
86 "protected head outside shape".into(),
87 ));
88 }
89 Ok(())
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub struct KvFallbackPolicyV1 {
96 pub mode: KvFallbackModeV1,
98 pub raw_fallback_available: bool,
100}
101
102impl KvFallbackPolicyV1 {
103 pub fn keep_raw() -> Self {
105 Self {
106 mode: KvFallbackModeV1::KeepRaw,
107 raw_fallback_available: true,
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
114pub struct KvQualityBudgetV1 {
115 pub max_logit_mse: Option<f64>,
117 pub max_attention_tv: Option<f64>,
119 pub max_topk_disagreement: Option<f64>,
121 pub max_value_aggregation_mse: Option<f64>,
123 pub fallback_on_violation: KvFallbackModeV1,
125}
126
127impl KvQualityBudgetV1 {
128 pub fn unavailable() -> Self {
130 Self {
131 max_logit_mse: None,
132 max_attention_tv: None,
133 max_topk_disagreement: None,
134 max_value_aggregation_mse: None,
135 fallback_on_violation: KvFallbackModeV1::KeepRaw,
136 }
137 }
138
139 pub fn has_any_metric(&self) -> bool {
141 self.max_logit_mse.is_some()
142 || self.max_attention_tv.is_some()
143 || self.max_topk_disagreement.is_some()
144 || self.max_value_aggregation_mse.is_some()
145 }
146
147 pub(crate) fn validate(&self) -> Result<()> {
148 for (name, value) in [
149 ("max_logit_mse", self.max_logit_mse),
150 ("max_attention_tv", self.max_attention_tv),
151 ("max_topk_disagreement", self.max_topk_disagreement),
152 ("max_value_aggregation_mse", self.max_value_aggregation_mse),
153 ] {
154 if let Some(value) = value {
155 if !value.is_finite() || value < 0.0 {
156 return Err(FibQuantError::CorruptPayload(format!(
157 "{name} must be finite and nonnegative"
158 )));
159 }
160 }
161 }
162 Ok(())
163 }
164}
165
166#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
168pub struct KvCompressionProfileV1 {
169 pub schema_version: String,
171 pub profile_id: String,
173 pub shape_digest: String,
175 pub fib_profile: FibQuantProfileV1,
177 pub fib_profile_digest: String,
179 pub codebook_digest: String,
181 pub rotation_digest: String,
183 pub role_policy: KvRole,
185 pub axis_policy: KvAxisPolicyV1,
187 pub page_geometry: KvPageGeometryV1,
189 pub protected_policy: KvProtectedPolicyV1,
191 pub fallback_policy: KvFallbackPolicyV1,
193 pub quality_budget: KvQualityBudgetV1,
195 pub calibration_digest: String,
197}
198
199impl KvCompressionProfileV1 {
200 pub fn from_parts(
202 profile_id: impl Into<String>,
203 shape: &KvTensorShapeV1,
204 fib_profile: FibQuantProfileV1,
205 codebook_digest: impl Into<String>,
206 axis_policy: KvAxisPolicyV1,
207 page_geometry: KvPageGeometryV1,
208 ) -> Result<Self> {
209 shape.validate_block_dim(fib_profile.block_dim)?;
210 fib_profile.validate()?;
211 if fib_profile.ambient_dim != shape.head_dim {
212 return Err(FibQuantError::CorruptPayload(
213 "fib profile ambient_dim must equal kv head_dim for CPU reference codec".into(),
214 ));
215 }
216 let rotation_digest =
217 StoredRotation::new(fib_profile.ambient_dim as usize, fib_profile.rotation_seed)?
218 .digest()?;
219 let profile = Self {
220 schema_version: KV_PROFILE_SCHEMA.into(),
221 profile_id: profile_id.into(),
222 shape_digest: shape.digest()?,
223 fib_profile_digest: fib_profile.digest()?,
224 role_policy: shape.role,
225 fib_profile,
226 codebook_digest: codebook_digest.into(),
227 rotation_digest,
228 axis_policy,
229 page_geometry,
230 protected_policy: KvProtectedPolicyV1::none(),
231 fallback_policy: KvFallbackPolicyV1::keep_raw(),
232 quality_budget: KvQualityBudgetV1::unavailable(),
233 calibration_digest: "missing:calibration".into(),
234 };
235 profile.validate_for_shape(shape)?;
236 Ok(profile)
237 }
238
239 pub fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
241 if self.schema_version != KV_PROFILE_SCHEMA {
242 return Err(FibQuantError::CorruptPayload(format!(
243 "kv profile schema_version {}, expected {KV_PROFILE_SCHEMA}",
244 self.schema_version
245 )));
246 }
247 shape.validate_block_dim(self.fib_profile.block_dim)?;
248 if self.shape_digest != shape.digest()? {
249 return Err(FibQuantError::ProfileDigestMismatch {
250 expected: shape.digest()?,
251 actual: self.shape_digest.clone(),
252 });
253 }
254 if self.role_policy != shape.role {
255 return Err(FibQuantError::CorruptPayload(
256 "kv profile role does not match shape role".into(),
257 ));
258 }
259 self.fib_profile.validate()?;
260 if self.fib_profile.ambient_dim != shape.head_dim {
261 return Err(FibQuantError::CorruptPayload(
262 "fib profile ambient_dim must equal kv head_dim".into(),
263 ));
264 }
265 let expected_fib = self.fib_profile.digest()?;
266 if self.fib_profile_digest != expected_fib {
267 return Err(FibQuantError::ProfileDigestMismatch {
268 expected: expected_fib,
269 actual: self.fib_profile_digest.clone(),
270 });
271 }
272 let expected_rotation = StoredRotation::new(
273 self.fib_profile.ambient_dim as usize,
274 self.fib_profile.rotation_seed,
275 )?
276 .digest()?;
277 if self.rotation_digest != expected_rotation {
278 return Err(FibQuantError::RotationDigestMismatch {
279 expected: expected_rotation,
280 actual: self.rotation_digest.clone(),
281 });
282 }
283 if self.codebook_digest.is_empty() {
284 return Err(FibQuantError::CorruptPayload(
285 "kv profile codebook_digest must be nonempty".into(),
286 ));
287 }
288 self.page_geometry.validate_for_shape(shape)?;
289 self.protected_policy.validate_for_shape(shape)?;
290 self.quality_budget.validate()?;
291 Ok(())
292 }
293
294 pub fn digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
296 self.validate_for_shape(shape)?;
297 json_digest(KV_PROFILE_SCHEMA, self)
298 }
299}