1use std::sync::atomic::{AtomicI32, AtomicU32, AtomicU64, Ordering};
2use std::sync::Arc;
3
4use parking_lot::Mutex;
5
6use crate::error::{ADError, ADResult};
7use crate::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
8use crate::ndarray_handle::{NDArrayHandle, pooled_array};
9use crate::timestamp::EpicsTimestamp;
10
11pub struct NDArrayPool {
18 max_memory: usize,
19 allocated_bytes: AtomicU64,
20 next_unique_id: AtomicI32,
21 free_list: Mutex<Vec<NDArray>>,
22 num_alloc_buffers: AtomicU32,
23 num_free_buffers: AtomicU32,
24}
25
26impl NDArrayPool {
27 pub fn new(max_memory: usize) -> Self {
28 Self {
29 max_memory,
30 allocated_bytes: AtomicU64::new(0),
31 next_unique_id: AtomicI32::new(1),
32 free_list: Mutex::new(Vec::new()),
33 num_alloc_buffers: AtomicU32::new(0),
34 num_free_buffers: AtomicU32::new(0),
35 }
36 }
37
38 pub fn alloc(&self, dims: Vec<NDDimension>, data_type: NDDataType) -> ADResult<NDArray> {
40 let num_elements: usize = dims.iter().map(|d| d.size).product();
41 let needed_bytes = num_elements * data_type.element_size();
42
43 let reused = {
45 let mut free = self.free_list.lock();
46 let mut best_idx = None;
48 let mut best_cap = usize::MAX;
49 for (i, arr) in free.iter().enumerate() {
50 let cap = arr.data.capacity_bytes();
51 if cap >= needed_bytes && cap < best_cap {
52 best_cap = cap;
53 best_idx = Some(i);
54 }
55 }
56 if let Some(idx) = best_idx {
57 let arr = free.swap_remove(idx);
58 self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
59 Some(arr)
60 } else {
61 None
62 }
63 };
64
65 let mut arr = if let Some(mut reused) = reused {
66 if reused.data.data_type() != data_type {
68 let old_cap = reused.data.capacity_bytes();
70 reused.data = NDDataBuffer::zeros(data_type, num_elements);
71 let new_cap = reused.data.capacity_bytes();
72 if new_cap > old_cap {
74 let diff = new_cap - old_cap;
75 let current = self.allocated_bytes.load(Ordering::Relaxed);
76 if current + diff as u64 > self.max_memory as u64 {
77 return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
78 }
79 self.allocated_bytes.fetch_add(diff as u64, Ordering::Relaxed);
80 } else {
81 let diff = old_cap - new_cap;
82 self.allocated_bytes.fetch_sub(diff as u64, Ordering::Relaxed);
83 }
84 } else {
85 reused.data.resize(num_elements);
86 }
87 reused.dims = dims;
88 reused.attributes.clear();
89 reused.codec = None;
90 reused
91 } else {
92 let current = self.allocated_bytes.load(Ordering::Relaxed);
94 if current + needed_bytes as u64 > self.max_memory as u64 {
95 return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
96 }
97 self.allocated_bytes.fetch_add(needed_bytes as u64, Ordering::Relaxed);
98 self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
99 NDArray::new(dims, data_type)
100 };
101
102 arr.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
103 arr.timestamp = EpicsTimestamp::now();
104 Ok(arr)
105 }
106
107 pub fn alloc_copy(&self, source: &NDArray) -> ADResult<NDArray> {
109 let bytes = source.data.total_bytes();
110 let current = self.allocated_bytes.load(Ordering::Relaxed);
111 if current + bytes as u64 > self.max_memory as u64 {
112 return Err(ADError::PoolExhausted(bytes, self.max_memory));
113 }
114 self.allocated_bytes.fetch_add(bytes as u64, Ordering::Relaxed);
115 self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
116
117 let mut copy = source.clone();
118 copy.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
119 copy.timestamp = EpicsTimestamp::now();
120 Ok(copy)
121 }
122
123 pub fn release(&self, array: NDArray) {
125 let cap = array.data.capacity_bytes();
126 let mut free = self.free_list.lock();
127 free.push(array);
128 self.num_free_buffers.fetch_add(1, Ordering::Relaxed);
129
130 let total = self.allocated_bytes.load(Ordering::Relaxed) as usize;
132 if total > self.max_memory && !free.is_empty() {
133 free.sort_by(|a, b| b.data.capacity_bytes().cmp(&a.data.capacity_bytes()));
135 let mut excess = total.saturating_sub(self.max_memory);
136 while excess > 0 && !free.is_empty() {
137 let dropped = free.remove(0);
138 let dropped_cap = dropped.data.capacity_bytes();
139 self.allocated_bytes.fetch_sub(dropped_cap.min(total) as u64, Ordering::Relaxed);
140 self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
141 self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
142 if dropped_cap >= excess {
143 break;
144 }
145 excess -= dropped_cap;
146 }
147 }
148 let _ = cap;
149 }
150
151 pub fn empty_free_list(&self) {
153 let mut free = self.free_list.lock();
154 let count = free.len() as u32;
155 for arr in free.drain(..) {
156 let cap = arr.data.capacity_bytes();
157 self.allocated_bytes.fetch_sub(cap as u64, Ordering::Relaxed);
158 self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
159 }
160 self.num_free_buffers.fetch_sub(count, Ordering::Relaxed);
161 }
162
163 pub fn allocated_bytes(&self) -> u64 {
164 self.allocated_bytes.load(Ordering::Relaxed)
165 }
166
167 pub fn num_alloc_buffers(&self) -> u32 {
168 self.num_alloc_buffers.load(Ordering::Relaxed)
169 }
170
171 pub fn num_free_buffers(&self) -> u32 {
172 self.num_free_buffers.load(Ordering::Relaxed)
173 }
174
175 pub fn max_memory(&self) -> usize {
176 self.max_memory
177 }
178
179 pub fn alloc_handle(
182 pool: &Arc<Self>,
183 dims: Vec<NDDimension>,
184 data_type: NDDataType,
185 ) -> ADResult<NDArrayHandle> {
186 let array = pool.alloc(dims, data_type)?;
187 Ok(pooled_array(array, pool))
188 }
189}
190
191const _: fn() = || {
193 fn assert_send_sync<T: Send + Sync>() {}
194 assert_send_sync::<NDArrayPool>();
195};
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_alloc_auto_id() {
203 let pool = NDArrayPool::new(1_000_000);
204 let a1 = pool
205 .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
206 .unwrap();
207 let a2 = pool
208 .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
209 .unwrap();
210 assert_eq!(a1.unique_id, 1);
211 assert_eq!(a2.unique_id, 2);
212 }
213
214 #[test]
215 fn test_alloc_tracks_bytes() {
216 let pool = NDArrayPool::new(1_000_000);
217 let _ = pool
218 .alloc(vec![NDDimension::new(100)], NDDataType::Float64)
219 .unwrap();
220 assert_eq!(pool.allocated_bytes(), 800);
221 }
222
223 #[test]
224 fn test_alloc_exceeds_max() {
225 let pool = NDArrayPool::new(100);
226 let result = pool.alloc(vec![NDDimension::new(200)], NDDataType::UInt8);
227 assert!(result.is_err());
228 }
229
230 #[test]
231 fn test_alloc_copy_preserves_data() {
232 let pool = NDArrayPool::new(1_000_000);
233 let mut source = pool
234 .alloc(vec![NDDimension::new(4)], NDDataType::UInt8)
235 .unwrap();
236 if let NDDataBuffer::U8(ref mut v) = source.data {
237 v[0] = 1;
238 v[1] = 2;
239 v[2] = 3;
240 v[3] = 4;
241 }
242
243 let copy = pool.alloc_copy(&source).unwrap();
244 assert_ne!(copy.unique_id, source.unique_id);
245 assert_eq!(copy.dims.len(), source.dims.len());
246 if let NDDataBuffer::U8(ref v) = copy.data {
247 assert_eq!(v, &[1, 2, 3, 4]);
248 } else {
249 panic!("wrong type");
250 }
251 }
252
253 #[test]
254 fn test_alloc_copy_tracks_bytes() {
255 let pool = NDArrayPool::new(1_000_000);
256 let source = pool
257 .alloc(vec![NDDimension::new(10)], NDDataType::UInt16)
258 .unwrap();
259 assert_eq!(pool.allocated_bytes(), 20);
260 let _ = pool.alloc_copy(&source).unwrap();
261 assert_eq!(pool.allocated_bytes(), 40);
262 }
263
264 #[test]
265 fn test_alloc_copy_exceeds_max() {
266 let pool = NDArrayPool::new(60);
267 let source = pool
268 .alloc(vec![NDDimension::new(50)], NDDataType::UInt8)
269 .unwrap();
270 assert!(pool.alloc_copy(&source).is_err());
271 }
272
273 #[test]
276 fn test_release_and_reuse() {
277 let pool = NDArrayPool::new(1_000_000);
278 let arr = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
279 let alloc_bytes_after_first = pool.allocated_bytes();
280 assert_eq!(pool.num_alloc_buffers(), 1);
281
282 pool.release(arr);
284 assert_eq!(pool.num_free_buffers(), 1);
285
286 let arr2 = pool.alloc(vec![NDDimension::new(50)], NDDataType::UInt8).unwrap();
288 assert_eq!(pool.num_free_buffers(), 0);
289 assert_eq!(pool.allocated_bytes(), alloc_bytes_after_first);
291 assert_eq!(arr2.data.len(), 50);
292 }
293
294 #[test]
295 fn test_free_list_prefers_smallest_sufficient() {
296 let pool = NDArrayPool::new(10_000_000);
297 let small = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
298 let large = pool.alloc(vec![NDDimension::new(10000)], NDDataType::UInt8).unwrap();
299 let medium = pool.alloc(vec![NDDimension::new(1000)], NDDataType::UInt8).unwrap();
300
301 pool.release(large);
302 pool.release(medium);
303 pool.release(small);
304 assert_eq!(pool.num_free_buffers(), 3);
305
306 let reused = pool.alloc(vec![NDDimension::new(500)], NDDataType::UInt8).unwrap();
308 assert_eq!(pool.num_free_buffers(), 2);
309 assert!(reused.data.capacity_bytes() >= 1000);
311 }
312
313 #[test]
314 fn test_empty_free_list() {
315 let pool = NDArrayPool::new(1_000_000);
316 let a1 = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
317 let a2 = pool.alloc(vec![NDDimension::new(200)], NDDataType::UInt8).unwrap();
318 pool.release(a1);
319 pool.release(a2);
320 assert_eq!(pool.num_free_buffers(), 2);
321
322 pool.empty_free_list();
323 assert_eq!(pool.num_free_buffers(), 0);
324 assert_eq!(pool.num_alloc_buffers(), 0);
325 }
326
327 #[test]
328 fn test_num_free_buffers_tracking() {
329 let pool = NDArrayPool::new(1_000_000);
330 assert_eq!(pool.num_free_buffers(), 0);
331
332 let a = pool.alloc(vec![NDDimension::new(10)], NDDataType::UInt8).unwrap();
333 assert_eq!(pool.num_free_buffers(), 0);
334
335 pool.release(a);
336 assert_eq!(pool.num_free_buffers(), 1);
337
338 let _ = pool.alloc(vec![NDDimension::new(5)], NDDataType::UInt8).unwrap();
339 assert_eq!(pool.num_free_buffers(), 0);
340 }
341
342 #[test]
343 fn test_concurrent_alloc_release() {
344 use std::sync::Arc;
345 use std::thread;
346
347 let pool = Arc::new(NDArrayPool::new(10_000_000));
348 let mut handles = Vec::new();
349
350 for _ in 0..4 {
351 let pool = pool.clone();
352 handles.push(thread::spawn(move || {
353 for _ in 0..100 {
354 let arr = pool.alloc(vec![NDDimension::new(100)], NDDataType::UInt8).unwrap();
355 pool.release(arr);
356 }
357 }));
358 }
359
360 for h in handles {
361 h.join().unwrap();
362 }
363
364 assert!(pool.num_free_buffers() > 0);
366 }
367
368 #[test]
369 fn test_max_memory() {
370 let pool = NDArrayPool::new(42);
371 assert_eq!(pool.max_memory(), 42);
372 }
373}