chess_vector_engine/utils/
object_pool.rs1use ndarray::Array1;
3use std::cell::RefCell;
4use std::collections::VecDeque;
5use std::sync::{Arc, Mutex};
6
7pub struct ObjectPool<T> {
9 pool: Arc<Mutex<VecDeque<T>>>,
10 factory: Arc<dyn Fn() -> T + Send + Sync>,
11 max_size: usize,
12}
13
14impl<T> ObjectPool<T> {
15 pub fn new<F>(factory: F, max_size: usize) -> Self
17 where
18 F: Fn() -> T + Send + Sync + 'static,
19 {
20 Self {
21 pool: Arc::new(Mutex::new(VecDeque::new())),
22 factory: Arc::new(factory),
23 max_size,
24 }
25 }
26
27 pub fn get(&self) -> PooledObject<T> {
29 let obj = {
30 let mut pool = self.pool.lock().unwrap();
31 pool.pop_front().unwrap_or_else(|| (self.factory)())
32 };
33
34 PooledObject {
35 object: Some(obj),
36 pool: Arc::clone(&self.pool),
37 max_size: self.max_size,
38 }
39 }
40
41 pub fn size(&self) -> usize {
43 self.pool.lock().unwrap().len()
44 }
45
46 pub fn clear(&self) {
48 self.pool.lock().unwrap().clear();
49 }
50}
51
52pub struct PooledObject<T> {
54 object: Option<T>,
55 pool: Arc<Mutex<VecDeque<T>>>,
56 max_size: usize,
57}
58
59impl<T> PooledObject<T> {
60 pub fn get(&self) -> &T {
62 self.object.as_ref().unwrap()
63 }
64
65 pub fn get_mut(&mut self) -> &mut T {
67 self.object.as_mut().unwrap()
68 }
69}
70
71impl<T> Drop for PooledObject<T> {
72 fn drop(&mut self) {
73 if let Some(obj) = self.object.take() {
74 let mut pool = self.pool.lock().unwrap();
75 if pool.len() < self.max_size {
76 pool.push_back(obj);
77 }
78 }
79 }
80}
81
82impl<T> std::ops::Deref for PooledObject<T> {
83 type Target = T;
84
85 fn deref(&self) -> &Self::Target {
86 self.get()
87 }
88}
89
90impl<T> std::ops::DerefMut for PooledObject<T> {
91 fn deref_mut(&mut self) -> &mut Self::Target {
92 self.get_mut()
93 }
94}
95
96pub struct ThreadLocalPool<T> {
98 pool: RefCell<VecDeque<T>>,
99 factory: Box<dyn Fn() -> T>,
100 max_size: usize,
101}
102
103impl<T> ThreadLocalPool<T> {
104 pub fn new<F>(factory: F, max_size: usize) -> Self
106 where
107 F: Fn() -> T + 'static,
108 {
109 Self {
110 pool: RefCell::new(VecDeque::new()),
111 factory: Box::new(factory),
112 max_size,
113 }
114 }
115
116 pub fn get(&self) -> ThreadLocalPooledObject<T> {
118 let obj = {
119 let mut pool = self.pool.borrow_mut();
120 pool.pop_front().unwrap_or_else(|| (self.factory)())
121 };
122
123 ThreadLocalPooledObject {
124 object: Some(obj),
125 pool: &self.pool,
126 max_size: self.max_size,
127 }
128 }
129
130 pub fn size(&self) -> usize {
132 self.pool.borrow().len()
133 }
134
135 pub fn clear(&self) {
137 self.pool.borrow_mut().clear();
138 }
139}
140
141pub struct ThreadLocalPooledObject<'a, T> {
143 object: Option<T>,
144 pool: &'a RefCell<VecDeque<T>>,
145 max_size: usize,
146}
147
148impl<'a, T> ThreadLocalPooledObject<'a, T> {
149 pub fn get(&self) -> &T {
151 self.object.as_ref().unwrap()
152 }
153
154 pub fn get_mut(&mut self) -> &mut T {
156 self.object.as_mut().unwrap()
157 }
158}
159
160impl<'a, T> Drop for ThreadLocalPooledObject<'a, T> {
161 fn drop(&mut self) {
162 if let Some(obj) = self.object.take() {
163 let mut pool = self.pool.borrow_mut();
164 if pool.len() < self.max_size {
165 pool.push_back(obj);
166 }
167 }
168 }
169}
170
171impl<'a, T> std::ops::Deref for ThreadLocalPooledObject<'a, T> {
172 type Target = T;
173
174 fn deref(&self) -> &Self::Target {
175 self.get()
176 }
177}
178
179impl<'a, T> std::ops::DerefMut for ThreadLocalPooledObject<'a, T> {
180 fn deref_mut(&mut self) -> &mut Self::Target {
181 self.get_mut()
182 }
183}
184
185pub struct VectorPool {
187 pool: ThreadLocalPool<Array1<f32>>,
188 vector_size: usize,
189}
190
191impl VectorPool {
192 pub fn new(vector_size: usize, max_size: usize) -> Self {
194 let pool = ThreadLocalPool::new(move || Array1::zeros(vector_size), max_size);
195
196 Self { pool, vector_size }
197 }
198
199 pub fn get_zeroed(&self) -> ThreadLocalPooledObject<Array1<f32>> {
201 let mut vec = self.pool.get();
202 vec.fill(0.0);
203 vec
204 }
205
206 pub fn get(&self) -> ThreadLocalPooledObject<Array1<f32>> {
208 self.pool.get()
209 }
210
211 pub fn vector_size(&self) -> usize {
213 self.vector_size
214 }
215
216 pub fn size(&self) -> usize {
218 self.pool.size()
219 }
220
221 pub fn clear(&self) {
223 self.pool.clear();
224 }
225}
226
227pub struct VectorPoolManager {
229 pools: std::collections::HashMap<usize, VectorPool>,
230 max_pool_size: usize,
231}
232
233impl VectorPoolManager {
234 pub fn new(max_pool_size: usize) -> Self {
236 Self {
237 pools: std::collections::HashMap::new(),
238 max_pool_size,
239 }
240 }
241
242 pub fn get_pool(&mut self, vector_size: usize) -> &VectorPool {
244 self.pools
245 .entry(vector_size)
246 .or_insert_with(|| VectorPool::new(vector_size, self.max_pool_size))
247 }
248
249 pub fn clear_all(&mut self) {
251 for pool in self.pools.values() {
252 pool.clear();
253 }
254 }
255}
256
257thread_local! {
259 static VECTOR_POOL_MANAGER: RefCell<VectorPoolManager> = RefCell::new(VectorPoolManager::new(16));
260}
261
262thread_local! {
264 static VECTOR_POOL_1024: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
265 static VECTOR_POOL_512: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
266 static VECTOR_POOL_256: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
267 static VECTOR_POOL_128: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
268 static VECTOR_POOL_64: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
269}
270
271pub fn get_vector(size: usize) -> Array1<f32> {
273 match size {
274 1024 => get_vector_from_pool(&VECTOR_POOL_1024, size),
275 512 => get_vector_from_pool(&VECTOR_POOL_512, size),
276 256 => get_vector_from_pool(&VECTOR_POOL_256, size),
277 128 => get_vector_from_pool(&VECTOR_POOL_128, size),
278 64 => get_vector_from_pool(&VECTOR_POOL_64, size),
279 _ => Array1::zeros(size), }
281}
282
283pub fn get_zeroed_vector(size: usize) -> Array1<f32> {
285 let mut vec = get_vector(size);
286 vec.fill(0.0);
287 vec
288}
289
290fn get_vector_from_pool(
292 pool: &'static std::thread::LocalKey<std::cell::RefCell<VecDeque<Array1<f32>>>>,
293 size: usize,
294) -> Array1<f32> {
295 pool.with(|pool_ref| {
296 let mut pool = pool_ref.borrow_mut();
297 pool.pop_front().unwrap_or_else(|| Array1::zeros(size))
298 })
299}
300
301pub fn return_vector(mut vec: Array1<f32>) {
303 let size = vec.len();
304
305 let pool = match size {
307 1024 => Some(&VECTOR_POOL_1024),
308 512 => Some(&VECTOR_POOL_512),
309 256 => Some(&VECTOR_POOL_256),
310 128 => Some(&VECTOR_POOL_128),
311 64 => Some(&VECTOR_POOL_64),
312 _ => None,
313 };
314
315 if let Some(pool) = pool {
316 vec.fill(0.0);
318
319 pool.with(|pool_ref| {
320 let mut pool = pool_ref.borrow_mut();
321
322 if pool.len() < 10 {
324 pool.push_back(vec);
325 }
326 });
328 }
329 }
331
332pub struct PooledVector {
334 vec: Option<Array1<f32>>,
335}
336
337impl PooledVector {
338 pub fn new(size: usize) -> Self {
340 Self {
341 vec: Some(get_vector(size)),
342 }
343 }
344
345 pub fn zeroed(size: usize) -> Self {
347 Self {
348 vec: Some(get_zeroed_vector(size)),
349 }
350 }
351
352 pub fn as_ref(&self) -> &Array1<f32> {
354 self.vec.as_ref().expect("Vector should always be present")
355 }
356
357 pub fn as_mut(&mut self) -> &mut Array1<f32> {
359 self.vec.as_mut().expect("Vector should always be present")
360 }
361
362 pub fn take(mut self) -> Array1<f32> {
364 self.vec.take().expect("Vector should always be present")
365 }
366}
367
368impl Drop for PooledVector {
369 fn drop(&mut self) {
370 if let Some(vec) = self.vec.take() {
371 return_vector(vec);
372 }
373 }
374}
375
376impl std::ops::Deref for PooledVector {
377 type Target = Array1<f32>;
378
379 fn deref(&self) -> &Self::Target {
380 self.as_ref()
381 }
382}
383
384impl std::ops::DerefMut for PooledVector {
385 fn deref_mut(&mut self) -> &mut Self::Target {
386 self.as_mut()
387 }
388}
389
390pub fn clear_vector_pools() {
392 VECTOR_POOL_1024.with(|pool| pool.borrow_mut().clear());
393 VECTOR_POOL_512.with(|pool| pool.borrow_mut().clear());
394 VECTOR_POOL_256.with(|pool| pool.borrow_mut().clear());
395 VECTOR_POOL_128.with(|pool| pool.borrow_mut().clear());
396 VECTOR_POOL_64.with(|pool| pool.borrow_mut().clear());
397}
398
399pub fn get_vector_pool_stats() -> std::collections::HashMap<usize, usize> {
401 let mut stats = std::collections::HashMap::new();
402
403 VECTOR_POOL_1024.with(|pool| {
404 stats.insert(1024, pool.borrow().len());
405 });
406 VECTOR_POOL_512.with(|pool| {
407 stats.insert(512, pool.borrow().len());
408 });
409 VECTOR_POOL_256.with(|pool| {
410 stats.insert(256, pool.borrow().len());
411 });
412 VECTOR_POOL_128.with(|pool| {
413 stats.insert(128, pool.borrow().len());
414 });
415 VECTOR_POOL_64.with(|pool| {
416 stats.insert(64, pool.borrow().len());
417 });
418
419 stats
420}
421
422pub type MovePool = ObjectPool<Vec<chess::ChessMove>>;
424
425pub fn create_move_pool(max_size: usize) -> MovePool {
427 ObjectPool::new(Vec::new, max_size)
428}
429
430pub type HashMapPool<K, V> = ObjectPool<std::collections::HashMap<K, V>>;
432
433pub fn create_hashmap_pool<K, V>(max_size: usize) -> HashMapPool<K, V>
435where
436 K: std::hash::Hash + Eq + 'static,
437 V: 'static,
438{
439 ObjectPool::new(std::collections::HashMap::new, max_size)
440}
441
442pub trait Resettable {
444 fn reset(&mut self);
446}
447
448impl<T> Resettable for Vec<T> {
449 fn reset(&mut self) {
450 self.clear();
451 }
452}
453
454impl<K, V> Resettable for std::collections::HashMap<K, V>
455where
456 K: std::hash::Hash + Eq,
457{
458 fn reset(&mut self) {
459 self.clear();
460 }
461}
462
463impl Resettable for Array1<f32> {
464 fn reset(&mut self) {
465 self.fill(0.0);
466 }
467}
468
469pub struct ResettablePool<T: Resettable> {
471 pool: ObjectPool<T>,
472}
473
474impl<T: Resettable> ResettablePool<T> {
475 pub fn new<F>(factory: F, max_size: usize) -> Self
477 where
478 F: Fn() -> T + Send + Sync + 'static,
479 {
480 Self {
481 pool: ObjectPool::new(factory, max_size),
482 }
483 }
484
485 pub fn get_reset(&self) -> PooledObject<T> {
487 let mut obj = self.pool.get();
488 obj.reset();
489 obj
490 }
491
492 pub fn get(&self) -> PooledObject<T> {
494 self.pool.get()
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_object_pool() {
504 let pool = ObjectPool::new(|| Vec::<i32>::new(), 10);
505
506 {
508 let mut obj1 = pool.get();
509 obj1.push(1);
510 obj1.push(2);
511 assert_eq!(pool.size(), 0);
512 }
513
514 assert_eq!(pool.size(), 1);
516
517 {
519 let obj2 = pool.get();
520 assert_eq!(obj2.len(), 2); }
522 }
523
524 #[test]
525 fn test_thread_local_pool() {
526 let pool = ThreadLocalPool::new(|| Vec::<i32>::new(), 5);
527
528 {
529 let mut obj = pool.get();
530 obj.push(42);
531 assert_eq!(pool.size(), 0);
532 }
533
534 assert_eq!(pool.size(), 1);
535
536 {
537 let obj = pool.get();
538 assert_eq!(obj.len(), 1);
539 assert_eq!(obj[0], 42);
540 }
541 }
542
543 #[test]
544 fn test_vector_pool() {
545 let pool = VectorPool::new(100, 5);
546
547 {
548 let mut vec = pool.get_zeroed();
549 vec[0] = 1.0;
550 vec[1] = 2.0;
551 assert_eq!(pool.size(), 0);
552 }
553
554 assert_eq!(pool.size(), 1);
555
556 {
557 let vec = pool.get_zeroed();
558 assert_eq!(vec[0], 0.0); assert_eq!(vec[1], 0.0);
560 }
561 }
562
563 #[test]
564 fn test_resettable_pool() {
565 let pool = ResettablePool::new(|| Vec::<i32>::new(), 3);
566
567 {
568 let mut obj = pool.get_reset();
569 obj.push(1);
570 obj.push(2);
571 }
572
573 {
574 let obj = pool.get_reset();
575 assert_eq!(obj.len(), 0); }
577 }
578
579 #[test]
580 fn test_pool_max_size() {
581 let pool = ObjectPool::new(|| Vec::<i32>::new(), 2);
582
583 {
585 let _obj1 = pool.get();
586 let _obj2 = pool.get();
587 let _obj3 = pool.get();
588 }
589
590 assert_eq!(pool.size(), 2);
592 }
593
594 #[test]
595 fn test_global_vector_pool() {
596 let vec1 = get_zeroed_vector(1024);
597 assert_eq!(vec1.len(), 1024);
598
599 let vec2 = get_vector(512);
600 assert_eq!(vec2.len(), 512);
601 }
602
603 #[test]
604 fn test_thread_local_vector_pooling() {
605 clear_vector_pools();
607
608 let vec1024 = get_vector(1024);
610 let vec512 = get_vector(512);
611 let vec256 = get_vector(256);
612
613 assert_eq!(vec1024.len(), 1024);
614 assert_eq!(vec512.len(), 512);
615 assert_eq!(vec256.len(), 256);
616
617 return_vector(vec1024);
619 return_vector(vec512);
620 return_vector(vec256);
621
622 let stats = get_vector_pool_stats();
624 assert_eq!(stats.get(&1024), Some(&1));
625 assert_eq!(stats.get(&512), Some(&1));
626 assert_eq!(stats.get(&256), Some(&1));
627
628 let vec1024_reused = get_vector(1024);
630 let vec512_reused = get_vector(512);
631
632 assert_eq!(vec1024_reused.len(), 1024);
633 assert_eq!(vec512_reused.len(), 512);
634
635 let stats_after = get_vector_pool_stats();
637 assert_eq!(stats_after.get(&1024), Some(&0));
638 assert_eq!(stats_after.get(&512), Some(&0));
639 assert_eq!(stats_after.get(&256), Some(&1)); }
641
642 #[test]
643 fn test_pooled_vector_raii() {
644 clear_vector_pools();
645
646 {
648 let mut pooled = PooledVector::new(1024);
649 assert_eq!(pooled.len(), 1024);
650
651 pooled[0] = 42.0;
653 assert_eq!(pooled[0], 42.0);
654 } let stats = get_vector_pool_stats();
658 assert_eq!(stats.get(&1024), Some(&1));
659
660 let vec = get_vector(1024);
662 assert_eq!(vec[0], 0.0); }
664
665 #[test]
666 fn test_pooled_vector_take() {
667 clear_vector_pools();
668
669 let pooled = PooledVector::new(512);
671 let vec = pooled.take(); assert_eq!(vec.len(), 512);
674
675 let stats = get_vector_pool_stats();
677 assert_eq!(stats.get(&512), Some(&0));
678 }
679
680 #[test]
681 fn test_pool_size_limit() {
682 clear_vector_pools();
683
684 for _ in 0..15 {
686 let vec = get_vector(128);
687 return_vector(vec);
688 }
689
690 let stats = get_vector_pool_stats();
692 let pool_size = stats.get(&128).unwrap_or(&0);
693 assert!(*pool_size > 0, "Pool should have at least 1 vector");
695 assert!(
696 *pool_size <= 10,
697 "Pool size should be limited to 10, but got {}",
698 pool_size
699 );
700
701 let vec = get_vector(128);
703 assert_eq!(vec.len(), 128);
704 }
705
706 #[test]
707 fn test_non_standard_size_vectors() {
708 let vec = get_vector(100); assert_eq!(vec.len(), 100);
711
712 return_vector(vec);
714
715 let stats = get_vector_pool_stats();
717 assert_eq!(stats.get(&100), None);
718 }
719
720 #[test]
721 fn test_zeroed_vector_function() {
722 let vec = get_zeroed_vector(256);
723 assert_eq!(vec.len(), 256);
724
725 for &value in vec.iter() {
727 assert_eq!(value, 0.0);
728 }
729 }
730}