Skip to main content

oxigdal_ml/optimization/
mod.rs

1//! Model optimization techniques for efficient inference
2//!
3//! This module provides various model optimization techniques to reduce model size,
4//! improve inference speed, and reduce memory consumption while maintaining accuracy.
5//!
6//! # Techniques
7//!
8//! - **Quantization**: Reduce precision (FP32 -> INT8/FP16)
9//! - **Pruning**: Remove unnecessary weights and connections
10//! - **Knowledge Distillation**: Transfer knowledge from large to small models
11//! - **Model Compression**: GZIP, Huffman coding, weight sharing
12//!
13//! # Example
14//!
15//! ```no_run
16//! use oxigdal_ml::optimization::quantize_model;
17//! use oxigdal_ml::optimization::{QuantizationConfig, QuantizationType};
18//!
19//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! let config = QuantizationConfig::builder()
21//!     .quantization_type(QuantizationType::Int8)
22//!     .per_channel(true)
23//!     .build();
24//!
25//! quantize_model("model.onnx", "model_quantized.onnx", &config)?;
26//! # Ok(())
27//! # }
28//! ```
29
30pub mod distillation;
31pub mod pruning;
32pub mod quantization;
33
34pub use distillation::{
35    // Neural network components (for testing/extension)
36    DenseLayer,
37    // Core types
38    DistillationConfig,
39    DistillationConfigBuilder,
40    DistillationLoss,
41    DistillationStats,
42    DistillationTrainer,
43    // Optimizer types
44    EarlyStopping,
45    ForwardCache,
46    LearningRateSchedule,
47    MLPGradients,
48    OptimizerType,
49    SimpleMLP,
50    SimpleRng,
51    Temperature,
52    TrainingState,
53    // Core functions
54    cross_entropy_loss,
55    cross_entropy_with_label,
56    kl_divergence,
57    kl_divergence_from_logits,
58    log_softmax,
59    mse_loss,
60    soft_targets,
61    softmax,
62    train_student_model,
63};
64pub use pruning::{
65    // Unstructured pruning types
66    FineTuneCallback,
67    GradientInfo,
68    ImportanceMethod,
69    LotteryTicketState,
70    MaskCreationMode,
71    NoOpFineTune,
72    // Core configuration types
73    PruningConfig,
74    PruningConfigBuilder,
75    PruningGranularity,
76    PruningMask,
77    PruningSchedule,
78    PruningStats,
79    PruningStrategy,
80    UnstructuredPruner,
81    WeightStatistics,
82    WeightTensor,
83    // Helper functions
84    compute_channel_importance,
85    compute_gradient_importance,
86    compute_magnitude_importance,
87    compute_taylor_importance,
88    iterative_pruning,
89    prune_model,
90    prune_weights_direct,
91    prune_weights_with_gradients,
92    select_weights_to_prune,
93    structured_pruning,
94    unstructured_pruning,
95};
96pub use quantization::{
97    QuantizationConfig, QuantizationMode, QuantizationParams, QuantizationResult, QuantizationType,
98    calibrate_quantization, dequantize_tensor, quantize_model, quantize_tensor,
99};
100
101use crate::error::Result;
102use std::path::Path;
103use tracing::info;
104
105/// Model optimization statistics
106#[derive(Debug, Clone)]
107pub struct OptimizationStats {
108    /// Original model size in bytes
109    pub original_size: usize,
110    /// Optimized model size in bytes
111    pub optimized_size: usize,
112    /// Compression ratio
113    pub compression_ratio: f32,
114    /// Inference speedup factor
115    pub speedup: f32,
116    /// Accuracy change (percentage points)
117    pub accuracy_delta: f32,
118}
119
120impl OptimizationStats {
121    /// Creates optimization statistics
122    #[must_use]
123    pub fn new(
124        original_size: usize,
125        optimized_size: usize,
126        speedup: f32,
127        accuracy_delta: f32,
128    ) -> Self {
129        let compression_ratio = if optimized_size > 0 {
130            original_size as f32 / optimized_size as f32
131        } else {
132            0.0
133        };
134
135        Self {
136            original_size,
137            optimized_size,
138            compression_ratio,
139            speedup,
140            accuracy_delta,
141        }
142    }
143
144    /// Returns the size reduction in bytes
145    #[must_use]
146    pub fn size_reduction(&self) -> usize {
147        self.original_size.saturating_sub(self.optimized_size)
148    }
149
150    /// Returns the size reduction as a percentage
151    #[must_use]
152    pub fn size_reduction_percent(&self) -> f32 {
153        if self.original_size > 0 {
154            (self.size_reduction() as f32 / self.original_size as f32) * 100.0
155        } else {
156            0.0
157        }
158    }
159
160    /// Checks if optimization is worthwhile (> 20% size reduction with < 2% accuracy loss)
161    #[must_use]
162    pub fn is_worthwhile(&self) -> bool {
163        self.size_reduction_percent() > 20.0 && self.accuracy_delta.abs() < 2.0
164    }
165}
166
167/// Model optimization profile
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum OptimizationProfile {
170    /// Maximum accuracy, minimal optimization
171    Accuracy,
172    /// Balanced accuracy and speed
173    Balanced,
174    /// Maximum speed, aggressive optimization
175    Speed,
176    /// Minimum size for edge devices
177    Size,
178}
179
180/// Combined optimization pipeline
181pub struct OptimizationPipeline {
182    /// Quantization configuration
183    pub quantization: Option<QuantizationConfig>,
184    /// Pruning configuration
185    pub pruning: Option<PruningConfig>,
186    /// Whether to apply weight sharing
187    pub weight_sharing: bool,
188    /// Whether to apply operator fusion
189    pub operator_fusion: bool,
190}
191
192impl OptimizationPipeline {
193    /// Creates an optimization pipeline from a profile
194    #[must_use]
195    pub fn from_profile(profile: OptimizationProfile) -> Self {
196        match profile {
197            OptimizationProfile::Accuracy => Self {
198                quantization: Some(
199                    QuantizationConfig::builder()
200                        .quantization_type(QuantizationType::Float16)
201                        .build(),
202                ),
203                pruning: None,
204                weight_sharing: false,
205                operator_fusion: true,
206            },
207            OptimizationProfile::Balanced => Self {
208                quantization: Some(
209                    QuantizationConfig::builder()
210                        .quantization_type(QuantizationType::Int8)
211                        .per_channel(true)
212                        .build(),
213                ),
214                pruning: Some(
215                    PruningConfig::builder()
216                        .sparsity_target(0.3)
217                        .strategy(PruningStrategy::Magnitude)
218                        .build(),
219                ),
220                weight_sharing: true,
221                operator_fusion: true,
222            },
223            OptimizationProfile::Speed => Self {
224                quantization: Some(
225                    QuantizationConfig::builder()
226                        .quantization_type(QuantizationType::Int8)
227                        .per_channel(true)
228                        .build(),
229                ),
230                pruning: Some(
231                    PruningConfig::builder()
232                        .sparsity_target(0.5)
233                        .strategy(PruningStrategy::Structured)
234                        .build(),
235                ),
236                weight_sharing: true,
237                operator_fusion: true,
238            },
239            OptimizationProfile::Size => Self {
240                quantization: Some(
241                    QuantizationConfig::builder()
242                        .quantization_type(QuantizationType::Int8)
243                        .per_channel(true)
244                        .build(),
245                ),
246                pruning: Some(
247                    PruningConfig::builder()
248                        .sparsity_target(0.7)
249                        .strategy(PruningStrategy::Structured)
250                        .build(),
251                ),
252                weight_sharing: true,
253                operator_fusion: true,
254            },
255        }
256    }
257
258    /// Applies the optimization pipeline to a model
259    ///
260    /// # Errors
261    /// Returns an error if optimization fails
262    pub fn optimize<P: AsRef<std::path::Path>>(
263        &self,
264        input_path: P,
265        output_path: P,
266    ) -> Result<OptimizationStats> {
267        use tracing::info;
268
269        info!("Running optimization pipeline");
270
271        let input = input_path.as_ref();
272        let output = output_path.as_ref();
273
274        // Get original size
275        let original_size = std::fs::metadata(input)
276            .map(|m| m.len() as usize)
277            .unwrap_or(0);
278
279        // Apply optimizations in sequence
280        let mut current_path = input.to_path_buf();
281
282        // 1. Pruning (if configured)
283        if let Some(ref config) = self.pruning {
284            let pruned_path = output.with_extension("pruned.onnx");
285            prune_model(&current_path, &pruned_path, config)?;
286            current_path = pruned_path;
287        }
288
289        // 2. Quantization (if configured)
290        if let Some(ref config) = self.quantization {
291            let quantized_path = output.with_extension("quantized.onnx");
292            quantize_model(&current_path, &quantized_path, config)?;
293            current_path = quantized_path;
294        }
295
296        // 3. Final rename
297        std::fs::rename(&current_path, output)?;
298
299        // Get optimized size
300        let optimized_size = std::fs::metadata(output)
301            .map(|m| m.len() as usize)
302            .unwrap_or(0);
303
304        // Measure actual speedup by benchmarking both models
305        let speedup = Self::measure_speedup(input, output)?;
306
307        // Accuracy measurement would require test dataset
308        // For now, use conservative estimate based on optimization level
309        let accuracy_delta = Self::estimate_accuracy_delta(self);
310
311        Ok(OptimizationStats::new(
312            original_size,
313            optimized_size,
314            speedup,
315            accuracy_delta,
316        ))
317    }
318
319    /// Measures inference speedup between original and optimized model
320    fn measure_speedup(original_path: &Path, optimized_path: &Path) -> Result<f32> {
321        use std::time::Instant;
322
323        // Number of warmup and benchmark iterations
324        const WARMUP_ITERS: usize = 5;
325        const BENCH_ITERS: usize = 20;
326
327        // Check if both models exist
328        if !original_path.exists() || !optimized_path.exists() {
329            info!("Skipping speedup measurement: model files not accessible");
330            return Ok(1.5); // Conservative default estimate
331        }
332
333        // Create dummy input for benchmarking
334        // In production, would use representative dataset
335        let dummy_input = vec![0.0f32; 224 * 224 * 3]; // Typical image size
336        let input_shape = vec![1, 3, 224, 224];
337
338        // Benchmark original model
339        let original_time = match Self::benchmark_model(
340            original_path,
341            &dummy_input,
342            &input_shape,
343            WARMUP_ITERS,
344            BENCH_ITERS,
345        ) {
346            Ok(t) => t,
347            Err(e) => {
348                info!("Could not benchmark original model: {}, using estimate", e);
349                return Ok(1.5);
350            }
351        };
352
353        // Benchmark optimized model
354        let optimized_time = match Self::benchmark_model(
355            optimized_path,
356            &dummy_input,
357            &input_shape,
358            WARMUP_ITERS,
359            BENCH_ITERS,
360        ) {
361            Ok(t) => t,
362            Err(e) => {
363                info!("Could not benchmark optimized model: {}, using estimate", e);
364                return Ok(1.5);
365            }
366        };
367
368        if optimized_time > 0.0 {
369            let speedup = (original_time / optimized_time) as f32;
370            info!(
371                "Measured speedup: {:.2}x (original: {:.2}ms, optimized: {:.2}ms)",
372                speedup,
373                original_time * 1000.0,
374                optimized_time * 1000.0
375            );
376            Ok(speedup)
377        } else {
378            Ok(1.5) // Default if measurement fails
379        }
380    }
381
382    /// Benchmarks a single model
383    fn benchmark_model(
384        model_path: &Path,
385        input: &[f32],
386        input_shape: &[usize],
387        warmup_iters: usize,
388        bench_iters: usize,
389    ) -> Result<f64> {
390        use ndarray::{Array, IxDyn};
391        use ort::session::Session;
392        use ort::value::TensorRef;
393        use std::time::Instant;
394
395        // Load ONNX model
396        let mut session = Session::builder()
397            .map_err(|e| crate::error::ModelError::LoadFailed {
398                reason: format!("Failed to create session builder: {}", e),
399            })?
400            .commit_from_file(model_path)
401            .map_err(|e| crate::error::ModelError::LoadFailed {
402                reason: format!("Failed to load model for benchmarking: {}", e),
403            })?;
404
405        // Get input name
406        let inputs = session.inputs();
407        let input_name = inputs
408            .first()
409            .ok_or_else(|| crate::error::ModelError::LoadFailed {
410                reason: "No input tensors found in model".to_string(),
411            })?
412            .name()
413            .to_string();
414
415        // Create input array from data and shape
416        let array_shape: Vec<usize> = input_shape.to_vec();
417        let total_elements: usize = array_shape.iter().product();
418
419        // Validate input size
420        if input.len() != total_elements {
421            return Err(crate::error::InferenceError::InvalidInputShape {
422                expected: array_shape.clone(),
423                actual: vec![input.len()],
424            }
425            .into());
426        }
427
428        // Create ndarray with dynamic dimensions
429        let input_array =
430            Array::from_shape_vec(IxDyn(&array_shape), input.to_vec()).map_err(|e| {
431                crate::error::InferenceError::Failed {
432                    reason: format!("Failed to create input array: {}", e),
433                }
434            })?;
435
436        // Run warmup iterations
437        for _ in 0..warmup_iters {
438            let input_tensor = TensorRef::from_array_view(input_array.view()).map_err(|e| {
439                crate::error::InferenceError::Failed {
440                    reason: format!("Failed to create input tensor: {}", e),
441                }
442            })?;
443
444            let _ = session
445                .run(ort::inputs![input_name.as_str() => input_tensor])
446                .map_err(|e| crate::error::InferenceError::Failed {
447                    reason: format!("Warmup inference failed: {}", e),
448                })?;
449        }
450
451        // Run benchmark iterations with timing
452        let start = Instant::now();
453        for _ in 0..bench_iters {
454            let input_tensor = TensorRef::from_array_view(input_array.view()).map_err(|e| {
455                crate::error::InferenceError::Failed {
456                    reason: format!("Failed to create input tensor: {}", e),
457                }
458            })?;
459
460            let _ = session
461                .run(ort::inputs![input_name.as_str() => input_tensor])
462                .map_err(|e| crate::error::InferenceError::Failed {
463                    reason: format!("Benchmark inference failed: {}", e),
464                })?;
465        }
466        let elapsed = start.elapsed();
467
468        // Calculate average time per inference in seconds
469        let avg_time = elapsed.as_secs_f64() / bench_iters as f64;
470
471        Ok(avg_time)
472    }
473
474    /// Estimates accuracy delta based on optimization configuration
475    fn estimate_accuracy_delta(&self) -> f32 {
476        let mut delta = 0.0f32;
477
478        // Quantization impact
479        if let Some(ref quant) = self.quantization {
480            delta += match quant.quantization_type {
481                QuantizationType::Float16 => -0.1, // Minimal loss
482                QuantizationType::Int8 => -0.5,    // Small loss
483                QuantizationType::UInt8 => -0.5,   // Similar to Int8
484                QuantizationType::Int4 => -2.0,    // Moderate loss
485            };
486        }
487
488        // Pruning impact
489        if let Some(ref prune) = self.pruning {
490            delta += -prune.sparsity_target * 2.0; // Rough heuristic
491        }
492
493        delta
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn test_optimization_stats() {
503        let stats = OptimizationStats::new(
504            1000000, // 1 MB original
505            250000,  // 250 KB optimized
506            2.0,     // 2x speedup
507            -0.5,    // 0.5% accuracy loss
508        );
509
510        assert_eq!(stats.size_reduction(), 750000);
511        assert!((stats.size_reduction_percent() - 75.0).abs() < 0.1);
512        assert!((stats.compression_ratio - 4.0).abs() < 0.1);
513        assert!(stats.is_worthwhile());
514    }
515
516    #[test]
517    fn test_optimization_profile_accuracy() {
518        let pipeline = OptimizationPipeline::from_profile(OptimizationProfile::Accuracy);
519        assert!(pipeline.quantization.is_some());
520        assert!(pipeline.pruning.is_none());
521        assert!(pipeline.operator_fusion);
522    }
523
524    #[test]
525    fn test_optimization_profile_speed() {
526        let pipeline = OptimizationPipeline::from_profile(OptimizationProfile::Speed);
527        assert!(pipeline.quantization.is_some());
528        assert!(pipeline.pruning.is_some());
529        assert!(pipeline.weight_sharing);
530    }
531
532    #[test]
533    fn test_optimization_profile_size() {
534        let pipeline = OptimizationPipeline::from_profile(OptimizationProfile::Size);
535        assert!(pipeline.quantization.is_some());
536        assert!(pipeline.pruning.is_some());
537
538        if let Some(pruning) = &pipeline.pruning {
539            // Size profile should have high sparsity
540            assert!(pruning.sparsity_target >= 0.6);
541        }
542    }
543}