1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7pub(crate) trait BlockingProcessFn: Send + Sync {
9 fn process_and_publish(&self, array: &NDArray);
10}
11
12pub struct QueuedArrayCounter {
15 count: AtomicUsize,
16 mutex: parking_lot::Mutex<()>,
17 condvar: parking_lot::Condvar,
18}
19
20impl QueuedArrayCounter {
21 pub fn new() -> Self {
23 Self {
24 count: AtomicUsize::new(0),
25 mutex: parking_lot::Mutex::new(()),
26 condvar: parking_lot::Condvar::new(),
27 }
28 }
29
30 pub fn increment(&self) {
32 self.count.fetch_add(1, Ordering::AcqRel);
33 }
34
35 pub fn decrement(&self) {
37 let prev = self.count.fetch_sub(1, Ordering::AcqRel);
38 if prev == 1 {
39 let _guard = self.mutex.lock();
40 self.condvar.notify_all();
41 }
42 }
43
44 pub fn get(&self) -> usize {
46 self.count.load(Ordering::Acquire)
47 }
48
49 pub fn wait_until_zero(&self, timeout: Duration) -> bool {
52 let mut guard = self.mutex.lock();
53 if self.count.load(Ordering::Acquire) == 0 {
54 return true;
55 }
56 !self
57 .condvar
58 .wait_while_for(
59 &mut guard,
60 |_| self.count.load(Ordering::Acquire) != 0,
61 timeout,
62 )
63 .timed_out()
64 }
65}
66
67impl Default for QueuedArrayCounter {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73pub struct ArrayMessage {
76 pub array: Arc<NDArray>,
77 pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
78}
79
80impl Drop for ArrayMessage {
81 fn drop(&mut self) {
82 if let Some(c) = self.counter.take() {
83 c.decrement();
84 }
85 }
86}
87
88#[derive(Clone)]
90pub struct NDArraySender {
91 tx: tokio::sync::mpsc::Sender<ArrayMessage>,
92 port_name: String,
93 dropped_count: Arc<AtomicU64>,
94 enabled: Arc<AtomicBool>,
95 blocking_mode: Arc<AtomicBool>,
96 blocking_processor: Option<Arc<dyn BlockingProcessFn>>,
97 queued_counter: Option<Arc<QueuedArrayCounter>>,
98}
99
100impl NDArraySender {
101 pub fn send(&self, array: Arc<NDArray>) {
106 if !self.enabled.load(Ordering::Acquire) {
107 return;
108 }
109 if self.blocking_mode.load(Ordering::Acquire) {
110 if let Some(ref bp) = self.blocking_processor {
111 bp.process_and_publish(&array);
112 return;
113 }
114 }
115 if let Some(ref c) = self.queued_counter {
117 c.increment();
118 }
119 let msg = ArrayMessage {
120 array,
121 counter: self.queued_counter.clone(),
122 };
123 match self.tx.try_send(msg) {
124 Ok(()) => {}
125 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
126 self.dropped_count.fetch_add(1, Ordering::Relaxed);
128 }
129 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
130 }
132 }
133 }
134
135 pub fn is_enabled(&self) -> bool {
137 self.enabled.load(Ordering::Acquire)
138 }
139
140 pub fn is_blocking(&self) -> bool {
142 self.blocking_mode.load(Ordering::Acquire)
143 }
144
145 pub fn port_name(&self) -> &str {
146 &self.port_name
147 }
148
149 pub fn dropped_count(&self) -> u64 {
150 self.dropped_count.load(Ordering::Relaxed)
151 }
152
153 pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
155 self.queued_counter = Some(counter);
156 }
157
158 pub(crate) fn with_blocking_support(
160 self,
161 enabled: Arc<AtomicBool>,
162 blocking_mode: Arc<AtomicBool>,
163 blocking_processor: Arc<dyn BlockingProcessFn>,
164 ) -> Self {
165 Self {
166 enabled,
167 blocking_mode,
168 blocking_processor: Some(blocking_processor),
169 ..self
170 }
171 }
172}
173
174pub struct NDArrayReceiver {
176 rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
177}
178
179impl NDArrayReceiver {
180 pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
182 self.rx.blocking_recv().map(|msg| msg.array.clone())
183 }
184
185 pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
187 self.rx.recv().await.map(|msg| msg.array.clone())
188 }
189
190 pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
193 self.rx.recv().await
194 }
195}
196
197pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
199 let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
200 (
201 NDArraySender {
202 tx,
203 port_name: port_name.to_string(),
204 dropped_count: Arc::new(AtomicU64::new(0)),
205 enabled: Arc::new(AtomicBool::new(true)),
206 blocking_mode: Arc::new(AtomicBool::new(false)),
207 blocking_processor: None,
208 queued_counter: None,
209 },
210 NDArrayReceiver { rx },
211 )
212}
213
214pub struct NDArrayOutput {
216 senders: Vec<NDArraySender>,
217}
218
219impl NDArrayOutput {
220 pub fn new() -> Self {
221 Self {
222 senders: Vec::new(),
223 }
224 }
225
226 pub fn add(&mut self, sender: NDArraySender) {
227 self.senders.push(sender);
228 }
229
230 pub fn remove(&mut self, port_name: &str) {
231 self.senders.retain(|s| s.port_name != port_name);
232 }
233
234 pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
236 let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
237 Some(self.senders.swap_remove(idx))
238 }
239
240 pub fn publish(&self, array: Arc<NDArray>) {
242 for sender in &self.senders {
243 sender.send(array.clone());
244 }
245 }
246
247 pub fn publish_to(&self, index: usize, array: Arc<NDArray>) {
249 if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
250 sender.send(array);
251 }
252 }
253
254 pub fn total_dropped(&self) -> u64 {
255 self.senders.iter().map(|s| s.dropped_count()).sum()
256 }
257
258 pub fn num_senders(&self) -> usize {
259 self.senders.len()
260 }
261}
262
263impl Default for NDArrayOutput {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use crate::ndarray::{NDArray, NDDataType, NDDimension};
273
274 fn make_test_array(id: i32) -> Arc<NDArray> {
275 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
276 arr.unique_id = id;
277 Arc::new(arr)
278 }
279
280 #[test]
281 fn test_send_receive_basic() {
282 let (sender, mut receiver) = ndarray_channel("TEST", 10);
283 sender.send(make_test_array(1));
284 sender.send(make_test_array(2));
285
286 let rt = tokio::runtime::Builder::new_current_thread()
287 .enable_all()
288 .build()
289 .unwrap();
290 rt.block_on(async {
291 let a1 = receiver.recv().await.unwrap();
292 assert_eq!(a1.unique_id, 1);
293 let a2 = receiver.recv().await.unwrap();
294 assert_eq!(a2.unique_id, 2);
295 });
296 }
297
298 #[test]
299 fn test_back_pressure_drops() {
300 let (sender, _receiver) = ndarray_channel("TEST", 2);
301 sender.send(make_test_array(1));
303 sender.send(make_test_array(2));
304 sender.send(make_test_array(3));
306 sender.send(make_test_array(4));
307
308 assert_eq!(sender.dropped_count(), 2);
309 }
310
311 #[test]
312 fn test_fanout_three_receivers() {
313 let (s1, mut r1) = ndarray_channel("P1", 10);
314 let (s2, mut r2) = ndarray_channel("P2", 10);
315 let (s3, mut r3) = ndarray_channel("P3", 10);
316
317 let mut output = NDArrayOutput::new();
318 output.add(s1);
319 output.add(s2);
320 output.add(s3);
321
322 output.publish(make_test_array(42));
323
324 let rt = tokio::runtime::Builder::new_current_thread()
325 .enable_all()
326 .build()
327 .unwrap();
328 rt.block_on(async {
329 assert_eq!(r1.recv().await.unwrap().unique_id, 42);
330 assert_eq!(r2.recv().await.unwrap().unique_id, 42);
331 assert_eq!(r3.recv().await.unwrap().unique_id, 42);
332 });
333 }
334
335 #[test]
336 fn test_fanout_total_dropped() {
337 let (s1, _r1) = ndarray_channel("P1", 1);
338 let (s2, _r2) = ndarray_channel("P2", 1);
339
340 let mut output = NDArrayOutput::new();
341 output.add(s1);
342 output.add(s2);
343
344 output.publish(make_test_array(1));
346 output.publish(make_test_array(2));
348
349 assert_eq!(output.total_dropped(), 2);
350 }
351
352 #[test]
353 fn test_fanout_remove() {
354 let (s1, _r1) = ndarray_channel("P1", 10);
355 let (s2, _r2) = ndarray_channel("P2", 10);
356
357 let mut output = NDArrayOutput::new();
358 output.add(s1);
359 output.add(s2);
360 assert_eq!(output.num_senders(), 2);
361
362 output.remove("P1");
363 assert_eq!(output.num_senders(), 1);
364 }
365
366 #[test]
367 fn test_blocking_recv() {
368 let (sender, mut receiver) = ndarray_channel("TEST", 10);
369
370 let handle = std::thread::spawn(move || {
371 let arr = receiver.blocking_recv().unwrap();
372 arr.unique_id
373 });
374
375 sender.send(make_test_array(99));
376 let id = handle.join().unwrap();
377 assert_eq!(id, 99);
378 }
379
380 #[test]
381 fn test_channel_closed_on_receiver_drop() {
382 let (sender, receiver) = ndarray_channel("TEST", 10);
383 drop(receiver);
384 sender.send(make_test_array(1));
386 assert_eq!(sender.dropped_count(), 0); }
388
389 #[test]
390 fn test_queued_counter_basic() {
391 let counter = QueuedArrayCounter::new();
392 assert_eq!(counter.get(), 0);
393 counter.increment();
394 assert_eq!(counter.get(), 1);
395 counter.increment();
396 assert_eq!(counter.get(), 2);
397 counter.decrement();
398 assert_eq!(counter.get(), 1);
399 counter.decrement();
400 assert_eq!(counter.get(), 0);
401 }
402
403 #[test]
404 fn test_queued_counter_wait_until_zero() {
405 let counter = Arc::new(QueuedArrayCounter::new());
406 counter.increment();
407 counter.increment();
408
409 let c = counter.clone();
410 let h = std::thread::spawn(move || {
411 std::thread::sleep(Duration::from_millis(10));
412 c.decrement();
413 std::thread::sleep(Duration::from_millis(10));
414 c.decrement();
415 });
416
417 assert!(counter.wait_until_zero(Duration::from_secs(5)));
418 h.join().unwrap();
419 }
420
421 #[test]
422 fn test_queued_counter_wait_timeout() {
423 let counter = Arc::new(QueuedArrayCounter::new());
424 counter.increment();
425 assert!(!counter.wait_until_zero(Duration::from_millis(10)));
426 }
427
428 #[test]
429 fn test_send_increments_counter() {
430 let counter = Arc::new(QueuedArrayCounter::new());
431 let (mut sender, _receiver) = ndarray_channel("TEST", 10);
432 sender.set_queued_counter(counter.clone());
433
434 sender.send(make_test_array(1));
435 assert_eq!(counter.get(), 1);
436 sender.send(make_test_array(2));
437 assert_eq!(counter.get(), 2);
438 }
439
440 #[test]
441 fn test_send_queue_full_no_net_increment() {
442 let counter = Arc::new(QueuedArrayCounter::new());
443 let (mut sender, _receiver) = ndarray_channel("TEST", 1);
444 sender.set_queued_counter(counter.clone());
445
446 sender.send(make_test_array(1)); assert_eq!(counter.get(), 1);
448 sender.send(make_test_array(2)); assert_eq!(counter.get(), 1);
450 }
451
452 #[test]
453 fn test_message_drop_decrements() {
454 let counter = Arc::new(QueuedArrayCounter::new());
455 counter.increment();
456 let msg = ArrayMessage {
457 array: make_test_array(1),
458 counter: Some(counter.clone()),
459 };
460 assert_eq!(counter.get(), 1);
461 drop(msg);
462 assert_eq!(counter.get(), 0);
463 }
464}