oxibonsai-model 0.1.1

Qwen3-8B Transformer implementation for OxiBonsai 1-bit inference
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
//! Model compression pipeline: prune → quantize → report.
//!
//! Combines pruning and quantization stages into a unified pipeline that can
//! be configured and applied to a collection of weight tensors. Each stage
//! records statistics that are aggregated into a final `CompressionResult`.
//!
//! # Example
//!
//! ```rust
//! use oxibonsai_model::compression::{CompressionConfig, compress_model};
//! use oxibonsai_model::model_merge::WeightTensor;
//!
//! let tensors = vec![
//!     WeightTensor::new("layer.weight", vec![1.0, -0.5, 0.3, 2.0], vec![2, 2]),
//! ];
//! let config = CompressionConfig::prune_then_quantize(0.5);
//! let result = compress_model(&tensors, &config).expect("compression failed");
//! println!("{}", result.summary());
//! ```

use crate::model_merge::WeightTensor;
use crate::pruning::{prune_tensor, PruningConfig, PruningError};

// ──────────────────────────────────────────────────────────────────
// CompressionStage
// ──────────────────────────────────────────────────────────────────

/// A compression stage to apply in sequence.
#[derive(Debug, Clone)]
pub enum CompressionStage {
    /// Prune weights to target sparsity using the provided configuration.
    Prune(PruningConfig),
    /// Quantize to INT8 (per-tensor, simulate INT8 precision loss while keeping
    /// f32 storage). Memory footprint is reported as `original * 0.25` (theoretical INT8 size).
    QuantizeInt8,
    /// Apply magnitude-based weight clipping: zero out weights whose absolute value
    /// falls below the given percentile of absolute values in each tensor.
    Clip {
        /// Fraction in `(0.0, 1.0]` — e.g. `0.1` zeros the bottom 10% of weights.
        percentile: f32,
    },
}

impl CompressionStage {
    /// Human-readable name for this stage.
    pub fn name(&self) -> &'static str {
        match self {
            CompressionStage::Prune(_) => "prune",
            CompressionStage::QuantizeInt8 => "quantize_int8",
            CompressionStage::Clip { .. } => "clip",
        }
    }
}

// ──────────────────────────────────────────────────────────────────
// CompressionConfig
// ──────────────────────────────────────────────────────────────────

/// Full pipeline configuration: ordered list of stages and global options.
#[derive(Debug, Clone, Default)]
pub struct CompressionConfig {
    /// Ordered sequence of compression stages to apply.
    pub stages: Vec<CompressionStage>,
    /// When `true`, tensors whose name starts with `"embed"` or `"token"` are
    /// skipped (not passed through any compression stage).
    pub skip_embedding_layers: bool,
}

impl CompressionConfig {
    /// Create a new, empty compression config (no stages).
    pub fn new() -> Self {
        Self {
            stages: Vec::new(),
            skip_embedding_layers: false,
        }
    }

    /// Append a stage and return `self` for chaining.
    pub fn add_stage(mut self, stage: CompressionStage) -> Self {
        self.stages.push(stage);
        self
    }

    /// Convenience: L1-unstructured pruning at `sparsity`, followed by INT8 quantization.
    pub fn prune_then_quantize(sparsity: f32) -> Self {
        let prune_cfg = PruningConfig::unstructured_l1(sparsity);
        Self::new()
            .add_stage(CompressionStage::Prune(prune_cfg))
            .add_stage(CompressionStage::QuantizeInt8)
    }

    /// Convenience: INT8 quantization only.
    pub fn quantize_only() -> Self {
        Self::new().add_stage(CompressionStage::QuantizeInt8)
    }

    /// Convenience: L1-unstructured pruning only at `sparsity`.
    pub fn prune_only(sparsity: f32) -> Self {
        let prune_cfg = PruningConfig::unstructured_l1(sparsity);
        Self::new().add_stage(CompressionStage::Prune(prune_cfg))
    }
}

// ──────────────────────────────────────────────────────────────────
// StageStats
// ──────────────────────────────────────────────────────────────────

/// Per-stage compression statistics.
#[derive(Debug, Clone)]
pub struct StageStats {
    /// Name of the stage that produced these stats.
    pub stage_name: String,
    /// Number of tensors that were processed by this stage.
    pub tensors_processed: usize,
    /// Number of tensors skipped (e.g. embedding layers with `skip_embedding_layers`).
    pub tensors_skipped: usize,
    /// Total number of parameters (elements) entering this stage.
    pub params_before: usize,
    /// Number of non-zero parameters after this stage.
    pub nonzero_params_after: usize,
    /// Total memory (bytes) of all processed tensors before this stage.
    pub memory_before_bytes: usize,
    /// Total memory (bytes) of all processed tensors after this stage.
    pub memory_after_bytes: usize,
}

