1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7pub(crate) trait BlockingProcessFn: Send + Sync {
9 fn process_and_publish(&self, array: &NDArray);
10}
11
12pub struct QueuedArrayCounter {
15 count: AtomicUsize,
16 mutex: parking_lot::Mutex<()>,
17 condvar: parking_lot::Condvar,
18}
19
20impl QueuedArrayCounter {
21 pub fn new() -> Self {
23 Self {
24 count: AtomicUsize::new(0),
25 mutex: parking_lot::Mutex::new(()),
26 condvar: parking_lot::Condvar::new(),
27 }
28 }
29
30 pub fn increment(&self) {
32 self.count.fetch_add(1, Ordering::AcqRel);
33 }
34
35 pub fn decrement(&self) {
37 let prev = self.count.fetch_sub(1, Ordering::AcqRel);
38 if prev == 1 {
39 let _guard = self.mutex.lock();
40 self.condvar.notify_all();
41 }
42 }
43
44 pub fn get(&self) -> usize {
46 self.count.load(Ordering::Acquire)
47 }
48
49 pub fn wait_until_zero(&self, timeout: Duration) -> bool {
52 let mut guard = self.mutex.lock();
53 if self.count.load(Ordering::Acquire) == 0 {
54 return true;
55 }
56 !self
57 .condvar
58 .wait_while_for(
59 &mut guard,
60 |_| self.count.load(Ordering::Acquire) != 0,
61 timeout,
62 )
63 .timed_out()
64 }
65}
66
67impl Default for QueuedArrayCounter {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73pub struct ArrayMessage {
76 pub array: Arc<NDArray>,
77 pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
78}
79
80impl Drop for ArrayMessage {
81 fn drop(&mut self) {
82 if let Some(c) = self.counter.take() {
83 c.decrement();
84 }
85 }
86}
87
88#[derive(Clone)]
90pub struct NDArraySender {
91 tx: tokio::sync::mpsc::Sender<ArrayMessage>,
92 port_name: String,
93 dropped_count: Arc<AtomicU64>,
94 enabled: Arc<AtomicBool>,
95 blocking_mode: Arc<AtomicBool>,
96 blocking_processor: Option<Arc<dyn BlockingProcessFn>>,
97 queued_counter: Option<Arc<QueuedArrayCounter>>,
98}
99
100impl NDArraySender {
101 pub fn send(&self, array: Arc<NDArray>) {
106 if !self.enabled.load(Ordering::Acquire) {
107 return;
108 }
109 if self.blocking_mode.load(Ordering::Acquire) {
110 if let Some(ref bp) = self.blocking_processor {
111 bp.process_and_publish(&array);
112 return;
113 }
114 }
115 if let Some(ref c) = self.queued_counter {
117 c.increment();
118 }
119 let msg = ArrayMessage {
120 array,
121 counter: self.queued_counter.clone(),
122 };
123 match self.tx.try_send(msg) {
124 Ok(()) => {}
125 Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
126 self.dropped_count.fetch_add(1, Ordering::Relaxed);
128 }
129 Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
130 }
132 }
133 }
134
135 pub fn is_enabled(&self) -> bool {
137 self.enabled.load(Ordering::Acquire)
138 }
139
140 pub fn is_blocking(&self) -> bool {
142 self.blocking_mode.load(Ordering::Acquire)
143 }
144
145 pub fn port_name(&self) -> &str {
146 &self.port_name
147 }
148
149 pub fn dropped_count(&self) -> u64 {
150 self.dropped_count.load(Ordering::Relaxed)
151 }
152
153 pub(crate) fn dropped_count_shared(&self) -> Arc<AtomicU64> {
155 self.dropped_count.clone()
156 }
157
158 pub(crate) fn tx_clone(&self) -> tokio::sync::mpsc::Sender<ArrayMessage> {
160 self.tx.clone()
161 }
162
163 pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
165 self.queued_counter = Some(counter);
166 }
167
168 pub(crate) fn with_blocking_support(
170 self,
171 enabled: Arc<AtomicBool>,
172 blocking_mode: Arc<AtomicBool>,
173 blocking_processor: Arc<dyn BlockingProcessFn>,
174 ) -> Self {
175 Self {
176 enabled,
177 blocking_mode,
178 blocking_processor: Some(blocking_processor),
179 ..self
180 }
181 }
182}
183
184pub struct NDArrayReceiver {
186 rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
187}
188
189impl NDArrayReceiver {
190 pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
192 self.rx.blocking_recv().map(|msg| msg.array.clone())
193 }
194
195 pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
197 self.rx.recv().await.map(|msg| msg.array.clone())
198 }
199
200 pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
203 self.rx.recv().await
204 }
205}
206
207pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
209 let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
210 (
211 NDArraySender {
212 tx,
213 port_name: port_name.to_string(),
214 dropped_count: Arc::new(AtomicU64::new(0)),
215 enabled: Arc::new(AtomicBool::new(true)),
216 blocking_mode: Arc::new(AtomicBool::new(false)),
217 blocking_processor: None,
218 queued_counter: None,
219 },
220 NDArrayReceiver { rx },
221 )
222}
223
224pub struct NDArrayOutput {
226 senders: Vec<NDArraySender>,
227}
228
229impl NDArrayOutput {
230 pub fn new() -> Self {
231 Self {
232 senders: Vec::new(),
233 }
234 }
235
236 pub fn add(&mut self, sender: NDArraySender) {
237 self.senders.push(sender);
238 }
239
240 pub fn remove(&mut self, port_name: &str) {
241 self.senders.retain(|s| s.port_name != port_name);
242 }
243
244 pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
246 let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
247 Some(self.senders.swap_remove(idx))
248 }
249
250 pub fn publish(&self, array: Arc<NDArray>) {
252 for sender in &self.senders {
253 sender.send(array.clone());
254 }
255 }
256
257 pub fn publish_to(&self, index: usize, array: Arc<NDArray>) {
259 if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
260 sender.send(array);
261 }
262 }
263
264 pub fn total_dropped(&self) -> u64 {
265 self.senders.iter().map(|s| s.dropped_count()).sum()
266 }
267
268 pub fn num_senders(&self) -> usize {
269 self.senders.len()
270 }
271}
272
273impl Default for NDArrayOutput {
274 fn default() -> Self {
275 Self::new()
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::ndarray::{NDArray, NDDataType, NDDimension};
283
284 fn make_test_array(id: i32) -> Arc<NDArray> {
285 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
286 arr.unique_id = id;
287 Arc::new(arr)
288 }
289
290 #[test]
291 fn test_send_receive_basic() {
292 let (sender, mut receiver) = ndarray_channel("TEST", 10);
293 sender.send(make_test_array(1));
294 sender.send(make_test_array(2));
295
296 let rt = tokio::runtime::Builder::new_current_thread()
297 .enable_all()
298 .build()
299 .unwrap();
300 rt.block_on(async {
301 let a1 = receiver.recv().await.unwrap();
302 assert_eq!(a1.unique_id, 1);
303 let a2 = receiver.recv().await.unwrap();
304 assert_eq!(a2.unique_id, 2);
305 });
306 }
307
308 #[test]
309 fn test_back_pressure_drops() {
310 let (sender, _receiver) = ndarray_channel("TEST", 2);
311 sender.send(make_test_array(1));
313 sender.send(make_test_array(2));
314 sender.send(make_test_array(3));
316 sender.send(make_test_array(4));
317
318 assert_eq!(sender.dropped_count(), 2);
319 }
320
321 #[test]
322 fn test_fanout_three_receivers() {
323 let (s1, mut r1) = ndarray_channel("P1", 10);
324 let (s2, mut r2) = ndarray_channel("P2", 10);
325 let (s3, mut r3) = ndarray_channel("P3", 10);
326
327 let mut output = NDArrayOutput::new();
328 output.add(s1);
329 output.add(s2);
330 output.add(s3);
331
332 output.publish(make_test_array(42));
333
334 let rt = tokio::runtime::Builder::new_current_thread()
335 .enable_all()
336 .build()
337 .unwrap();
338 rt.block_on(async {
339 assert_eq!(r1.recv().await.unwrap().unique_id, 42);
340 assert_eq!(r2.recv().await.unwrap().unique_id, 42);
341 assert_eq!(r3.recv().await.unwrap().unique_id, 42);
342 });
343 }
344
345 #[test]
346 fn test_fanout_total_dropped() {
347 let (s1, _r1) = ndarray_channel("P1", 1);
348 let (s2, _r2) = ndarray_channel("P2", 1);
349
350 let mut output = NDArrayOutput::new();
351 output.add(s1);
352 output.add(s2);
353
354 output.publish(make_test_array(1));
356 output.publish(make_test_array(2));
358
359 assert_eq!(output.total_dropped(), 2);
360 }
361
362 #[test]
363 fn test_fanout_remove() {
364 let (s1, _r1) = ndarray_channel("P1", 10);
365 let (s2, _r2) = ndarray_channel("P2", 10);
366
367 let mut output = NDArrayOutput::new();
368 output.add(s1);
369 output.add(s2);
370 assert_eq!(output.num_senders(), 2);
371
372 output.remove("P1");
373 assert_eq!(output.num_senders(), 1);
374 }
375
376 #[test]
377 fn test_blocking_recv() {
378 let (sender, mut receiver) = ndarray_channel("TEST", 10);
379
380 let handle = std::thread::spawn(move || {
381 let arr = receiver.blocking_recv().unwrap();
382 arr.unique_id
383 });
384
385 sender.send(make_test_array(99));
386 let id = handle.join().unwrap();
387 assert_eq!(id, 99);
388 }
389
390 #[test]
391 fn test_channel_closed_on_receiver_drop() {
392 let (sender, receiver) = ndarray_channel("TEST", 10);
393 drop(receiver);
394 sender.send(make_test_array(1));
396 assert_eq!(sender.dropped_count(), 0); }
398
399 #[test]
400 fn test_queued_counter_basic() {
401 let counter = QueuedArrayCounter::new();
402 assert_eq!(counter.get(), 0);
403 counter.increment();
404 assert_eq!(counter.get(), 1);
405 counter.increment();
406 assert_eq!(counter.get(), 2);
407 counter.decrement();
408 assert_eq!(counter.get(), 1);
409 counter.decrement();
410 assert_eq!(counter.get(), 0);
411 }
412
413 #[test]
414 fn test_queued_counter_wait_until_zero() {
415 let counter = Arc::new(QueuedArrayCounter::new());
416 counter.increment();
417 counter.increment();
418
419 let c = counter.clone();
420 let h = std::thread::spawn(move || {
421 std::thread::sleep(Duration::from_millis(10));
422 c.decrement();
423 std::thread::sleep(Duration::from_millis(10));
424 c.decrement();
425 });
426
427 assert!(counter.wait_until_zero(Duration::from_secs(5)));
428 h.join().unwrap();
429 }
430
431 #[test]
432 fn test_queued_counter_wait_timeout() {
433 let counter = Arc::new(QueuedArrayCounter::new());
434 counter.increment();
435 assert!(!counter.wait_until_zero(Duration::from_millis(10)));
436 }
437
438 #[test]
439 fn test_send_increments_counter() {
440 let counter = Arc::new(QueuedArrayCounter::new());
441 let (mut sender, _receiver) = ndarray_channel("TEST", 10);
442 sender.set_queued_counter(counter.clone());
443
444 sender.send(make_test_array(1));
445 assert_eq!(counter.get(), 1);
446 sender.send(make_test_array(2));
447 assert_eq!(counter.get(), 2);
448 }
449
450 #[test]
451 fn test_send_queue_full_no_net_increment() {
452 let counter = Arc::new(QueuedArrayCounter::new());
453 let (mut sender, _receiver) = ndarray_channel("TEST", 1);
454 sender.set_queued_counter(counter.clone());
455
456 sender.send(make_test_array(1)); assert_eq!(counter.get(), 1);
458 sender.send(make_test_array(2)); assert_eq!(counter.get(), 1);
460 }
461
462 #[test]
463 fn test_message_drop_decrements() {
464 let counter = Arc::new(QueuedArrayCounter::new());
465 counter.increment();
466 let msg = ArrayMessage {
467 array: make_test_array(1),
468 counter: Some(counter.clone()),
469 };
470 assert_eq!(counter.get(), 1);
471 drop(msg);
472 assert_eq!(counter.get(), 0);
473 }
474}