trustformers-core 0.1.1

Core traits and utilities for TrustformeRS
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
//! Mixed-bit quantization implementation for TrustformeRS
//!
//! This module provides mixed-bit quantization where different layers/channels
//! can use different quantization bit widths based on their importance and
//! sensitivity to quantization errors.

#![allow(unused_variables)] // Mixed-bit quantization implementation

use crate::errors::Result;
use crate::quantization::base::QuantizationScheme;
use crate::tensor::Tensor;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// Mixed-bit quantization configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MixedBitConfig {
    /// Layer-specific quantization configurations
    pub layer_configs: HashMap<String, LayerQuantConfig>,
    /// Default configuration for layers not specified
    pub default_config: LayerQuantConfig,
    /// Sensitivity analysis configuration
    pub sensitivity_config: SensitivityConfig,
    /// Automatic bit allocation strategy
    pub auto_bit_allocation: Option<AutoBitAllocationStrategy>,
}

/// Layer-specific quantization configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerQuantConfig {
    /// Bit width for weights
    pub weight_bits: u8,
    /// Bit width for activations
    pub activation_bits: u8,
    /// Quantization scheme to use
    pub scheme: QuantizationScheme,
    /// Whether to use symmetric quantization
    pub symmetric: bool,
    /// Group size for grouped quantization
    pub group_size: Option<usize>,
    /// Channel-specific bit allocation
    pub channel_bits: Option<Vec<u8>>,
}

/// Sensitivity analysis configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SensitivityConfig {
    /// Number of calibration samples
    pub calibration_samples: usize,
    /// Threshold for sensitivity (higher = more sensitive)
    pub sensitivity_threshold: f32,
    /// Metrics to consider for sensitivity analysis
    pub metrics: Vec<SensitivityMetric>,
}

/// Sensitivity metrics for determining quantization bit allocation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SensitivityMetric {
    /// Gradient magnitude
    GradientMagnitude,
    /// Hessian diagonal
    HessianDiagonal,
    /// Activation variance
    ActivationVariance,
    /// Weight variance
    WeightVariance,
    /// Output sensitivity
    OutputSensitivity,
}

/// Automatic bit allocation strategies
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AutoBitAllocationStrategy {
    /// Sensitivity-based allocation
    SensitivityBased {
        /// Target model size compression ratio
        target_compression: f32,
        /// Minimum bits per layer
        min_bits: u8,
        /// Maximum bits per layer
        max_bits: u8,
    },
    /// Uniform allocation with adaptive adjustment
    AdaptiveUniform {
        /// Base bit width
        base_bits: u8,
        /// Adjustment range
        adjustment_range: u8,
    },
    /// Performance-driven allocation
    PerformanceDriven {
        /// Target inference latency (ms)
        target_latency: f32,
        /// Accuracy tolerance
        accuracy_tolerance: f32,
    },
}

/// Mixed-bit quantized tensor
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MixedBitQuantizedTensor {
    /// Layer name
    pub layer_name: String,
    /// Quantized data with different bit widths
    pub quantized_data: Vec<QuantizedBlock>,
    /// Original tensor shape
    pub shape: Vec<usize>,
    /// Quantization configuration used
    pub config: LayerQuantConfig,
    /// Sensitivity scores for each block
    pub sensitivity_scores: Vec<f32>,
}

/// A block of quantized data with specific bit width
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedBlock {
    /// Quantized data
    pub data: Vec<u8>,
    /// Scale factor
    pub scale: f32,
    /// Zero point
    pub zero_point: i32,
    /// Bit width used for this block
    pub bit_width: u8,
    /// Block shape
    pub block_shape: Vec<usize>,
    /// Block offset in the original tensor
    pub block_offset: Vec<usize>,
}

/// Mixed-bit quantizer
pub struct MixedBitQuantizer {
    config: MixedBitConfig,
    sensitivity_analyzer: SensitivityAnalyzer,
}