impl StageStats {
    /// Ratio of `memory_before_bytes / memory_after_bytes`. Returns `1.0` if
    /// `memory_after_bytes` is zero to avoid division by zero.
    pub fn compression_ratio(&self) -> f32 {
        if self.memory_after_bytes == 0 {
            return 1.0;
        }
        self.memory_before_bytes as f32 / self.memory_after_bytes as f32
    }

    /// Fraction of parameters that are zero after this stage.
    pub fn sparsity(&self) -> f32 {
        if self.params_before == 0 {
            return 0.0;
        }
        let zeros = self.params_before.saturating_sub(self.nonzero_params_after);
        zeros as f32 / self.params_before as f32
    }
}

// ──────────────────────────────────────────────────────────────────
// CompressionResult
// ──────────────────────────────────────────────────────────────────

/// The outcome of running the full compression pipeline.
#[derive(Debug, Clone)]
pub struct CompressionResult {
    /// The compressed (and possibly sparsified) tensors in the original order.
    pub compressed_tensors: Vec<WeightTensor>,
    /// One `StageStats` entry per stage in the pipeline.
    pub stage_stats: Vec<StageStats>,
}

impl CompressionResult {
    /// Total number of parameters across all compressed tensors.
    pub fn total_params(&self) -> usize {
        self.compressed_tensors.iter().map(|t| t.data.len()).sum()
    }

    /// Total number of non-zero parameters across all compressed tensors.
    pub fn total_nonzero(&self) -> usize {
        self.compressed_tensors
            .iter()
            .map(|t| t.data.iter().filter(|&&x| x != 0.0).count())
            .sum()
    }

    /// Overall sparsity (fraction of zero weights) across all compressed tensors.
    pub fn overall_sparsity(&self) -> f32 {
        let total = self.total_params();
        if total == 0 {
            return 0.0;
        }
        let nonzero = self.total_nonzero();
        let zeros = total.saturating_sub(nonzero);
        zeros as f32 / total as f32
    }

    /// Compression ratio: `memory_before / memory_after` using the first and last
    /// stage's memory stats. Falls back to `1.0` if there are no stages.
    pub fn total_compression_ratio(&self) -> f32 {
        if self.stage_stats.is_empty() {
            return 1.0;
        }
        let before = self.memory_before_bytes();
        let after = self.memory_after_bytes();
        if after == 0 {
            return 1.0;
        }
        before as f32 / after as f32
    }

    /// Memory (bytes) before any compression: taken from the first stage's
    /// `memory_before_bytes`. Returns `0` if there are no stage stats.
    pub fn memory_before_bytes(&self) -> usize {
        self.stage_stats
            .first()
            .map(|s| s.memory_before_bytes)
            .unwrap_or(0)
    }

    /// Memory (bytes) after all compression: taken from the last stage's
    /// `memory_after_bytes`. Returns `0` if there are no stage stats.
    pub fn memory_after_bytes(&self) -> usize {
        self.stage_stats
            .last()
            .map(|s| s.memory_after_bytes)
            .unwrap_or(0)
    }

    /// Human-readable multi-line summary of all stages and overall statistics.
    pub fn summary(&self) -> String {
        let mut lines: Vec<String> = Vec::new();
        lines.push(format!(
            "=== Compression Summary ({} stage(s)) ===",
            self.stage_stats.len()
        ));
        for (i, stats) in self.stage_stats.iter().enumerate() {
            lines.push(format!(
                "  Stage {}: [{}] processed={} skipped={} sparsity={:.4} ratio={:.3}x \
                 memory={}B->{}B",
                i + 1,
                stats.stage_name,
                stats.tensors_processed,
                stats.tensors_skipped,
                stats.sparsity(),
                stats.compression_ratio(),
                stats.memory_before_bytes,
                stats.memory_after_bytes,
            ));
        }
        lines.push(format!(
            "  Overall: tensors={} total_params={} nonzero={} sparsity={:.4} \
             compression_ratio={:.3}x memory={}B->{}B",
            self.compressed_tensors.len(),
            self.total_params(),
            self.total_nonzero(),
            self.overall_sparsity(),
            self.total_compression_ratio(),
            self.memory_before_bytes(),
            self.memory_after_bytes(),
        ));
        lines.join("\n")
    }
}

// ──────────────────────────────────────────────────────────────────
// CompressionError
// ──────────────────────────────────────────────────────────────────

