1use crate::{
7 bitpack,
8 error::{Result, TurboQuantError},
9 polar::PolarCode,
10 qjl::QjlSketch,
11 radius::{self, CompressedRadiiV1, RadiusCodecProfileV1},
12 rotation::RotationKind,
13 turbo::{TurboCode, TurboMode, TurboQuantizer},
14};
15
16pub const TURBO_CODE_WIRE_MAGIC: &[u8; 4] = b"TQW1";
18
19pub const TURBO_CODE_BATCHED_WIRE_MAGIC: &[u8; 4] = b"TQB1";
28
29pub const RADII_CODEC_F32: u8 = 0;
33pub const RADII_CODEC_BLOCK_LOG_U8: u8 = 1;
34
35const VERSION: u16 = 1;
36const VARIANT_TURBO_CODE: u8 = 1;
37const VARIANT_TURBO_BATCH: u8 = 2;
38
39pub struct TurboCodeWireV1;
41
42impl TurboCodeWireV1 {
43 pub fn encode(code: &TurboCode, profile: &TurboQuantizer) -> Result<Vec<u8>> {
45 code.validate_for(
46 profile.dim(),
47 profile.bits(),
48 profile.projections(),
49 profile.mode(),
50 )?;
51
52 let dim = checked_u32(profile.dim(), "dimension")?;
53 let polar_bits = code.polar_code.bits;
54 let qjl_projections = checked_u32(profile.projections(), "projection count")?;
55 let polar_block_count = checked_u32(code.polar_code.radii.len(), "polar block count")?;
56 let qjl_sign_count = checked_u32(
57 match profile.mode() {
58 TurboMode::PolarOnly => 0,
59 TurboMode::PolarWithQjl => code.residual_sketch.projections,
60 },
61 "qjl sign count",
62 )?;
63 let packed_angle_indices =
64 bitpack::pack_indices(&code.polar_code.angle_indices, polar_bits)?;
65 let packed_signs = match profile.mode() {
66 TurboMode::PolarOnly => Vec::new(),
67 TurboMode::PolarWithQjl => bitpack::pack_signs(&code.residual_sketch.signs)?,
68 };
69 let payload_len = checked_u64(
70 code.polar_code.radii.len() * 4 + packed_angle_indices.len() + packed_signs.len(),
71 "payload length",
72 )?;
73
74 let mut bytes = Vec::with_capacity(42 + payload_len as usize);
75 bytes.extend_from_slice(TURBO_CODE_WIRE_MAGIC);
76 bytes.extend_from_slice(&VERSION.to_le_bytes());
77 bytes.extend_from_slice(&rotation_flag(profile.rotation_kind()).to_le_bytes());
78 bytes.push(VARIANT_TURBO_CODE);
79 bytes.push(0);
80 bytes.extend_from_slice(&dim.to_le_bytes());
81 bytes.push(polar_bits);
82 bytes.extend_from_slice(&[0, 0, 0]);
83 bytes.extend_from_slice(&qjl_projections.to_le_bytes());
84 bytes.extend_from_slice(&profile.seed().to_le_bytes());
85 bytes.extend_from_slice(&polar_block_count.to_le_bytes());
86 bytes.extend_from_slice(&qjl_sign_count.to_le_bytes());
87 bytes.extend_from_slice(&payload_len.to_le_bytes());
88
89 for radius in &code.polar_code.radii {
90 bytes.extend_from_slice(&radius.to_le_bytes());
91 }
92 bytes.extend_from_slice(&packed_angle_indices);
93 bytes.extend_from_slice(&packed_signs);
94 Ok(bytes)
95 }
96
97 pub fn decode(bytes: &[u8], profile: &TurboQuantizer) -> Result<TurboCode> {
99 let mut cursor = WireCursor::new(bytes);
100 if cursor.read_exact(TURBO_CODE_WIRE_MAGIC.len())? != TURBO_CODE_WIRE_MAGIC {
101 return Err(TurboQuantError::MalformedCode {
102 reason: "wrong TurboQuant wire magic".into(),
103 });
104 }
105 let version = cursor.read_u16()?;
106 if version != VERSION {
107 return Err(TurboQuantError::MalformedCode {
108 reason: format!("unsupported TurboQuant wire version {version}"),
109 });
110 }
111 let wire_rotation_flag = cursor.read_u16()?;
112 let expected_rotation_flag = rotation_flag(profile.rotation_kind());
113 if wire_rotation_flag != expected_rotation_flag {
114 return Err(TurboQuantError::MalformedCode {
115 reason: format!(
116 "wire rotation flag {wire_rotation_flag} does not match quantizer profile flag {expected_rotation_flag}"
117 ),
118 });
119 }
120 let variant = cursor.read_u8()?;
121 if variant != VARIANT_TURBO_CODE {
122 return Err(TurboQuantError::MalformedCode {
123 reason: format!("unsupported TurboQuant wire variant {variant}"),
124 });
125 }
126 let reserved = cursor.read_u8()?;
127 if reserved != 0 {
128 return Err(TurboQuantError::MalformedCode {
129 reason: "nonzero TurboQuant wire reserved byte".into(),
130 });
131 }
132
133 let dim = cursor.read_u32()? as usize;
134 let polar_bits = cursor.read_u8()?;
135 let reserved2 = cursor.read_exact(3)?;
136 if reserved2 != [0, 0, 0] {
137 return Err(TurboQuantError::MalformedCode {
138 reason: "nonzero TurboQuant wire reserved bytes".into(),
139 });
140 }
141 let qjl_projections = cursor.read_u32()? as usize;
142 let seed = cursor.read_u64()?;
143 let polar_block_count = cursor.read_u32()? as usize;
144 let qjl_sign_count = cursor.read_u32()? as usize;
145 let payload_len = cursor.read_u64()?;
146 let payload_start = cursor.offset();
147
148 let expected_polar_bits = match profile.mode() {
149 TurboMode::PolarOnly => profile.bits(),
150 TurboMode::PolarWithQjl => profile.bits() - 1,
151 };
152 if dim != profile.dim()
153 || polar_bits != expected_polar_bits
154 || qjl_projections != profile.projections()
155 {
156 return Err(TurboQuantError::MalformedCode {
157 reason: "wire header does not match quantizer profile".into(),
158 });
159 }
160 if seed != profile.seed() {
161 return Err(TurboQuantError::MalformedCode {
162 reason: format!(
163 "wire seed {seed} does not match quantizer profile seed {}",
164 profile.seed()
165 ),
166 });
167 }
168 if polar_block_count != profile.dim() / 2 {
169 return Err(TurboQuantError::MalformedCode {
170 reason: format!(
171 "wire polar block count {polar_block_count} does not match dimension {}",
172 profile.dim()
173 ),
174 });
175 }
176 let expected_qjl_sign_count = match profile.mode() {
177 TurboMode::PolarOnly => 0,
178 TurboMode::PolarWithQjl => profile.projections(),
179 };
180 if qjl_sign_count != expected_qjl_sign_count {
181 return Err(TurboQuantError::MalformedCode {
182 reason: format!(
183 "wire sign count {qjl_sign_count} does not match expected {expected_qjl_sign_count}"
184 ),
185 });
186 }
187 let angle_bytes = bitpack::packed_len(polar_block_count, polar_bits)?;
188 let sign_bytes = match profile.mode() {
189 TurboMode::PolarOnly => 0,
190 TurboMode::PolarWithQjl => profile.projections().div_ceil(8),
191 };
192 let residual_bytes = sign_bytes;
193 let expected_payload_len = checked_u64(
194 polar_block_count * 4 + angle_bytes + residual_bytes,
195 "expected payload length",
196 )?;
197 if payload_len != expected_payload_len {
198 return Err(TurboQuantError::MalformedCode {
199 reason: format!(
200 "TurboQuant wire payload length {payload_len} does not match expected {expected_payload_len}"
201 ),
202 });
203 }
204 if payload_len > cursor.remaining_len() as u64 {
205 return Err(TurboQuantError::MalformedCode {
206 reason: "TurboQuant wire payload length exceeds remaining bytes".into(),
207 });
208 }
209
210 let mut radii = Vec::with_capacity(polar_block_count);
211 for _ in 0..polar_block_count {
212 radii.push(cursor.read_f32()?);
213 }
214 let packed_angle_indices = cursor.read_exact(angle_bytes)?.to_vec();
215 let angle_indices =
216 bitpack::unpack_indices(&packed_angle_indices, polar_block_count, polar_bits)?;
217 let residual_sketch = match profile.mode() {
218 TurboMode::PolarOnly => QjlSketch {
219 dim: profile.dim(),
220 projections: 0,
221 signs: Vec::new(),
222 },
223 TurboMode::PolarWithQjl => {
224 let packed_signs = cursor.read_exact(sign_bytes)?.to_vec();
225 let signs = bitpack::unpack_signs(&packed_signs, profile.projections())?;
226 QjlSketch {
227 dim: profile.dim(),
228 projections: profile.projections(),
229 signs,
230 }
231 }
232 };
233 if cursor.offset() - payload_start != payload_len as usize {
234 return Err(TurboQuantError::MalformedCode {
235 reason: "TurboQuant wire payload length mismatch".into(),
236 });
237 }
238 cursor.finish()?;
239
240 let code = TurboCode {
241 polar_code: PolarCode {
242 dim: profile.dim(),
243 bits: polar_bits,
244 radii,
245 angle_indices,
246 },
247 residual_sketch,
248 };
249 code.validate_for(
250 profile.dim(),
251 profile.bits(),
252 profile.projections(),
253 profile.mode(),
254 )?;
255 Ok(code)
256 }
257
258 pub fn encode_batch(codes: &[TurboCode], profile: &TurboQuantizer) -> Result<Vec<u8>> {
261 Self::encode_batch_with_radii(codes, profile, radius::RadiusCodecProfileV1::F32)
262 }
263
264 pub fn encode_batch_with_radii(
292 codes: &[TurboCode],
293 profile: &TurboQuantizer,
294 radii_codec: radius::RadiusCodecProfileV1,
295 ) -> Result<Vec<u8>> {
296 if codes.is_empty() {
297 return Err(TurboQuantError::MalformedCode {
298 reason: "empty batch".into(),
299 });
300 }
301 for code in codes {
302 code.validate_for(
303 profile.dim(),
304 profile.bits(),
305 profile.projections(),
306 profile.mode(),
307 )?;
308 }
309 let dim = checked_u32(profile.dim(), "dimension")?;
310 let polar_bits = codes[0].polar_code.bits;
311 let qjl_projections = checked_u32(profile.projections(), "projection count")?;
312 let polar_block_count = checked_u32(codes[0].polar_code.radii.len(), "polar block count")?;
313 let n_vectors = checked_u32(codes.len(), "vector count")?;
314 let angle_bytes_per_vec = bitpack::packed_len(polar_block_count as usize, polar_bits)?;
315 let sign_bytes_per_vec = match profile.mode() {
316 TurboMode::PolarOnly => 0,
317 TurboMode::PolarWithQjl => profile.projections().div_ceil(8),
318 };
319 let radii_bytes_per_vec: usize = match radii_codec {
320 radius::RadiusCodecProfileV1::F32 => polar_block_count as usize * 4,
321 radius::RadiusCodecProfileV1::BlockLinearU16 => {
322 polar_block_count as usize * 2 + 8
323 }
324 radius::RadiusCodecProfileV1::BlockLogU8 => polar_block_count as usize + 8,
325 };
326 let vector_payload_len = checked_u32(
327 radii_bytes_per_vec + angle_bytes_per_vec + sign_bytes_per_vec,
328 "vector payload length",
329 )?;
330 let total_payload = vector_payload_len as usize * codes.len();
331 let mut bytes = Vec::with_capacity(38 + total_payload);
332 bytes.extend_from_slice(TURBO_CODE_BATCHED_WIRE_MAGIC);
333 bytes.extend_from_slice(&VERSION.to_le_bytes());
334 bytes.extend_from_slice(&rotation_flag(profile.rotation_kind()).to_le_bytes());
335 bytes.push(VARIANT_TURBO_BATCH);
336 bytes.push(0);
337 bytes.extend_from_slice(&dim.to_le_bytes());
338 bytes.push(polar_bits);
339 bytes.push(match radii_codec {
341 radius::RadiusCodecProfileV1::F32 => RADII_CODEC_F32,
342 _ => RADII_CODEC_BLOCK_LOG_U8,
343 });
344 bytes.extend_from_slice(&[0, 0]);
345 bytes.extend_from_slice(&qjl_projections.to_le_bytes());
346 bytes.extend_from_slice(&profile.seed().to_le_bytes());
347 bytes.extend_from_slice(&n_vectors.to_le_bytes());
348 bytes.extend_from_slice(&vector_payload_len.to_le_bytes());
349 for code in codes {
350 let compressed = radius::CompressedRadiiV1::compress(&code.polar_code.radii, radii_codec)?;
352 if compressed.payload.len() + (if matches!(radii_codec, radius::RadiusCodecProfileV1::F32) { 0 } else { 8 }) != radii_bytes_per_vec {
353 return Err(TurboQuantError::MalformedCode {
354 reason: format!(
355 "compressed radii payload {} + header != expected {}",
356 compressed.payload.len(),
357 radii_bytes_per_vec
358 ),
359 });
360 }
361 bytes.extend_from_slice(&compressed.payload);
362 if !matches!(radii_codec, radius::RadiusCodecProfileV1::F32) {
363 bytes.extend_from_slice(&compressed.min.to_le_bytes());
364 bytes.extend_from_slice(&compressed.max.to_le_bytes());
365 }
366 let packed = bitpack::pack_indices(&code.polar_code.angle_indices, polar_bits)?;
367 if packed.len() != angle_bytes_per_vec {
368 return Err(TurboQuantError::MalformedCode {
369 reason: format!(
370 "angle packed length {} != expected {}",
371 packed.len(),
372 angle_bytes_per_vec
373 ),
374 });
375 }
376 bytes.extend_from_slice(&packed);
377 if matches!(profile.mode(), TurboMode::PolarWithQjl) {
378 let packed_signs = bitpack::pack_signs(&code.residual_sketch.signs)?;
379 if packed_signs.len() != sign_bytes_per_vec {
380 return Err(TurboQuantError::MalformedCode {
381 reason: format!(
382 "sign packed length {} != expected {}",
383 packed_signs.len(),
384 sign_bytes_per_vec
385 ),
386 });
387 }
388 bytes.extend_from_slice(&packed_signs);
389 }
390 }
391 Ok(bytes)
392 }
393
394 pub fn decode_batch(bytes: &[u8], profile: &TurboQuantizer) -> Result<Vec<TurboCode>> {
396 let mut cursor = WireCursor::new(bytes);
397 if cursor.read_exact(TURBO_CODE_BATCHED_WIRE_MAGIC.len())? != TURBO_CODE_BATCHED_WIRE_MAGIC
398 {
399 return Err(TurboQuantError::MalformedCode {
400 reason: "wrong TurboQuant batched wire magic".into(),
401 });
402 }
403 let version = cursor.read_u16()?;
404 if version != VERSION {
405 return Err(TurboQuantError::MalformedCode {
406 reason: format!("unsupported TurboQuant batched wire version {version}"),
407 });
408 }
409 let wire_rotation_flag = cursor.read_u16()?;
410 let expected_rotation_flag = rotation_flag(profile.rotation_kind());
411 if wire_rotation_flag != expected_rotation_flag {
412 return Err(TurboQuantError::MalformedCode {
413 reason: format!(
414 "batched wire rotation flag {wire_rotation_flag} does not match profile {expected_rotation_flag}"
415 ),
416 });
417 }
418 let variant = cursor.read_u8()?;
419 if variant != VARIANT_TURBO_BATCH {
420 return Err(TurboQuantError::MalformedCode {
421 reason: format!("unsupported TurboQuant batched wire variant {variant}"),
422 });
423 }
424 let _reserved = cursor.read_u8()?;
425 let dim = cursor.read_u32()? as usize;
426 let polar_bits = cursor.read_u8()?;
427 let radii_codec_byte = cursor.read_u8()?;
430 let radii_codec = match radii_codec_byte {
431 RADII_CODEC_F32 => radius::RadiusCodecProfileV1::F32,
432 RADII_CODEC_BLOCK_LOG_U8 => radius::RadiusCodecProfileV1::BlockLogU8,
433 other => {
434 return Err(TurboQuantError::MalformedCode {
435 reason: format!("unsupported batched wire radii_codec {other}"),
436 });
437 }
438 };
439 let _reserved2 = cursor.read_exact(2)?;
440 let qjl_projections = cursor.read_u32()? as usize;
441 let seed = cursor.read_u64()?;
442 let n_vectors = cursor.read_u32()? as usize;
443 let vector_payload_len = cursor.read_u32()? as usize;
444 if dim != profile.dim()
445 || qjl_projections != profile.projections()
446 || seed != profile.seed()
447 {
448 return Err(TurboQuantError::MalformedCode {
449 reason: "batched wire header does not match quantizer profile".into(),
450 });
451 }
452 let expected_polar_bits = match profile.mode() {
453 TurboMode::PolarOnly => profile.bits(),
454 TurboMode::PolarWithQjl => profile.bits() - 1,
455 };
456 if polar_bits != expected_polar_bits {
457 return Err(TurboQuantError::MalformedCode {
458 reason: format!(
459 "batched wire polar_bits {polar_bits} != expected {expected_polar_bits}"
460 ),
461 });
462 }
463 let polar_block_count = profile.dim() / 2;
464 let angle_bytes_per_vec = bitpack::packed_len(polar_block_count, polar_bits)?;
465 let sign_bytes_per_vec = match profile.mode() {
466 TurboMode::PolarOnly => 0,
467 TurboMode::PolarWithQjl => profile.projections().div_ceil(8),
468 };
469 let radii_bytes_per_vec: usize = match radii_codec {
470 radius::RadiusCodecProfileV1::F32 => polar_block_count * 4,
471 radius::RadiusCodecProfileV1::BlockLinearU16 => polar_block_count * 2 + 8,
472 radius::RadiusCodecProfileV1::BlockLogU8 => polar_block_count + 8,
473 };
474 let expected_vector_payload_len =
475 radii_bytes_per_vec + angle_bytes_per_vec + sign_bytes_per_vec;
476 if vector_payload_len != expected_vector_payload_len {
477 return Err(TurboQuantError::MalformedCode {
478 reason: format!(
479 "batched wire vector_payload_len {vector_payload_len} != expected {expected_vector_payload_len}"
480 ),
481 });
482 }
483 let expected_total = 38 + n_vectors * vector_payload_len;
484 if bytes.len() < expected_total {
485 return Err(TurboQuantError::MalformedCode {
486 reason: format!(
487 "batched wire buffer {} bytes < expected {} for {} vectors",
488 bytes.len(),
489 expected_total,
490 n_vectors
491 ),
492 });
493 }
494 let mut codes = Vec::with_capacity(n_vectors);
495 for _ in 0..n_vectors {
496 let radii_payload_len = match radii_codec {
503 radius::RadiusCodecProfileV1::F32 => polar_block_count * 4,
504 radius::RadiusCodecProfileV1::BlockLinearU16 => polar_block_count * 2,
505 radius::RadiusCodecProfileV1::BlockLogU8 => polar_block_count,
506 };
507 let radii_payload = cursor.read_exact(radii_payload_len)?.to_vec();
508 let compressed = radius::CompressedRadiiV1 {
509 profile: radii_codec,
510 count: polar_block_count,
511 min: if matches!(radii_codec, radius::RadiusCodecProfileV1::F32) {
512 0.0
513 } else {
514 cursor.read_f32()?
515 },
516 max: if matches!(radii_codec, radius::RadiusCodecProfileV1::F32) {
517 0.0
518 } else {
519 cursor.read_f32()?
520 },
521 payload: radii_payload,
522 };
523 let radii = compressed.decompress()?;
524 let packed_angles = cursor.read_exact(angle_bytes_per_vec)?.to_vec();
525 let angle_indices =
526 bitpack::unpack_indices(&packed_angles, polar_block_count, polar_bits)?;
527 let residual_sketch = match profile.mode() {
528 TurboMode::PolarOnly => QjlSketch {
529 dim: profile.dim(),
530 projections: 0,
531 signs: Vec::new(),
532 },
533 TurboMode::PolarWithQjl => {
534 let packed_signs = cursor.read_exact(sign_bytes_per_vec)?.to_vec();
535 let signs = bitpack::unpack_signs(&packed_signs, profile.projections())?;
536 QjlSketch {
537 dim: profile.dim(),
538 projections: profile.projections(),
539 signs,
540 }
541 }
542 };
543 let code = TurboCode {
544 polar_code: PolarCode {
545 dim: profile.dim(),
546 bits: polar_bits,
547 radii,
548 angle_indices,
549 },
550 residual_sketch,
551 };
552 code.validate_for(
553 profile.dim(),
554 profile.bits(),
555 profile.projections(),
556 profile.mode(),
557 )?;
558 codes.push(code);
559 }
560 cursor.finish()?;
561 Ok(codes)
562 }
563
564
565}
566
567fn checked_u32(value: usize, field: &str) -> Result<u32> {
568 u32::try_from(value).map_err(|_| TurboQuantError::MalformedCode {
569 reason: format!("{field} {value} does not fit u32 wire field"),
570 })
571}
572
573fn checked_u64(value: usize, field: &str) -> Result<u64> {
574 u64::try_from(value).map_err(|_| TurboQuantError::MalformedCode {
575 reason: format!("{field} {value} does not fit u64 wire field"),
576 })
577}
578
579fn rotation_flag(kind: RotationKind) -> u16 {
580 match kind {
581 RotationKind::Auto => 0,
582 RotationKind::FastHadamard => 1,
583 RotationKind::StoredQr => 2,
584 }
585}
586
587struct WireCursor<'a> {
588 bytes: &'a [u8],
589 offset: usize,
590}
591
592#[derive(Debug, Clone, PartialEq, Eq)]
597pub struct TurboCodeWireHeader {
598 pub dim: usize,
600 pub polar_bits: u8,
602 pub qjl_projections: usize,
604 pub seed: u64,
606 pub polar_block_count: usize,
608 pub qjl_sign_count: usize,
610 pub payload_len: u64,
612 pub rotation_kind: RotationKind,
614}
615
616impl TurboCodeWireV1 {
617 pub fn parse_header(bytes: &[u8]) -> Result<TurboCodeWireHeader> {
621 if bytes.len() < 44 {
622 return Err(TurboQuantError::MalformedCode {
623 reason: format!("TurboQuant wire header is {} bytes, need 44", bytes.len()),
624 });
625 }
626 if &bytes[0..4] != TURBO_CODE_WIRE_MAGIC {
627 return Err(TurboQuantError::MalformedCode {
628 reason: "wrong TurboQuant wire magic".into(),
629 });
630 }
631 let version = u16::from_le_bytes(bytes[4..6].try_into().unwrap());
632 if version != VERSION {
633 return Err(TurboQuantError::MalformedCode {
634 reason: format!("unsupported TurboQuant wire version {version}"),
635 });
636 }
637 let wire_rotation_flag = u16::from_le_bytes(bytes[6..8].try_into().unwrap());
638 let rotation_kind = match wire_rotation_flag {
639 0 => RotationKind::Auto,
640 1 => RotationKind::FastHadamard,
641 2 => RotationKind::StoredQr,
642 _ => {
643 return Err(TurboQuantError::MalformedCode {
644 reason: format!("unknown TurboQuant rotation flag {wire_rotation_flag}"),
645 })
646 }
647 };
648 let variant = bytes[8];
649 if variant != VARIANT_TURBO_CODE {
650 return Err(TurboQuantError::MalformedCode {
651 reason: format!("unsupported TurboQuant wire variant {variant}"),
652 });
653 }
654 let reserved = bytes[9];
655 if reserved != 0 {
656 return Err(TurboQuantError::MalformedCode {
657 reason: "nonzero TurboQuant wire reserved byte".into(),
658 });
659 }
660 let dim = u32::from_le_bytes(bytes[10..14].try_into().unwrap()) as usize;
661 let polar_bits = bytes[14];
662 let reserved2: [u8; 3] = bytes[15..18].try_into().unwrap();
663 if reserved2 != [0, 0, 0] {
664 return Err(TurboQuantError::MalformedCode {
665 reason: "nonzero TurboQuant wire reserved bytes".into(),
666 });
667 }
668 let qjl_projections = u32::from_le_bytes(bytes[18..22].try_into().unwrap()) as usize;
669 let seed = u64::from_le_bytes(bytes[22..30].try_into().unwrap());
670 let polar_block_count = u32::from_le_bytes(bytes[30..34].try_into().unwrap()) as usize;
671 let qjl_sign_count = u32::from_le_bytes(bytes[34..38].try_into().unwrap()) as usize;
672 let payload_len = u64::from_le_bytes(bytes[38..46].try_into().unwrap());
673 Ok(TurboCodeWireHeader {
674 dim,
675 polar_bits,
676 qjl_projections,
677 seed,
678 polar_block_count,
679 qjl_sign_count,
680 payload_len,
681 rotation_kind,
682 })
683 }
684}
685
686impl<'a> WireCursor<'a> {
687 fn new(bytes: &'a [u8]) -> Self {
688 Self { bytes, offset: 0 }
689 }
690
691 fn offset(&self) -> usize {
692 self.offset
693 }
694
695 fn remaining_len(&self) -> usize {
696 self.bytes.len().saturating_sub(self.offset)
697 }
698
699 fn read_exact(&mut self, len: usize) -> Result<&'a [u8]> {
700 let end = self
701 .offset
702 .checked_add(len)
703 .ok_or_else(|| TurboQuantError::MalformedCode {
704 reason: "wire offset overflow".into(),
705 })?;
706 if end > self.bytes.len() {
707 return Err(TurboQuantError::MalformedCode {
708 reason: "truncated TurboQuant wire artifact".into(),
709 });
710 }
711 let out = &self.bytes[self.offset..end];
712 self.offset = end;
713 Ok(out)
714 }
715
716 fn read_u8(&mut self) -> Result<u8> {
717 Ok(self.read_exact(1)?[0])
718 }
719
720 fn read_u16(&mut self) -> Result<u16> {
721 let bytes = self.read_exact(2)?;
722 Ok(u16::from_le_bytes([bytes[0], bytes[1]]))
723 }
724
725 fn read_u32(&mut self) -> Result<u32> {
726 let bytes = self.read_exact(4)?;
727 Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
728 }
729
730 fn read_u64(&mut self) -> Result<u64> {
731 let bytes = self.read_exact(8)?;
732 Ok(u64::from_le_bytes([
733 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
734 ]))
735 }
736
737 fn read_f32(&mut self) -> Result<f32> {
738 let bytes = self.read_exact(4)?;
739 Ok(f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
740 }
741
742 fn finish(&self) -> Result<()> {
743 if self.offset != self.bytes.len() {
744 return Err(TurboQuantError::MalformedCode {
745 reason: "trailing bytes in TurboQuant wire artifact".into(),
746 });
747 }
748 Ok(())
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755
756 fn make_quantizer(dim: usize, seed: u64) -> TurboQuantizer {
757 TurboQuantizer::new(dim, 8, 32, seed).expect("quantizer build")
759 }
760
761 #[test]
762 fn parse_header_round_trips_encoded_code() {
763 let q = make_quantizer(128, 42);
764 let vector: Vec<f32> = (0..128).map(|i| (i as f32 / 128.0) - 0.5).collect();
765 let code = q.encode(&vector).expect("encode");
766 let wire = TurboCodeWireV1::encode(&code, &q).expect("wire encode");
767
768 let header = TurboCodeWireV1::parse_header(&wire).expect("parse header");
769 assert_eq!(header.dim, 128);
770 assert_eq!(header.qjl_projections, 32);
771 assert_eq!(header.seed, 42);
772 assert!(header.polar_block_count > 0);
773 assert!(header.payload_len > 0);
774 }
775
776 #[test]
777 fn parse_header_rejects_short_buffer() {
778 let bytes = vec![0u8; 10];
779 let result = TurboCodeWireV1::parse_header(&bytes);
780 assert!(result.is_err());
781 }
782
783 #[test]
784 fn parse_header_rejects_bad_magic() {
785 let mut bytes = vec![0u8; 44];
786 bytes[0..4].copy_from_slice(b"XXXX");
787 let result = TurboCodeWireV1::parse_header(&bytes);
788 assert!(result.is_err());
789 }
790 #[test]
821 fn batched_wire_roundtrip_matches_single() {
822 use crate::turbo::TurboQuantizer;
823 let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
824 let vectors: Vec<Vec<f32>> = (0..16)
825 .map(|i| (0..64).map(|j| ((i * 64 + j) as f32 * 0.013).sin()).collect())
826 .collect();
827 let codes: Vec<_> = vectors.iter().map(|v| q.encode(v).unwrap()).collect();
828 let single_bytes: Vec<Vec<u8>> = codes
829 .iter()
830 .map(|c| TurboCodeWireV1::encode(c, &q).unwrap())
831 .collect();
832 let single_total: usize = single_bytes.iter().map(|b| b.len()).sum();
833 let batched_bytes = TurboCodeWireV1::encode_batch(&codes, &q).unwrap();
834 assert!(
835 batched_bytes.len() < single_total,
836 "batched {} >= single total {}",
837 batched_bytes.len(),
838 single_total
839 );
840 let decoded = TurboCodeWireV1::decode_batch(&batched_bytes, &q).unwrap();
841 assert_eq!(decoded.len(), codes.len());
842 for (i, (orig, back)) in codes.iter().zip(decoded.iter()).enumerate() {
843 assert_eq!(
844 orig.polar_code.radii, back.polar_code.radii,
845 "radii mismatch at vec {i}"
846 );
847 assert_eq!(
848 orig.polar_code.angle_indices, back.polar_code.angle_indices,
849 "angles mismatch at vec {i}"
850 );
851 assert_eq!(
852 orig.residual_sketch.signs, back.residual_sketch.signs,
853 "signs mismatch at vec {i}"
854 );
855 }
856 }
857
858 #[test]
859 fn batched_wire_rejects_wrong_magic() {
860 use crate::turbo::TurboQuantizer;
861 let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
862 let mut bytes = vec![0u8; 64];
863 bytes[0..4].copy_from_slice(b"XXXX");
864 let r = TurboCodeWireV1::decode_batch(&bytes, &q);
865 assert!(r.is_err());
866 }
867
868 #[test]
869 fn batched_wire_rejects_buffer_too_short() {
870 use crate::turbo::TurboQuantizer;
871 let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
872 let mut bytes = b"TQB1".to_vec();
873 bytes.extend_from_slice(&1u16.to_le_bytes());
874 bytes.extend_from_slice(&0u16.to_le_bytes());
875 bytes.push(2);
876 let r = TurboCodeWireV1::decode_batch(&bytes, &q);
878 assert!(r.is_err());
879 }
880
881 #[test]
887 fn batched_wire_lossy_roundtrip_is_within_tolerance() {
888 use crate::turbo::TurboQuantizer;
889 let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
890 let vectors: Vec<Vec<f32>> = (0..16)
891 .map(|i| (0..64).map(|j| ((i * 64 + j) as f32 * 0.013 + 0.1).sin().abs() + 0.1).collect())
892 .collect();
893 let codes: Vec<_> = vectors.iter().map(|v| q.encode(v).unwrap()).collect();
894
895 let lossless_bytes = TurboCodeWireV1::encode_batch(&codes, &q).unwrap();
897 let lossless_decoded = TurboCodeWireV1::decode_batch(&lossless_bytes, &q).unwrap();
898 for (i, (orig, back)) in codes.iter().zip(lossless_decoded.iter()).enumerate() {
899 assert_eq!(
900 orig.polar_code.radii, back.polar_code.radii,
901 "lossless radii mismatch at vec {i}"
902 );
903 }
904
905 let lossy_bytes = TurboCodeWireV1::encode_batch_with_radii(
907 &codes,
908 &q,
909 crate::radius::RadiusCodecProfileV1::BlockLogU8,
910 )
911 .unwrap();
912 let ratio = lossless_bytes.len() as f64 / lossy_bytes.len() as f64;
913 assert!(
914 ratio > 1.5,
915 "lossy should be at least 1.5x smaller, got {ratio:.2}x"
916 );
917
918 let lossy_decoded = TurboCodeWireV1::decode_batch(&lossy_bytes, &q).unwrap();
920 for (i, (orig, back)) in codes.iter().zip(lossy_decoded.iter()).enumerate() {
921 assert_eq!(orig.polar_code.radii.len(), back.polar_code.radii.len());
922 for (j, (a, b)) in orig
923 .polar_code
924 .radii
925 .iter()
926 .zip(back.polar_code.radii.iter())
927 .enumerate()
928 {
929 let rel = if *a > 0.0 { (a - b).abs() / a } else { 0.0 };
930 assert!(
931 rel < 0.05,
932 "lossy radii rel error {rel:.4} at vec {i} radius {j}: orig={a} decoded={b}"
933 );
934 }
935 assert_eq!(
937 orig.polar_code.angle_indices, back.polar_code.angle_indices,
938 "lossy angles mismatch at vec {i}"
939 );
940 assert_eq!(
941 orig.residual_sketch.signs, back.residual_sketch.signs,
942 "lossy signs mismatch at vec {i}"
943 );
944 }
945 }
946
947 #[test]
951 fn batched_wire_lossy_magic_byte_is_one() {
952 use crate::turbo::TurboQuantizer;
953 let q = TurboQuantizer::new(64, 8, 32, 42).unwrap();
954 let vectors: Vec<Vec<f32>> = (0..4)
955 .map(|i| (0..64).map(|j| ((i * 64 + j) as f32 * 0.013).sin()).collect())
956 .collect();
957 let codes: Vec<_> = vectors.iter().map(|v| q.encode(v).unwrap()).collect();
958 let lossy_bytes = TurboCodeWireV1::encode_batch_with_radii(
959 &codes,
960 &q,
961 crate::radius::RadiusCodecProfileV1::BlockLogU8,
962 )
963 .unwrap();
964 assert_eq!(lossy_bytes[15], RADII_CODEC_BLOCK_LOG_U8);
966
967 let lossless_bytes = TurboCodeWireV1::encode_batch(&codes, &q).unwrap();
968 assert_eq!(lossless_bytes[15], RADII_CODEC_F32);
969 }
970}