Skip to main content

oxirs_chat/memory_optimization/
mod.rs

1//! Memory optimization for AI operations
2//!
3//! This module provides memory-efficient operations for embeddings, model weights,
4//! and large-scale AI processing to minimize memory footprint in production.
5
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use std::sync::{Arc, RwLock};
9
10pub mod compression;
11pub mod pooling;
12pub mod streaming;
13pub mod tensor_ops;
14
15pub use compression::{CompressionAlgorithm, Compressor};
16pub use pooling::{MemoryPool, PooledBuffer};
17pub use streaming::{ChunkProcessor, StreamProcessor};
18pub use tensor_ops::{MemoryEfficientTensor, TensorOptimizer};
19
20/// Configuration for memory optimization
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct MemoryOptimizationConfig {
23    /// Enable memory pooling
24    pub enable_pooling: bool,
25
26    /// Pool size in MB
27    pub pool_size_mb: usize,
28
29    /// Enable streaming for large datasets
30    pub enable_streaming: bool,
31
32    /// Chunk size for streaming (in records)
33    pub streaming_chunk_size: usize,
34
35    /// Enable compression for cached data
36    pub enable_compression: bool,
37
38    /// Compression algorithm
39    pub compression_algorithm: CompressionAlgorithm,
40
41    /// Enable tensor optimization
42    pub enable_tensor_optimization: bool,
43
44    /// Use low-precision for inference (f16 instead of f32)
45    pub use_low_precision: bool,
46}
47
48impl Default for MemoryOptimizationConfig {
49    fn default() -> Self {
50        Self {
51            enable_pooling: true,
52            pool_size_mb: 512, // 512MB default pool
53            enable_streaming: true,
54            streaming_chunk_size: 1000,
55            enable_compression: true,
56            compression_algorithm: CompressionAlgorithm::Zstd,
57            enable_tensor_optimization: true,
58            use_low_precision: false, // f32 by default for accuracy
59        }
60    }
61}
62
63/// Memory optimization manager
64pub struct MemoryOptimizer {
65    config: MemoryOptimizationConfig,
66    pool: Option<Arc<RwLock<MemoryPool>>>,
67    compressor: Option<Compressor>,
68    tensor_optimizer: Option<TensorOptimizer>,
69    metrics: Arc<RwLock<MemoryMetrics>>,
70}
71
72/// Memory usage metrics
73#[derive(Debug, Clone, Default)]
74pub struct MemoryMetrics {
75    pub total_allocated: usize,
76    pub total_freed: usize,
77    pub current_usage: usize,
78    pub peak_usage: usize,
79    pub compression_ratio: f64,
80    pub pool_hits: u64,
81    pub pool_misses: u64,
82}
83
84impl MemoryOptimizer {
85    /// Create a new memory optimizer
86    pub fn new(config: MemoryOptimizationConfig) -> Result<Self> {
87        let pool = if config.enable_pooling {
88            Some(Arc::new(RwLock::new(MemoryPool::new(
89                config.pool_size_mb * 1024 * 1024,
90            ))))
91        } else {
92            None
93        };
94
95        let compressor = if config.enable_compression {
96            Some(Compressor::new(config.compression_algorithm))
97        } else {
98            None
99        };
100
101        let tensor_optimizer = if config.enable_tensor_optimization {
102            Some(TensorOptimizer::new(config.use_low_precision))
103        } else {
104            None
105        };
106
107        Ok(Self {
108            config,
109            pool,
110            compressor,
111            tensor_optimizer,
112            metrics: Arc::new(RwLock::new(MemoryMetrics::default())),
113        })
114    }
115
116    /// Allocate memory from pool if available, otherwise use heap
117    pub fn allocate(&self, size: usize) -> Result<PooledBuffer> {
118        if let Some(ref pool) = self.pool {
119            let mut pool_guard = pool
120                .write()
121                .map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
122
123            match pool_guard.allocate(size) {
124                Ok(buffer) => {
125                    // Update metrics
126                    let mut metrics = self
127                        .metrics
128                        .write()
129                        .map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
130                    metrics.pool_hits += 1;
131                    metrics.total_allocated += size;
132                    metrics.current_usage += size;
133                    if metrics.current_usage > metrics.peak_usage {
134                        metrics.peak_usage = metrics.current_usage;
135                    }
136
137                    Ok(buffer)
138                }
139                Err(_) => {
140                    // Pool exhausted, allocate on heap
141                    let mut metrics = self
142                        .metrics
143                        .write()
144                        .map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
145                    metrics.pool_misses += 1;
146
147                    PooledBuffer::new_heap(size)
148                }
149            }
150        } else {
151            // Pooling disabled, allocate on heap
152            PooledBuffer::new_heap(size)
153        }
154    }
155
156    /// Compress data if compression is enabled
157    pub fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
158        if let Some(ref compressor) = self.compressor {
159            let compressed = compressor.compress(data)?;
160
161            // Update metrics
162            let mut metrics = self
163                .metrics
164                .write()
165                .map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
166            let ratio = data.len() as f64 / compressed.len() as f64;
167            metrics.compression_ratio = ratio;
168
169            Ok(compressed)
170        } else {
171            Ok(data.to_vec())
172        }
173    }
174
175    /// Decompress data if compression is enabled
176    pub fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
177        if let Some(ref compressor) = self.compressor {
178            compressor.decompress(data)
179        } else {
180            Ok(data.to_vec())
181        }
182    }
183
184    /// Optimize tensor for memory efficiency
185    pub fn optimize_tensor(&self, tensor: &[f32]) -> Result<MemoryEfficientTensor> {
186        if let Some(ref optimizer) = self.tensor_optimizer {
187            optimizer.optimize(tensor)
188        } else {
189            // No optimization, wrap as-is
190            Ok(MemoryEfficientTensor::F32(tensor.to_vec()))
191        }
192    }
193
194    /// Get current memory metrics
195    pub fn metrics(&self) -> Result<MemoryMetrics> {
196        let metrics = self
197            .metrics
198            .read()
199            .map_err(|e| anyhow!("Failed to acquire read lock: {}", e))?;
200        Ok(metrics.clone())
201    }
202
203    /// Reset metrics
204    pub fn reset_metrics(&self) -> Result<()> {
205        let mut metrics = self
206            .metrics
207            .write()
208            .map_err(|e| anyhow!("Failed to acquire write lock: {}", e))?;
209        *metrics = MemoryMetrics::default();
210        Ok(())
211    }
212
213    /// Get pool statistics
214    pub fn pool_hit_rate(&self) -> Result<f64> {
215        let metrics = self
216            .metrics
217            .read()
218            .map_err(|e| anyhow!("Failed to acquire read lock: {}", e))?;
219
220        let total = metrics.pool_hits + metrics.pool_misses;
221        if total == 0 {
222            return Ok(0.0);
223        }
224
225        Ok(metrics.pool_hits as f64 / total as f64)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_memory_optimizer_creation() {
235        let optimizer =
236            MemoryOptimizer::new(MemoryOptimizationConfig::default()).expect("should succeed");
237        let metrics = optimizer.metrics().expect("should succeed");
238        assert_eq!(metrics.total_allocated, 0);
239    }
240
241    #[test]
242    fn test_memory_allocation() {
243        let optimizer =
244            MemoryOptimizer::new(MemoryOptimizationConfig::default()).expect("should succeed");
245
246        let buffer = optimizer.allocate(1024).expect("should succeed");
247        assert!(buffer.len() >= 1024);
248
249        let metrics = optimizer.metrics().expect("should succeed");
250        assert_eq!(metrics.pool_hits, 1);
251        assert_eq!(metrics.total_allocated, 1024);
252    }
253
254    #[test]
255    fn test_compression() {
256        let optimizer =
257            MemoryOptimizer::new(MemoryOptimizationConfig::default()).expect("should succeed");
258
259        let data = vec![42u8; 1000];
260        let compressed = optimizer.compress(&data).expect("should succeed");
261        assert!(compressed.len() < data.len());
262
263        let decompressed = optimizer.decompress(&compressed).expect("should succeed");
264        assert_eq!(decompressed, data);
265    }
266
267    #[test]
268    fn test_pool_hit_rate() {
269        let optimizer =
270            MemoryOptimizer::new(MemoryOptimizationConfig::default()).expect("should succeed");
271
272        // Allocate from pool
273        let _b1 = optimizer.allocate(1024).expect("should succeed");
274        let _b2 = optimizer.allocate(2048).expect("should succeed");
275
276        let hit_rate = optimizer.pool_hit_rate().expect("should succeed");
277        assert!(hit_rate > 0.0);
278    }
279}