/// Errors that can arise during the compression pipeline.
#[derive(Debug, thiserror::Error)]
pub enum CompressionError {
    /// Wraps an underlying `PruningError` from the pruning stage.
    #[error("pruning error: {0}")]
    Pruning(#[from] PruningError),

    /// The input tensor slice is empty; there is nothing to compress.
    #[error("empty model: no tensors")]
    EmptyModel,

    /// The pipeline has no stages configured; nothing would be done.
    #[error("empty pipeline: no stages")]
    EmptyPipeline,

    /// The clip percentile is outside the valid range `(0.0, 1.0]`.
    #[error("invalid clip percentile {0}: must be in (0, 1]")]
    InvalidPercentile(f32),
}

// ──────────────────────────────────────────────────────────────────
// Internal helpers
// ──────────────────────────────────────────────────────────────────

/// Return `true` if this tensor should be skipped based on `skip_embedding_layers`.
#[inline]
fn is_embedding_layer(name: &str) -> bool {
    let lower = name.to_ascii_lowercase();
    lower.starts_with("embed") || lower.starts_with("token")
}

/// Bytes occupied by a `WeightTensor` in f32 form.
#[inline]
fn tensor_bytes(tensor: &WeightTensor) -> usize {
    tensor.data.len() * core::mem::size_of::<f32>()
}

/// Count non-zero values in a tensor.
#[inline]
fn count_nonzero(tensor: &WeightTensor) -> usize {
    tensor.data.iter().filter(|&&x| x != 0.0).count()
}

/// Apply the INT8 quantization stage to a single tensor in-place.
///
/// Quantises each tensor per-tensor using `scale = max(|w|) / 127`,
/// then dequantises back to f32 to simulate the precision loss.
/// The `memory_after_bytes` is reported as `memory_before * 0.25` (theoretical INT8 size).
fn apply_quantize_int8_inplace(tensor: &mut WeightTensor) {
    let data = &mut tensor.data;
    if data.is_empty() {
        return;
    }

    // Compute per-tensor max absolute value
    let max_abs = data.iter().map(|w| w.abs()).fold(0.0_f32, f32::max);
    if max_abs == 0.0 {
        return; // all-zero tensor; nothing to do
    }

    let scale = max_abs / 127.0_f32;

    // Quantize → i8 → dequantize back to f32
    for w in data.iter_mut() {
        let q = (*w / scale).round().clamp(-127.0_f32, 127.0_f32) as i8;
        *w = q as f32 * scale;
    }
}

/// Apply the clip stage to a single tensor in-place.
///
/// Computes the `percentile`-th percentile of absolute values and zeros every
/// element whose absolute value is at or below that threshold.
fn apply_clip_inplace(tensor: &mut WeightTensor, percentile: f32) -> Result<(), CompressionError> {
    if percentile <= 0.0 || percentile > 1.0 {
        return Err(CompressionError::InvalidPercentile(percentile));
    }
    let data = &mut tensor.data;
    if data.is_empty() {
        return Ok(());
    }

    // Collect absolute values and sort ascending
    let mut abs_vals: Vec<f32> = data.iter().map(|w| w.abs()).collect();
    abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));

    let n = abs_vals.len();
    // Index of the percentile threshold value (0-based)
    // e.g. percentile=0.1 on 10 elements → idx = ceil(0.1*10) - 1 = 0
    let idx = ((percentile * n as f32).ceil() as usize)
        .saturating_sub(1)
        .min(n - 1);
    let threshold = abs_vals[idx];

    // Zero out elements at or below the threshold
    for w in data.iter_mut() {
        if w.abs() <= threshold {
            *w = 0.0;
        }
    }

    Ok(())
}

/// Compute memory-after for a quantize_int8 stage (theoretical: 0.25 × original).
#[inline]
fn quantize_int8_memory_after(memory_before: usize) -> usize {
    // INT8 is 1 byte vs 4 bytes for f32 → theoretical 4× compression
    (memory_before as f32 * 0.25).round() as usize
}

// ──────────────────────────────────────────────────────────────────
// Public API
// ──────────────────────────────────────────────────────────────────

