Skip to main content

fib_quant/kv/
policy.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{FibQuantError, Result};
4
5use super::{
6    profile::{KvAxisPolicyV1, KvProtectedPolicyV1},
7    shape::{KvRole, KvTensorShapeV1},
8};
9
10/// Named compression strategy.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[non_exhaustive]
13#[serde(rename_all = "snake_case")]
14pub enum KvCompressionStrategyV1 {
15    /// Raw reference path.
16    Raw,
17    /// FibQuant each token/head vector.
18    FibQuantPerToken,
19    /// FibQuant channel/group vectors. Declared for policy baselines.
20    FibQuantPerChannel,
21    /// KIVI-style role split: keys channel-wise, values token-wise.
22    RoleAwareKiviStyleBaseline,
23    /// Experimental FibQuant role-aware profile.
24    ExperimentalFibQuantRoleAware,
25}
26
27/// Policy configuration.
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29pub struct KvCompressionPolicyV1 {
30    /// Strategy.
31    pub strategy: KvCompressionStrategyV1,
32    /// Protected raw regions.
33    pub protected_policy: KvProtectedPolicyV1,
34    /// Require a calibration digest/budget before compression.
35    pub require_calibration: bool,
36    /// Allow raw fallback decisions.
37    pub allow_raw_fallback: bool,
38}
39
40impl KvCompressionPolicyV1 {
41    /// Conservative raw policy.
42    pub fn raw() -> Self {
43        Self {
44            strategy: KvCompressionStrategyV1::Raw,
45            protected_policy: KvProtectedPolicyV1::none(),
46            require_calibration: false,
47            allow_raw_fallback: true,
48        }
49    }
50
51    /// Role-aware baseline policy.
52    pub fn role_aware_baseline() -> Self {
53        Self {
54            strategy: KvCompressionStrategyV1::RoleAwareKiviStyleBaseline,
55            protected_policy: KvProtectedPolicyV1 {
56                first_tokens_raw: 0,
57                last_tokens_raw: 1,
58                raw_layers: Vec::new(),
59                raw_heads: Vec::new(),
60            },
61            require_calibration: false,
62            allow_raw_fallback: true,
63        }
64    }
65}
66
67/// Decision action.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
69#[non_exhaustive]
70#[serde(rename_all = "snake_case")]
71pub enum KvDecisionActionV1 {
72    /// Store raw.
73    KeepRaw,
74    /// Compress using the selected axis.
75    Compress,
76    /// Calibration is needed before compression.
77    NeedCalibration,
78    /// Quarantine stale or mismatched artifacts.
79    Quarantine,
80}
81
82/// Decision reason.
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
84#[non_exhaustive]
85#[serde(rename_all = "snake_case")]
86pub enum KvDecisionReasonV1 {
87    /// Explicit raw strategy.
88    RawStrategy,
89    /// Protected token, layer, or head.
90    ProtectedRegion,
91    /// Unsupported shape/layout for selected strategy.
92    UnsupportedShape,
93    /// Missing calibration or quality budget.
94    CalibrationMissing,
95    /// Key role selects key-oriented axis.
96    KeyRoleAxis,
97    /// Value role selects value-oriented axis.
98    ValueRoleAxis,
99    /// Experimental role-aware FibQuant selection.
100    ExperimentalRoleAware,
101}
102
103/// Policy decision for one vector/block.
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct KvCompressionDecisionV1 {
106    /// Action to take.
107    pub action: KvDecisionActionV1,
108    /// Selected axis if compressing.
109    pub axis_policy: KvAxisPolicyV1,
110    /// Reasons for the action.
111    pub reasons: Vec<KvDecisionReasonV1>,
112}
113
114/// Decide compression for a single logical vector.
115pub fn decide_kv_compression(
116    policy: &KvCompressionPolicyV1,
117    shape: &KvTensorShapeV1,
118    layer: u32,
119    head: u32,
120    token: u32,
121    has_calibration: bool,
122) -> Result<KvCompressionDecisionV1> {
123    shape.validate()?;
124    policy.protected_policy.validate_for_shape(shape)?;
125    if layer >= shape.layers || head >= shape.kv_heads || token >= shape.tokens {
126        return Err(FibQuantError::CorruptPayload(
127            "kv policy index outside shape".into(),
128        ));
129    }
130    if policy
131        .protected_policy
132        .is_protected(shape, layer, head, token)
133    {
134        return Ok(KvCompressionDecisionV1 {
135            action: KvDecisionActionV1::KeepRaw,
136            axis_policy: KvAxisPolicyV1::Raw,
137            reasons: vec![KvDecisionReasonV1::ProtectedRegion],
138        });
139    }
140    if policy.require_calibration && !has_calibration {
141        return Ok(KvCompressionDecisionV1 {
142            action: if policy.allow_raw_fallback {
143                KvDecisionActionV1::KeepRaw
144            } else {
145                KvDecisionActionV1::NeedCalibration
146            },
147            axis_policy: KvAxisPolicyV1::Raw,
148            reasons: vec![KvDecisionReasonV1::CalibrationMissing],
149        });
150    }
151    match policy.strategy {
152        KvCompressionStrategyV1::Raw => Ok(KvCompressionDecisionV1 {
153            action: KvDecisionActionV1::KeepRaw,
154            axis_policy: KvAxisPolicyV1::Raw,
155            reasons: vec![KvDecisionReasonV1::RawStrategy],
156        }),
157        KvCompressionStrategyV1::FibQuantPerToken => Ok(KvCompressionDecisionV1 {
158            action: KvDecisionActionV1::Compress,
159            axis_policy: KvAxisPolicyV1::PerToken,
160            reasons: role_reason(shape.role),
161        }),
162        KvCompressionStrategyV1::FibQuantPerChannel => Ok(KvCompressionDecisionV1 {
163            action: KvDecisionActionV1::Compress,
164            axis_policy: KvAxisPolicyV1::PerChannel,
165            reasons: role_reason(shape.role),
166        }),
167        KvCompressionStrategyV1::RoleAwareKiviStyleBaseline => {
168            let axis_policy = match shape.role {
169                KvRole::Key => KvAxisPolicyV1::PerChannel,
170                KvRole::Value => KvAxisPolicyV1::PerToken,
171            };
172            Ok(KvCompressionDecisionV1 {
173                action: KvDecisionActionV1::Compress,
174                axis_policy,
175                reasons: role_reason(shape.role),
176            })
177        }
178        KvCompressionStrategyV1::ExperimentalFibQuantRoleAware => {
179            let axis_policy = match shape.role {
180                KvRole::Key => KvAxisPolicyV1::PerChannel,
181                KvRole::Value => KvAxisPolicyV1::PerToken,
182            };
183            Ok(KvCompressionDecisionV1 {
184                action: KvDecisionActionV1::Compress,
185                axis_policy,
186                reasons: vec![KvDecisionReasonV1::ExperimentalRoleAware],
187            })
188        }
189    }
190}
191
192fn role_reason(role: KvRole) -> Vec<KvDecisionReasonV1> {
193    match role {
194        KvRole::Key => vec![KvDecisionReasonV1::KeyRoleAxis],
195        KvRole::Value => vec![KvDecisionReasonV1::ValueRoleAxis],
196    }
197}