/// Sensitivity analyzer for determining optimal bit allocations
struct SensitivityAnalyzer {
    config: SensitivityConfig,
    sensitivity_cache: HashMap<String, Vec<f32>>,
}

impl Default for MixedBitConfig {
    fn default() -> Self {
        Self {
            layer_configs: HashMap::new(),
            default_config: LayerQuantConfig::default(),
            sensitivity_config: SensitivityConfig::default(),
            auto_bit_allocation: Some(AutoBitAllocationStrategy::SensitivityBased {
                target_compression: 0.25, // 4x compression
                min_bits: 2,
                max_bits: 8,
            }),
        }
    }
}

impl Default for LayerQuantConfig {
    fn default() -> Self {
        Self {
            weight_bits: 4,
            activation_bits: 8,
            scheme: QuantizationScheme::Int4,
            symmetric: true,
            group_size: Some(128),
            channel_bits: None,
        }
    }
}

impl Default for SensitivityConfig {
    fn default() -> Self {
        Self {
            calibration_samples: 128,
            sensitivity_threshold: 0.01,
            metrics: vec![
                SensitivityMetric::GradientMagnitude,
                SensitivityMetric::ActivationVariance,
                SensitivityMetric::WeightVariance,
            ],
        }
    }
}

impl MixedBitQuantizer {
    /// Create a new mixed-bit quantizer
    pub fn new(config: MixedBitConfig) -> Self {
        let sensitivity_analyzer = SensitivityAnalyzer::new(config.sensitivity_config.clone());
        Self {
            config,
            sensitivity_analyzer,
        }
    }

    /// Quantize a tensor using mixed-bit quantization
    pub fn quantize(
        &mut self,
        tensor: &Tensor,
        layer_name: &str,
    ) -> Result<MixedBitQuantizedTensor> {
        // Get or create layer configuration
        let layer_config = self
            .config
            .layer_configs
            .get(layer_name)
            .cloned()
            .unwrap_or_else(|| self.config.default_config.clone());

        // Analyze sensitivity if needed
        let sensitivity_scores = if let Some(ref auto_strategy) = self.config.auto_bit_allocation {
            self.sensitivity_analyzer
                .analyze_sensitivity(tensor, layer_name, &layer_config)?
        } else {
            vec![1.0; tensor.shape().iter().product()]
        };

        // Allocate bits based on sensitivity
        let bit_allocation = self.allocate_bits(&sensitivity_scores, &layer_config)?;

        // Quantize tensor into blocks
        let quantized_blocks = self.quantize_blocks(tensor, &bit_allocation, &layer_config)?;

        Ok(MixedBitQuantizedTensor {
            layer_name: layer_name.to_string(),
            quantized_data: quantized_blocks,
            shape: tensor.shape(),
            config: layer_config,
            sensitivity_scores,
        })
    }

