Skip to main content

pocket_tts/
quantize.rs

1//! Quantization support for Pocket TTS
2//!
3//! This module provides quantization utilities for reduced memory footprint
4//! and potentially faster inference on CPU.
5//!
6//! Note: Candle doesn't natively support int8 tensor operations, so we use
7//! a simulated quantization approach that stores quantized values as f32 but
8//! represents them using only 256 discrete levels (mimicking int8 range).
9//!
10//! For true int8 acceleration, use GGML/GGUF format weights with candle-quantized.
11
12use anyhow::Result;
13use candle_core::{DType, Tensor};
14use std::collections::HashMap;
15
16/// Quantization configuration
17#[derive(Debug, Clone)]
18pub struct QuantizeConfig {
19    /// Layers to skip quantization (keep in full precision)
20    pub skip_layers: Vec<String>,
21    /// Minimum tensor size to quantize (smaller tensors stay full precision)
22    pub min_size: usize,
23    /// Number of quantization levels (256 for int8-like behavior)
24    pub num_levels: usize,
25}
26
27impl Default for QuantizeConfig {
28    fn default() -> Self {
29        Self {
30            skip_layers: vec![
31                // Embeddings often benefit from staying in full precision
32                "embed".to_string(),
33                "lut".to_string(),
34                // Final output projections
35                "out_proj".to_string(),
36                "eos_head".to_string(),
37            ],
38            min_size: 1024,  // Don't bother quantizing small tensors
39            num_levels: 256, // int8-like
40        }
41    }
42}
43
44/// Quantized tensor wrapper that stores scale for dequantization
45///
46/// This uses simulated quantization - values are stored as f32 but discretized
47/// to num_levels distinct values (256 for int8-equivalent).
48#[derive(Debug, Clone)]
49pub struct QuantizedTensor {
50    /// Quantized data (stored as f32 but with discrete values)
51    pub data: Tensor,
52    /// Scale factor for dequantization
53    pub scale: f32,
54    /// Zero point for asymmetric quantization
55    pub zero_point: f32,
56    /// Number of quantization levels used
57    pub num_levels: usize,
58}
59
60impl QuantizedTensor {
61    /// Quantize a tensor using symmetric per-tensor quantization
62    ///
63    /// This discretizes values to num_levels distinct values centered around 0,
64    /// simulating int8 quantization behavior while using f32 storage.
65    pub fn quantize(tensor: &Tensor, num_levels: usize) -> Result<Self> {
66        // Convert to f32 if needed
67        let tensor_f32 = tensor.to_dtype(DType::F32)?;
68
69        // Find max absolute value for symmetric quantization
70        let abs_max = tensor_f32.abs()?.max_all()?.to_scalar::<f32>()?;
71
72        // Calculate scale (half the range for symmetric)
73        let half_levels = (num_levels / 2) as f32;
74        let scale = if abs_max > 0.0 {
75            abs_max / (half_levels - 1.0)
76        } else {
77            1.0
78        };
79
80        // Quantize: q = round(x / scale), then dequantize back: x' = q * scale
81        // This simulates quantization while staying in f32
82        let scale_tensor = Tensor::new(&[scale], tensor.device())?;
83        let quantized = tensor_f32.broadcast_div(&scale_tensor)?;
84        let quantized = quantized.round()?;
85        let clamped = quantized.clamp(-(half_levels - 1.0) as f64, (half_levels - 1.0) as f64)?;
86        let data = clamped.broadcast_mul(&scale_tensor)?;
87
88        Ok(Self {
89            data,
90            scale,
91            zero_point: 0.0, // Symmetric quantization
92            num_levels,
93        })
94    }
95
96    /// Get the quantized tensor data
97    pub fn data(&self) -> &Tensor {
98        &self.data
99    }
100
101    /// Get the scale value
102    pub fn scale(&self) -> f32 {
103        self.scale
104    }
105
106    /// Get theoretical memory savings ratio compared to f32
107    /// (In practice, data is still stored as f32, but this shows potential savings)
108    pub fn theoretical_memory_savings(&self) -> f32 {
109        match self.num_levels {
110            256 => 4.0,   // int8 would be 4x smaller than f32
111            65536 => 2.0, // int16 would be 2x smaller
112            _ => 1.0,
113        }
114    }
115}
116
117/// Check if a layer name should skip quantization
118fn should_skip_layer(name: &str, config: &QuantizeConfig) -> bool {
119    config.skip_layers.iter().any(|skip| name.contains(skip))
120}
121
122/// Quantize a collection of weights according to config
123///
124/// Returns quantized weights. Layers in skip_layers or smaller than min_size
125/// are returned unchanged.
126pub fn quantize_weights(
127    weights: &HashMap<String, Tensor>,
128    config: &QuantizeConfig,
129) -> Result<HashMap<String, QuantizedTensor>> {
130    let mut quantized = HashMap::new();
131
132    for (name, tensor) in weights {
133        // Skip small tensors and excluded layers
134        if tensor.elem_count() < config.min_size || should_skip_layer(name, config) {
135            // Keep unquantized (scale=1, no discretization)
136            quantized.insert(
137                name.clone(),
138                QuantizedTensor {
139                    data: tensor.clone(),
140                    scale: 1.0,
141                    zero_point: 0.0,
142                    num_levels: 0, // Indicates not actually quantized
143                },
144            );
145        } else {
146            quantized.insert(
147                name.clone(),
148                QuantizedTensor::quantize(tensor, config.num_levels)?,
149            );
150        }
151    }
152
153    Ok(quantized)
154}
155
156/// Calculate signal-to-noise ratio between original and quantized tensors
157pub fn calculate_snr(original: &Tensor, quantized: &Tensor) -> Result<f32> {
158    let original_f32 = original.to_dtype(DType::F32)?;
159    let quantized_f32 = quantized.to_dtype(DType::F32)?;
160
161    // SNR = 10 * log10(signal_power / noise_power)
162    let signal_power = original_f32.sqr()?.mean_all()?.to_scalar::<f32>()?;
163    let noise = (&original_f32 - &quantized_f32)?;
164    let noise_power = noise.sqr()?.mean_all()?.to_scalar::<f32>()?;
165
166    if noise_power <= 0.0 {
167        return Ok(f32::INFINITY); // Perfect reconstruction
168    }
169
170    Ok(10.0 * (signal_power / noise_power).log10())
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use candle_core::Device;
177
178    #[test]
179    fn test_quantize_tensor() {
180        let device = Device::Cpu;
181        let tensor = Tensor::new(&[1.0f32, 2.0, -3.0, 4.5, -2.1], &device).unwrap();
182
183        let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
184
185        // Check SNR - should be good for int8-like quantization
186        let snr = calculate_snr(&tensor, &quantized.data).unwrap();
187        assert!(snr > 30.0, "SNR {} is too low", snr);
188    }
189
190    #[test]
191    fn test_quantize_large_tensor() {
192        let device = Device::Cpu;
193        // Create a larger tensor with varying values
194        let values: Vec<f32> = (0..10000).map(|i| (i as f32 * 0.01).sin() * 10.0).collect();
195        let tensor = Tensor::new(&values[..], &device).unwrap();
196
197        let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
198        let snr = calculate_snr(&tensor, &quantized.data).unwrap();
199
200        // For larger tensors with varied values, expect good SNR
201        assert!(snr > 30.0, "SNR {} is too low", snr);
202    }
203
204    #[test]
205    fn test_quantize_config_skip_layers() {
206        let config = QuantizeConfig::default();
207        assert!(should_skip_layer("model.embed_tokens", &config));
208        assert!(should_skip_layer("decoder.out_proj", &config));
209        assert!(!should_skip_layer("encoder.layers.0.linear", &config));
210    }
211
212    #[test]
213    fn test_theoretical_savings() {
214        let device = Device::Cpu;
215        let tensor = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
216        let quantized = QuantizedTensor::quantize(&tensor, 256).unwrap();
217        assert_eq!(quantized.theoretical_memory_savings(), 4.0);
218    }
219}