1use serde::{Deserialize, Serialize};
2
3use crate::{codec::FibCodeV1, digest::json_digest, FibQuantError, Result};
4
5pub const KV_BLOCK_SCHEMA: &str = "fib_quant_kv_encoded_block_v1";
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9#[serde(rename_all = "snake_case")]
10pub enum KvBlockEncodingV1 {
11 RawF32 { values: Vec<f32> },
13 FibQuant { code: Box<FibCodeV1> },
15}
16
17#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
19pub struct KvEncodedBlockV1 {
20 pub schema_version: String,
22 pub block_id: u32,
24 pub batch: u32,
26 pub layer: u32,
28 pub kv_head: u32,
30 pub token: u32,
32 pub vector_count: u32,
34 pub fixed_size_bytes: u32,
36 pub raw_fallback: bool,
38 pub reason: String,
40 pub encoding: KvBlockEncodingV1,
42}
43
44impl KvEncodedBlockV1 {
45 #[allow(clippy::too_many_arguments)]
47 pub fn raw(
48 block_id: u32,
49 batch: u32,
50 layer: u32,
51 kv_head: u32,
52 token: u32,
53 values: Vec<f32>,
54 fixed_size_bytes: u32,
55 reason: impl Into<String>,
56 ) -> Self {
57 Self {
58 schema_version: KV_BLOCK_SCHEMA.into(),
59 block_id,
60 batch,
61 layer,
62 kv_head,
63 token,
64 vector_count: 1,
65 fixed_size_bytes,
66 raw_fallback: true,
67 reason: reason.into(),
68 encoding: KvBlockEncodingV1::RawF32 { values },
69 }
70 }
71
72 #[allow(clippy::too_many_arguments)]
74 pub fn fib_quant(
75 block_id: u32,
76 batch: u32,
77 layer: u32,
78 kv_head: u32,
79 token: u32,
80 code: FibCodeV1,
81 fixed_size_bytes: u32,
82 reason: impl Into<String>,
83 ) -> Self {
84 Self {
85 schema_version: KV_BLOCK_SCHEMA.into(),
86 block_id,
87 batch,
88 layer,
89 kv_head,
90 token,
91 vector_count: 1,
92 fixed_size_bytes,
93 raw_fallback: false,
94 reason: reason.into(),
95 encoding: KvBlockEncodingV1::FibQuant {
96 code: Box::new(code),
97 },
98 }
99 }
100
101 pub fn validate(&self, head_dim: u32) -> Result<()> {
103 if self.schema_version != KV_BLOCK_SCHEMA {
104 return Err(FibQuantError::CorruptPayload(format!(
105 "kv block schema_version {}, expected {KV_BLOCK_SCHEMA}",
106 self.schema_version
107 )));
108 }
109 if self.vector_count != 1 {
110 return Err(FibQuantError::CorruptPayload(
111 "kv block vector_count must be 1".into(),
112 ));
113 }
114 if self.fixed_size_bytes == 0 {
115 return Err(FibQuantError::CorruptPayload(
116 "kv block fixed_size_bytes must be nonzero".into(),
117 ));
118 }
119 match &self.encoding {
120 KvBlockEncodingV1::RawF32 { values } => {
121 if !self.raw_fallback {
122 return Err(FibQuantError::CorruptPayload(
123 "raw block must set raw_fallback".into(),
124 ));
125 }
126 if values.len() != head_dim as usize {
127 return Err(FibQuantError::CorruptPayload(
128 "raw kv block head_dim mismatch".into(),
129 ));
130 }
131 if values.iter().any(|value| !value.is_finite()) {
132 return Err(FibQuantError::CorruptPayload(
133 "raw kv block contains non-finite value".into(),
134 ));
135 }
136 }
137 KvBlockEncodingV1::FibQuant { code } => {
138 if self.raw_fallback {
139 return Err(FibQuantError::CorruptPayload(
140 "compressed block cannot set raw_fallback".into(),
141 ));
142 }
143 if code.ambient_dim != head_dim {
144 return Err(FibQuantError::CorruptPayload(
145 "fib kv block ambient_dim mismatch".into(),
146 ));
147 }
148 }
149 }
150 Ok(())
151 }
152
153 pub fn digest(&self, head_dim: u32) -> Result<String> {
155 self.validate(head_dim)?;
156 json_digest(KV_BLOCK_SCHEMA, self)
157 }
158}