1use serde::{Deserialize, Serialize};
2
3use crate::{FibQuantError, Result};
4
5use super::{
6 profile::{KvAxisPolicyV1, KvProtectedPolicyV1},
7 shape::{KvRole, KvTensorShapeV1},
8};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[non_exhaustive]
13#[serde(rename_all = "snake_case")]
14pub enum KvCompressionStrategyV1 {
15 Raw,
17 FibQuantPerToken,
19 FibQuantPerChannel,
21 RoleAwareKiviStyleBaseline,
23 ExperimentalFibQuantRoleAware,
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub struct KvCompressionPolicyV1 {
30 pub strategy: KvCompressionStrategyV1,
32 pub protected_policy: KvProtectedPolicyV1,
34 pub require_calibration: bool,
36 pub allow_raw_fallback: bool,
38}
39
40impl KvCompressionPolicyV1 {
41 pub fn raw() -> Self {
43 Self {
44 strategy: KvCompressionStrategyV1::Raw,
45 protected_policy: KvProtectedPolicyV1::none(),
46 require_calibration: false,
47 allow_raw_fallback: true,
48 }
49 }
50
51 pub fn role_aware_baseline() -> Self {
53 Self {
54 strategy: KvCompressionStrategyV1::RoleAwareKiviStyleBaseline,
55 protected_policy: KvProtectedPolicyV1 {
56 first_tokens_raw: 0,
57 last_tokens_raw: 1,
58 raw_layers: Vec::new(),
59 raw_heads: Vec::new(),
60 },
61 require_calibration: false,
62 allow_raw_fallback: true,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
69#[non_exhaustive]
70#[serde(rename_all = "snake_case")]
71pub enum KvDecisionActionV1 {
72 KeepRaw,
74 Compress,
76 NeedCalibration,
78 Quarantine,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
84#[non_exhaustive]
85#[serde(rename_all = "snake_case")]
86pub enum KvDecisionReasonV1 {
87 RawStrategy,
89 ProtectedRegion,
91 UnsupportedShape,
93 CalibrationMissing,
95 KeyRoleAxis,
97 ValueRoleAxis,
99 ExperimentalRoleAware,
101}
102
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct KvCompressionDecisionV1 {
106 pub action: KvDecisionActionV1,
108 pub axis_policy: KvAxisPolicyV1,
110 pub reasons: Vec<KvDecisionReasonV1>,
112}
113
114pub fn decide_kv_compression(
116 policy: &KvCompressionPolicyV1,
117 shape: &KvTensorShapeV1,
118 layer: u32,
119 head: u32,
120 token: u32,
121 has_calibration: bool,
122) -> Result<KvCompressionDecisionV1> {
123 shape.validate()?;
124 policy.protected_policy.validate_for_shape(shape)?;
125 if layer >= shape.layers || head >= shape.kv_heads || token >= shape.tokens {
126 return Err(FibQuantError::CorruptPayload(
127 "kv policy index outside shape".into(),
128 ));
129 }
130 if policy
131 .protected_policy
132 .is_protected(shape, layer, head, token)
133 {
134 return Ok(KvCompressionDecisionV1 {
135 action: KvDecisionActionV1::KeepRaw,
136 axis_policy: KvAxisPolicyV1::Raw,
137 reasons: vec![KvDecisionReasonV1::ProtectedRegion],
138 });
139 }
140 if policy.require_calibration && !has_calibration {
141 return Ok(KvCompressionDecisionV1 {
142 action: if policy.allow_raw_fallback {
143 KvDecisionActionV1::KeepRaw
144 } else {
145 KvDecisionActionV1::NeedCalibration
146 },
147 axis_policy: KvAxisPolicyV1::Raw,
148 reasons: vec![KvDecisionReasonV1::CalibrationMissing],
149 });
150 }
151 match policy.strategy {
152 KvCompressionStrategyV1::Raw => Ok(KvCompressionDecisionV1 {
153 action: KvDecisionActionV1::KeepRaw,
154 axis_policy: KvAxisPolicyV1::Raw,
155 reasons: vec![KvDecisionReasonV1::RawStrategy],
156 }),
157 KvCompressionStrategyV1::FibQuantPerToken => Ok(KvCompressionDecisionV1 {
158 action: KvDecisionActionV1::Compress,
159 axis_policy: KvAxisPolicyV1::PerToken,
160 reasons: role_reason(shape.role),
161 }),
162 KvCompressionStrategyV1::FibQuantPerChannel => Ok(KvCompressionDecisionV1 {
163 action: KvDecisionActionV1::Compress,
164 axis_policy: KvAxisPolicyV1::PerChannel,
165 reasons: role_reason(shape.role),
166 }),
167 KvCompressionStrategyV1::RoleAwareKiviStyleBaseline => {
168 let axis_policy = match shape.role {
169 KvRole::Key => KvAxisPolicyV1::PerChannel,
170 KvRole::Value => KvAxisPolicyV1::PerToken,
171 };
172 Ok(KvCompressionDecisionV1 {
173 action: KvDecisionActionV1::Compress,
174 axis_policy,
175 reasons: role_reason(shape.role),
176 })
177 }
178 KvCompressionStrategyV1::ExperimentalFibQuantRoleAware => {
179 let axis_policy = match shape.role {
180 KvRole::Key => KvAxisPolicyV1::PerChannel,
181 KvRole::Value => KvAxisPolicyV1::PerToken,
182 };
183 Ok(KvCompressionDecisionV1 {
184 action: KvDecisionActionV1::Compress,
185 axis_policy,
186 reasons: vec![KvDecisionReasonV1::ExperimentalRoleAware],
187 })
188 }
189 }
190}
191
192fn role_reason(role: KvRole) -> Vec<KvDecisionReasonV1> {
193 match role {
194 KvRole::Key => vec![KvDecisionReasonV1::KeyRoleAxis],
195 KvRole::Value => vec![KvDecisionReasonV1::ValueRoleAxis],
196 }
197}