Skip to main content

turbo_quant/
packed.rs

1//! Explicit packed sidecar payloads.
2
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7    bitpack,
8    error::{Result, TurboQuantError},
9    polar::PolarCode,
10    qjl::QjlSketch,
11    radius::{CompressedRadiiV1, RadiusCodecProfileV1},
12    turbo::TurboCode,
13};
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
16pub struct PackedPolarCode {
17    pub dim: usize,
18    pub bits: u8,
19    pub radii: CompressedRadiiV1,
20    pub packed_angle_indices: Vec<u8>,
21}
22
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
24pub struct PackedQjlSketch {
25    pub dim: usize,
26    pub projections: usize,
27    pub packed_signs: Vec<u8>,
28}
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
31pub struct PackedTurboCode {
32    pub polar_code: PackedPolarCode,
33    pub residual_sketch: PackedQjlSketch,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
37pub struct PackedCompressionStatsV1 {
38    pub raw_fp32_bytes: usize,
39    pub fp16_baseline_bytes: usize,
40    pub legacy_logical_bytes: usize,
41    pub packed_sidecar_bytes: usize,
42}
43
44impl PackedPolarCode {
45    pub fn from_polar(code: &PolarCode, radius_profile: RadiusCodecProfileV1) -> Result<Self> {
46        code.validate_for(code.dim, code.bits)?;
47        Ok(Self {
48            dim: code.dim,
49            bits: code.bits,
50            radii: CompressedRadiiV1::compress(&code.radii, radius_profile)?,
51            packed_angle_indices: bitpack::pack_indices(&code.angle_indices, code.bits)?,
52        })
53    }
54
55    pub fn unpack(&self) -> Result<PolarCode> {
56        let pairs = checked_pairs(self.dim)?;
57        let radii = self.radii.decompress()?;
58        if radii.len() != pairs {
59            return Err(TurboQuantError::MalformedCode {
60                reason: format!("packed polar has {} radii, expected {pairs}", radii.len()),
61            });
62        }
63        let angle_indices = bitpack::unpack_indices(&self.packed_angle_indices, pairs, self.bits)?;
64        let code = PolarCode {
65            dim: self.dim,
66            bits: self.bits,
67            radii,
68            angle_indices,
69        };
70        code.validate_for(self.dim, self.bits)?;
71        Ok(code)
72    }
73
74    pub fn encoded_bytes(&self) -> usize {
75        self.radii.encoded_bytes() + self.packed_angle_indices.len()
76    }
77}
78
79impl PackedQjlSketch {
80    pub fn from_qjl(sketch: &QjlSketch) -> Result<Self> {
81        sketch.validate_for(sketch.dim, sketch.projections)?;
82        Ok(Self {
83            dim: sketch.dim,
84            projections: sketch.projections,
85            packed_signs: bitpack::pack_signs(&sketch.signs)?,
86        })
87    }
88
89    pub fn unpack(&self) -> Result<QjlSketch> {
90        let signs = bitpack::unpack_signs(&self.packed_signs, self.projections)?;
91        let sketch = QjlSketch {
92            dim: self.dim,
93            projections: self.projections,
94            signs,
95        };
96        sketch.validate_for(self.dim, self.projections)?;
97        Ok(sketch)
98    }
99
100    pub fn encoded_bytes(&self) -> usize {
101        self.packed_signs.len()
102    }
103}
104
105impl PackedTurboCode {
106    pub fn from_turbo(code: &TurboCode, radius_profile: RadiusCodecProfileV1) -> Result<Self> {
107        Ok(Self {
108            polar_code: PackedPolarCode::from_polar(&code.polar_code, radius_profile)?,
109            residual_sketch: PackedQjlSketch::from_qjl(&code.residual_sketch)?,
110        })
111    }
112
113    pub fn unpack(&self) -> Result<TurboCode> {
114        Ok(TurboCode {
115            polar_code: self.polar_code.unpack()?,
116            residual_sketch: self.residual_sketch.unpack()?,
117        })
118    }
119
120    pub fn encoded_bytes(&self) -> usize {
121        self.polar_code.encoded_bytes() + self.residual_sketch.encoded_bytes()
122    }
123
124    pub fn stats(&self) -> PackedCompressionStatsV1 {
125        let dim = self.polar_code.dim;
126        let packed_sidecar_bytes = self.encoded_bytes();
127        PackedCompressionStatsV1 {
128            raw_fp32_bytes: dim * 4,
129            fp16_baseline_bytes: dim * 2,
130            legacy_logical_bytes: self
131                .unpack()
132                .map(|code| {
133                    code.polar_code.radii.len() * 4
134                        + code.polar_code.angle_indices.len() * 2
135                        + code.residual_sketch.signs.len()
136                })
137                .unwrap_or(usize::MAX),
138            packed_sidecar_bytes,
139        }
140    }
141}
142
143fn checked_pairs(dim: usize) -> Result<usize> {
144    if dim == 0 || dim % 2 != 0 {
145        return Err(TurboQuantError::MalformedCode {
146            reason: format!("packed polar dimension must be positive and even, got {dim}"),
147        });
148    }
149    Ok(dim / 2)
150}