Skip to main content

fib_quant/kv/
shape.rs

1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, FibQuantError, Result};
4
5pub const KV_SHAPE_SCHEMA: &str = "fib_quant_kv_tensor_shape_v1";
6
7/// KV tensor role.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[non_exhaustive]
10#[serde(rename_all = "snake_case")]
11pub enum KvRole {
12    /// Attention key cache.
13    Key,
14    /// Attention value cache.
15    Value,
16}
17
18/// Key RoPE state.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[non_exhaustive]
21#[serde(rename_all = "snake_case")]
22pub enum KvRopeState {
23    /// Key tensor captured before RoPE.
24    PreRope,
25    /// Key tensor captured after RoPE.
26    PostRope,
27    /// Value tensors and non-RoPE tensors.
28    NotApplicable,
29}
30
31/// Attention geometry family.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33#[non_exhaustive]
34#[serde(rename_all = "snake_case")]
35pub enum KvAttentionKind {
36    /// Multi-head attention.
37    Mha,
38    /// Multi-query attention.
39    Mqa,
40    /// Grouped-query attention.
41    Gqa,
42}
43
44/// Source tensor dtype.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[non_exhaustive]
47#[serde(rename_all = "snake_case")]
48pub enum KvDType {
49    /// IEEE fp16 source.
50    F16,
51    /// bfloat16 source.
52    Bf16,
53    /// f32 source.
54    F32,
55}
56
57/// Logical KV tensor shape.
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct KvTensorShapeV1 {
60    /// Stable schema marker.
61    pub schema_version: String,
62    /// Key or value role.
63    pub role: KvRole,
64    /// Attention head sharing geometry.
65    pub attention_kind: KvAttentionKind,
66    /// Batch count.
67    pub batch: u32,
68    /// Layer count.
69    pub layers: u32,
70    /// KV head count.
71    pub kv_heads: u32,
72    /// Query head count.
73    pub query_heads: u32,
74    /// Token count.
75    pub tokens: u32,
76    /// Per-head channel dimension.
77    pub head_dim: u32,
78    /// Source dtype.
79    pub dtype: KvDType,
80    /// Key RoPE state or not-applicable for values.
81    pub rope_state: KvRopeState,
82}
83
84impl KvTensorShapeV1 {
85    /// Create a shape with the v1 schema marker.
86    #[allow(clippy::too_many_arguments)]
87    pub fn new(
88        role: KvRole,
89        attention_kind: KvAttentionKind,
90        batch: u32,
91        layers: u32,
92        kv_heads: u32,
93        query_heads: u32,
94        tokens: u32,
95        head_dim: u32,
96        dtype: KvDType,
97        rope_state: KvRopeState,
98    ) -> Self {
99        Self {
100            schema_version: KV_SHAPE_SCHEMA.into(),
101            role,
102            attention_kind,
103            batch,
104            layers,
105            kv_heads,
106            query_heads,
107            tokens,
108            head_dim,
109            dtype,
110            rope_state,
111        }
112    }
113
114    /// Validate shape invariants that are independent of compression profile.
115    pub fn validate(&self) -> Result<()> {
116        if self.schema_version != KV_SHAPE_SCHEMA {
117            return Err(FibQuantError::CorruptPayload(format!(
118                "kv shape schema_version {}, expected {KV_SHAPE_SCHEMA}",
119                self.schema_version
120            )));
121        }
122        for (name, value) in [
123            ("batch", self.batch),
124            ("layers", self.layers),
125            ("kv_heads", self.kv_heads),
126            ("query_heads", self.query_heads),
127            ("tokens", self.tokens),
128            ("head_dim", self.head_dim),
129        ] {
130            if value == 0 {
131                return Err(FibQuantError::CorruptPayload(format!(
132                    "kv shape {name} must be > 0"
133                )));
134            }
135        }
136        match self.attention_kind {
137            KvAttentionKind::Mha if self.query_heads != self.kv_heads => {
138                return Err(FibQuantError::CorruptPayload(
139                    "MHA requires query_heads == kv_heads".into(),
140                ));
141            }
142            KvAttentionKind::Mqa if self.kv_heads != 1 => {
143                return Err(FibQuantError::CorruptPayload(
144                    "MQA requires kv_heads == 1".into(),
145                ));
146            }
147            KvAttentionKind::Gqa | KvAttentionKind::Mqa
148                if self.query_heads % self.kv_heads != 0 =>
149            {
150                return Err(FibQuantError::CorruptPayload(
151                    "query_heads must be divisible by kv_heads".into(),
152                ));
153            }
154            _ => {}
155        }
156        match (self.role, self.rope_state) {
157            (KvRole::Key, KvRopeState::NotApplicable) => {
158                return Err(FibQuantError::CorruptPayload(
159                    "key tensors must declare pre_rope or post_rope".into(),
160                ));
161            }
162            (KvRole::Value, KvRopeState::PreRope | KvRopeState::PostRope) => {
163                return Err(FibQuantError::CorruptPayload(
164                    "value tensors must use not_applicable rope state".into(),
165                ));
166            }
167            _ => {}
168        }
169        let _ = self.element_count()?;
170        Ok(())
171    }
172
173    /// Validate that the head dimension can be directly compressed by a FibQuant block.
174    pub fn validate_block_dim(&self, block_dim: u32) -> Result<()> {
175        self.validate()?;
176        if block_dim == 0 || self.head_dim % block_dim != 0 {
177            return Err(FibQuantError::DimensionNotDivisible {
178                ambient_dim: self.head_dim as usize,
179                block_dim: block_dim as usize,
180            });
181        }
182        Ok(())
183    }
184
185    /// Number of `[batch, layer, kv_head, token]` vectors.
186    pub fn vector_count(&self) -> Result<usize> {
187        checked_product(&[
188            self.batch as usize,
189            self.layers as usize,
190            self.kv_heads as usize,
191            self.tokens as usize,
192        ])
193    }
194
195    /// Number of f32 scalar values in canonical contiguous form.
196    pub fn element_count(&self) -> Result<usize> {
197        checked_product(&[self.vector_count()?, self.head_dim as usize])
198    }
199
200    /// Stable digest over the explicit shape.
201    pub fn digest(&self) -> Result<String> {
202        self.validate()?;
203        json_digest(KV_SHAPE_SCHEMA, self)
204    }
205}
206
207pub(crate) fn checked_product(values: &[usize]) -> Result<usize> {
208    values.iter().try_fold(1usize, |acc, value| {
209        acc.checked_mul(*value)
210            .ok_or_else(|| FibQuantError::ResourceLimitExceeded("kv shape size overflow".into()))
211    })
212}