Skip to main content

ad_core_rs/
ndarray_pool.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicI32, AtomicU32, AtomicU64, Ordering};
3
4use parking_lot::Mutex;
5
6use crate::error::{ADError, ADResult};
7use crate::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
8use crate::ndarray_handle::{NDArrayHandle, pooled_array};
9use crate::timestamp::EpicsTimestamp;
10
11/// NDArray factory with free-list reuse and memory tracking.
12///
13/// Mimics C++ ADCore's NDArrayPool: on alloc, checks the free list for a
14/// buffer with sufficient capacity. On release, returns the buffer to the
15/// free list for future reuse. The free list is sorted by capacity (descending)
16/// and excess entries are dropped when max_memory is exceeded.
17pub struct NDArrayPool {
18    max_memory: usize,
19    allocated_bytes: AtomicU64,
20    next_unique_id: AtomicI32,
21    free_list: Mutex<Vec<NDArray>>,
22    num_alloc_buffers: AtomicU32,
23    num_free_buffers: AtomicU32,
24}
25
26impl NDArrayPool {
27    pub fn new(max_memory: usize) -> Self {
28        Self {
29            max_memory,
30            allocated_bytes: AtomicU64::new(0),
31            next_unique_id: AtomicI32::new(1),
32            free_list: Mutex::new(Vec::new()),
33            num_alloc_buffers: AtomicU32::new(0),
34            num_free_buffers: AtomicU32::new(0),
35        }
36    }
37
38    /// Allocate an NDArray. Tries to reuse a free-list entry with sufficient capacity.
39    pub fn alloc(&self, dims: Vec<NDDimension>, data_type: NDDataType) -> ADResult<NDArray> {
40        let num_elements: usize = dims.iter().map(|d| d.size).product();
41        let needed_bytes = num_elements * data_type.element_size();
42
43        // Try to find a reusable buffer in the free list
44        let reused = {
45            let mut free = self.free_list.lock();
46            // Find smallest buffer that is large enough (free list sorted descending by capacity)
47            let mut best_idx = None;
48            let mut best_cap = usize::MAX;
49            for (i, arr) in free.iter().enumerate() {
50                let cap = arr.data.capacity_bytes();
51                if cap >= needed_bytes && cap < best_cap {
52                    best_cap = cap;
53                    best_idx = Some(i);
54                }
55            }
56            if let Some(idx) = best_idx {
57                let arr = free.swap_remove(idx);
58                self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
59                Some(arr)
60            } else {
61                None
62            }
63        };
64
65        let mut arr = if let Some(mut reused) = reused {
66            // Reuse: retype the buffer if needed, resize to match
67            if reused.data.data_type() != data_type {
68                // Must reallocate with new type, but we keep the allocation tracked
69                let old_cap = reused.data.capacity_bytes();
70                reused.data = NDDataBuffer::zeros(data_type, num_elements);
71                let new_cap = reused.data.capacity_bytes();
72                // Adjust allocated_bytes for the difference
73                if new_cap > old_cap {
74                    let diff = new_cap - old_cap;
75                    let current = self.allocated_bytes.load(Ordering::Relaxed);
76                    if current + diff as u64 > self.max_memory as u64 {
77                        return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
78                    }
79                    self.allocated_bytes
80                        .fetch_add(diff as u64, Ordering::Relaxed);
81                } else {
82                    let diff = old_cap - new_cap;
83                    self.allocated_bytes
84                        .fetch_sub(diff as u64, Ordering::Relaxed);
85                }
86            } else {
87                reused.data.resize(num_elements);
88            }
89            reused.dims = dims;
90            reused.attributes.clear();
91            reused.codec = None;
92            reused
93        } else {
94            // Fresh allocation
95            let current = self.allocated_bytes.load(Ordering::Relaxed);
96            if current + needed_bytes as u64 > self.max_memory as u64 {
97                return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
98            }
99            self.allocated_bytes
100                .fetch_add(needed_bytes as u64, Ordering::Relaxed);
101            self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
102            NDArray::new(dims, data_type)
103        };
104
105        arr.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
106        arr.timestamp = EpicsTimestamp::now();
107        Ok(arr)
108    }
109
110    /// Allocate a copy of an existing NDArray (new unique_id, data cloned).
111    pub fn alloc_copy(&self, source: &NDArray) -> ADResult<NDArray> {
112        let bytes = source.data.total_bytes();
113        let current = self.allocated_bytes.load(Ordering::Relaxed);
114        if current + bytes as u64 > self.max_memory as u64 {
115            return Err(ADError::PoolExhausted(bytes, self.max_memory));
116        }
117        self.allocated_bytes
118            .fetch_add(bytes as u64, Ordering::Relaxed);
119        self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
120
121        let mut copy = source.clone();
122        copy.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
123        copy.timestamp = EpicsTimestamp::now();
124        Ok(copy)
125    }
126
127    /// Return an array to the free list for future reuse.
128    pub fn release(&self, array: NDArray) {
129        let cap = array.data.capacity_bytes();
130        let mut free = self.free_list.lock();
131        free.push(array);
132        self.num_free_buffers.fetch_add(1, Ordering::Relaxed);
133
134        // If total allocated exceeds max_memory, drop largest free entries
135        let total = self.allocated_bytes.load(Ordering::Relaxed) as usize;
136        if total > self.max_memory && !free.is_empty() {
137            // Sort descending by capacity so we drop largest first
138            free.sort_by(|a, b| b.data.capacity_bytes().cmp(&a.data.capacity_bytes()));
139            let mut excess = total.saturating_sub(self.max_memory);
140            while excess > 0 && !free.is_empty() {
141                let dropped = free.remove(0);
142                let dropped_cap = dropped.data.capacity_bytes();
143                self.allocated_bytes
144                    .fetch_sub(dropped_cap.min(total) as u64, Ordering::Relaxed);
145                self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
146                self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
147                if dropped_cap >= excess {
148                    break;
149                }
150                excess -= dropped_cap;
151            }
152        }
153        let _ = cap;
154    }
155
156    /// Clear all entries from the free list.
157    pub fn empty_free_list(&self) {
158        let mut free = self.free_list.lock();
159        let count = free.len() as u32;
160        for arr in free.drain(..) {
161            let cap = arr.data.capacity_bytes();
162            self.allocated_bytes
163                .fetch_sub(cap as u64, Ordering::Relaxed);
164            self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
165        }
166        self.num_free_buffers.fetch_sub(count, Ordering::Relaxed);
167    }
168
169    pub fn allocated_bytes(&self) -> u64 {
170        self.allocated_bytes.load(Ordering::Relaxed)
171    }
172
173    pub fn num_alloc_buffers(&self) -> u32 {
174        self.num_alloc_buffers.load(Ordering::Relaxed)
175    }
176
177    pub fn num_free_buffers(&self) -> u32 {
178        self.num_free_buffers.load(Ordering::Relaxed)
179    }
180
181    pub fn max_memory(&self) -> usize {
182        self.max_memory
183    }
184
185    /// Allocate an NDArray wrapped in a pool-aware handle.
186    /// On final drop, the array is returned to this pool's free list.
187    pub fn alloc_handle(
188        pool: &Arc<Self>,
189        dims: Vec<NDDimension>,
190        data_type: NDDataType,
191    ) -> ADResult<NDArrayHandle> {
192        let array = pool.alloc(dims, data_type)?;
193        Ok(pooled_array(array, pool))
194    }
195}
196
197// Compile-time check: NDArrayPool is Send + Sync
198const _: fn() = || {
199    fn assert_send_sync<T: Send + Sync>() {}
200    assert_send_sync::<NDArrayPool>();
201};
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_alloc_auto_id() {
209        let pool = NDArrayPool::new(1_000_000);
210        let a1 = pool
211            .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
212            .unwrap();
213        let a2 = pool
214            .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
215            .unwrap();
216        assert_eq!(a1.unique_id, 1);
217        assert_eq!(a2.unique_id, 2);
218    }
219
220    #[test]
221    fn test_alloc_tracks_bytes() {
222        let pool = NDArrayPool::new(1_000_000);
223        let _ = pool
224            .alloc(vec![NDDimension::new(100)], NDDataType::Float64)
225            .unwrap();
226        assert_eq!(pool.allocated_bytes(), 800);
227    }
228
229    #[test]
230    fn test_alloc_exceeds_max() {
231        let pool = NDArrayPool::new(100);
232        let result = pool.alloc(vec![NDDimension::new(200)], NDDataType::UInt8);
233        assert!(result.is_err());
234    }
235
236    #[test]
237    fn test_alloc_copy_preserves_data() {
238        let pool = NDArrayPool::new(1_000_000);
239        let mut source = pool
240            .alloc(vec![NDDimension::new(4)], NDDataType::UInt8)
241            .unwrap();
242        if let NDDataBuffer::U8(ref mut v) = source.data {
243            v[0] = 1;
244            v[1] = 2;
245            v[2] = 3;
246            v[3] = 4;
247        }
248
249        let copy = pool.alloc_copy(&source).unwrap();
250        assert_ne!(copy.unique_id, source.unique_id);
251        assert_eq!(copy.dims.len(), source.dims.len());
252        if let NDDataBuffer::U8(ref v) = copy.data {
253            assert_eq!(v, &[1, 2, 3, 4]);
254        } else {
255            panic!("wrong type");
256        }
257    }
258
259    #[test]
260    fn test_alloc_copy_tracks_bytes() {
261        let pool = NDArrayPool::new(1_000_000);
262        let source = pool
263            .alloc(vec![NDDimension::new(10)], NDDataType::UInt16)
264            .unwrap();
265        assert_eq!(pool.allocated_bytes(), 20);
266        let _ = pool.alloc_copy(&source).unwrap();
267        assert_eq!(pool.allocated_bytes(), 40);
268    }
269
270    #[test]
271    fn test_alloc_copy_exceeds_max() {
272        let pool = NDArrayPool::new(60);
273        let source = pool
274            .alloc(vec![NDDimension::new(50)], NDDataType::UInt8)
275            .unwrap();
276        assert!(pool.alloc_copy(&source).is_err());
277    }
278
279    // --- Free-list reuse tests ---
280
281    #[test]
282    fn test_release_and_reuse() {
283        let pool = NDArrayPool::new(1_000_000);
284        let arr = pool
285            .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
286            .unwrap();
287        let alloc_bytes_after_first = pool.allocated_bytes();
288        assert_eq!(pool.num_alloc_buffers(), 1);
289
290        // Release back to free list
291        pool.release(arr);
292        assert_eq!(pool.num_free_buffers(), 1);
293
294        // Alloc again — should reuse the freed buffer
295        let arr2 = pool
296            .alloc(vec![NDDimension::new(50)], NDDataType::UInt8)
297            .unwrap();
298        assert_eq!(pool.num_free_buffers(), 0);
299        // allocated_bytes should be unchanged (reused buffer)
300        assert_eq!(pool.allocated_bytes(), alloc_bytes_after_first);
301        assert_eq!(arr2.data.len(), 50);
302    }
303
304    #[test]
305    fn test_free_list_prefers_smallest_sufficient() {
306        let pool = NDArrayPool::new(10_000_000);
307        let small = pool
308            .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
309            .unwrap();
310        let large = pool
311            .alloc(vec![NDDimension::new(10000)], NDDataType::UInt8)
312            .unwrap();
313        let medium = pool
314            .alloc(vec![NDDimension::new(1000)], NDDataType::UInt8)
315            .unwrap();
316
317        pool.release(large);
318        pool.release(medium);
319        pool.release(small);
320        assert_eq!(pool.num_free_buffers(), 3);
321
322        // Request 500 bytes — should pick medium (1000 cap), not large (10000 cap)
323        let reused = pool
324            .alloc(vec![NDDimension::new(500)], NDDataType::UInt8)
325            .unwrap();
326        assert_eq!(pool.num_free_buffers(), 2);
327        // The reused buffer should have capacity >= 1000 (from medium)
328        assert!(reused.data.capacity_bytes() >= 1000);
329    }
330
331    #[test]
332    fn test_empty_free_list() {
333        let pool = NDArrayPool::new(1_000_000);
334        let a1 = pool
335            .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
336            .unwrap();
337        let a2 = pool
338            .alloc(vec![NDDimension::new(200)], NDDataType::UInt8)
339            .unwrap();
340        pool.release(a1);
341        pool.release(a2);
342        assert_eq!(pool.num_free_buffers(), 2);
343
344        pool.empty_free_list();
345        assert_eq!(pool.num_free_buffers(), 0);
346        assert_eq!(pool.num_alloc_buffers(), 0);
347    }
348
349    #[test]
350    fn test_num_free_buffers_tracking() {
351        let pool = NDArrayPool::new(1_000_000);
352        assert_eq!(pool.num_free_buffers(), 0);
353
354        let a = pool
355            .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
356            .unwrap();
357        assert_eq!(pool.num_free_buffers(), 0);
358
359        pool.release(a);
360        assert_eq!(pool.num_free_buffers(), 1);
361
362        let _ = pool
363            .alloc(vec![NDDimension::new(5)], NDDataType::UInt8)
364            .unwrap();
365        assert_eq!(pool.num_free_buffers(), 0);
366    }
367
368    #[test]
369    fn test_concurrent_alloc_release() {
370        use std::sync::Arc;
371        use std::thread;
372
373        let pool = Arc::new(NDArrayPool::new(10_000_000));
374        let mut handles = Vec::new();
375
376        for _ in 0..4 {
377            let pool = pool.clone();
378            handles.push(thread::spawn(move || {
379                for _ in 0..100 {
380                    let arr = pool
381                        .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
382                        .unwrap();
383                    pool.release(arr);
384                }
385            }));
386        }
387
388        for h in handles {
389            h.join().unwrap();
390        }
391
392        // All should be released back
393        assert!(pool.num_free_buffers() > 0);
394    }
395
396    #[test]
397    fn test_max_memory() {
398        let pool = NDArrayPool::new(42);
399        assert_eq!(pool.max_memory(), 42);
400    }
401}