1pub mod distillation;
31pub mod pruning;
32pub mod quantization;
33
34pub use distillation::{
35 DenseLayer,
37 DistillationConfig,
39 DistillationConfigBuilder,
40 DistillationLoss,
41 DistillationStats,
42 DistillationTrainer,
43 EarlyStopping,
45 ForwardCache,
46 LearningRateSchedule,
47 MLPGradients,
48 OptimizerType,
49 SimpleMLP,
50 SimpleRng,
51 Temperature,
52 TrainingState,
53 cross_entropy_loss,
55 cross_entropy_with_label,
56 kl_divergence,
57 kl_divergence_from_logits,
58 log_softmax,
59 mse_loss,
60 soft_targets,
61 softmax,
62 train_student_model,
63};
64pub use pruning::{
65 FineTuneCallback,
67 GradientInfo,
68 ImportanceMethod,
69 LotteryTicketState,
70 MaskCreationMode,
71 NoOpFineTune,
72 PruningConfig,
74 PruningConfigBuilder,
75 PruningGranularity,
76 PruningMask,
77 PruningSchedule,
78 PruningStats,
79 PruningStrategy,
80 UnstructuredPruner,
81 WeightStatistics,
82 WeightTensor,
83 compute_channel_importance,
85 compute_gradient_importance,
86 compute_magnitude_importance,
87 compute_taylor_importance,
88 iterative_pruning,
89 prune_model,
90 prune_weights_direct,
91 prune_weights_with_gradients,
92 select_weights_to_prune,
93 structured_pruning,
94 unstructured_pruning,
95};
96pub use quantization::{
97 QuantizationConfig, QuantizationMode, QuantizationParams, QuantizationResult, QuantizationType,
98 calibrate_quantization, dequantize_tensor, quantize_model, quantize_tensor,
99};
100
101use crate::error::Result;
102use std::path::Path;
103use tracing::info;
104
105#[derive(Debug, Clone)]
107pub struct OptimizationStats {
108 pub original_size: usize,
110 pub optimized_size: usize,
112 pub compression_ratio: f32,
114 pub speedup: f32,
116 pub accuracy_delta: f32,
118}
119
120impl OptimizationStats {
121 #[must_use]
123 pub fn new(
124 original_size: usize,
125 optimized_size: usize,
126 speedup: f32,
127 accuracy_delta: f32,
128 ) -> Self {
129 let compression_ratio = if optimized_size > 0 {
130 original_size as f32 / optimized_size as f32
131 } else {
132 0.0
133 };
134
135 Self {
136 original_size,
137 optimized_size,
138 compression_ratio,
139 speedup,
140 accuracy_delta,
141 }
142 }
143
144 #[must_use]
146 pub fn size_reduction(&self) -> usize {
147 self.original_size.saturating_sub(self.optimized_size)
148 }
149
150 #[must_use]
152 pub fn size_reduction_percent(&self) -> f32 {
153 if self.original_size > 0 {
154 (self.size_reduction() as f32 / self.original_size as f32) * 100.0
155 } else {
156 0.0
157 }
158 }
159
160 #[must_use]
162 pub fn is_worthwhile(&self) -> bool {
163 self.size_reduction_percent() > 20.0 && self.accuracy_delta.abs() < 2.0
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum OptimizationProfile {
170 Accuracy,
172 Balanced,
174 Speed,
176 Size,
178}
179
180pub struct OptimizationPipeline {
182 pub quantization: Option<QuantizationConfig>,
184 pub pruning: Option<PruningConfig>,
186 pub weight_sharing: bool,
188 pub operator_fusion: bool,
190}
191
192impl OptimizationPipeline {
193 #[must_use]
195 pub fn from_profile(profile: OptimizationProfile) -> Self {
196 match profile {
197 OptimizationProfile::Accuracy => Self {
198 quantization: Some(
199 QuantizationConfig::builder()
200 .quantization_type(QuantizationType::Float16)
201 .build(),
202 ),
203 pruning: None,
204 weight_sharing: false,
205 operator_fusion: true,
206 },
207 OptimizationProfile::Balanced => Self {
208 quantization: Some(
209 QuantizationConfig::builder()
210 .quantization_type(QuantizationType::Int8)
211 .per_channel(true)
212 .build(),
213 ),
214 pruning: Some(
215 PruningConfig::builder()
216 .sparsity_target(0.3)
217 .strategy(PruningStrategy::Magnitude)
218 .build(),
219 ),
220 weight_sharing: true,
221 operator_fusion: true,
222 },
223 OptimizationProfile::Speed => Self {
224 quantization: Some(
225 QuantizationConfig::builder()
226 .quantization_type(QuantizationType::Int8)
227 .per_channel(true)
228 .build(),
229 ),
230 pruning: Some(
231 PruningConfig::builder()
232 .sparsity_target(0.5)
233 .strategy(PruningStrategy::Structured)
234 .build(),
235 ),
236 weight_sharing: true,
237 operator_fusion: true,
238 },
239 OptimizationProfile::Size => Self {
240 quantization: Some(
241 QuantizationConfig::builder()
242 .quantization_type(QuantizationType::Int8)
243 .per_channel(true)
244 .build(),
245 ),
246 pruning: Some(
247 PruningConfig::builder()
248 .sparsity_target(0.7)
249 .strategy(PruningStrategy::Structured)
250 .build(),
251 ),
252 weight_sharing: true,
253 operator_fusion: true,
254 },
255 }
256 }
257
258 pub fn optimize<P: AsRef<std::path::Path>>(
263 &self,
264 input_path: P,
265 output_path: P,
266 ) -> Result<OptimizationStats> {
267 use tracing::info;
268
269 info!("Running optimization pipeline");
270
271 let input = input_path.as_ref();
272 let output = output_path.as_ref();
273
274 let original_size = std::fs::metadata(input)
276 .map(|m| m.len() as usize)
277 .unwrap_or(0);
278
279 let mut current_path = input.to_path_buf();
281
282 if let Some(ref config) = self.pruning {
284 let pruned_path = output.with_extension("pruned.onnx");
285 prune_model(¤t_path, &pruned_path, config)?;
286 current_path = pruned_path;
287 }
288
289 if let Some(ref config) = self.quantization {
291 let quantized_path = output.with_extension("quantized.onnx");
292 quantize_model(¤t_path, &quantized_path, config)?;
293 current_path = quantized_path;
294 }
295
296 std::fs::rename(¤t_path, output)?;
298
299 let optimized_size = std::fs::metadata(output)
301 .map(|m| m.len() as usize)
302 .unwrap_or(0);
303
304 let speedup = Self::measure_speedup(input, output)?;
306
307 let accuracy_delta = Self::estimate_accuracy_delta(self);
310
311 Ok(OptimizationStats::new(
312 original_size,
313 optimized_size,
314 speedup,
315 accuracy_delta,
316 ))
317 }
318
319 fn measure_speedup(original_path: &Path, optimized_path: &Path) -> Result<f32> {
321 use std::time::Instant;
322
323 const WARMUP_ITERS: usize = 5;
325 const BENCH_ITERS: usize = 20;
326
327 if !original_path.exists() || !optimized_path.exists() {
329 info!("Skipping speedup measurement: model files not accessible");
330 return Ok(1.5); }
332
333 let dummy_input = vec![0.0f32; 224 * 224 * 3]; let input_shape = vec![1, 3, 224, 224];
337
338 let original_time = match Self::benchmark_model(
340 original_path,
341 &dummy_input,
342 &input_shape,
343 WARMUP_ITERS,
344 BENCH_ITERS,
345 ) {
346 Ok(t) => t,
347 Err(e) => {
348 info!("Could not benchmark original model: {}, using estimate", e);
349 return Ok(1.5);
350 }
351 };
352
353 let optimized_time = match Self::benchmark_model(
355 optimized_path,
356 &dummy_input,
357 &input_shape,
358 WARMUP_ITERS,
359 BENCH_ITERS,
360 ) {
361 Ok(t) => t,
362 Err(e) => {
363 info!("Could not benchmark optimized model: {}, using estimate", e);
364 return Ok(1.5);
365 }
366 };
367
368 if optimized_time > 0.0 {
369 let speedup = (original_time / optimized_time) as f32;
370 info!(
371 "Measured speedup: {:.2}x (original: {:.2}ms, optimized: {:.2}ms)",
372 speedup,
373 original_time * 1000.0,
374 optimized_time * 1000.0
375 );
376 Ok(speedup)
377 } else {
378 Ok(1.5) }
380 }
381
382 fn benchmark_model(
384 model_path: &Path,
385 input: &[f32],
386 input_shape: &[usize],
387 warmup_iters: usize,
388 bench_iters: usize,
389 ) -> Result<f64> {
390 use ndarray::{Array, IxDyn};
391 use ort::session::Session;
392 use ort::value::TensorRef;
393 use std::time::Instant;
394
395 let mut session = Session::builder()
397 .map_err(|e| crate::error::ModelError::LoadFailed {
398 reason: format!("Failed to create session builder: {}", e),
399 })?
400 .commit_from_file(model_path)
401 .map_err(|e| crate::error::ModelError::LoadFailed {
402 reason: format!("Failed to load model for benchmarking: {}", e),
403 })?;
404
405 let inputs = session.inputs();
407 let input_name = inputs
408 .first()
409 .ok_or_else(|| crate::error::ModelError::LoadFailed {
410 reason: "No input tensors found in model".to_string(),
411 })?
412 .name()
413 .to_string();
414
415 let array_shape: Vec<usize> = input_shape.to_vec();
417 let total_elements: usize = array_shape.iter().product();
418
419 if input.len() != total_elements {
421 return Err(crate::error::InferenceError::InvalidInputShape {
422 expected: array_shape.clone(),
423 actual: vec![input.len()],
424 }
425 .into());
426 }
427
428 let input_array =
430 Array::from_shape_vec(IxDyn(&array_shape), input.to_vec()).map_err(|e| {
431 crate::error::InferenceError::Failed {
432 reason: format!("Failed to create input array: {}", e),
433 }
434 })?;
435
436 for _ in 0..warmup_iters {
438 let input_tensor = TensorRef::from_array_view(input_array.view()).map_err(|e| {
439 crate::error::InferenceError::Failed {
440 reason: format!("Failed to create input tensor: {}", e),
441 }
442 })?;
443
444 let _ = session
445 .run(ort::inputs![input_name.as_str() => input_tensor])
446 .map_err(|e| crate::error::InferenceError::Failed {
447 reason: format!("Warmup inference failed: {}", e),
448 })?;
449 }
450
451 let start = Instant::now();
453 for _ in 0..bench_iters {
454 let input_tensor = TensorRef::from_array_view(input_array.view()).map_err(|e| {
455 crate::error::InferenceError::Failed {
456 reason: format!("Failed to create input tensor: {}", e),
457 }
458 })?;
459
460 let _ = session
461 .run(ort::inputs![input_name.as_str() => input_tensor])
462 .map_err(|e| crate::error::InferenceError::Failed {
463 reason: format!("Benchmark inference failed: {}", e),
464 })?;
465 }
466 let elapsed = start.elapsed();
467
468 let avg_time = elapsed.as_secs_f64() / bench_iters as f64;
470
471 Ok(avg_time)
472 }
473
474 fn estimate_accuracy_delta(&self) -> f32 {
476 let mut delta = 0.0f32;
477
478 if let Some(ref quant) = self.quantization {
480 delta += match quant.quantization_type {
481 QuantizationType::Float16 => -0.1, QuantizationType::Int8 => -0.5, QuantizationType::UInt8 => -0.5, QuantizationType::Int4 => -2.0, };
486 }
487
488 if let Some(ref prune) = self.pruning {
490 delta += -prune.sparsity_target * 2.0; }
492
493 delta
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500
501 #[test]
502 fn test_optimization_stats() {
503 let stats = OptimizationStats::new(
504 1000000, 250000, 2.0, -0.5, );
509
510 assert_eq!(stats.size_reduction(), 750000);
511 assert!((stats.size_reduction_percent() - 75.0).abs() < 0.1);
512 assert!((stats.compression_ratio - 4.0).abs() < 0.1);
513 assert!(stats.is_worthwhile());
514 }
515
516 #[test]
517 fn test_optimization_profile_accuracy() {
518 let pipeline = OptimizationPipeline::from_profile(OptimizationProfile::Accuracy);
519 assert!(pipeline.quantization.is_some());
520 assert!(pipeline.pruning.is_none());
521 assert!(pipeline.operator_fusion);
522 }
523
524 #[test]
525 fn test_optimization_profile_speed() {
526 let pipeline = OptimizationPipeline::from_profile(OptimizationProfile::Speed);
527 assert!(pipeline.quantization.is_some());
528 assert!(pipeline.pruning.is_some());
529 assert!(pipeline.weight_sharing);
530 }
531
532 #[test]
533 fn test_optimization_profile_size() {
534 let pipeline = OptimizationPipeline::from_profile(OptimizationProfile::Size);
535 assert!(pipeline.quantization.is_some());
536 assert!(pipeline.pruning.is_some());
537
538 if let Some(pruning) = &pipeline.pruning {
539 assert!(pruning.sparsity_target >= 0.6);
541 }
542 }
543}