1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicI32, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7pub struct QueuedArrayCounter {
10 count: AtomicUsize,
11 mutex: parking_lot::Mutex<()>,
12 condvar: parking_lot::Condvar,
13}
14
15impl QueuedArrayCounter {
16 pub fn new() -> Self {
18 Self {
19 count: AtomicUsize::new(0),
20 mutex: parking_lot::Mutex::new(()),
21 condvar: parking_lot::Condvar::new(),
22 }
23 }
24
25 pub fn increment(&self) {
27 self.count.fetch_add(1, Ordering::AcqRel);
28 }
29
30 pub fn decrement(&self) {
32 let prev = self.count.fetch_sub(1, Ordering::AcqRel);
33 if prev == 1 {
34 let _guard = self.mutex.lock();
35 self.condvar.notify_all();
36 }
37 }
38
39 pub fn get(&self) -> usize {
41 self.count.load(Ordering::Acquire)
42 }
43
44 pub fn wait_until_zero(&self, timeout: Duration) -> bool {
47 let mut guard = self.mutex.lock();
48 if self.count.load(Ordering::Acquire) == 0 {
49 return true;
50 }
51 !self
52 .condvar
53 .wait_while_for(
54 &mut guard,
55 |_| self.count.load(Ordering::Acquire) != 0,
56 timeout,
57 )
58 .timed_out()
59 }
60}
61
62impl Default for QueuedArrayCounter {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68pub struct ArrayMessage {
72 pub array: Arc<NDArray>,
73 pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
74 pub(crate) done_tx: Option<tokio::sync::oneshot::Sender<()>>,
77}
78
79impl Drop for ArrayMessage {
80 fn drop(&mut self) {
81 if let Some(tx) = self.done_tx.take() {
82 let _ = tx.send(());
83 }
84 if let Some(c) = self.counter.take() {
85 c.decrement();
86 }
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum PublishOutcome {
93 Delivered,
95 Disabled,
97 DroppedQueueFull,
100 ChannelClosed,
102}
103
104#[derive(Clone)]
120pub struct NDArraySender {
121 tx: tokio::sync::mpsc::Sender<ArrayMessage>,
122 port_name: String,
123 enabled: Arc<AtomicBool>,
124 blocking_mode: Arc<AtomicBool>,
125 queued_counter: Option<Arc<QueuedArrayCounter>>,
126 dropped_arrays: Arc<AtomicI32>,
132}
133
134impl NDArraySender {
135 pub async fn publish(&self, array: Arc<NDArray>) -> PublishOutcome {
143 if !self.enabled.load(Ordering::Acquire) {
144 return PublishOutcome::Disabled;
145 }
146
147 let blocking = self.blocking_mode.load(Ordering::Acquire);
148
149 if !blocking {
150 if let Some(ref c) = self.queued_counter {
153 c.increment();
154 }
155 let msg = ArrayMessage {
156 array,
157 counter: self.queued_counter.clone(),
158 done_tx: None,
159 };
160 return match self.tx.try_send(msg) {
161 Ok(()) => PublishOutcome::Delivered,
162 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
164 self.dropped_arrays.fetch_add(1, Ordering::AcqRel);
165 PublishOutcome::DroppedQueueFull
166 }
167 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
168 PublishOutcome::ChannelClosed
169 }
170 };
171 }
172
173 if let Some(ref c) = self.queued_counter {
175 c.increment();
176 }
177 let (done_tx, done_rx) = tokio::sync::oneshot::channel();
178 let msg = ArrayMessage {
179 array,
180 counter: self.queued_counter.clone(),
181 done_tx: Some(done_tx),
182 };
183 if self.tx.send(msg).await.is_err() {
184 return PublishOutcome::ChannelClosed;
186 }
187 let _ = done_rx.await;
188 PublishOutcome::Delivered
189 }
190
191 pub fn is_enabled(&self) -> bool {
193 self.enabled.load(Ordering::Acquire)
194 }
195
196 pub fn is_blocking(&self) -> bool {
198 self.blocking_mode.load(Ordering::Acquire)
199 }
200
201 pub fn port_name(&self) -> &str {
202 &self.port_name
203 }
204
205 pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
207 self.queued_counter = Some(counter);
208 }
209
210 pub fn set_dropped_arrays_counter(&mut self, counter: Arc<AtomicI32>) {
213 self.dropped_arrays = counter;
214 }
215
216 pub fn dropped_arrays_counter(&self) -> &Arc<AtomicI32> {
218 &self.dropped_arrays
219 }
220
221 pub fn capacity(&self) -> usize {
223 self.tx.capacity()
224 }
225
226 pub fn max_capacity(&self) -> usize {
228 self.tx.max_capacity()
229 }
230
231 pub(crate) fn set_mode_flags(
233 &mut self,
234 enabled: Arc<AtomicBool>,
235 blocking_mode: Arc<AtomicBool>,
236 ) {
237 self.enabled = enabled;
238 self.blocking_mode = blocking_mode;
239 }
240}
241
242pub struct NDArrayReceiver {
244 rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
245}
246
247impl NDArrayReceiver {
248 pub fn pending(&self) -> usize {
250 self.rx.len()
251 }
252
253 pub fn max_capacity(&self) -> usize {
255 self.rx.max_capacity()
256 }
257
258 pub fn capacity(&self) -> usize {
260 self.rx.capacity()
261 }
262
263 pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
265 self.rx.blocking_recv().map(|msg| msg.array.clone())
266 }
267
268 pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
270 self.rx.recv().await.map(|msg| msg.array.clone())
271 }
272
273 pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
276 self.rx.recv().await
277 }
278}
279
280pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
282 let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
283 (
284 NDArraySender {
285 tx,
286 port_name: port_name.to_string(),
287 enabled: Arc::new(AtomicBool::new(true)),
288 blocking_mode: Arc::new(AtomicBool::new(false)),
289 queued_counter: None,
290 dropped_arrays: Arc::new(AtomicI32::new(0)),
291 },
292 NDArrayReceiver { rx },
293 )
294}
295
296pub struct NDArrayOutput {
298 senders: Vec<NDArraySender>,
299}
300
301impl NDArrayOutput {
302 pub fn new() -> Self {
303 Self {
304 senders: Vec::new(),
305 }
306 }
307
308 pub fn add(&mut self, sender: NDArraySender) {
309 self.senders.push(sender);
310 }
311
312 pub fn remove(&mut self, port_name: &str) {
313 self.senders.retain(|s| s.port_name != port_name);
314 }
315
316 pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
318 let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
319 Some(self.senders.swap_remove(idx))
320 }
321
322 pub async fn publish(&self, array: Arc<NDArray>) -> Vec<PublishOutcome> {
328 let futs = self.senders.iter().map(|s| s.publish(array.clone()));
329 futures_util::future::join_all(futs).await
330 }
331
332 pub async fn publish_to(&self, index: usize, array: Arc<NDArray>) -> Option<PublishOutcome> {
334 if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
335 Some(sender.publish(array).await)
336 } else {
337 None
338 }
339 }
340
341 pub fn num_senders(&self) -> usize {
342 self.senders.len()
343 }
344
345 pub(crate) fn senders_clone(&self) -> Vec<NDArraySender> {
347 self.senders.clone()
348 }
349}
350
351#[derive(Clone)]
364pub struct ArrayPublisher {
365 output: Arc<parking_lot::Mutex<NDArrayOutput>>,
366}
367
368impl ArrayPublisher {
369 pub fn new(output: Arc<parking_lot::Mutex<NDArrayOutput>>) -> Self {
371 Self { output }
372 }
373
374 pub async fn publish(&self, array: Arc<NDArray>) -> Vec<PublishOutcome> {
381 let senders = self.output.lock().senders_clone();
382 let futs = senders.iter().map(|s| s.publish(array.clone()));
383 futures_util::future::join_all(futs).await
384 }
385}
386
387impl Default for NDArrayOutput {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use crate::ndarray::{NDArray, NDDataType, NDDimension};
397
398 fn make_test_array(id: i32) -> Arc<NDArray> {
399 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
400 arr.unique_id = id;
401 Arc::new(arr)
402 }
403
404 #[tokio::test]
405 async fn test_publish_receive_basic() {
406 let (sender, mut receiver) = ndarray_channel("TEST", 10);
407 sender.publish(make_test_array(1)).await;
408 sender.publish(make_test_array(2)).await;
409
410 let a1 = receiver.recv().await.unwrap();
411 assert_eq!(a1.unique_id, 1);
412 let a2 = receiver.recv().await.unwrap();
413 assert_eq!(a2.unique_id, 2);
414 }
415
416 #[tokio::test]
417 async fn test_publish_blocking_no_drop() {
418 let (sender, mut receiver) = ndarray_channel("TEST", 1);
421 sender.blocking_mode.store(true, Ordering::Release);
422
423 let s = sender.clone();
424 let pub_handle = tokio::spawn(async move {
425 s.publish(make_test_array(1)).await;
426 s.publish(make_test_array(2)).await;
427 s.publish(make_test_array(3)).await;
428 });
429
430 let a1 = receiver.recv().await.unwrap();
432 assert_eq!(a1.unique_id, 1);
433 let a2 = receiver.recv().await.unwrap();
434 assert_eq!(a2.unique_id, 2);
435 let a3 = receiver.recv().await.unwrap();
436 assert_eq!(a3.unique_id, 3);
437
438 pub_handle.await.unwrap();
439 }
440
441 #[tokio::test]
442 async fn test_publish_drops_on_full_queue() {
443 let (sender, _receiver) = ndarray_channel("TEST", 1);
446
447 assert_eq!(
449 sender.publish(make_test_array(1)).await,
450 PublishOutcome::Delivered
451 );
452 assert_eq!(
454 sender.publish(make_test_array(2)).await,
455 PublishOutcome::DroppedQueueFull
456 );
457 }
458
459 #[tokio::test]
460 async fn test_drop_on_full_does_not_leak_counter() {
461 let counter = Arc::new(QueuedArrayCounter::new());
463 let (mut sender, _receiver) = ndarray_channel("TEST", 1);
464 sender.set_queued_counter(counter.clone());
465
466 sender.publish(make_test_array(1)).await; assert_eq!(counter.get(), 1);
468 let outcome = sender.publish(make_test_array(2)).await; assert_eq!(outcome, PublishOutcome::DroppedQueueFull);
470 assert_eq!(counter.get(), 1);
472 }
473
474 #[tokio::test]
475 async fn test_blocking_callbacks_completion_wait() {
476 let (sender, mut receiver) = ndarray_channel("TEST", 10);
477 sender.blocking_mode.store(true, Ordering::Release);
478
479 let completed = Arc::new(AtomicBool::new(false));
480 let completed_clone = completed.clone();
481
482 let recv_handle = tokio::spawn(async move {
484 let msg = receiver.recv_msg().await.unwrap();
485 assert_eq!(msg.array.unique_id, 42);
486 tokio::time::sleep(Duration::from_millis(50)).await;
488 completed_clone.store(true, Ordering::Release);
489 });
491
492 sender.publish(make_test_array(42)).await;
494
495 assert!(completed.load(Ordering::Acquire));
497
498 recv_handle.await.unwrap();
499 }
500
501 #[tokio::test]
502 async fn test_fanout_three_receivers() {
503 let (s1, mut r1) = ndarray_channel("P1", 10);
504 let (s2, mut r2) = ndarray_channel("P2", 10);
505 let (s3, mut r3) = ndarray_channel("P3", 10);
506
507 let mut output = NDArrayOutput::new();
508 output.add(s1);
509 output.add(s2);
510 output.add(s3);
511
512 output.publish(make_test_array(42)).await;
513
514 assert_eq!(r1.recv().await.unwrap().unique_id, 42);
515 assert_eq!(r2.recv().await.unwrap().unique_id, 42);
516 assert_eq!(r3.recv().await.unwrap().unique_id, 42);
517 }
518
519 #[test]
520 fn test_blocking_recv() {
521 let rt = tokio::runtime::Builder::new_current_thread()
522 .enable_all()
523 .build()
524 .unwrap();
525 let (sender, mut receiver) = ndarray_channel("TEST", 10);
526
527 let handle = std::thread::spawn(move || {
528 let arr = receiver.blocking_recv().unwrap();
529 arr.unique_id
530 });
531
532 rt.block_on(sender.publish(make_test_array(99)));
533 let id = handle.join().unwrap();
534 assert_eq!(id, 99);
535 }
536
537 #[tokio::test]
538 async fn test_channel_closed_on_receiver_drop() {
539 let (sender, receiver) = ndarray_channel("TEST", 10);
540 drop(receiver);
541 sender.publish(make_test_array(1)).await;
543 }
544
545 #[test]
546 fn test_queued_counter_basic() {
547 let counter = QueuedArrayCounter::new();
548 assert_eq!(counter.get(), 0);
549 counter.increment();
550 assert_eq!(counter.get(), 1);
551 counter.increment();
552 assert_eq!(counter.get(), 2);
553 counter.decrement();
554 assert_eq!(counter.get(), 1);
555 counter.decrement();
556 assert_eq!(counter.get(), 0);
557 }
558
559 #[test]
560 fn test_queued_counter_wait_until_zero() {
561 let counter = Arc::new(QueuedArrayCounter::new());
562 counter.increment();
563 counter.increment();
564
565 let c = counter.clone();
566 let h = std::thread::spawn(move || {
567 std::thread::sleep(Duration::from_millis(10));
568 c.decrement();
569 std::thread::sleep(Duration::from_millis(10));
570 c.decrement();
571 });
572
573 assert!(counter.wait_until_zero(Duration::from_secs(5)));
574 h.join().unwrap();
575 }
576
577 #[test]
578 fn test_queued_counter_wait_timeout() {
579 let counter = Arc::new(QueuedArrayCounter::new());
580 counter.increment();
581 assert!(!counter.wait_until_zero(Duration::from_millis(10)));
582 }
583
584 #[tokio::test]
585 async fn test_publish_increments_counter() {
586 let counter = Arc::new(QueuedArrayCounter::new());
587 let (mut sender, mut _receiver) = ndarray_channel("TEST", 10);
588 sender.set_queued_counter(counter.clone());
589
590 sender.publish(make_test_array(1)).await;
591 assert_eq!(counter.get(), 1);
592 sender.publish(make_test_array(2)).await;
593 assert_eq!(counter.get(), 2);
594 }
595
596 #[tokio::test]
597 async fn test_message_drop_decrements() {
598 let counter = Arc::new(QueuedArrayCounter::new());
599 counter.increment();
600 let msg = ArrayMessage {
601 array: make_test_array(1),
602 counter: Some(counter.clone()),
603 done_tx: None,
604 };
605 assert_eq!(counter.get(), 1);
606 drop(msg);
607 assert_eq!(counter.get(), 0);
608 }
609}