kizzasi_inference/
pool.rs1use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10use crate::error::{InferenceError, InferenceResult};
11
12#[derive(Debug, Clone, Hash, Eq, PartialEq)]
14pub struct BufferKey {
15 pub size: usize,
17 pub dtype: String,
19 pub tag: Option<String>,
21}
22
23impl BufferKey {
24 pub fn f32(size: usize) -> Self {
26 Self {
27 size,
28 dtype: "f32".to_string(),
29 tag: None,
30 }
31 }
32
33 pub fn f64(size: usize) -> Self {
35 Self {
36 size,
37 dtype: "f64".to_string(),
38 tag: None,
39 }
40 }
41
42 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
44 self.tag = Some(tag.into());
45 self
46 }
47}
48
49pub struct PooledBuffer<T> {
51 data: Vec<T>,
52 key: BufferKey,
53 pool: Arc<Mutex<TensorPoolInner>>,
54}
55
56impl<T> PooledBuffer<T> {
57 pub fn data(&self) -> &[T] {
59 &self.data
60 }
61
62 pub fn data_mut(&mut self) -> &mut [T] {
64 &mut self.data
65 }
66
67 pub fn len(&self) -> usize {
69 self.data.len()
70 }
71
72 pub fn is_empty(&self) -> bool {
74 self.data.is_empty()
75 }
76
77 pub fn into_vec(mut self) -> Vec<T> {
80 std::mem::take(&mut self.data)
82 }
83}
84
85impl<T> Drop for PooledBuffer<T> {
86 fn drop(&mut self) {
87 if !self.data.is_empty() {
89 if let Ok(mut pool) = self.pool.lock() {
90 pool.return_raw_buffer(self.key.clone(), std::mem::take(&mut self.data));
91 }
92 }
93 }
94}
95
96impl<T> std::ops::Deref for PooledBuffer<T> {
97 type Target = [T];
98 fn deref(&self) -> &Self::Target {
99 &self.data
100 }
101}
102
103impl<T> std::ops::DerefMut for PooledBuffer<T> {
104 fn deref_mut(&mut self) -> &mut Self::Target {
105 &mut self.data
106 }
107}
108
109#[derive(Clone)]
114pub struct TensorPool {
115 inner: Arc<Mutex<TensorPoolInner>>,
116}
117
118struct TensorPoolInner {
119 f32_buffers: HashMap<BufferKey, Vec<Vec<f32>>>,
121 f64_buffers: HashMap<BufferKey, Vec<Vec<f64>>>,
123 max_buffers_per_key: usize,
125 stats: PoolStats,
127}
128
129#[derive(Debug, Clone, Default)]
130pub struct PoolStats {
131 pub total_allocations: usize,
133 pub total_reuses: usize,
135 pub total_returns: usize,
137 pub total_discards: usize,
139}
140
141impl TensorPool {
142 pub fn new() -> Self {
144 Self::with_capacity(16)
145 }
146
147 pub fn with_capacity(max_buffers_per_key: usize) -> Self {
149 Self {
150 inner: Arc::new(Mutex::new(TensorPoolInner {
151 f32_buffers: HashMap::new(),
152 f64_buffers: HashMap::new(),
153 max_buffers_per_key,
154 stats: PoolStats::default(),
155 })),
156 }
157 }
158
159 pub fn acquire_f32(&self, key: BufferKey) -> InferenceResult<PooledBuffer<f32>> {
161 let mut inner = self
162 .inner
163 .lock()
164 .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
165 let data = inner.f32_buffers.get_mut(&key).and_then(|pool| pool.pop());
166
167 let data = if let Some(mut buf) = data {
168 inner.stats.total_reuses += 1;
169 buf.clear();
171 buf.resize(key.size, 0.0);
172 buf
173 } else {
174 inner.stats.total_allocations += 1;
175 vec![0.0; key.size]
176 };
177
178 drop(inner); Ok(PooledBuffer {
181 data,
182 key,
183 pool: self.inner.clone(),
184 })
185 }
186
187 pub fn acquire_f64(&self, key: BufferKey) -> InferenceResult<PooledBuffer<f64>> {
189 let mut inner = self
190 .inner
191 .lock()
192 .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
193 let data = inner.f64_buffers.get_mut(&key).and_then(|pool| pool.pop());
194
195 let data = if let Some(mut buf) = data {
196 inner.stats.total_reuses += 1;
197 buf.clear();
199 buf.resize(key.size, 0.0);
200 buf
201 } else {
202 inner.stats.total_allocations += 1;
203 vec![0.0; key.size]
204 };
205
206 drop(inner); Ok(PooledBuffer {
209 data,
210 key,
211 pool: self.inner.clone(),
212 })
213 }
214
215 pub fn clear(&self) -> InferenceResult<()> {
217 let mut inner = self
218 .inner
219 .lock()
220 .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
221 inner.f32_buffers.clear();
222 inner.f64_buffers.clear();
223 Ok(())
224 }
225
226 pub fn stats(&self) -> InferenceResult<PoolStats> {
228 let inner = self
229 .inner
230 .lock()
231 .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
232 Ok(inner.stats.clone())
233 }
234
235 pub fn pooled_count(&self) -> InferenceResult<usize> {
237 let inner = self
238 .inner
239 .lock()
240 .map_err(|e| InferenceError::LockError(format!("Failed to acquire lock: {}", e)))?;
241 Ok(inner.f32_buffers.values().map(|v| v.len()).sum::<usize>()
242 + inner.f64_buffers.values().map(|v| v.len()).sum::<usize>())
243 }
244}
245
246impl TensorPoolInner {
247 fn return_raw_buffer<T>(&mut self, key: BufferKey, buffer: Vec<T>) {
248 self.stats.total_returns += 1;
249
250 match key.dtype.as_str() {
251 "f32" => {
252 let pool = self.f32_buffers.entry(key).or_default();
253 if pool.len() < self.max_buffers_per_key {
254 let buffer: Vec<f32> = unsafe { std::mem::transmute(buffer) };
256 pool.push(buffer);
257 } else {
258 self.stats.total_discards += 1;
259 }
260 }
261 "f64" => {
262 let pool = self.f64_buffers.entry(key).or_default();
263 if pool.len() < self.max_buffers_per_key {
264 let buffer: Vec<f64> = unsafe { std::mem::transmute(buffer) };
266 pool.push(buffer);
267 } else {
268 self.stats.total_discards += 1;
269 }
270 }
271 _ => {
272 self.stats.total_discards += 1;
274 }
275 }
276 }
277}
278
279impl Default for TensorPool {
280 fn default() -> Self {
281 Self::new()
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_buffer_pool_basic() {
291 let pool = TensorPool::new();
292 let key = BufferKey::f32(1024);
293
294 let buf1 = pool
296 .acquire_f32(key.clone())
297 .expect("Failed to acquire buffer");
298 assert_eq!(buf1.len(), 1024);
299 let stats1 = pool.stats().expect("Failed to get stats");
300 assert_eq!(stats1.total_allocations, 1);
301 assert_eq!(stats1.total_reuses, 0);
302
303 drop(buf1);
305 let buf2 = pool
306 .acquire_f32(key.clone())
307 .expect("Failed to acquire buffer");
308 let stats2 = pool.stats().expect("Failed to get stats");
309 assert_eq!(stats2.total_allocations, 1);
310 assert_eq!(stats2.total_reuses, 1);
311 assert_eq!(stats2.total_returns, 1);
312
313 drop(buf2);
314 }
315
316 #[test]
317 fn test_buffer_pool_multiple_keys() {
318 let pool = TensorPool::new();
319 let key1 = BufferKey::f32(512);
320 let key2 = BufferKey::f32(1024);
321 let key3 = BufferKey::f64(512);
322
323 let buf1 = pool
324 .acquire_f32(key1.clone())
325 .expect("Failed to acquire buffer");
326 let buf2 = pool
327 .acquire_f32(key2.clone())
328 .expect("Failed to acquire buffer");
329 let buf3 = pool
330 .acquire_f64(key3.clone())
331 .expect("Failed to acquire buffer");
332
333 assert_eq!(buf1.len(), 512);
334 assert_eq!(buf2.len(), 1024);
335 assert_eq!(buf3.len(), 512);
336
337 drop(buf1);
338 drop(buf2);
339 drop(buf3);
340
341 let stats = pool.stats().expect("Failed to get stats");
342 assert_eq!(stats.total_allocations, 3);
343 assert_eq!(stats.total_returns, 3);
344 }
345
346 #[test]
347 fn test_buffer_pool_capacity_limit() {
348 let pool = TensorPool::with_capacity(2);
349 let key = BufferKey::f32(100);
350
351 let buf1 = pool
354 .acquire_f32(key.clone())
355 .expect("Failed to acquire buffer");
356 let buf2 = pool
357 .acquire_f32(key.clone())
358 .expect("Failed to acquire buffer");
359 let buf3 = pool
360 .acquire_f32(key.clone())
361 .expect("Failed to acquire buffer");
362
363 drop(buf1);
366 drop(buf2);
367 drop(buf3);
368
369 let stats = pool.stats().expect("Failed to get stats");
370 assert_eq!(stats.total_allocations, 3);
372 assert_eq!(stats.total_reuses, 0);
373 assert_eq!(stats.total_returns, 3);
375 assert_eq!(stats.total_discards, 1);
376 assert_eq!(pool.pooled_count().expect("Failed to get count"), 2);
378 }
379
380 #[test]
381 fn test_buffer_pool_tagged_keys() {
382 let pool = TensorPool::new();
383 let key1 = BufferKey::f32(1024).with_tag("state");
384 let key2 = BufferKey::f32(1024).with_tag("output");
385 let key3 = BufferKey::f32(1024); let buf1 = pool
388 .acquire_f32(key1.clone())
389 .expect("Failed to acquire buffer");
390 let buf2 = pool
391 .acquire_f32(key2.clone())
392 .expect("Failed to acquire buffer");
393 let buf3 = pool
394 .acquire_f32(key3.clone())
395 .expect("Failed to acquire buffer");
396
397 assert_eq!(buf1.len(), 1024);
398 assert_eq!(buf2.len(), 1024);
399 assert_eq!(buf3.len(), 1024);
400
401 drop(buf1);
402 drop(buf2);
403 drop(buf3);
404
405 let stats = pool.stats().expect("Failed to get stats");
407 assert_eq!(stats.total_allocations, 3);
408 assert_eq!(pool.pooled_count().expect("Failed to get count"), 3);
409 }
410
411 #[test]
412 fn test_buffer_clear() {
413 let pool = TensorPool::new();
414 let key = BufferKey::f32(100);
415
416 let mut buf = pool
417 .acquire_f32(key.clone())
418 .expect("Failed to acquire buffer");
419 buf[0] = 42.0;
420 drop(buf);
421
422 let buf2 = pool.acquire_f32(key).expect("Failed to acquire buffer");
424 assert_eq!(buf2[0], 0.0);
425 }
426
427 #[test]
428 fn test_pooled_buffer_into_vec() {
429 let pool = TensorPool::new();
430 let key = BufferKey::f32(100);
431
432 let mut buf = pool
433 .acquire_f32(key.clone())
434 .expect("Failed to acquire buffer");
435 buf[0] = 42.0;
436
437 let vec = buf.into_vec();
438 assert_eq!(vec[0], 42.0);
439 assert_eq!(vec.len(), 100);
440
441 let stats = pool.stats().expect("Failed to get stats");
443 assert_eq!(stats.total_returns, 0);
444 }
445
446 #[test]
447 fn test_concurrent_access() {
448 use std::sync::Arc;
449 use std::thread;
450
451 let pool = Arc::new(TensorPool::new());
452 let handles: Vec<_> = (0..4)
453 .map(|i| {
454 let pool = pool.clone();
455 thread::spawn(move || {
456 for _ in 0..100 {
457 let key = BufferKey::f32(1024).with_tag(format!("thread_{}", i));
458 let buf = pool.acquire_f32(key).expect("Failed to acquire buffer");
459 assert_eq!(buf.len(), 1024);
460 drop(buf);
461 }
462 })
463 })
464 .collect();
465
466 for handle in handles {
467 handle.join().expect("Thread panicked");
468 }
469
470 let stats = pool.stats().expect("Failed to get stats");
471 assert!(stats.total_allocations > 0);
472 assert!(stats.total_reuses > 0);
473 }
474}