1use crate::model_merge::WeightTensor;
22use crate::pruning::{prune_tensor, PruningConfig, PruningError};
23
24#[derive(Debug, Clone)]
30pub enum CompressionStage {
31 Prune(PruningConfig),
33 QuantizeInt8,
36 Clip {
39 percentile: f32,
41 },
42}
43
44impl CompressionStage {
45 pub fn name(&self) -> &'static str {
47 match self {
48 CompressionStage::Prune(_) => "prune",
49 CompressionStage::QuantizeInt8 => "quantize_int8",
50 CompressionStage::Clip { .. } => "clip",
51 }
52 }
53}
54
55#[derive(Debug, Clone, Default)]
61pub struct CompressionConfig {
62 pub stages: Vec<CompressionStage>,
64 pub skip_embedding_layers: bool,
67}
68
69impl CompressionConfig {
70 pub fn new() -> Self {
72 Self {
73 stages: Vec::new(),
74 skip_embedding_layers: false,
75 }
76 }
77
78 pub fn add_stage(mut self, stage: CompressionStage) -> Self {
80 self.stages.push(stage);
81 self
82 }
83
84 pub fn prune_then_quantize(sparsity: f32) -> Self {
86 let prune_cfg = PruningConfig::unstructured_l1(sparsity);
87 Self::new()
88 .add_stage(CompressionStage::Prune(prune_cfg))
89 .add_stage(CompressionStage::QuantizeInt8)
90 }
91
92 pub fn quantize_only() -> Self {
94 Self::new().add_stage(CompressionStage::QuantizeInt8)
95 }
96
97 pub fn prune_only(sparsity: f32) -> Self {
99 let prune_cfg = PruningConfig::unstructured_l1(sparsity);
100 Self::new().add_stage(CompressionStage::Prune(prune_cfg))
101 }
102}
103
104#[derive(Debug, Clone)]
110pub struct StageStats {
111 pub stage_name: String,
113 pub tensors_processed: usize,
115 pub tensors_skipped: usize,
117 pub params_before: usize,
119 pub nonzero_params_after: usize,
121 pub memory_before_bytes: usize,
123 pub memory_after_bytes: usize,
125}
126
127impl StageStats {
128 pub fn compression_ratio(&self) -> f32 {
131 if self.memory_after_bytes == 0 {
132 return 1.0;
133 }
134 self.memory_before_bytes as f32 / self.memory_after_bytes as f32
135 }
136
137 pub fn sparsity(&self) -> f32 {
139 if self.params_before == 0 {
140 return 0.0;
141 }
142 let zeros = self.params_before.saturating_sub(self.nonzero_params_after);
143 zeros as f32 / self.params_before as f32
144 }
145}
146
147#[derive(Debug, Clone)]
153pub struct CompressionResult {
154 pub compressed_tensors: Vec<WeightTensor>,
156 pub stage_stats: Vec<StageStats>,
158}
159
160impl CompressionResult {
161 pub fn total_params(&self) -> usize {
163 self.compressed_tensors.iter().map(|t| t.data.len()).sum()
164 }
165
166 pub fn total_nonzero(&self) -> usize {
168 self.compressed_tensors
169 .iter()
170 .map(|t| t.data.iter().filter(|&&x| x != 0.0).count())
171 .sum()
172 }
173
174 pub fn overall_sparsity(&self) -> f32 {
176 let total = self.total_params();
177 if total == 0 {
178 return 0.0;
179 }
180 let nonzero = self.total_nonzero();
181 let zeros = total.saturating_sub(nonzero);
182 zeros as f32 / total as f32
183 }
184
185 pub fn total_compression_ratio(&self) -> f32 {
188 if self.stage_stats.is_empty() {
189 return 1.0;
190 }
191 let before = self.memory_before_bytes();
192 let after = self.memory_after_bytes();
193 if after == 0 {
194 return 1.0;
195 }
196 before as f32 / after as f32
197 }
198
199 pub fn memory_before_bytes(&self) -> usize {
202 self.stage_stats
203 .first()
204 .map(|s| s.memory_before_bytes)
205 .unwrap_or(0)
206 }
207
208 pub fn memory_after_bytes(&self) -> usize {
211 self.stage_stats
212 .last()
213 .map(|s| s.memory_after_bytes)
214 .unwrap_or(0)
215 }
216
217 pub fn summary(&self) -> String {
219 let mut lines: Vec<String> = Vec::new();
220 lines.push(format!(
221 "=== Compression Summary ({} stage(s)) ===",
222 self.stage_stats.len()
223 ));
224 for (i, stats) in self.stage_stats.iter().enumerate() {
225 lines.push(format!(
226 " Stage {}: [{}] processed={} skipped={} sparsity={:.4} ratio={:.3}x \
227 memory={}B->{}B",
228 i + 1,
229 stats.stage_name,
230 stats.tensors_processed,
231 stats.tensors_skipped,
232 stats.sparsity(),
233 stats.compression_ratio(),
234 stats.memory_before_bytes,
235 stats.memory_after_bytes,
236 ));
237 }
238 lines.push(format!(
239 " Overall: tensors={} total_params={} nonzero={} sparsity={:.4} \
240 compression_ratio={:.3}x memory={}B->{}B",
241 self.compressed_tensors.len(),
242 self.total_params(),
243 self.total_nonzero(),
244 self.overall_sparsity(),
245 self.total_compression_ratio(),
246 self.memory_before_bytes(),
247 self.memory_after_bytes(),
248 ));
249 lines.join("\n")
250 }
251}
252
253#[derive(Debug, thiserror::Error)]
259pub enum CompressionError {
260 #[error("pruning error: {0}")]
262 Pruning(#[from] PruningError),
263
264 #[error("empty model: no tensors")]
266 EmptyModel,
267
268 #[error("empty pipeline: no stages")]
270 EmptyPipeline,
271
272 #[error("invalid clip percentile {0}: must be in (0, 1]")]
274 InvalidPercentile(f32),
275}
276
277#[inline]
283fn is_embedding_layer(name: &str) -> bool {
284 let lower = name.to_ascii_lowercase();
285 lower.starts_with("embed") || lower.starts_with("token")
286}
287
288#[inline]
290fn tensor_bytes(tensor: &WeightTensor) -> usize {
291 tensor.data.len() * core::mem::size_of::<f32>()
292}
293
294#[inline]
296fn count_nonzero(tensor: &WeightTensor) -> usize {
297 tensor.data.iter().filter(|&&x| x != 0.0).count()
298}
299
300fn apply_quantize_int8_inplace(tensor: &mut WeightTensor) {
306 let data = &mut tensor.data;
307 if data.is_empty() {
308 return;
309 }
310
311 let max_abs = data.iter().map(|w| w.abs()).fold(0.0_f32, f32::max);
313 if max_abs == 0.0 {
314 return; }
316
317 let scale = max_abs / 127.0_f32;
318
319 for w in data.iter_mut() {
321 let q = (*w / scale).round().clamp(-127.0_f32, 127.0_f32) as i8;
322 *w = q as f32 * scale;
323 }
324}
325
326fn apply_clip_inplace(tensor: &mut WeightTensor, percentile: f32) -> Result<(), CompressionError> {
331 if percentile <= 0.0 || percentile > 1.0 {
332 return Err(CompressionError::InvalidPercentile(percentile));
333 }
334 let data = &mut tensor.data;
335 if data.is_empty() {
336 return Ok(());
337 }
338
339 let mut abs_vals: Vec<f32> = data.iter().map(|w| w.abs()).collect();
341 abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
342
343 let n = abs_vals.len();
344 let idx = ((percentile * n as f32).ceil() as usize)
347 .saturating_sub(1)
348 .min(n - 1);
349 let threshold = abs_vals[idx];
350
351 for w in data.iter_mut() {
353 if w.abs() <= threshold {
354 *w = 0.0;
355 }
356 }
357
358 Ok(())
359}
360
361#[inline]
363fn quantize_int8_memory_after(memory_before: usize) -> usize {
364 (memory_before as f32 * 0.25).round() as usize
366}
367
368pub fn compress_model(
384 tensors: &[WeightTensor],
385 config: &CompressionConfig,
386) -> Result<CompressionResult, CompressionError> {
387 if tensors.is_empty() {
388 return Err(CompressionError::EmptyModel);
389 }
390 if config.stages.is_empty() {
391 return Err(CompressionError::EmptyPipeline);
392 }
393
394 for stage in &config.stages {
396 if let CompressionStage::Clip { percentile } = stage {
397 if *percentile <= 0.0 || *percentile > 1.0 {
398 return Err(CompressionError::InvalidPercentile(*percentile));
399 }
400 }
401 }
402
403 let mut working: Vec<WeightTensor> = tensors.to_vec();
405 let mut stage_stats: Vec<StageStats> = Vec::with_capacity(config.stages.len());
406
407 for stage in &config.stages {
408 let stage_name = stage.name().to_string();
409
410 let mut tensors_processed = 0usize;
411 let mut tensors_skipped = 0usize;
412 let mut params_before = 0usize;
413 let mut nonzero_after = 0usize;
414 let mut memory_before = 0usize;
415 let mut memory_after = 0usize;
416
417 for tensor in working.iter_mut() {
418 let should_skip = config.skip_embedding_layers && is_embedding_layer(&tensor.name);
419
420 let tb = tensor_bytes(tensor);
421 params_before += tensor.data.len();
422 memory_before += tb;
423
424 if should_skip {
425 tensors_skipped += 1;
426 nonzero_after += count_nonzero(tensor);
428 memory_after += tb;
429 continue;
430 }
431
432 tensors_processed += 1;
433
434 match stage {
435 CompressionStage::Prune(prune_cfg) => {
436 let (pruned, _mask) = prune_tensor(tensor, prune_cfg)?;
437 *tensor = pruned;
438 nonzero_after += count_nonzero(tensor);
439 memory_after += tensor_bytes(tensor);
440 }
441 CompressionStage::QuantizeInt8 => {
442 apply_quantize_int8_inplace(tensor);
443 nonzero_after += count_nonzero(tensor);
444 memory_after += quantize_int8_memory_after(tb);
446 }
447 CompressionStage::Clip { percentile } => {
448 apply_clip_inplace(tensor, *percentile)?;
450 nonzero_after += count_nonzero(tensor);
451 memory_after += tensor_bytes(tensor);
452 }
453 }
454 }
455
456 stage_stats.push(StageStats {
457 stage_name,
458 tensors_processed,
459 tensors_skipped,
460 params_before,
461 nonzero_params_after: nonzero_after,
462 memory_before_bytes: memory_before,
463 memory_after_bytes: memory_after,
464 });
465 }
466
467 Ok(CompressionResult {
468 compressed_tensors: working,
469 stage_stats,
470 })
471}
472
473pub fn estimate_compressed_size(tensors: &[WeightTensor], config: &CompressionConfig) -> usize {
482 if tensors.is_empty() || config.stages.is_empty() {
483 return 0;
484 }
485
486 let total_f32_bytes: usize = tensors.iter().map(tensor_bytes).sum();
488
489 let embedding_bytes: usize = if config.skip_embedding_layers {
491 tensors
492 .iter()
493 .filter(|t| is_embedding_layer(&t.name))
494 .map(tensor_bytes)
495 .sum()
496 } else {
497 0
498 };
499 let compressible_bytes = total_f32_bytes.saturating_sub(embedding_bytes);
500
501 let mut size = compressible_bytes as f64;
503 for stage in &config.stages {
504 match stage {
505 CompressionStage::Prune(_) => {
506 }
509 CompressionStage::QuantizeInt8 => {
510 size *= 0.25;
512 }
513 CompressionStage::Clip { .. } => {
514 }
516 }
517 }
518
519 embedding_bytes + size.round() as usize
521}
522
523#[cfg(test)]
528mod tests {
529 use super::*;
530
531 fn make_tensor(name: &str, data: Vec<f32>, shape: Vec<usize>) -> WeightTensor {
532 WeightTensor::new(name, data, shape)
533 }
534
535 fn linear_data(n: usize) -> Vec<f32> {
536 (1..=n).map(|i| i as f32).collect()
537 }
538
539 #[test]
540 fn is_embedding_layer_matches_embed_prefix() {
541 assert!(is_embedding_layer("embed.weight"));
542 assert!(is_embedding_layer("Embed.weight"));
543 assert!(is_embedding_layer("embedding_layer"));
544 assert!(is_embedding_layer("token_embedding"));
545 assert!(!is_embedding_layer("linear.weight"));
546 assert!(!is_embedding_layer("layer_norm"));
547 }
548
549 #[test]
550 fn apply_quantize_int8_preserves_sign() {
551 let mut t = make_tensor("w", vec![1.0, -2.0, 0.5, -0.25], vec![4]);
552 apply_quantize_int8_inplace(&mut t);
553 assert!(t.data[0] > 0.0);
554 assert!(t.data[1] < 0.0);
555 assert!(t.data[2] > 0.0);
556 assert!(t.data[3] < 0.0);
557 }
558
559 #[test]
560 fn apply_clip_zeros_small_values() {
561 let mut t = make_tensor("w", linear_data(10), vec![10]);
562 apply_clip_inplace(&mut t, 0.3).expect("clip ok");
564 assert_eq!(t.data[0], 0.0);
565 assert_eq!(t.data[1], 0.0);
566 assert_eq!(t.data[2], 0.0);
567 assert!(t.data[9] != 0.0);
568 }
569
570 #[test]
571 fn apply_clip_invalid_percentile_returns_error() {
572 let mut t = make_tensor("w", vec![1.0; 4], vec![4]);
573 assert!(apply_clip_inplace(&mut t, 0.0).is_err());
574 assert!(apply_clip_inplace(&mut t, 1.1).is_err());
575 assert!(apply_clip_inplace(&mut t, -0.5).is_err());
576 assert!(apply_clip_inplace(&mut t, 1.0).is_ok()); }
578
579 #[test]
580 fn stage_stats_compression_ratio_equals_before_over_after() {
581 let stats = StageStats {
582 stage_name: "prune".to_string(),
583 tensors_processed: 1,
584 tensors_skipped: 0,
585 params_before: 100,
586 nonzero_params_after: 50,
587 memory_before_bytes: 400,
588 memory_after_bytes: 400,
589 };
590 let ratio = stats.compression_ratio();
591 assert!((ratio - 1.0).abs() < 1e-6);
592 }
593
594 #[test]
595 fn stage_stats_sparsity_half() {
596 let stats = StageStats {
597 stage_name: "prune".to_string(),
598 tensors_processed: 2,
599 tensors_skipped: 0,
600 params_before: 100,
601 nonzero_params_after: 50,
602 memory_before_bytes: 400,
603 memory_after_bytes: 400,
604 };
605 assert!((stats.sparsity() - 0.5).abs() < 1e-6);
606 }
607
608 #[test]
609 fn compression_result_memory_helpers() {
610 let result = CompressionResult {
611 compressed_tensors: vec![],
612 stage_stats: vec![
613 StageStats {
614 stage_name: "prune".to_string(),
615 tensors_processed: 1,
616 tensors_skipped: 0,
617 params_before: 10,
618 nonzero_params_after: 5,
619 memory_before_bytes: 40,
620 memory_after_bytes: 40,
621 },
622 StageStats {
623 stage_name: "quantize_int8".to_string(),
624 tensors_processed: 1,
625 tensors_skipped: 0,
626 params_before: 10,
627 nonzero_params_after: 5,
628 memory_before_bytes: 40,
629 memory_after_bytes: 10,
630 },
631 ],
632 };
633 assert_eq!(result.memory_before_bytes(), 40);
634 assert_eq!(result.memory_after_bytes(), 10);
635 assert!((result.total_compression_ratio() - 4.0).abs() < 1e-4);
636 }
637
638 #[test]
639 fn compress_model_returns_same_tensor_count() {
640 let tensors = vec![
641 make_tensor("layer1.weight", linear_data(8), vec![2, 4]),
642 make_tensor("layer2.weight", linear_data(4), vec![2, 2]),
643 ];
644 let config = CompressionConfig::quantize_only();
645 let result = compress_model(&tensors, &config).expect("compress ok");
646 assert_eq!(result.compressed_tensors.len(), 2);
647 }
648
649 #[test]
650 fn compress_model_prune_reduces_nonzero() {
651 let tensors = vec![make_tensor("layer.weight", linear_data(10), vec![10])];
652 let config = CompressionConfig::prune_only(0.5);
653 let result = compress_model(&tensors, &config).expect("compress ok");
654 assert!(result.total_nonzero() < 10);
655 }
656}