Skip to main content

fib_quant/
codec.rs

1use half::f16;
2use serde::{Deserialize, Serialize};
3
4use crate::{
5    bitpack::{pack_indices, unpack_indices},
6    codebook::FibCodebookV1,
7    digest::{bytes_digest, json_digest},
8    lloyd::nearest_index,
9    metrics,
10    profile::{FibQuantProfileV1, NormFormat},
11    receipt::FibQuantCompressionReceiptV1,
12    rotation::StoredRotation,
13    FibQuantError, Result,
14};
15
16pub const CODE_SCHEMA: &str = "fib_code_v1";
17
18/// Encoded fixed-rate FibQuant artifact.
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
20pub struct FibCodeV1 {
21    /// Stable schema marker.
22    pub schema_version: String,
23    /// Profile digest.
24    pub profile_digest: String,
25    /// Codebook digest.
26    pub codebook_digest: String,
27    /// Rotation digest.
28    pub rotation_digest: String,
29    /// Ambient dimension.
30    pub ambient_dim: u32,
31    /// Block dimension.
32    pub block_dim: u32,
33    /// Norm payload format.
34    pub norm_format: NormFormat,
35    /// Norm bytes.
36    pub norm_payload: Vec<u8>,
37    /// Bits per fixed-rate index.
38    pub wire_index_bits: u8,
39    /// Number of indices.
40    pub block_count: u32,
41    /// Packed fixed-rate indices.
42    pub indices: Vec<u8>,
43}
44
45/// FibQuant encoder/decoder bound to one profile and codebook.
46#[derive(Debug, Clone)]
47pub struct FibQuantizer {
48    profile: FibQuantProfileV1,
49    codebook: FibCodebookV1,
50    rotation: StoredRotation,
51}
52
53impl FibQuantizer {
54    /// Build a quantizer by constructing the profile codebook.
55    pub fn new(profile: FibQuantProfileV1) -> Result<Self> {
56        let codebook = FibCodebookV1::build(profile)?;
57        Self::from_codebook(codebook)
58    }
59
60    /// Build a quantizer from a validated codebook.
61    pub fn from_codebook(codebook: FibCodebookV1) -> Result<Self> {
62        codebook.validate()?;
63        let profile = codebook.profile.clone();
64        let rotation = StoredRotation::new(profile.ambient_dim as usize, profile.rotation_seed)?;
65        Ok(Self {
66            profile,
67            codebook,
68            rotation,
69        })
70    }
71
72    /// Access the profile.
73    pub fn profile(&self) -> &FibQuantProfileV1 {
74        &self.profile
75    }
76
77    /// Access the codebook.
78    pub fn codebook(&self) -> &FibCodebookV1 {
79        &self.codebook
80    }
81
82    /// Encode a vector into a fixed-rate artifact.
83    pub fn encode(&self, x: &[f32]) -> Result<FibCodeV1> {
84        let d = self.profile.ambient_dim as usize;
85        let k = self.profile.block_dim as usize;
86        if x.len() != d {
87            return Err(FibQuantError::CorruptPayload(format!(
88                "input dimension {}, expected {d}",
89                x.len()
90            )));
91        }
92        check_finite(x)?;
93        let norm = l2_norm(x);
94        if norm == 0.0 {
95            return Err(FibQuantError::ZeroNorm);
96        }
97        let normalized: Vec<f64> = x.iter().map(|value| f64::from(*value) / norm).collect();
98        let rotated = self.rotation.apply(&normalized)?;
99        let codewords_f64: Vec<f64> = self
100            .codebook
101            .codewords
102            .iter()
103            .map(|value| f64::from(*value))
104            .collect();
105        let block_count = self.profile.block_count() as usize;
106        let mut indices = Vec::with_capacity(block_count);
107        for block in rotated.chunks_exact(k) {
108            indices.push(nearest_index(block, &codewords_f64, k).0 as u32);
109        }
110        Ok(FibCodeV1 {
111            schema_version: CODE_SCHEMA.into(),
112            profile_digest: self.profile.digest()?,
113            codebook_digest: self.codebook.codebook_digest.clone(),
114            rotation_digest: self.rotation.digest()?,
115            ambient_dim: self.profile.ambient_dim,
116            block_dim: self.profile.block_dim,
117            norm_format: self.profile.norm_format.clone(),
118            norm_payload: encode_norm(norm, &self.profile.norm_format)?,
119            wire_index_bits: self.profile.wire_index_bits,
120            block_count: self.profile.block_count(),
121            indices: pack_indices(&indices, self.profile.wire_index_bits)?,
122        })
123    }
124
125    /// Decode a fixed-rate artifact.
126    pub fn decode(&self, code: &FibCodeV1) -> Result<Vec<f32>> {
127        self.validate_code_header(code)?;
128        let k = self.profile.block_dim as usize;
129        let block_count = self.profile.block_count() as usize;
130        let unpacked = unpack_indices(&code.indices, block_count, self.profile.wire_index_bits)?;
131        let mut rotated = Vec::with_capacity(self.profile.ambient_dim as usize);
132        for index in unpacked {
133            if index >= self.profile.codebook_size {
134                return Err(FibQuantError::IndexOutOfRange {
135                    index,
136                    codebook_size: self.profile.codebook_size,
137                });
138            }
139            rotated.extend(self.codebook.codeword(index as usize)?);
140        }
141        let expected_rotated_len = block_count.checked_mul(k).ok_or_else(|| {
142            FibQuantError::ResourceLimitExceeded("decoded rotated vector length overflow".into())
143        })?;
144        if rotated.len() != expected_rotated_len {
145            return Err(FibQuantError::CorruptPayload(
146                "decoded rotated vector length mismatch".into(),
147            ));
148        }
149        let norm = decode_norm(&code.norm_payload, &code.norm_format)?;
150        let reconstructed = self.rotation.apply_inverse(&rotated)?;
151        let out: Vec<f32> = reconstructed
152            .into_iter()
153            .map(|value| (value * norm) as f32)
154            .collect();
155        check_finite(&out)?;
156        Ok(out)
157    }
158
159    /// Encode and emit a receipt.
160    pub fn encode_with_receipt(
161        &self,
162        x: &[f32],
163    ) -> Result<(FibCodeV1, FibQuantCompressionReceiptV1)> {
164        let code = self.encode(x)?;
165        let source_vector_digest = source_vector_digest(x)?;
166        let mut receipt = FibQuantCompressionReceiptV1::new(
167            &self.profile,
168            code.profile_digest.clone(),
169            code.codebook_digest.clone(),
170            code.rotation_digest.clone(),
171            source_vector_digest,
172            encoded_digest(&code)?,
173        );
174        let decoded = self.decode(&code)?;
175        receipt.mse = Some(metrics::mse(x, &decoded)?);
176        receipt.cosine_similarity = Some(metrics::cosine_similarity(x, &decoded)?);
177        Ok((code, receipt))
178    }
179
180    /// Reconstruction MSE for one vector.
181    pub fn reconstruction_mse(&self, x: &[f32]) -> Result<f64> {
182        let code = self.encode(x)?;
183        let decoded = self.decode(&code)?;
184        metrics::mse(x, &decoded)
185    }
186
187    /// Reconstruction cosine similarity for one vector.
188    pub fn cosine_similarity(&self, x: &[f32]) -> Result<f64> {
189        let code = self.encode(x)?;
190        let decoded = self.decode(&code)?;
191        metrics::cosine_similarity(x, &decoded)
192    }
193
194    fn validate_code_header(&self, code: &FibCodeV1) -> Result<()> {
195        if code.schema_version != CODE_SCHEMA {
196            return Err(FibQuantError::CorruptPayload(format!(
197                "code schema_version {}, expected {CODE_SCHEMA}",
198                code.schema_version
199            )));
200        }
201        let expected_profile = self.profile.digest()?;
202        if code.profile_digest != expected_profile {
203            return Err(FibQuantError::ProfileDigestMismatch {
204                expected: expected_profile,
205                actual: code.profile_digest.clone(),
206            });
207        }
208        if code.codebook_digest != self.codebook.codebook_digest {
209            return Err(FibQuantError::CodebookDigestMismatch {
210                expected: self.codebook.codebook_digest.clone(),
211                actual: code.codebook_digest.clone(),
212            });
213        }
214        let expected_rotation = self.rotation.digest()?;
215        if code.rotation_digest != expected_rotation
216            || code.rotation_digest != self.codebook.rotation_digest
217        {
218            return Err(FibQuantError::RotationDigestMismatch {
219                expected: expected_rotation,
220                actual: code.rotation_digest.clone(),
221            });
222        }
223        if code.ambient_dim != self.profile.ambient_dim
224            || code.block_dim != self.profile.block_dim
225            || code.block_count != self.profile.block_count()
226            || code.wire_index_bits != self.profile.wire_index_bits
227            || code.norm_format != self.profile.norm_format
228        {
229            return Err(FibQuantError::CorruptPayload(
230                "encoded header does not match profile".into(),
231            ));
232        }
233        Ok(())
234    }
235}
236
237/// Stable digest over the encoded artifact fields.
238pub fn encoded_digest(code: &FibCodeV1) -> Result<String> {
239    json_digest(CODE_SCHEMA, code)
240}
241
242fn source_vector_digest(x: &[f32]) -> Result<String> {
243    check_finite(x)?;
244    let mut bytes = Vec::with_capacity(32 + std::mem::size_of_val(x));
245    bytes.extend_from_slice(b"fib_quant_source_vector_v1");
246    bytes.push(0);
247    bytes.extend_from_slice(&(x.len() as u64).to_le_bytes());
248    for value in x {
249        bytes.extend_from_slice(&value.to_le_bytes());
250    }
251    Ok(bytes_digest(&bytes))
252}
253
254fn encode_norm(norm: f64, format: &NormFormat) -> Result<Vec<u8>> {
255    if !norm.is_finite() || norm <= 0.0 {
256        return Err(FibQuantError::CorruptPayload(
257            "norm must be finite and positive".into(),
258        ));
259    }
260    match format {
261        NormFormat::Fp16Paper => {
262            let narrowed = f16::from_f32(norm as f32);
263            if !narrowed.is_finite() || narrowed <= f16::ZERO {
264                return Err(FibQuantError::CorruptPayload(
265                    "norm cannot be represented as finite positive fp16".into(),
266                ));
267            }
268            Ok(narrowed.to_le_bytes().to_vec())
269        }
270        NormFormat::F32Reference => {
271            let narrowed = norm as f32;
272            if !narrowed.is_finite() || narrowed <= 0.0 {
273                return Err(FibQuantError::CorruptPayload(
274                    "norm cannot be represented as finite positive f32".into(),
275                ));
276            }
277            Ok(narrowed.to_le_bytes().to_vec())
278        }
279    }
280}
281
282fn decode_norm(bytes: &[u8], format: &NormFormat) -> Result<f64> {
283    match format {
284        NormFormat::Fp16Paper => {
285            let bytes: [u8; 2] = bytes
286                .try_into()
287                .map_err(|_| FibQuantError::CorruptPayload("fp16 norm length".into()))?;
288            let value = f16::from_le_bytes(bytes).to_f32() as f64;
289            if value.is_finite() && value > 0.0 {
290                Ok(value)
291            } else {
292                Err(FibQuantError::CorruptPayload("invalid fp16 norm".into()))
293            }
294        }
295        NormFormat::F32Reference => {
296            let bytes: [u8; 4] = bytes
297                .try_into()
298                .map_err(|_| FibQuantError::CorruptPayload("f32 norm length".into()))?;
299            let value = f32::from_le_bytes(bytes) as f64;
300            if value.is_finite() && value > 0.0 {
301                Ok(value)
302            } else {
303                Err(FibQuantError::CorruptPayload("invalid f32 norm".into()))
304            }
305        }
306    }
307}
308
309fn l2_norm(x: &[f32]) -> f64 {
310    x.iter()
311        .map(|value| {
312            let value = f64::from(*value);
313            value * value
314        })
315        .sum::<f64>()
316        .sqrt()
317}
318
319fn check_finite(x: &[f32]) -> Result<()> {
320    if let Some((idx, _)) = x.iter().enumerate().find(|(_, value)| !value.is_finite()) {
321        return Err(FibQuantError::NonFiniteInput(idx));
322    }
323    Ok(())
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn f32_norm_overflow_rejects_before_payload_emit() {
332        let err = encode_norm(f64::MAX, &NormFormat::F32Reference).unwrap_err();
333        assert!(matches!(err, FibQuantError::CorruptPayload(message) if message.contains("f32")));
334    }
335
336    #[test]
337    fn f32_norm_underflow_rejects_before_payload_emit() {
338        let err = encode_norm(
339            f64::from(f32::from_bits(1)) / 2.0,
340            &NormFormat::F32Reference,
341        )
342        .unwrap_err();
343        assert!(matches!(err, FibQuantError::CorruptPayload(message) if message.contains("f32")));
344    }
345}