chess_vector_engine/utils/
object_pool.rs

1// Removed unused imports
2use ndarray::Array1;
3use std::cell::RefCell;
4use std::collections::VecDeque;
5use std::sync::{Arc, Mutex};
6
7/// Thread-safe object pool for reusing expensive-to-create objects
8pub struct ObjectPool<T> {
9    pool: Arc<Mutex<VecDeque<T>>>,
10    factory: Arc<dyn Fn() -> T + Send + Sync>,
11    max_size: usize,
12}
13
14impl<T> ObjectPool<T> {
15    /// Create a new object pool with a factory function
16    pub fn new<F>(factory: F, max_size: usize) -> Self
17    where
18        F: Fn() -> T + Send + Sync + 'static,
19    {
20        Self {
21            pool: Arc::new(Mutex::new(VecDeque::new())),
22            factory: Arc::new(factory),
23            max_size,
24        }
25    }
26
27    /// Get an object from the pool, creating one if necessary
28    pub fn get(&self) -> PooledObject<T> {
29        let obj = {
30            let mut pool = self.pool.lock().unwrap();
31            pool.pop_front().unwrap_or_else(|| (self.factory)())
32        };
33
34        PooledObject {
35            object: Some(obj),
36            pool: Arc::clone(&self.pool),
37            max_size: self.max_size,
38        }
39    }
40
41    /// Get the current pool size
42    pub fn size(&self) -> usize {
43        self.pool.lock().unwrap().len()
44    }
45
46    /// Clear the pool
47    pub fn clear(&self) {
48        self.pool.lock().unwrap().clear();
49    }
50}
51
52/// A pooled object that returns to the pool when dropped
53pub struct PooledObject<T> {
54    object: Option<T>,
55    pool: Arc<Mutex<VecDeque<T>>>,
56    max_size: usize,
57}
58
59impl<T> PooledObject<T> {
60    /// Get a reference to the pooled object
61    pub fn get(&self) -> &T {
62        self.object.as_ref().unwrap()
63    }
64
65    /// Get a mutable reference to the pooled object
66    pub fn get_mut(&mut self) -> &mut T {
67        self.object.as_mut().unwrap()
68    }
69}
70
71impl<T> Drop for PooledObject<T> {
72    fn drop(&mut self) {
73        if let Some(obj) = self.object.take() {
74            let mut pool = self.pool.lock().unwrap();
75            if pool.len() < self.max_size {
76                pool.push_back(obj);
77            }
78        }
79    }
80}
81
82impl<T> std::ops::Deref for PooledObject<T> {
83    type Target = T;
84
85    fn deref(&self) -> &Self::Target {
86        self.get()
87    }
88}
89
90impl<T> std::ops::DerefMut for PooledObject<T> {
91    fn deref_mut(&mut self) -> &mut Self::Target {
92        self.get_mut()
93    }
94}
95
96/// Thread-local object pool for single-threaded performance
97pub struct ThreadLocalPool<T> {
98    pool: RefCell<VecDeque<T>>,
99    factory: Box<dyn Fn() -> T>,
100    max_size: usize,
101}
102
103impl<T> ThreadLocalPool<T> {
104    /// Create a new thread-local pool
105    pub fn new<F>(factory: F, max_size: usize) -> Self
106    where
107        F: Fn() -> T + 'static,
108    {
109        Self {
110            pool: RefCell::new(VecDeque::new()),
111            factory: Box::new(factory),
112            max_size,
113        }
114    }
115
116    /// Get an object from the pool
117    pub fn get(&self) -> ThreadLocalPooledObject<T> {
118        let obj = {
119            let mut pool = self.pool.borrow_mut();
120            pool.pop_front().unwrap_or_else(|| (self.factory)())
121        };
122
123        ThreadLocalPooledObject {
124            object: Some(obj),
125            pool: &self.pool,
126            max_size: self.max_size,
127        }
128    }
129
130    /// Get the current pool size
131    pub fn size(&self) -> usize {
132        self.pool.borrow().len()
133    }
134
135    /// Clear the pool
136    pub fn clear(&self) {
137        self.pool.borrow_mut().clear();
138    }
139}
140
141/// Thread-local pooled object
142pub struct ThreadLocalPooledObject<'a, T> {
143    object: Option<T>,
144    pool: &'a RefCell<VecDeque<T>>,
145    max_size: usize,
146}
147
148impl<'a, T> ThreadLocalPooledObject<'a, T> {
149    /// Get a reference to the pooled object
150    pub fn get(&self) -> &T {
151        self.object.as_ref().unwrap()
152    }
153
154    /// Get a mutable reference to the pooled object
155    pub fn get_mut(&mut self) -> &mut T {
156        self.object.as_mut().unwrap()
157    }
158}
159
160impl<'a, T> Drop for ThreadLocalPooledObject<'a, T> {
161    fn drop(&mut self) {
162        if let Some(obj) = self.object.take() {
163            let mut pool = self.pool.borrow_mut();
164            if pool.len() < self.max_size {
165                pool.push_back(obj);
166            }
167        }
168    }
169}
170
171impl<'a, T> std::ops::Deref for ThreadLocalPooledObject<'a, T> {
172    type Target = T;
173
174    fn deref(&self) -> &Self::Target {
175        self.get()
176    }
177}
178
179impl<'a, T> std::ops::DerefMut for ThreadLocalPooledObject<'a, T> {
180    fn deref_mut(&mut self) -> &mut Self::Target {
181        self.get_mut()
182    }
183}
184
185/// Specialized vector pool for chess engine operations
186pub struct VectorPool {
187    pool: ThreadLocalPool<Array1<f32>>,
188    vector_size: usize,
189}
190
191impl VectorPool {
192    /// Create a new vector pool
193    pub fn new(vector_size: usize, max_size: usize) -> Self {
194        let pool = ThreadLocalPool::new(move || Array1::zeros(vector_size), max_size);
195
196        Self { pool, vector_size }
197    }
198
199    /// Get a zeroed vector from the pool
200    pub fn get_zeroed(&self) -> ThreadLocalPooledObject<Array1<f32>> {
201        let mut vec = self.pool.get();
202        vec.fill(0.0);
203        vec
204    }
205
206    /// Get a vector from the pool (contents undefined)
207    pub fn get(&self) -> ThreadLocalPooledObject<Array1<f32>> {
208        self.pool.get()
209    }
210
211    /// Get the vector size
212    pub fn vector_size(&self) -> usize {
213        self.vector_size
214    }
215
216    /// Get the current pool size
217    pub fn size(&self) -> usize {
218        self.pool.size()
219    }
220
221    /// Clear the pool
222    pub fn clear(&self) {
223        self.pool.clear();
224    }
225}
226
227/// Global vector pool manager
228pub struct VectorPoolManager {
229    pools: std::collections::HashMap<usize, VectorPool>,
230    max_pool_size: usize,
231}
232
233impl VectorPoolManager {
234    /// Create a new vector pool manager
235    pub fn new(max_pool_size: usize) -> Self {
236        Self {
237            pools: std::collections::HashMap::new(),
238            max_pool_size,
239        }
240    }
241
242    /// Get or create a vector pool for a specific size
243    pub fn get_pool(&mut self, vector_size: usize) -> &VectorPool {
244        self.pools
245            .entry(vector_size)
246            .or_insert_with(|| VectorPool::new(vector_size, self.max_pool_size))
247    }
248
249    /// Clear all pools
250    pub fn clear_all(&mut self) {
251        for pool in self.pools.values() {
252            pool.clear();
253        }
254    }
255}
256
257/// Thread-local vector pool instance
258thread_local! {
259    static VECTOR_POOL_MANAGER: RefCell<VectorPoolManager> = RefCell::new(VectorPoolManager::new(16));
260}
261
262/// Thread-local vector pool for efficient reuse
263thread_local! {
264    static VECTOR_POOL_1024: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
265    static VECTOR_POOL_512: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
266    static VECTOR_POOL_256: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
267    static VECTOR_POOL_128: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
268    static VECTOR_POOL_64: std::cell::RefCell<VecDeque<Array1<f32>>> = std::cell::RefCell::new(VecDeque::new());
269}
270
271/// Get a vector from the appropriate thread-local pool
272pub fn get_vector(size: usize) -> Array1<f32> {
273    match size {
274        1024 => get_vector_from_pool(&VECTOR_POOL_1024, size),
275        512 => get_vector_from_pool(&VECTOR_POOL_512, size),
276        256 => get_vector_from_pool(&VECTOR_POOL_256, size),
277        128 => get_vector_from_pool(&VECTOR_POOL_128, size),
278        64 => get_vector_from_pool(&VECTOR_POOL_64, size),
279        _ => Array1::zeros(size), // For non-standard sizes, just create new
280    }
281}
282
283/// Get a zeroed vector from the thread-local pool
284pub fn get_zeroed_vector(size: usize) -> Array1<f32> {
285    let mut vec = get_vector(size);
286    vec.fill(0.0);
287    vec
288}
289
290/// Helper function to get vector from specific pool
291fn get_vector_from_pool(
292    pool: &'static std::thread::LocalKey<std::cell::RefCell<VecDeque<Array1<f32>>>>,
293    size: usize,
294) -> Array1<f32> {
295    pool.with(|pool_ref| {
296        let mut pool = pool_ref.borrow_mut();
297        pool.pop_front().unwrap_or_else(|| Array1::zeros(size))
298    })
299}
300
301/// Return a vector to the appropriate thread-local pool
302pub fn return_vector(mut vec: Array1<f32>) {
303    let size = vec.len();
304
305    // Only pool commonly used sizes to prevent memory bloat
306    let pool = match size {
307        1024 => Some(&VECTOR_POOL_1024),
308        512 => Some(&VECTOR_POOL_512),
309        256 => Some(&VECTOR_POOL_256),
310        128 => Some(&VECTOR_POOL_128),
311        64 => Some(&VECTOR_POOL_64),
312        _ => None,
313    };
314
315    if let Some(pool) = pool {
316        // Reset the vector to zeros for reuse
317        vec.fill(0.0);
318
319        pool.with(|pool_ref| {
320            let mut pool = pool_ref.borrow_mut();
321
322            // Limit pool size to prevent memory bloat (max 10 vectors per size)
323            if pool.len() < 10 {
324                pool.push_back(vec);
325            }
326            // If pool is full, just drop the vector
327        });
328    }
329    // For non-standard sizes, just let the vector drop
330}
331
332/// RAII wrapper for automatic return to pool
333pub struct PooledVector {
334    vec: Option<Array1<f32>>,
335}
336
337impl PooledVector {
338    /// Create a new pooled vector
339    pub fn new(size: usize) -> Self {
340        Self {
341            vec: Some(get_vector(size)),
342        }
343    }
344
345    /// Create a new zeroed pooled vector
346    pub fn zeroed(size: usize) -> Self {
347        Self {
348            vec: Some(get_zeroed_vector(size)),
349        }
350    }
351
352    /// Get a reference to the underlying vector
353    pub fn as_ref(&self) -> &Array1<f32> {
354        self.vec.as_ref().expect("Vector should always be present")
355    }
356
357    /// Get a mutable reference to the underlying vector
358    pub fn as_mut(&mut self) -> &mut Array1<f32> {
359        self.vec.as_mut().expect("Vector should always be present")
360    }
361
362    /// Take ownership of the vector (prevents automatic return to pool)
363    pub fn take(mut self) -> Array1<f32> {
364        self.vec.take().expect("Vector should always be present")
365    }
366}
367
368impl Drop for PooledVector {
369    fn drop(&mut self) {
370        if let Some(vec) = self.vec.take() {
371            return_vector(vec);
372        }
373    }
374}
375
376impl std::ops::Deref for PooledVector {
377    type Target = Array1<f32>;
378
379    fn deref(&self) -> &Self::Target {
380        self.as_ref()
381    }
382}
383
384impl std::ops::DerefMut for PooledVector {
385    fn deref_mut(&mut self) -> &mut Self::Target {
386        self.as_mut()
387    }
388}
389
390/// Clear all thread-local vector pools (useful for testing and cleanup)
391pub fn clear_vector_pools() {
392    VECTOR_POOL_1024.with(|pool| pool.borrow_mut().clear());
393    VECTOR_POOL_512.with(|pool| pool.borrow_mut().clear());
394    VECTOR_POOL_256.with(|pool| pool.borrow_mut().clear());
395    VECTOR_POOL_128.with(|pool| pool.borrow_mut().clear());
396    VECTOR_POOL_64.with(|pool| pool.borrow_mut().clear());
397}
398
399/// Get statistics about thread-local vector pools
400pub fn get_vector_pool_stats() -> std::collections::HashMap<usize, usize> {
401    let mut stats = std::collections::HashMap::new();
402
403    VECTOR_POOL_1024.with(|pool| {
404        stats.insert(1024, pool.borrow().len());
405    });
406    VECTOR_POOL_512.with(|pool| {
407        stats.insert(512, pool.borrow().len());
408    });
409    VECTOR_POOL_256.with(|pool| {
410        stats.insert(256, pool.borrow().len());
411    });
412    VECTOR_POOL_128.with(|pool| {
413        stats.insert(128, pool.borrow().len());
414    });
415    VECTOR_POOL_64.with(|pool| {
416        stats.insert(64, pool.borrow().len());
417    });
418
419    stats
420}
421
422/// Pool for chess move vectors
423pub type MovePool = ObjectPool<Vec<chess::ChessMove>>;
424
425/// Create a move pool
426pub fn create_move_pool(max_size: usize) -> MovePool {
427    ObjectPool::new(Vec::new, max_size)
428}
429
430/// Pool for hash maps
431pub type HashMapPool<K, V> = ObjectPool<std::collections::HashMap<K, V>>;
432
433/// Create a hash map pool
434pub fn create_hashmap_pool<K, V>(max_size: usize) -> HashMapPool<K, V>
435where
436    K: std::hash::Hash + Eq + 'static,
437    V: 'static,
438{
439    ObjectPool::new(std::collections::HashMap::new, max_size)
440}
441
442/// Trait for resettable objects (objects that can be reused)
443pub trait Resettable {
444    /// Reset the object to its initial state
445    fn reset(&mut self);
446}
447
448impl<T> Resettable for Vec<T> {
449    fn reset(&mut self) {
450        self.clear();
451    }
452}
453
454impl<K, V> Resettable for std::collections::HashMap<K, V>
455where
456    K: std::hash::Hash + Eq,
457{
458    fn reset(&mut self) {
459        self.clear();
460    }
461}
462
463impl Resettable for Array1<f32> {
464    fn reset(&mut self) {
465        self.fill(0.0);
466    }
467}
468
469/// Pool for resettable objects
470pub struct ResettablePool<T: Resettable> {
471    pool: ObjectPool<T>,
472}
473
474impl<T: Resettable> ResettablePool<T> {
475    /// Create a new resettable pool
476    pub fn new<F>(factory: F, max_size: usize) -> Self
477    where
478        F: Fn() -> T + Send + Sync + 'static,
479    {
480        Self {
481            pool: ObjectPool::new(factory, max_size),
482        }
483    }
484
485    /// Get a reset object from the pool
486    pub fn get_reset(&self) -> PooledObject<T> {
487        let mut obj = self.pool.get();
488        obj.reset();
489        obj
490    }
491
492    /// Get an object from the pool (contents undefined)
493    pub fn get(&self) -> PooledObject<T> {
494        self.pool.get()
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_object_pool() {
504        let pool = ObjectPool::new(|| Vec::<i32>::new(), 10);
505
506        // Test getting and returning objects
507        {
508            let mut obj1 = pool.get();
509            obj1.push(1);
510            obj1.push(2);
511            assert_eq!(pool.size(), 0);
512        }
513
514        // Object should be returned to pool
515        assert_eq!(pool.size(), 1);
516
517        // Test reusing object
518        {
519            let obj2 = pool.get();
520            assert_eq!(obj2.len(), 2); // Should contain previous data
521        }
522    }
523
524    #[test]
525    fn test_thread_local_pool() {
526        let pool = ThreadLocalPool::new(|| Vec::<i32>::new(), 5);
527
528        {
529            let mut obj = pool.get();
530            obj.push(42);
531            assert_eq!(pool.size(), 0);
532        }
533
534        assert_eq!(pool.size(), 1);
535
536        {
537            let obj = pool.get();
538            assert_eq!(obj.len(), 1);
539            assert_eq!(obj[0], 42);
540        }
541    }
542
543    #[test]
544    fn test_vector_pool() {
545        let pool = VectorPool::new(100, 5);
546
547        {
548            let mut vec = pool.get_zeroed();
549            vec[0] = 1.0;
550            vec[1] = 2.0;
551            assert_eq!(pool.size(), 0);
552        }
553
554        assert_eq!(pool.size(), 1);
555
556        {
557            let vec = pool.get_zeroed();
558            assert_eq!(vec[0], 0.0); // Should be zeroed
559            assert_eq!(vec[1], 0.0);
560        }
561    }
562
563    #[test]
564    fn test_resettable_pool() {
565        let pool = ResettablePool::new(|| Vec::<i32>::new(), 3);
566
567        {
568            let mut obj = pool.get_reset();
569            obj.push(1);
570            obj.push(2);
571        }
572
573        {
574            let obj = pool.get_reset();
575            assert_eq!(obj.len(), 0); // Should be reset
576        }
577    }
578
579    #[test]
580    fn test_pool_max_size() {
581        let pool = ObjectPool::new(|| Vec::<i32>::new(), 2);
582
583        // Fill pool to capacity
584        {
585            let _obj1 = pool.get();
586            let _obj2 = pool.get();
587            let _obj3 = pool.get();
588        }
589
590        // Should only store 2 objects
591        assert_eq!(pool.size(), 2);
592    }
593
594    #[test]
595    fn test_global_vector_pool() {
596        let vec1 = get_zeroed_vector(1024);
597        assert_eq!(vec1.len(), 1024);
598
599        let vec2 = get_vector(512);
600        assert_eq!(vec2.len(), 512);
601    }
602
603    #[test]
604    fn test_thread_local_vector_pooling() {
605        // Clear pools to start fresh
606        clear_vector_pools();
607
608        // Get vectors of different sizes
609        let vec1024 = get_vector(1024);
610        let vec512 = get_vector(512);
611        let vec256 = get_vector(256);
612
613        assert_eq!(vec1024.len(), 1024);
614        assert_eq!(vec512.len(), 512);
615        assert_eq!(vec256.len(), 256);
616
617        // Return vectors to pool
618        return_vector(vec1024);
619        return_vector(vec512);
620        return_vector(vec256);
621
622        // Check pool stats
623        let stats = get_vector_pool_stats();
624        assert_eq!(stats.get(&1024), Some(&1));
625        assert_eq!(stats.get(&512), Some(&1));
626        assert_eq!(stats.get(&256), Some(&1));
627
628        // Get vectors again - should reuse from pool
629        let vec1024_reused = get_vector(1024);
630        let vec512_reused = get_vector(512);
631
632        assert_eq!(vec1024_reused.len(), 1024);
633        assert_eq!(vec512_reused.len(), 512);
634
635        // Pool should now have one fewer vector
636        let stats_after = get_vector_pool_stats();
637        assert_eq!(stats_after.get(&1024), Some(&0));
638        assert_eq!(stats_after.get(&512), Some(&0));
639        assert_eq!(stats_after.get(&256), Some(&1)); // This one wasn't reused
640    }
641
642    #[test]
643    fn test_pooled_vector_raii() {
644        clear_vector_pools();
645
646        // Create a pooled vector in scope
647        {
648            let mut pooled = PooledVector::new(1024);
649            assert_eq!(pooled.len(), 1024);
650
651            // Modify the vector
652            pooled[0] = 42.0;
653            assert_eq!(pooled[0], 42.0);
654        } // pooled goes out of scope, should return to pool
655
656        // Check that it was returned to pool
657        let stats = get_vector_pool_stats();
658        assert_eq!(stats.get(&1024), Some(&1));
659
660        // Get the vector again - should be zeroed
661        let vec = get_vector(1024);
662        assert_eq!(vec[0], 0.0); // Should be reset to zero
663    }
664
665    #[test]
666    fn test_pooled_vector_take() {
667        clear_vector_pools();
668
669        // Create a pooled vector and take ownership
670        let pooled = PooledVector::new(512);
671        let vec = pooled.take(); // Take ownership, won't return to pool
672
673        assert_eq!(vec.len(), 512);
674
675        // Pool should still be empty since we took ownership
676        let stats = get_vector_pool_stats();
677        assert_eq!(stats.get(&512), Some(&0));
678    }
679
680    #[test]
681    fn test_pool_size_limit() {
682        clear_vector_pools();
683
684        // Return more vectors than the pool limit (10)
685        for _ in 0..15 {
686            let vec = get_vector(128);
687            return_vector(vec);
688        }
689
690        // Pool should be limited to 10 vectors
691        let stats = get_vector_pool_stats();
692        let pool_size = stats.get(&128).unwrap_or(&0);
693        // The pool should have at least 1 vector but no more than 10
694        assert!(*pool_size > 0, "Pool should have at least 1 vector");
695        assert!(
696            *pool_size <= 10,
697            "Pool size should be limited to 10, but got {}",
698            pool_size
699        );
700
701        // Test that we can get vectors from the pool
702        let vec = get_vector(128);
703        assert_eq!(vec.len(), 128);
704    }
705
706    #[test]
707    fn test_non_standard_size_vectors() {
708        // Non-standard sizes should not be pooled
709        let vec = get_vector(100); // Non-standard size
710        assert_eq!(vec.len(), 100);
711
712        // Return it (should not be pooled)
713        return_vector(vec);
714
715        // Pool stats should not include this size
716        let stats = get_vector_pool_stats();
717        assert_eq!(stats.get(&100), None);
718    }
719
720    #[test]
721    fn test_zeroed_vector_function() {
722        let vec = get_zeroed_vector(256);
723        assert_eq!(vec.len(), 256);
724
725        // All elements should be zero
726        for &value in vec.iter() {
727            assert_eq!(value, 0.0);
728        }
729    }
730}