Skip to main content

ad_core_rs/plugin/
channel.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicI32, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7/// Tracks the number of queued (in-flight) arrays across plugins.
8/// Used by drivers to perform a bounded wait at end of acquisition.
9pub struct QueuedArrayCounter {
10    count: AtomicUsize,
11    mutex: parking_lot::Mutex<()>,
12    condvar: parking_lot::Condvar,
13}
14
15impl QueuedArrayCounter {
16    /// Create a new counter starting at zero.
17    pub fn new() -> Self {
18        Self {
19            count: AtomicUsize::new(0),
20            mutex: parking_lot::Mutex::new(()),
21            condvar: parking_lot::Condvar::new(),
22        }
23    }
24
25    /// Increment the queued count (called before send).
26    pub fn increment(&self) {
27        self.count.fetch_add(1, Ordering::AcqRel);
28    }
29
30    /// Decrement the queued count. Notifies waiters when reaching zero.
31    pub fn decrement(&self) {
32        let prev = self.count.fetch_sub(1, Ordering::AcqRel);
33        if prev == 1 {
34            let _guard = self.mutex.lock();
35            self.condvar.notify_all();
36        }
37    }
38
39    /// Current queued count.
40    pub fn get(&self) -> usize {
41        self.count.load(Ordering::Acquire)
42    }
43
44    /// Wait until count reaches zero, or timeout expires.
45    /// Returns `true` if count is zero, `false` on timeout.
46    pub fn wait_until_zero(&self, timeout: Duration) -> bool {
47        let mut guard = self.mutex.lock();
48        if self.count.load(Ordering::Acquire) == 0 {
49            return true;
50        }
51        !self
52            .condvar
53            .wait_while_for(
54                &mut guard,
55                |_| self.count.load(Ordering::Acquire) != 0,
56                timeout,
57            )
58            .timed_out()
59    }
60}
61
62impl Default for QueuedArrayCounter {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68/// Array message with optional queued-array counter and completion signal.
69/// When dropped, decrements the counter (if present) — this signals that
70/// the downstream plugin has finished processing the array.
71pub struct ArrayMessage {
72    pub array: Arc<NDArray>,
73    pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
74    /// When Some, the sender awaits this to confirm downstream processing completed.
75    /// Fired when ArrayMessage is dropped (i.e., after plugin process_array finishes).
76    pub(crate) done_tx: Option<tokio::sync::oneshot::Sender<()>>,
77}
78
79impl Drop for ArrayMessage {
80    fn drop(&mut self) {
81        if let Some(tx) = self.done_tx.take() {
82            let _ = tx.send(());
83        }
84        if let Some(c) = self.counter.take() {
85            c.decrement();
86        }
87    }
88}
89
90/// Outcome of a `publish` call, mirroring C++ `driverCallback` accounting.
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum PublishOutcome {
93    /// The array was enqueued (and, in blocking mode, processed).
94    Delivered,
95    /// `enable_callbacks` was 0 — array not sent (not a drop, not counted).
96    Disabled,
97    /// The downstream queue was full and the array was dropped. The caller
98    /// must increment `DroppedArrays`, matching C++ `trySend` semantics.
99    DroppedQueueFull,
100    /// The downstream channel was closed (receiver gone).
101    ChannelClosed,
102}
103
104/// Sender held by upstream.
105///
106/// # Default: drop-on-full (C++ parity)
107///
108/// By default `publish` uses a bounded `try_send`: when the downstream queue
109/// is full the array is **dropped** and `PublishOutcome::DroppedQueueFull` is
110/// returned, matching C++ `NDPluginDriver::driverCallback` `trySend` — a slow
111/// plugin drops frames rather than back-pressuring the detector driver.
112///
113/// # `blocking_callbacks=1`: reliable opt-in
114///
115/// When `blocking_callbacks` is set, `publish` instead uses a reliable
116/// `send().await` and waits for the downstream plugin to finish processing.
117/// This is the explicit opt-in for "never drop, apply back-pressure"
118/// behavior. It is NOT the default.
119#[derive(Clone)]
120pub struct NDArraySender {
121    tx: tokio::sync::mpsc::Sender<ArrayMessage>,
122    port_name: String,
123    enabled: Arc<AtomicBool>,
124    blocking_mode: Arc<AtomicBool>,
125    queued_counter: Option<Arc<QueuedArrayCounter>>,
126    /// Cumulative count of arrays dropped because this sender's downstream
127    /// input queue was full. Owned by the downstream plugin (which publishes
128    /// it to its `DROPPED_ARRAYS` param), shared back to every upstream
129    /// sender that feeds this plugin — matching C++ `driverCallback` which
130    /// increments the *receiving* plugin's `NDPluginDriverDroppedArrays`.
131    dropped_arrays: Arc<AtomicI32>,
132}
133
134impl NDArraySender {
135    /// Publish an array downstream.
136    ///
137    /// - `enable_callbacks=0`: returns `Disabled`, array not sent.
138    /// - `blocking_callbacks=0` (default): bounded `try_send` — on a full queue
139    ///   the array is dropped and `DroppedQueueFull` is returned (C++ parity).
140    /// - `blocking_callbacks=1`: reliable `send().await` + awaits downstream
141    ///   processing completion (explicit opt-in, never drops).
142    pub async fn publish(&self, array: Arc<NDArray>) -> PublishOutcome {
143        if !self.enabled.load(Ordering::Acquire) {
144            return PublishOutcome::Disabled;
145        }
146
147        let blocking = self.blocking_mode.load(Ordering::Acquire);
148
149        if !blocking {
150            // Drop-on-full path (C++ trySend). Build the message only on the
151            // way into try_send so a full queue does not touch the counter.
152            if let Some(ref c) = self.queued_counter {
153                c.increment();
154            }
155            let msg = ArrayMessage {
156                array,
157                counter: self.queued_counter.clone(),
158                done_tx: None,
159            };
160            return match self.tx.try_send(msg) {
161                Ok(()) => PublishOutcome::Delivered,
162                // `msg` is dropped here → counter decremented by ArrayMessage::drop.
163                Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
164                    self.dropped_arrays.fetch_add(1, Ordering::AcqRel);
165                    PublishOutcome::DroppedQueueFull
166                }
167                Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
168                    PublishOutcome::ChannelClosed
169                }
170            };
171        }
172
173        // Reliable blocking path: never drops, awaits completion.
174        if let Some(ref c) = self.queued_counter {
175            c.increment();
176        }
177        let (done_tx, done_rx) = tokio::sync::oneshot::channel();
178        let msg = ArrayMessage {
179            array,
180            counter: self.queued_counter.clone(),
181            done_tx: Some(done_tx),
182        };
183        if self.tx.send(msg).await.is_err() {
184            // Channel closed — counter was decremented by ArrayMessage::drop
185            return PublishOutcome::ChannelClosed;
186        }
187        let _ = done_rx.await;
188        PublishOutcome::Delivered
189    }
190
191    /// Whether this sender's plugin has callbacks enabled.
192    pub fn is_enabled(&self) -> bool {
193        self.enabled.load(Ordering::Acquire)
194    }
195
196    /// Whether this sender's plugin is in blocking mode.
197    pub fn is_blocking(&self) -> bool {
198        self.blocking_mode.load(Ordering::Acquire)
199    }
200
201    pub fn port_name(&self) -> &str {
202        &self.port_name
203    }
204
205    /// Set the queued-array counter for tracking in-flight arrays.
206    pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
207        self.queued_counter = Some(counter);
208    }
209
210    /// Attach the downstream plugin's shared `DroppedArrays` counter so that
211    /// a full-queue drop on this sender is accounted to that plugin (C++ parity).
212    pub fn set_dropped_arrays_counter(&mut self, counter: Arc<AtomicI32>) {
213        self.dropped_arrays = counter;
214    }
215
216    /// The shared `DroppedArrays` counter for this sender's downstream queue.
217    pub fn dropped_arrays_counter(&self) -> &Arc<AtomicI32> {
218        &self.dropped_arrays
219    }
220
221    /// Current capacity (free slots) of the downstream input queue.
222    pub fn capacity(&self) -> usize {
223        self.tx.capacity()
224    }
225
226    /// Maximum capacity of the downstream input queue.
227    pub fn max_capacity(&self) -> usize {
228        self.tx.max_capacity()
229    }
230
231    /// Set the enabled/blocking mode flags (used by plugin runtime wiring).
232    pub(crate) fn set_mode_flags(
233        &mut self,
234        enabled: Arc<AtomicBool>,
235        blocking_mode: Arc<AtomicBool>,
236    ) {
237        self.enabled = enabled;
238        self.blocking_mode = blocking_mode;
239    }
240}
241
242/// Receiver held by downstream plugin.
243pub struct NDArrayReceiver {
244    rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
245}
246
247impl NDArrayReceiver {
248    /// Number of currently buffered (pending) messages in the input queue.
249    pub fn pending(&self) -> usize {
250        self.rx.len()
251    }
252
253    /// Maximum capacity of the input queue.
254    pub fn max_capacity(&self) -> usize {
255        self.rx.max_capacity()
256    }
257
258    /// Number of free slots in the input queue (`max_capacity - pending`).
259    pub fn capacity(&self) -> usize {
260        self.rx.capacity()
261    }
262
263    /// Blocking receive (for use in std::thread data processing loops).
264    pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
265        self.rx.blocking_recv().map(|msg| msg.array.clone())
266    }
267
268    /// Async receive.
269    pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
270        self.rx.recv().await.map(|msg| msg.array.clone())
271    }
272
273    /// Receive the full ArrayMessage (crate-internal). The message's Drop
274    /// will signal completion when the caller is done with it.
275    pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
276        self.rx.recv().await
277    }
278}
279
280/// Create a matched sender/receiver pair.
281pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
282    let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
283    (
284        NDArraySender {
285            tx,
286            port_name: port_name.to_string(),
287            enabled: Arc::new(AtomicBool::new(true)),
288            blocking_mode: Arc::new(AtomicBool::new(false)),
289            queued_counter: None,
290            dropped_arrays: Arc::new(AtomicI32::new(0)),
291        },
292        NDArrayReceiver { rx },
293    )
294}
295
296/// Fan-out: publishes arrays to multiple downstream receivers.
297pub struct NDArrayOutput {
298    senders: Vec<NDArraySender>,
299}
300
301impl NDArrayOutput {
302    pub fn new() -> Self {
303        Self {
304            senders: Vec::new(),
305        }
306    }
307
308    pub fn add(&mut self, sender: NDArraySender) {
309        self.senders.push(sender);
310    }
311
312    pub fn remove(&mut self, port_name: &str) {
313        self.senders.retain(|s| s.port_name != port_name);
314    }
315
316    /// Remove a sender by port name and return it (if found).
317    pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
318        let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
319        Some(self.senders.swap_remove(idx))
320    }
321
322    /// Publish an array to all downstream receivers (async, concurrent).
323    ///
324    /// Each sender publishes independently. Returns the per-sender outcomes
325    /// so the caller can count `DroppedArrays` for any downstream whose queue
326    /// was full (C++ `driverCallback` semantics).
327    pub async fn publish(&self, array: Arc<NDArray>) -> Vec<PublishOutcome> {
328        let futs = self.senders.iter().map(|s| s.publish(array.clone()));
329        futures_util::future::join_all(futs).await
330    }
331
332    /// Publish an array to a single downstream receiver by index (for scatter/round-robin).
333    pub async fn publish_to(&self, index: usize, array: Arc<NDArray>) -> Option<PublishOutcome> {
334        if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
335            Some(sender.publish(array).await)
336        } else {
337            None
338        }
339    }
340
341    pub fn num_senders(&self) -> usize {
342        self.senders.len()
343    }
344
345    /// Clone the senders list (for publishing outside a lock in async context).
346    pub(crate) fn senders_clone(&self) -> Vec<NDArraySender> {
347        self.senders.clone()
348    }
349}
350
351/// Cloneable async handle for publishing arrays to downstream plugins.
352///
353/// This is the public API for driver acquisition tasks.
354/// Internally it snapshots the sender list, releases the lock, then
355/// publishes to all senders concurrently.
356///
357/// # Example
358/// ```ignore
359/// if config.array_callbacks {
360///     publisher.publish(Arc::new(frame)).await;
361/// }
362/// ```
363#[derive(Clone)]
364pub struct ArrayPublisher {
365    output: Arc<parking_lot::Mutex<NDArrayOutput>>,
366}
367
368impl ArrayPublisher {
369    /// Create a publisher backed by the given output.
370    pub fn new(output: Arc<parking_lot::Mutex<NDArrayOutput>>) -> Self {
371        Self { output }
372    }
373
374    /// Publish an array to all downstream plugins (async, concurrent fan-out).
375    ///
376    /// Returns the per-downstream outcomes — a `DroppedQueueFull` entry means
377    /// that downstream plugin's input queue was full and the array was dropped
378    /// (C++ `driverCallback` `trySend`). The driver should count those as
379    /// `DroppedArrays`.
380    pub async fn publish(&self, array: Arc<NDArray>) -> Vec<PublishOutcome> {
381        let senders = self.output.lock().senders_clone();
382        let futs = senders.iter().map(|s| s.publish(array.clone()));
383        futures_util::future::join_all(futs).await
384    }
385}
386
387impl Default for NDArrayOutput {
388    fn default() -> Self {
389        Self::new()
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::ndarray::{NDArray, NDDataType, NDDimension};
397
398    fn make_test_array(id: i32) -> Arc<NDArray> {
399        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
400        arr.unique_id = id;
401        Arc::new(arr)
402    }
403
404    #[tokio::test]
405    async fn test_publish_receive_basic() {
406        let (sender, mut receiver) = ndarray_channel("TEST", 10);
407        sender.publish(make_test_array(1)).await;
408        sender.publish(make_test_array(2)).await;
409
410        let a1 = receiver.recv().await.unwrap();
411        assert_eq!(a1.unique_id, 1);
412        let a2 = receiver.recv().await.unwrap();
413        assert_eq!(a2.unique_id, 2);
414    }
415
416    #[tokio::test]
417    async fn test_publish_blocking_no_drop() {
418        // In blocking_callbacks mode, reliable send().await is used: even a
419        // queue of 1 must not drop — the producer back-pressures instead.
420        let (sender, mut receiver) = ndarray_channel("TEST", 1);
421        sender.blocking_mode.store(true, Ordering::Release);
422
423        let s = sender.clone();
424        let pub_handle = tokio::spawn(async move {
425            s.publish(make_test_array(1)).await;
426            s.publish(make_test_array(2)).await;
427            s.publish(make_test_array(3)).await;
428        });
429
430        // Receive all 3 — no drops in blocking mode.
431        let a1 = receiver.recv().await.unwrap();
432        assert_eq!(a1.unique_id, 1);
433        let a2 = receiver.recv().await.unwrap();
434        assert_eq!(a2.unique_id, 2);
435        let a3 = receiver.recv().await.unwrap();
436        assert_eq!(a3.unique_id, 3);
437
438        pub_handle.await.unwrap();
439    }
440
441    #[tokio::test]
442    async fn test_publish_drops_on_full_queue() {
443        // B1: default (non-blocking) mode drops on a full queue and reports
444        // DroppedQueueFull, matching C++ trySend.
445        let (sender, _receiver) = ndarray_channel("TEST", 1);
446
447        // First publish fills the queue.
448        assert_eq!(
449            sender.publish(make_test_array(1)).await,
450            PublishOutcome::Delivered
451        );
452        // Second publish finds the queue full → dropped + counted.
453        assert_eq!(
454            sender.publish(make_test_array(2)).await,
455            PublishOutcome::DroppedQueueFull
456        );
457    }
458
459    #[tokio::test]
460    async fn test_drop_on_full_does_not_leak_counter() {
461        // A dropped array must not leave the queued-array counter incremented.
462        let counter = Arc::new(QueuedArrayCounter::new());
463        let (mut sender, _receiver) = ndarray_channel("TEST", 1);
464        sender.set_queued_counter(counter.clone());
465
466        sender.publish(make_test_array(1)).await; // delivered, counter=1
467        assert_eq!(counter.get(), 1);
468        let outcome = sender.publish(make_test_array(2)).await; // dropped
469        assert_eq!(outcome, PublishOutcome::DroppedQueueFull);
470        // Counter must still be 1 — the dropped message decremented on drop.
471        assert_eq!(counter.get(), 1);
472    }
473
474    #[tokio::test]
475    async fn test_blocking_callbacks_completion_wait() {
476        let (sender, mut receiver) = ndarray_channel("TEST", 10);
477        sender.blocking_mode.store(true, Ordering::Release);
478
479        let completed = Arc::new(AtomicBool::new(false));
480        let completed_clone = completed.clone();
481
482        // Spawn receiver that takes some time to process
483        let recv_handle = tokio::spawn(async move {
484            let msg = receiver.recv_msg().await.unwrap();
485            assert_eq!(msg.array.unique_id, 42);
486            // Simulate processing time
487            tokio::time::sleep(Duration::from_millis(50)).await;
488            completed_clone.store(true, Ordering::Release);
489            // msg dropped here → done_tx fires
490        });
491
492        // publish() should wait for completion
493        sender.publish(make_test_array(42)).await;
494
495        // By the time publish returns, downstream should have completed
496        assert!(completed.load(Ordering::Acquire));
497
498        recv_handle.await.unwrap();
499    }
500
501    #[tokio::test]
502    async fn test_fanout_three_receivers() {
503        let (s1, mut r1) = ndarray_channel("P1", 10);
504        let (s2, mut r2) = ndarray_channel("P2", 10);
505        let (s3, mut r3) = ndarray_channel("P3", 10);
506
507        let mut output = NDArrayOutput::new();
508        output.add(s1);
509        output.add(s2);
510        output.add(s3);
511
512        output.publish(make_test_array(42)).await;
513
514        assert_eq!(r1.recv().await.unwrap().unique_id, 42);
515        assert_eq!(r2.recv().await.unwrap().unique_id, 42);
516        assert_eq!(r3.recv().await.unwrap().unique_id, 42);
517    }
518
519    #[test]
520    fn test_blocking_recv() {
521        let rt = tokio::runtime::Builder::new_current_thread()
522            .enable_all()
523            .build()
524            .unwrap();
525        let (sender, mut receiver) = ndarray_channel("TEST", 10);
526
527        let handle = std::thread::spawn(move || {
528            let arr = receiver.blocking_recv().unwrap();
529            arr.unique_id
530        });
531
532        rt.block_on(sender.publish(make_test_array(99)));
533        let id = handle.join().unwrap();
534        assert_eq!(id, 99);
535    }
536
537    #[tokio::test]
538    async fn test_channel_closed_on_receiver_drop() {
539        let (sender, receiver) = ndarray_channel("TEST", 10);
540        drop(receiver);
541        // Sending to closed channel should not panic
542        sender.publish(make_test_array(1)).await;
543    }
544
545    #[test]
546    fn test_queued_counter_basic() {
547        let counter = QueuedArrayCounter::new();
548        assert_eq!(counter.get(), 0);
549        counter.increment();
550        assert_eq!(counter.get(), 1);
551        counter.increment();
552        assert_eq!(counter.get(), 2);
553        counter.decrement();
554        assert_eq!(counter.get(), 1);
555        counter.decrement();
556        assert_eq!(counter.get(), 0);
557    }
558
559    #[test]
560    fn test_queued_counter_wait_until_zero() {
561        let counter = Arc::new(QueuedArrayCounter::new());
562        counter.increment();
563        counter.increment();
564
565        let c = counter.clone();
566        let h = std::thread::spawn(move || {
567            std::thread::sleep(Duration::from_millis(10));
568            c.decrement();
569            std::thread::sleep(Duration::from_millis(10));
570            c.decrement();
571        });
572
573        assert!(counter.wait_until_zero(Duration::from_secs(5)));
574        h.join().unwrap();
575    }
576
577    #[test]
578    fn test_queued_counter_wait_timeout() {
579        let counter = Arc::new(QueuedArrayCounter::new());
580        counter.increment();
581        assert!(!counter.wait_until_zero(Duration::from_millis(10)));
582    }
583
584    #[tokio::test]
585    async fn test_publish_increments_counter() {
586        let counter = Arc::new(QueuedArrayCounter::new());
587        let (mut sender, mut _receiver) = ndarray_channel("TEST", 10);
588        sender.set_queued_counter(counter.clone());
589
590        sender.publish(make_test_array(1)).await;
591        assert_eq!(counter.get(), 1);
592        sender.publish(make_test_array(2)).await;
593        assert_eq!(counter.get(), 2);
594    }
595
596    #[tokio::test]
597    async fn test_message_drop_decrements() {
598        let counter = Arc::new(QueuedArrayCounter::new());
599        counter.increment();
600        let msg = ArrayMessage {
601            array: make_test_array(1),
602            counter: Some(counter.clone()),
603            done_tx: None,
604        };
605        assert_eq!(counter.get(), 1);
606        drop(msg);
607        assert_eq!(counter.get(), 0);
608    }
609}