1use crate::error::{ModelError, ModelResult};
23use scirs2_core::ndarray::{Array1, Array2};
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26
27#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
29pub enum PruningStrategy {
30 Magnitude,
32 Random,
34 Structured,
36 Movement,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct PruningConfig {
43 pub strategy: PruningStrategy,
45 pub sparsity: f32,
47 pub global_threshold: bool,
49 pub min_sparsity: f32,
51 pub max_sparsity: f32,
53}
54
55impl PruningConfig {
56 pub fn magnitude_based(sparsity: f32) -> Self {
58 Self {
59 strategy: PruningStrategy::Magnitude,
60 sparsity,
61 global_threshold: true,
62 min_sparsity: 0.0,
63 max_sparsity: 0.95,
64 }
65 }
66
67 pub fn structured(sparsity: f32) -> Self {
69 Self {
70 strategy: PruningStrategy::Structured,
71 sparsity,
72 global_threshold: false,
73 min_sparsity: 0.0,
74 max_sparsity: 0.9,
75 }
76 }
77
78 pub fn global(mut self, global: bool) -> Self {
80 self.global_threshold = global;
81 self
82 }
83
84 pub fn bounds(mut self, min: f32, max: f32) -> Self {
86 self.min_sparsity = min;
87 self.max_sparsity = max;
88 self
89 }
90}
91
92#[derive(Debug, Clone)]
94pub struct PruningStats {
95 pub total_params: usize,
97 pub pruned_params: usize,
99 pub sparsity: f32,
101 pub compression_ratio: f32,
103 pub layer_stats: HashMap<String, LayerPruningStats>,
105}
106
107#[derive(Debug, Clone)]
109pub struct LayerPruningStats {
110 pub total: usize,
112 pub pruned: usize,
114 pub sparsity: f32,
116}
117
118impl PruningStats {
119 pub fn new() -> Self {
121 Self {
122 total_params: 0,
123 pruned_params: 0,
124 sparsity: 0.0,
125 compression_ratio: 1.0,
126 layer_stats: HashMap::new(),
127 }
128 }
129
130 pub fn finalize(&mut self) {
132 if self.total_params > 0 {
133 self.sparsity = self.pruned_params as f32 / self.total_params as f32;
134 self.compression_ratio = 1.0 / (1.0 - self.sparsity);
135 }
136 }
137
138 pub fn add_layer(&mut self, name: String, total: usize, pruned: usize) {
140 self.total_params += total;
141 self.pruned_params += pruned;
142
143 let sparsity = if total > 0 {
144 pruned as f32 / total as f32
145 } else {
146 0.0
147 };
148
149 self.layer_stats.insert(
150 name,
151 LayerPruningStats {
152 total,
153 pruned,
154 sparsity,
155 },
156 );
157 }
158
159 pub fn print_summary(&self) {
161 tracing::info!("=== Pruning Statistics ===");
162 tracing::info!("Total parameters: {}", self.total_params);
163 tracing::info!("Pruned parameters: {}", self.pruned_params);
164 tracing::info!("Sparsity: {:.2}%", self.sparsity * 100.0);
165 tracing::info!("Compression ratio: {:.2}x", self.compression_ratio);
166 tracing::info!("\nPer-layer statistics:");
167 for (name, stats) in &self.layer_stats {
168 tracing::info!(
169 " {}: {}/{} ({:.2}%)",
170 name,
171 stats.pruned,
172 stats.total,
173 stats.sparsity * 100.0
174 );
175 }
176 }
177}
178
179impl Default for PruningStats {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185pub fn prune_magnitude(
187 weights: &Array2<f32>,
188 sparsity: f32,
189) -> ModelResult<(Array2<f32>, Array2<bool>)> {
190 if !(0.0..=1.0).contains(&sparsity) {
191 return Err(ModelError::invalid_config(format!(
192 "Pruning: Sparsity must be between 0 and 1, got {}",
193 sparsity
194 )));
195 }
196
197 let total_elements = weights.len();
198 let num_to_prune = (total_elements as f32 * sparsity) as usize;
199
200 let mut abs_weights: Vec<(f32, (usize, usize))> = weights
202 .indexed_iter()
203 .map(|(idx, &val)| (val.abs(), idx))
204 .collect();
205
206 abs_weights.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
207
208 let mut mask = Array2::from_elem(weights.dim(), true);
210 for i in 0..num_to_prune {
211 if i < abs_weights.len() {
212 let (_, idx) = abs_weights[i];
213 mask[idx] = false;
214 }
215 }
216
217 let pruned = weights * &mask.mapv(|x| if x { 1.0 } else { 0.0 });
219
220 Ok((pruned, mask))
221}
222
223pub fn prune_threshold(
225 weights: &Array2<f32>,
226 threshold: f32,
227) -> ModelResult<(Array2<f32>, Array2<bool>)> {
228 let mask = weights.mapv(|x| x.abs() >= threshold);
229 let pruned = weights * &mask.mapv(|x| if x { 1.0 } else { 0.0 });
230
231 Ok((pruned, mask))
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct DistillationConfig {
237 pub temperature: f32,
239 pub alpha: f32,
241 pub task_weight: f32,
243}
244
245impl Default for DistillationConfig {
246 fn default() -> Self {
247 Self {
248 temperature: 3.0,
249 alpha: 0.7,
250 task_weight: 0.3,
251 }
252 }
253}
254
255impl DistillationConfig {
256 pub fn new(temperature: f32, alpha: f32) -> Self {
258 Self {
259 temperature,
260 alpha,
261 task_weight: 1.0 - alpha,
262 }
263 }
264
265 pub fn temperature(mut self, temp: f32) -> Self {
267 self.temperature = temp;
268 self
269 }
270
271 pub fn alpha(mut self, alpha: f32) -> Self {
273 self.alpha = alpha;
274 self.task_weight = 1.0 - alpha;
275 self
276 }
277}
278
279pub fn distillation_loss(
281 student_logits: &Array1<f32>,
282 teacher_logits: &Array1<f32>,
283 temperature: f32,
284) -> ModelResult<f32> {
285 if student_logits.len() != teacher_logits.len() {
286 return Err(ModelError::dimension_mismatch(
287 "distillation loss",
288 student_logits.len(),
289 teacher_logits.len(),
290 ));
291 }
292
293 let student_scaled = student_logits.mapv(|x| x / temperature);
295 let teacher_scaled = teacher_logits.mapv(|x| x / temperature);
296
297 let student_max = student_scaled.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
299 let teacher_max = teacher_scaled.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
300
301 let student_exp = student_scaled.mapv(|x| (x - student_max).exp());
302 let teacher_exp = teacher_scaled.mapv(|x| (x - teacher_max).exp());
303
304 let student_sum = student_exp.sum();
305 let teacher_sum = teacher_exp.sum();
306
307 let student_probs = &student_exp / student_sum;
308 let teacher_probs = &teacher_exp / teacher_sum;
309
310 let mut kl_div = 0.0;
312 for i in 0..student_probs.len() {
313 if teacher_probs[i] > 1e-10 && student_probs[i] > 1e-10 {
314 kl_div += teacher_probs[i] * (teacher_probs[i] / student_probs[i]).ln();
315 }
316 }
317
318 Ok(kl_div * temperature * temperature)
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct LowRankConfig {
325 pub rank: usize,
327 pub use_svd: bool,
329}
330
331impl LowRankConfig {
332 pub fn new(rank: usize) -> Self {
334 Self {
335 rank,
336 use_svd: true,
337 }
338 }
339
340 pub fn svd(mut self, use_svd: bool) -> Self {
342 self.use_svd = use_svd;
343 self
344 }
345}
346
347pub fn compression_ratio(original_size: usize, compressed_size: usize) -> f32 {
349 if compressed_size == 0 {
350 return f32::INFINITY;
351 }
352 original_size as f32 / compressed_size as f32
353}
354
355pub mod weight_sharing {
357 use super::*;
358
359 pub fn kmeans_cluster(weights: &Array2<f32>, num_clusters: usize) -> ModelResult<Array2<f32>> {
361 if num_clusters == 0 || num_clusters > weights.len() {
362 return Err(ModelError::invalid_config(format!(
363 "K-means clustering: Invalid number of clusters: {}",
364 num_clusters
365 )));
366 }
367
368 let flat_weights: Vec<f32> = weights.iter().copied().collect();
371
372 let mut centroids = Vec::new();
374 let step = flat_weights.len() / num_clusters;
375 for i in 0..num_clusters {
376 if i * step < flat_weights.len() {
377 centroids.push(flat_weights[i * step]);
378 }
379 }
380
381 for _ in 0..10 {
383 let mut cluster_sums = vec![0.0; num_clusters];
384 let mut cluster_counts = vec![0usize; num_clusters];
385
386 for &weight in &flat_weights {
387 let mut min_dist = f32::INFINITY;
388 let mut cluster_id = 0;
389
390 for (i, ¢roid) in centroids.iter().enumerate() {
391 let dist = (weight - centroid).abs();
392 if dist < min_dist {
393 min_dist = dist;
394 cluster_id = i;
395 }
396 }
397
398 cluster_sums[cluster_id] += weight;
399 cluster_counts[cluster_id] += 1;
400 }
401
402 for i in 0..num_clusters {
404 if cluster_counts[i] > 0 {
405 centroids[i] = cluster_sums[i] / cluster_counts[i] as f32;
406 }
407 }
408 }
409
410 let mut quantized = Array2::zeros(weights.dim());
412 for (idx, &weight) in weights.indexed_iter() {
413 let mut min_dist = f32::INFINITY;
414 let mut best_centroid = centroids[0];
415
416 for ¢roid in ¢roids {
417 let dist = (weight - centroid).abs();
418 if dist < min_dist {
419 min_dist = dist;
420 best_centroid = centroid;
421 }
422 }
423
424 quantized[idx] = best_centroid;
425 }
426
427 Ok(quantized)
428 }
429}
430
431#[derive(Debug, Clone)]
440pub struct MagnitudePruner {
441 pub threshold: f32,
443 pub pruned_count: usize,
445 pub total_count: usize,
447}
448
449impl MagnitudePruner {
450 pub fn new(threshold: f32) -> Self {
452 Self {
453 threshold,
454 pruned_count: 0,
455 total_count: 0,
456 }
457 }
458
459 pub fn prune_matrix(&mut self, w: &mut Array2<f32>) -> f32 {
462 let total = w.len();
463 let mut pruned = 0usize;
464 for v in w.iter_mut() {
465 if v.abs() < self.threshold {
466 *v = 0.0;
467 pruned += 1;
468 }
469 }
470 self.total_count += total;
471 self.pruned_count += pruned;
472 if total == 0 {
473 0.0
474 } else {
475 pruned as f32 / total as f32
476 }
477 }
478
479 pub fn prune_vector(&mut self, v: &mut Array1<f32>) -> f32 {
482 let total = v.len();
483 let mut pruned = 0usize;
484 for x in v.iter_mut() {
485 if x.abs() < self.threshold {
486 *x = 0.0;
487 pruned += 1;
488 }
489 }
490 self.total_count += total;
491 self.pruned_count += pruned;
492 if total == 0 {
493 0.0
494 } else {
495 pruned as f32 / total as f32
496 }
497 }
498
499 pub fn sparsity(&self) -> f32 {
501 if self.total_count == 0 {
502 0.0
503 } else {
504 self.pruned_count as f32 / self.total_count as f32
505 }
506 }
507
508 pub fn reset_stats(&mut self) {
510 self.pruned_count = 0;
511 self.total_count = 0;
512 }
513}
514
515#[derive(Debug, Clone)]
522pub struct StructuredPruner {
523 pub keep_fraction: f32,
525}
526
527impl StructuredPruner {
528 pub fn new(keep_fraction: f32) -> Self {
530 Self { keep_fraction }
531 }
532
533 pub fn prune_rows(&self, w: &Array2<f32>) -> ModelResult<Vec<bool>> {
537 let nrows = w.nrows();
538 if nrows == 0 {
539 return Err(ModelError::invalid_config(
540 "StructuredPruner::prune_rows: empty matrix",
541 ));
542 }
543 let keep = ((self.keep_fraction * nrows as f32).ceil() as usize).min(nrows);
544
545 let mut row_norms: Vec<(usize, f32)> = (0..nrows)
547 .map(|i| {
548 let norm = w.row(i).iter().map(|&x| x * x).sum::<f32>().sqrt();
549 (i, norm)
550 })
551 .collect();
552
553 row_norms.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
555
556 let mut mask = vec![false; nrows];
557 for (row_idx, _) in row_norms.iter().take(keep) {
558 mask[*row_idx] = true;
559 }
560 Ok(mask)
561 }
562
563 pub fn compress_rows(&self, w: &Array2<f32>) -> ModelResult<Array2<f32>> {
565 let mask = self.prune_rows(w)?;
566 let kept_rows: Vec<usize> = mask
567 .iter()
568 .enumerate()
569 .filter_map(|(i, &keep)| if keep { Some(i) } else { None })
570 .collect();
571
572 if kept_rows.is_empty() {
573 return Err(ModelError::invalid_config(
574 "StructuredPruner::compress_rows: no rows kept",
575 ));
576 }
577
578 let ncols = w.ncols();
579 let mut out = Array2::<f32>::zeros((kept_rows.len(), ncols));
580 for (new_i, &old_i) in kept_rows.iter().enumerate() {
581 for j in 0..ncols {
582 out[(new_i, j)] = w[(old_i, j)];
583 }
584 }
585 Ok(out)
586 }
587}
588
589#[derive(Debug, Clone)]
596pub struct LowRankApprox {
597 pub rank: usize,
599 pub u: Array2<f32>,
601 pub vt: Array2<f32>,
603 pub singular_values: Array1<f32>,
605 pub reconstruction_error: f32,
607}
608
609impl LowRankApprox {
610 pub fn compute(w: &Array2<f32>, rank: usize, num_iter: usize) -> ModelResult<Self> {
615 let rows = w.nrows();
616 let cols = w.ncols();
617
618 if rank == 0 {
619 return Err(ModelError::invalid_config(
620 "LowRankApprox: rank must be > 0",
621 ));
622 }
623 let effective_rank = rank.min(rows.min(cols));
624
625 let mut u_cols: Vec<Array1<f32>> = Vec::with_capacity(effective_rank);
626 let mut vt_rows: Vec<Array1<f32>> = Vec::with_capacity(effective_rank);
627 let mut sigmas: Vec<f32> = Vec::with_capacity(effective_rank);
628
629 let mut residual = w.clone();
631
632 for k in 0..effective_rank {
633 let mut v = Array1::<f32>::zeros(cols);
635 v[k % cols] = 1.0;
636
637 let iters = num_iter.max(1);
638 for _ in 0..iters {
639 let mut u_vec = Array1::<f32>::zeros(rows);
641 for i in 0..rows {
642 u_vec[i] = (0..cols).map(|j| residual[(i, j)] * v[j]).sum();
643 }
644 let sigma = u_vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
646 if sigma < 1e-12 {
647 break;
648 }
649 let u_norm = u_vec.mapv(|x| x / sigma);
651
652 let mut v_new = Array1::<f32>::zeros(cols);
654 for j in 0..cols {
655 v_new[j] = (0..rows).map(|i| residual[(i, j)] * u_norm[i]).sum();
656 }
657 let v_norm_val = v_new.iter().map(|&x| x * x).sum::<f32>().sqrt();
658 if v_norm_val < 1e-12 {
659 break;
660 }
661 v = v_new.mapv(|x| x / v_norm_val);
662 }
663
664 let mut u_vec = Array1::<f32>::zeros(rows);
666 for i in 0..rows {
667 u_vec[i] = (0..cols).map(|j| residual[(i, j)] * v[j]).sum();
668 }
669 let sigma = u_vec.iter().map(|&x| x * x).sum::<f32>().sqrt();
670 if sigma < 1e-12 {
671 u_cols.push(Array1::zeros(rows));
673 vt_rows.push(Array1::zeros(cols));
674 sigmas.push(0.0);
675 } else {
676 let u_final = u_vec.mapv(|x| x / sigma);
677
678 for i in 0..rows {
680 for j in 0..cols {
681 residual[(i, j)] -= sigma * u_final[i] * v[j];
682 }
683 }
684
685 u_cols.push(u_final);
686 vt_rows.push(v);
687 sigmas.push(sigma);
688 }
689 }
690
691 let mut u_mat = Array2::<f32>::zeros((rows, effective_rank));
693 let mut vt_mat = Array2::<f32>::zeros((effective_rank, cols));
694 for k in 0..effective_rank {
695 for i in 0..rows {
696 u_mat[(i, k)] = u_cols[k][i];
697 }
698 for j in 0..cols {
699 vt_mat[(k, j)] = vt_rows[k][j];
700 }
701 }
702 let singular_values = Array1::from_vec(sigmas);
703
704 let w_frob: f32 = w.iter().map(|&x| x * x).sum::<f32>().sqrt();
706 let rec_error = if w_frob < 1e-12 {
707 0.0
708 } else {
709 let mut err_sq = 0.0_f32;
711 for i in 0..rows {
712 for j in 0..cols {
713 let approx: f32 = (0..effective_rank)
714 .map(|k| u_mat[(i, k)] * singular_values[k] * vt_mat[(k, j)])
715 .sum();
716 err_sq += (w[(i, j)] - approx).powi(2);
717 }
718 }
719 err_sq.sqrt() / w_frob
720 };
721
722 Ok(Self {
723 rank: effective_rank,
724 u: u_mat,
725 vt: vt_mat,
726 singular_values,
727 reconstruction_error: rec_error,
728 })
729 }
730
731 pub fn reconstruct(&self) -> ModelResult<Array2<f32>> {
733 let rows = self.u.nrows();
734 let cols = self.vt.ncols();
735 let mut out = Array2::<f32>::zeros((rows, cols));
736 for i in 0..rows {
737 for j in 0..cols {
738 out[(i, j)] = (0..self.rank)
739 .map(|k| self.u[(i, k)] * self.singular_values[k] * self.vt[(k, j)])
740 .sum();
741 }
742 }
743 Ok(out)
744 }
745
746 pub fn compression_ratio(&self) -> f32 {
748 let rows = self.u.nrows();
749 let cols = self.vt.ncols();
750 let original = rows * cols;
751 let compressed = rows * self.rank + self.rank * cols;
752 if compressed == 0 {
753 return f32::INFINITY;
754 }
755 original as f32 / compressed as f32
756 }
757
758 pub fn forward(&self, x: &Array1<f32>) -> ModelResult<Array1<f32>> {
762 let cols = self.vt.ncols();
763 let rows = self.u.nrows();
764 if x.len() != cols {
765 return Err(ModelError::dimension_mismatch(
766 "LowRankApprox::forward",
767 cols,
768 x.len(),
769 ));
770 }
771 let mut intermediate = Array1::<f32>::zeros(self.rank);
773 for k in 0..self.rank {
774 intermediate[k] = (0..cols).map(|j| self.vt[(k, j)] * x[j]).sum();
775 }
776 for k in 0..self.rank {
778 intermediate[k] *= self.singular_values[k];
779 }
780 let mut out = Array1::<f32>::zeros(rows);
782 for i in 0..rows {
783 out[i] = (0..self.rank)
784 .map(|k| self.u[(i, k)] * intermediate[k])
785 .sum();
786 }
787 Ok(out)
788 }
789}
790
791#[derive(Debug, Clone)]
797pub struct CompressionReport {
798 pub original_params: usize,
800 pub compressed_params: usize,
802 pub pruned_params: usize,
804 pub rank_reductions: Vec<(String, usize, usize)>,
806 pub overall_compression_ratio: f32,
808}
809
810impl CompressionReport {
811 pub fn new() -> Self {
813 Self {
814 original_params: 0,
815 compressed_params: 0,
816 pruned_params: 0,
817 rank_reductions: Vec::new(),
818 overall_compression_ratio: 1.0,
819 }
820 }
821
822 pub fn add_layer(&mut self, name: &str, original: &Array2<f32>, compressed: &Array2<f32>) {
826 let orig_params = original.nrows() * original.ncols();
827 let comp_params = compressed.nrows() * compressed.ncols();
828
829 let pruned = original.iter().filter(|&&x| x == 0.0).count();
831
832 self.original_params += orig_params;
833 self.compressed_params += comp_params;
834 self.pruned_params += pruned;
835
836 let orig_rank = original.nrows().min(original.ncols());
837 let comp_rank = compressed.nrows().min(compressed.ncols());
838 self.rank_reductions
839 .push((name.to_string(), orig_rank, comp_rank));
840
841 self.overall_compression_ratio = if self.compressed_params == 0 {
842 f32::INFINITY
843 } else {
844 self.original_params as f32 / self.compressed_params as f32
845 };
846 }
847
848 pub fn summary(&self) -> String {
850 let mut lines = vec![
851 "=== Compression Report ===".to_string(),
852 format!("Original parameters : {}", self.original_params),
853 format!("Compressed parameters: {}", self.compressed_params),
854 format!("Pruned parameters : {}", self.pruned_params),
855 format!(
856 "Overall compression ratio: {:.3}x",
857 self.overall_compression_ratio
858 ),
859 String::new(),
860 "Layer rank reductions:".to_string(),
861 ];
862 for (name, orig_rank, comp_rank) in &self.rank_reductions {
863 lines.push(format!(" {}: rank {} -> {}", name, orig_rank, comp_rank));
864 }
865 lines.join("\n")
866 }
867}
868
869impl Default for CompressionReport {
870 fn default() -> Self {
871 Self::new()
872 }
873}
874
875#[cfg(test)]
880mod tests {
881 use super::*;
882
883 #[test]
884 fn test_prune_magnitude() {
885 let weights = Array2::from_shape_vec(
886 (3, 3),
887 vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0],
888 )
889 .expect("Failed to create test array");
890
891 let (pruned, mask) = prune_magnitude(&weights, 0.5).expect("Failed to prune");
892
893 let num_zeros = pruned.iter().filter(|&&x| x == 0.0).count();
895 assert!(num_zeros >= 4);
896 assert_eq!(pruned.dim(), weights.dim());
897 assert_eq!(mask.dim(), weights.dim());
898 }
899
900 #[test]
901 fn test_prune_threshold() {
902 let weights = Array2::from_shape_vec((2, 2), vec![1.0, 0.5, 0.1, 2.0])
903 .expect("Failed to create test array");
904
905 let (pruned, mask) = prune_threshold(&weights, 0.6).expect("Failed to prune");
906
907 assert_eq!(pruned[[0, 0]], 1.0);
908 assert_eq!(pruned[[0, 1]], 0.0); assert_eq!(pruned[[1, 0]], 0.0); assert_eq!(pruned[[1, 1]], 2.0);
911
912 assert!(mask[[0, 0]]);
913 assert!(!mask[[0, 1]]);
914 assert!(!mask[[1, 0]]);
915 assert!(mask[[1, 1]]);
916 }
917
918 #[test]
919 fn test_distillation_loss() {
920 let student = Array1::from_vec(vec![2.0, 1.0, 0.1]);
921 let teacher = Array1::from_vec(vec![2.5, 1.5, 0.5]);
922
923 let loss = distillation_loss(&student, &teacher, 3.0).expect("Failed to compute loss");
924
925 assert!(loss >= 0.0);
926 assert!(loss.is_finite());
927 }
928
929 #[test]
930 fn test_pruning_config() {
931 let config = PruningConfig::magnitude_based(0.3)
932 .global(false)
933 .bounds(0.1, 0.8);
934
935 assert_eq!(config.strategy, PruningStrategy::Magnitude);
936 assert_eq!(config.sparsity, 0.3);
937 assert!(!config.global_threshold);
938 assert_eq!(config.min_sparsity, 0.1);
939 assert_eq!(config.max_sparsity, 0.8);
940 }
941
942 #[test]
943 fn test_distillation_config() {
944 let config = DistillationConfig::new(5.0, 0.8);
945
946 assert_eq!(config.temperature, 5.0);
947 assert_eq!(config.alpha, 0.8);
948 assert!((config.task_weight - 0.2).abs() < 1e-6);
949 }
950
951 #[test]
952 fn test_compression_ratio() {
953 let ratio = compression_ratio(1000, 250);
954 assert_eq!(ratio, 4.0);
955
956 let ratio = compression_ratio(1000, 1000);
957 assert_eq!(ratio, 1.0);
958 }
959
960 #[test]
961 fn test_pruning_stats() {
962 let mut stats = PruningStats::new();
963 stats.add_layer("layer1".to_string(), 1000, 300);
964 stats.add_layer("layer2".to_string(), 2000, 800);
965 stats.finalize();
966
967 assert_eq!(stats.total_params, 3000);
968 assert_eq!(stats.pruned_params, 1100);
969 assert!((stats.sparsity - 0.366667).abs() < 1e-5);
970 assert!(stats.compression_ratio > 1.0);
971 }
972
973 #[test]
974 fn test_kmeans_weight_sharing() {
975 let weights = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0])
976 .expect("Failed to create test array");
977
978 let quantized = weight_sharing::kmeans_cluster(&weights, 2).expect("Failed to cluster");
979
980 assert_eq!(quantized.dim(), weights.dim());
981
982 let unique_vals: std::collections::HashSet<_> =
984 quantized.iter().map(|&x| (x * 1000.0) as i32).collect();
985 assert!(unique_vals.len() <= 2);
986 }
987
988 #[test]
993 fn test_magnitude_pruner_basic() {
994 let mut pruner = MagnitudePruner::new(0.5);
995 let mut w =
997 Array2::from_shape_vec((2, 4), vec![0.1_f32, 0.6, 0.2, 0.7, 0.3, 0.8, 0.4, 0.9])
998 .expect("shape");
999
1000 let sparsity = pruner.prune_matrix(&mut w);
1001 assert!(sparsity > 0.0, "sparsity should be > 0");
1003 let zero_count = w.iter().filter(|&&x| x == 0.0).count();
1004 assert_eq!(zero_count, 4);
1005 assert!(pruner.pruned_count > 0);
1006 assert!(pruner.total_count > 0);
1007 }
1008
1009 #[test]
1010 fn test_magnitude_pruner_zero_threshold() {
1011 let mut pruner = MagnitudePruner::new(0.0);
1012 let mut w = Array2::from_shape_vec((2, 2), vec![0.5_f32, 1.0, -0.3, 2.0]).expect("shape");
1013
1014 let sparsity = pruner.prune_matrix(&mut w);
1015 assert_eq!(sparsity, 0.0, "zero threshold should prune nothing");
1017 assert_eq!(pruner.pruned_count, 0);
1018 }
1019
1020 #[test]
1025 fn test_structured_pruner_row_mask_count() {
1026 let w = Array2::from_shape_fn((10, 4), |(i, j)| (i * 4 + j) as f32);
1027 let pruner = StructuredPruner::new(0.6);
1028 let mask = pruner.prune_rows(&w).expect("prune_rows failed");
1029
1030 let keep_count = mask.iter().filter(|&&k| k).count();
1031 assert_eq!(keep_count, 6, "expected 6 kept rows, got {keep_count}");
1033 assert_eq!(mask.len(), 10);
1034 }
1035
1036 #[test]
1037 fn test_structured_pruner_compress_reduces_rows() {
1038 let w = Array2::from_shape_fn((8, 3), |(i, j)| (i + j) as f32);
1039 let pruner = StructuredPruner::new(0.5);
1040 let compressed = pruner.compress_rows(&w).expect("compress_rows failed");
1041
1042 assert!(
1043 compressed.nrows() < w.nrows(),
1044 "compressed rows {} should be < original {}",
1045 compressed.nrows(),
1046 w.nrows()
1047 );
1048 assert_eq!(compressed.ncols(), w.ncols());
1049 }
1050
1051 #[test]
1056 fn test_low_rank_approx_shapes() {
1057 let w = Array2::from_shape_fn((8, 6), |(i, j)| (i * j) as f32 * 0.1);
1058 let lra = LowRankApprox::compute(&w, 3, 50).expect("compute failed");
1059
1060 assert_eq!(lra.u.nrows(), 8);
1061 assert_eq!(lra.u.ncols(), 3);
1062 assert_eq!(lra.vt.nrows(), 3);
1063 assert_eq!(lra.vt.ncols(), 6);
1064 assert_eq!(lra.singular_values.len(), 3);
1065 }
1066
1067 #[test]
1068 fn test_low_rank_approx_reconstruction_error() {
1069 let mut data = vec![0.0_f32; 16];
1071 for i in 0..4 {
1072 data[i * 4 + i] = 1.0;
1073 }
1074 let w = Array2::from_shape_vec((4, 4), data).expect("shape");
1075
1076 let lra = LowRankApprox::compute(&w, 4, 100).expect("compute failed");
1077 assert!(
1078 lra.reconstruction_error < 0.01,
1079 "reconstruction_error {} should be < 0.01",
1080 lra.reconstruction_error
1081 );
1082 }
1083
1084 #[test]
1085 fn test_low_rank_approx_compression_ratio() {
1086 let w = Array2::from_shape_fn((10, 10), |(i, j)| (i as f32).sin() + (j as f32).cos());
1088 let lra = LowRankApprox::compute(&w, 2, 20).expect("compute failed");
1089
1090 assert!(
1091 lra.compression_ratio() > 1.0,
1092 "compression_ratio {} should be > 1.0",
1093 lra.compression_ratio()
1094 );
1095 }
1096
1097 #[test]
1098 fn test_low_rank_forward_shape() {
1099 let w = Array2::from_shape_fn((8, 6), |(i, j)| ((i + j) as f32) * 0.1);
1101 let lra = LowRankApprox::compute(&w, 3, 30).expect("compute failed");
1102
1103 let x = Array1::from_vec(vec![1.0_f32; 6]);
1104 let out = lra.forward(&x).expect("forward failed");
1105 assert_eq!(out.len(), 8, "expected output len 8, got {}", out.len());
1106 }
1107
1108 #[test]
1109 fn test_distillation_loss_same_logits() {
1110 let logits = Array1::from_vec(vec![1.0_f32, 2.0, 3.0]);
1111 let loss = distillation_loss(&logits, &logits, 1.0).expect("distillation_loss failed");
1112 assert!(loss < 1e-5, "same logits should give loss ≈ 0, got {loss}");
1114 }
1115}