oxirs_chat/memory_optimization/
tensor_ops.rs1use anyhow::Result;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub enum MemoryEfficientTensor {
9 F32(Vec<f32>),
11
12 F16(Vec<u16>),
14
15 I8 {
17 data: Vec<i8>,
18 scale: f32,
19 zero_point: f32, },
21
22 Sparse {
24 indices: Vec<usize>,
25 values: Vec<f32>,
26 size: usize,
27 },
28}
29
30impl MemoryEfficientTensor {
31 pub fn to_f32(&self) -> Vec<f32> {
33 match self {
34 Self::F32(data) => data.clone(),
35 Self::F16(data) => data.iter().map(|x| f16_to_f32(*x)).collect(),
36 Self::I8 {
37 data,
38 scale,
39 zero_point,
40 } => data
41 .iter()
42 .map(|x| (*x as f32 + 128.0) * scale + zero_point)
43 .collect(),
44 Self::Sparse {
45 indices,
46 values,
47 size,
48 } => {
49 let mut result = vec![0.0f32; *size];
50 for (idx, val) in indices.iter().zip(values.iter()) {
51 result[*idx] = *val;
52 }
53 result
54 }
55 }
56 }
57
58 pub fn memory_size(&self) -> usize {
60 match self {
61 Self::F32(data) => data.len() * 4,
62 Self::F16(data) => data.len() * 2,
63 Self::I8 { data, .. } => data.len() + 8, Self::Sparse {
65 indices, values, ..
66 } => indices.len() * 8 + values.len() * 4 + 8,
67 }
68 }
69}
70
71pub struct TensorOptimizer {
73 use_low_precision: bool,
74 quantization_enabled: bool,
75 sparse_threshold: f32,
76}
77
78impl TensorOptimizer {
79 pub fn new(use_low_precision: bool) -> Self {
80 Self {
81 use_low_precision,
82 quantization_enabled: false,
83 sparse_threshold: 0.5, }
85 }
86
87 pub fn with_quantization(mut self) -> Self {
89 self.quantization_enabled = true;
90 self
91 }
92
93 pub fn with_sparse_threshold(mut self, threshold: f32) -> Self {
95 self.sparse_threshold = threshold;
96 self
97 }
98
99 pub fn optimize(&self, tensor: &[f32]) -> Result<MemoryEfficientTensor> {
101 let sparsity = calculate_sparsity(tensor);
103 if sparsity >= self.sparse_threshold {
104 return Ok(self.to_sparse(tensor));
105 }
106
107 if self.quantization_enabled {
109 return Ok(self.quantize_i8(tensor));
110 }
111
112 if self.use_low_precision {
114 return Ok(self.to_f16(tensor));
115 }
116
117 Ok(MemoryEfficientTensor::F32(tensor.to_vec()))
119 }
120
121 fn to_f16(&self, tensor: &[f32]) -> MemoryEfficientTensor {
122 let data: Vec<u16> = tensor.iter().map(|x| f32_to_f16(*x)).collect();
123 MemoryEfficientTensor::F16(data)
124 }
125
126 fn quantize_i8(&self, tensor: &[f32]) -> MemoryEfficientTensor {
127 let min = tensor.iter().fold(f32::INFINITY, |a, &b| a.min(b));
128 let max = tensor.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
129
130 let scale = (max - min) / 255.0;
131 let _zero_point = 0i8; let data: Vec<i8> = tensor
134 .iter()
135 .map(|x| (((x - min) / scale).round() as i32 - 128).clamp(-128, 127) as i8)
136 .collect();
137
138 MemoryEfficientTensor::I8 {
139 data,
140 scale,
141 zero_point: min, }
143 }
144
145 fn to_sparse(&self, tensor: &[f32]) -> MemoryEfficientTensor {
146 let mut indices = Vec::new();
147 let mut values = Vec::new();
148
149 for (idx, &val) in tensor.iter().enumerate() {
150 if val.abs() > 1e-6 {
151 indices.push(idx);
152 values.push(val);
153 }
154 }
155
156 MemoryEfficientTensor::Sparse {
157 indices,
158 values,
159 size: tensor.len(),
160 }
161 }
162
163 pub fn memory_savings(&self, original: &[f32], optimized: &MemoryEfficientTensor) -> f64 {
165 let original_size = original.len() * 4; let optimized_size = optimized.memory_size();
167
168 1.0 - (optimized_size as f64 / original_size as f64)
169 }
170}
171
172fn calculate_sparsity(tensor: &[f32]) -> f32 {
174 if tensor.is_empty() {
175 return 0.0;
176 }
177
178 let zeros = tensor.iter().filter(|x| x.abs() < 1e-6).count();
179 zeros as f32 / tensor.len() as f32
180}
181
182fn f32_to_f16(value: f32) -> u16 {
184 let bits = value.to_bits();
187 let sign = (bits >> 16) & 0x8000;
188 let exp = ((bits >> 23) & 0xFF) as i32;
189 let frac = (bits >> 13) & 0x3FF;
190
191 if exp == 0 {
192 return sign as u16;
193 }
194
195 let exp_adj = exp - 127 + 15;
196 if exp_adj >= 31 {
197 return (sign | 0x7C00) as u16; }
199 if exp_adj <= 0 {
200 return sign as u16; }
202
203 (sign | ((exp_adj as u32) << 10) | frac) as u16
204}
205
206fn f16_to_f32(value: u16) -> f32 {
208 let sign = ((value >> 15) & 1) as u32;
209 let exp = ((value >> 10) & 0x1F) as i32;
210 let frac = (value & 0x3FF) as u32;
211
212 if exp == 0 {
213 return if sign == 1 { -0.0 } else { 0.0 };
214 }
215
216 if exp == 31 {
217 return if frac == 0 {
218 if sign == 1 {
219 f32::NEG_INFINITY
220 } else {
221 f32::INFINITY
222 }
223 } else {
224 f32::NAN
225 };
226 }
227
228 let exp_adj = exp - 15 + 127;
229 let bits = (sign << 31) | ((exp_adj as u32) << 23) | (frac << 13);
230 f32::from_bits(bits)
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_f32_tensor() {
239 let data = vec![1.0, 2.0, 3.0];
240 let tensor = MemoryEfficientTensor::F32(data.clone());
241
242 let recovered = tensor.to_f32();
243 assert_eq!(recovered, data);
244 assert_eq!(tensor.memory_size(), 12); }
246
247 #[test]
248 fn test_f16_tensor() {
249 let optimizer = TensorOptimizer::new(true);
250 let data = vec![1.0, 2.0, 3.0, 4.0];
251
252 let optimized = optimizer.optimize(&data).expect("should succeed");
253 assert_eq!(optimized.memory_size(), 8); let recovered = optimized.to_f32();
256 assert_eq!(recovered.len(), data.len());
257 }
258
259 #[test]
260 fn test_quantized_tensor() {
261 let optimizer = TensorOptimizer::new(false).with_quantization();
262 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
263
264 let optimized = optimizer.optimize(&data).expect("should succeed");
265 let recovered = optimized.to_f32();
266
267 assert_eq!(recovered.len(), data.len());
268 for (a, b) in data.iter().zip(recovered.iter()) {
271 assert!((a - b).abs() < 0.5, "Expected {} but got {}", a, b);
272 }
273 }
274
275 #[test]
276 fn test_sparse_tensor() {
277 let optimizer = TensorOptimizer::new(false).with_sparse_threshold(0.5);
278 let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0]; let optimized = optimizer.optimize(&data).expect("should succeed");
281 assert!(matches!(optimized, MemoryEfficientTensor::Sparse { .. }));
282
283 let recovered = optimized.to_f32();
284 assert_eq!(recovered.len(), data.len());
285
286 for (a, b) in data.iter().zip(recovered.iter()) {
287 assert!((a - b).abs() < 0.001);
288 }
289 }
290
291 #[test]
292 fn test_calculate_sparsity() {
293 let sparse = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0];
294 let sparsity = calculate_sparsity(&sparse);
295 assert!((sparsity - 0.666).abs() < 0.01);
296
297 let dense = vec![1.0, 2.0, 3.0, 4.0];
298 let sparsity_dense = calculate_sparsity(&dense);
299 assert_eq!(sparsity_dense, 0.0);
300 }
301
302 #[test]
303 fn test_memory_savings() {
304 let optimizer = TensorOptimizer::new(false).with_quantization();
305 let data = vec![1.0; 1000];
306
307 let optimized = optimizer.optimize(&data).expect("should succeed");
308 let savings = optimizer.memory_savings(&data, &optimized);
309
310 assert!(savings > 0.7); }
312}