    /// Allocate bits based on sensitivity scores
    fn allocate_bits(
        &self,
        sensitivity_scores: &[f32],
        config: &LayerQuantConfig,
    ) -> Result<Vec<u8>> {
        let mut bit_allocation = vec![config.weight_bits; sensitivity_scores.len()];

        if let Some(ref strategy) = self.config.auto_bit_allocation {
            match strategy {
                AutoBitAllocationStrategy::SensitivityBased {
                    target_compression,
                    min_bits,
                    max_bits,
                } => {
                    // Sort indices by sensitivity
                    let mut indexed_scores: Vec<(usize, f32)> = sensitivity_scores
                        .iter()
                        .enumerate()
                        .map(|(i, &score)| (i, score))
                        .collect();
                    indexed_scores
                        .sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));

                    // Calculate target total bits
                    let total_elements = sensitivity_scores.len();
                    let target_total_bits = (total_elements as f32
                        * config.weight_bits as f32
                        * target_compression) as usize;
                    let mut allocated_bits = 0;

                    // Allocate bits starting from most sensitive
                    for (idx, _) in indexed_scores {
                        let remaining_elements =
                            total_elements - allocated_bits / (*max_bits as usize);
                        let remaining_budget = target_total_bits.saturating_sub(allocated_bits);

                        let avg_bits_remaining =
                            remaining_budget.checked_div(remaining_elements).unwrap_or(0);
                        if avg_bits_remaining > 0 {
                            let bits = (avg_bits_remaining as u8).clamp(*min_bits, *max_bits);
                            bit_allocation[idx] = bits;
                            allocated_bits += bits as usize;
                        }
                    }
                },
                AutoBitAllocationStrategy::AdaptiveUniform {
                    base_bits,
                    adjustment_range,
                } => {
                    // Calculate mean sensitivity
                    let mean_sensitivity =
                        sensitivity_scores.iter().sum::<f32>() / sensitivity_scores.len() as f32;

                    for (i, &score) in sensitivity_scores.iter().enumerate() {
                        let normalized_score = score / mean_sensitivity;
                        let adjustment = (normalized_score * *adjustment_range as f32) as i8;
                        let bits = (*base_bits as i8 + adjustment).clamp(1, 8) as u8;
                        bit_allocation[i] = bits;
                    }
                },
                AutoBitAllocationStrategy::PerformanceDriven {
                    target_latency,
                    accuracy_tolerance,
                } => {
                    // Performance-driven allocation optimizes for latency while maintaining accuracy
                    return self.allocate_bits_performance_driven(
                        sensitivity_scores,
                        config,
                        *target_latency,
                        *accuracy_tolerance,
                    );
                },
            }
        }

        Ok(bit_allocation)
    }

    /// Allocate bits based on sensitivity scores (fallback implementation)
    #[allow(dead_code)]
    fn allocate_bits_sensitivity_based(
        &self,
        sensitivity_scores: &[f32],
        config: &LayerQuantConfig,
    ) -> Result<Vec<u8>> {
        let mut bit_allocation = vec![config.weight_bits; sensitivity_scores.len()];

        // Find sensitivity percentiles
        let mut sorted_scores = sensitivity_scores.to_vec();
        sorted_scores.sort_by(|a, b| a.partial_cmp(b).expect("Partial comparison failed"));

        let high_sensitivity_threshold =
            sorted_scores[(sorted_scores.len() * 90 / 100).min(sorted_scores.len() - 1)];
        let low_sensitivity_threshold = sorted_scores[sorted_scores.len() * 10 / 100];

        for (i, &score) in sensitivity_scores.iter().enumerate() {
            if score >= high_sensitivity_threshold {
                bit_allocation[i] = 8; // High precision for sensitive parts
            } else if score <= low_sensitivity_threshold {
                bit_allocation[i] = 2; // Low precision for insensitive parts
            } else {
                bit_allocation[i] = 4; // Medium precision
            }
        }

        Ok(bit_allocation)
    }

    /// Performance-driven bit allocation optimizing for latency while maintaining accuracy
    fn allocate_bits_performance_driven(
        &self,
        sensitivity_scores: &[f32],
        config: &LayerQuantConfig,
        target_latency: f32,
        accuracy_tolerance: f32,
    ) -> Result<Vec<u8>> {
        let total_elements = sensitivity_scores.len();

        // Model performance characteristics (simplified model)
        // Lower bits = faster computation but potentially lower accuracy
        let performance_factor = |bits: u8| -> f32 {
            match bits {
                1 => 0.1,  // Very fast, very low accuracy impact
                2 => 0.25, // Fast, low accuracy impact
                3 => 0.4,  // Medium-fast, medium accuracy impact
                4 => 0.6,  // Medium, medium accuracy impact
                5 => 0.75, // Medium-slow, higher accuracy
                6 => 0.85, // Slow, high accuracy
                7 => 0.92, // Very slow, very high accuracy
                8 => 1.0,  // Slowest, highest accuracy
                _ => 1.0,
            }
        };

        // Calculate accuracy impact based on sensitivity
        let accuracy_impact = |sensitivity: f32, bits: u8| -> f32 {
            let base_impact = sensitivity / 100.0; // Normalize sensitivity
            let bit_factor = (8.0 - bits as f32) / 7.0; // Higher impact with fewer bits
            base_impact * bit_factor
        };

        // Sort elements by sensitivity to prioritize important layers
        let mut indexed_scores: Vec<(usize, f32)> =
            sensitivity_scores.iter().enumerate().map(|(i, &score)| (i, score)).collect();
        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));

        // Start with lowest bits for maximum performance
        let mut current_bits = vec![2u8; total_elements];
        let mut current_latency = 0.0;
        let mut current_accuracy_loss = 0.0;

        // Calculate initial latency and accuracy
        for (i, &score) in sensitivity_scores.iter().enumerate() {
            current_latency += performance_factor(2);
            current_accuracy_loss += accuracy_impact(score, 2);
        }

        // Iteratively increase bits for most sensitive elements until we hit constraints
        for (idx, sensitivity) in indexed_scores {
            let current_element_bits = current_bits[idx];

            // Try increasing bits for this element
            for new_bits in (current_element_bits + 1)..=8 {
                let latency_change =
                    performance_factor(new_bits) - performance_factor(current_element_bits);
                let accuracy_change = accuracy_impact(sensitivity, current_element_bits)
                    - accuracy_impact(sensitivity, new_bits);

                let new_latency = current_latency + latency_change;
                let new_accuracy_loss = current_accuracy_loss - accuracy_change;

                // Check if this change fits within our constraints
                let normalized_latency = new_latency / total_elements as f32;
                if normalized_latency <= target_latency && new_accuracy_loss <= accuracy_tolerance {
                    // Apply the change
                    current_bits[idx] = new_bits;
                    current_latency = new_latency;
                    current_accuracy_loss = new_accuracy_loss;
                } else {
                    // Can't improve this element further, move to next
                    break;
                }
            }
        }

        // Apply final allocation
        let bit_allocation = current_bits;

        Ok(bit_allocation)
    }

    /// Quantize tensor into blocks with different bit widths
    fn quantize_blocks(
        &self,
        tensor: &Tensor,
        bit_allocation: &[u8],
        config: &LayerQuantConfig,
    ) -> Result<Vec<QuantizedBlock>> {
        let data = tensor.data()?;
        let shape = tensor.shape();
        let mut blocks = Vec::new();

        // Group elements by bit width
        let mut bit_groups: HashMap<u8, Vec<(usize, f32)>> = HashMap::new();
        for (i, (&bits, &value)) in bit_allocation.iter().zip(data.iter()).enumerate() {
            bit_groups.entry(bits).or_default().push((i, value));
        }

        // Quantize each group
        for (bit_width, elements) in bit_groups {
            let values: Vec<f32> = elements.iter().map(|(_, v)| *v).collect();
            let indices: Vec<usize> = elements.iter().map(|(i, _)| *i).collect();

            let (quantized_data, scale, zero_point) =
                self.quantize_group(&values, bit_width, config)?;

            blocks.push(QuantizedBlock {
                data: quantized_data,
                scale,
                zero_point,
                bit_width,
                block_shape: vec![values.len()],
                block_offset: vec![indices[0]], // Simplified for now
            });
        }

        Ok(blocks)
    }

    /// Quantize a group of values with specified bit width
    fn quantize_group(
        &self,
        values: &[f32],
        bit_width: u8,
        config: &LayerQuantConfig,
    ) -> Result<(Vec<u8>, f32, i32)> {
        if values.is_empty() {
            return Ok((Vec::new(), 1.0, 0));
        }

        let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
        let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));

        let qmin = 0;
        let qmax = (1 << bit_width) - 1;

        let (scale, zero_point) = if config.symmetric {
            let max_abs = max_val.abs().max(min_val.abs());
            let scale = max_abs / (qmax as f32 / 2.0);
            (scale, qmax / 2)
        } else {
            let scale = (max_val - min_val) / (qmax - qmin) as f32;
            let zero_point = qmin as f32 - min_val / scale;
            (scale, zero_point.round() as i32)
        };

        let mut quantized = Vec::with_capacity(values.len());
        for &value in values {
            let q_val = (value / scale + zero_point as f32).round() as i32;
            let clamped = q_val.clamp(qmin, qmax) as u8;
            quantized.push(clamped);
        }

        Ok((quantized, scale, zero_point))
    }

    /// Get compression ratio achieved
    pub fn compression_ratio(
        &self,
        original_size: usize,
        quantized_tensor: &MixedBitQuantizedTensor,
    ) -> f32 {
        let compressed_size: usize =
            quantized_tensor.quantized_data.iter().map(|block| block.data.len()).sum();

        original_size as f32 / compressed_size as f32
    }

    /// Estimate memory savings
    pub fn memory_savings(
        &self,
        original_tensor: &Tensor,
        quantized_tensor: &MixedBitQuantizedTensor,
    ) -> f32 {
        let original_bytes = original_tensor.size() * std::mem::size_of::<f32>();
        let quantized_bytes: usize =
            quantized_tensor.quantized_data.iter().map(|block| block.data.len()).sum();

        1.0 - (quantized_bytes as f32 / original_bytes as f32)
    }
}