/// Run the compression pipeline on `tensors` according to `config`.
///
/// Returns a [`CompressionResult`] containing the compressed tensors and
/// per-stage statistics.
///
/// # Errors
///
/// - [`CompressionError::EmptyModel`] if `tensors` is empty.
/// - [`CompressionError::EmptyPipeline`] if `config.stages` is empty.
/// - [`CompressionError::Pruning`] if a pruning stage fails.
/// - [`CompressionError::InvalidPercentile`] if a clip stage has an invalid percentile.
pub fn compress_model(
    tensors: &[WeightTensor],
    config: &CompressionConfig,
) -> Result<CompressionResult, CompressionError> {
    if tensors.is_empty() {
        return Err(CompressionError::EmptyModel);
    }
    if config.stages.is_empty() {
        return Err(CompressionError::EmptyPipeline);
    }

    // Validate clip percentiles before doing any work
    for stage in &config.stages {
        if let CompressionStage::Clip { percentile } = stage {
            if *percentile <= 0.0 || *percentile > 1.0 {
                return Err(CompressionError::InvalidPercentile(*percentile));
            }
        }
    }

    // Working copy — we mutate in-place across stages
    let mut working: Vec<WeightTensor> = tensors.to_vec();
    let mut stage_stats: Vec<StageStats> = Vec::with_capacity(config.stages.len());

    for stage in &config.stages {
        let stage_name = stage.name().to_string();

        let mut tensors_processed = 0usize;
        let mut tensors_skipped = 0usize;
        let mut params_before = 0usize;
        let mut nonzero_after = 0usize;
        let mut memory_before = 0usize;
        let mut memory_after = 0usize;

        for tensor in working.iter_mut() {
            let should_skip = config.skip_embedding_layers && is_embedding_layer(&tensor.name);

            let tb = tensor_bytes(tensor);
            params_before += tensor.data.len();
            memory_before += tb;

            if should_skip {
                tensors_skipped += 1;
                // Skipped tensors are carried through unchanged
                nonzero_after += count_nonzero(tensor);
                memory_after += tb;
                continue;
            }

            tensors_processed += 1;

            match stage {
                CompressionStage::Prune(prune_cfg) => {
                    let (pruned, _mask) = prune_tensor(tensor, prune_cfg)?;
                    *tensor = pruned;
                    nonzero_after += count_nonzero(tensor);
                    memory_after += tensor_bytes(tensor);
                }
                CompressionStage::QuantizeInt8 => {
                    apply_quantize_int8_inplace(tensor);
                    nonzero_after += count_nonzero(tensor);
                    // Theoretical INT8 memory: 0.25 × original f32 bytes
                    memory_after += quantize_int8_memory_after(tb);
                }
                CompressionStage::Clip { percentile } => {
                    // percentile already validated above
                    apply_clip_inplace(tensor, *percentile)?;
                    nonzero_after += count_nonzero(tensor);
                    memory_after += tensor_bytes(tensor);
                }
            }
        }

        stage_stats.push(StageStats {
            stage_name,
            tensors_processed,
            tensors_skipped,
            params_before,
            nonzero_params_after: nonzero_after,
            memory_before_bytes: memory_before,
            memory_after_bytes: memory_after,
        });
    }

    Ok(CompressionResult {
        compressed_tensors: working,
        stage_stats,
    })
}

/// Estimate the compressed size in bytes without actually compressing.
///
/// For `Prune` stages the estimate assumes a fraction of weights will be
/// zeroed (sparsity), but f32 storage is retained (same byte count).
/// For `QuantizeInt8` stages memory is reduced by 4x (theoretical INT8).
/// For `Clip` stages the estimate is the same as `Prune` (storage unchanged).
///
/// Returns `0` if `tensors` is empty or `config.stages` is empty.
pub fn estimate_compressed_size(tensors: &[WeightTensor], config: &CompressionConfig) -> usize {
    if tensors.is_empty() || config.stages.is_empty() {
        return 0;
    }

    // Total f32 bytes of all tensors
    let total_f32_bytes: usize = tensors.iter().map(tensor_bytes).sum();

    // Compute how many bytes are attributable to embedding layers (which are skipped)
    let embedding_bytes: usize = if config.skip_embedding_layers {
        tensors
            .iter()
            .filter(|t| is_embedding_layer(&t.name))
            .map(tensor_bytes)
            .sum()
    } else {
        0
    };
    let compressible_bytes = total_f32_bytes.saturating_sub(embedding_bytes);

    // Apply each stage's size multiplier to the compressible portion
    let mut size = compressible_bytes as f64;
    for stage in &config.stages {
        match stage {
            CompressionStage::Prune(_) => {
                // Pruning zeros weights but doesn't change storage in dense format
                // No change to byte count for dense storage
            }
            CompressionStage::QuantizeInt8 => {
                // Theoretical 4x compression
                size *= 0.25;
            }
            CompressionStage::Clip { .. } => {
                // Same as pruning — storage unchanged
            }
        }
    }

    // Add back embedding bytes unchanged
    embedding_bytes + size.round() as usize
}

