1use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7pub(crate) trait BlockingProcessFn: Send + Sync {
9 fn process_and_publish(&self, array: &NDArray);
10}
11
12pub struct QueuedArrayCounter {
15 count: AtomicUsize,
16 mutex: parking_lot::Mutex<()>,
17 condvar: parking_lot::Condvar,
18}
19
20impl QueuedArrayCounter {
21 pub fn new() -> Self {
23 Self {
24 count: AtomicUsize::new(0),
25 mutex: parking_lot::Mutex::new(()),
26 condvar: parking_lot::Condvar::new(),
27 }
28 }
29
30 pub fn increment(&self) {
32 self.count.fetch_add(1, Ordering::AcqRel);
33 }
34
35 pub fn decrement(&self) {
37 let prev = self.count.fetch_sub(1, Ordering::AcqRel);
38 if prev == 1 {
39 let _guard = self.mutex.lock();
40 self.condvar.notify_all();
41 }
42 }
43
44 pub fn get(&self) -> usize {
46 self.count.load(Ordering::Acquire)
47 }
48
49 pub fn wait_until_zero(&self, timeout: Duration) -> bool {
52 let mut guard = self.mutex.lock();
53 if self.count.load(Ordering::Acquire) == 0 {
54 return true;
55 }
56 !self
57 .condvar
58 .wait_while_for(&mut guard, |_| {
59 self.count.load(Ordering::Acquire) != 0
60 }, timeout)
61 .timed_out()
62 }
63}
64
65impl Default for QueuedArrayCounter {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71pub struct ArrayMessage {
74 pub array: Arc<NDArray>,
75 pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
76}
77
78impl Drop for ArrayMessage {
79 fn drop(&mut self) {
80 if let Some(c) = self.counter.take() {
81 c.decrement();
82 }
83 }
84}
85
86#[derive(Clone)]
88pub struct NDArraySender {
89 tx: tokio::sync::mpsc::Sender<ArrayMessage>,
90 port_name: String,
91 dropped_count: Arc<AtomicU64>,
92 enabled: Arc<AtomicBool>,
93 blocking_mode: Arc<AtomicBool>,
94 blocking_processor: Option<Arc<dyn BlockingProcessFn>>,
95 queued_counter: Option<Arc<QueuedArrayCounter>>,
96}
97
98impl NDArraySender {
99 pub fn send(&self, array: Arc<NDArray>) {
104 if !self.enabled.load(Ordering::Acquire) {
105 return;
106 }
107 if self.blocking_mode.load(Ordering::Acquire) {
108 if let Some(ref bp) = self.blocking_processor {
109 bp.process_and_publish(&array);
110 return;
111 }
112 }
113 if let Some(ref c) = self.queued_counter {
115 c.increment();
116 }
117 let msg = ArrayMessage {
118 array,
119 counter: self.queued_counter.clone(),
120 };
121 match self.tx.try_send(msg) {
122 Ok(()) => {}
123 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
124 self.dropped_count.fetch_add(1, Ordering::Relaxed);
126 }
127 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
128 }
130 }
131 }
132
133 pub fn is_enabled(&self) -> bool {
135 self.enabled.load(Ordering::Acquire)
136 }
137
138 pub fn is_blocking(&self) -> bool {
140 self.blocking_mode.load(Ordering::Acquire)
141 }
142
143 pub fn port_name(&self) -> &str {
144 &self.port_name
145 }
146
147 pub fn dropped_count(&self) -> u64 {
148 self.dropped_count.load(Ordering::Relaxed)
149 }
150
151 pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
153 self.queued_counter = Some(counter);
154 }
155
156 pub(crate) fn with_blocking_support(
158 self,
159 enabled: Arc<AtomicBool>,
160 blocking_mode: Arc<AtomicBool>,
161 blocking_processor: Arc<dyn BlockingProcessFn>,
162 ) -> Self {
163 Self {
164 enabled,
165 blocking_mode,
166 blocking_processor: Some(blocking_processor),
167 ..self
168 }
169 }
170}
171
172pub struct NDArrayReceiver {
174 rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
175}
176
177impl NDArrayReceiver {
178 pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
180 self.rx.blocking_recv().map(|msg| msg.array.clone())
181 }
182
183 pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
185 self.rx.recv().await.map(|msg| msg.array.clone())
186 }
187
188 pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
191 self.rx.recv().await
192 }
193}
194
195pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
197 let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
198 (
199 NDArraySender {
200 tx,
201 port_name: port_name.to_string(),
202 dropped_count: Arc::new(AtomicU64::new(0)),
203 enabled: Arc::new(AtomicBool::new(true)),
204 blocking_mode: Arc::new(AtomicBool::new(false)),
205 blocking_processor: None,
206 queued_counter: None,
207 },
208 NDArrayReceiver { rx },
209 )
210}
211
212pub struct NDArrayOutput {
214 senders: Vec<NDArraySender>,
215}
216
217impl NDArrayOutput {
218 pub fn new() -> Self {
219 Self {
220 senders: Vec::new(),
221 }
222 }
223
224 pub fn add(&mut self, sender: NDArraySender) {
225 self.senders.push(sender);
226 }
227
228 pub fn remove(&mut self, port_name: &str) {
229 self.senders.retain(|s| s.port_name != port_name);
230 }
231
232 pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
234 let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
235 Some(self.senders.swap_remove(idx))
236 }
237
238 pub fn publish(&self, array: Arc<NDArray>) {
240 for sender in &self.senders {
241 sender.send(array.clone());
242 }
243 }
244
245 pub fn total_dropped(&self) -> u64 {
246 self.senders.iter().map(|s| s.dropped_count()).sum()
247 }
248
249 pub fn num_senders(&self) -> usize {
250 self.senders.len()
251 }
252}
253
254impl Default for NDArrayOutput {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::ndarray::{NDArray, NDDataType, NDDimension};
264
265 fn make_test_array(id: i32) -> Arc<NDArray> {
266 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
267 arr.unique_id = id;
268 Arc::new(arr)
269 }
270
271 #[test]
272 fn test_send_receive_basic() {
273 let (sender, mut receiver) = ndarray_channel("TEST", 10);
274 sender.send(make_test_array(1));
275 sender.send(make_test_array(2));
276
277 let rt = tokio::runtime::Builder::new_current_thread()
278 .enable_all()
279 .build()
280 .unwrap();
281 rt.block_on(async {
282 let a1 = receiver.recv().await.unwrap();
283 assert_eq!(a1.unique_id, 1);
284 let a2 = receiver.recv().await.unwrap();
285 assert_eq!(a2.unique_id, 2);
286 });
287 }
288
289 #[test]
290 fn test_back_pressure_drops() {
291 let (sender, _receiver) = ndarray_channel("TEST", 2);
292 sender.send(make_test_array(1));
294 sender.send(make_test_array(2));
295 sender.send(make_test_array(3));
297 sender.send(make_test_array(4));
298
299 assert_eq!(sender.dropped_count(), 2);
300 }
301
302 #[test]
303 fn test_fanout_three_receivers() {
304 let (s1, mut r1) = ndarray_channel("P1", 10);
305 let (s2, mut r2) = ndarray_channel("P2", 10);
306 let (s3, mut r3) = ndarray_channel("P3", 10);
307
308 let mut output = NDArrayOutput::new();
309 output.add(s1);
310 output.add(s2);
311 output.add(s3);
312
313 output.publish(make_test_array(42));
314
315 let rt = tokio::runtime::Builder::new_current_thread()
316 .enable_all()
317 .build()
318 .unwrap();
319 rt.block_on(async {
320 assert_eq!(r1.recv().await.unwrap().unique_id, 42);
321 assert_eq!(r2.recv().await.unwrap().unique_id, 42);
322 assert_eq!(r3.recv().await.unwrap().unique_id, 42);
323 });
324 }
325
326 #[test]
327 fn test_fanout_total_dropped() {
328 let (s1, _r1) = ndarray_channel("P1", 1);
329 let (s2, _r2) = ndarray_channel("P2", 1);
330
331 let mut output = NDArrayOutput::new();
332 output.add(s1);
333 output.add(s2);
334
335 output.publish(make_test_array(1));
337 output.publish(make_test_array(2));
339
340 assert_eq!(output.total_dropped(), 2);
341 }
342
343 #[test]
344 fn test_fanout_remove() {
345 let (s1, _r1) = ndarray_channel("P1", 10);
346 let (s2, _r2) = ndarray_channel("P2", 10);
347
348 let mut output = NDArrayOutput::new();
349 output.add(s1);
350 output.add(s2);
351 assert_eq!(output.num_senders(), 2);
352
353 output.remove("P1");
354 assert_eq!(output.num_senders(), 1);
355 }
356
357 #[test]
358 fn test_blocking_recv() {
359 let (sender, mut receiver) = ndarray_channel("TEST", 10);
360
361 let handle = std::thread::spawn(move || {
362 let arr = receiver.blocking_recv().unwrap();
363 arr.unique_id
364 });
365
366 sender.send(make_test_array(99));
367 let id = handle.join().unwrap();
368 assert_eq!(id, 99);
369 }
370
371 #[test]
372 fn test_channel_closed_on_receiver_drop() {
373 let (sender, receiver) = ndarray_channel("TEST", 10);
374 drop(receiver);
375 sender.send(make_test_array(1));
377 assert_eq!(sender.dropped_count(), 0); }
379
380 #[test]
381 fn test_queued_counter_basic() {
382 let counter = QueuedArrayCounter::new();
383 assert_eq!(counter.get(), 0);
384 counter.increment();
385 assert_eq!(counter.get(), 1);
386 counter.increment();
387 assert_eq!(counter.get(), 2);
388 counter.decrement();
389 assert_eq!(counter.get(), 1);
390 counter.decrement();
391 assert_eq!(counter.get(), 0);
392 }
393
394 #[test]
395 fn test_queued_counter_wait_until_zero() {
396 let counter = Arc::new(QueuedArrayCounter::new());
397 counter.increment();
398 counter.increment();
399
400 let c = counter.clone();
401 let h = std::thread::spawn(move || {
402 std::thread::sleep(Duration::from_millis(10));
403 c.decrement();
404 std::thread::sleep(Duration::from_millis(10));
405 c.decrement();
406 });
407
408 assert!(counter.wait_until_zero(Duration::from_secs(5)));
409 h.join().unwrap();
410 }
411
412 #[test]
413 fn test_queued_counter_wait_timeout() {
414 let counter = Arc::new(QueuedArrayCounter::new());
415 counter.increment();
416 assert!(!counter.wait_until_zero(Duration::from_millis(10)));
417 }
418
419 #[test]
420 fn test_send_increments_counter() {
421 let counter = Arc::new(QueuedArrayCounter::new());
422 let (mut sender, _receiver) = ndarray_channel("TEST", 10);
423 sender.set_queued_counter(counter.clone());
424
425 sender.send(make_test_array(1));
426 assert_eq!(counter.get(), 1);
427 sender.send(make_test_array(2));
428 assert_eq!(counter.get(), 2);
429 }
430
431 #[test]
432 fn test_send_queue_full_no_net_increment() {
433 let counter = Arc::new(QueuedArrayCounter::new());
434 let (mut sender, _receiver) = ndarray_channel("TEST", 1);
435 sender.set_queued_counter(counter.clone());
436
437 sender.send(make_test_array(1)); assert_eq!(counter.get(), 1);
439 sender.send(make_test_array(2)); assert_eq!(counter.get(), 1);
441 }
442
443 #[test]
444 fn test_message_drop_decrements() {
445 let counter = Arc::new(QueuedArrayCounter::new());
446 counter.increment();
447 let msg = ArrayMessage {
448 array: make_test_array(1),
449 counter: Some(counter.clone()),
450 };
451 assert_eq!(counter.get(), 1);
452 drop(msg);
453 assert_eq!(counter.get(), 0);
454 }
455}