Skip to main content

ad_core/
ndarray_handle.rs

1use std::mem::ManuallyDrop;
2use std::ops::Deref;
3use std::sync::Arc;
4
5use crate::ndarray::NDArray;
6use crate::ndarray_pool::NDArrayPool;
7
8/// Pool-aware wrapper. On final drop, returns array to pool.
9pub struct PooledNDArray {
10    array: ManuallyDrop<NDArray>,
11    pool: Arc<NDArrayPool>,
12}
13
14impl Deref for PooledNDArray {
15    type Target = NDArray;
16    fn deref(&self) -> &NDArray {
17        &self.array
18    }
19}
20
21impl Drop for PooledNDArray {
22    fn drop(&mut self) {
23        // SAFETY: only taken once in drop, never accessed after
24        let array = unsafe { ManuallyDrop::take(&mut self.array) };
25        self.pool.release(array);
26    }
27}
28
29/// Cloneable handle. Inner Arc ensures pool return on last clone drop.
30pub type NDArrayHandle = Arc<PooledNDArray>;
31
32/// Create a pool-aware handle wrapping an NDArray.
33pub fn pooled_array(array: NDArray, pool: &Arc<NDArrayPool>) -> NDArrayHandle {
34    Arc::new(PooledNDArray {
35        array: ManuallyDrop::new(array),
36        pool: Arc::clone(pool),
37    })
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use crate::ndarray::{NDDataType, NDDimension};
44
45    #[test]
46    fn test_pooled_array_returns_to_pool_on_drop() {
47        let pool = Arc::new(NDArrayPool::new(1_000_000));
48        let arr = pool
49            .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
50            .unwrap();
51        assert_eq!(pool.num_free_buffers(), 0);
52
53        let handle = pooled_array(arr, &pool);
54        drop(handle);
55
56        assert_eq!(pool.num_free_buffers(), 1);
57    }
58
59    #[test]
60    fn test_clone_keeps_alive_drop_both_returns() {
61        let pool = Arc::new(NDArrayPool::new(1_000_000));
62        let arr = pool
63            .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
64            .unwrap();
65
66        let handle = pooled_array(arr, &pool);
67        let handle2 = handle.clone();
68
69        drop(handle);
70        assert_eq!(pool.num_free_buffers(), 0, "still one clone alive");
71
72        drop(handle2);
73        assert_eq!(pool.num_free_buffers(), 1, "both dropped, returned to pool");
74    }
75
76    #[test]
77    fn test_deref_access() {
78        let pool = Arc::new(NDArrayPool::new(1_000_000));
79        let arr = pool
80            .alloc(vec![NDDimension::new(50)], NDDataType::Float64)
81            .unwrap();
82        let id = arr.unique_id;
83
84        let handle = pooled_array(arr, &pool);
85        assert_eq!(handle.unique_id, id);
86        assert_eq!(handle.data.len(), 50);
87        assert_eq!(handle.dims[0].size, 50);
88    }
89
90    #[test]
91    fn test_alloc_handle_via_pool() {
92        let pool = Arc::new(NDArrayPool::new(1_000_000));
93        let handle = NDArrayPool::alloc_handle(
94            &pool,
95            vec![NDDimension::new(64)],
96            NDDataType::UInt16,
97        )
98        .unwrap();
99        assert_eq!(handle.data.len(), 64);
100        let alloc_before = pool.num_alloc_buffers();
101
102        drop(handle);
103        assert_eq!(pool.num_free_buffers(), 1);
104        assert_eq!(pool.num_alloc_buffers(), alloc_before);
105    }
106}