1use serde::{Deserialize, Serialize};
2
3use crate::{digest::json_digest, FibQuantError, Result};
4
5pub const KV_SHAPE_SCHEMA: &str = "fib_quant_kv_tensor_shape_v1";
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[non_exhaustive]
10#[serde(rename_all = "snake_case")]
11pub enum KvRole {
12 Key,
14 Value,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[non_exhaustive]
21#[serde(rename_all = "snake_case")]
22pub enum KvRopeState {
23 PreRope,
25 PostRope,
27 NotApplicable,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33#[non_exhaustive]
34#[serde(rename_all = "snake_case")]
35pub enum KvAttentionKind {
36 Mha,
38 Mqa,
40 Gqa,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[non_exhaustive]
47#[serde(rename_all = "snake_case")]
48pub enum KvDType {
49 F16,
51 Bf16,
53 F32,
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct KvTensorShapeV1 {
60 pub schema_version: String,
62 pub role: KvRole,
64 pub attention_kind: KvAttentionKind,
66 pub batch: u32,
68 pub layers: u32,
70 pub kv_heads: u32,
72 pub query_heads: u32,
74 pub tokens: u32,
76 pub head_dim: u32,
78 pub dtype: KvDType,
80 pub rope_state: KvRopeState,
82}
83
84impl KvTensorShapeV1 {
85 #[allow(clippy::too_many_arguments)]
87 pub fn new(
88 role: KvRole,
89 attention_kind: KvAttentionKind,
90 batch: u32,
91 layers: u32,
92 kv_heads: u32,
93 query_heads: u32,
94 tokens: u32,
95 head_dim: u32,
96 dtype: KvDType,
97 rope_state: KvRopeState,
98 ) -> Self {
99 Self {
100 schema_version: KV_SHAPE_SCHEMA.into(),
101 role,
102 attention_kind,
103 batch,
104 layers,
105 kv_heads,
106 query_heads,
107 tokens,
108 head_dim,
109 dtype,
110 rope_state,
111 }
112 }
113
114 pub fn validate(&self) -> Result<()> {
116 if self.schema_version != KV_SHAPE_SCHEMA {
117 return Err(FibQuantError::CorruptPayload(format!(
118 "kv shape schema_version {}, expected {KV_SHAPE_SCHEMA}",
119 self.schema_version
120 )));
121 }
122 for (name, value) in [
123 ("batch", self.batch),
124 ("layers", self.layers),
125 ("kv_heads", self.kv_heads),
126 ("query_heads", self.query_heads),
127 ("tokens", self.tokens),
128 ("head_dim", self.head_dim),
129 ] {
130 if value == 0 {
131 return Err(FibQuantError::CorruptPayload(format!(
132 "kv shape {name} must be > 0"
133 )));
134 }
135 }
136 match self.attention_kind {
137 KvAttentionKind::Mha if self.query_heads != self.kv_heads => {
138 return Err(FibQuantError::CorruptPayload(
139 "MHA requires query_heads == kv_heads".into(),
140 ));
141 }
142 KvAttentionKind::Mqa if self.kv_heads != 1 => {
143 return Err(FibQuantError::CorruptPayload(
144 "MQA requires kv_heads == 1".into(),
145 ));
146 }
147 KvAttentionKind::Gqa | KvAttentionKind::Mqa
148 if self.query_heads % self.kv_heads != 0 =>
149 {
150 return Err(FibQuantError::CorruptPayload(
151 "query_heads must be divisible by kv_heads".into(),
152 ));
153 }
154 _ => {}
155 }
156 match (self.role, self.rope_state) {
157 (KvRole::Key, KvRopeState::NotApplicable) => {
158 return Err(FibQuantError::CorruptPayload(
159 "key tensors must declare pre_rope or post_rope".into(),
160 ));
161 }
162 (KvRole::Value, KvRopeState::PreRope | KvRopeState::PostRope) => {
163 return Err(FibQuantError::CorruptPayload(
164 "value tensors must use not_applicable rope state".into(),
165 ));
166 }
167 _ => {}
168 }
169 let _ = self.element_count()?;
170 Ok(())
171 }
172
173 pub fn validate_block_dim(&self, block_dim: u32) -> Result<()> {
175 self.validate()?;
176 if block_dim == 0 || self.head_dim % block_dim != 0 {
177 return Err(FibQuantError::DimensionNotDivisible {
178 ambient_dim: self.head_dim as usize,
179 block_dim: block_dim as usize,
180 });
181 }
182 Ok(())
183 }
184
185 pub fn vector_count(&self) -> Result<usize> {
187 checked_product(&[
188 self.batch as usize,
189 self.layers as usize,
190 self.kv_heads as usize,
191 self.tokens as usize,
192 ])
193 }
194
195 pub fn element_count(&self) -> Result<usize> {
197 checked_product(&[self.vector_count()?, self.head_dim as usize])
198 }
199
200 pub fn digest(&self) -> Result<String> {
202 self.validate()?;
203 json_digest(KV_SHAPE_SCHEMA, self)
204 }
205}
206
207pub(crate) fn checked_product(values: &[usize]) -> Result<usize> {
208 values.iter().try_fold(1usize, |acc, value| {
209 acc.checked_mul(*value)
210 .ok_or_else(|| FibQuantError::ResourceLimitExceeded("kv shape size overflow".into()))
211 })
212}