1use std::sync::Arc;
2use std::sync::atomic::{AtomicI32, AtomicU32, AtomicU64, Ordering};
3
4use parking_lot::Mutex;
5
6use crate::error::{ADError, ADResult};
7use crate::ndarray::{NDArray, NDDataBuffer, NDDataType, NDDimension};
8use crate::ndarray_handle::{NDArrayHandle, pooled_array};
9use crate::timestamp::EpicsTimestamp;
10
11pub struct NDArrayPool {
18 max_memory: usize,
19 allocated_bytes: AtomicU64,
20 next_unique_id: AtomicI32,
21 free_list: Mutex<Vec<NDArray>>,
22 num_alloc_buffers: AtomicU32,
23 num_free_buffers: AtomicU32,
24}
25
26impl NDArrayPool {
27 pub fn new(max_memory: usize) -> Self {
28 Self {
29 max_memory,
30 allocated_bytes: AtomicU64::new(0),
31 next_unique_id: AtomicI32::new(1),
32 free_list: Mutex::new(Vec::new()),
33 num_alloc_buffers: AtomicU32::new(0),
34 num_free_buffers: AtomicU32::new(0),
35 }
36 }
37
38 pub fn alloc(&self, dims: Vec<NDDimension>, data_type: NDDataType) -> ADResult<NDArray> {
40 let num_elements: usize = dims.iter().map(|d| d.size).product();
41 let needed_bytes = num_elements * data_type.element_size();
42
43 let reused = {
45 let mut free = self.free_list.lock();
46 let mut best_idx = None;
48 let mut best_cap = usize::MAX;
49 for (i, arr) in free.iter().enumerate() {
50 let cap = arr.data.capacity_bytes();
51 if cap >= needed_bytes && cap < best_cap {
52 best_cap = cap;
53 best_idx = Some(i);
54 }
55 }
56 if let Some(idx) = best_idx {
57 let arr = free.swap_remove(idx);
58 self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
59 Some(arr)
60 } else {
61 None
62 }
63 };
64
65 let mut arr = if let Some(mut reused) = reused {
66 if reused.data.data_type() != data_type {
68 let old_cap = reused.data.capacity_bytes();
70 reused.data = NDDataBuffer::zeros(data_type, num_elements);
71 let new_cap = reused.data.capacity_bytes();
72 if new_cap > old_cap {
74 let diff = new_cap - old_cap;
75 let current = self.allocated_bytes.load(Ordering::Relaxed);
76 if current + diff as u64 > self.max_memory as u64 {
77 return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
78 }
79 self.allocated_bytes
80 .fetch_add(diff as u64, Ordering::Relaxed);
81 } else {
82 let diff = old_cap - new_cap;
83 self.allocated_bytes
84 .fetch_sub(diff as u64, Ordering::Relaxed);
85 }
86 } else {
87 reused.data.resize(num_elements);
88 }
89 reused.dims = dims;
90 reused.attributes.clear();
91 reused.codec = None;
92 reused
93 } else {
94 let current = self.allocated_bytes.load(Ordering::Relaxed);
96 if current + needed_bytes as u64 > self.max_memory as u64 {
97 return Err(ADError::PoolExhausted(needed_bytes, self.max_memory));
98 }
99 self.allocated_bytes
100 .fetch_add(needed_bytes as u64, Ordering::Relaxed);
101 self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
102 NDArray::new(dims, data_type)
103 };
104
105 arr.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
106 arr.timestamp = EpicsTimestamp::now();
107 Ok(arr)
108 }
109
110 pub fn alloc_copy(&self, source: &NDArray) -> ADResult<NDArray> {
112 let bytes = source.data.total_bytes();
113 let current = self.allocated_bytes.load(Ordering::Relaxed);
114 if current + bytes as u64 > self.max_memory as u64 {
115 return Err(ADError::PoolExhausted(bytes, self.max_memory));
116 }
117 self.allocated_bytes
118 .fetch_add(bytes as u64, Ordering::Relaxed);
119 self.num_alloc_buffers.fetch_add(1, Ordering::Relaxed);
120
121 let mut copy = source.clone();
122 copy.unique_id = self.next_unique_id.fetch_add(1, Ordering::Relaxed);
123 copy.timestamp = EpicsTimestamp::now();
124 Ok(copy)
125 }
126
127 pub fn release(&self, array: NDArray) {
129 let cap = array.data.capacity_bytes();
130 let mut free = self.free_list.lock();
131 free.push(array);
132 self.num_free_buffers.fetch_add(1, Ordering::Relaxed);
133
134 let total = self.allocated_bytes.load(Ordering::Relaxed) as usize;
136 if total > self.max_memory && !free.is_empty() {
137 free.sort_by(|a, b| b.data.capacity_bytes().cmp(&a.data.capacity_bytes()));
139 let mut excess = total.saturating_sub(self.max_memory);
140 while excess > 0 && !free.is_empty() {
141 let dropped = free.remove(0);
142 let dropped_cap = dropped.data.capacity_bytes();
143 self.allocated_bytes
144 .fetch_sub(dropped_cap.min(total) as u64, Ordering::Relaxed);
145 self.num_free_buffers.fetch_sub(1, Ordering::Relaxed);
146 self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
147 if dropped_cap >= excess {
148 break;
149 }
150 excess -= dropped_cap;
151 }
152 }
153 let _ = cap;
154 }
155
156 pub fn empty_free_list(&self) {
158 let mut free = self.free_list.lock();
159 let count = free.len() as u32;
160 for arr in free.drain(..) {
161 let cap = arr.data.capacity_bytes();
162 self.allocated_bytes
163 .fetch_sub(cap as u64, Ordering::Relaxed);
164 self.num_alloc_buffers.fetch_sub(1, Ordering::Relaxed);
165 }
166 self.num_free_buffers.fetch_sub(count, Ordering::Relaxed);
167 }
168
169 pub fn allocated_bytes(&self) -> u64 {
170 self.allocated_bytes.load(Ordering::Relaxed)
171 }
172
173 pub fn num_alloc_buffers(&self) -> u32 {
174 self.num_alloc_buffers.load(Ordering::Relaxed)
175 }
176
177 pub fn num_free_buffers(&self) -> u32 {
178 self.num_free_buffers.load(Ordering::Relaxed)
179 }
180
181 pub fn max_memory(&self) -> usize {
182 self.max_memory
183 }
184
185 pub fn alloc_handle(
188 pool: &Arc<Self>,
189 dims: Vec<NDDimension>,
190 data_type: NDDataType,
191 ) -> ADResult<NDArrayHandle> {
192 let array = pool.alloc(dims, data_type)?;
193 Ok(pooled_array(array, pool))
194 }
195}
196
197const _: fn() = || {
199 fn assert_send_sync<T: Send + Sync>() {}
200 assert_send_sync::<NDArrayPool>();
201};
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_alloc_auto_id() {
209 let pool = NDArrayPool::new(1_000_000);
210 let a1 = pool
211 .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
212 .unwrap();
213 let a2 = pool
214 .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
215 .unwrap();
216 assert_eq!(a1.unique_id, 1);
217 assert_eq!(a2.unique_id, 2);
218 }
219
220 #[test]
221 fn test_alloc_tracks_bytes() {
222 let pool = NDArrayPool::new(1_000_000);
223 let _ = pool
224 .alloc(vec![NDDimension::new(100)], NDDataType::Float64)
225 .unwrap();
226 assert_eq!(pool.allocated_bytes(), 800);
227 }
228
229 #[test]
230 fn test_alloc_exceeds_max() {
231 let pool = NDArrayPool::new(100);
232 let result = pool.alloc(vec![NDDimension::new(200)], NDDataType::UInt8);
233 assert!(result.is_err());
234 }
235
236 #[test]
237 fn test_alloc_copy_preserves_data() {
238 let pool = NDArrayPool::new(1_000_000);
239 let mut source = pool
240 .alloc(vec![NDDimension::new(4)], NDDataType::UInt8)
241 .unwrap();
242 if let NDDataBuffer::U8(ref mut v) = source.data {
243 v[0] = 1;
244 v[1] = 2;
245 v[2] = 3;
246 v[3] = 4;
247 }
248
249 let copy = pool.alloc_copy(&source).unwrap();
250 assert_ne!(copy.unique_id, source.unique_id);
251 assert_eq!(copy.dims.len(), source.dims.len());
252 if let NDDataBuffer::U8(ref v) = copy.data {
253 assert_eq!(v, &[1, 2, 3, 4]);
254 } else {
255 panic!("wrong type");
256 }
257 }
258
259 #[test]
260 fn test_alloc_copy_tracks_bytes() {
261 let pool = NDArrayPool::new(1_000_000);
262 let source = pool
263 .alloc(vec![NDDimension::new(10)], NDDataType::UInt16)
264 .unwrap();
265 assert_eq!(pool.allocated_bytes(), 20);
266 let _ = pool.alloc_copy(&source).unwrap();
267 assert_eq!(pool.allocated_bytes(), 40);
268 }
269
270 #[test]
271 fn test_alloc_copy_exceeds_max() {
272 let pool = NDArrayPool::new(60);
273 let source = pool
274 .alloc(vec![NDDimension::new(50)], NDDataType::UInt8)
275 .unwrap();
276 assert!(pool.alloc_copy(&source).is_err());
277 }
278
279 #[test]
282 fn test_release_and_reuse() {
283 let pool = NDArrayPool::new(1_000_000);
284 let arr = pool
285 .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
286 .unwrap();
287 let alloc_bytes_after_first = pool.allocated_bytes();
288 assert_eq!(pool.num_alloc_buffers(), 1);
289
290 pool.release(arr);
292 assert_eq!(pool.num_free_buffers(), 1);
293
294 let arr2 = pool
296 .alloc(vec![NDDimension::new(50)], NDDataType::UInt8)
297 .unwrap();
298 assert_eq!(pool.num_free_buffers(), 0);
299 assert_eq!(pool.allocated_bytes(), alloc_bytes_after_first);
301 assert_eq!(arr2.data.len(), 50);
302 }
303
304 #[test]
305 fn test_free_list_prefers_smallest_sufficient() {
306 let pool = NDArrayPool::new(10_000_000);
307 let small = pool
308 .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
309 .unwrap();
310 let large = pool
311 .alloc(vec![NDDimension::new(10000)], NDDataType::UInt8)
312 .unwrap();
313 let medium = pool
314 .alloc(vec![NDDimension::new(1000)], NDDataType::UInt8)
315 .unwrap();
316
317 pool.release(large);
318 pool.release(medium);
319 pool.release(small);
320 assert_eq!(pool.num_free_buffers(), 3);
321
322 let reused = pool
324 .alloc(vec![NDDimension::new(500)], NDDataType::UInt8)
325 .unwrap();
326 assert_eq!(pool.num_free_buffers(), 2);
327 assert!(reused.data.capacity_bytes() >= 1000);
329 }
330
331 #[test]
332 fn test_empty_free_list() {
333 let pool = NDArrayPool::new(1_000_000);
334 let a1 = pool
335 .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
336 .unwrap();
337 let a2 = pool
338 .alloc(vec![NDDimension::new(200)], NDDataType::UInt8)
339 .unwrap();
340 pool.release(a1);
341 pool.release(a2);
342 assert_eq!(pool.num_free_buffers(), 2);
343
344 pool.empty_free_list();
345 assert_eq!(pool.num_free_buffers(), 0);
346 assert_eq!(pool.num_alloc_buffers(), 0);
347 }
348
349 #[test]
350 fn test_num_free_buffers_tracking() {
351 let pool = NDArrayPool::new(1_000_000);
352 assert_eq!(pool.num_free_buffers(), 0);
353
354 let a = pool
355 .alloc(vec![NDDimension::new(10)], NDDataType::UInt8)
356 .unwrap();
357 assert_eq!(pool.num_free_buffers(), 0);
358
359 pool.release(a);
360 assert_eq!(pool.num_free_buffers(), 1);
361
362 let _ = pool
363 .alloc(vec![NDDimension::new(5)], NDDataType::UInt8)
364 .unwrap();
365 assert_eq!(pool.num_free_buffers(), 0);
366 }
367
368 #[test]
369 fn test_concurrent_alloc_release() {
370 use std::sync::Arc;
371 use std::thread;
372
373 let pool = Arc::new(NDArrayPool::new(10_000_000));
374 let mut handles = Vec::new();
375
376 for _ in 0..4 {
377 let pool = pool.clone();
378 handles.push(thread::spawn(move || {
379 for _ in 0..100 {
380 let arr = pool
381 .alloc(vec![NDDimension::new(100)], NDDataType::UInt8)
382 .unwrap();
383 pool.release(arr);
384 }
385 }));
386 }
387
388 for h in handles {
389 h.join().unwrap();
390 }
391
392 assert!(pool.num_free_buffers() > 0);
394 }
395
396 #[test]
397 fn test_max_memory() {
398 let pool = NDArrayPool::new(42);
399 assert_eq!(pool.max_memory(), 42);
400 }
401}