1use half::f16;
2use serde::{Deserialize, Serialize};
3
4use crate::{
5 bitpack::{pack_indices, unpack_indices},
6 codebook::FibCodebookV1,
7 digest::{bytes_digest, json_digest},
8 lloyd::nearest_index,
9 metrics,
10 profile::{FibQuantProfileV1, NormFormat},
11 receipt::FibQuantCompressionReceiptV1,
12 rotation::StoredRotation,
13 FibQuantError, Result,
14};
15
16pub const CODE_SCHEMA: &str = "fib_code_v1";
17
18#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20pub struct FibCodeV1 {
21 pub schema_version: String,
23 pub profile_digest: String,
25 pub codebook_digest: String,
27 pub rotation_digest: String,
29 pub ambient_dim: u32,
31 pub block_dim: u32,
33 pub norm_format: NormFormat,
35 pub norm_payload: Vec<u8>,
37 pub wire_index_bits: u8,
39 pub block_count: u32,
41 pub indices: Vec<u8>,
43}
44
45#[derive(Debug, Clone)]
47pub struct FibQuantizer {
48 profile: FibQuantProfileV1,
49 codebook: FibCodebookV1,
50 rotation: StoredRotation,
51}
52
53impl FibQuantizer {
54 pub fn new(profile: FibQuantProfileV1) -> Result<Self> {
56 let codebook = FibCodebookV1::build(profile)?;
57 Self::from_codebook(codebook)
58 }
59
60 pub fn from_codebook(codebook: FibCodebookV1) -> Result<Self> {
62 codebook.validate()?;
63 let profile = codebook.profile.clone();
64 let rotation = StoredRotation::new(profile.ambient_dim as usize, profile.rotation_seed)?;
65 Ok(Self {
66 profile,
67 codebook,
68 rotation,
69 })
70 }
71
72 pub fn profile(&self) -> &FibQuantProfileV1 {
74 &self.profile
75 }
76
77 pub fn codebook(&self) -> &FibCodebookV1 {
79 &self.codebook
80 }
81
82 pub fn encode(&self, x: &[f32]) -> Result<FibCodeV1> {
84 let d = self.profile.ambient_dim as usize;
85 let k = self.profile.block_dim as usize;
86 if x.len() != d {
87 return Err(FibQuantError::CorruptPayload(format!(
88 "input dimension {}, expected {d}",
89 x.len()
90 )));
91 }
92 check_finite(x)?;
93 let norm = l2_norm(x);
94 if norm == 0.0 {
95 return Err(FibQuantError::ZeroNorm);
96 }
97 let normalized: Vec<f64> = x.iter().map(|value| f64::from(*value) / norm).collect();
98 let rotated = self.rotation.apply(&normalized)?;
99 let codewords_f64: Vec<f64> = self
100 .codebook
101 .codewords
102 .iter()
103 .map(|value| f64::from(*value))
104 .collect();
105 let block_count = self.profile.block_count() as usize;
106 let mut indices = Vec::with_capacity(block_count);
107 for block in rotated.chunks_exact(k) {
108 indices.push(nearest_index(block, &codewords_f64, k).0 as u32);
109 }
110 Ok(FibCodeV1 {
111 schema_version: CODE_SCHEMA.into(),
112 profile_digest: self.profile.digest()?,
113 codebook_digest: self.codebook.codebook_digest.clone(),
114 rotation_digest: self.rotation.digest()?,
115 ambient_dim: self.profile.ambient_dim,
116 block_dim: self.profile.block_dim,
117 norm_format: self.profile.norm_format.clone(),
118 norm_payload: encode_norm(norm, &self.profile.norm_format)?,
119 wire_index_bits: self.profile.wire_index_bits,
120 block_count: self.profile.block_count(),
121 indices: pack_indices(&indices, self.profile.wire_index_bits)?,
122 })
123 }
124
125 pub fn decode(&self, code: &FibCodeV1) -> Result<Vec<f32>> {
127 self.validate_code_header(code)?;
128 let k = self.profile.block_dim as usize;
129 let block_count = self.profile.block_count() as usize;
130 let unpacked = unpack_indices(&code.indices, block_count, self.profile.wire_index_bits)?;
131 let mut rotated = Vec::with_capacity(self.profile.ambient_dim as usize);
132 for index in unpacked {
133 if index >= self.profile.codebook_size {
134 return Err(FibQuantError::IndexOutOfRange {
135 index,
136 codebook_size: self.profile.codebook_size,
137 });
138 }
139 rotated.extend(self.codebook.codeword(index as usize)?);
140 }
141 let expected_rotated_len = block_count.checked_mul(k).ok_or_else(|| {
142 FibQuantError::ResourceLimitExceeded("decoded rotated vector length overflow".into())
143 })?;
144 if rotated.len() != expected_rotated_len {
145 return Err(FibQuantError::CorruptPayload(
146 "decoded rotated vector length mismatch".into(),
147 ));
148 }
149 let norm = decode_norm(&code.norm_payload, &code.norm_format)?;
150 let reconstructed = self.rotation.apply_inverse(&rotated)?;
151 let out: Vec<f32> = reconstructed
152 .into_iter()
153 .map(|value| (value * norm) as f32)
154 .collect();
155 check_finite(&out)?;
156 Ok(out)
157 }
158
159 pub fn encode_with_receipt(
161 &self,
162 x: &[f32],
163 ) -> Result<(FibCodeV1, FibQuantCompressionReceiptV1)> {
164 let code = self.encode(x)?;
165 let source_vector_digest = source_vector_digest(x)?;
166 let mut receipt = FibQuantCompressionReceiptV1::new(
167 &self.profile,
168 code.profile_digest.clone(),
169 code.codebook_digest.clone(),
170 code.rotation_digest.clone(),
171 source_vector_digest,
172 encoded_digest(&code)?,
173 );
174 let decoded = self.decode(&code)?;
175 receipt.mse = Some(metrics::mse(x, &decoded)?);
176 receipt.cosine_similarity = Some(metrics::cosine_similarity(x, &decoded)?);
177 Ok((code, receipt))
178 }
179
180 pub fn reconstruction_mse(&self, x: &[f32]) -> Result<f64> {
182 let code = self.encode(x)?;
183 let decoded = self.decode(&code)?;
184 metrics::mse(x, &decoded)
185 }
186
187 pub fn cosine_similarity(&self, x: &[f32]) -> Result<f64> {
189 let code = self.encode(x)?;
190 let decoded = self.decode(&code)?;
191 metrics::cosine_similarity(x, &decoded)
192 }
193
194 fn validate_code_header(&self, code: &FibCodeV1) -> Result<()> {
195 if code.schema_version != CODE_SCHEMA {
196 return Err(FibQuantError::CorruptPayload(format!(
197 "code schema_version {}, expected {CODE_SCHEMA}",
198 code.schema_version
199 )));
200 }
201 let expected_profile = self.profile.digest()?;
202 if code.profile_digest != expected_profile {
203 return Err(FibQuantError::ProfileDigestMismatch {
204 expected: expected_profile,
205 actual: code.profile_digest.clone(),
206 });
207 }
208 if code.codebook_digest != self.codebook.codebook_digest {
209 return Err(FibQuantError::CodebookDigestMismatch {
210 expected: self.codebook.codebook_digest.clone(),
211 actual: code.codebook_digest.clone(),
212 });
213 }
214 let expected_rotation = self.rotation.digest()?;
215 if code.rotation_digest != expected_rotation
216 || code.rotation_digest != self.codebook.rotation_digest
217 {
218 return Err(FibQuantError::RotationDigestMismatch {
219 expected: expected_rotation,
220 actual: code.rotation_digest.clone(),
221 });
222 }
223 if code.ambient_dim != self.profile.ambient_dim
224 || code.block_dim != self.profile.block_dim
225 || code.block_count != self.profile.block_count()
226 || code.wire_index_bits != self.profile.wire_index_bits
227 || code.norm_format != self.profile.norm_format
228 {
229 return Err(FibQuantError::CorruptPayload(
230 "encoded header does not match profile".into(),
231 ));
232 }
233 Ok(())
234 }
235}
236
237pub fn encoded_digest(code: &FibCodeV1) -> Result<String> {
239 json_digest(CODE_SCHEMA, code)
240}
241
242fn source_vector_digest(x: &[f32]) -> Result<String> {
243 check_finite(x)?;
244 let mut bytes = Vec::with_capacity(32 + std::mem::size_of_val(x));
245 bytes.extend_from_slice(b"fib_quant_source_vector_v1");
246 bytes.push(0);
247 bytes.extend_from_slice(&(x.len() as u64).to_le_bytes());
248 for value in x {
249 bytes.extend_from_slice(&value.to_le_bytes());
250 }
251 Ok(bytes_digest(&bytes))
252}
253
254fn encode_norm(norm: f64, format: &NormFormat) -> Result<Vec<u8>> {
255 if !norm.is_finite() || norm <= 0.0 {
256 return Err(FibQuantError::CorruptPayload(
257 "norm must be finite and positive".into(),
258 ));
259 }
260 match format {
261 NormFormat::Fp16Paper => {
262 let narrowed = f16::from_f32(norm as f32);
263 if !narrowed.is_finite() || narrowed <= f16::ZERO {
264 return Err(FibQuantError::CorruptPayload(
265 "norm cannot be represented as finite positive fp16".into(),
266 ));
267 }
268 Ok(narrowed.to_le_bytes().to_vec())
269 }
270 NormFormat::F32Reference => {
271 let narrowed = norm as f32;
272 if !narrowed.is_finite() || narrowed <= 0.0 {
273 return Err(FibQuantError::CorruptPayload(
274 "norm cannot be represented as finite positive f32".into(),
275 ));
276 }
277 Ok(narrowed.to_le_bytes().to_vec())
278 }
279 }
280}
281
282fn decode_norm(bytes: &[u8], format: &NormFormat) -> Result<f64> {
283 match format {
284 NormFormat::Fp16Paper => {
285 let bytes: [u8; 2] = bytes
286 .try_into()
287 .map_err(|_| FibQuantError::CorruptPayload("fp16 norm length".into()))?;
288 let value = f16::from_le_bytes(bytes).to_f32() as f64;
289 if value.is_finite() && value > 0.0 {
290 Ok(value)
291 } else {
292 Err(FibQuantError::CorruptPayload("invalid fp16 norm".into()))
293 }
294 }
295 NormFormat::F32Reference => {
296 let bytes: [u8; 4] = bytes
297 .try_into()
298 .map_err(|_| FibQuantError::CorruptPayload("f32 norm length".into()))?;
299 let value = f32::from_le_bytes(bytes) as f64;
300 if value.is_finite() && value > 0.0 {
301 Ok(value)
302 } else {
303 Err(FibQuantError::CorruptPayload("invalid f32 norm".into()))
304 }
305 }
306 }
307}
308
309fn l2_norm(x: &[f32]) -> f64 {
310 x.iter()
311 .map(|value| {
312 let value = f64::from(*value);
313 value * value
314 })
315 .sum::<f64>()
316 .sqrt()
317}
318
319fn check_finite(x: &[f32]) -> Result<()> {
320 if let Some((idx, _)) = x.iter().enumerate().find(|(_, value)| !value.is_finite()) {
321 return Err(FibQuantError::NonFiniteInput(idx));
322 }
323 Ok(())
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn f32_norm_overflow_rejects_before_payload_emit() {
332 let err = encode_norm(f64::MAX, &NormFormat::F32Reference).unwrap_err();
333 assert!(matches!(err, FibQuantError::CorruptPayload(message) if message.contains("f32")));
334 }
335
336 #[test]
337 fn f32_norm_underflow_rejects_before_payload_emit() {
338 let err = encode_norm(
339 f64::from(f32::from_bits(1)) / 2.0,
340 &NormFormat::F32Reference,
341 )
342 .unwrap_err();
343 assert!(matches!(err, FibQuantError::CorruptPayload(message) if message.contains("f32")));
344 }
345}