ad_core/
ndarray_handle.rs1use std::mem::ManuallyDrop;
2use std::ops::Deref;
3use std::sync::Arc;
4
5use crate::ndarray::NDArray;
6use crate::ndarray_pool::NDArrayPool;
7
8pub struct PooledNDArray {
10 array: ManuallyDrop<NDArray>,
11 pool: Arc<NDArrayPool>,
12}
13
14impl Deref for PooledNDArray {
15 type Target = NDArray;
16 fn deref(&self) -> &NDArray {
17 &self.array
18 }
19}
20
21impl Drop for PooledNDArray {
22 fn drop(&mut self) {
23 let array = unsafe { ManuallyDrop::take(&mut self.array) };
25 self.pool.release(array);
26 }
27}
28
29pub type NDArrayHandle = Arc<PooledNDArray>;
31
32pub fn pooled_array(array: NDArray, pool: &Arc<NDArrayPool>) -> NDArrayHandle {
34 Arc::new(PooledNDArray {
35 array: ManuallyDrop::new(array),
36 pool: Arc::clone(pool),
37 })
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use crate::ndarray::{NDDataType, NDDimension};
44
45 #[test]
46 fn test_pooled_array_returns_to_pool_on_drop() {
47 let pool = Arc::new(NDArrayPool::new(1_000_000));
48 let arr = pool
49 .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
50 .unwrap();
51 assert_eq!(pool.num_free_buffers(), 0);
52
53 let handle = pooled_array(arr, &pool);
54 drop(handle);
55
56 assert_eq!(pool.num_free_buffers(), 1);
57 }
58
59 #[test]
60 fn test_clone_keeps_alive_drop_both_returns() {
61 let pool = Arc::new(NDArrayPool::new(1_000_000));
62 let arr = pool
63 .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
64 .unwrap();
65
66 let handle = pooled_array(arr, &pool);
67 let handle2 = handle.clone();
68
69 drop(handle);
70 assert_eq!(pool.num_free_buffers(), 0, "still one clone alive");
71
72 drop(handle2);
73 assert_eq!(pool.num_free_buffers(), 1, "both dropped, returned to pool");
74 }
75
76 #[test]
77 fn test_deref_access() {
78 let pool = Arc::new(NDArrayPool::new(1_000_000));
79 let arr = pool
80 .alloc(vec![NDDimension::new(50)], NDDataType::Float64)
81 .unwrap();
82 let id = arr.unique_id;
83
84 let handle = pooled_array(arr, &pool);
85 assert_eq!(handle.unique_id, id);
86 assert_eq!(handle.data.len(), 50);
87 assert_eq!(handle.dims[0].size, 50);
88 }
89
90 #[test]
91 fn test_alloc_handle_via_pool() {
92 let pool = Arc::new(NDArrayPool::new(1_000_000));
93 let handle = NDArrayPool::alloc_handle(
94 &pool,
95 vec![NDDimension::new(64)],
96 NDDataType::UInt16,
97 )
98 .unwrap();
99 assert_eq!(handle.data.len(), 64);
100 let alloc_before = pool.num_alloc_buffers();
101
102 drop(handle);
103 assert_eq!(pool.num_free_buffers(), 1);
104 assert_eq!(pool.num_alloc_buffers(), alloc_before);
105 }
106}