// ──────────────────────────────────────────────────────────────────
// In-module smoke tests
// ──────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;

    fn make_tensor(name: &str, data: Vec<f32>, shape: Vec<usize>) -> WeightTensor {
        WeightTensor::new(name, data, shape)
    }

    fn linear_data(n: usize) -> Vec<f32> {
        (1..=n).map(|i| i as f32).collect()
    }

    #[test]
    fn is_embedding_layer_matches_embed_prefix() {
        assert!(is_embedding_layer("embed.weight"));
        assert!(is_embedding_layer("Embed.weight"));
        assert!(is_embedding_layer("embedding_layer"));
        assert!(is_embedding_layer("token_embedding"));
        assert!(!is_embedding_layer("linear.weight"));
        assert!(!is_embedding_layer("layer_norm"));
    }

    #[test]
    fn apply_quantize_int8_preserves_sign() {
        let mut t = make_tensor("w", vec![1.0, -2.0, 0.5, -0.25], vec![4]);
        apply_quantize_int8_inplace(&mut t);
        assert!(t.data[0] > 0.0);
        assert!(t.data[1] < 0.0);
        assert!(t.data[2] > 0.0);
        assert!(t.data[3] < 0.0);
    }

    #[test]
    fn apply_clip_zeros_small_values() {
        let mut t = make_tensor("w", linear_data(10), vec![10]);
        // Clip bottom 30%: values 1, 2, 3 should be zeroed
        apply_clip_inplace(&mut t, 0.3).expect("clip ok");
        assert_eq!(t.data[0], 0.0);
        assert_eq!(t.data[1], 0.0);
        assert_eq!(t.data[2], 0.0);
        assert!(t.data[9] != 0.0);
    }

    #[test]
    fn apply_clip_invalid_percentile_returns_error() {
        let mut t = make_tensor("w", vec![1.0; 4], vec![4]);
        assert!(apply_clip_inplace(&mut t, 0.0).is_err());
        assert!(apply_clip_inplace(&mut t, 1.1).is_err());
        assert!(apply_clip_inplace(&mut t, -0.5).is_err());
        assert!(apply_clip_inplace(&mut t, 1.0).is_ok()); // 1.0 is valid
    }

    #[test]
    fn stage_stats_compression_ratio_equals_before_over_after() {
        let stats = StageStats {
            stage_name: "prune".to_string(),
            tensors_processed: 1,
            tensors_skipped: 0,
            params_before: 100,
            nonzero_params_after: 50,
            memory_before_bytes: 400,
            memory_after_bytes: 400,
        };
        let ratio = stats.compression_ratio();
        assert!((ratio - 1.0).abs() < 1e-6);
    }

    #[test]
    fn stage_stats_sparsity_half() {
        let stats = StageStats {
            stage_name: "prune".to_string(),
            tensors_processed: 2,
            tensors_skipped: 0,
            params_before: 100,
            nonzero_params_after: 50,
            memory_before_bytes: 400,
            memory_after_bytes: 400,
        };
        assert!((stats.sparsity() - 0.5).abs() < 1e-6);
    }

    #[test]
    fn compression_result_memory_helpers() {
        let result = CompressionResult {
            compressed_tensors: vec![],
            stage_stats: vec![
                StageStats {
                    stage_name: "prune".to_string(),
                    tensors_processed: 1,
                    tensors_skipped: 0,
                    params_before: 10,
                    nonzero_params_after: 5,
                    memory_before_bytes: 40,
                    memory_after_bytes: 40,
                },
                StageStats {
                    stage_name: "quantize_int8".to_string(),
                    tensors_processed: 1,
                    tensors_skipped: 0,
                    params_before: 10,
                    nonzero_params_after: 5,
                    memory_before_bytes: 40,
                    memory_after_bytes: 10,
                },
            ],
        };
        assert_eq!(result.memory_before_bytes(), 40);
        assert_eq!(result.memory_after_bytes(), 10);
        assert!((result.total_compression_ratio() - 4.0).abs() < 1e-4);
    }

    #[test]
    fn compress_model_returns_same_tensor_count() {
        let tensors = vec![
            make_tensor("layer1.weight", linear_data(8), vec![2, 4]),
            make_tensor("layer2.weight", linear_data(4), vec![2, 2]),
        ];
        let config = CompressionConfig::quantize_only();
        let result = compress_model(&tensors, &config).expect("compress ok");
        assert_eq!(result.compressed_tensors.len(), 2);
    }

    #[test]
    fn compress_model_prune_reduces_nonzero() {
        let tensors = vec![make_tensor("layer.weight", linear_data(10), vec![10])];
        let config = CompressionConfig::prune_only(0.5);
        let result = compress_model(&tensors, &config).expect("compress ok");
        assert!(result.total_nonzero() < 10);
    }
}