Skip to main content

oxibonsai_model/
compression.rs

1//! Model compression pipeline: prune → quantize → report.
2//!
3//! Combines pruning and quantization stages into a unified pipeline that can
4//! be configured and applied to a collection of weight tensors. Each stage
5//! records statistics that are aggregated into a final `CompressionResult`.
6//!
7//! # Example
8//!
9//! ```rust
10//! use oxibonsai_model::compression::{CompressionConfig, compress_model};
11//! use oxibonsai_model::model_merge::WeightTensor;
12//!
13//! let tensors = vec![
14//!     WeightTensor::new("layer.weight", vec![1.0, -0.5, 0.3, 2.0], vec![2, 2]),
15//! ];
16//! let config = CompressionConfig::prune_then_quantize(0.5);
17//! let result = compress_model(&tensors, &config).expect("compression failed");
18//! println!("{}", result.summary());
19//! ```
20
21use crate::model_merge::WeightTensor;
22use crate::pruning::{prune_tensor, PruningConfig, PruningError};
23
24// ──────────────────────────────────────────────────────────────────
25// CompressionStage
26// ──────────────────────────────────────────────────────────────────
27
28/// A compression stage to apply in sequence.
29#[derive(Debug, Clone)]
30pub enum CompressionStage {
31    /// Prune weights to target sparsity using the provided configuration.
32    Prune(PruningConfig),
33    /// Quantize to INT8 (per-tensor, simulate INT8 precision loss while keeping
34    /// f32 storage). Memory footprint is reported as `original * 0.25` (theoretical INT8 size).
35    QuantizeInt8,
36    /// Apply magnitude-based weight clipping: zero out weights whose absolute value
37    /// falls below the given percentile of absolute values in each tensor.
38    Clip {
39        /// Fraction in `(0.0, 1.0]` — e.g. `0.1` zeros the bottom 10% of weights.
40        percentile: f32,
41    },
42}
43
44impl CompressionStage {
45    /// Human-readable name for this stage.
46    pub fn name(&self) -> &'static str {
47        match self {
48            CompressionStage::Prune(_) => "prune",
49            CompressionStage::QuantizeInt8 => "quantize_int8",
50            CompressionStage::Clip { .. } => "clip",
51        }
52    }
53}
54
55// ──────────────────────────────────────────────────────────────────
56// CompressionConfig
57// ──────────────────────────────────────────────────────────────────
58
59/// Full pipeline configuration: ordered list of stages and global options.
60#[derive(Debug, Clone, Default)]
61pub struct CompressionConfig {
62    /// Ordered sequence of compression stages to apply.
63    pub stages: Vec<CompressionStage>,
64    /// When `true`, tensors whose name starts with `"embed"` or `"token"` are
65    /// skipped (not passed through any compression stage).
66    pub skip_embedding_layers: bool,
67}
68
69impl CompressionConfig {
70    /// Create a new, empty compression config (no stages).
71    pub fn new() -> Self {
72        Self {
73            stages: Vec::new(),
74            skip_embedding_layers: false,
75        }
76    }
77
78    /// Append a stage and return `self` for chaining.
79    pub fn add_stage(mut self, stage: CompressionStage) -> Self {
80        self.stages.push(stage);
81        self
82    }
83
84    /// Convenience: L1-unstructured pruning at `sparsity`, followed by INT8 quantization.
85    pub fn prune_then_quantize(sparsity: f32) -> Self {
86        let prune_cfg = PruningConfig::unstructured_l1(sparsity);
87        Self::new()
88            .add_stage(CompressionStage::Prune(prune_cfg))
89            .add_stage(CompressionStage::QuantizeInt8)
90    }
91
92    /// Convenience: INT8 quantization only.
93    pub fn quantize_only() -> Self {
94        Self::new().add_stage(CompressionStage::QuantizeInt8)
95    }
96
97    /// Convenience: L1-unstructured pruning only at `sparsity`.
98    pub fn prune_only(sparsity: f32) -> Self {
99        let prune_cfg = PruningConfig::unstructured_l1(sparsity);
100        Self::new().add_stage(CompressionStage::Prune(prune_cfg))
101    }
102}
103
104// ──────────────────────────────────────────────────────────────────
105// StageStats
106// ──────────────────────────────────────────────────────────────────
107
108/// Per-stage compression statistics.
109#[derive(Debug, Clone)]
110pub struct StageStats {
111    /// Name of the stage that produced these stats.
112    pub stage_name: String,
113    /// Number of tensors that were processed by this stage.
114    pub tensors_processed: usize,
115    /// Number of tensors skipped (e.g. embedding layers with `skip_embedding_layers`).
116    pub tensors_skipped: usize,
117    /// Total number of parameters (elements) entering this stage.
118    pub params_before: usize,
119    /// Number of non-zero parameters after this stage.
120    pub nonzero_params_after: usize,
121    /// Total memory (bytes) of all processed tensors before this stage.
122    pub memory_before_bytes: usize,
123    /// Total memory (bytes) of all processed tensors after this stage.
124    pub memory_after_bytes: usize,
125}
126
127impl StageStats {
128    /// Ratio of `memory_before_bytes / memory_after_bytes`. Returns `1.0` if
129    /// `memory_after_bytes` is zero to avoid division by zero.
130    pub fn compression_ratio(&self) -> f32 {
131        if self.memory_after_bytes == 0 {
132            return 1.0;
133        }
134        self.memory_before_bytes as f32 / self.memory_after_bytes as f32
135    }
136
137    /// Fraction of parameters that are zero after this stage.
138    pub fn sparsity(&self) -> f32 {
139        if self.params_before == 0 {
140            return 0.0;
141        }
142        let zeros = self.params_before.saturating_sub(self.nonzero_params_after);
143        zeros as f32 / self.params_before as f32
144    }
145}
146
147// ──────────────────────────────────────────────────────────────────
148// CompressionResult
149// ──────────────────────────────────────────────────────────────────
150
151/// The outcome of running the full compression pipeline.
152#[derive(Debug, Clone)]
153pub struct CompressionResult {
154    /// The compressed (and possibly sparsified) tensors in the original order.
155    pub compressed_tensors: Vec<WeightTensor>,
156    /// One `StageStats` entry per stage in the pipeline.
157    pub stage_stats: Vec<StageStats>,
158}
159
160impl CompressionResult {
161    /// Total number of parameters across all compressed tensors.
162    pub fn total_params(&self) -> usize {
163        self.compressed_tensors.iter().map(|t| t.data.len()).sum()
164    }
165
166    /// Total number of non-zero parameters across all compressed tensors.
167    pub fn total_nonzero(&self) -> usize {
168        self.compressed_tensors
169            .iter()
170            .map(|t| t.data.iter().filter(|&&x| x != 0.0).count())
171            .sum()
172    }
173
174    /// Overall sparsity (fraction of zero weights) across all compressed tensors.
175    pub fn overall_sparsity(&self) -> f32 {
176        let total = self.total_params();
177        if total == 0 {
178            return 0.0;
179        }
180        let nonzero = self.total_nonzero();
181        let zeros = total.saturating_sub(nonzero);
182        zeros as f32 / total as f32
183    }
184
185    /// Compression ratio: `memory_before / memory_after` using the first and last
186    /// stage's memory stats. Falls back to `1.0` if there are no stages.
187    pub fn total_compression_ratio(&self) -> f32 {
188        if self.stage_stats.is_empty() {
189            return 1.0;
190        }
191        let before = self.memory_before_bytes();
192        let after = self.memory_after_bytes();
193        if after == 0 {
194            return 1.0;
195        }
196        before as f32 / after as f32
197    }
198
199    /// Memory (bytes) before any compression: taken from the first stage's
200    /// `memory_before_bytes`. Returns `0` if there are no stage stats.
201    pub fn memory_before_bytes(&self) -> usize {
202        self.stage_stats
203            .first()
204            .map(|s| s.memory_before_bytes)
205            .unwrap_or(0)
206    }
207
208    /// Memory (bytes) after all compression: taken from the last stage's
209    /// `memory_after_bytes`. Returns `0` if there are no stage stats.
210    pub fn memory_after_bytes(&self) -> usize {
211        self.stage_stats
212            .last()
213            .map(|s| s.memory_after_bytes)
214            .unwrap_or(0)
215    }
216
217    /// Human-readable multi-line summary of all stages and overall statistics.
218    pub fn summary(&self) -> String {
219        let mut lines: Vec<String> = Vec::new();
220        lines.push(format!(
221            "=== Compression Summary ({} stage(s)) ===",
222            self.stage_stats.len()
223        ));
224        for (i, stats) in self.stage_stats.iter().enumerate() {
225            lines.push(format!(
226                "  Stage {}: [{}] processed={} skipped={} sparsity={:.4} ratio={:.3}x \
227                 memory={}B->{}B",
228                i + 1,
229                stats.stage_name,
230                stats.tensors_processed,
231                stats.tensors_skipped,
232                stats.sparsity(),
233                stats.compression_ratio(),
234                stats.memory_before_bytes,
235                stats.memory_after_bytes,
236            ));
237        }
238        lines.push(format!(
239            "  Overall: tensors={} total_params={} nonzero={} sparsity={:.4} \
240             compression_ratio={:.3}x memory={}B->{}B",
241            self.compressed_tensors.len(),
242            self.total_params(),
243            self.total_nonzero(),
244            self.overall_sparsity(),
245            self.total_compression_ratio(),
246            self.memory_before_bytes(),
247            self.memory_after_bytes(),
248        ));
249        lines.join("\n")
250    }
251}
252
253// ──────────────────────────────────────────────────────────────────
254// CompressionError
255// ──────────────────────────────────────────────────────────────────
256
257/// Errors that can arise during the compression pipeline.
258#[derive(Debug, thiserror::Error)]
259pub enum CompressionError {
260    /// Wraps an underlying `PruningError` from the pruning stage.
261    #[error("pruning error: {0}")]
262    Pruning(#[from] PruningError),
263
264    /// The input tensor slice is empty; there is nothing to compress.
265    #[error("empty model: no tensors")]
266    EmptyModel,
267
268    /// The pipeline has no stages configured; nothing would be done.
269    #[error("empty pipeline: no stages")]
270    EmptyPipeline,
271
272    /// The clip percentile is outside the valid range `(0.0, 1.0]`.
273    #[error("invalid clip percentile {0}: must be in (0, 1]")]
274    InvalidPercentile(f32),
275}
276
277// ──────────────────────────────────────────────────────────────────
278// Internal helpers
279// ──────────────────────────────────────────────────────────────────
280
281/// Return `true` if this tensor should be skipped based on `skip_embedding_layers`.
282#[inline]
283fn is_embedding_layer(name: &str) -> bool {
284    let lower = name.to_ascii_lowercase();
285    lower.starts_with("embed") || lower.starts_with("token")
286}
287
288/// Bytes occupied by a `WeightTensor` in f32 form.
289#[inline]
290fn tensor_bytes(tensor: &WeightTensor) -> usize {
291    tensor.data.len() * core::mem::size_of::<f32>()
292}
293
294/// Count non-zero values in a tensor.
295#[inline]
296fn count_nonzero(tensor: &WeightTensor) -> usize {
297    tensor.data.iter().filter(|&&x| x != 0.0).count()
298}
299
300/// Apply the INT8 quantization stage to a single tensor in-place.
301///
302/// Quantises each tensor per-tensor using `scale = max(|w|) / 127`,
303/// then dequantises back to f32 to simulate the precision loss.
304/// The `memory_after_bytes` is reported as `memory_before * 0.25` (theoretical INT8 size).
305fn apply_quantize_int8_inplace(tensor: &mut WeightTensor) {
306    let data = &mut tensor.data;
307    if data.is_empty() {
308        return;
309    }
310
311    // Compute per-tensor max absolute value
312    let max_abs = data.iter().map(|w| w.abs()).fold(0.0_f32, f32::max);
313    if max_abs == 0.0 {
314        return; // all-zero tensor; nothing to do
315    }
316
317    let scale = max_abs / 127.0_f32;
318
319    // Quantize → i8 → dequantize back to f32
320    for w in data.iter_mut() {
321        let q = (*w / scale).round().clamp(-127.0_f32, 127.0_f32) as i8;
322        *w = q as f32 * scale;
323    }
324}
325
326/// Apply the clip stage to a single tensor in-place.
327///
328/// Computes the `percentile`-th percentile of absolute values and zeros every
329/// element whose absolute value is at or below that threshold.
330fn apply_clip_inplace(tensor: &mut WeightTensor, percentile: f32) -> Result<(), CompressionError> {
331    if percentile <= 0.0 || percentile > 1.0 {
332        return Err(CompressionError::InvalidPercentile(percentile));
333    }
334    let data = &mut tensor.data;
335    if data.is_empty() {
336        return Ok(());
337    }
338
339    // Collect absolute values and sort ascending
340    let mut abs_vals: Vec<f32> = data.iter().map(|w| w.abs()).collect();
341    abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
342
343    let n = abs_vals.len();
344    // Index of the percentile threshold value (0-based)
345    // e.g. percentile=0.1 on 10 elements → idx = ceil(0.1*10) - 1 = 0
346    let idx = ((percentile * n as f32).ceil() as usize)
347        .saturating_sub(1)
348        .min(n - 1);
349    let threshold = abs_vals[idx];
350
351    // Zero out elements at or below the threshold
352    for w in data.iter_mut() {
353        if w.abs() <= threshold {
354            *w = 0.0;
355        }
356    }
357
358    Ok(())
359}
360
361/// Compute memory-after for a quantize_int8 stage (theoretical: 0.25 × original).
362#[inline]
363fn quantize_int8_memory_after(memory_before: usize) -> usize {
364    // INT8 is 1 byte vs 4 bytes for f32 → theoretical 4× compression
365    (memory_before as f32 * 0.25).round() as usize
366}
367
368// ──────────────────────────────────────────────────────────────────
369// Public API
370// ──────────────────────────────────────────────────────────────────
371
372/// Run the compression pipeline on `tensors` according to `config`.
373///
374/// Returns a [`CompressionResult`] containing the compressed tensors and
375/// per-stage statistics.
376///
377/// # Errors
378///
379/// - [`CompressionError::EmptyModel`] if `tensors` is empty.
380/// - [`CompressionError::EmptyPipeline`] if `config.stages` is empty.
381/// - [`CompressionError::Pruning`] if a pruning stage fails.
382/// - [`CompressionError::InvalidPercentile`] if a clip stage has an invalid percentile.
383pub fn compress_model(
384    tensors: &[WeightTensor],
385    config: &CompressionConfig,
386) -> Result<CompressionResult, CompressionError> {
387    if tensors.is_empty() {
388        return Err(CompressionError::EmptyModel);
389    }
390    if config.stages.is_empty() {
391        return Err(CompressionError::EmptyPipeline);
392    }
393
394    // Validate clip percentiles before doing any work
395    for stage in &config.stages {
396        if let CompressionStage::Clip { percentile } = stage {
397            if *percentile <= 0.0 || *percentile > 1.0 {
398                return Err(CompressionError::InvalidPercentile(*percentile));
399            }
400        }
401    }
402
403    // Working copy — we mutate in-place across stages
404    let mut working: Vec<WeightTensor> = tensors.to_vec();
405    let mut stage_stats: Vec<StageStats> = Vec::with_capacity(config.stages.len());
406
407    for stage in &config.stages {
408        let stage_name = stage.name().to_string();
409
410        let mut tensors_processed = 0usize;
411        let mut tensors_skipped = 0usize;
412        let mut params_before = 0usize;
413        let mut nonzero_after = 0usize;
414        let mut memory_before = 0usize;
415        let mut memory_after = 0usize;
416
417        for tensor in working.iter_mut() {
418            let should_skip = config.skip_embedding_layers && is_embedding_layer(&tensor.name);
419
420            let tb = tensor_bytes(tensor);
421            params_before += tensor.data.len();
422            memory_before += tb;
423
424            if should_skip {
425                tensors_skipped += 1;
426                // Skipped tensors are carried through unchanged
427                nonzero_after += count_nonzero(tensor);
428                memory_after += tb;
429                continue;
430            }
431
432            tensors_processed += 1;
433
434            match stage {
435                CompressionStage::Prune(prune_cfg) => {
436                    let (pruned, _mask) = prune_tensor(tensor, prune_cfg)?;
437                    *tensor = pruned;
438                    nonzero_after += count_nonzero(tensor);
439                    memory_after += tensor_bytes(tensor);
440                }
441                CompressionStage::QuantizeInt8 => {
442                    apply_quantize_int8_inplace(tensor);
443                    nonzero_after += count_nonzero(tensor);
444                    // Theoretical INT8 memory: 0.25 × original f32 bytes
445                    memory_after += quantize_int8_memory_after(tb);
446                }
447                CompressionStage::Clip { percentile } => {
448                    // percentile already validated above
449                    apply_clip_inplace(tensor, *percentile)?;
450                    nonzero_after += count_nonzero(tensor);
451                    memory_after += tensor_bytes(tensor);
452                }
453            }
454        }
455
456        stage_stats.push(StageStats {
457            stage_name,
458            tensors_processed,
459            tensors_skipped,
460            params_before,
461            nonzero_params_after: nonzero_after,
462            memory_before_bytes: memory_before,
463            memory_after_bytes: memory_after,
464        });
465    }
466
467    Ok(CompressionResult {
468        compressed_tensors: working,
469        stage_stats,
470    })
471}
472
473/// Estimate the compressed size in bytes without actually compressing.
474///
475/// For `Prune` stages the estimate assumes a fraction of weights will be
476/// zeroed (sparsity), but f32 storage is retained (same byte count).
477/// For `QuantizeInt8` stages memory is reduced by 4x (theoretical INT8).
478/// For `Clip` stages the estimate is the same as `Prune` (storage unchanged).
479///
480/// Returns `0` if `tensors` is empty or `config.stages` is empty.
481pub fn estimate_compressed_size(tensors: &[WeightTensor], config: &CompressionConfig) -> usize {
482    if tensors.is_empty() || config.stages.is_empty() {
483        return 0;
484    }
485
486    // Total f32 bytes of all tensors
487    let total_f32_bytes: usize = tensors.iter().map(tensor_bytes).sum();
488
489    // Compute how many bytes are attributable to embedding layers (which are skipped)
490    let embedding_bytes: usize = if config.skip_embedding_layers {
491        tensors
492            .iter()
493            .filter(|t| is_embedding_layer(&t.name))
494            .map(tensor_bytes)
495            .sum()
496    } else {
497        0
498    };
499    let compressible_bytes = total_f32_bytes.saturating_sub(embedding_bytes);
500
501    // Apply each stage's size multiplier to the compressible portion
502    let mut size = compressible_bytes as f64;
503    for stage in &config.stages {
504        match stage {
505            CompressionStage::Prune(_) => {
506                // Pruning zeros weights but doesn't change storage in dense format
507                // No change to byte count for dense storage
508            }
509            CompressionStage::QuantizeInt8 => {
510                // Theoretical 4x compression
511                size *= 0.25;
512            }
513            CompressionStage::Clip { .. } => {
514                // Same as pruning — storage unchanged
515            }
516        }
517    }
518
519    // Add back embedding bytes unchanged
520    embedding_bytes + size.round() as usize
521}
522
523// ──────────────────────────────────────────────────────────────────
524// In-module smoke tests
525// ──────────────────────────────────────────────────────────────────
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    fn make_tensor(name: &str, data: Vec<f32>, shape: Vec<usize>) -> WeightTensor {
532        WeightTensor::new(name, data, shape)
533    }
534
535    fn linear_data(n: usize) -> Vec<f32> {
536        (1..=n).map(|i| i as f32).collect()
537    }
538
539    #[test]
540    fn is_embedding_layer_matches_embed_prefix() {
541        assert!(is_embedding_layer("embed.weight"));
542        assert!(is_embedding_layer("Embed.weight"));
543        assert!(is_embedding_layer("embedding_layer"));
544        assert!(is_embedding_layer("token_embedding"));
545        assert!(!is_embedding_layer("linear.weight"));
546        assert!(!is_embedding_layer("layer_norm"));
547    }
548
549    #[test]
550    fn apply_quantize_int8_preserves_sign() {
551        let mut t = make_tensor("w", vec![1.0, -2.0, 0.5, -0.25], vec![4]);
552        apply_quantize_int8_inplace(&mut t);
553        assert!(t.data[0] > 0.0);
554        assert!(t.data[1] < 0.0);
555        assert!(t.data[2] > 0.0);
556        assert!(t.data[3] < 0.0);
557    }
558
559    #[test]
560    fn apply_clip_zeros_small_values() {
561        let mut t = make_tensor("w", linear_data(10), vec![10]);
562        // Clip bottom 30%: values 1, 2, 3 should be zeroed
563        apply_clip_inplace(&mut t, 0.3).expect("clip ok");
564        assert_eq!(t.data[0], 0.0);
565        assert_eq!(t.data[1], 0.0);
566        assert_eq!(t.data[2], 0.0);
567        assert!(t.data[9] != 0.0);
568    }
569
570    #[test]
571    fn apply_clip_invalid_percentile_returns_error() {
572        let mut t = make_tensor("w", vec![1.0; 4], vec![4]);
573        assert!(apply_clip_inplace(&mut t, 0.0).is_err());
574        assert!(apply_clip_inplace(&mut t, 1.1).is_err());
575        assert!(apply_clip_inplace(&mut t, -0.5).is_err());
576        assert!(apply_clip_inplace(&mut t, 1.0).is_ok()); // 1.0 is valid
577    }
578
579    #[test]
580    fn stage_stats_compression_ratio_equals_before_over_after() {
581        let stats = StageStats {
582            stage_name: "prune".to_string(),
583            tensors_processed: 1,
584            tensors_skipped: 0,
585            params_before: 100,
586            nonzero_params_after: 50,
587            memory_before_bytes: 400,
588            memory_after_bytes: 400,
589        };
590        let ratio = stats.compression_ratio();
591        assert!((ratio - 1.0).abs() < 1e-6);
592    }
593
594    #[test]
595    fn stage_stats_sparsity_half() {
596        let stats = StageStats {
597            stage_name: "prune".to_string(),
598            tensors_processed: 2,
599            tensors_skipped: 0,
600            params_before: 100,
601            nonzero_params_after: 50,
602            memory_before_bytes: 400,
603            memory_after_bytes: 400,
604        };
605        assert!((stats.sparsity() - 0.5).abs() < 1e-6);
606    }
607
608    #[test]
609    fn compression_result_memory_helpers() {
610        let result = CompressionResult {
611            compressed_tensors: vec![],
612            stage_stats: vec![
613                StageStats {
614                    stage_name: "prune".to_string(),
615                    tensors_processed: 1,
616                    tensors_skipped: 0,
617                    params_before: 10,
618                    nonzero_params_after: 5,
619                    memory_before_bytes: 40,
620                    memory_after_bytes: 40,
621                },
622                StageStats {
623                    stage_name: "quantize_int8".to_string(),
624                    tensors_processed: 1,
625                    tensors_skipped: 0,
626                    params_before: 10,
627                    nonzero_params_after: 5,
628                    memory_before_bytes: 40,
629                    memory_after_bytes: 10,
630                },
631            ],
632        };
633        assert_eq!(result.memory_before_bytes(), 40);
634        assert_eq!(result.memory_after_bytes(), 10);
635        assert!((result.total_compression_ratio() - 4.0).abs() < 1e-4);
636    }
637
638    #[test]
639    fn compress_model_returns_same_tensor_count() {
640        let tensors = vec![
641            make_tensor("layer1.weight", linear_data(8), vec![2, 4]),
642            make_tensor("layer2.weight", linear_data(4), vec![2, 2]),
643        ];
644        let config = CompressionConfig::quantize_only();
645        let result = compress_model(&tensors, &config).expect("compress ok");
646        assert_eq!(result.compressed_tensors.len(), 2);
647    }
648
649    #[test]
650    fn compress_model_prune_reduces_nonzero() {
651        let tensors = vec![make_tensor("layer.weight", linear_data(10), vec![10])];
652        let config = CompressionConfig::prune_only(0.5);
653        let result = compress_model(&tensors, &config).expect("compress ok");
654        assert!(result.total_nonzero() < 10);
655    }
656}