Skip to main content

kizzasi_model/
compression.rs

1//! Model Compression Utilities
2//!
3//! Provides techniques for reducing model size and computational cost while
4//! maintaining performance.
5//!
6//! # Techniques
7//!
8//! - **Pruning**: Remove less important weights
9//! - **Knowledge Distillation**: Transfer knowledge from large to small models
10//! - **Weight Sharing**: Share weights across layers
11//! - **Low-Rank Factorization**: Decompose weight matrices
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use kizzasi_model::compression::{PruningConfig, prune_model};
17//!
18//! let config = PruningConfig::magnitude_based(0.3); // Prune 30% of weights
19//! let compressed_model = prune_model(&model, &config)?;
20//! ```
21
22use crate::error::{ModelError, ModelResult};
23use scirs2_core::ndarray::{Array1, Array2};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27/// Pruning strategy
28#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
29pub enum PruningStrategy {
30    /// Magnitude-based pruning (remove smallest weights)
31    Magnitude,
32    /// Random pruning
33    Random,
34    /// Structured pruning (entire neurons/channels)
35    Structured,
36    /// Movement pruning (based on weight updates)
37    Movement,
38}
39
40/// Pruning configuration
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct PruningConfig {
43    /// Pruning strategy
44    pub strategy: PruningStrategy,
45    /// Sparsity ratio (0.0 - 1.0)
46    pub sparsity: f32,
47    /// Whether to use global or layer-wise threshold
48    pub global_threshold: bool,
49    /// Minimum sparsity per layer
50    pub min_sparsity: f32,
51    /// Maximum sparsity per layer
52    pub max_sparsity: f32,
53}
54
55impl PruningConfig {
56    /// Create magnitude-based pruning configuration
57    pub fn magnitude_based(sparsity: f32) -> Self {
58        Self {
59            strategy: PruningStrategy::Magnitude,
60            sparsity,
61            global_threshold: true,
62            min_sparsity: 0.0,
63            max_sparsity: 0.95,
64        }
65    }
66
67    /// Create structured pruning configuration
68    pub fn structured(sparsity: f32) -> Self {
69        Self {
70            strategy: PruningStrategy::Structured,
71            sparsity,
72            global_threshold: false,
73            min_sparsity: 0.0,
74            max_sparsity: 0.9,
75        }
76    }
77
78    /// Set global threshold flag
79    pub fn global(mut self, global: bool) -> Self {
80        self.global_threshold = global;
81        self
82    }
83
84    /// Set sparsity bounds
85    pub fn bounds(mut self, min: f32, max: f32) -> Self {
86        self.min_sparsity = min;
87        self.max_sparsity = max;
88        self
89    }
90}
91
92/// Pruning statistics
93#[derive(Debug, Clone)]
94pub struct PruningStats {
95    /// Total number of parameters
96    pub total_params: usize,
97    /// Number of pruned parameters
98    pub pruned_params: usize,
99    /// Sparsity ratio achieved
100    pub sparsity: f32,
101    /// Compression ratio
102    pub compression_ratio: f32,
103    /// Per-layer statistics
104    pub layer_stats: HashMap<String, LayerPruningStats>,
105}
106
107/// Per-layer pruning statistics
108#[derive(Debug, Clone)]
109pub struct LayerPruningStats {
110    /// Total parameters in layer
111    pub total: usize,
112    /// Pruned parameters in layer
113    pub pruned: usize,
114    /// Layer sparsity
115    pub sparsity: f32,
116}
117
118impl PruningStats {
119    /// Create new pruning statistics
120    pub fn new() -> Self {
121        Self {
122            total_params: 0,
123            pruned_params: 0,
124            sparsity: 0.0,
125            compression_ratio: 1.0,
126            layer_stats: HashMap::new(),
127        }
128    }
129
130    /// Calculate final statistics
131    pub fn finalize(&mut self) {
132        if self.total_params > 0 {
133            self.sparsity = self.pruned_params as f32 / self.total_params as f32;
134            self.compression_ratio = 1.0 / (1.0 - self.sparsity);
135        }
136    }
137
138    /// Add layer statistics
139    pub fn add_layer(&mut self, name: String, total: usize, pruned: usize) {
140        self.total_params += total;
141        self.pruned_params += pruned;
142
143        let sparsity = if total > 0 {
144            pruned as f32 / total as f32
145        } else {
146            0.0
147        };
148
149        self.layer_stats.insert(
150            name,
151            LayerPruningStats {
152                total,
153                pruned,
154                sparsity,
155            },
156        );
157    }
158
159    /// Print summary
160    pub fn print_summary(&self) {
161        tracing::info!("=== Pruning Statistics ===");
162        tracing::info!("Total parameters: {}", self.total_params);
163        tracing::info!("Pruned parameters: {}", self.pruned_params);
164        tracing::info!("Sparsity: {:.2}%", self.sparsity * 100.0);
165        tracing::info!("Compression ratio: {:.2}x", self.compression_ratio);
166        tracing::info!("\nPer-layer statistics:");
167        for (name, stats) in &self.layer_stats {
168            tracing::info!(
169                "  {}: {}/{} ({:.2}%)",
170                name,
171                stats.pruned,
172                stats.total,
173                stats.sparsity * 100.0
174            );
175        }
176    }
177}
178
179impl Default for PruningStats {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185/// Prune a weight matrix using magnitude-based pruning
186pub fn prune_magnitude(
187    weights: &Array2<f32>,
188    sparsity: f32,
189) -> ModelResult<(Array2<f32>, Array2<bool>)> {
190    if !(0.0..=1.0).contains(&sparsity) {
191        return Err(ModelError::invalid_config(format!(
192            "Pruning: Sparsity must be between 0 and 1, got {}",
193            sparsity
194        )));
195    }
196
197    let total_elements = weights.len();
198    let num_to_prune = (total_elements as f32 * sparsity) as usize;
199
200    // Get absolute values and sort
201    let mut abs_weights: Vec<(f32, (usize, usize))> = weights
202        .indexed_iter()
203        .map(|(idx, &val)| (val.abs(), idx))
204        .collect();
205
206    abs_weights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
207
208    // Create pruning mask
209    let mut mask = Array2::from_elem(weights.dim(), true);
210    for i in 0..num_to_prune {
211        if i < abs_weights.len() {
212            let (_, idx) = abs_weights[i];
213            mask[idx] = false;
214        }
215    }
216
217    // Apply mask
218    let pruned = weights * &mask.mapv(|x| if x { 1.0 } else { 0.0 });
219
220    Ok((pruned, mask))
221}
222
223/// Prune weights based on a global threshold
224pub fn prune_threshold(
225    weights: &Array2<f32>,
226    threshold: f32,
227) -> ModelResult<(Array2<f32>, Array2<bool>)> {
228    let mask = weights.mapv(|x| x.abs() >= threshold);
229    let pruned = weights * &mask.mapv(|x| if x { 1.0 } else { 0.0 });
230
231    Ok((pruned, mask))
232}
233
234/// Knowledge distillation configuration
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct DistillationConfig {
237    /// Temperature for softening probability distributions
238    pub temperature: f32,
239    /// Weight for distillation loss (0.0 - 1.0)
240    pub alpha: f32,
241    /// Weight for task loss (1-alpha typically)
242    pub task_weight: f32,
243}
244
245impl Default for DistillationConfig {
246    fn default() -> Self {
247        Self {
248            temperature: 3.0,
249            alpha: 0.7,
250            task_weight: 0.3,
251        }
252    }
253}
254
255impl DistillationConfig {
256    /// Create new distillation config
257    pub fn new(temperature: f32, alpha: f32) -> Self {
258        Self {
259            temperature,
260            alpha,
261            task_weight: 1.0 - alpha,
262        }
263    }
264
265    /// Set temperature
266    pub fn temperature(mut self, temp: f32) -> Self {
267        self.temperature = temp;
268        self
269    }
270
271    /// Set alpha (distillation weight)
272    pub fn alpha(mut self, alpha: f32) -> Self {
273        self.alpha = alpha;
274        self.task_weight = 1.0 - alpha;
275        self
276    }
277}
278
279/// Compute distillation loss between teacher and student outputs
280pub fn distillation_loss(
281    student_logits: &Array1<f32>,
282    teacher_logits: &Array1<f32>,
283    temperature: f32,
284) -> ModelResult<f32> {
285    if student_logits.len() != teacher_logits.len() {
286        return Err(ModelError::dimension_mismatch(
287            "distillation loss",
288            student_logits.len(),
289            teacher_logits.len(),
290        ));
291    }
292
293    // Apply temperature scaling
294    let student_scaled = student_logits.mapv(|x| x / temperature);
295    let teacher_scaled = teacher_logits.mapv(|x| x / temperature);
296
297    // Compute softmax
298    let student_max = student_scaled.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
299    let teacher_max = teacher_scaled.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
300
301    let student_exp = student_scaled.mapv(|x| (x - student_max).exp());
302    let teacher_exp = teacher_scaled.mapv(|x| (x - teacher_max).exp());
303
304    let student_sum = student_exp.sum();
305    let teacher_sum = teacher_exp.sum();
306
307    let student_probs = &student_exp / student_sum;
308    let teacher_probs = &teacher_exp / teacher_sum;
309
310    // KL divergence: sum(teacher * log(teacher / student))
311    let mut kl_div = 0.0;
312    for i in 0..student_probs.len() {
313        if teacher_probs[i] > 1e-10 && student_probs[i] > 1e-10 {
314            kl_div += teacher_probs[i] * (teacher_probs[i] / student_probs[i]).ln();
315        }
316    }
317
318    // Scale by temperature squared (as per Hinton et al.)
319    Ok(kl_div * temperature * temperature)
320}
321
322/// Low-rank factorization configuration
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct LowRankConfig {
325    /// Rank for factorization
326    pub rank: usize,
327    /// Whether to use SVD or other methods
328    pub use_svd: bool,
329}
330
331impl LowRankConfig {
332    /// Create new low-rank config
333    pub fn new(rank: usize) -> Self {
334        Self {
335            rank,
336            use_svd: true,
337        }
338    }
339
340    /// Set SVD flag
341    pub fn svd(mut self, use_svd: bool) -> Self {
342        self.use_svd = use_svd;
343        self
344    }
345}
346
347/// Compute compression ratio from original and compressed sizes
348pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f32 {
349    if compressed_size == 0 {
350        return f32::INFINITY;
351    }
352    original_size as f32 / compressed_size as f32
353}
354
355/// Weight sharing utilities
356pub mod weight_sharing {
357    use super::*;
358
359    /// K-means clustering for weight sharing
360    pub fn kmeans_cluster(weights: &Array2<f32>, num_clusters: usize) -> ModelResult<Array2<f32>> {
361        if num_clusters == 0 || num_clusters > weights.len() {
362            return Err(ModelError::invalid_config(format!(
363                "K-means clustering: Invalid number of clusters: {}",
364                num_clusters
365            )));
366        }
367
368        // Simple k-means implementation
369        // In production, use a proper clustering library
370        let flat_weights: Vec<f32> = weights.iter().copied().collect();
371
372        // Initialize centroids
373        let mut centroids = Vec::new();
374        let step = flat_weights.len() / num_clusters;
375        for i in 0..num_clusters {
376            if i * step < flat_weights.len() {
377                centroids.push(flat_weights[i * step]);
378            }
379        }
380
381        // Iterative refinement (simplified)
382        for _ in 0..10 {
383            let mut cluster_sums = vec![0.0; num_clusters];
384            let mut cluster_counts = vec![0usize; num_clusters];
385
386            for &weight in &flat_weights {
387                let mut min_dist = f32::INFINITY;
388                let mut cluster_id = 0;
389
390                for (i, &centroid) in centroids.iter().enumerate() {
391                    let dist = (weight - centroid).abs();
392                    if dist < min_dist {
393                        min_dist = dist;
394                        cluster_id = i;
395                    }
396                }
397
398                cluster_sums[cluster_id] += weight;
399                cluster_counts[cluster_id] += 1;
400            }
401
402            // Update centroids
403            for i in 0..num_clusters {
404                if cluster_counts[i] > 0 {
405                    centroids[i] = cluster_sums[i] / cluster_counts[i] as f32;
406                }
407            }
408        }
409
410        // Assign weights to nearest centroid
411        let mut quantized = Array2::zeros(weights.dim());
412        for (idx, &weight) in weights.indexed_iter() {
413            let mut min_dist = f32::INFINITY;
414            let mut best_centroid = centroids[0];
415
416            for &centroid in &centroids {
417                let dist = (weight - centroid).abs();
418                if dist < min_dist {
419                    min_dist = dist;
420                    best_centroid = centroid;
421                }
422            }
423
424            quantized[idx] = best_centroid;
425        }
426
427        Ok(quantized)
428    }
429}
430
431// ---------------------------------------------------------------------------
432// MagnitudePruner
433// ---------------------------------------------------------------------------
434
435/// Unstructured magnitude-based weight pruner.
436///
437/// Zeroes out all weight entries whose absolute value is strictly below
438/// `threshold`. Tracks cumulative pruning statistics across calls.
439#[derive(Debug, Clone)]
440pub struct MagnitudePruner {
441    /// Magnitude threshold: entries with |w| < threshold are zeroed
442    pub threshold: f32,
443    /// Total number of entries pruned so far
444    pub pruned_count: usize,
445    /// Total number of entries processed so far
446    pub total_count: usize,
447}
448
449impl MagnitudePruner {
450    /// Create a new `MagnitudePruner` with the given magnitude threshold.
451    pub fn new(threshold: f32) -> Self {
452        Self {
453            threshold,
454            pruned_count: 0,
455            total_count: 0,
456        }
457    }
458
459    /// Prune a 2D weight matrix in-place. Returns the sparsity fraction of
460    /// this call (not cumulative).
461    pub fn prune_matrix(&mut self, w: &mut Array2<f32>) -> f32 {
462        let total = w.len();
463        let mut pruned = 0usize;
464        for v in w.iter_mut() {
465            if v.abs() < self.threshold {
466                *v = 0.0;
467                pruned += 1;
468            }
469        }
470        self.total_count += total;
471        self.pruned_count += pruned;
472        if total == 0 {
473            0.0
474        } else {
475            pruned as f32 / total as f32
476        }
477    }
478
479    /// Prune a 1D weight vector in-place. Returns the sparsity fraction of
480    /// this call (not cumulative).
481    pub fn prune_vector(&mut self, v: &mut Array1<f32>) -> f32 {
482        let total = v.len();
483        let mut pruned = 0usize;
484        for x in v.iter_mut() {
485            if x.abs() < self.threshold {
486                *x = 0.0;
487                pruned += 1;
488            }
489        }
490        self.total_count += total;
491        self.pruned_count += pruned;
492        if total == 0 {
493            0.0
494        } else {
495            pruned as f32 / total as f32
496        }
497    }
498
499    /// Cumulative sparsity fraction across all processed entries.
500    pub fn sparsity(&self) -> f32 {
501        if self.total_count == 0 {
502            0.0
503        } else {
504            self.pruned_count as f32 / self.total_count as f32
505        }
506    }
507
508    /// Reset cumulative statistics (threshold is kept).
509    pub fn reset_stats(&mut self) {
510        self.pruned_count = 0;
511        self.total_count = 0;
512    }
513}
514
515// ---------------------------------------------------------------------------
516// StructuredPruner
517// ---------------------------------------------------------------------------
518
519/// Structured pruner: removes entire rows (neurons/channels) with the
520/// smallest L2 norms, keeping `keep_fraction` of rows.
521#[derive(Debug, Clone)]
522pub struct StructuredPruner {
523    /// Fraction of rows to retain (0.0 – 1.0)
524    pub keep_fraction: f32,
525}
526
527impl StructuredPruner {
528    /// Create a new `StructuredPruner` that keeps `keep_fraction` of rows.
529    pub fn new(keep_fraction: f32) -> Self {
530        Self { keep_fraction }
531    }
532
533    /// Compute a boolean mask over rows (true = keep).
534    ///
535    /// The top `ceil(keep_fraction * nrows)` rows by L2 norm are kept.
536    pub fn prune_rows(&self, w: &Array2<f32>) -> ModelResult<Vec<bool>> {
537        let nrows = w.nrows();
538        if nrows == 0 {
539            return Err(ModelError::invalid_config(
540                "StructuredPruner::prune_rows: empty matrix",
541            ));
542        }
543        let keep = ((self.keep_fraction * nrows as f32).ceil() as usize).min(nrows);
544
545        // Compute L2 norm per row
546        let mut row_norms: Vec<(usize, f32)> = (0..nrows)
547            .map(|i| {
548                let norm = w.row(i).iter().map(|&x| x * x).sum::<f32>().sqrt();
549                (i, norm)
550            })
551            .collect();
552
553        // Sort descending by norm
554        row_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
555
556        let mut mask = vec![false; nrows];
557        for (row_idx, _) in row_norms.iter().take(keep) {
558            mask[*row_idx] = true;
559        }
560        Ok(mask)
561    }
562
563    /// Return a new matrix with pruned rows removed.
564    pub fn compress_rows(&self, w: &Array2<f32>) -> ModelResult<Array2<f32>> {
565        let mask = self.prune_rows(w)?;
566        let kept_rows: Vec<usize> = mask
567            .iter()
568            .enumerate()
569            .filter_map(|(i, &keep)| if keep { Some(i) } else { None })
570            .collect();
571
572        if kept_rows.is_empty() {
573            return Err(ModelError::invalid_config(
574                "StructuredPruner::compress_rows: no rows kept",
575            ));
576        }
577
578        let ncols = w.ncols();
579        let mut out = Array2::<f32>::zeros((kept_rows.len(), ncols));
580        for (new_i, &old_i) in kept_rows.iter().enumerate() {
581            for j in 0..ncols {
582                out[(new_i, j)] = w[(old_i, j)];
583            }
584        }
585        Ok(out)
586    }
587}
588
589// ---------------------------------------------------------------------------
590// LowRankApprox
591// ---------------------------------------------------------------------------
592
593/// Low-rank approximation W ≈ U @ diag(S) @ V^T computed via pure-Rust
594/// power iteration (no LAPACK / C dependencies).
595#[derive(Debug, Clone)]
596pub struct LowRankApprox {
597    /// Target rank
598    pub rank: usize,
599    /// Left singular vectors — shape `(rows, rank)`
600    pub u: Array2<f32>,
601    /// Right singular vectors (transposed) — shape `(rank, cols)`
602    pub vt: Array2<f32>,
603    /// Singular values — shape `(rank,)`
604    pub singular_values: Array1<f32>,
605    /// Relative Frobenius reconstruction error `||W - approx||_F / ||W||_F`
606    pub reconstruction_error: f32,
607}
608
609impl LowRankApprox {
610    /// Compute a rank-`rank` approximation of `w` using power iteration.
611    ///
612    /// `num_iter` controls the number of power-iteration steps per component.
613    /// Higher values give more accurate singular vectors.
614    pub fn compute(w: &Array2<f32>, rank: usize, num_iter: usize) -> ModelResult<Self> {
615        let rows = w.nrows();
616        let cols = w.ncols();
617
618        if rank == 0 {
619            return Err(ModelError::invalid_config(
620                "LowRankApprox: rank must be > 0",
621            ));
622        }
623        let effective_rank = rank.min(rows.min(cols));
624
625        let mut u_cols: Vec<Array1<f32>> = Vec::with_capacity(effective_rank);
626        let mut vt_rows: Vec<Array1<f32>> = Vec::with_capacity(effective_rank);
627        let mut sigmas: Vec<f32> = Vec::with_capacity(effective_rank);
628
629        // Working copy for deflation
630        let mut residual = w.clone();
631
632        for k in 0..effective_rank {
633            // Initialise right singular vector (deterministic)
634            let mut v = Array1::<f32>::zeros(cols);
635            v[k % cols] = 1.0;
636
637            let iters = num_iter.max(1);
638            for _ in 0..iters {
639                // u = residual @ v  — shape (rows,)
640                let mut u_vec = Array1::<f32>::zeros(rows);
641                for i in 0..rows {
642                    u_vec[i] = (0..cols).map(|j| residual[(i, j)] * v[j]).sum();
643                }
644                // sigma = ||u||
645                let sigma = u_vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
646                if sigma < 1e-12 {
647                    break;
648                }
649                // u = u / sigma
650                let u_norm = u_vec.mapv(|x| x / sigma);
651
652                // v_new = residual^T @ u_norm  — shape (cols,)
653                let mut v_new = Array1::<f32>::zeros(cols);
654                for j in 0..cols {
655                    v_new[j] = (0..rows).map(|i| residual[(i, j)] * u_norm[i]).sum();
656                }
657                let v_norm_val = v_new.iter().map(|&x| x * x).sum::<f32>().sqrt();
658                if v_norm_val < 1e-12 {
659                    break;
660                }
661                v = v_new.mapv(|x| x / v_norm_val);
662            }
663
664            // Final computation of u and sigma
665            let mut u_vec = Array1::<f32>::zeros(rows);
666            for i in 0..rows {
667                u_vec[i] = (0..cols).map(|j| residual[(i, j)] * v[j]).sum();
668            }
669            let sigma = u_vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
670            if sigma < 1e-12 {
671                // No more signal — fill remaining components with zeros
672                u_cols.push(Array1::zeros(rows));
673                vt_rows.push(Array1::zeros(cols));
674                sigmas.push(0.0);
675            } else {
676                let u_final = u_vec.mapv(|x| x / sigma);
677
678                // Deflate
679                for i in 0..rows {
680                    for j in 0..cols {
681                        residual[(i, j)] -= sigma * u_final[i] * v[j];
682                    }
683                }
684
685                u_cols.push(u_final);
686                vt_rows.push(v);
687                sigmas.push(sigma);
688            }
689        }
690
691        // Assemble U (rows, rank) and Vt (rank, cols)
692        let mut u_mat = Array2::<f32>::zeros((rows, effective_rank));
693        let mut vt_mat = Array2::<f32>::zeros((effective_rank, cols));
694        for k in 0..effective_rank {
695            for i in 0..rows {
696                u_mat[(i, k)] = u_cols[k][i];
697            }
698            for j in 0..cols {
699                vt_mat[(k, j)] = vt_rows[k][j];
700            }
701        }
702        let singular_values = Array1::from_vec(sigmas);
703
704        // Reconstruction error
705        let w_frob: f32 = w.iter().map(|&x| x * x).sum::<f32>().sqrt();
706        let rec_error = if w_frob < 1e-12 {
707            0.0
708        } else {
709            // approx = U S Vt
710            let mut err_sq = 0.0_f32;
711            for i in 0..rows {
712                for j in 0..cols {
713                    let approx: f32 = (0..effective_rank)
714                        .map(|k| u_mat[(i, k)] * singular_values[k] * vt_mat[(k, j)])
715                        .sum();
716                    err_sq += (w[(i, j)] - approx).powi(2);
717                }
718            }
719            err_sq.sqrt() / w_frob
720        };
721
722        Ok(Self {
723            rank: effective_rank,
724            u: u_mat,
725            vt: vt_mat,
726            singular_values,
727            reconstruction_error: rec_error,
728        })
729    }
730
731    /// Reconstruct the full matrix: U @ diag(S) @ V^T.
732    pub fn reconstruct(&self) -> ModelResult<Array2<f32>> {
733        let rows = self.u.nrows();
734        let cols = self.vt.ncols();
735        let mut out = Array2::<f32>::zeros((rows, cols));
736        for i in 0..rows {
737            for j in 0..cols {
738                out[(i, j)] = (0..self.rank)
739                    .map(|k| self.u[(i, k)] * self.singular_values[k] * self.vt[(k, j)])
740                    .sum();
741            }
742        }
743        Ok(out)
744    }
745
746    /// Compression ratio: `(rows * cols) / (rows * rank + rank * cols)`.
747    pub fn compression_ratio(&self) -> f32 {
748        let rows = self.u.nrows();
749        let cols = self.vt.ncols();
750        let original = rows * cols;
751        let compressed = rows * self.rank + self.rank * cols;
752        if compressed == 0 {
753            return f32::INFINITY;
754        }
755        original as f32 / compressed as f32
756    }
757
758    /// Fast forward pass using factored form: `(U @ diag(S)) @ (V^T @ x)`.
759    ///
760    /// `x` must have length equal to the number of columns (original input dim).
761    pub fn forward(&self, x: &Array1<f32>) -> ModelResult<Array1<f32>> {
762        let cols = self.vt.ncols();
763        let rows = self.u.nrows();
764        if x.len() != cols {
765            return Err(ModelError::dimension_mismatch(
766                "LowRankApprox::forward",
767                cols,
768                x.len(),
769            ));
770        }
771        // intermediate = V^T @ x  — shape (rank,)
772        let mut intermediate = Array1::<f32>::zeros(self.rank);
773        for k in 0..self.rank {
774            intermediate[k] = (0..cols).map(|j| self.vt[(k, j)] * x[j]).sum();
775        }
776        // scale by singular values
777        for k in 0..self.rank {
778            intermediate[k] *= self.singular_values[k];
779        }
780        // output = U @ intermediate  — shape (rows,)
781        let mut out = Array1::<f32>::zeros(rows);
782        for i in 0..rows {
783            out[i] = (0..self.rank)
784                .map(|k| self.u[(i, k)] * intermediate[k])
785                .sum();
786        }
787        Ok(out)
788    }
789}
790
791// ---------------------------------------------------------------------------
792// CompressionReport
793// ---------------------------------------------------------------------------
794
795/// Summary report of compression applied to a set of weight matrices.
796#[derive(Debug, Clone)]
797pub struct CompressionReport {
798    /// Total number of parameters in original model
799    pub original_params: usize,
800    /// Total number of parameters after compression
801    pub compressed_params: usize,
802    /// Number of parameters set to zero (pruned)
803    pub pruned_params: usize,
804    /// `(layer_name, original_rank, compressed_rank)` per layer
805    pub rank_reductions: Vec<(String, usize, usize)>,
806    /// Overall compression ratio
807    pub overall_compression_ratio: f32,
808}
809
810impl CompressionReport {
811    /// Create a new, empty compression report.
812    pub fn new() -> Self {
813        Self {
814            original_params: 0,
815            compressed_params: 0,
816            pruned_params: 0,
817            rank_reductions: Vec::new(),
818            overall_compression_ratio: 1.0,
819        }
820    }
821
822    /// Register a layer's original and compressed weight matrices.
823    ///
824    /// Updates parameter counts and compression ratio automatically.
825    pub fn add_layer(&mut self, name: &str, original: &Array2<f32>, compressed: &Array2<f32>) {
826        let orig_params = original.nrows() * original.ncols();
827        let comp_params = compressed.nrows() * compressed.ncols();
828
829        // Count zero entries in original as "pruned"
830        let pruned = original.iter().filter(|&&x| x == 0.0).count();
831
832        self.original_params += orig_params;
833        self.compressed_params += comp_params;
834        self.pruned_params += pruned;
835
836        let orig_rank = original.nrows().min(original.ncols());
837        let comp_rank = compressed.nrows().min(compressed.ncols());
838        self.rank_reductions
839            .push((name.to_string(), orig_rank, comp_rank));
840
841        self.overall_compression_ratio = if self.compressed_params == 0 {
842            f32::INFINITY
843        } else {
844            self.original_params as f32 / self.compressed_params as f32
845        };
846    }
847
848    /// Generate a human-readable summary string.
849    pub fn summary(&self) -> String {
850        let mut lines = vec![
851            "=== Compression Report ===".to_string(),
852            format!("Original parameters : {}", self.original_params),
853            format!("Compressed parameters: {}", self.compressed_params),
854            format!("Pruned parameters   : {}", self.pruned_params),
855            format!(
856                "Overall compression ratio: {:.3}x",
857                self.overall_compression_ratio
858            ),
859            String::new(),
860            "Layer rank reductions:".to_string(),
861        ];
862        for (name, orig_rank, comp_rank) in &self.rank_reductions {
863            lines.push(format!("  {}: rank {} -> {}", name, orig_rank, comp_rank));
864        }
865        lines.join("\n")
866    }
867}
868
869impl Default for CompressionReport {
870    fn default() -> Self {
871        Self::new()
872    }
873}
874
875// ---------------------------------------------------------------------------
876// Tests
877// ---------------------------------------------------------------------------
878
879#[cfg(test)]
880mod tests {
881    use super::*;
882
883    #[test]
884    fn test_prune_magnitude() {
885        let weights = Array2::from_shape_vec(
886            (3, 3),
887            vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0],
888        )
889        .expect("Failed to create test array");
890
891        let (pruned, mask) = prune_magnitude(&weights, 0.5).expect("Failed to prune");
892
893        // Should prune ~50% of smallest magnitude weights
894        let num_zeros = pruned.iter().filter(|&&x| x == 0.0).count();
895        assert!(num_zeros >= 4);
896        assert_eq!(pruned.dim(), weights.dim());
897        assert_eq!(mask.dim(), weights.dim());
898    }
899
900    #[test]
901    fn test_prune_threshold() {
902        let weights = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, 0.1, 2.0])
903            .expect("Failed to create test array");
904
905        let (pruned, mask) = prune_threshold(&weights, 0.6).expect("Failed to prune");
906
907        assert_eq!(pruned[[0, 0]], 1.0);
908        assert_eq!(pruned[[0, 1]], 0.0); // 0.5 < 0.6
909        assert_eq!(pruned[[1, 0]], 0.0); // 0.1 < 0.6
910        assert_eq!(pruned[[1, 1]], 2.0);
911
912        assert!(mask[[0, 0]]);
913        assert!(!mask[[0, 1]]);
914        assert!(!mask[[1, 0]]);
915        assert!(mask[[1, 1]]);
916    }
917
918    #[test]
919    fn test_distillation_loss() {
920        let student = Array1::from_vec(vec![2.0, 1.0, 0.1]);
921        let teacher = Array1::from_vec(vec![2.5, 1.5, 0.5]);
922
923        let loss = distillation_loss(&student, &teacher, 3.0).expect("Failed to compute loss");
924
925        assert!(loss >= 0.0);
926        assert!(loss.is_finite());
927    }
928
929    #[test]
930    fn test_pruning_config() {
931        let config = PruningConfig::magnitude_based(0.3)
932            .global(false)
933            .bounds(0.1, 0.8);
934
935        assert_eq!(config.strategy, PruningStrategy::Magnitude);
936        assert_eq!(config.sparsity, 0.3);
937        assert!(!config.global_threshold);
938        assert_eq!(config.min_sparsity, 0.1);
939        assert_eq!(config.max_sparsity, 0.8);
940    }
941
942    #[test]
943    fn test_distillation_config() {
944        let config = DistillationConfig::new(5.0, 0.8);
945
946        assert_eq!(config.temperature, 5.0);
947        assert_eq!(config.alpha, 0.8);
948        assert!((config.task_weight - 0.2).abs() < 1e-6);
949    }
950
951    #[test]
952    fn test_compression_ratio() {
953        let ratio = compression_ratio(1000, 250);
954        assert_eq!(ratio, 4.0);
955
956        let ratio = compression_ratio(1000, 1000);
957        assert_eq!(ratio, 1.0);
958    }
959
960    #[test]
961    fn test_pruning_stats() {
962        let mut stats = PruningStats::new();
963        stats.add_layer("layer1".to_string(), 1000, 300);
964        stats.add_layer("layer2".to_string(), 2000, 800);
965        stats.finalize();
966
967        assert_eq!(stats.total_params, 3000);
968        assert_eq!(stats.pruned_params, 1100);
969        assert!((stats.sparsity - 0.366667).abs() < 1e-5);
970        assert!(stats.compression_ratio > 1.0);
971    }
972
973    #[test]
974    fn test_kmeans_weight_sharing() {
975        let weights = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0])
976            .expect("Failed to create test array");
977
978        let quantized = weight_sharing::kmeans_cluster(&weights, 2).expect("Failed to cluster");
979
980        assert_eq!(quantized.dim(), weights.dim());
981
982        // Should have only 2 unique values
983        let unique_vals: std::collections::HashSet<_> =
984            quantized.iter().map(|&x| (x * 1000.0) as i32).collect();
985        assert!(unique_vals.len() <= 2);
986    }
987
988    // -----------------------------------------------------------------------
989    // MagnitudePruner tests
990    // -----------------------------------------------------------------------
991
992    #[test]
993    fn test_magnitude_pruner_basic() {
994        let mut pruner = MagnitudePruner::new(0.5);
995        // Values: 0.1, 0.2, 0.3, 0.4 are below threshold; 0.6, 0.7, 0.8, 0.9 are above
996        let mut w =
997            Array2::from_shape_vec((2, 4), vec![0.1_f32, 0.6, 0.2, 0.7, 0.3, 0.8, 0.4, 0.9])
998                .expect("shape");
999
1000        let sparsity = pruner.prune_matrix(&mut w);
1001        // 4 out of 8 entries are zeroed
1002        assert!(sparsity > 0.0, "sparsity should be > 0");
1003        let zero_count = w.iter().filter(|&&x| x == 0.0).count();
1004        assert_eq!(zero_count, 4);
1005        assert!(pruner.pruned_count > 0);
1006        assert!(pruner.total_count > 0);
1007    }
1008
1009    #[test]
1010    fn test_magnitude_pruner_zero_threshold() {
1011        let mut pruner = MagnitudePruner::new(0.0);
1012        let mut w = Array2::from_shape_vec((2, 2), vec![0.5_f32, 1.0, -0.3, 2.0]).expect("shape");
1013
1014        let sparsity = pruner.prune_matrix(&mut w);
1015        // threshold = 0 means |w| < 0, which is never true → nothing pruned
1016        assert_eq!(sparsity, 0.0, "zero threshold should prune nothing");
1017        assert_eq!(pruner.pruned_count, 0);
1018    }
1019
1020    // -----------------------------------------------------------------------
1021    // StructuredPruner tests
1022    // -----------------------------------------------------------------------
1023
1024    #[test]
1025    fn test_structured_pruner_row_mask_count() {
1026        let w = Array2::from_shape_fn((10, 4), |(i, j)| (i * 4 + j) as f32);
1027        let pruner = StructuredPruner::new(0.6);
1028        let mask = pruner.prune_rows(&w).expect("prune_rows failed");
1029
1030        let keep_count = mask.iter().filter(|&&k| k).count();
1031        // ceil(0.6 * 10) = 6
1032        assert_eq!(keep_count, 6, "expected 6 kept rows, got {keep_count}");
1033        assert_eq!(mask.len(), 10);
1034    }
1035
1036    #[test]
1037    fn test_structured_pruner_compress_reduces_rows() {
1038        let w = Array2::from_shape_fn((8, 3), |(i, j)| (i + j) as f32);
1039        let pruner = StructuredPruner::new(0.5);
1040        let compressed = pruner.compress_rows(&w).expect("compress_rows failed");
1041
1042        assert!(
1043            compressed.nrows() < w.nrows(),
1044            "compressed rows {} should be < original {}",
1045            compressed.nrows(),
1046            w.nrows()
1047        );
1048        assert_eq!(compressed.ncols(), w.ncols());
1049    }
1050
1051    // -----------------------------------------------------------------------
1052    // LowRankApprox tests
1053    // -----------------------------------------------------------------------
1054
1055    #[test]
1056    fn test_low_rank_approx_shapes() {
1057        let w = Array2::from_shape_fn((8, 6), |(i, j)| (i * j) as f32 * 0.1);
1058        let lra = LowRankApprox::compute(&w, 3, 50).expect("compute failed");
1059
1060        assert_eq!(lra.u.nrows(), 8);
1061        assert_eq!(lra.u.ncols(), 3);
1062        assert_eq!(lra.vt.nrows(), 3);
1063        assert_eq!(lra.vt.ncols(), 6);
1064        assert_eq!(lra.singular_values.len(), 3);
1065    }
1066
1067    #[test]
1068    fn test_low_rank_approx_reconstruction_error() {
1069        // Identity matrix: rank-4 approx should reconstruct perfectly
1070        let mut data = vec![0.0_f32; 16];
1071        for i in 0..4 {
1072            data[i * 4 + i] = 1.0;
1073        }
1074        let w = Array2::from_shape_vec((4, 4), data).expect("shape");
1075
1076        let lra = LowRankApprox::compute(&w, 4, 100).expect("compute failed");
1077        assert!(
1078            lra.reconstruction_error < 0.01,
1079            "reconstruction_error {} should be < 0.01",
1080            lra.reconstruction_error
1081        );
1082    }
1083
1084    #[test]
1085    fn test_low_rank_approx_compression_ratio() {
1086        // 10x10, rank 2 → (10*10) / (10*2 + 2*10) = 100/40 = 2.5 > 1
1087        let w = Array2::from_shape_fn((10, 10), |(i, j)| (i as f32).sin() + (j as f32).cos());
1088        let lra = LowRankApprox::compute(&w, 2, 20).expect("compute failed");
1089
1090        assert!(
1091            lra.compression_ratio() > 1.0,
1092            "compression_ratio {} should be > 1.0",
1093            lra.compression_ratio()
1094        );
1095    }
1096
1097    #[test]
1098    fn test_low_rank_forward_shape() {
1099        // w: 8x6, rank=3 → forward(x: 6) → output shape (8,)
1100        let w = Array2::from_shape_fn((8, 6), |(i, j)| ((i + j) as f32) * 0.1);
1101        let lra = LowRankApprox::compute(&w, 3, 30).expect("compute failed");
1102
1103        let x = Array1::from_vec(vec![1.0_f32; 6]);
1104        let out = lra.forward(&x).expect("forward failed");
1105        assert_eq!(out.len(), 8, "expected output len 8, got {}", out.len());
1106    }
1107
1108    #[test]
1109    fn test_distillation_loss_same_logits() {
1110        let logits = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
1111        let loss = distillation_loss(&logits, &logits, 1.0).expect("distillation_loss failed");
1112        // KL(p || p) = 0
1113        assert!(loss < 1e-5, "same logits should give loss ≈ 0, got {loss}");
1114    }
1115}