kizzasi_inference/
pool.rs

1//! Memory pool for efficient tensor allocation and reuse.
2//!
3//! This module provides a buffer pool to reduce allocations during inference
4//! by reusing pre-allocated tensor buffers. This is especially important for
5//! streaming scenarios where state tensors are frequently created and destroyed.
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10use crate::error::{InferenceError, InferenceResult};
11
12/// A key identifying a buffer shape and type configuration.
13#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14pub struct BufferKey {
15    /// Total number of elements in the buffer
16    pub size: usize,
17    /// Element type identifier (e.g., "f32", "f64")
18    pub dtype: String,
19    /// Optional semantic tag for specialized pools
20    pub tag: Option<String>,
21}
22
23impl BufferKey {
24    /// Create a new buffer key for f32 tensors.
25    pub fn f32(size: usize) -> Self {
26        Self {
27            size,
28            dtype: "f32".to_string(),
29            tag: None,
30        }
31    }
32
33    /// Create a new buffer key for f64 tensors.
34    pub fn f64(size: usize) -> Self {
35        Self {
36            size,
37            dtype: "f64".to_string(),
38            tag: None,
39        }
40    }
41
42    /// Add a semantic tag to this key.
43    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
44        self.tag = Some(tag.into());
45        self
46    }
47}
48
49/// A pooled buffer that returns itself to the pool when dropped.
50pub struct PooledBuffer<T> {
51    data: Vec<T>,
52    key: BufferKey,
53    pool: Arc<Mutex<TensorPoolInner>>,
54}
55
56impl<T> PooledBuffer<T> {
57    /// Get a reference to the underlying data.
58    pub fn data(&self) -> &[T] {
59        &self.data
60    }
61
62    /// Get a mutable reference to the underlying data.
63    pub fn data_mut(&mut self) -> &mut [T] {
64        &mut self.data
65    }
66
67    /// Get the size of the buffer.
68    pub fn len(&self) -> usize {
69        self.data.len()
70    }
71
72    /// Check if the buffer is empty.
73    pub fn is_empty(&self) -> bool {
74        self.data.is_empty()
75    }
76
77    /// Consume the buffer and return the underlying Vec.
78    /// This prevents the buffer from being returned to the pool.
79    pub fn into_vec(mut self) -> Vec<T> {
80        // Take data and leave empty vec to prevent return to pool
81        std::mem::take(&mut self.data)
82    }
83}
84
85impl<T> Drop for PooledBuffer<T> {
86    fn drop(&mut self) {
87        // Return buffer to pool only if it's not empty
88        if !self.data.is_empty() {
89            if let Ok(mut pool) = self.pool.lock() {
90                pool.return_raw_buffer(self.key.clone(), std::mem::take(&mut self.data));
91            }
92        }
93    }
94}
95
96impl<T> std::ops::Deref for PooledBuffer<T> {
97    type Target = [T];
98    fn deref(&self) -> &Self::Target {
99        &self.data
100    }
101}
102
103impl<T> std::ops::DerefMut for PooledBuffer<T> {
104    fn deref_mut(&mut self) -> &mut Self::Target {
105        &mut self.data
106    }
107}
108
109/// A thread-safe memory pool for tensor buffers.
110///
111/// The pool maintains separate storage for different buffer configurations
112/// and automatically grows as needed. Buffers are returned to the pool when dropped.
113#[derive(Clone)]
114pub struct TensorPool {
115    inner: Arc<Mutex<TensorPoolInner>>,
116}
117
118struct TensorPoolInner {
119    /// Storage for f32 buffers
120    f32_buffers: HashMap<BufferKey, Vec<Vec<f32>>>,
121    /// Storage for f64 buffers
122    f64_buffers: HashMap<BufferKey, Vec<Vec<f64>>>,
123    /// Maximum number of buffers to keep per key
124    max_buffers_per_key: usize,
125    /// Statistics
126    stats: PoolStats,
127}
128
129#[derive(Debug, Clone, Default)]
130pub struct PoolStats {
131    /// Total number of buffer allocations
132    pub total_allocations: usize,
133    /// Number of buffer reuses from pool
134    pub total_reuses: usize,
135    /// Number of buffers returned to pool
136    pub total_returns: usize,
137    /// Number of buffers discarded (pool full)
138    pub total_discards: usize,
139}
140
141impl TensorPool {
142    /// Create a new tensor pool with default capacity (16 buffers per key).
143    pub fn new() -> Self {
144        Self::with_capacity(16)
145    }
146
147    /// Create a new tensor pool with specified maximum buffers per key.
148    pub fn with_capacity(max_buffers_per_key: usize) -> Self {
149        Self {
150            inner: Arc::new(Mutex::new(TensorPoolInner {
151                f32_buffers: HashMap::new(),
152                f64_buffers: HashMap::new(),
153                max_buffers_per_key,
154                stats: PoolStats::default(),
155            })),
156        }
157    }
158
159    /// Acquire a pooled f32 buffer.
160    pub fn acquire_f32(&self, key: BufferKey) -> InferenceResult<PooledBuffer<f32>> {
161        let mut inner = self
162            .inner
163            .lock()
164            .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
165        let data = inner.f32_buffers.get_mut(&key).and_then(|pool| pool.pop());
166
167        let data = if let Some(mut buf) = data {
168            inner.stats.total_reuses += 1;
169            // Clear the buffer for reuse
170            buf.clear();
171            buf.resize(key.size, 0.0);
172            buf
173        } else {
174            inner.stats.total_allocations += 1;
175            vec![0.0; key.size]
176        };
177
178        drop(inner); // Release lock before returning
179
180        Ok(PooledBuffer {
181            data,
182            key,
183            pool: self.inner.clone(),
184        })
185    }
186
187    /// Acquire a pooled f64 buffer.
188    pub fn acquire_f64(&self, key: BufferKey) -> InferenceResult<PooledBuffer<f64>> {
189        let mut inner = self
190            .inner
191            .lock()
192            .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
193        let data = inner.f64_buffers.get_mut(&key).and_then(|pool| pool.pop());
194
195        let data = if let Some(mut buf) = data {
196            inner.stats.total_reuses += 1;
197            // Clear the buffer for reuse
198            buf.clear();
199            buf.resize(key.size, 0.0);
200            buf
201        } else {
202            inner.stats.total_allocations += 1;
203            vec![0.0; key.size]
204        };
205
206        drop(inner); // Release lock before returning
207
208        Ok(PooledBuffer {
209            data,
210            key,
211            pool: self.inner.clone(),
212        })
213    }
214
215    /// Clear all pooled buffers.
216    pub fn clear(&self) -> InferenceResult<()> {
217        let mut inner = self
218            .inner
219            .lock()
220            .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
221        inner.f32_buffers.clear();
222        inner.f64_buffers.clear();
223        Ok(())
224    }
225
226    /// Get pool statistics.
227    pub fn stats(&self) -> InferenceResult<PoolStats> {
228        let inner = self
229            .inner
230            .lock()
231            .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
232        Ok(inner.stats.clone())
233    }
234
235    /// Get the current number of pooled buffers.
236    pub fn pooled_count(&self) -> InferenceResult<usize> {
237        let inner = self
238            .inner
239            .lock()
240            .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
241        Ok(inner.f32_buffers.values().map(|v| v.len()).sum::<usize>()
242            + inner.f64_buffers.values().map(|v| v.len()).sum::<usize>())
243    }
244}
245
246impl TensorPoolInner {
247    fn return_raw_buffer<T>(&mut self, key: BufferKey, buffer: Vec<T>) {
248        self.stats.total_returns += 1;
249
250        match key.dtype.as_str() {
251            "f32" => {
252                let pool = self.f32_buffers.entry(key).or_default();
253                if pool.len() < self.max_buffers_per_key {
254                    // Safe because we know T is f32 for dtype "f32"
255                    let buffer: Vec<f32> = unsafe { std::mem::transmute(buffer) };
256                    pool.push(buffer);
257                } else {
258                    self.stats.total_discards += 1;
259                }
260            }
261            "f64" => {
262                let pool = self.f64_buffers.entry(key).or_default();
263                if pool.len() < self.max_buffers_per_key {
264                    // Safe because we know T is f64 for dtype "f64"
265                    let buffer: Vec<f64> = unsafe { std::mem::transmute(buffer) };
266                    pool.push(buffer);
267                } else {
268                    self.stats.total_discards += 1;
269                }
270            }
271            _ => {
272                // Unknown dtype, discard
273                self.stats.total_discards += 1;
274            }
275        }
276    }
277}
278
279impl Default for TensorPool {
280    fn default() -> Self {
281        Self::new()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_buffer_pool_basic() {
291        let pool = TensorPool::new();
292        let key = BufferKey::f32(1024);
293
294        // First acquisition should allocate
295        let buf1 = pool
296            .acquire_f32(key.clone())
297            .expect("Failed to acquire buffer");
298        assert_eq!(buf1.len(), 1024);
299        let stats1 = pool.stats().expect("Failed to get stats");
300        assert_eq!(stats1.total_allocations, 1);
301        assert_eq!(stats1.total_reuses, 0);
302
303        // Drop and reacquire should reuse
304        drop(buf1);
305        let buf2 = pool
306            .acquire_f32(key.clone())
307            .expect("Failed to acquire buffer");
308        let stats2 = pool.stats().expect("Failed to get stats");
309        assert_eq!(stats2.total_allocations, 1);
310        assert_eq!(stats2.total_reuses, 1);
311        assert_eq!(stats2.total_returns, 1);
312
313        drop(buf2);
314    }
315
316    #[test]
317    fn test_buffer_pool_multiple_keys() {
318        let pool = TensorPool::new();
319        let key1 = BufferKey::f32(512);
320        let key2 = BufferKey::f32(1024);
321        let key3 = BufferKey::f64(512);
322
323        let buf1 = pool
324            .acquire_f32(key1.clone())
325            .expect("Failed to acquire buffer");
326        let buf2 = pool
327            .acquire_f32(key2.clone())
328            .expect("Failed to acquire buffer");
329        let buf3 = pool
330            .acquire_f64(key3.clone())
331            .expect("Failed to acquire buffer");
332
333        assert_eq!(buf1.len(), 512);
334        assert_eq!(buf2.len(), 1024);
335        assert_eq!(buf3.len(), 512);
336
337        drop(buf1);
338        drop(buf2);
339        drop(buf3);
340
341        let stats = pool.stats().expect("Failed to get stats");
342        assert_eq!(stats.total_allocations, 3);
343        assert_eq!(stats.total_returns, 3);
344    }
345
346    #[test]
347    fn test_buffer_pool_capacity_limit() {
348        let pool = TensorPool::with_capacity(2);
349        let key = BufferKey::f32(100);
350
351        // Create 3 buffers simultaneously (before dropping any)
352        // This forces 3 allocations since pool is empty
353        let buf1 = pool
354            .acquire_f32(key.clone())
355            .expect("Failed to acquire buffer");
356        let buf2 = pool
357            .acquire_f32(key.clone())
358            .expect("Failed to acquire buffer");
359        let buf3 = pool
360            .acquire_f32(key.clone())
361            .expect("Failed to acquire buffer");
362
363        // Now drop all 3 - they will try to return to pool
364        // But pool capacity is 2, so one should be discarded
365        drop(buf1);
366        drop(buf2);
367        drop(buf3);
368
369        let stats = pool.stats().expect("Failed to get stats");
370        // All 3 were new allocations (pool was empty)
371        assert_eq!(stats.total_allocations, 3);
372        assert_eq!(stats.total_reuses, 0);
373        // All 3 returns attempted, but 1 discarded due to capacity
374        assert_eq!(stats.total_returns, 3);
375        assert_eq!(stats.total_discards, 1);
376        // Pool should only have 2 buffers
377        assert_eq!(pool.pooled_count().expect("Failed to get count"), 2);
378    }
379
380    #[test]
381    fn test_buffer_pool_tagged_keys() {
382        let pool = TensorPool::new();
383        let key1 = BufferKey::f32(1024).with_tag("state");
384        let key2 = BufferKey::f32(1024).with_tag("output");
385        let key3 = BufferKey::f32(1024); // No tag
386
387        let buf1 = pool
388            .acquire_f32(key1.clone())
389            .expect("Failed to acquire buffer");
390        let buf2 = pool
391            .acquire_f32(key2.clone())
392            .expect("Failed to acquire buffer");
393        let buf3 = pool
394            .acquire_f32(key3.clone())
395            .expect("Failed to acquire buffer");
396
397        assert_eq!(buf1.len(), 1024);
398        assert_eq!(buf2.len(), 1024);
399        assert_eq!(buf3.len(), 1024);
400
401        drop(buf1);
402        drop(buf2);
403        drop(buf3);
404
405        // All should be separate pools
406        let stats = pool.stats().expect("Failed to get stats");
407        assert_eq!(stats.total_allocations, 3);
408        assert_eq!(pool.pooled_count().expect("Failed to get count"), 3);
409    }
410
411    #[test]
412    fn test_buffer_clear() {
413        let pool = TensorPool::new();
414        let key = BufferKey::f32(100);
415
416        let mut buf = pool
417            .acquire_f32(key.clone())
418            .expect("Failed to acquire buffer");
419        buf[0] = 42.0;
420        drop(buf);
421
422        // After reacquisition, buffer should be cleared
423        let buf2 = pool.acquire_f32(key).expect("Failed to acquire buffer");
424        assert_eq!(buf2[0], 0.0);
425    }
426
427    #[test]
428    fn test_pooled_buffer_into_vec() {
429        let pool = TensorPool::new();
430        let key = BufferKey::f32(100);
431
432        let mut buf = pool
433            .acquire_f32(key.clone())
434            .expect("Failed to acquire buffer");
435        buf[0] = 42.0;
436
437        let vec = buf.into_vec();
438        assert_eq!(vec[0], 42.0);
439        assert_eq!(vec.len(), 100);
440
441        // Buffer should not have been returned to pool
442        let stats = pool.stats().expect("Failed to get stats");
443        assert_eq!(stats.total_returns, 0);
444    }
445
446    #[test]
447    fn test_concurrent_access() {
448        use std::sync::Arc;
449        use std::thread;
450
451        let pool = Arc::new(TensorPool::new());
452        let handles: Vec<_> = (0..4)
453            .map(|i| {
454                let pool = pool.clone();
455                thread::spawn(move || {
456                    for _ in 0..100 {
457                        let key = BufferKey::f32(1024).with_tag(format!("thread_{}", i));
458                        let buf = pool.acquire_f32(key).expect("Failed to acquire buffer");
459                        assert_eq!(buf.len(), 1024);
460                        drop(buf);
461                    }
462                })
463            })
464            .collect();
465
466        for handle in handles {
467            handle.join().expect("Thread panicked");
468        }
469
470        let stats = pool.stats().expect("Failed to get stats");
471        assert!(stats.total_allocations > 0);
472        assert!(stats.total_reuses > 0);
473    }
474}