1use serde::{Deserialize, Serialize};
2
3use crate::{
4 digest::json_digest,
5 directions::directions_for_method,
6 lloyd::{refine_codebook, LloydReportV1},
7 profile::{FibQuantProfileV1, RadiusMethod},
8 rotation::StoredRotation,
9 spherical_beta::{radius_quantile, radius_quantile_k2_closed_form},
10 FibQuantError, Result,
11};
12
13pub const CODEBOOK_SCHEMA: &str = "fib_codebook_v1";
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub struct FibCodebookV1 {
18 pub schema_version: String,
20 pub profile: FibQuantProfileV1,
22 pub profile_digest: String,
24 pub codebook_digest: String,
26 pub rotation_digest: String,
28 pub codewords: Vec<f32>,
30 pub init_mse: f64,
32 pub training_mse: f64,
34 pub refinement_report: LloydReportV1,
36}
37
38impl FibCodebookV1 {
39 pub fn build(profile: FibQuantProfileV1) -> Result<Self> {
41 profile.validate()?;
42 let profile_digest = profile.digest()?;
43 let rotation_digest =
44 StoredRotation::new(profile.ambient_dim as usize, profile.rotation_seed)?.digest()?;
45 let initial = build_initial_codebook(&profile)?;
46 let refined = refine_codebook(&profile, &initial)?;
47 let mut codebook = Self {
48 schema_version: CODEBOOK_SCHEMA.into(),
49 profile,
50 profile_digest,
51 codebook_digest: String::new(),
52 rotation_digest,
53 codewords: refined.codewords,
54 init_mse: refined.init_mse,
55 training_mse: refined.training_mse,
56 refinement_report: refined.report,
57 };
58 codebook.codebook_digest = codebook.compute_digest()?;
59 Ok(codebook)
60 }
61
62 pub fn validate(&self) -> Result<()> {
64 if self.schema_version != CODEBOOK_SCHEMA {
65 return Err(FibQuantError::CorruptPayload(format!(
66 "codebook schema_version {}, expected {CODEBOOK_SCHEMA}",
67 self.schema_version
68 )));
69 }
70 self.profile.validate()?;
71 self.refinement_report
72 .validate_against_profile(&self.profile)?;
73 let expected_profile = self.profile.digest()?;
74 if self.profile_digest != expected_profile {
75 return Err(FibQuantError::ProfileDigestMismatch {
76 expected: expected_profile,
77 actual: self.profile_digest.clone(),
78 });
79 }
80 let expected_rotation = StoredRotation::new(
81 self.profile.ambient_dim as usize,
82 self.profile.rotation_seed,
83 )?
84 .digest()?;
85 if self.rotation_digest != expected_rotation {
86 return Err(FibQuantError::RotationDigestMismatch {
87 expected: expected_rotation,
88 actual: self.rotation_digest.clone(),
89 });
90 }
91 let expected_codebook = self.compute_digest()?;
92 if self.codebook_digest != expected_codebook {
93 return Err(FibQuantError::CodebookDigestMismatch {
94 expected: expected_codebook,
95 actual: self.codebook_digest.clone(),
96 });
97 }
98 let expected_len = (self.profile.codebook_size as usize)
99 .checked_mul(self.profile.block_dim as usize)
100 .ok_or_else(|| {
101 FibQuantError::ResourceLimitExceeded("codebook value count overflow".into())
102 })?;
103 if self.codewords.len() != expected_len {
104 return Err(FibQuantError::CorruptPayload(format!(
105 "codebook has {} values, expected {expected_len}",
106 self.codewords.len()
107 )));
108 }
109 if self.codewords.iter().any(|value| !value.is_finite()) {
110 return Err(FibQuantError::CorruptPayload(
111 "codebook contains non-finite value".into(),
112 ));
113 }
114 Ok(())
115 }
116
117 pub fn codeword(&self, index: usize) -> Result<Vec<f64>> {
119 let n = self.profile.codebook_size as usize;
120 let k = self.profile.block_dim as usize;
121 if index >= n {
122 return Err(FibQuantError::IndexOutOfRange {
123 index: index as u32,
124 codebook_size: n as u32,
125 });
126 }
127 Ok(self.codewords[index * k..(index + 1) * k]
128 .iter()
129 .map(|value| f64::from(*value))
130 .collect())
131 }
132
133 pub fn compute_digest(&self) -> Result<String> {
135 #[derive(Serialize)]
136 struct DigestView<'a> {
137 schema_version: &'a str,
138 profile_digest: &'a str,
139 rotation_digest: &'a str,
140 codewords: &'a [f32],
141 init_mse: f64,
142 training_mse: f64,
143 refinement_report: &'a LloydReportV1,
144 }
145 json_digest(
146 CODEBOOK_SCHEMA,
147 &DigestView {
148 schema_version: &self.schema_version,
149 profile_digest: &self.profile_digest,
150 rotation_digest: &self.rotation_digest,
151 codewords: &self.codewords,
152 init_mse: self.init_mse,
153 training_mse: self.training_mse,
154 refinement_report: &self.refinement_report,
155 },
156 )
157 }
158}
159
160pub fn build_initial_codebook(profile: &FibQuantProfileV1) -> Result<Vec<f64>> {
162 profile.validate()?;
163 let d = profile.ambient_dim as usize;
164 let k = profile.block_dim as usize;
165 let n = profile.codebook_size as usize;
166 let directions = directions_for_method(k, n, &profile.direction_method)?;
167 let value_count = n.checked_mul(k).ok_or_else(|| {
168 FibQuantError::ResourceLimitExceeded("codebook value count overflow".into())
169 })?;
170 let mut codewords = Vec::with_capacity(value_count);
171 for (idx, direction) in directions.iter().enumerate() {
172 let radius = radius_for_method(profile, d, k, idx + 1, n)?;
173 for value in direction {
174 let code = radius * value;
175 if !code.is_finite() {
176 return Err(FibQuantError::NumericalFailure(
177 "non-finite initialized codeword".into(),
178 ));
179 }
180 codewords.push(code);
181 }
182 }
183 Ok(codewords)
184}
185
186fn radius_for_method(
187 profile: &FibQuantProfileV1,
188 d: usize,
189 k: usize,
190 idx: usize,
191 n: usize,
192) -> Result<f64> {
193 match profile.radius_method {
194 RadiusMethod::K2ClosedForm if k == 2 => {
195 let q = (idx as f64 - 0.5) / n as f64;
196 radius_quantile_k2_closed_form(d, q)
197 }
198 RadiusMethod::BetaQuantile if k >= 3 => radius_quantile(d, k, idx, n),
199 _ => Err(FibQuantError::CorruptPayload(format!(
200 "radius method {:?} is not supported for k={k}",
201 profile.radius_method
202 ))),
203 }
204}