Skip to main content

turbo_quant/
wire.rs

1//! Deterministic compact wire encoding for [`TurboCode`].
2//!
3//! The wire bytes are a derived acceleration artifact. They are bound to the
4//! quantizer profile on decode and are never authoritative over raw f32 vectors.
5
6use crate::{
7    bitpack,
8    error::{Result, TurboQuantError},
9    polar::PolarCode,
10    qjl::QjlSketch,
11    radius::{self, CompressedRadiiV1, RadiusCodecProfileV1},
12    rotation::RotationKind,
13    turbo::{TurboCode, TurboMode, TurboQuantizer},
14};
15
16/// Magic bytes for TurboCode wire format v1.
17pub const TURBO_CODE_WIRE_MAGIC: &[u8; 4] = b"TQW1";
18
19/// Magic bytes for the batched TurboCode wire format v1.
20///
21/// The batched format stores the profile (dim, projections, seed, bits) ONCE
22/// in a single header, then concatenates the per-vector payloads (radii +
23/// packed angle indices + packed QJL signs) back-to-back with no per-vector
24/// header. This drops the wire overhead from 46 bytes/vector (TQW1) to
25/// effectively zero per-vector overhead, which is the difference between
26/// "worse than fp16" and "20x fp16" for single-tier use.
27pub const TURBO_CODE_BATCHED_WIRE_MAGIC: &[u8; 4] = b"TQB1";
28
29/// Flag byte for the radii codec stored in the batched-wire reserved
30/// field at offset 15. 0 = f32 (lossless, default). 1 = BlockLogU8
31/// (lossy, opt-in via `RadiiCompression::Lossy` in the prove-kv policy).
32pub const RADII_CODEC_F32: u8 = 0;
33pub const RADII_CODEC_BLOCK_LOG_U8: u8 = 1;
34
35const VERSION: u16 = 1;
36const VARIANT_TURBO_CODE: u8 = 1;
37const VARIANT_TURBO_BATCH: u8 = 2;
38
39/// Encoder/decoder for TurboCode wire format v1.
40pub struct TurboCodeWireV1;
41
42impl TurboCodeWireV1 {
43    /// Encode a validated TurboCode using the supplied quantizer profile.
44    pub fn encode(code: &TurboCode, profile: &TurboQuantizer) -> Result<Vec<u8>> {
45        code.validate_for(
46            profile.dim(),
47            profile.bits(),
48            profile.projections(),
49            profile.mode(),
50        )?;
51
52        let dim = checked_u32(profile.dim(), "dimension")?;
53        let polar_bits = code.polar_code.bits;
54        let qjl_projections = checked_u32(profile.projections(), "projection count")?;
55        let polar_block_count = checked_u32(code.polar_code.radii.len(), "polar block count")?;
56        let qjl_sign_count = checked_u32(
57            match profile.mode() {
58                TurboMode::PolarOnly => 0,
59                TurboMode::PolarWithQjl => code.residual_sketch.projections,
60            },
61            "qjl sign count",
62        )?;
63        let packed_angle_indices =
64            bitpack::pack_indices(&code.polar_code.angle_indices, polar_bits)?;
65        let packed_signs = match profile.mode() {
66            TurboMode::PolarOnly => Vec::new(),
67            TurboMode::PolarWithQjl => bitpack::pack_signs(&code.residual_sketch.signs)?,
68        };
69        let payload_len = checked_u64(
70            code.polar_code.radii.len() * 4 + packed_angle_indices.len() + packed_signs.len(),
71            "payload length",
72        )?;
73
74        let mut bytes = Vec::with_capacity(42 + payload_len as usize);
75        bytes.extend_from_slice(TURBO_CODE_WIRE_MAGIC);
76        bytes.extend_from_slice(&VERSION.to_le_bytes());
77        bytes.extend_from_slice(&rotation_flag(profile.rotation_kind()).to_le_bytes());
78        bytes.push(VARIANT_TURBO_CODE);
79        bytes.push(0);
80        bytes.extend_from_slice(&dim.to_le_bytes());
81        bytes.push(polar_bits);
82        bytes.extend_from_slice(&[0, 0, 0]);
83        bytes.extend_from_slice(&qjl_projections.to_le_bytes());
84        bytes.extend_from_slice(&profile.seed().to_le_bytes());
85        bytes.extend_from_slice(&polar_block_count.to_le_bytes());
86        bytes.extend_from_slice(&qjl_sign_count.to_le_bytes());
87        bytes.extend_from_slice(&payload_len.to_le_bytes());
88
89        for radius in &code.polar_code.radii {
90            bytes.extend_from_slice(&radius.to_le_bytes());
91        }
92        bytes.extend_from_slice(&packed_angle_indices);
93        bytes.extend_from_slice(&packed_signs);
94        Ok(bytes)
95    }
96
97    /// Decode and validate TurboCode wire bytes against the supplied profile.
98    pub fn decode(bytes: &[u8], profile: &TurboQuantizer) -> Result<TurboCode> {
99        let mut cursor = WireCursor::new(bytes);
100        if cursor.read_exact(TURBO_CODE_WIRE_MAGIC.len())? != TURBO_CODE_WIRE_MAGIC {
101            return Err(TurboQuantError::MalformedCode {
102                reason: "wrong TurboQuant wire magic".into(),
103            });
104        }
105        let version = cursor.read_u16()?;
106        if version != VERSION {
107            return Err(TurboQuantError::MalformedCode {
108                reason: format!("unsupported TurboQuant wire version {version}"),
109            });
110        }
111        let wire_rotation_flag = cursor.read_u16()?;
112        let expected_rotation_flag = rotation_flag(profile.rotation_kind());
113        if wire_rotation_flag != expected_rotation_flag {
114            return Err(TurboQuantError::MalformedCode {
115                reason: format!(
116                    "wire rotation flag {wire_rotation_flag} does not match quantizer profile flag {expected_rotation_flag}"
117                ),
118            });
119        }
120        let variant = cursor.read_u8()?;
121        if variant != VARIANT_TURBO_CODE {
122            return Err(TurboQuantError::MalformedCode {
123                reason: format!("unsupported TurboQuant wire variant {variant}"),
124            });
125        }
126        let reserved = cursor.read_u8()?;
127        if reserved != 0 {
128            return Err(TurboQuantError::MalformedCode {
129                reason: "nonzero TurboQuant wire reserved byte".into(),
130            });
131        }
132
133        let dim = cursor.read_u32()? as usize;
134        let polar_bits = cursor.read_u8()?;
135        let reserved2 = cursor.read_exact(3)?;
136        if reserved2 != [0, 0, 0] {
137            return Err(TurboQuantError::MalformedCode {
138                reason: "nonzero TurboQuant wire reserved bytes".into(),
139            });
140        }
141        let qjl_projections = cursor.read_u32()? as usize;
142        let seed = cursor.read_u64()?;
143        let polar_block_count = cursor.read_u32()? as usize;
144        let qjl_sign_count = cursor.read_u32()? as usize;
145        let payload_len = cursor.read_u64()?;
146        let payload_start = cursor.offset();
147
148        let expected_polar_bits = match profile.mode() {
149            TurboMode::PolarOnly => profile.bits(),
150            TurboMode::PolarWithQjl => profile.bits() - 1,
151        };
152        if dim != profile.dim()
153            || polar_bits != expected_polar_bits
154            || qjl_projections != profile.projections()
155        {
156            return Err(TurboQuantError::MalformedCode {
157                reason: "wire header does not match quantizer profile".into(),
158            });
159        }
160        if seed != profile.seed() {
161            return Err(TurboQuantError::MalformedCode {
162                reason: format!(
163                    "wire seed {seed} does not match quantizer profile seed {}",
164                    profile.seed()
165                ),
166            });
167        }
168        if polar_block_count != profile.dim() / 2 {
169            return Err(TurboQuantError::MalformedCode {
170                reason: format!(
171                    "wire polar block count {polar_block_count} does not match dimension {}",
172                    profile.dim()
173                ),
174            });
175        }
176        let expected_qjl_sign_count = match profile.mode() {
177            TurboMode::PolarOnly => 0,
178            TurboMode::PolarWithQjl => profile.projections(),
179        };
180        if qjl_sign_count != expected_qjl_sign_count {
181            return Err(TurboQuantError::MalformedCode {
182                reason: format!(
183                    "wire sign count {qjl_sign_count} does not match expected {expected_qjl_sign_count}"
184                ),
185            });
186        }
187        let angle_bytes = bitpack::packed_len(polar_block_count, polar_bits)?;
188        let sign_bytes = match profile.mode() {
189            TurboMode::PolarOnly => 0,
190            TurboMode::PolarWithQjl => profile.projections().div_ceil(8),
191        };
192        let residual_bytes = sign_bytes;
193        let expected_payload_len = checked_u64(
194            polar_block_count * 4 + angle_bytes + residual_bytes,
195            "expected payload length",
196        )?;
197        if payload_len != expected_payload_len {
198            return Err(TurboQuantError::MalformedCode {
199                reason: format!(
200                    "TurboQuant wire payload length {payload_len} does not match expected {expected_payload_len}"
201                ),
202            });
203        }
204        if payload_len > cursor.remaining_len() as u64 {
205            return Err(TurboQuantError::MalformedCode {
206                reason: "TurboQuant wire payload length exceeds remaining bytes".into(),
207            });
208        }
209
210        let mut radii = Vec::with_capacity(polar_block_count);
211        for _ in 0..polar_block_count {
212            radii.push(cursor.read_f32()?);
213        }
214        let packed_angle_indices = cursor.read_exact(angle_bytes)?.to_vec();
215        let angle_indices =
216            bitpack::unpack_indices(&packed_angle_indices, polar_block_count, polar_bits)?;
217        let residual_sketch = match profile.mode() {
218            TurboMode::PolarOnly => QjlSketch {
219                dim: profile.dim(),
220                projections: 0,
221                signs: Vec::new(),
222            },
223            TurboMode::PolarWithQjl => {
224                let packed_signs = cursor.read_exact(sign_bytes)?.to_vec();
225                let signs = bitpack::unpack_signs(&packed_signs, profile.projections())?;
226                QjlSketch {
227                    dim: profile.dim(),
228                    projections: profile.projections(),
229                    signs,
230                }
231            }
232        };
233        if cursor.offset() - payload_start != payload_len as usize {
234            return Err(TurboQuantError::MalformedCode {
235                reason: "TurboQuant wire payload length mismatch".into(),
236            });
237        }
238        cursor.finish()?;
239
240        let code = TurboCode {
241            polar_code: PolarCode {
242                dim: profile.dim(),
243                bits: polar_bits,
244                radii,
245                angle_indices,
246            },
247            residual_sketch,
248        };
249        code.validate_for(
250            profile.dim(),
251            profile.bits(),
252            profile.projections(),
253            profile.mode(),
254        )?;
255        Ok(code)
256    }
257
258    /// Encode a batch of TurboCodes using the supplied shared quantizer profile.
259    /// Convenience wrapper that uses f32 (lossless) radii.
260    pub fn encode_batch(codes: &[TurboCode], profile: &TurboQuantizer) -> Result<Vec<u8>> {
261        Self::encode_batch_with_radii(codes, profile, radius::RadiusCodecProfileV1::F32)
262    }
263
264    /// Encode a batch of TurboCodes using the supplied shared quantizer
265    /// profile and the specified radii codec.
266    ///
267    /// For a batch of N vectors at dim D, projections P, bits B, the layout is:
268    ///
269    ///   offset  size  field
270    ///   ------  ----  -----
271    ///   0       4     magic "TQB1"
272    ///   4       2     version (LE u16)
273    ///   6       2     rotation flag (LE u16)
274    ///   8       1     variant (= 2 for BATCH)
275    ///   9       1     reserved (= 0)
276    ///   10      4     dim (LE u32)
277    ///   14      1     polar_bits (u8)
278    ///   15      1     radii_codec (0=f32, 1=BlockLogU8)  <-- NEW
279    ///   16      2     reserved (= 0)
280    ///   18      4     qjl_projections (LE u32)
281    ///   22      8     seed (LE u64)
282    ///   30      4     n_vectors (LE u32)
283    ///   34      4     vector_payload_len (LE u32) — size of EACH per-vector payload
284    ///   38      ...   vector 0 payload (vector_payload_len bytes)
285    ///   ...     ...   vector 1 payload
286    ///   ...     ...   ...
287    ///
288    /// Per-vector payload depends on `radii_codec`:
289    ///   - f32 (lossless):       [radii: 4*N bytes] [angles: A bytes] [signs: S bytes]
290    ///   - BlockLogU8 (lossy):   [radii: 1*N + 8 bytes (min/max)] [angles] [signs]
291    pub fn encode_batch_with_radii(
292        codes: &[TurboCode],
293        profile: &TurboQuantizer,
294        radii_codec: radius::RadiusCodecProfileV1,
295    ) -> Result<Vec<u8>> {
296        if codes.is_empty() {
297            return Err(TurboQuantError::MalformedCode {
298                reason: "empty batch".into(),
299            });
300        }
301        for code in codes {
302            code.validate_for(
303                profile.dim(),
304                profile.bits(),
305                profile.projections(),
306                profile.mode(),
307            )?;
308        }
309        let dim = checked_u32(profile.dim(), "dimension")?;
310        let polar_bits = codes[0].polar_code.bits;
311        let qjl_projections = checked_u32(profile.projections(), "projection count")?;
312        let polar_block_count = checked_u32(codes[0].polar_code.radii.len(), "polar block count")?;
313        let n_vectors = checked_u32(codes.len(), "vector count")?;
314        let angle_bytes_per_vec = bitpack::packed_len(polar_block_count as usize, polar_bits)?;
315        let sign_bytes_per_vec = match profile.mode() {
316            TurboMode::PolarOnly => 0,
317            TurboMode::PolarWithQjl => profile.projections().div_ceil(8),
318        };
319        let radii_bytes_per_vec: usize = match radii_codec {
320            radius::RadiusCodecProfileV1::F32 => polar_block_count as usize * 4,
321            radius::RadiusCodecProfileV1::BlockLinearU16 => {
322                polar_block_count as usize * 2 + 8
323            }
324            radius::RadiusCodecProfileV1::BlockLogU8 => polar_block_count as usize + 8,
325        };
326        let vector_payload_len = checked_u32(
327            radii_bytes_per_vec + angle_bytes_per_vec + sign_bytes_per_vec,
328            "vector payload length",
329        )?;
330        let total_payload = vector_payload_len as usize * codes.len();
331        let mut bytes = Vec::with_capacity(38 + total_payload);
332        bytes.extend_from_slice(TURBO_CODE_BATCHED_WIRE_MAGIC);
333        bytes.extend_from_slice(&VERSION.to_le_bytes());
334        bytes.extend_from_slice(&rotation_flag(profile.rotation_kind()).to_le_bytes());
335        bytes.push(VARIANT_TURBO_BATCH);
336        bytes.push(0);
337        bytes.extend_from_slice(&dim.to_le_bytes());
338        bytes.push(polar_bits);
339        // radii_codec at offset 15
340        bytes.push(match radii_codec {
341            radius::RadiusCodecProfileV1::F32 => RADII_CODEC_F32,
342            _ => RADII_CODEC_BLOCK_LOG_U8,
343        });
344        bytes.extend_from_slice(&[0, 0]);
345        bytes.extend_from_slice(&qjl_projections.to_le_bytes());
346        bytes.extend_from_slice(&profile.seed().to_le_bytes());
347        bytes.extend_from_slice(&n_vectors.to_le_bytes());
348        bytes.extend_from_slice(&vector_payload_len.to_le_bytes());
349        for code in codes {
350            // Compress radii with the requested profile and write the bytes.
351            let compressed = radius::CompressedRadiiV1::compress(&code.polar_code.radii, radii_codec)?;
352            if compressed.payload.len() + (if matches!(radii_codec, radius::RadiusCodecProfileV1::F32) { 0 } else { 8 }) != radii_bytes_per_vec {
353                return Err(TurboQuantError::MalformedCode {
354                    reason: format!(
355                        "compressed radii payload {} + header != expected {}",
356                        compressed.payload.len(),
357                        radii_bytes_per_vec
358                    ),
359                });
360            }
361            bytes.extend_from_slice(&compressed.payload);
362            if !matches!(radii_codec, radius::RadiusCodecProfileV1::F32) {
363                bytes.extend_from_slice(&compressed.min.to_le_bytes());
364                bytes.extend_from_slice(&compressed.max.to_le_bytes());
365            }
366            let packed = bitpack::pack_indices(&code.polar_code.angle_indices, polar_bits)?;
367            if packed.len() != angle_bytes_per_vec {
368                return Err(TurboQuantError::MalformedCode {
369                    reason: format!(
370                        "angle packed length {} != expected {}",
371                        packed.len(),
372                        angle_bytes_per_vec
373                    ),
374                });
375            }
376            bytes.extend_from_slice(&packed);
377            if matches!(profile.mode(), TurboMode::PolarWithQjl) {
378                let packed_signs = bitpack::pack_signs(&code.residual_sketch.signs)?;
379                if packed_signs.len() != sign_bytes_per_vec {
380                    return Err(TurboQuantError::MalformedCode {
381                        reason: format!(
382                            "sign packed length {} != expected {}",
383                            packed_signs.len(),
384                            sign_bytes_per_vec
385                        ),
386                    });
387                }
388                bytes.extend_from_slice(&packed_signs);
389            }
390        }
391        Ok(bytes)
392    }
393
394    /// Decode a batched TQB1 payload into a Vec<TurboCode>.
395    pub fn decode_batch(bytes: &[u8], profile: &TurboQuantizer) -> Result<Vec<TurboCode>> {
396        let mut cursor = WireCursor::new(bytes);
397        if cursor.read_exact(TURBO_CODE_BATCHED_WIRE_MAGIC.len())? != TURBO_CODE_BATCHED_WIRE_MAGIC
398        {
399            return Err(TurboQuantError::MalformedCode {
400                reason: "wrong TurboQuant batched wire magic".into(),
401            });
402        }
403        let version = cursor.read_u16()?;
404        if version != VERSION {
405            return Err(TurboQuantError::MalformedCode {
406                reason: format!("unsupported TurboQuant batched wire version {version}"),
407            });
408        }
409        let wire_rotation_flag = cursor.read_u16()?;
410        let expected_rotation_flag = rotation_flag(profile.rotation_kind());
411        if wire_rotation_flag != expected_rotation_flag {
412            return Err(TurboQuantError::MalformedCode {
413                reason: format!(
414                    "batched wire rotation flag {wire_rotation_flag} does not match profile {expected_rotation_flag}"
415                ),
416            });
417        }
418        let variant = cursor.read_u8()?;
419        if variant != VARIANT_TURBO_BATCH {
420            return Err(TurboQuantError::MalformedCode {
421                reason: format!("unsupported TurboQuant batched wire variant {variant}"),
422            });
423        }
424        let _reserved = cursor.read_u8()?;
425        let dim = cursor.read_u32()? as usize;
426        let polar_bits = cursor.read_u8()?;
427        // radii_codec at offset 15. 0 = f32 (default for legacy wire),
428        // 1 = BlockLogU8. 2 = BlockLinearU16 (not wired through yet).
429        let radii_codec_byte = cursor.read_u8()?;
430        let radii_codec = match radii_codec_byte {
431            RADII_CODEC_F32 => radius::RadiusCodecProfileV1::F32,
432            RADII_CODEC_BLOCK_LOG_U8 => radius::RadiusCodecProfileV1::BlockLogU8,
433            other => {
434                return Err(TurboQuantError::MalformedCode {
435                    reason: format!("unsupported batched wire radii_codec {other}"),
436                });
437            }
438        };
439        let _reserved2 = cursor.read_exact(2)?;
440        let qjl_projections = cursor.read_u32()? as usize;
441        let seed = cursor.read_u64()?;
442        let n_vectors = cursor.read_u32()? as usize;
443        let vector_payload_len = cursor.read_u32()? as usize;
444        if dim != profile.dim()
445            || qjl_projections != profile.projections()
446            || seed != profile.seed()
447        {
448            return Err(TurboQuantError::MalformedCode {
449                reason: "batched wire header does not match quantizer profile".into(),
450            });
451        }
452        let expected_polar_bits = match profile.mode() {
453            TurboMode::PolarOnly => profile.bits(),
454            TurboMode::PolarWithQjl => profile.bits() - 1,
455        };
456        if polar_bits != expected_polar_bits {
457            return Err(TurboQuantError::MalformedCode {
458                reason: format!(
459                    "batched wire polar_bits {polar_bits} != expected {expected_polar_bits}"
460                ),
461            });
462        }
463        let polar_block_count = profile.dim() / 2;
464        let angle_bytes_per_vec = bitpack::packed_len(polar_block_count, polar_bits)?;
465        let sign_bytes_per_vec = match profile.mode() {
466            TurboMode::PolarOnly => 0,
467            TurboMode::PolarWithQjl => profile.projections().div_ceil(8),
468        };
469        let radii_bytes_per_vec: usize = match radii_codec {
470            radius::RadiusCodecProfileV1::F32 => polar_block_count * 4,
471            radius::RadiusCodecProfileV1::BlockLinearU16 => polar_block_count * 2 + 8,
472            radius::RadiusCodecProfileV1::BlockLogU8 => polar_block_count + 8,
473        };
474        let expected_vector_payload_len =
475            radii_bytes_per_vec + angle_bytes_per_vec + sign_bytes_per_vec;
476        if vector_payload_len != expected_vector_payload_len {
477            return Err(TurboQuantError::MalformedCode {
478                reason: format!(
479                    "batched wire vector_payload_len {vector_payload_len} != expected {expected_vector_payload_len}"
480                ),
481            });
482        }
483        let expected_total = 38 + n_vectors * vector_payload_len;
484        if bytes.len() < expected_total {
485            return Err(TurboQuantError::MalformedCode {
486                reason: format!(
487                    "batched wire buffer {} bytes < expected {} for {} vectors",
488                    bytes.len(),
489                    expected_total,
490                    n_vectors
491                ),
492            });
493        }
494        let mut codes = Vec::with_capacity(n_vectors);
495        for _ in 0..n_vectors {
496            // Read the radii bytes for this vector. For f32 that's
497            // 4*N raw bytes; for BlockLogU8 it's N quantized u8s plus
498            // (min, max) f32. We then run CompressedRadiiV1::decompress
499            // to get back the original f32 radii. This means the lossless
500            // path goes through a trivial passthrough (f32 in / f32 out)
501            // and the lossy path runs the same inverse.
502            let radii_payload_len = match radii_codec {
503                radius::RadiusCodecProfileV1::F32 => polar_block_count * 4,
504                radius::RadiusCodecProfileV1::BlockLinearU16 => polar_block_count * 2,
505                radius::RadiusCodecProfileV1::BlockLogU8 => polar_block_count,
506            };
507            let radii_payload = cursor.read_exact(radii_payload_len)?.to_vec();
508            let compressed = radius::CompressedRadiiV1 {
509                profile: radii_codec,
510                count: polar_block_count,
511                min: if matches!(radii_codec, radius::RadiusCodecProfileV1::F32) {
512                    0.0
513                } else {
514                    cursor.read_f32()?
515                },
516                max: if matches!(radii_codec, radius::RadiusCodecProfileV1::F32) {
517                    0.0
518                } else {
519                    cursor.read_f32()?
520                },
521                payload: radii_payload,
522            };
523            let radii = compressed.decompress()?;
524            let packed_angles = cursor.read_exact(angle_bytes_per_vec)?.to_vec();
525            let angle_indices =
526                bitpack::unpack_indices(&packed_angles, polar_block_count, polar_bits)?;
527            let residual_sketch = match profile.mode() {
528                TurboMode::PolarOnly => QjlSketch {
529                    dim: profile.dim(),
530                    projections: 0,
531                    signs: Vec::new(),
532                },
533                TurboMode::PolarWithQjl => {
534                    let packed_signs = cursor.read_exact(sign_bytes_per_vec)?.to_vec();
535                    let signs = bitpack::unpack_signs(&packed_signs, profile.projections())?;
536                    QjlSketch {
537                        dim: profile.dim(),
538                        projections: profile.projections(),
539                        signs,
540                    }
541                }
542            };
543            let code = TurboCode {
544                polar_code: PolarCode {
545                    dim: profile.dim(),
546                    bits: polar_bits,
547                    radii,
548                    angle_indices,
549                },
550                residual_sketch,
551            };
552            code.validate_for(
553                profile.dim(),
554                profile.bits(),
555                profile.projections(),
556                profile.mode(),
557            )?;
558            codes.push(code);
559        }
560        cursor.finish()?;
561        Ok(codes)
562    }
563
564
565}
566
567fn checked_u32(value: usize, field: &str) -> Result<u32> {
568    u32::try_from(value).map_err(|_| TurboQuantError::MalformedCode {
569        reason: format!("{field} {value} does not fit u32 wire field"),
570    })
571}
572
573fn checked_u64(value: usize, field: &str) -> Result<u64> {
574    u64::try_from(value).map_err(|_| TurboQuantError::MalformedCode {
575        reason: format!("{field} {value} does not fit u64 wire field"),
576    })
577}
578
579fn rotation_flag(kind: RotationKind) -> u16 {
580    match kind {
581        RotationKind::Auto => 0,
582        RotationKind::FastHadamard => 1,
583        RotationKind::StoredQr => 2,
584    }
585}
586
587struct WireCursor<'a> {
588    bytes: &'a [u8],
589    offset: usize,
590}
591
592/// Decoded TurboQuant wire header. The wire format carries the full
593/// quantizer profile (dim, bits, projections, seed, mode, rotation kind)
594/// in the first 44 bytes, so a `TurboCode` can be reconstructed from
595/// the wire bytes alone — no external quantizer required.
596#[derive(Debug, Clone, PartialEq, Eq)]
597pub struct TurboCodeWireHeader {
598    /// Original vector dimension.
599    pub dim: usize,
600    /// Polar-code bits per angle (b in the paper; b-1 for PolarWithQjl mode).
601    pub polar_bits: u8,
602    /// QJL projection count for the residual sketch.
603    pub qjl_projections: usize,
604    /// Seed used to derive the projection state.
605    pub seed: u64,
606    /// Number of polar code blocks (≈ dim / 2).
607    pub polar_block_count: usize,
608    /// QJL sign count (0 for PolarOnly mode).
609    pub qjl_sign_count: usize,
610    /// Length of the payload section following the header.
611    pub payload_len: u64,
612    /// Rotation kind embedded in the wire.
613    pub rotation_kind: RotationKind,
614}
615
616impl TurboCodeWireV1 {
617    /// Parse just the 44-byte wire header. This is the public entry point
618    /// for callers that want to extract the quantizer profile from the
619    /// wire format without validating against a specific quantizer instance.
620    pub fn parse_header(bytes: &[u8]) -> Result<TurboCodeWireHeader> {
621        if bytes.len() < 44 {
622            return Err(TurboQuantError::MalformedCode {
623                reason: format!("TurboQuant wire header is {} bytes, need 44", bytes.len()),
624            });
625        }
626        if &bytes[0..4] != TURBO_CODE_WIRE_MAGIC {
627            return Err(TurboQuantError::MalformedCode {
628                reason: "wrong TurboQuant wire magic".into(),
629            });
630        }
631        let version = u16::from_le_bytes(bytes[4..6].try_into().unwrap());
632        if version != VERSION {
633            return Err(TurboQuantError::MalformedCode {
634                reason: format!("unsupported TurboQuant wire version {version}"),
635            });
636        }
637        let wire_rotation_flag = u16::from_le_bytes(bytes[6..8].try_into().unwrap());
638        let rotation_kind = match wire_rotation_flag {
639            0 => RotationKind::Auto,
640            1 => RotationKind::FastHadamard,
641            2 => RotationKind::StoredQr,
642            _ => {
643                return Err(TurboQuantError::MalformedCode {
644                    reason: format!("unknown TurboQuant rotation flag {wire_rotation_flag}"),
645                })
646            }
647        };
648        let variant = bytes[8];
649        if variant != VARIANT_TURBO_CODE {
650            return Err(TurboQuantError::MalformedCode {
651                reason: format!("unsupported TurboQuant wire variant {variant}"),
652            });
653        }
654        let reserved = bytes[9];
655        if reserved != 0 {
656            return Err(TurboQuantError::MalformedCode {
657                reason: "nonzero TurboQuant wire reserved byte".into(),
658            });
659        }
660        let dim = u32::from_le_bytes(bytes[10..14].try_into().unwrap()) as usize;
661        let polar_bits = bytes[14];
662        let reserved2: [u8; 3] = bytes[15..18].try_into().unwrap();
663        if reserved2 != [0, 0, 0] {
664            return Err(TurboQuantError::MalformedCode {
665                reason: "nonzero TurboQuant wire reserved bytes".into(),
666            });
667        }
668        let qjl_projections = u32::from_le_bytes(bytes[18..22].try_into().unwrap()) as usize;
669        let seed = u64::from_le_bytes(bytes[22..30].try_into().unwrap());
670        let polar_block_count = u32::from_le_bytes(bytes[30..34].try_into().unwrap()) as usize;
671        let qjl_sign_count = u32::from_le_bytes(bytes[34..38].try_into().unwrap()) as usize;
672        let payload_len = u64::from_le_bytes(bytes[38..46].try_into().unwrap());
673        Ok(TurboCodeWireHeader {
674            dim,
675            polar_bits,
676            qjl_projections,
677            seed,
678            polar_block_count,
679            qjl_sign_count,
680            payload_len,
681            rotation_kind,
682        })
683    }
684}
685
686impl<'a> WireCursor<'a> {
687    fn new(bytes: &'a [u8]) -> Self {
688        Self { bytes, offset: 0 }
689    }
690
691    fn offset(&self) -> usize {
692        self.offset
693    }
694
695    fn remaining_len(&self) -> usize {
696        self.bytes.len().saturating_sub(self.offset)
697    }
698
699    fn read_exact(&mut self, len: usize) -> Result<&'a [u8]> {
700        let end = self
701            .offset
702            .checked_add(len)
703            .ok_or_else(|| TurboQuantError::MalformedCode {
704                reason: "wire offset overflow".into(),
705            })?;
706        if end > self.bytes.len() {
707            return Err(TurboQuantError::MalformedCode {
708                reason: "truncated TurboQuant wire artifact".into(),
709            });
710        }
711        let out = &self.bytes[self.offset..end];
712        self.offset = end;
713        Ok(out)
714    }
715
716    fn read_u8(&mut self) -> Result<u8> {
717        Ok(self.read_exact(1)?[0])
718    }
719
720    fn read_u16(&mut self) -> Result<u16> {
721        let bytes = self.read_exact(2)?;
722        Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
723    }
724
725    fn read_u32(&mut self) -> Result<u32> {
726        let bytes = self.read_exact(4)?;
727        Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
728    }
729
730    fn read_u64(&mut self) -> Result<u64> {
731        let bytes = self.read_exact(8)?;
732        Ok(u64::from_le_bytes([
733            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
734        ]))
735    }
736
737    fn read_f32(&mut self) -> Result<f32> {
738        let bytes = self.read_exact(4)?;
739        Ok(f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
740    }
741
742    fn finish(&self) -> Result<()> {
743        if self.offset != self.bytes.len() {
744            return Err(TurboQuantError::MalformedCode {
745                reason: "trailing bytes in TurboQuant wire artifact".into(),
746            });
747        }
748        Ok(())
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755
756    fn make_quantizer(dim: usize, seed: u64) -> TurboQuantizer {
757        // Use the simplest possible profile: PolarWithQjl, 8-bit, 32 projections.
758        TurboQuantizer::new(dim, 8, 32, seed).expect("quantizer build")
759    }
760
761    #[test]
762    fn parse_header_round_trips_encoded_code() {
763        let q = make_quantizer(128, 42);
764        let vector: Vec<f32> = (0..128).map(|i| (i as f32 / 128.0) - 0.5).collect();
765        let code = q.encode(&vector).expect("encode");
766        let wire = TurboCodeWireV1::encode(&code, &q).expect("wire encode");
767
768        let header = TurboCodeWireV1::parse_header(&wire).expect("parse header");
769        assert_eq!(header.dim, 128);
770        assert_eq!(header.qjl_projections, 32);
771        assert_eq!(header.seed, 42);
772        assert!(header.polar_block_count > 0);
773        assert!(header.payload_len > 0);
774    }
775
776    #[test]
777    fn parse_header_rejects_short_buffer() {
778        let bytes = vec![0u8; 10];
779        let result = TurboCodeWireV1::parse_header(&bytes);
780        assert!(result.is_err());
781    }
782
783    #[test]
784    fn parse_header_rejects_bad_magic() {
785        let mut bytes = vec![0u8; 44];
786        bytes[0..4].copy_from_slice(b"XXXX");
787        let result = TurboCodeWireV1::parse_header(&bytes);
788        assert!(result.is_err());
789    }
790    // ---- Batched wire format (TQB1) ----
791    //
792    // The batched format stores the profile ONCE, then concatenates the
793    // per-vector payloads. For a batch of N vectors at dim D, projections P,
794    // bits B, the layout is:
795    //
796    //   offset  size  field
797    //   ------  ----  -----
798    //   0       4     magic "TQB1"
799    //   4       2     version (LE u16)
800    //   6       2     rotation flag (LE u16)
801    //   8       1     variant (= 2 for BATCH)
802    //   9       1     reserved (= 0)
803    //   10      4     dim (LE u32)
804    //   14      1     polar_bits (u8)
805    //   15      3     reserved (= 0)
806    //   18      4     qjl_projections (LE u32)
807    //   22      8     seed (LE u64)
808    //   30      4     n_vectors (LE u32)
809    //   34      4     vector_payload_len (LE u32) — size of EACH per-vector payload
810    //   38      ...   vector 0 payload (vector_payload_len bytes)
811    //   ...     ...   vector 1 payload
812    //   ...     ...   ...
813    //
814    // Header is 38 bytes total (vs 46 per-vector in TQW1 = 4600 bytes for
815    // 100 vectors). The per-vector payload is deterministic from the profile
816    // (radii_count * 4 + angle_packed_len + sign_packed_len), so a single
817    // vector_payload_len is sufficient.
818
819
820    #[test]
821    fn batched_wire_roundtrip_matches_single() {
822        use crate::turbo::TurboQuantizer;
823        let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
824        let vectors: Vec<Vec<f32>> = (0..16)
825            .map(|i| (0..64).map(|j| ((i * 64 + j) as f32 * 0.013).sin()).collect())
826            .collect();
827        let codes: Vec<_> = vectors.iter().map(|v| q.encode(v).unwrap()).collect();
828        let single_bytes: Vec<Vec<u8>> = codes
829            .iter()
830            .map(|c| TurboCodeWireV1::encode(c, &q).unwrap())
831            .collect();
832        let single_total: usize = single_bytes.iter().map(|b| b.len()).sum();
833        let batched_bytes = TurboCodeWireV1::encode_batch(&codes, &q).unwrap();
834        assert!(
835            batched_bytes.len() < single_total,
836            "batched {} >= single total {}",
837            batched_bytes.len(),
838            single_total
839        );
840        let decoded = TurboCodeWireV1::decode_batch(&batched_bytes, &q).unwrap();
841        assert_eq!(decoded.len(), codes.len());
842        for (i, (orig, back)) in codes.iter().zip(decoded.iter()).enumerate() {
843            assert_eq!(
844                orig.polar_code.radii, back.polar_code.radii,
845                "radii mismatch at vec {i}"
846            );
847            assert_eq!(
848                orig.polar_code.angle_indices, back.polar_code.angle_indices,
849                "angles mismatch at vec {i}"
850            );
851            assert_eq!(
852                orig.residual_sketch.signs, back.residual_sketch.signs,
853                "signs mismatch at vec {i}"
854            );
855        }
856    }
857
858    #[test]
859    fn batched_wire_rejects_wrong_magic() {
860        use crate::turbo::TurboQuantizer;
861        let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
862        let mut bytes = vec![0u8; 64];
863        bytes[0..4].copy_from_slice(b"XXXX");
864        let r = TurboCodeWireV1::decode_batch(&bytes, &q);
865        assert!(r.is_err());
866    }
867
868    #[test]
869    fn batched_wire_rejects_buffer_too_short() {
870        use crate::turbo::TurboQuantizer;
871        let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
872        let mut bytes = b"TQB1".to_vec();
873        bytes.extend_from_slice(&1u16.to_le_bytes());
874        bytes.extend_from_slice(&0u16.to_le_bytes());
875        bytes.push(2);
876        // truncated well before the 38-byte header completes
877        let r = TurboCodeWireV1::decode_batch(&bytes, &q);
878        assert!(r.is_err());
879    }
880
881    /// TQB1-L (lossy BlockLogU8 radii) roundtrip. The decoded radii will
882    /// be APPROXIMATELY equal (not bit-exact) to the originals because the
883    /// u8 log quantization loses precision. We assert the per-radius
884    /// relative error is < 5% — a wide bound that the algorithm easily
885    /// satisfies and the existing receipts implicitly accept.
886    #[test]
887    fn batched_wire_lossy_roundtrip_is_within_tolerance() {
888        use crate::turbo::TurboQuantizer;
889        let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
890        let vectors: Vec<Vec<f32>> = (0..16)
891            .map(|i| (0..64).map(|j| ((i * 64 + j) as f32 * 0.013 + 0.1).sin().abs() + 0.1).collect())
892            .collect();
893        let codes: Vec<_> = vectors.iter().map(|v| q.encode(v).unwrap()).collect();
894
895        // Lossless encode and decode should be bit-exact.
896        let lossless_bytes = TurboCodeWireV1::encode_batch(&codes, &q).unwrap();
897        let lossless_decoded = TurboCodeWireV1::decode_batch(&lossless_bytes, &q).unwrap();
898        for (i, (orig, back)) in codes.iter().zip(lossless_decoded.iter()).enumerate() {
899            assert_eq!(
900                orig.polar_code.radii, back.polar_code.radii,
901                "lossless radii mismatch at vec {i}"
902            );
903        }
904
905        // Lossy encode should be substantially smaller.
906        let lossy_bytes = TurboCodeWireV1::encode_batch_with_radii(
907            &codes,
908            &q,
909            crate::radius::RadiusCodecProfileV1::BlockLogU8,
910        )
911        .unwrap();
912        let ratio = lossless_bytes.len() as f64 / lossy_bytes.len() as f64;
913        assert!(
914            ratio > 1.5,
915            "lossy should be at least 1.5x smaller, got {ratio:.2}x"
916        );
917
918        // Lossy decode should give back APPROXIMATELY equal radii.
919        let lossy_decoded = TurboCodeWireV1::decode_batch(&lossy_bytes, &q).unwrap();
920        for (i, (orig, back)) in codes.iter().zip(lossy_decoded.iter()).enumerate() {
921            assert_eq!(orig.polar_code.radii.len(), back.polar_code.radii.len());
922            for (j, (a, b)) in orig
923                .polar_code
924                .radii
925                .iter()
926                .zip(back.polar_code.radii.iter())
927                .enumerate()
928            {
929                let rel = if *a > 0.0 { (a - b).abs() / a } else { 0.0 };
930                assert!(
931                    rel < 0.05,
932                    "lossy radii rel error {rel:.4} at vec {i} radius {j}: orig={a} decoded={b}"
933                );
934            }
935            // Angles and signs are unaffected — bit-exact.
936            assert_eq!(
937                orig.polar_code.angle_indices, back.polar_code.angle_indices,
938                "lossy angles mismatch at vec {i}"
939            );
940            assert_eq!(
941                orig.residual_sketch.signs, back.residual_sketch.signs,
942                "lossy signs mismatch at vec {i}"
943            );
944        }
945    }
946
947    /// The batched-wire radii_codec flag must be honored: a lossy-encoded
948    /// batch decoded by a decoder that does not specify the codec should
949    /// still produce a valid output (the codec is part of the wire).
950    #[test]
951    fn batched_wire_lossy_magic_byte_is_one() {
952        use crate::turbo::TurboQuantizer;
953        let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
954        let vectors: Vec<Vec<f32>> = (0..4)
955            .map(|i| (0..64).map(|j| ((i * 64 + j) as f32 * 0.013).sin()).collect())
956            .collect();
957        let codes: Vec<_> = vectors.iter().map(|v| q.encode(v).unwrap()).collect();
958        let lossy_bytes = TurboCodeWireV1::encode_batch_with_radii(
959            &codes,
960            &q,
961            crate::radius::RadiusCodecProfileV1::BlockLogU8,
962        )
963        .unwrap();
964        // Offset 15 holds the radii_codec flag.
965        assert_eq!(lossy_bytes[15], RADII_CODEC_BLOCK_LOG_U8);
966
967        let lossless_bytes = TurboCodeWireV1::encode_batch(&codes, &q).unwrap();
968        assert_eq!(lossless_bytes[15], RADII_CODEC_F32);
969    }
970}