Skip to main content

fib_quant/
lloyd.rs

1use rand::SeedableRng;
2use rand_chacha::ChaCha8Rng;
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    profile::{EmptyCellPolicy, FibQuantProfileV1, SourceMode},
7    rotation::StoredRotation,
8    spherical_beta::{sample_reference_projection, sample_spherical_beta},
9    FibQuantError, Result,
10};
11
12pub const LLOYD_REPORT_SCHEMA: &str = "lloyd_report_v1";
13const DONOR_SPLIT_EPSILON: f64 = 1.0e-6;
14
15/// Deterministic empty-cell repair event.
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub struct LloydRepairEventV1 {
18    /// Restart index where the repair occurred.
19    pub restart: u32,
20    /// Iteration index where the repair occurred.
21    pub iteration: u32,
22    /// Empty cell that was repaired.
23    pub empty_cell: u32,
24    /// Donor cell selected for splitting.
25    pub donor_cell: u32,
26    /// Donor assignment count before splitting.
27    pub donor_count_before: u32,
28    /// Donor total distortion before splitting.
29    pub donor_distortion: f64,
30    /// Norm of the farthest residual used as split direction.
31    pub residual_norm: f64,
32    /// Epsilon used to split the donor centroid.
33    pub split_epsilon: f64,
34}
35
36/// Lloyd-Max refinement report.
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
38pub struct LloydReportV1 {
39    /// Stable schema marker.
40    pub schema_version: String,
41    /// Number of requested restarts.
42    pub restarts: u32,
43    /// Number of requested iterations per restart.
44    pub iterations: u32,
45    /// Number of training samples used.
46    pub training_samples: u32,
47    /// Initial codebook MSE on the training set.
48    pub init_mse: f64,
49    /// Best MSE found.
50    pub best_mse: f64,
51    /// Best restart index.
52    pub best_restart: u32,
53    /// Number of empty cells repaired.
54    pub empty_cells_repaired: u32,
55    /// Detailed deterministic empty-cell repair events.
56    pub repair_events: Vec<LloydRepairEventV1>,
57    /// Deterministic seed used for refinement.
58    pub seed: u64,
59}
60
61impl LloydReportV1 {
62    pub(crate) fn validate_against_profile(&self, profile: &FibQuantProfileV1) -> Result<()> {
63        if self.schema_version != LLOYD_REPORT_SCHEMA {
64            return Err(FibQuantError::CorruptPayload(format!(
65                "lloyd report schema_version {}, expected {LLOYD_REPORT_SCHEMA}",
66                self.schema_version
67            )));
68        }
69        if self.restarts != profile.lloyd_restarts
70            || self.iterations != profile.lloyd_iterations
71            || self.training_samples != profile.training_samples
72            || self.seed != profile.codebook_seed
73        {
74            return Err(FibQuantError::CorruptPayload(
75                "lloyd report does not match profile settings".into(),
76            ));
77        }
78        if !self.init_mse.is_finite() || !self.best_mse.is_finite() {
79            return Err(FibQuantError::CorruptPayload(
80                "lloyd report contains non-finite mse".into(),
81            ));
82        }
83        if self.empty_cells_repaired as usize != self.repair_events.len() {
84            return Err(FibQuantError::CorruptPayload(
85                "lloyd repair event count mismatch".into(),
86            ));
87        }
88        Ok(())
89    }
90}
91
92pub(crate) struct RefinedCodebook {
93    pub codewords: Vec<f32>,
94    pub init_mse: f64,
95    pub training_mse: f64,
96    pub report: LloydReportV1,
97}
98
99struct RepairRecorder<'a> {
100    events: &'a mut Vec<LloydRepairEventV1>,
101    restart: u32,
102    iteration: u32,
103}
104
105/// Run deterministic multi-restart Lloyd-Max refinement.
106pub(crate) fn refine_codebook(
107    profile: &FibQuantProfileV1,
108    initial: &[f64],
109) -> Result<RefinedCodebook> {
110    profile.validate()?;
111    let k = profile.block_dim as usize;
112    let n = profile.codebook_size as usize;
113    if initial.len() != n * k {
114        return Err(FibQuantError::CorruptPayload(format!(
115            "initial codebook has {}, expected {}",
116            initial.len(),
117            n * k
118        )));
119    }
120    let samples = training_samples(profile)?;
121    let init_mse = mse_for_codebook(initial, k, &samples)?;
122    let restarts = profile.lloyd_restarts.max(1);
123    let iterations = profile.lloyd_iterations;
124    let mut best = initial.to_vec();
125    let mut best_mse = init_mse;
126    let mut best_restart = 0;
127    let mut total_repairs = 0u32;
128    let mut all_repair_events = Vec::new();
129
130    for restart in 0..restarts {
131        let mut codebook = rotated_initial(profile, initial, restart)?;
132        let mut restart_repairs = 0u32;
133        for iteration in 0..iterations {
134            let assignments = assign_samples(&codebook, k, &samples);
135            update_centroids(
136                &mut codebook,
137                k,
138                &samples,
139                &assignments,
140                profile.empty_cell_policy.clone(),
141                &mut restart_repairs,
142                RepairRecorder {
143                    events: &mut all_repair_events,
144                    restart,
145                    iteration,
146                },
147            )?;
148        }
149        let mse = mse_for_codebook(&codebook, k, &samples)?;
150        total_repairs = total_repairs.saturating_add(restart_repairs);
151        if mse < best_mse || restart == 0 && init_mse.is_infinite() {
152            best_mse = mse;
153            best = codebook;
154            best_restart = restart;
155        }
156    }
157
158    if best_mse > init_mse {
159        best = initial.to_vec();
160        best_mse = init_mse;
161        best_restart = u32::MAX;
162    }
163
164    let report = LloydReportV1 {
165        schema_version: LLOYD_REPORT_SCHEMA.into(),
166        restarts,
167        iterations,
168        training_samples: samples.len() as u32,
169        init_mse,
170        best_mse,
171        best_restart,
172        empty_cells_repaired: total_repairs,
173        repair_events: all_repair_events,
174        seed: profile.codebook_seed,
175    };
176    Ok(RefinedCodebook {
177        codewords: best.into_iter().map(|value| value as f32).collect(),
178        init_mse,
179        training_mse: best_mse,
180        report,
181    })
182}
183
184fn training_samples(profile: &FibQuantProfileV1) -> Result<Vec<Vec<f64>>> {
185    let d = profile.ambient_dim as usize;
186    let k = profile.block_dim as usize;
187    let count = profile.training_samples.max(profile.codebook_size) as usize;
188    let mut rng = ChaCha8Rng::seed_from_u64(profile.codebook_seed ^ 0x4651_5541_4e54);
189    (0..count)
190        .map(|_| match profile.source_mode {
191            SourceMode::CanonicalSphericalBeta => sample_spherical_beta(d, k, &mut rng),
192            SourceMode::ReferenceGaussianProjection => sample_reference_projection(d, k, &mut rng),
193        })
194        .collect()
195}
196
197fn rotated_initial(profile: &FibQuantProfileV1, initial: &[f64], restart: u32) -> Result<Vec<f64>> {
198    let k = profile.block_dim as usize;
199    if restart == 0 {
200        return Ok(initial.to_vec());
201    }
202    let rotation = StoredRotation::new(
203        k,
204        profile
205            .codebook_seed
206            .wrapping_add(u64::from(restart) * 0x9e37_79b9),
207    )?;
208    let mut out = Vec::with_capacity(initial.len());
209    for codeword in initial.chunks_exact(k) {
210        out.extend(rotation.apply(codeword)?);
211    }
212    Ok(out)
213}
214
215fn assign_samples(codebook: &[f64], k: usize, samples: &[Vec<f64>]) -> Vec<usize> {
216    samples
217        .iter()
218        .map(|sample| nearest_index(sample, codebook, k).0)
219        .collect()
220}
221
222fn update_centroids(
223    codebook: &mut [f64],
224    k: usize,
225    samples: &[Vec<f64>],
226    assignments: &[usize],
227    policy: EmptyCellPolicy,
228    repairs: &mut u32,
229    recorder: RepairRecorder<'_>,
230) -> Result<()> {
231    let n = codebook.len() / k;
232    let mut sums = vec![0.0; codebook.len()];
233    let mut counts = vec![0usize; n];
234    let mut distortion = vec![0.0; n];
235    let mut farthest_samples = vec![vec![0.0; k]; n];
236    let mut farthest_distances = vec![-1.0; n];
237    for (sample, &assignment) in samples.iter().zip(assignments) {
238        counts[assignment] += 1;
239        let mut sample_dist = 0.0;
240        for dim in 0..k {
241            sums[assignment * k + dim] += sample[dim];
242            let delta = sample[dim] - codebook[assignment * k + dim];
243            sample_dist += delta * delta;
244        }
245        distortion[assignment] += sample_dist;
246        if sample_dist > farthest_distances[assignment] {
247            farthest_distances[assignment] = sample_dist;
248            farthest_samples[assignment].clone_from(sample);
249        }
250    }
251    for idx in 0..n {
252        if counts[idx] > 0 {
253            for dim in 0..k {
254                codebook[idx * k + dim] = sums[idx * k + dim] / counts[idx] as f64;
255            }
256        }
257    }
258    let empty: Vec<_> = counts
259        .iter()
260        .enumerate()
261        .filter_map(|(idx, count)| (*count == 0).then_some(idx))
262        .collect();
263    if empty.is_empty() {
264        return Ok(());
265    }
266    if policy == EmptyCellPolicy::FailClosed {
267        return Err(FibQuantError::EmptyCellRepairFailed(format!(
268            "{} empty cells",
269            empty.len()
270        )));
271    }
272    for empty_idx in empty {
273        let donor = counts
274            .iter()
275            .enumerate()
276            .filter(|(_, count)| **count > 1)
277            .max_by(|(left, _), (right, _)| distortion[*left].total_cmp(&distortion[*right]))
278            .map(|(idx, _)| idx)
279            .ok_or_else(|| FibQuantError::EmptyCellRepairFailed("no donor cell".into()))?;
280        let donor_count_before = counts[donor];
281        let donor_distortion = distortion[donor];
282        let mut residual = vec![0.0; k];
283        let mut residual_norm_sq = 0.0;
284        for dim in 0..k {
285            residual[dim] = farthest_samples[donor][dim] - codebook[donor * k + dim];
286            residual_norm_sq += residual[dim] * residual[dim];
287        }
288        let residual_norm = residual_norm_sq.sqrt();
289        if !residual_norm.is_finite() || residual_norm <= f64::EPSILON {
290            return Err(FibQuantError::EmptyCellRepairFailed(
291                "donor residual has zero direction".into(),
292            ));
293        }
294        for dim in 0..k {
295            let direction = residual[dim] / residual_norm;
296            let centroid = codebook[donor * k + dim];
297            codebook[donor * k + dim] = centroid - DONOR_SPLIT_EPSILON * direction;
298            codebook[empty_idx * k + dim] = centroid + DONOR_SPLIT_EPSILON * direction;
299        }
300        recorder.events.push(LloydRepairEventV1 {
301            restart: recorder.restart,
302            iteration: recorder.iteration,
303            empty_cell: empty_idx as u32,
304            donor_cell: donor as u32,
305            donor_count_before: donor_count_before as u32,
306            donor_distortion,
307            residual_norm,
308            split_epsilon: DONOR_SPLIT_EPSILON,
309        });
310        counts[donor] -= 1;
311        counts[empty_idx] = 1;
312        distortion[donor] = 0.0;
313        distortion[empty_idx] = 0.0;
314        *repairs = repairs.saturating_add(1);
315    }
316    Ok(())
317}
318
319pub(crate) fn nearest_index(sample: &[f64], codebook: &[f64], k: usize) -> (usize, f64) {
320    let mut best_idx = 0usize;
321    let mut best_dist = f64::INFINITY;
322    for (idx, codeword) in codebook.chunks_exact(k).enumerate() {
323        let dist: f64 = sample
324            .iter()
325            .zip(codeword)
326            .map(|(left, right)| {
327                let delta = left - right;
328                delta * delta
329            })
330            .sum();
331        if dist < best_dist {
332            best_dist = dist;
333            best_idx = idx;
334        }
335    }
336    (best_idx, best_dist)
337}
338
339fn mse_for_codebook(codebook: &[f64], k: usize, samples: &[Vec<f64>]) -> Result<f64> {
340    if samples.is_empty() {
341        return Err(FibQuantError::NumericalFailure(
342            "empty Lloyd training set".into(),
343        ));
344    }
345    let sum: f64 = samples
346        .iter()
347        .map(|sample| nearest_index(sample, codebook, k).1)
348        .sum();
349    let mse = sum / samples.len() as f64;
350    if mse.is_finite() {
351        Ok(mse)
352    } else {
353        Err(FibQuantError::NumericalFailure(
354            "non-finite Lloyd MSE".into(),
355        ))
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn empty_cell_repair_splits_highest_distortion_donor() {
365        let mut codebook = vec![0.0, 0.0, 10.0, 0.0, 20.0, 0.0];
366        let samples = vec![
367            vec![0.0, 0.0],
368            vec![1.0, 0.0],
369            vec![2.0, 0.0],
370            vec![10.0, 0.0],
371            vec![10.1, 0.0],
372        ];
373        let assignments = vec![0, 0, 0, 1, 1];
374        let mut repairs = 0;
375        let mut events = Vec::new();
376
377        update_centroids(
378            &mut codebook,
379            2,
380            &samples,
381            &assignments,
382            EmptyCellPolicy::SplitHighestDistortion,
383            &mut repairs,
384            RepairRecorder {
385                events: &mut events,
386                restart: 0,
387                iteration: 0,
388            },
389        )
390        .unwrap();
391
392        assert_eq!(repairs, 1);
393        assert_eq!(events.len(), 1);
394        assert_eq!(events[0].empty_cell, 2);
395        assert_eq!(events[0].donor_cell, 0);
396        assert_eq!(events[0].donor_count_before, 3);
397        assert!(events[0].donor_distortion > 1.0);
398        assert!(events[0].residual_norm > 0.0);
399        assert!(codebook[0] < 1.0);
400        assert!(codebook[4] > 1.0);
401    }
402}