1use anyhow::{anyhow, Result};
7use scirs2_core::ndarray_ext::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct QuantizationConfig {
14 pub method: QuantizationMethod,
16 pub bit_precision: u8,
18 pub calibration_size: usize,
20 pub per_channel: bool,
22 pub symmetric: bool,
24 pub qat_enabled: bool,
26 pub target: OptimizationTarget,
28}
29
30impl Default for QuantizationConfig {
31 fn default() -> Self {
32 Self {
33 method: QuantizationMethod::PostTrainingQuantization,
34 bit_precision: 8,
35 calibration_size: 1000,
36 per_channel: true,
37 symmetric: true,
38 qat_enabled: false,
39 target: OptimizationTarget::Speed,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub enum QuantizationMethod {
47 PostTrainingQuantization,
49 QuantizationAwareTraining,
51 DynamicQuantization,
53 BinaryNeuralNetworks,
55 MixedBitQuantization,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub enum OptimizationTarget {
62 Speed,
64 Memory,
66 Energy,
68 Balanced,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PruningConfig {
75 pub method: PruningMethod,
77 pub sparsity_ratio: f32,
79 pub structured: bool,
81 pub schedule: PruningSchedule,
83 pub fine_tune_epochs: usize,
85 pub magnitude_threshold: f32,
87}
88
89impl Default for PruningConfig {
90 fn default() -> Self {
91 Self {
92 method: PruningMethod::MagnitudePruning,
93 sparsity_ratio: 0.5,
94 structured: false,
95 schedule: PruningSchedule::Gradual,
96 fine_tune_epochs: 10,
97 magnitude_threshold: 0.01,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub enum PruningMethod {
105 MagnitudePruning,
107 SNIP,
109 LotteryTicket,
111 FisherInformation,
113 GradualMagnitude,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub enum PruningSchedule {
120 OneShot,
122 Gradual,
124 PolynomialDecay,
126 ExponentialDecay,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct DistillationConfig {
133 pub teacher_model: String,
135 pub student_model: String,
137 pub temperature: f32,
139 pub alpha: f32,
141 pub distillation_type: DistillationType,
143 pub feature_layers: Vec<usize>,
145 pub attention_transfer: bool,
147}
148
149impl Default for DistillationConfig {
150 fn default() -> Self {
151 Self {
152 teacher_model: "large_transformer".to_string(),
153 student_model: "small_transformer".to_string(),
154 temperature: 4.0,
155 alpha: 0.3,
156 distillation_type: DistillationType::ResponseBased,
157 feature_layers: vec![6, 12],
158 attention_transfer: true,
159 }
160 }
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub enum DistillationType {
166 ResponseBased,
168 FeatureBased,
170 AttentionBased,
172 RelationBased,
174 MultiTeacher,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct NASConfig {
181 pub strategy: SearchStrategy,
183 pub search_space: SearchSpace,
185 pub num_architectures: usize,
187 pub max_search_time: f32,
189 pub hardware_constraints: HardwareConstraints,
191 pub use_predictor: bool,
193}
194
195impl Default for NASConfig {
196 fn default() -> Self {
197 Self {
198 strategy: SearchStrategy::Evolutionary,
199 search_space: SearchSpace::MicroSearch,
200 num_architectures: 100,
201 max_search_time: 24.0,
202 hardware_constraints: HardwareConstraints::default(),
203 use_predictor: true,
204 }
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub enum SearchStrategy {
211 Random,
213 Evolutionary,
215 ReinforcementLearning,
217 GradientBased,
219 BayesianOptimization,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub enum SearchSpace {
226 MacroSearch,
228 MicroSearch,
230 Hierarchical,
232 Progressive,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct HardwareConstraints {
239 pub max_memory_mb: usize,
241 pub max_inference_time_ms: f32,
243 pub max_energy_mj: f32,
245 pub platform: HardwarePlatform,
247}
248
249impl Default for HardwareConstraints {
250 fn default() -> Self {
251 Self {
252 max_memory_mb: 512,
253 max_inference_time_ms: 100.0,
254 max_energy_mj: 10.0,
255 platform: HardwarePlatform::CPU,
256 }
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub enum HardwarePlatform {
263 CPU,
264 GPU,
265 TPU,
266 EdgeTPU,
267 Mobile,
268 FPGA,
269}
270
271pub struct ModelCompressionManager {
273 pub quantization: QuantizationProcessor,
275 pub pruning: PruningProcessor,
277 pub distillation: DistillationProcessor,
279 pub nas: NASProcessor,
281}
282
283impl Default for ModelCompressionManager {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289impl ModelCompressionManager {
290 pub fn new() -> Self {
292 Self {
293 quantization: QuantizationProcessor::new(QuantizationConfig::default()),
294 pruning: PruningProcessor::new(PruningConfig::default()),
295 distillation: DistillationProcessor::new(DistillationConfig::default()),
296 nas: NASProcessor::new(NASConfig::default()),
297 }
298 }
299
300 pub async fn compress_model(
302 &mut self,
303 model_weights: &HashMap<String, Array2<f32>>,
304 compression_target: CompressionTarget,
305 ) -> Result<CompressedModel> {
306 println!("🗜️ Starting model compression with target: {compression_target:?}");
307
308 let mut compressed_weights = model_weights.clone();
309 let mut compression_stats = CompressionStats::default();
310
311 println!("✂️ Applying pruning...");
313 let pruning_result = self.pruning.prune_weights(&compressed_weights).await?;
314 compressed_weights = pruning_result.pruned_weights;
315 compression_stats.sparsity_ratio = pruning_result.sparsity_achieved;
316
317 println!("📊 Applying quantization...");
319 let quantization_result = self
320 .quantization
321 .quantize_weights(&compressed_weights)
322 .await?;
323 let quantized_weights = quantization_result.quantized_weights;
324 compression_stats.quantization_ratio = quantization_result.compression_ratio;
325
326 let distilled_weights = if compression_target.enable_distillation {
328 println!("🎓 Applying knowledge distillation...");
329 let distillation_result = self
330 .distillation
331 .distill_knowledge(&compressed_weights)
332 .await?;
333 compression_stats.distillation_loss = distillation_result.final_loss;
334 distillation_result.student_weights
335 } else {
336 compressed_weights
337 };
338
339 let original_size = self.calculate_model_size(model_weights);
341 let compressed_size = self
342 .calculate_quantized_size(&quantized_weights, self.quantization.config.bit_precision);
343 compression_stats.size_reduction_ratio =
344 1.0 - (compressed_size as f32 / original_size as f32);
345 compression_stats.memory_savings_mb =
346 (original_size - compressed_size) as f32 / (1024.0 * 1024.0);
347
348 let compressed_model = CompressedModel {
349 original_weights: model_weights.clone(),
350 compressed_weights: distilled_weights,
351 quantized_weights,
352 compression_config: compression_target,
353 stats: compression_stats,
354 };
355
356 println!("✅ Model compression completed!");
357 println!(
358 " 📉 Size reduction: {:.1}%",
359 compressed_model.stats.size_reduction_ratio * 100.0
360 );
361 println!(
362 " 💾 Memory saved: {:.1}MB",
363 compressed_model.stats.memory_savings_mb
364 );
365 println!(
366 " 🕳️ Sparsity: {:.1}%",
367 compressed_model.stats.sparsity_ratio * 100.0
368 );
369
370 Ok(compressed_model)
371 }
372
373 fn calculate_model_size(&self, weights: &HashMap<String, Array2<f32>>) -> usize {
375 weights
376 .values()
377 .map(|w| w.len() * std::mem::size_of::<f32>())
378 .sum()
379 }
380
381 fn calculate_quantized_size(
383 &self,
384 weights: &HashMap<String, Array2<f32>>,
385 bit_precision: u8,
386 ) -> usize {
387 let bytes_per_element = (bit_precision as f32 / 8.0).ceil() as usize;
388 weights.values().map(|w| w.len() * bytes_per_element).sum()
389 }
390}
391
392pub struct QuantizationProcessor {
394 pub config: QuantizationConfig,
395 pub layer_params: HashMap<String, QuantizationParams>,
397}
398
399#[derive(Debug, Clone)]
401pub struct QuantizationParams {
402 pub scale: f32,
403 pub zero_point: i32,
404 pub min_val: f32,
405 pub max_val: f32,
406}
407
408impl QuantizationProcessor {
409 pub fn new(config: QuantizationConfig) -> Self {
411 Self {
412 config,
413 layer_params: HashMap::new(),
414 }
415 }
416
417 pub async fn quantize_weights(
419 &mut self,
420 weights: &HashMap<String, Array2<f32>>,
421 ) -> Result<QuantizationResult> {
422 let mut quantized_weights = HashMap::new();
423 let mut total_size_original = 0;
424 let mut total_size_quantized = 0;
425
426 for (layer_name, weight_tensor) in weights {
427 let params = self.calculate_quantization_params(weight_tensor)?;
429 self.layer_params.insert(layer_name.clone(), params.clone());
430
431 let quantized = self.apply_quantization(weight_tensor, ¶ms)?;
433
434 total_size_original += weight_tensor.len() * std::mem::size_of::<f32>();
435 total_size_quantized += weight_tensor.len() * (self.config.bit_precision as usize / 8);
436
437 quantized_weights.insert(layer_name.clone(), quantized);
438 }
439
440 let compression_ratio = 1.0 - (total_size_quantized as f32 / total_size_original as f32);
441
442 Ok(QuantizationResult {
443 quantized_weights,
444 compression_ratio,
445 bit_precision: self.config.bit_precision,
446 method: self.config.method.clone(),
447 })
448 }
449
450 fn calculate_quantization_params(&self, tensor: &Array2<f32>) -> Result<QuantizationParams> {
452 let min_val = tensor.iter().fold(f32::INFINITY, |a, &b| a.min(b));
453 let max_val = tensor.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
454
455 let qmin = 0i32;
456 let qmax = (1i32 << self.config.bit_precision) - 1;
457
458 let scale = if self.config.symmetric {
459 let abs_max = min_val.abs().max(max_val.abs());
460 abs_max / (qmax as f32 / 2.0)
461 } else {
462 (max_val - min_val) / (qmax - qmin) as f32
463 };
464
465 let zero_point = if self.config.symmetric {
466 (qmin + qmax) / 2
467 } else {
468 (qmin as f32 - min_val / scale).round() as i32
469 };
470
471 Ok(QuantizationParams {
472 scale,
473 zero_point,
474 min_val,
475 max_val,
476 })
477 }
478
479 fn apply_quantization(
481 &self,
482 tensor: &Array2<f32>,
483 params: &QuantizationParams,
484 ) -> Result<Array2<f32>> {
485 let quantized = tensor.mapv(|x| {
486 let quantized_val = (x / params.scale + params.zero_point as f32).round();
487 let clamped = quantized_val
488 .max(0.0)
489 .min((1 << self.config.bit_precision) as f32 - 1.0);
490 (clamped - params.zero_point as f32) * params.scale
491 });
492
493 Ok(quantized)
494 }
495
496 pub fn apply_binary_quantization(&self, tensor: &Array2<f32>) -> Result<Array2<f32>> {
498 let binary = tensor.mapv(|x| if x >= 0.0 { 1.0 } else { -1.0 });
500 Ok(binary)
501 }
502}
503
504pub struct PruningProcessor {
506 pub config: PruningConfig,
507 pub pruning_masks: HashMap<String, Array2<bool>>,
509}
510
511impl PruningProcessor {
512 pub fn new(config: PruningConfig) -> Self {
514 Self {
515 config,
516 pruning_masks: HashMap::new(),
517 }
518 }
519
520 pub async fn prune_weights(
522 &mut self,
523 weights: &HashMap<String, Array2<f32>>,
524 ) -> Result<PruningResult> {
525 let mut pruned_weights = HashMap::new();
526 let mut total_params = 0;
527 let mut pruned_params = 0;
528
529 for (layer_name, weight_tensor) in weights {
530 let mask = self.generate_pruning_mask(weight_tensor)?;
531 let pruned = self.apply_pruning_mask(weight_tensor, &mask);
532
533 total_params += weight_tensor.len();
534 pruned_params += mask.iter().filter(|&&x| !x).count();
535
536 self.pruning_masks.insert(layer_name.clone(), mask);
537 pruned_weights.insert(layer_name.clone(), pruned);
538 }
539
540 let sparsity_achieved = pruned_params as f32 / total_params as f32;
541
542 Ok(PruningResult {
543 pruned_weights,
544 sparsity_achieved,
545 method: self.config.method.clone(),
546 })
547 }
548
549 fn generate_pruning_mask(&self, tensor: &Array2<f32>) -> Result<Array2<bool>> {
551 match self.config.method {
552 PruningMethod::MagnitudePruning => {
553 let threshold = self.calculate_magnitude_threshold(tensor);
554 let mask = tensor.mapv(|x| x.abs() >= threshold);
555 Ok(mask)
556 }
557 PruningMethod::SNIP => {
558 self.snip_pruning(tensor)
560 }
561 PruningMethod::LotteryTicket => {
562 self.lottery_ticket_pruning(tensor)
564 }
565 _ => {
566 let threshold = self.calculate_magnitude_threshold(tensor);
568 let mask = tensor.mapv(|x| x.abs() >= threshold);
569 Ok(mask)
570 }
571 }
572 }
573
574 fn calculate_magnitude_threshold(&self, tensor: &Array2<f32>) -> f32 {
576 let mut abs_values: Vec<f32> = tensor.iter().copied().map(|x| x.abs()).collect();
577 abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
578
579 let percentile_index = (abs_values.len() as f32 * self.config.sparsity_ratio) as usize;
580 abs_values.get(percentile_index).copied().unwrap_or(0.0)
581 }
582
583 fn snip_pruning(&self, tensor: &Array2<f32>) -> Result<Array2<bool>> {
585 let importance_scores = tensor.mapv(|x| x.abs() * (1.0 - x.tanh().powi(2))); let threshold = self.calculate_snip_threshold(&importance_scores);
588 let mask = importance_scores.mapv(|x| x >= threshold);
589 Ok(mask)
590 }
591
592 fn calculate_snip_threshold(&self, importance_scores: &Array2<f32>) -> f32 {
594 let mut scores: Vec<f32> = importance_scores.iter().copied().collect();
595 scores.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)); let keep_index = ((scores.len() as f32) * (1.0 - self.config.sparsity_ratio)) as usize;
598 scores.get(keep_index).copied().unwrap_or(0.0)
599 }
600
601 fn lottery_ticket_pruning(&self, tensor: &Array2<f32>) -> Result<Array2<bool>> {
603 let mut current_tensor = tensor.clone();
605 let mut mask = Array2::from_elem(tensor.dim(), true);
606
607 let pruning_rate = 0.2; let iterations =
609 (self.config.sparsity_ratio.ln() / (1.0f32 - pruning_rate).ln()).ceil() as usize;
610
611 for _ in 0..iterations {
612 let threshold = self.calculate_percentile_threshold(¤t_tensor, pruning_rate);
613 let iteration_mask = current_tensor.mapv(|x| x.abs() >= threshold);
614
615 for ((i, j), &keep) in iteration_mask.indexed_iter() {
617 if !keep {
618 mask[[i, j]] = false;
619 current_tensor[[i, j]] = 0.0;
620 }
621 }
622 }
623
624 Ok(mask)
625 }
626
627 fn calculate_percentile_threshold(&self, tensor: &Array2<f32>, percentile: f32) -> f32 {
629 let mut abs_values: Vec<f32> = tensor
630 .iter()
631 .filter(|&&x| x != 0.0) .map(|&x| x.abs())
633 .collect();
634 abs_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
635
636 if abs_values.is_empty() {
637 return 0.0;
638 }
639
640 let index = (abs_values.len() as f32 * percentile) as usize;
641 abs_values.get(index).copied().unwrap_or(0.0)
642 }
643
644 fn apply_pruning_mask(&self, tensor: &Array2<f32>, mask: &Array2<bool>) -> Array2<f32> {
646 tensor * &mask.mapv(|x| if x { 1.0 } else { 0.0 })
647 }
648}
649
650pub struct DistillationProcessor {
652 pub config: DistillationConfig,
653}
654
655impl DistillationProcessor {
656 pub fn new(config: DistillationConfig) -> Self {
658 Self { config }
659 }
660
661 pub async fn distill_knowledge(
663 &self,
664 teacher_weights: &HashMap<String, Array2<f32>>,
665 ) -> Result<DistillationResult> {
666 println!("🎓 Starting knowledge distillation...");
668
669 let mut student_weights = HashMap::new();
671 for (layer_name, teacher_tensor) in teacher_weights {
672 let (rows, cols) = teacher_tensor.dim();
673 let student_rows = rows / 2;
674 let student_cols = cols / 2;
675
676 let student_tensor = Array2::from_shape_fn((student_rows, student_cols), |(i, j)| {
678 let teacher_i = (i * rows) / student_rows;
679 let teacher_j = (j * cols) / student_cols;
680 teacher_tensor[[teacher_i, teacher_j]] * 0.8 });
682
683 student_weights.insert(layer_name.clone(), student_tensor);
684 }
685
686 let mut distillation_loss = 1.0;
688 for epoch in 0..20 {
689 distillation_loss *= 0.95; if epoch % 5 == 0 {
693 println!(" 📉 Epoch {epoch}: Distillation loss = {distillation_loss:.4}");
694 }
695 }
696
697 Ok(DistillationResult {
698 student_weights,
699 final_loss: distillation_loss,
700 compression_ratio: 0.5, })
702 }
703
704 fn calculate_distillation_loss(
706 &self,
707 teacher_output: &Array1<f32>,
708 student_output: &Array1<f32>,
709 ) -> f32 {
710 let teacher_soft = self.apply_temperature_softmax(teacher_output, self.config.temperature);
711 let student_soft = self.apply_temperature_softmax(student_output, self.config.temperature);
712
713 teacher_soft
715 .iter()
716 .zip(student_soft.iter())
717 .map(|(&t, &s)| {
718 if t > 0.0 {
719 t * (t / s.max(1e-8)).ln()
720 } else {
721 0.0
722 }
723 })
724 .sum()
725 }
726
727 fn apply_temperature_softmax(&self, logits: &Array1<f32>, temperature: f32) -> Array1<f32> {
729 let scaled = logits.mapv(|x| x / temperature);
730 let max_val = scaled.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
731 let exp_vals = scaled.mapv(|x| (x - max_val).exp());
732 let sum_exp = exp_vals.sum();
733 exp_vals.mapv(|x| x / sum_exp)
734 }
735}
736
737pub struct NASProcessor {
739 pub config: NASConfig,
740 pub population: Vec<ArchitectureCandidate>,
742}
743
744impl NASProcessor {
745 pub fn new(config: NASConfig) -> Self {
747 Self {
748 config,
749 population: Vec::new(),
750 }
751 }
752
753 pub async fn search_architecture(&mut self) -> Result<OptimalArchitecture> {
755 println!("🔍 Starting Neural Architecture Search...");
756
757 self.initialize_population()?;
759
760 let mut best_architecture = None;
761 let mut best_score = f32::NEG_INFINITY;
762
763 for generation in 0..20 {
765 let mut scores = Vec::new();
767 for candidate in &self.population {
768 let score = self.evaluate_architecture_readonly(candidate).await?;
769 scores.push(score);
770 if score > best_score {
771 best_score = score;
772 best_architecture = Some(candidate.clone());
773 }
774 }
775
776 for (i, score) in scores.into_iter().enumerate() {
778 self.population[i].score = score;
779 }
780
781 self.evolve_population()?;
783
784 if generation % 5 == 0 {
785 println!(" 🧬 Generation {generation}: Best score = {best_score:.4}");
786 }
787 }
788
789 let optimal = best_architecture.ok_or_else(|| anyhow!("No optimal architecture found"))?;
790
791 Ok(OptimalArchitecture {
792 architecture: optimal.architecture,
793 performance_score: optimal.score,
794 memory_usage: optimal.estimated_memory,
795 inference_time: optimal.estimated_latency,
796 })
797 }
798
799 fn initialize_population(&mut self) -> Result<()> {
801 self.population.clear();
802
803 for _ in 0..self.config.num_architectures {
804 let architecture = self.generate_random_architecture()?;
805 let candidate = ArchitectureCandidate {
806 architecture,
807 score: 0.0,
808 estimated_memory: 0.0,
809 estimated_latency: 0.0,
810 };
811 self.population.push(candidate);
812 }
813
814 Ok(())
815 }
816
817 fn generate_random_architecture(&self) -> Result<Architecture> {
819 #[allow(unused_imports)]
820 use scirs2_core::random::{Random, Rng};
821 let mut rng = Random::default();
822
823 let num_layers = rng.random_range(2..11); let mut layers = Vec::new();
825
826 for _ in 0..num_layers {
827 let layer_type = match rng.random_range(0..4) {
828 0 => LayerType::Linear,
829 1 => LayerType::Attention,
830 2 => LayerType::Convolution,
831 _ => LayerType::Normalization,
832 };
833
834 let input_dim = rng.random_range(128..640);
835 let output_dim = rng.random_range(128..640);
836
837 layers.push(LayerConfig {
838 layer_type,
839 input_dim,
840 output_dim,
841 activation: ActivationType::ReLU,
842 });
843 }
844
845 Ok(Architecture {
846 layers,
847 skip_connections: rng.random_f64() < 0.5,
848 normalization: rng.random_f64() < 0.5,
849 })
850 }
851
852 async fn evaluate_architecture_readonly(
854 &self,
855 candidate: &ArchitectureCandidate,
856 ) -> Result<f32> {
857 let complexity_score = self.calculate_complexity_score(&candidate.architecture);
859 let efficiency_score = self.calculate_efficiency_score(&candidate.architecture);
860 let hardware_score = self.calculate_hardware_score(&candidate.architecture);
861
862 let score = complexity_score * 0.4 + efficiency_score * 0.4 + hardware_score * 0.2;
864
865 Ok(score)
866 }
867
868 async fn evaluate_architecture(&self, candidate: &mut ArchitectureCandidate) -> Result<f32> {
870 let complexity_score = self.calculate_complexity_score(&candidate.architecture);
872 let efficiency_score = self.calculate_efficiency_score(&candidate.architecture);
873 let hardware_score = self.calculate_hardware_score(&candidate.architecture);
874
875 candidate.estimated_memory = self.estimate_memory_usage(&candidate.architecture);
877 candidate.estimated_latency = self.estimate_inference_time(&candidate.architecture);
878
879 let score = complexity_score * 0.4 + efficiency_score * 0.4 + hardware_score * 0.2;
881
882 Ok(score)
883 }
884
885 fn calculate_complexity_score(&self, architecture: &Architecture) -> f32 {
887 let total_params: usize = architecture
888 .layers
889 .iter()
890 .map(|layer| layer.input_dim * layer.output_dim)
891 .sum();
892
893 let optimal_params = 100_000;
895 let ratio = total_params as f32 / optimal_params as f32;
896 (-((ratio - 1.0).powi(2))).exp() }
898
899 fn calculate_efficiency_score(&self, architecture: &Architecture) -> f32 {
901 let mut score = 0.0;
902
903 for layer in &architecture.layers {
905 score += match layer.layer_type {
906 LayerType::Linear => 0.8,
907 LayerType::Attention => 0.6,
908 LayerType::Convolution => 0.7,
909 LayerType::Normalization => 0.9,
910 };
911 }
912
913 if architecture.skip_connections {
915 score += 0.2;
916 }
917 if architecture.normalization {
918 score += 0.1;
919 }
920
921 score / architecture.layers.len() as f32
922 }
923
924 fn calculate_hardware_score(&self, architecture: &Architecture) -> f32 {
926 let memory_usage = self.estimate_memory_usage(architecture);
927 let inference_time = self.estimate_inference_time(architecture);
928
929 let memory_score = if memory_usage <= self.config.hardware_constraints.max_memory_mb as f32
930 {
931 1.0 - (memory_usage / self.config.hardware_constraints.max_memory_mb as f32)
932 } else {
933 0.0
934 };
935
936 let time_score = if inference_time <= self.config.hardware_constraints.max_inference_time_ms
937 {
938 1.0 - (inference_time / self.config.hardware_constraints.max_inference_time_ms)
939 } else {
940 0.0
941 };
942
943 (memory_score + time_score) / 2.0
944 }
945
946 fn estimate_memory_usage(&self, architecture: &Architecture) -> f32 {
948 let param_memory: usize = architecture
949 .layers
950 .iter()
951 .map(|layer| layer.input_dim * layer.output_dim * 4) .sum();
953
954 param_memory as f32 / (1024.0 * 1024.0) }
956
957 fn estimate_inference_time(&self, architecture: &Architecture) -> f32 {
959 let ops_count: usize = architecture
960 .layers
961 .iter()
962 .map(|layer| layer.input_dim * layer.output_dim)
963 .sum();
964
965 ops_count as f32 / 1_000_000.0 }
968
969 fn evolve_population(&mut self) -> Result<()> {
971 self.population.sort_by(|a, b| {
973 b.score
974 .partial_cmp(&a.score)
975 .unwrap_or(std::cmp::Ordering::Equal)
976 });
977
978 let survivors = self.population.len() / 2;
980 self.population.truncate(survivors);
981
982 let mut offspring = Vec::new();
984 for parent in &self.population {
985 let mut child = parent.clone();
986 self.mutate_architecture(&mut child.architecture)?;
987 child.score = 0.0; offspring.push(child);
989 }
990
991 self.population.extend(offspring);
992 Ok(())
993 }
994
995 fn mutate_architecture(&self, architecture: &mut Architecture) -> Result<()> {
997 #[allow(unused_imports)]
998 use scirs2_core::random::{Random, Rng};
999 let mut rng = Random::default();
1000
1001 let mutation_type = rng.random_range(0..4);
1002
1003 match mutation_type {
1004 0 => {
1005 let layer_count = architecture.layers.len();
1007 if layer_count > 0 {
1008 if let Some(layer) = architecture
1009 .layers
1010 .get_mut(rng.random_range(0..layer_count))
1011 {
1012 layer.output_dim = (layer.output_dim as f32
1013 * (0.8 + rng.random_f64() as f32 * 0.4))
1014 as usize;
1015 layer.output_dim = layer.output_dim.clamp(32, 1024);
1016 }
1017 }
1018 }
1019 1 => {
1020 let layer_count = architecture.layers.len();
1022 if layer_count > 0 {
1023 if let Some(layer) = architecture
1024 .layers
1025 .get_mut(rng.random_range(0..layer_count))
1026 {
1027 layer.layer_type = match rng.random_range(0..4) {
1028 0 => LayerType::Linear,
1029 1 => LayerType::Attention,
1030 2 => LayerType::Convolution,
1031 _ => LayerType::Normalization,
1032 };
1033 }
1034 }
1035 }
1036 2 => {
1037 architecture.skip_connections = !architecture.skip_connections;
1039 }
1040 _ => {
1041 architecture.normalization = !architecture.normalization;
1043 }
1044 }
1045
1046 Ok(())
1047 }
1048}
1049
1050#[derive(Debug, Clone)]
1053pub struct CompressionTarget {
1054 pub target_size_reduction: f32,
1055 pub target_speedup: f32,
1056 pub maintain_accuracy: f32,
1057 pub enable_quantization: bool,
1058 pub enable_pruning: bool,
1059 pub enable_distillation: bool,
1060 pub enable_nas: bool,
1061}
1062
1063impl Default for CompressionTarget {
1064 fn default() -> Self {
1065 Self {
1066 target_size_reduction: 0.5,
1067 target_speedup: 2.0,
1068 maintain_accuracy: 0.95,
1069 enable_quantization: true,
1070 enable_pruning: true,
1071 enable_distillation: false,
1072 enable_nas: false,
1073 }
1074 }
1075}
1076
1077#[derive(Debug, Clone, Default)]
1078pub struct CompressionStats {
1079 pub size_reduction_ratio: f32,
1080 pub memory_savings_mb: f32,
1081 pub sparsity_ratio: f32,
1082 pub quantization_ratio: f32,
1083 pub distillation_loss: f32,
1084 pub inference_speedup: f32,
1085}
1086
1087#[derive(Debug, Clone)]
1088pub struct CompressedModel {
1089 pub original_weights: HashMap<String, Array2<f32>>,
1090 pub compressed_weights: HashMap<String, Array2<f32>>,
1091 pub quantized_weights: HashMap<String, Array2<f32>>,
1092 pub compression_config: CompressionTarget,
1093 pub stats: CompressionStats,
1094}
1095
1096#[derive(Debug, Clone)]
1097pub struct QuantizationResult {
1098 pub quantized_weights: HashMap<String, Array2<f32>>,
1099 pub compression_ratio: f32,
1100 pub bit_precision: u8,
1101 pub method: QuantizationMethod,
1102}
1103
1104#[derive(Debug, Clone)]
1105pub struct PruningResult {
1106 pub pruned_weights: HashMap<String, Array2<f32>>,
1107 pub sparsity_achieved: f32,
1108 pub method: PruningMethod,
1109}
1110
1111#[derive(Debug, Clone)]
1112pub struct DistillationResult {
1113 pub student_weights: HashMap<String, Array2<f32>>,
1114 pub final_loss: f32,
1115 pub compression_ratio: f32,
1116}
1117
1118#[derive(Debug, Clone)]
1119pub struct OptimalArchitecture {
1120 pub architecture: Architecture,
1121 pub performance_score: f32,
1122 pub memory_usage: f32,
1123 pub inference_time: f32,
1124}
1125
1126#[derive(Debug, Clone)]
1127pub struct ArchitectureCandidate {
1128 pub architecture: Architecture,
1129 pub score: f32,
1130 pub estimated_memory: f32,
1131 pub estimated_latency: f32,
1132}
1133
1134#[derive(Debug, Clone)]
1135pub struct Architecture {
1136 pub layers: Vec<LayerConfig>,
1137 pub skip_connections: bool,
1138 pub normalization: bool,
1139}
1140
1141#[derive(Debug, Clone)]
1142pub struct LayerConfig {
1143 pub layer_type: LayerType,
1144 pub input_dim: usize,
1145 pub output_dim: usize,
1146 pub activation: ActivationType,
1147}
1148
1149#[derive(Debug, Clone)]
1150pub enum LayerType {
1151 Linear,
1152 Attention,
1153 Convolution,
1154 Normalization,
1155}
1156
1157#[derive(Debug, Clone)]
1158pub enum ActivationType {
1159 ReLU,
1160 GELU,
1161 Tanh,
1162 Sigmoid,
1163}
1164
1165#[cfg(test)]
1166mod tests {
1167 use super::*;
1168
1169 #[test]
1170 fn test_quantization_config_default() {
1171 let config = QuantizationConfig::default();
1172 assert_eq!(config.bit_precision, 8);
1173 assert!(config.per_channel);
1174 assert!(config.symmetric);
1175 }
1176
1177 #[test]
1178 fn test_pruning_config_default() {
1179 let config = PruningConfig::default();
1180 assert_eq!(config.sparsity_ratio, 0.5);
1181 assert!(!config.structured);
1182 assert_eq!(config.fine_tune_epochs, 10);
1183 }
1184
1185 #[test]
1186 fn test_quantization_processor() {
1187 let config = QuantizationConfig::default();
1188 let processor = QuantizationProcessor::new(config);
1189
1190 let tensor = Array2::from_shape_fn((4, 4), |(i, j)| (i + j) as f32 * 0.1);
1191 let params = processor.calculate_quantization_params(&tensor).unwrap();
1192
1193 assert!(params.scale > 0.0);
1194 assert!(params.min_val <= params.max_val);
1195 }
1196
1197 #[test]
1198 fn test_pruning_processor() {
1199 let config = PruningConfig::default();
1200 let processor = PruningProcessor::new(config);
1201
1202 let tensor = Array2::from_shape_fn((4, 4), |(i, j)| if i == j { 1.0 } else { 0.01 });
1203 let mask = processor.generate_pruning_mask(&tensor).unwrap();
1204
1205 assert!(mask[[0, 0]]);
1207 assert!(mask[[1, 1]]);
1208 }
1209
1210 #[tokio::test]
1211 async fn test_model_compression_manager() {
1212 let mut manager = ModelCompressionManager::new();
1213
1214 let mut weights = HashMap::new();
1215 weights.insert(
1216 "layer1".to_string(),
1217 Array2::from_shape_fn((8, 8), |(i, j)| (i + j) as f32 * 0.1),
1218 );
1219 weights.insert(
1220 "layer2".to_string(),
1221 Array2::from_shape_fn((8, 4), |(i, j)| (i as f32 - j as f32) * 0.05),
1222 );
1223
1224 let target = CompressionTarget::default();
1225 let result = manager.compress_model(&weights, target).await.unwrap();
1226
1227 assert!(result.stats.size_reduction_ratio > 0.0);
1228 assert!(result.stats.memory_savings_mb >= 0.0);
1229 assert_eq!(result.compressed_weights.len(), weights.len());
1230 }
1231}