use serde::{Deserialize, Serialize};
use crate::{digest::json_digest, FibQuantError, Result};
pub const KV_SHAPE_SCHEMA: &str = "fib_quant_kv_tensor_shape_v1";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvRole {
Key,
Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvRopeState {
PreRope,
PostRope,
NotApplicable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvAttentionKind {
Mha,
Mqa,
Gqa,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
pub enum KvDType {
F16,
Bf16,
F32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct KvTensorShapeV1 {
pub schema_version: String,
pub role: KvRole,
pub attention_kind: KvAttentionKind,
pub batch: u32,
pub layers: u32,
pub kv_heads: u32,
pub query_heads: u32,
pub tokens: u32,
pub head_dim: u32,
pub dtype: KvDType,
pub rope_state: KvRopeState,
}
impl KvTensorShapeV1 {
#[allow(clippy::too_many_arguments)]
pub fn new(
role: KvRole,
attention_kind: KvAttentionKind,
batch: u32,
layers: u32,
kv_heads: u32,
query_heads: u32,
tokens: u32,
head_dim: u32,
dtype: KvDType,
rope_state: KvRopeState,
) -> Self {
Self {
schema_version: KV_SHAPE_SCHEMA.into(),
role,
attention_kind,
batch,
layers,
kv_heads,
query_heads,
tokens,
head_dim,
dtype,
rope_state,
}
}
pub fn validate(&self) -> Result<()> {
if self.schema_version != KV_SHAPE_SCHEMA {
return Err(FibQuantError::CorruptPayload(format!(
"kv shape schema_version {}, expected {KV_SHAPE_SCHEMA}",
self.schema_version
)));
}
for (name, value) in [
("batch", self.batch),
("layers", self.layers),
("kv_heads", self.kv_heads),
("query_heads", self.query_heads),
("tokens", self.tokens),
("head_dim", self.head_dim),
] {
if value == 0 {
return Err(FibQuantError::CorruptPayload(format!(
"kv shape {name} must be > 0"
)));
}
}
match self.attention_kind {
KvAttentionKind::Mha if self.query_heads != self.kv_heads => {
return Err(FibQuantError::CorruptPayload(
"MHA requires query_heads == kv_heads".into(),
));
}
KvAttentionKind::Mqa if self.kv_heads != 1 => {
return Err(FibQuantError::CorruptPayload(
"MQA requires kv_heads == 1".into(),
));
}
KvAttentionKind::Gqa | KvAttentionKind::Mqa
if self.query_heads % self.kv_heads != 0 =>
{
return Err(FibQuantError::CorruptPayload(
"query_heads must be divisible by kv_heads".into(),
));
}
_ => {}
}
match (self.role, self.rope_state) {
(KvRole::Key, KvRopeState::NotApplicable) => {
return Err(FibQuantError::CorruptPayload(
"key tensors must declare pre_rope or post_rope".into(),
));
}
(KvRole::Value, KvRopeState::PreRope | KvRopeState::PostRope) => {
return Err(FibQuantError::CorruptPayload(
"value tensors must use not_applicable rope state".into(),
));
}
_ => {}
}
let _ = self.element_count()?;
Ok(())
}
pub fn validate_block_dim(&self, block_dim: u32) -> Result<()> {
self.validate()?;
if block_dim == 0 || self.head_dim % block_dim != 0 {
return Err(FibQuantError::DimensionNotDivisible {
ambient_dim: self.head_dim as usize,
block_dim: block_dim as usize,
});
}
Ok(())
}
pub fn vector_count(&self) -> Result<usize> {
checked_product(&[
self.batch as usize,
self.layers as usize,
self.kv_heads as usize,
self.tokens as usize,
])
}
pub fn element_count(&self) -> Result<usize> {
checked_product(&[self.vector_count()?, self.head_dim as usize])
}
pub fn digest(&self) -> Result<String> {
self.validate()?;
json_digest(KV_SHAPE_SCHEMA, self)
}
}
pub(crate) fn checked_product(values: &[usize]) -> Result<usize> {
values.iter().try_fold(1usize, |acc, value| {
acc.checked_mul(*value)
.ok_or_else(|| FibQuantError::ResourceLimitExceeded("kv shape size overflow".into()))
})
}