Skip to main content

oxirs_chat/memory_optimization/
tensor_ops.rs

1//! Memory-efficient tensor operations
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5
6/// Memory-efficient tensor representation
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub enum MemoryEfficientTensor {
9    /// Full precision (f32)
10    F32(Vec<f32>),
11
12    /// Half precision (f16) - 50% memory reduction
13    F16(Vec<u16>),
14
15    /// Quantized (i8) - 75% memory reduction
16    I8 {
17        data: Vec<i8>,
18        scale: f32,
19        zero_point: f32, // Stores min value for dequantization
20    },
21
22    /// Sparse (only non-zero values)
23    Sparse {
24        indices: Vec<usize>,
25        values: Vec<f32>,
26        size: usize,
27    },
28}
29
30impl MemoryEfficientTensor {
31    /// Convert to f32 for computation
32    pub fn to_f32(&self) -> Vec<f32> {
33        match self {
34            Self::F32(data) => data.clone(),
35            Self::F16(data) => data.iter().map(|x| f16_to_f32(*x)).collect(),
36            Self::I8 {
37                data,
38                scale,
39                zero_point,
40            } => data
41                .iter()
42                .map(|x| (*x as f32 + 128.0) * scale + zero_point)
43                .collect(),
44            Self::Sparse {
45                indices,
46                values,
47                size,
48            } => {
49                let mut result = vec![0.0f32; *size];
50                for (idx, val) in indices.iter().zip(values.iter()) {
51                    result[*idx] = *val;
52                }
53                result
54            }
55        }
56    }
57
58    /// Get memory size in bytes
59    pub fn memory_size(&self) -> usize {
60        match self {
61            Self::F32(data) => data.len() * 4,
62            Self::F16(data) => data.len() * 2,
63            Self::I8 { data, .. } => data.len() + 8, // data + scale (f32) + zero_point (f32)
64            Self::Sparse {
65                indices, values, ..
66            } => indices.len() * 8 + values.len() * 4 + 8,
67        }
68    }
69}
70
71/// Tensor optimizer for memory efficiency
72pub struct TensorOptimizer {
73    use_low_precision: bool,
74    quantization_enabled: bool,
75    sparse_threshold: f32,
76}
77
78impl TensorOptimizer {
79    pub fn new(use_low_precision: bool) -> Self {
80        Self {
81            use_low_precision,
82            quantization_enabled: false,
83            sparse_threshold: 0.5, // 50% sparsity threshold
84        }
85    }
86
87    /// Enable quantization for even more memory savings
88    pub fn with_quantization(mut self) -> Self {
89        self.quantization_enabled = true;
90        self
91    }
92
93    /// Set sparsity threshold (0.0 to 1.0)
94    pub fn with_sparse_threshold(mut self, threshold: f32) -> Self {
95        self.sparse_threshold = threshold;
96        self
97    }
98
99    /// Optimize tensor representation
100    pub fn optimize(&self, tensor: &[f32]) -> Result<MemoryEfficientTensor> {
101        // Check for sparsity
102        let sparsity = calculate_sparsity(tensor);
103        if sparsity >= self.sparse_threshold {
104            return Ok(self.to_sparse(tensor));
105        }
106
107        // Use quantization if enabled
108        if self.quantization_enabled {
109            return Ok(self.quantize_i8(tensor));
110        }
111
112        // Use half precision if enabled
113        if self.use_low_precision {
114            return Ok(self.to_f16(tensor));
115        }
116
117        // Default: keep as f32
118        Ok(MemoryEfficientTensor::F32(tensor.to_vec()))
119    }
120
121    fn to_f16(&self, tensor: &[f32]) -> MemoryEfficientTensor {
122        let data: Vec<u16> = tensor.iter().map(|x| f32_to_f16(*x)).collect();
123        MemoryEfficientTensor::F16(data)
124    }
125
126    fn quantize_i8(&self, tensor: &[f32]) -> MemoryEfficientTensor {
127        let min = tensor.iter().fold(f32::INFINITY, |a, &b| a.min(b));
128        let max = tensor.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
129
130        let scale = (max - min) / 255.0;
131        let _zero_point = 0i8; // Use 0 as zero point for simplicity
132
133        let data: Vec<i8> = tensor
134            .iter()
135            .map(|x| (((x - min) / scale).round() as i32 - 128).clamp(-128, 127) as i8)
136            .collect();
137
138        MemoryEfficientTensor::I8 {
139            data,
140            scale,
141            zero_point: min, // Store min as zero_point for dequantization
142        }
143    }
144
145    fn to_sparse(&self, tensor: &[f32]) -> MemoryEfficientTensor {
146        let mut indices = Vec::new();
147        let mut values = Vec::new();
148
149        for (idx, &val) in tensor.iter().enumerate() {
150            if val.abs() > 1e-6 {
151                indices.push(idx);
152                values.push(val);
153            }
154        }
155
156        MemoryEfficientTensor::Sparse {
157            indices,
158            values,
159            size: tensor.len(),
160        }
161    }
162
163    /// Calculate memory savings
164    pub fn memory_savings(&self, original: &[f32], optimized: &MemoryEfficientTensor) -> f64 {
165        let original_size = original.len() * 4; // f32 = 4 bytes
166        let optimized_size = optimized.memory_size();
167
168        1.0 - (optimized_size as f64 / original_size as f64)
169    }
170}
171
172/// Calculate sparsity (ratio of zero elements)
173fn calculate_sparsity(tensor: &[f32]) -> f32 {
174    if tensor.is_empty() {
175        return 0.0;
176    }
177
178    let zeros = tensor.iter().filter(|x| x.abs() < 1e-6).count();
179    zeros as f32 / tensor.len() as f32
180}
181
182/// Simple f32 to f16 conversion (simplified, not IEEE 754 compliant)
183fn f32_to_f16(value: f32) -> u16 {
184    // Simplified conversion (for demonstration)
185    // In production, use proper IEEE 754 half-precision conversion
186    let bits = value.to_bits();
187    let sign = (bits >> 16) & 0x8000;
188    let exp = ((bits >> 23) & 0xFF) as i32;
189    let frac = (bits >> 13) & 0x3FF;
190
191    if exp == 0 {
192        return sign as u16;
193    }
194
195    let exp_adj = exp - 127 + 15;
196    if exp_adj >= 31 {
197        return (sign | 0x7C00) as u16; // Infinity
198    }
199    if exp_adj <= 0 {
200        return sign as u16; // Zero
201    }
202
203    (sign | ((exp_adj as u32) << 10) | frac) as u16
204}
205
206/// Simple f16 to f32 conversion
207fn f16_to_f32(value: u16) -> f32 {
208    let sign = ((value >> 15) & 1) as u32;
209    let exp = ((value >> 10) & 0x1F) as i32;
210    let frac = (value & 0x3FF) as u32;
211
212    if exp == 0 {
213        return if sign == 1 { -0.0 } else { 0.0 };
214    }
215
216    if exp == 31 {
217        return if frac == 0 {
218            if sign == 1 {
219                f32::NEG_INFINITY
220            } else {
221                f32::INFINITY
222            }
223        } else {
224            f32::NAN
225        };
226    }
227
228    let exp_adj = exp - 15 + 127;
229    let bits = (sign << 31) | ((exp_adj as u32) << 23) | (frac << 13);
230    f32::from_bits(bits)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_f32_tensor() {
239        let data = vec![1.0, 2.0, 3.0];
240        let tensor = MemoryEfficientTensor::F32(data.clone());
241
242        let recovered = tensor.to_f32();
243        assert_eq!(recovered, data);
244        assert_eq!(tensor.memory_size(), 12); // 3 * 4 bytes
245    }
246
247    #[test]
248    fn test_f16_tensor() {
249        let optimizer = TensorOptimizer::new(true);
250        let data = vec![1.0, 2.0, 3.0, 4.0];
251
252        let optimized = optimizer.optimize(&data).expect("should succeed");
253        assert_eq!(optimized.memory_size(), 8); // 4 * 2 bytes
254
255        let recovered = optimized.to_f32();
256        assert_eq!(recovered.len(), data.len());
257    }
258
259    #[test]
260    fn test_quantized_tensor() {
261        let optimizer = TensorOptimizer::new(false).with_quantization();
262        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
263
264        let optimized = optimizer.optimize(&data).expect("should succeed");
265        let recovered = optimized.to_f32();
266
267        assert_eq!(recovered.len(), data.len());
268        // Values should be close but not exact due to quantization
269        // i8 quantization introduces more error, so use larger tolerance
270        for (a, b) in data.iter().zip(recovered.iter()) {
271            assert!((a - b).abs() < 0.5, "Expected {} but got {}", a, b);
272        }
273    }
274
275    #[test]
276    fn test_sparse_tensor() {
277        let optimizer = TensorOptimizer::new(false).with_sparse_threshold(0.5);
278        let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0]; // 66% sparse
279
280        let optimized = optimizer.optimize(&data).expect("should succeed");
281        assert!(matches!(optimized, MemoryEfficientTensor::Sparse { .. }));
282
283        let recovered = optimized.to_f32();
284        assert_eq!(recovered.len(), data.len());
285
286        for (a, b) in data.iter().zip(recovered.iter()) {
287            assert!((a - b).abs() < 0.001);
288        }
289    }
290
291    #[test]
292    fn test_calculate_sparsity() {
293        let sparse = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0];
294        let sparsity = calculate_sparsity(&sparse);
295        assert!((sparsity - 0.666).abs() < 0.01);
296
297        let dense = vec![1.0, 2.0, 3.0, 4.0];
298        let sparsity_dense = calculate_sparsity(&dense);
299        assert_eq!(sparsity_dense, 0.0);
300    }
301
302    #[test]
303    fn test_memory_savings() {
304        let optimizer = TensorOptimizer::new(false).with_quantization();
305        let data = vec![1.0; 1000];
306
307        let optimized = optimizer.optimize(&data).expect("should succeed");
308        let savings = optimizer.memory_savings(&data, &optimized);
309
310        assert!(savings > 0.7); // Should save >70% with i8 quantization
311    }
312}