Skip to main content

ad_core/
ndarray_pool.rs

1use std::sync::atomic::{AtomicI32, AtomicU32, AtomicU64, Ordering};
2use std::sync::Arc;
3
4use parking_lot::Mutex;
5
6use crate::error::{ADError, ADResult};
7use crate::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
8use crate::ndarray_handle::{NDArrayHandle, pooled_array};
9use crate::timestamp::EpicsTimestamp;
10
11/// NDArray factory with free-list reuse and memory tracking.
12///
13/// Mimics C++ ADCore's NDArrayPool: on alloc, checks the free list for a
14/// buffer with sufficient capacity. On release, returns the buffer to the
15/// free list for future reuse. The free list is sorted by capacity (descending)
16/// and excess entries are dropped when max_memory is exceeded.
17pub struct NDArrayPool {
18    max_memory: usize,
19    allocated_bytes: AtomicU64,
20    next_unique_id: AtomicI32,
21    free_list: Mutex<Vec<NDArray>>,
22    num_alloc_buffers: AtomicU32,
23    num_free_buffers: AtomicU32,
24}
25
26impl NDArrayPool {
27    pub fn new(max_memory: usize) -> Self {
28        Self {
29            max_memory,
30            allocated_bytes: AtomicU64::new(0),
31            next_unique_id: AtomicI32::new(1),
32            free_list: Mutex::new(Vec::new()),
33            num_alloc_buffers: AtomicU32::new(0),
34            num_free_buffers: AtomicU32::new(0),
35        }
36    }
37
38    /// Allocate an NDArray. Tries to reuse a free-list entry with sufficient capacity.
39    pub fn alloc(&self, dims: Vec<NDDimension>, data_type: NDDataType) -> ADResult<NDArray> {
40        let num_elements: usize = dims.iter().map(|d| d.size).product();
41        let needed_bytes = num_elements * data_type.element_size();
42
43        // Try to find a reusable buffer in the free list
44        let reused = {
45            let mut free = self.free_list.lock();
46            // Find smallest buffer that is large enough (free list sorted descending by capacity)
47            let mut best_idx = None;
48            let mut best_cap = usize::MAX;
49            for (i, arr) in free.iter().enumerate() {
50                let cap = arr.data.capacity_bytes();
51                if cap >= needed_bytes && cap < best_cap {
52                    best_cap = cap;
53                    best_idx = Some(i);
54                }
55            }
56            if let Some(idx) = best_idx {
57                let arr = free.swap_remove(idx);
58                self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
59                Some(arr)
60            } else {
61                None
62            }
63        };
64
65        let mut arr = if let Some(mut reused) = reused {
66            // Reuse: retype the buffer if needed, resize to match
67            if reused.data.data_type() != data_type {
68                // Must reallocate with new type, but we keep the allocation tracked
69                let old_cap = reused.data.capacity_bytes();
70                reused.data = NDDataBuffer::zeros(data_type, num_elements);
71                let new_cap = reused.data.capacity_bytes();
72                // Adjust allocated_bytes for the difference
73                if new_cap > old_cap {
74                    let diff = new_cap - old_cap;
75                    let current = self.allocated_bytes.load(Ordering::Relaxed);
76                    if current + diff as u64 > self.max_memory as u64 {
77                        return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
78                    }
79                    self.allocated_bytes.fetch_add(diff as u64, Ordering::Relaxed);
80                } else {
81                    let diff = old_cap - new_cap;
82                    self.allocated_bytes.fetch_sub(diff as u64, Ordering::Relaxed);
83                }
84            } else {
85                reused.data.resize(num_elements);
86            }
87            reused.dims = dims;
88            reused.attributes.clear();
89            reused.codec = None;
90            reused
91        } else {
92            // Fresh allocation
93            let current = self.allocated_bytes.load(Ordering::Relaxed);
94            if current + needed_bytes as u64 > self.max_memory as u64 {
95                return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
96            }
97            self.allocated_bytes.fetch_add(needed_bytes as u64, Ordering::Relaxed);
98            self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
99            NDArray::new(dims, data_type)
100        };
101
102        arr.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
103        arr.timestamp = EpicsTimestamp::now();
104        Ok(arr)
105    }
106
107    /// Allocate a copy of an existing NDArray (new unique_id, data cloned).
108    pub fn alloc_copy(&self, source: &NDArray) -> ADResult<NDArray> {
109        let bytes = source.data.total_bytes();
110        let current = self.allocated_bytes.load(Ordering::Relaxed);
111        if current + bytes as u64 > self.max_memory as u64 {
112            return Err(ADError::PoolExhausted(bytes, self.max_memory));
113        }
114        self.allocated_bytes.fetch_add(bytes as u64, Ordering::Relaxed);
115        self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
116
117        let mut copy = source.clone();
118        copy.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
119        copy.timestamp = EpicsTimestamp::now();
120        Ok(copy)
121    }
122
123    /// Return an array to the free list for future reuse.
124    pub fn release(&self, array: NDArray) {
125        let cap = array.data.capacity_bytes();
126        let mut free = self.free_list.lock();
127        free.push(array);
128        self.num_free_buffers.fetch_add(1, Ordering::Relaxed);
129
130        // If total allocated exceeds max_memory, drop largest free entries
131        let total = self.allocated_bytes.load(Ordering::Relaxed) as usize;
132        if total > self.max_memory && !free.is_empty() {
133            // Sort descending by capacity so we drop largest first
134            free.sort_by(|a, b| b.data.capacity_bytes().cmp(&a.data.capacity_bytes()));
135            let mut excess = total.saturating_sub(self.max_memory);
136            while excess > 0 && !free.is_empty() {
137                let dropped = free.remove(0);
138                let dropped_cap = dropped.data.capacity_bytes();
139                self.allocated_bytes.fetch_sub(dropped_cap.min(total) as u64, Ordering::Relaxed);
140                self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
141                self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
142                if dropped_cap >= excess {
143                    break;
144                }
145                excess -= dropped_cap;
146            }
147        }
148        let _ = cap;
149    }
150
151    /// Clear all entries from the free list.
152    pub fn empty_free_list(&self) {
153        let mut free = self.free_list.lock();
154        let count = free.len() as u32;
155        for arr in free.drain(..) {
156            let cap = arr.data.capacity_bytes();
157            self.allocated_bytes.fetch_sub(cap as u64, Ordering::Relaxed);
158            self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
159        }
160        self.num_free_buffers.fetch_sub(count, Ordering::Relaxed);
161    }
162
163    pub fn allocated_bytes(&self) -> u64 {
164        self.allocated_bytes.load(Ordering::Relaxed)
165    }
166
167    pub fn num_alloc_buffers(&self) -> u32 {
168        self.num_alloc_buffers.load(Ordering::Relaxed)
169    }
170
171    pub fn num_free_buffers(&self) -> u32 {
172        self.num_free_buffers.load(Ordering::Relaxed)
173    }
174
175    pub fn max_memory(&self) -> usize {
176        self.max_memory
177    }
178
179    /// Allocate an NDArray wrapped in a pool-aware handle.
180    /// On final drop, the array is returned to this pool's free list.
181    pub fn alloc_handle(
182        pool: &Arc<Self>,
183        dims: Vec<NDDimension>,
184        data_type: NDDataType,
185    ) -> ADResult<NDArrayHandle> {
186        let array = pool.alloc(dims, data_type)?;
187        Ok(pooled_array(array, pool))
188    }
189}
190
191// Compile-time check: NDArrayPool is Send + Sync
192const _: fn() = || {
193    fn assert_send_sync<T: Send + Sync>() {}
194    assert_send_sync::<NDArrayPool>();
195};
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_alloc_auto_id() {
203        let pool = NDArrayPool::new(1_000_000);
204        let a1 = pool
205            .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
206            .unwrap();
207        let a2 = pool
208            .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
209            .unwrap();
210        assert_eq!(a1.unique_id, 1);
211        assert_eq!(a2.unique_id, 2);
212    }
213
214    #[test]
215    fn test_alloc_tracks_bytes() {
216        let pool = NDArrayPool::new(1_000_000);
217        let _ = pool
218            .alloc(vec![NDDimension::new(100)], NDDataType::Float64)
219            .unwrap();
220        assert_eq!(pool.allocated_bytes(), 800);
221    }
222
223    #[test]
224    fn test_alloc_exceeds_max() {
225        let pool = NDArrayPool::new(100);
226        let result = pool.alloc(vec![NDDimension::new(200)], NDDataType::UInt8);
227        assert!(result.is_err());
228    }
229
230    #[test]
231    fn test_alloc_copy_preserves_data() {
232        let pool = NDArrayPool::new(1_000_000);
233        let mut source = pool
234            .alloc(vec![NDDimension::new(4)], NDDataType::UInt8)
235            .unwrap();
236        if let NDDataBuffer::U8(ref mut v) = source.data {
237            v[0] = 1;
238            v[1] = 2;
239            v[2] = 3;
240            v[3] = 4;
241        }
242
243        let copy = pool.alloc_copy(&source).unwrap();
244        assert_ne!(copy.unique_id, source.unique_id);
245        assert_eq!(copy.dims.len(), source.dims.len());
246        if let NDDataBuffer::U8(ref v) = copy.data {
247            assert_eq!(v, &[1, 2, 3, 4]);
248        } else {
249            panic!("wrong type");
250        }
251    }
252
253    #[test]
254    fn test_alloc_copy_tracks_bytes() {
255        let pool = NDArrayPool::new(1_000_000);
256        let source = pool
257            .alloc(vec![NDDimension::new(10)], NDDataType::UInt16)
258            .unwrap();
259        assert_eq!(pool.allocated_bytes(), 20);
260        let _ = pool.alloc_copy(&source).unwrap();
261        assert_eq!(pool.allocated_bytes(), 40);
262    }
263
264    #[test]
265    fn test_alloc_copy_exceeds_max() {
266        let pool = NDArrayPool::new(60);
267        let source = pool
268            .alloc(vec![NDDimension::new(50)], NDDataType::UInt8)
269            .unwrap();
270        assert!(pool.alloc_copy(&source).is_err());
271    }
272
273    // --- Free-list reuse tests ---
274
275    #[test]
276    fn test_release_and_reuse() {
277        let pool = NDArrayPool::new(1_000_000);
278        let arr = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
279        let alloc_bytes_after_first = pool.allocated_bytes();
280        assert_eq!(pool.num_alloc_buffers(), 1);
281
282        // Release back to free list
283        pool.release(arr);
284        assert_eq!(pool.num_free_buffers(), 1);
285
286        // Alloc again — should reuse the freed buffer
287        let arr2 = pool.alloc(vec![NDDimension::new(50)], NDDataType::UInt8).unwrap();
288        assert_eq!(pool.num_free_buffers(), 0);
289        // allocated_bytes should be unchanged (reused buffer)
290        assert_eq!(pool.allocated_bytes(), alloc_bytes_after_first);
291        assert_eq!(arr2.data.len(), 50);
292    }
293
294    #[test]
295    fn test_free_list_prefers_smallest_sufficient() {
296        let pool = NDArrayPool::new(10_000_000);
297        let small = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
298        let large = pool.alloc(vec![NDDimension::new(10000)], NDDataType::UInt8).unwrap();
299        let medium = pool.alloc(vec![NDDimension::new(1000)], NDDataType::UInt8).unwrap();
300
301        pool.release(large);
302        pool.release(medium);
303        pool.release(small);
304        assert_eq!(pool.num_free_buffers(), 3);
305
306        // Request 500 bytes — should pick medium (1000 cap), not large (10000 cap)
307        let reused = pool.alloc(vec![NDDimension::new(500)], NDDataType::UInt8).unwrap();
308        assert_eq!(pool.num_free_buffers(), 2);
309        // The reused buffer should have capacity >= 1000 (from medium)
310        assert!(reused.data.capacity_bytes() >= 1000);
311    }
312
313    #[test]
314    fn test_empty_free_list() {
315        let pool = NDArrayPool::new(1_000_000);
316        let a1 = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
317        let a2 = pool.alloc(vec![NDDimension::new(200)], NDDataType::UInt8).unwrap();
318        pool.release(a1);
319        pool.release(a2);
320        assert_eq!(pool.num_free_buffers(), 2);
321
322        pool.empty_free_list();
323        assert_eq!(pool.num_free_buffers(), 0);
324        assert_eq!(pool.num_alloc_buffers(), 0);
325    }
326
327    #[test]
328    fn test_num_free_buffers_tracking() {
329        let pool = NDArrayPool::new(1_000_000);
330        assert_eq!(pool.num_free_buffers(), 0);
331
332        let a = pool.alloc(vec![NDDimension::new(10)], NDDataType::UInt8).unwrap();
333        assert_eq!(pool.num_free_buffers(), 0);
334
335        pool.release(a);
336        assert_eq!(pool.num_free_buffers(), 1);
337
338        let _ = pool.alloc(vec![NDDimension::new(5)], NDDataType::UInt8).unwrap();
339        assert_eq!(pool.num_free_buffers(), 0);
340    }
341
342    #[test]
343    fn test_concurrent_alloc_release() {
344        use std::sync::Arc;
345        use std::thread;
346
347        let pool = Arc::new(NDArrayPool::new(10_000_000));
348        let mut handles = Vec::new();
349
350        for _ in 0..4 {
351            let pool = pool.clone();
352            handles.push(thread::spawn(move || {
353                for _ in 0..100 {
354                    let arr = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
355                    pool.release(arr);
356                }
357            }));
358        }
359
360        for h in handles {
361            h.join().unwrap();
362        }
363
364        // All should be released back
365        assert!(pool.num_free_buffers() > 0);
366    }
367
368    #[test]
369    fn test_max_memory() {
370        let pool = NDArrayPool::new(42);
371        assert_eq!(pool.max_memory(), 42);
372    }
373}