1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7pub struct QueuedArrayCounter {
10 count: AtomicUsize,
11 mutex: parking_lot::Mutex<()>,
12 condvar: parking_lot::Condvar,
13}
14
15impl QueuedArrayCounter {
16 pub fn new() -> Self {
18 Self {
19 count: AtomicUsize::new(0),
20 mutex: parking_lot::Mutex::new(()),
21 condvar: parking_lot::Condvar::new(),
22 }
23 }
24
25 pub fn increment(&self) {
27 self.count.fetch_add(1, Ordering::AcqRel);
28 }
29
30 pub fn decrement(&self) {
32 let prev = self.count.fetch_sub(1, Ordering::AcqRel);
33 if prev == 1 {
34 let _guard = self.mutex.lock();
35 self.condvar.notify_all();
36 }
37 }
38
39 pub fn get(&self) -> usize {
41 self.count.load(Ordering::Acquire)
42 }
43
44 pub fn wait_until_zero(&self, timeout: Duration) -> bool {
47 let mut guard = self.mutex.lock();
48 if self.count.load(Ordering::Acquire) == 0 {
49 return true;
50 }
51 !self
52 .condvar
53 .wait_while_for(
54 &mut guard,
55 |_| self.count.load(Ordering::Acquire) != 0,
56 timeout,
57 )
58 .timed_out()
59 }
60}
61
62impl Default for QueuedArrayCounter {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68pub struct ArrayMessage {
72 pub array: Arc<NDArray>,
73 pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
74 pub(crate) done_tx: Option<tokio::sync::oneshot::Sender<()>>,
77}
78
79impl Drop for ArrayMessage {
80 fn drop(&mut self) {
81 if let Some(tx) = self.done_tx.take() {
82 let _ = tx.send(());
83 }
84 if let Some(c) = self.counter.take() {
85 c.decrement();
86 }
87 }
88}
89
90#[derive(Clone)]
104pub struct NDArraySender {
105 tx: tokio::sync::mpsc::Sender<ArrayMessage>,
106 port_name: String,
107 enabled: Arc<AtomicBool>,
108 blocking_mode: Arc<AtomicBool>,
109 queued_counter: Option<Arc<QueuedArrayCounter>>,
110}
111
112impl NDArraySender {
113 pub async fn publish(&self, array: Arc<NDArray>) {
119 if !self.enabled.load(Ordering::Acquire) {
120 return;
121 }
122 if let Some(ref c) = self.queued_counter {
123 c.increment();
124 }
125
126 let blocking = self.blocking_mode.load(Ordering::Acquire);
127 let (done_tx, done_rx) = if blocking {
128 let (tx, rx) = tokio::sync::oneshot::channel();
129 (Some(tx), Some(rx))
130 } else {
131 (None, None)
132 };
133
134 let msg = ArrayMessage {
135 array,
136 counter: self.queued_counter.clone(),
137 done_tx,
138 };
139
140 if self.tx.send(msg).await.is_err() {
141 return;
143 }
144
145 if let Some(rx) = done_rx {
147 let _ = rx.await;
148 }
149 }
150
151 pub fn is_enabled(&self) -> bool {
153 self.enabled.load(Ordering::Acquire)
154 }
155
156 pub fn is_blocking(&self) -> bool {
158 self.blocking_mode.load(Ordering::Acquire)
159 }
160
161 pub fn port_name(&self) -> &str {
162 &self.port_name
163 }
164
165 pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
167 self.queued_counter = Some(counter);
168 }
169
170 pub(crate) fn set_mode_flags(
172 &mut self,
173 enabled: Arc<AtomicBool>,
174 blocking_mode: Arc<AtomicBool>,
175 ) {
176 self.enabled = enabled;
177 self.blocking_mode = blocking_mode;
178 }
179}
180
181pub struct NDArrayReceiver {
183 rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
184}
185
186impl NDArrayReceiver {
187 pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
189 self.rx.blocking_recv().map(|msg| msg.array.clone())
190 }
191
192 pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
194 self.rx.recv().await.map(|msg| msg.array.clone())
195 }
196
197 pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
200 self.rx.recv().await
201 }
202}
203
204pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
206 let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
207 (
208 NDArraySender {
209 tx,
210 port_name: port_name.to_string(),
211 enabled: Arc::new(AtomicBool::new(true)),
212 blocking_mode: Arc::new(AtomicBool::new(false)),
213 queued_counter: None,
214 },
215 NDArrayReceiver { rx },
216 )
217}
218
219pub struct NDArrayOutput {
221 senders: Vec<NDArraySender>,
222}
223
224impl NDArrayOutput {
225 pub fn new() -> Self {
226 Self {
227 senders: Vec::new(),
228 }
229 }
230
231 pub fn add(&mut self, sender: NDArraySender) {
232 self.senders.push(sender);
233 }
234
235 pub fn remove(&mut self, port_name: &str) {
236 self.senders.retain(|s| s.port_name != port_name);
237 }
238
239 pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
241 let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
242 Some(self.senders.swap_remove(idx))
243 }
244
245 pub async fn publish(&self, array: Arc<NDArray>) {
252 let futs = self.senders.iter().map(|s| s.publish(array.clone()));
253 futures_util::future::join_all(futs).await;
254 }
255
256 pub async fn publish_to(&self, index: usize, array: Arc<NDArray>) {
258 if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
259 sender.publish(array).await;
260 }
261 }
262
263 pub fn num_senders(&self) -> usize {
264 self.senders.len()
265 }
266
267 pub(crate) fn senders_clone(&self) -> Vec<NDArraySender> {
269 self.senders.clone()
270 }
271}
272
273#[derive(Clone)]
286pub struct ArrayPublisher {
287 output: Arc<parking_lot::Mutex<NDArrayOutput>>,
288}
289
290impl ArrayPublisher {
291 pub fn new(output: Arc<parking_lot::Mutex<NDArrayOutput>>) -> Self {
293 Self { output }
294 }
295
296 pub async fn publish(&self, array: Arc<NDArray>) {
298 let senders = self.output.lock().senders_clone();
299 let futs = senders.iter().map(|s| s.publish(array.clone()));
300 futures_util::future::join_all(futs).await;
301 }
302}
303
304impl Default for NDArrayOutput {
305 fn default() -> Self {
306 Self::new()
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use crate::ndarray::{NDArray, NDDataType, NDDimension};
314
315 fn make_test_array(id: i32) -> Arc<NDArray> {
316 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
317 arr.unique_id = id;
318 Arc::new(arr)
319 }
320
321 #[tokio::test]
322 async fn test_publish_receive_basic() {
323 let (sender, mut receiver) = ndarray_channel("TEST", 10);
324 sender.publish(make_test_array(1)).await;
325 sender.publish(make_test_array(2)).await;
326
327 let a1 = receiver.recv().await.unwrap();
328 assert_eq!(a1.unique_id, 1);
329 let a2 = receiver.recv().await.unwrap();
330 assert_eq!(a2.unique_id, 2);
331 }
332
333 #[tokio::test]
334 async fn test_publish_no_drop() {
335 let (sender, mut receiver) = ndarray_channel("TEST", 1);
337
338 let s = sender.clone();
340 let pub_handle = tokio::spawn(async move {
341 s.publish(make_test_array(1)).await;
342 s.publish(make_test_array(2)).await;
343 s.publish(make_test_array(3)).await;
344 });
345
346 let a1 = receiver.recv().await.unwrap();
348 assert_eq!(a1.unique_id, 1);
349 let a2 = receiver.recv().await.unwrap();
350 assert_eq!(a2.unique_id, 2);
351 let a3 = receiver.recv().await.unwrap();
352 assert_eq!(a3.unique_id, 3);
353
354 pub_handle.await.unwrap();
355 }
356
357 #[tokio::test]
358 async fn test_blocking_callbacks_completion_wait() {
359 let (sender, mut receiver) = ndarray_channel("TEST", 10);
360 sender.blocking_mode.store(true, Ordering::Release);
361
362 let completed = Arc::new(AtomicBool::new(false));
363 let completed_clone = completed.clone();
364
365 let recv_handle = tokio::spawn(async move {
367 let msg = receiver.recv_msg().await.unwrap();
368 assert_eq!(msg.array.unique_id, 42);
369 tokio::time::sleep(Duration::from_millis(50)).await;
371 completed_clone.store(true, Ordering::Release);
372 });
374
375 sender.publish(make_test_array(42)).await;
377
378 assert!(completed.load(Ordering::Acquire));
380
381 recv_handle.await.unwrap();
382 }
383
384 #[tokio::test]
385 async fn test_fanout_three_receivers() {
386 let (s1, mut r1) = ndarray_channel("P1", 10);
387 let (s2, mut r2) = ndarray_channel("P2", 10);
388 let (s3, mut r3) = ndarray_channel("P3", 10);
389
390 let mut output = NDArrayOutput::new();
391 output.add(s1);
392 output.add(s2);
393 output.add(s3);
394
395 output.publish(make_test_array(42)).await;
396
397 assert_eq!(r1.recv().await.unwrap().unique_id, 42);
398 assert_eq!(r2.recv().await.unwrap().unique_id, 42);
399 assert_eq!(r3.recv().await.unwrap().unique_id, 42);
400 }
401
402 #[test]
403 fn test_blocking_recv() {
404 let rt = tokio::runtime::Builder::new_current_thread()
405 .enable_all()
406 .build()
407 .unwrap();
408 let (sender, mut receiver) = ndarray_channel("TEST", 10);
409
410 let handle = std::thread::spawn(move || {
411 let arr = receiver.blocking_recv().unwrap();
412 arr.unique_id
413 });
414
415 rt.block_on(sender.publish(make_test_array(99)));
416 let id = handle.join().unwrap();
417 assert_eq!(id, 99);
418 }
419
420 #[tokio::test]
421 async fn test_channel_closed_on_receiver_drop() {
422 let (sender, receiver) = ndarray_channel("TEST", 10);
423 drop(receiver);
424 sender.publish(make_test_array(1)).await;
426 }
427
428 #[test]
429 fn test_queued_counter_basic() {
430 let counter = QueuedArrayCounter::new();
431 assert_eq!(counter.get(), 0);
432 counter.increment();
433 assert_eq!(counter.get(), 1);
434 counter.increment();
435 assert_eq!(counter.get(), 2);
436 counter.decrement();
437 assert_eq!(counter.get(), 1);
438 counter.decrement();
439 assert_eq!(counter.get(), 0);
440 }
441
442 #[test]
443 fn test_queued_counter_wait_until_zero() {
444 let counter = Arc::new(QueuedArrayCounter::new());
445 counter.increment();
446 counter.increment();
447
448 let c = counter.clone();
449 let h = std::thread::spawn(move || {
450 std::thread::sleep(Duration::from_millis(10));
451 c.decrement();
452 std::thread::sleep(Duration::from_millis(10));
453 c.decrement();
454 });
455
456 assert!(counter.wait_until_zero(Duration::from_secs(5)));
457 h.join().unwrap();
458 }
459
460 #[test]
461 fn test_queued_counter_wait_timeout() {
462 let counter = Arc::new(QueuedArrayCounter::new());
463 counter.increment();
464 assert!(!counter.wait_until_zero(Duration::from_millis(10)));
465 }
466
467 #[tokio::test]
468 async fn test_publish_increments_counter() {
469 let counter = Arc::new(QueuedArrayCounter::new());
470 let (mut sender, mut _receiver) = ndarray_channel("TEST", 10);
471 sender.set_queued_counter(counter.clone());
472
473 sender.publish(make_test_array(1)).await;
474 assert_eq!(counter.get(), 1);
475 sender.publish(make_test_array(2)).await;
476 assert_eq!(counter.get(), 2);
477 }
478
479 #[tokio::test]
480 async fn test_message_drop_decrements() {
481 let counter = Arc::new(QueuedArrayCounter::new());
482 counter.increment();
483 let msg = ArrayMessage {
484 array: make_test_array(1),
485 counter: Some(counter.clone()),
486 done_tx: None,
487 };
488 assert_eq!(counter.get(), 1);
489 drop(msg);
490 assert_eq!(counter.get(), 0);
491 }
492}