1use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 bitpack,
8 error::{Result, TurboQuantError},
9 polar::PolarCode,
10 qjl::QjlSketch,
11 radius::{CompressedRadiiV1, RadiusCodecProfileV1},
12 turbo::TurboCode,
13};
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
16pub struct PackedPolarCode {
17 pub dim: usize,
18 pub bits: u8,
19 pub radii: CompressedRadiiV1,
20 pub packed_angle_indices: Vec<u8>,
21}
22
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
24pub struct PackedQjlSketch {
25 pub dim: usize,
26 pub projections: usize,
27 pub packed_signs: Vec<u8>,
28}
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
31pub struct PackedTurboCode {
32 pub polar_code: PackedPolarCode,
33 pub residual_sketch: PackedQjlSketch,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
37pub struct PackedCompressionStatsV1 {
38 pub raw_fp32_bytes: usize,
39 pub fp16_baseline_bytes: usize,
40 pub legacy_logical_bytes: usize,
41 pub packed_sidecar_bytes: usize,
42}
43
44impl PackedPolarCode {
45 pub fn from_polar(code: &PolarCode, radius_profile: RadiusCodecProfileV1) -> Result<Self> {
46 code.validate_for(code.dim, code.bits)?;
47 Ok(Self {
48 dim: code.dim,
49 bits: code.bits,
50 radii: CompressedRadiiV1::compress(&code.radii, radius_profile)?,
51 packed_angle_indices: bitpack::pack_indices(&code.angle_indices, code.bits)?,
52 })
53 }
54
55 pub fn unpack(&self) -> Result<PolarCode> {
56 let pairs = checked_pairs(self.dim)?;
57 let radii = self.radii.decompress()?;
58 if radii.len() != pairs {
59 return Err(TurboQuantError::MalformedCode {
60 reason: format!("packed polar has {} radii, expected {pairs}", radii.len()),
61 });
62 }
63 let angle_indices = bitpack::unpack_indices(&self.packed_angle_indices, pairs, self.bits)?;
64 let code = PolarCode {
65 dim: self.dim,
66 bits: self.bits,
67 radii,
68 angle_indices,
69 };
70 code.validate_for(self.dim, self.bits)?;
71 Ok(code)
72 }
73
74 pub fn encoded_bytes(&self) -> usize {
75 self.radii.encoded_bytes() + self.packed_angle_indices.len()
76 }
77}
78
79impl PackedQjlSketch {
80 pub fn from_qjl(sketch: &QjlSketch) -> Result<Self> {
81 sketch.validate_for(sketch.dim, sketch.projections)?;
82 Ok(Self {
83 dim: sketch.dim,
84 projections: sketch.projections,
85 packed_signs: bitpack::pack_signs(&sketch.signs)?,
86 })
87 }
88
89 pub fn unpack(&self) -> Result<QjlSketch> {
90 let signs = bitpack::unpack_signs(&self.packed_signs, self.projections)?;
91 let sketch = QjlSketch {
92 dim: self.dim,
93 projections: self.projections,
94 signs,
95 };
96 sketch.validate_for(self.dim, self.projections)?;
97 Ok(sketch)
98 }
99
100 pub fn encoded_bytes(&self) -> usize {
101 self.packed_signs.len()
102 }
103}
104
105impl PackedTurboCode {
106 pub fn from_turbo(code: &TurboCode, radius_profile: RadiusCodecProfileV1) -> Result<Self> {
107 Ok(Self {
108 polar_code: PackedPolarCode::from_polar(&code.polar_code, radius_profile)?,
109 residual_sketch: PackedQjlSketch::from_qjl(&code.residual_sketch)?,
110 })
111 }
112
113 pub fn unpack(&self) -> Result<TurboCode> {
114 Ok(TurboCode {
115 polar_code: self.polar_code.unpack()?,
116 residual_sketch: self.residual_sketch.unpack()?,
117 })
118 }
119
120 pub fn encoded_bytes(&self) -> usize {
121 self.polar_code.encoded_bytes() + self.residual_sketch.encoded_bytes()
122 }
123
124 pub fn stats(&self) -> PackedCompressionStatsV1 {
125 let dim = self.polar_code.dim;
126 let packed_sidecar_bytes = self.encoded_bytes();
127 PackedCompressionStatsV1 {
128 raw_fp32_bytes: dim * 4,
129 fp16_baseline_bytes: dim * 2,
130 legacy_logical_bytes: self
131 .unpack()
132 .map(|code| {
133 code.polar_code.radii.len() * 4
134 + code.polar_code.angle_indices.len() * 2
135 + code.residual_sketch.signs.len()
136 })
137 .unwrap_or(usize::MAX),
138 packed_sidecar_bytes,
139 }
140 }
141}
142
143fn checked_pairs(dim: usize) -> Result<usize> {
144 if dim == 0 || dim % 2 != 0 {
145 return Err(TurboQuantError::MalformedCode {
146 reason: format!("packed polar dimension must be positive and even, got {dim}"),
147 });
148 }
149 Ok(dim / 2)
150}