use serde::{Deserialize, Serialize};
use crate::{FibQuantError, Result};
use super::{
profile::{KvAxisPolicyV1, KvProtectedPolicyV1},
shape::{KvRole, KvTensorShapeV1},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvCompressionStrategyV1 {
Raw,
FibQuantPerToken,
FibQuantPerChannel,
RoleAwareKiviStyleBaseline,
ExperimentalFibQuantRoleAware,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct KvCompressionPolicyV1 {
pub strategy: KvCompressionStrategyV1,
pub protected_policy: KvProtectedPolicyV1,
pub require_calibration: bool,
pub allow_raw_fallback: bool,
}
impl KvCompressionPolicyV1 {
pub fn raw() -> Self {
Self {
strategy: KvCompressionStrategyV1::Raw,
protected_policy: KvProtectedPolicyV1::none(),
require_calibration: false,
allow_raw_fallback: true,
}
}
pub fn role_aware_baseline() -> Self {
Self {
strategy: KvCompressionStrategyV1::RoleAwareKiviStyleBaseline,
protected_policy: KvProtectedPolicyV1 {
first_tokens_raw: 0,
last_tokens_raw: 1,
raw_layers: Vec::new(),
raw_heads: Vec::new(),
},
require_calibration: false,
allow_raw_fallback: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvDecisionActionV1 {
KeepRaw,
Compress,
NeedCalibration,
Quarantine,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvDecisionReasonV1 {
RawStrategy,
ProtectedRegion,
UnsupportedShape,
CalibrationMissing,
KeyRoleAxis,
ValueRoleAxis,
ExperimentalRoleAware,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct KvCompressionDecisionV1 {
pub action: KvDecisionActionV1,
pub axis_policy: KvAxisPolicyV1,
pub reasons: Vec<KvDecisionReasonV1>,
}
pub fn decide_kv_compression(
policy: &KvCompressionPolicyV1,
shape: &KvTensorShapeV1,
layer: u32,
head: u32,
token: u32,
has_calibration: bool,
) -> Result<KvCompressionDecisionV1> {
shape.validate()?;
policy.protected_policy.validate_for_shape(shape)?;
if layer >= shape.layers || head >= shape.kv_heads || token >= shape.tokens {
return Err(FibQuantError::CorruptPayload(
"kv policy index outside shape".into(),
));
}
if policy
.protected_policy
.is_protected(shape, layer, head, token)
{
return Ok(KvCompressionDecisionV1 {
action: KvDecisionActionV1::KeepRaw,
axis_policy: KvAxisPolicyV1::Raw,
reasons: vec![KvDecisionReasonV1::ProtectedRegion],
});
}
if policy.require_calibration && !has_calibration {
return Ok(KvCompressionDecisionV1 {
action: if policy.allow_raw_fallback {
KvDecisionActionV1::KeepRaw
} else {
KvDecisionActionV1::NeedCalibration
},
axis_policy: KvAxisPolicyV1::Raw,
reasons: vec![KvDecisionReasonV1::CalibrationMissing],
});
}
match policy.strategy {
KvCompressionStrategyV1::Raw => Ok(KvCompressionDecisionV1 {
action: KvDecisionActionV1::KeepRaw,
axis_policy: KvAxisPolicyV1::Raw,
reasons: vec![KvDecisionReasonV1::RawStrategy],
}),
KvCompressionStrategyV1::FibQuantPerToken => Ok(KvCompressionDecisionV1 {
action: KvDecisionActionV1::Compress,
axis_policy: KvAxisPolicyV1::PerToken,
reasons: role_reason(shape.role),
}),
KvCompressionStrategyV1::FibQuantPerChannel => Ok(KvCompressionDecisionV1 {
action: KvDecisionActionV1::Compress,
axis_policy: KvAxisPolicyV1::PerChannel,
reasons: role_reason(shape.role),
}),
KvCompressionStrategyV1::RoleAwareKiviStyleBaseline => {
let axis_policy = match shape.role {
KvRole::Key => KvAxisPolicyV1::PerChannel,
KvRole::Value => KvAxisPolicyV1::PerToken,
};
Ok(KvCompressionDecisionV1 {
action: KvDecisionActionV1::Compress,
axis_policy,
reasons: role_reason(shape.role),
})
}
KvCompressionStrategyV1::ExperimentalFibQuantRoleAware => {
let axis_policy = match shape.role {
KvRole::Key => KvAxisPolicyV1::PerChannel,
KvRole::Value => KvAxisPolicyV1::PerToken,
};
Ok(KvCompressionDecisionV1 {
action: KvDecisionActionV1::Compress,
axis_policy,
reasons: vec![KvDecisionReasonV1::ExperimentalRoleAware],
})
}
}
}
fn role_reason(role: KvRole) -> Vec<KvDecisionReasonV1> {
match role {
KvRole::Key => vec![KvDecisionReasonV1::KeyRoleAxis],
KvRole::Value => vec![KvDecisionReasonV1::ValueRoleAxis],
}
}