use serde::{Deserialize, Serialize};
use crate::{
digest::json_digest, profile::FibQuantProfileV1, rotation::StoredRotation, FibQuantError,
Result,
};
use super::{
layout::KvPageGeometryV1,
shape::{KvRole, KvTensorShapeV1},
};
pub const KV_PROFILE_SCHEMA: &str = "fib_quant_kv_compression_profile_v1";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvAxisPolicyV1 {
Raw,
PerToken,
PerChannel,
RoleAwareKiviStyle,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvFallbackModeV1 {
KeepRaw,
FailClosed,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct KvProtectedPolicyV1 {
pub first_tokens_raw: u32,
pub last_tokens_raw: u32,
pub raw_layers: Vec<u32>,
pub raw_heads: Vec<u32>,
}
impl KvProtectedPolicyV1 {
pub fn none() -> Self {
Self {
first_tokens_raw: 0,
last_tokens_raw: 0,
raw_layers: Vec::new(),
raw_heads: Vec::new(),
}
}
pub fn is_protected(&self, shape: &KvTensorShapeV1, layer: u32, head: u32, token: u32) -> bool {
token < self.first_tokens_raw
|| token.saturating_add(self.last_tokens_raw) >= shape.tokens
|| self.raw_layers.contains(&layer)
|| self.raw_heads.contains(&head)
}
pub(crate) fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
if self.first_tokens_raw > shape.tokens || self.last_tokens_raw > shape.tokens {
return Err(FibQuantError::CorruptPayload(
"protected token count exceeds shape tokens".into(),
));
}
if self.raw_layers.iter().any(|layer| *layer >= shape.layers) {
return Err(FibQuantError::CorruptPayload(
"protected layer outside shape".into(),
));
}
if self.raw_heads.iter().any(|head| *head >= shape.kv_heads) {
return Err(FibQuantError::CorruptPayload(
"protected head outside shape".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct KvFallbackPolicyV1 {
pub mode: KvFallbackModeV1,
pub raw_fallback_available: bool,
}
impl KvFallbackPolicyV1 {
pub fn keep_raw() -> Self {
Self {
mode: KvFallbackModeV1::KeepRaw,
raw_fallback_available: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct KvQualityBudgetV1 {
pub max_logit_mse: Option<f64>,
pub max_attention_tv: Option<f64>,
pub max_topk_disagreement: Option<f64>,
pub max_value_aggregation_mse: Option<f64>,
pub fallback_on_violation: KvFallbackModeV1,
}
impl KvQualityBudgetV1 {
pub fn unavailable() -> Self {
Self {
max_logit_mse: None,
max_attention_tv: None,
max_topk_disagreement: None,
max_value_aggregation_mse: None,
fallback_on_violation: KvFallbackModeV1::KeepRaw,
}
}
pub fn has_any_metric(&self) -> bool {
self.max_logit_mse.is_some()
|| self.max_attention_tv.is_some()
|| self.max_topk_disagreement.is_some()
|| self.max_value_aggregation_mse.is_some()
}
pub(crate) fn validate(&self) -> Result<()> {
for (name, value) in [
("max_logit_mse", self.max_logit_mse),
("max_attention_tv", self.max_attention_tv),
("max_topk_disagreement", self.max_topk_disagreement),
("max_value_aggregation_mse", self.max_value_aggregation_mse),
] {
if let Some(value) = value {
if !value.is_finite() || value < 0.0 {
return Err(FibQuantError::CorruptPayload(format!(
"{name} must be finite and nonnegative"
)));
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct KvCompressionProfileV1 {
pub schema_version: String,
pub profile_id: String,
pub shape_digest: String,
pub fib_profile: FibQuantProfileV1,
pub fib_profile_digest: String,
pub codebook_digest: String,
pub rotation_digest: String,
pub role_policy: KvRole,
pub axis_policy: KvAxisPolicyV1,
pub page_geometry: KvPageGeometryV1,
pub protected_policy: KvProtectedPolicyV1,
pub fallback_policy: KvFallbackPolicyV1,
pub quality_budget: KvQualityBudgetV1,
pub calibration_digest: String,
}
impl KvCompressionProfileV1 {
pub fn from_parts(
profile_id: impl Into<String>,
shape: &KvTensorShapeV1,
fib_profile: FibQuantProfileV1,
codebook_digest: impl Into<String>,
axis_policy: KvAxisPolicyV1,
page_geometry: KvPageGeometryV1,
) -> Result<Self> {
shape.validate_block_dim(fib_profile.block_dim)?;
fib_profile.validate()?;
if fib_profile.ambient_dim != shape.head_dim {
return Err(FibQuantError::CorruptPayload(
"fib profile ambient_dim must equal kv head_dim for CPU reference codec".into(),
));
}
let rotation_digest =
StoredRotation::new(fib_profile.ambient_dim as usize, fib_profile.rotation_seed)?
.digest()?;
let profile = Self {
schema_version: KV_PROFILE_SCHEMA.into(),
profile_id: profile_id.into(),
shape_digest: shape.digest()?,
fib_profile_digest: fib_profile.digest()?,
role_policy: shape.role,
fib_profile,
codebook_digest: codebook_digest.into(),
rotation_digest,
axis_policy,
page_geometry,
protected_policy: KvProtectedPolicyV1::none(),
fallback_policy: KvFallbackPolicyV1::keep_raw(),
quality_budget: KvQualityBudgetV1::unavailable(),
calibration_digest: "missing:calibration".into(),
};
profile.validate_for_shape(shape)?;
Ok(profile)
}
pub fn validate_for_shape(&self, shape: &KvTensorShapeV1) -> Result<()> {
if self.schema_version != KV_PROFILE_SCHEMA {
return Err(FibQuantError::CorruptPayload(format!(
"kv profile schema_version {}, expected {KV_PROFILE_SCHEMA}",
self.schema_version
)));
}
shape.validate_block_dim(self.fib_profile.block_dim)?;
if self.shape_digest != shape.digest()? {
return Err(FibQuantError::ProfileDigestMismatch {
expected: shape.digest()?,
actual: self.shape_digest.clone(),
});
}
if self.role_policy != shape.role {
return Err(FibQuantError::CorruptPayload(
"kv profile role does not match shape role".into(),
));
}
self.fib_profile.validate()?;
if self.fib_profile.ambient_dim != shape.head_dim {
return Err(FibQuantError::CorruptPayload(
"fib profile ambient_dim must equal kv head_dim".into(),
));
}
let expected_fib = self.fib_profile.digest()?;
if self.fib_profile_digest != expected_fib {
return Err(FibQuantError::ProfileDigestMismatch {
expected: expected_fib,
actual: self.fib_profile_digest.clone(),
});
}
let expected_rotation = StoredRotation::new(
self.fib_profile.ambient_dim as usize,
self.fib_profile.rotation_seed,
)?
.digest()?;
if self.rotation_digest != expected_rotation {
return Err(FibQuantError::RotationDigestMismatch {
expected: expected_rotation,
actual: self.rotation_digest.clone(),
});
}
if self.codebook_digest.is_empty() {
return Err(FibQuantError::CorruptPayload(
"kv profile codebook_digest must be nonempty".into(),
));
}
self.page_geometry.validate_for_shape(shape)?;
self.protected_policy.validate_for_shape(shape)?;
self.quality_budget.validate()?;
Ok(())
}
pub fn digest(&self, shape: &KvTensorShapeV1) -> Result<String> {
self.validate_for_shape(shape)?;
json_digest(KV_PROFILE_SCHEMA, self)
}
}