1use scirs2_core::ndarray::Array1;
7use std::cell::RefCell;
8use std::collections::VecDeque;
9
10#[derive(Debug)]
12pub struct ArrayPool {
13 array_size: usize,
15 max_capacity: usize,
17 pool: RefCell<VecDeque<Array1<f32>>>,
19 stats: RefCell<PoolStats>,
21}
22
23#[derive(Debug, Default, Clone, Copy)]
25pub struct PoolStats {
26 pub hits: u64,
28 pub misses: u64,
30 pub returns: u64,
32 pub drops: u64,
34}
35
36impl PoolStats {
37 pub fn hit_rate(&self) -> f64 {
39 let total = self.hits + self.misses;
40 if total == 0 {
41 0.0
42 } else {
43 self.hits as f64 / total as f64 * 100.0
44 }
45 }
46}
47
48impl ArrayPool {
49 pub fn new(array_size: usize, max_capacity: usize) -> Self {
51 Self {
52 array_size,
53 max_capacity,
54 pool: RefCell::new(VecDeque::with_capacity(max_capacity)),
55 stats: RefCell::new(PoolStats::default()),
56 }
57 }
58
59 pub fn array_size(&self) -> usize {
61 self.array_size
62 }
63
64 pub fn size(&self) -> usize {
66 self.pool.borrow().len()
67 }
68
69 pub fn stats(&self) -> PoolStats {
71 *self.stats.borrow()
72 }
73
74 pub fn acquire(&self) -> Array1<f32> {
76 let mut pool = self.pool.borrow_mut();
77 let mut stats = self.stats.borrow_mut();
78
79 if let Some(array) = pool.pop_front() {
80 stats.hits += 1;
81 array
82 } else {
83 stats.misses += 1;
84 Array1::zeros(self.array_size)
85 }
86 }
87
88 pub fn acquire_filled(&self, value: f32) -> Array1<f32> {
90 let mut array = self.acquire();
91 array.fill(value);
92 array
93 }
94
95 pub fn acquire_zeros(&self) -> Array1<f32> {
97 let mut array = self.acquire();
98 array.fill(0.0);
99 array
100 }
101
102 pub fn release(&self, array: Array1<f32>) {
104 if array.len() != self.array_size {
106 return;
107 }
108
109 let mut pool = self.pool.borrow_mut();
110 let mut stats = self.stats.borrow_mut();
111
112 if pool.len() < self.max_capacity {
113 pool.push_back(array);
114 stats.returns += 1;
115 } else {
116 stats.drops += 1;
117 }
119 }
120
121 pub fn clear(&self) {
123 self.pool.borrow_mut().clear();
124 }
125
126 pub fn warm(&self) {
128 let mut pool = self.pool.borrow_mut();
129 while pool.len() < self.max_capacity {
130 pool.push_back(Array1::zeros(self.array_size));
131 }
132 }
133}
134
135pub struct PooledArray<'a> {
137 array: Option<Array1<f32>>,
138 pool: &'a ArrayPool,
139}
140
141impl<'a> PooledArray<'a> {
142 pub fn new(pool: &'a ArrayPool) -> Self {
144 Self {
145 array: Some(pool.acquire()),
146 pool,
147 }
148 }
149
150 pub fn zeros(pool: &'a ArrayPool) -> Self {
152 Self {
153 array: Some(pool.acquire_zeros()),
154 pool,
155 }
156 }
157
158 pub fn as_array(&self) -> &Array1<f32> {
160 self.array.as_ref().unwrap()
161 }
162
163 pub fn as_array_mut(&mut self) -> &mut Array1<f32> {
165 self.array.as_mut().unwrap()
166 }
167
168 pub fn take(mut self) -> Array1<f32> {
170 self.array.take().unwrap()
171 }
172}
173
174impl Drop for PooledArray<'_> {
175 fn drop(&mut self) {
176 if let Some(array) = self.array.take() {
177 self.pool.release(array);
178 }
179 }
180}
181
182impl std::ops::Deref for PooledArray<'_> {
183 type Target = Array1<f32>;
184
185 fn deref(&self) -> &Self::Target {
186 self.as_array()
187 }
188}
189
190impl std::ops::DerefMut for PooledArray<'_> {
191 fn deref_mut(&mut self) -> &mut Self::Target {
192 self.as_array_mut()
193 }
194}
195
196#[derive(Debug)]
198pub struct MultiArrayPool {
199 pools: Vec<ArrayPool>,
201 sizes: Vec<usize>,
203}
204
205impl MultiArrayPool {
206 pub fn new() -> Self {
208 Self::with_sizes(&[32, 64, 128, 256, 512, 1024, 2048, 4096], 8)
210 }
211
212 pub fn with_sizes(sizes: &[usize], capacity_per_size: usize) -> Self {
214 let mut sorted_sizes: Vec<usize> = sizes.to_vec();
215 sorted_sizes.sort_unstable();
216
217 let pools = sorted_sizes
218 .iter()
219 .map(|&size| ArrayPool::new(size, capacity_per_size))
220 .collect();
221
222 Self {
223 pools,
224 sizes: sorted_sizes,
225 }
226 }
227
228 pub fn acquire(&self, min_size: usize) -> Array1<f32> {
230 if let Some(idx) = self.sizes.iter().position(|&s| s >= min_size) {
232 self.pools[idx].acquire()
233 } else {
234 Array1::zeros(min_size)
236 }
237 }
238
239 pub fn acquire_zeros(&self, min_size: usize) -> Array1<f32> {
241 let mut arr = self.acquire(min_size);
242 arr.fill(0.0);
243 arr
244 }
245
246 pub fn release(&self, array: Array1<f32>) {
248 let size = array.len();
249 if let Some(idx) = self.sizes.iter().position(|&s| s == size) {
251 self.pools[idx].release(array);
252 }
253 }
255
256 pub fn stats(&self) -> PoolStats {
258 let mut total = PoolStats::default();
259 for pool in &self.pools {
260 let s = pool.stats();
261 total.hits += s.hits;
262 total.misses += s.misses;
263 total.returns += s.returns;
264 total.drops += s.drops;
265 }
266 total
267 }
268
269 pub fn warm(&self) {
271 for pool in &self.pools {
272 pool.warm();
273 }
274 }
275
276 pub fn clear(&self) {
278 for pool in &self.pools {
279 pool.clear();
280 }
281 }
282}
283
284impl Default for MultiArrayPool {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290thread_local! {
292 static LOCAL_POOL: MultiArrayPool = MultiArrayPool::new();
293}
294
295pub fn tl_acquire(min_size: usize) -> Array1<f32> {
297 LOCAL_POOL.with(|pool| pool.acquire(min_size))
298}
299
300pub fn tl_acquire_zeros(min_size: usize) -> Array1<f32> {
302 LOCAL_POOL.with(|pool| pool.acquire_zeros(min_size))
303}
304
305pub fn tl_release(array: Array1<f32>) {
307 LOCAL_POOL.with(|pool| pool.release(array));
308}
309
310pub fn tl_stats() -> PoolStats {
312 LOCAL_POOL.with(|pool| pool.stats())
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_array_pool_basic() {
321 let pool = ArrayPool::new(64, 4);
322
323 let arr1 = pool.acquire();
325 assert_eq!(arr1.len(), 64);
326
327 pool.release(arr1);
329 assert_eq!(pool.size(), 1);
330
331 let arr2 = pool.acquire();
333 assert_eq!(arr2.len(), 64);
334 assert_eq!(pool.size(), 0);
335
336 let stats = pool.stats();
337 assert_eq!(stats.hits, 1);
338 assert_eq!(stats.misses, 1);
339 }
340
341 #[test]
342 fn test_array_pool_capacity() {
343 let pool = ArrayPool::new(32, 2);
344
345 let a1 = pool.acquire();
347 let a2 = pool.acquire();
348 let a3 = pool.acquire();
349
350 pool.release(a1);
351 pool.release(a2);
352 pool.release(a3); let stats = pool.stats();
355 assert_eq!(stats.returns, 2);
356 assert_eq!(stats.drops, 1);
357 }
358
359 #[test]
360 fn test_pooled_array_scope() {
361 let pool = ArrayPool::new(32, 4);
362
363 {
364 let mut arr = PooledArray::zeros(&pool);
365 arr[0] = 1.0;
366 assert_eq!(arr.len(), 32);
367 } assert_eq!(pool.size(), 1);
370 }
371
372 #[test]
373 fn test_pooled_array_take() {
374 let pool = ArrayPool::new(32, 4);
375
376 let owned = {
377 let arr = PooledArray::zeros(&pool);
378 arr.take() };
380
381 assert_eq!(owned.len(), 32);
382 assert_eq!(pool.size(), 0); }
384
385 #[test]
386 fn test_multi_pool() {
387 let pool = MultiArrayPool::with_sizes(&[32, 64, 128], 4);
388
389 let arr = pool.acquire(50);
391 assert_eq!(arr.len(), 64);
392
393 pool.release(arr);
394 assert_eq!(pool.stats().returns, 1);
395 }
396
397 #[test]
398 fn test_pool_warm() {
399 let pool = ArrayPool::new(64, 4);
400 pool.warm();
401 assert_eq!(pool.size(), 4);
402
403 for _ in 0..4 {
405 let _ = pool.acquire();
406 }
407 let stats = pool.stats();
408 assert_eq!(stats.hits, 4);
409 assert_eq!(stats.misses, 0);
410 }
411
412 #[test]
413 fn test_hit_rate() {
414 let stats = PoolStats {
415 hits: 80,
416 misses: 20,
417 returns: 0,
418 drops: 0,
419 };
420 assert!((stats.hit_rate() - 80.0).abs() < 0.01);
421 }
422
423 #[test]
424 fn test_thread_local_pool() {
425 let arr = tl_acquire_zeros(100);
426 assert!(arr.len() >= 100);
427
428 tl_release(arr);
429
430 let stats = tl_stats();
431 assert!(stats.misses >= 1);
432 }
433}