1use rand::SeedableRng;
2use rand_chacha::ChaCha8Rng;
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 profile::{EmptyCellPolicy, FibQuantProfileV1, SourceMode},
7 rotation::StoredRotation,
8 spherical_beta::{sample_reference_projection, sample_spherical_beta},
9 FibQuantError, Result,
10};
11
12pub const LLOYD_REPORT_SCHEMA: &str = "lloyd_report_v1";
13const DONOR_SPLIT_EPSILON: f64 = 1.0e-6;
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub struct LloydRepairEventV1 {
18 pub restart: u32,
20 pub iteration: u32,
22 pub empty_cell: u32,
24 pub donor_cell: u32,
26 pub donor_count_before: u32,
28 pub donor_distortion: f64,
30 pub residual_norm: f64,
32 pub split_epsilon: f64,
34}
35
36#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct LloydReportV1 {
39 pub schema_version: String,
41 pub restarts: u32,
43 pub iterations: u32,
45 pub training_samples: u32,
47 pub init_mse: f64,
49 pub best_mse: f64,
51 pub best_restart: u32,
53 pub empty_cells_repaired: u32,
55 pub repair_events: Vec<LloydRepairEventV1>,
57 pub seed: u64,
59}
60
61impl LloydReportV1 {
62 pub(crate) fn validate_against_profile(&self, profile: &FibQuantProfileV1) -> Result<()> {
63 if self.schema_version != LLOYD_REPORT_SCHEMA {
64 return Err(FibQuantError::CorruptPayload(format!(
65 "lloyd report schema_version {}, expected {LLOYD_REPORT_SCHEMA}",
66 self.schema_version
67 )));
68 }
69 if self.restarts != profile.lloyd_restarts
70 || self.iterations != profile.lloyd_iterations
71 || self.training_samples != profile.training_samples
72 || self.seed != profile.codebook_seed
73 {
74 return Err(FibQuantError::CorruptPayload(
75 "lloyd report does not match profile settings".into(),
76 ));
77 }
78 if !self.init_mse.is_finite() || !self.best_mse.is_finite() {
79 return Err(FibQuantError::CorruptPayload(
80 "lloyd report contains non-finite mse".into(),
81 ));
82 }
83 if self.empty_cells_repaired as usize != self.repair_events.len() {
84 return Err(FibQuantError::CorruptPayload(
85 "lloyd repair event count mismatch".into(),
86 ));
87 }
88 Ok(())
89 }
90}
91
92pub(crate) struct RefinedCodebook {
93 pub codewords: Vec<f32>,
94 pub init_mse: f64,
95 pub training_mse: f64,
96 pub report: LloydReportV1,
97}
98
99struct RepairRecorder<'a> {
100 events: &'a mut Vec<LloydRepairEventV1>,
101 restart: u32,
102 iteration: u32,
103}
104
105pub(crate) fn refine_codebook(
107 profile: &FibQuantProfileV1,
108 initial: &[f64],
109) -> Result<RefinedCodebook> {
110 profile.validate()?;
111 let k = profile.block_dim as usize;
112 let n = profile.codebook_size as usize;
113 if initial.len() != n * k {
114 return Err(FibQuantError::CorruptPayload(format!(
115 "initial codebook has {}, expected {}",
116 initial.len(),
117 n * k
118 )));
119 }
120 let samples = training_samples(profile)?;
121 let init_mse = mse_for_codebook(initial, k, &samples)?;
122 let restarts = profile.lloyd_restarts.max(1);
123 let iterations = profile.lloyd_iterations;
124 let mut best = initial.to_vec();
125 let mut best_mse = init_mse;
126 let mut best_restart = 0;
127 let mut total_repairs = 0u32;
128 let mut all_repair_events = Vec::new();
129
130 for restart in 0..restarts {
131 let mut codebook = rotated_initial(profile, initial, restart)?;
132 let mut restart_repairs = 0u32;
133 for iteration in 0..iterations {
134 let assignments = assign_samples(&codebook, k, &samples);
135 update_centroids(
136 &mut codebook,
137 k,
138 &samples,
139 &assignments,
140 profile.empty_cell_policy.clone(),
141 &mut restart_repairs,
142 RepairRecorder {
143 events: &mut all_repair_events,
144 restart,
145 iteration,
146 },
147 )?;
148 }
149 let mse = mse_for_codebook(&codebook, k, &samples)?;
150 total_repairs = total_repairs.saturating_add(restart_repairs);
151 if mse < best_mse || restart == 0 && init_mse.is_infinite() {
152 best_mse = mse;
153 best = codebook;
154 best_restart = restart;
155 }
156 }
157
158 if best_mse > init_mse {
159 best = initial.to_vec();
160 best_mse = init_mse;
161 best_restart = u32::MAX;
162 }
163
164 let report = LloydReportV1 {
165 schema_version: LLOYD_REPORT_SCHEMA.into(),
166 restarts,
167 iterations,
168 training_samples: samples.len() as u32,
169 init_mse,
170 best_mse,
171 best_restart,
172 empty_cells_repaired: total_repairs,
173 repair_events: all_repair_events,
174 seed: profile.codebook_seed,
175 };
176 Ok(RefinedCodebook {
177 codewords: best.into_iter().map(|value| value as f32).collect(),
178 init_mse,
179 training_mse: best_mse,
180 report,
181 })
182}
183
184fn training_samples(profile: &FibQuantProfileV1) -> Result<Vec<Vec<f64>>> {
185 let d = profile.ambient_dim as usize;
186 let k = profile.block_dim as usize;
187 let count = profile.training_samples.max(profile.codebook_size) as usize;
188 let mut rng = ChaCha8Rng::seed_from_u64(profile.codebook_seed ^ 0x4651_5541_4e54);
189 (0..count)
190 .map(|_| match profile.source_mode {
191 SourceMode::CanonicalSphericalBeta => sample_spherical_beta(d, k, &mut rng),
192 SourceMode::ReferenceGaussianProjection => sample_reference_projection(d, k, &mut rng),
193 })
194 .collect()
195}
196
197fn rotated_initial(profile: &FibQuantProfileV1, initial: &[f64], restart: u32) -> Result<Vec<f64>> {
198 let k = profile.block_dim as usize;
199 if restart == 0 {
200 return Ok(initial.to_vec());
201 }
202 let rotation = StoredRotation::new(
203 k,
204 profile
205 .codebook_seed
206 .wrapping_add(u64::from(restart) * 0x9e37_79b9),
207 )?;
208 let mut out = Vec::with_capacity(initial.len());
209 for codeword in initial.chunks_exact(k) {
210 out.extend(rotation.apply(codeword)?);
211 }
212 Ok(out)
213}
214
215fn assign_samples(codebook: &[f64], k: usize, samples: &[Vec<f64>]) -> Vec<usize> {
216 samples
217 .iter()
218 .map(|sample| nearest_index(sample, codebook, k).0)
219 .collect()
220}
221
222fn update_centroids(
223 codebook: &mut [f64],
224 k: usize,
225 samples: &[Vec<f64>],
226 assignments: &[usize],
227 policy: EmptyCellPolicy,
228 repairs: &mut u32,
229 recorder: RepairRecorder<'_>,
230) -> Result<()> {
231 let n = codebook.len() / k;
232 let mut sums = vec![0.0; codebook.len()];
233 let mut counts = vec![0usize; n];
234 let mut distortion = vec![0.0; n];
235 let mut farthest_samples = vec![vec![0.0; k]; n];
236 let mut farthest_distances = vec![-1.0; n];
237 for (sample, &assignment) in samples.iter().zip(assignments) {
238 counts[assignment] += 1;
239 let mut sample_dist = 0.0;
240 for dim in 0..k {
241 sums[assignment * k + dim] += sample[dim];
242 let delta = sample[dim] - codebook[assignment * k + dim];
243 sample_dist += delta * delta;
244 }
245 distortion[assignment] += sample_dist;
246 if sample_dist > farthest_distances[assignment] {
247 farthest_distances[assignment] = sample_dist;
248 farthest_samples[assignment].clone_from(sample);
249 }
250 }
251 for idx in 0..n {
252 if counts[idx] > 0 {
253 for dim in 0..k {
254 codebook[idx * k + dim] = sums[idx * k + dim] / counts[idx] as f64;
255 }
256 }
257 }
258 let empty: Vec<_> = counts
259 .iter()
260 .enumerate()
261 .filter_map(|(idx, count)| (*count == 0).then_some(idx))
262 .collect();
263 if empty.is_empty() {
264 return Ok(());
265 }
266 if policy == EmptyCellPolicy::FailClosed {
267 return Err(FibQuantError::EmptyCellRepairFailed(format!(
268 "{} empty cells",
269 empty.len()
270 )));
271 }
272 for empty_idx in empty {
273 let donor = counts
274 .iter()
275 .enumerate()
276 .filter(|(_, count)| **count > 1)
277 .max_by(|(left, _), (right, _)| distortion[*left].total_cmp(&distortion[*right]))
278 .map(|(idx, _)| idx)
279 .ok_or_else(|| FibQuantError::EmptyCellRepairFailed("no donor cell".into()))?;
280 let donor_count_before = counts[donor];
281 let donor_distortion = distortion[donor];
282 let mut residual = vec![0.0; k];
283 let mut residual_norm_sq = 0.0;
284 for dim in 0..k {
285 residual[dim] = farthest_samples[donor][dim] - codebook[donor * k + dim];
286 residual_norm_sq += residual[dim] * residual[dim];
287 }
288 let residual_norm = residual_norm_sq.sqrt();
289 if !residual_norm.is_finite() || residual_norm <= f64::EPSILON {
290 return Err(FibQuantError::EmptyCellRepairFailed(
291 "donor residual has zero direction".into(),
292 ));
293 }
294 for dim in 0..k {
295 let direction = residual[dim] / residual_norm;
296 let centroid = codebook[donor * k + dim];
297 codebook[donor * k + dim] = centroid - DONOR_SPLIT_EPSILON * direction;
298 codebook[empty_idx * k + dim] = centroid + DONOR_SPLIT_EPSILON * direction;
299 }
300 recorder.events.push(LloydRepairEventV1 {
301 restart: recorder.restart,
302 iteration: recorder.iteration,
303 empty_cell: empty_idx as u32,
304 donor_cell: donor as u32,
305 donor_count_before: donor_count_before as u32,
306 donor_distortion,
307 residual_norm,
308 split_epsilon: DONOR_SPLIT_EPSILON,
309 });
310 counts[donor] -= 1;
311 counts[empty_idx] = 1;
312 distortion[donor] = 0.0;
313 distortion[empty_idx] = 0.0;
314 *repairs = repairs.saturating_add(1);
315 }
316 Ok(())
317}
318
319pub(crate) fn nearest_index(sample: &[f64], codebook: &[f64], k: usize) -> (usize, f64) {
320 let mut best_idx = 0usize;
321 let mut best_dist = f64::INFINITY;
322 for (idx, codeword) in codebook.chunks_exact(k).enumerate() {
323 let dist: f64 = sample
324 .iter()
325 .zip(codeword)
326 .map(|(left, right)| {
327 let delta = left - right;
328 delta * delta
329 })
330 .sum();
331 if dist < best_dist {
332 best_dist = dist;
333 best_idx = idx;
334 }
335 }
336 (best_idx, best_dist)
337}
338
339fn mse_for_codebook(codebook: &[f64], k: usize, samples: &[Vec<f64>]) -> Result<f64> {
340 if samples.is_empty() {
341 return Err(FibQuantError::NumericalFailure(
342 "empty Lloyd training set".into(),
343 ));
344 }
345 let sum: f64 = samples
346 .iter()
347 .map(|sample| nearest_index(sample, codebook, k).1)
348 .sum();
349 let mse = sum / samples.len() as f64;
350 if mse.is_finite() {
351 Ok(mse)
352 } else {
353 Err(FibQuantError::NumericalFailure(
354 "non-finite Lloyd MSE".into(),
355 ))
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn empty_cell_repair_splits_highest_distortion_donor() {
365 let mut codebook = vec![0.0, 0.0, 10.0, 0.0, 20.0, 0.0];
366 let samples = vec![
367 vec![0.0, 0.0],
368 vec![1.0, 0.0],
369 vec![2.0, 0.0],
370 vec![10.0, 0.0],
371 vec![10.1, 0.0],
372 ];
373 let assignments = vec![0, 0, 0, 1, 1];
374 let mut repairs = 0;
375 let mut events = Vec::new();
376
377 update_centroids(
378 &mut codebook,
379 2,
380 &samples,
381 &assignments,
382 EmptyCellPolicy::SplitHighestDistortion,
383 &mut repairs,
384 RepairRecorder {
385 events: &mut events,
386 restart: 0,
387 iteration: 0,
388 },
389 )
390 .unwrap();
391
392 assert_eq!(repairs, 1);
393 assert_eq!(events.len(), 1);
394 assert_eq!(events[0].empty_cell, 2);
395 assert_eq!(events[0].donor_cell, 0);
396 assert_eq!(events[0].donor_count_before, 3);
397 assert!(events[0].donor_distortion > 1.0);
398 assert!(events[0].residual_norm > 0.0);
399 assert!(codebook[0] < 1.0);
400 assert!(codebook[4] > 1.0);
401 }
402}