impl MixedBitQuantizedTensor {
    /// Dequantize back to original tensor
    pub fn dequantize(&self) -> Result<Tensor> {
        let total_elements: usize = self.shape.iter().product();
        let mut result = vec![0.0f32; total_elements];

        for block in &self.quantized_data {
            for (i, &quantized_val) in block.data.iter().enumerate() {
                let dequantized = (quantized_val as i32 - block.zero_point) as f32 * block.scale;
                // Simplified mapping - in practice, would need proper index mapping
                if i < result.len() {
                    result[i] = dequantized;
                }
            }
        }

        Tensor::from_vec(result, &self.shape)
    }

    /// Get average bit width used
    pub fn average_bit_width(&self) -> f32 {
        let total_elements: usize = self.quantized_data.iter().map(|b| b.data.len()).sum();
        if total_elements == 0 {
            return 0.0;
        }

        let total_bits: f32 = self
            .quantized_data
            .iter()
            .map(|block| block.data.len() as f32 * block.bit_width as f32)
            .sum();

        total_bits / total_elements as f32
    }

    /// Get memory footprint in bytes
    pub fn memory_footprint(&self) -> usize {
        self.quantized_data.iter().map(|block| block.data.len()).sum()
    }
}

impl SensitivityAnalyzer {
    fn new(config: SensitivityConfig) -> Self {
        Self {
            config,
            sensitivity_cache: HashMap::new(),
        }
    }

    /// Analyze sensitivity of tensor elements
    fn analyze_sensitivity(
        &mut self,
        tensor: &Tensor,
        layer_name: &str,
        _config: &LayerQuantConfig,
    ) -> Result<Vec<f32>> {
        // Check cache first
        if let Some(cached_scores) = self.sensitivity_cache.get(layer_name) {
            return Ok(cached_scores.clone());
        }

        let data = tensor.data()?;
        let mut sensitivity_scores = vec![0.0; data.len()];

        // Analyze each configured metric
        for metric in &self.config.metrics {
            let metric_scores = self.compute_metric_scores(tensor, metric)?;

            // Combine metrics (simple averaging for now)
            for (i, score) in metric_scores.iter().enumerate() {
                sensitivity_scores[i] += score / self.config.metrics.len() as f32;
            }
        }

        // Cache the results
        self.sensitivity_cache
            .insert(layer_name.to_string(), sensitivity_scores.clone());

        Ok(sensitivity_scores)
    }

    /// Compute sensitivity scores for a specific metric
    fn compute_metric_scores(
        &self,
        tensor: &Tensor,
        metric: &SensitivityMetric,
    ) -> Result<Vec<f32>> {
        let data = tensor.data()?;

        match metric {
            SensitivityMetric::WeightVariance => {
                // Compute local variance as a sensitivity measure
                let mean = data.iter().sum::<f32>() / data.len() as f32;
                let variance: Vec<f32> = data.iter().map(|&x| (x - mean).powi(2)).collect();
                Ok(variance)
            },
            SensitivityMetric::GradientMagnitude => {
                // Approximate gradient magnitude using weight magnitude
                Ok(data.iter().map(|&x| x.abs()).collect())
            },
            SensitivityMetric::ActivationVariance => {
                // For weights, use magnitude as proxy for activation impact
                Ok(data.iter().map(|&x| x.abs()).collect())
            },
            SensitivityMetric::HessianDiagonal => {
                // Simplified hessian approximation
                let hessian_approx: Vec<f32> = data.iter().map(|&x| x.powi(2)).collect();
                Ok(hessian_approx)
            },
            SensitivityMetric::OutputSensitivity => {
                // Use weight magnitude as proxy for output sensitivity
                Ok(data.iter().map(|&x| x.abs()).collect())
            },
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tensor::Tensor;

    #[test]
    fn test_mixed_bit_quantizer_creation() {
        let config = MixedBitConfig::default();
        let quantizer = MixedBitQuantizer::new(config);
        assert!(quantizer.config.auto_bit_allocation.is_some());
    }

    #[test]
    fn test_mixed_bit_quantization() -> Result<()> {
        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
        let tensor = Tensor::randn(&[4, 4])?;

        let quantized = quantizer.quantize(&tensor, "test_layer")?;
        assert_eq!(quantized.shape, vec![4, 4]);
        assert!(!quantized.quantized_data.is_empty());

        Ok(())
    }

    #[test]
    fn test_mixed_bit_dequantization() -> Result<()> {
        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
        let tensor = Tensor::randn(&[2, 2])?;

        let quantized = quantizer.quantize(&tensor, "test_layer")?;
        let dequantized = quantized.dequantize()?;

        assert_eq!(dequantized.shape(), tensor.shape());
        Ok(())
    }

    #[test]
    fn test_average_bit_width() -> Result<()> {
        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
        let tensor = Tensor::randn(&[8])?;

        let quantized = quantizer.quantize(&tensor, "test_layer")?;
        let avg_bits = quantized.average_bit_width();

        assert!(avg_bits > 0.0);
        assert!(avg_bits <= 8.0);
        Ok(())
    }

    #[test]
    fn test_compression_ratio() -> Result<()> {
        let mut quantizer = MixedBitQuantizer::new(MixedBitConfig::default());
        let tensor = Tensor::randn(&[1024])?; // Use larger tensor to overcome metadata overhead

        let quantized = quantizer.quantize(&tensor, "test_layer")?;
        let ratio = quantizer.compression_ratio(tensor.size(), &quantized);

        assert!(ratio >= 1.0); // Current implementation stores as bytes, so ratio may be 1.0
        Ok(())
    }

    #[test]
    fn test_sensitivity_analysis() -> Result<()> {
        let config = SensitivityConfig::default();
        let mut analyzer = SensitivityAnalyzer::new(config);
        let tensor = Tensor::randn(&[4, 4])?;

        let layer_config = LayerQuantConfig::default();
        let scores = analyzer.analyze_sensitivity(&tensor, "test_layer", &layer_config)?;

        assert_eq!(scores.len(), 16);
        assert!(scores.iter().all(|&score| score >= 0.0));
        Ok(())
    }
}