Skip to main content

ad_core_rs/plugin/
channel.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7/// Tracks the number of queued (in-flight) arrays across plugins.
8/// Used by drivers to perform a bounded wait at end of acquisition.
9pub struct QueuedArrayCounter {
10    count: AtomicUsize,
11    mutex: parking_lot::Mutex<()>,
12    condvar: parking_lot::Condvar,
13}
14
15impl QueuedArrayCounter {
16    /// Create a new counter starting at zero.
17    pub fn new() -> Self {
18        Self {
19            count: AtomicUsize::new(0),
20            mutex: parking_lot::Mutex::new(()),
21            condvar: parking_lot::Condvar::new(),
22        }
23    }
24
25    /// Increment the queued count (called before send).
26    pub fn increment(&self) {
27        self.count.fetch_add(1, Ordering::AcqRel);
28    }
29
30    /// Decrement the queued count. Notifies waiters when reaching zero.
31    pub fn decrement(&self) {
32        let prev = self.count.fetch_sub(1, Ordering::AcqRel);
33        if prev == 1 {
34            let _guard = self.mutex.lock();
35            self.condvar.notify_all();
36        }
37    }
38
39    /// Current queued count.
40    pub fn get(&self) -> usize {
41        self.count.load(Ordering::Acquire)
42    }
43
44    /// Wait until count reaches zero, or timeout expires.
45    /// Returns `true` if count is zero, `false` on timeout.
46    pub fn wait_until_zero(&self, timeout: Duration) -> bool {
47        let mut guard = self.mutex.lock();
48        if self.count.load(Ordering::Acquire) == 0 {
49            return true;
50        }
51        !self
52            .condvar
53            .wait_while_for(
54                &mut guard,
55                |_| self.count.load(Ordering::Acquire) != 0,
56                timeout,
57            )
58            .timed_out()
59    }
60}
61
62impl Default for QueuedArrayCounter {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68/// Array message with optional queued-array counter and completion signal.
69/// When dropped, decrements the counter (if present) — this signals that
70/// the downstream plugin has finished processing the array.
71pub struct ArrayMessage {
72    pub array: Arc<NDArray>,
73    pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
74    /// When Some, the sender awaits this to confirm downstream processing completed.
75    /// Fired when ArrayMessage is dropped (i.e., after plugin process_array finishes).
76    pub(crate) done_tx: Option<tokio::sync::oneshot::Sender<()>>,
77}
78
79impl Drop for ArrayMessage {
80    fn drop(&mut self) {
81        if let Some(tx) = self.done_tx.take() {
82            let _ = tx.send(());
83        }
84        if let Some(c) = self.counter.take() {
85            c.decrement();
86        }
87    }
88}
89
90/// Sender held by upstream. Fully async, reliable (no drops).
91///
92/// # `blocking_callbacks` semantics
93///
94/// Both modes use reliable async enqueue (`send().await`). The difference is
95/// how long the caller waits:
96///
97/// - `blocking_callbacks=0`: waits until the message is in the downstream queue
98///   (enqueue guaranteed, processing NOT awaited).
99/// - `blocking_callbacks=1`: waits until the downstream plugin has finished
100///   processing the array (enqueue + completion awaited).
101///
102/// Neither mode drops arrays due to back-pressure — the caller yields instead.
103#[derive(Clone)]
104pub struct NDArraySender {
105    tx: tokio::sync::mpsc::Sender<ArrayMessage>,
106    port_name: String,
107    enabled: Arc<AtomicBool>,
108    blocking_mode: Arc<AtomicBool>,
109    queued_counter: Option<Arc<QueuedArrayCounter>>,
110}
111
112impl NDArraySender {
113    /// Publish an array downstream (async, reliable).
114    ///
115    /// - `enable_callbacks=0`: returns immediately, array not sent.
116    /// - `blocking_callbacks=0`: awaits queue admission only.
117    /// - `blocking_callbacks=1`: awaits queue admission + downstream processing completion.
118    pub async fn publish(&self, array: Arc<NDArray>) {
119        if !self.enabled.load(Ordering::Acquire) {
120            return;
121        }
122        if let Some(ref c) = self.queued_counter {
123            c.increment();
124        }
125
126        let blocking = self.blocking_mode.load(Ordering::Acquire);
127        let (done_tx, done_rx) = if blocking {
128            let (tx, rx) = tokio::sync::oneshot::channel();
129            (Some(tx), Some(rx))
130        } else {
131            (None, None)
132        };
133
134        let msg = ArrayMessage {
135            array,
136            counter: self.queued_counter.clone(),
137            done_tx,
138        };
139
140        if self.tx.send(msg).await.is_err() {
141            // Channel closed — counter was decremented by ArrayMessage::drop
142            return;
143        }
144
145        // blocking_callbacks=1: wait for downstream to finish processing
146        if let Some(rx) = done_rx {
147            let _ = rx.await;
148        }
149    }
150
151    /// Whether this sender's plugin has callbacks enabled.
152    pub fn is_enabled(&self) -> bool {
153        self.enabled.load(Ordering::Acquire)
154    }
155
156    /// Whether this sender's plugin is in blocking mode.
157    pub fn is_blocking(&self) -> bool {
158        self.blocking_mode.load(Ordering::Acquire)
159    }
160
161    pub fn port_name(&self) -> &str {
162        &self.port_name
163    }
164
165    /// Set the queued-array counter for tracking in-flight arrays.
166    pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
167        self.queued_counter = Some(counter);
168    }
169
170    /// Set the enabled/blocking mode flags (used by plugin runtime wiring).
171    pub(crate) fn set_mode_flags(
172        &mut self,
173        enabled: Arc<AtomicBool>,
174        blocking_mode: Arc<AtomicBool>,
175    ) {
176        self.enabled = enabled;
177        self.blocking_mode = blocking_mode;
178    }
179}
180
181/// Receiver held by downstream plugin.
182pub struct NDArrayReceiver {
183    rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
184}
185
186impl NDArrayReceiver {
187    /// Blocking receive (for use in std::thread data processing loops).
188    pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
189        self.rx.blocking_recv().map(|msg| msg.array.clone())
190    }
191
192    /// Async receive.
193    pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
194        self.rx.recv().await.map(|msg| msg.array.clone())
195    }
196
197    /// Receive the full ArrayMessage (crate-internal). The message's Drop
198    /// will signal completion when the caller is done with it.
199    pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
200        self.rx.recv().await
201    }
202}
203
204/// Create a matched sender/receiver pair.
205pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
206    let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
207    (
208        NDArraySender {
209            tx,
210            port_name: port_name.to_string(),
211            enabled: Arc::new(AtomicBool::new(true)),
212            blocking_mode: Arc::new(AtomicBool::new(false)),
213            queued_counter: None,
214        },
215        NDArrayReceiver { rx },
216    )
217}
218
219/// Fan-out: publishes arrays to multiple downstream receivers.
220pub struct NDArrayOutput {
221    senders: Vec<NDArraySender>,
222}
223
224impl NDArrayOutput {
225    pub fn new() -> Self {
226        Self {
227            senders: Vec::new(),
228        }
229    }
230
231    pub fn add(&mut self, sender: NDArraySender) {
232        self.senders.push(sender);
233    }
234
235    pub fn remove(&mut self, port_name: &str) {
236        self.senders.retain(|s| s.port_name != port_name);
237    }
238
239    /// Remove a sender by port name and return it (if found).
240    pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
241        let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
242        Some(self.senders.swap_remove(idx))
243    }
244
245    /// Publish an array to all downstream receivers (async, reliable, concurrent).
246    ///
247    /// Each sender is awaited independently — a slow downstream does not
248    /// block enqueue to sibling downstreams. The function returns after
249    /// all senders have completed their publish (enqueue or completion,
250    /// depending on `blocking_callbacks`).
251    pub async fn publish(&self, array: Arc<NDArray>) {
252        let futs = self.senders.iter().map(|s| s.publish(array.clone()));
253        futures_util::future::join_all(futs).await;
254    }
255
256    /// Publish an array to a single downstream receiver by index (for scatter/round-robin).
257    pub async fn publish_to(&self, index: usize, array: Arc<NDArray>) {
258        if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
259            sender.publish(array).await;
260        }
261    }
262
263    pub fn num_senders(&self) -> usize {
264        self.senders.len()
265    }
266
267    /// Clone the senders list (for publishing outside a lock in async context).
268    pub(crate) fn senders_clone(&self) -> Vec<NDArraySender> {
269        self.senders.clone()
270    }
271}
272
273/// Cloneable async handle for publishing arrays to downstream plugins.
274///
275/// This is the public API for driver acquisition tasks.
276/// Internally it snapshots the sender list, releases the lock, then
277/// publishes to all senders concurrently.
278///
279/// # Example
280/// ```ignore
281/// if config.array_callbacks {
282///     publisher.publish(Arc::new(frame)).await;
283/// }
284/// ```
285#[derive(Clone)]
286pub struct ArrayPublisher {
287    output: Arc<parking_lot::Mutex<NDArrayOutput>>,
288}
289
290impl ArrayPublisher {
291    /// Create a publisher backed by the given output.
292    pub fn new(output: Arc<parking_lot::Mutex<NDArrayOutput>>) -> Self {
293        Self { output }
294    }
295
296    /// Publish an array to all downstream plugins (async, concurrent fan-out).
297    pub async fn publish(&self, array: Arc<NDArray>) {
298        let senders = self.output.lock().senders_clone();
299        let futs = senders.iter().map(|s| s.publish(array.clone()));
300        futures_util::future::join_all(futs).await;
301    }
302}
303
304impl Default for NDArrayOutput {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::ndarray::{NDArray, NDDataType, NDDimension};
314
315    fn make_test_array(id: i32) -> Arc<NDArray> {
316        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
317        arr.unique_id = id;
318        Arc::new(arr)
319    }
320
321    #[tokio::test]
322    async fn test_publish_receive_basic() {
323        let (sender, mut receiver) = ndarray_channel("TEST", 10);
324        sender.publish(make_test_array(1)).await;
325        sender.publish(make_test_array(2)).await;
326
327        let a1 = receiver.recv().await.unwrap();
328        assert_eq!(a1.unique_id, 1);
329        let a2 = receiver.recv().await.unwrap();
330        assert_eq!(a2.unique_id, 2);
331    }
332
333    #[tokio::test]
334    async fn test_publish_no_drop() {
335        // With reliable send().await, even a queue of 1 should not drop
336        let (sender, mut receiver) = ndarray_channel("TEST", 1);
337
338        // Spawn publisher that sends 3 arrays
339        let s = sender.clone();
340        let pub_handle = tokio::spawn(async move {
341            s.publish(make_test_array(1)).await;
342            s.publish(make_test_array(2)).await;
343            s.publish(make_test_array(3)).await;
344        });
345
346        // Receive all 3 — no drops
347        let a1 = receiver.recv().await.unwrap();
348        assert_eq!(a1.unique_id, 1);
349        let a2 = receiver.recv().await.unwrap();
350        assert_eq!(a2.unique_id, 2);
351        let a3 = receiver.recv().await.unwrap();
352        assert_eq!(a3.unique_id, 3);
353
354        pub_handle.await.unwrap();
355    }
356
357    #[tokio::test]
358    async fn test_blocking_callbacks_completion_wait() {
359        let (sender, mut receiver) = ndarray_channel("TEST", 10);
360        sender.blocking_mode.store(true, Ordering::Release);
361
362        let completed = Arc::new(AtomicBool::new(false));
363        let completed_clone = completed.clone();
364
365        // Spawn receiver that takes some time to process
366        let recv_handle = tokio::spawn(async move {
367            let msg = receiver.recv_msg().await.unwrap();
368            assert_eq!(msg.array.unique_id, 42);
369            // Simulate processing time
370            tokio::time::sleep(Duration::from_millis(50)).await;
371            completed_clone.store(true, Ordering::Release);
372            // msg dropped here → done_tx fires
373        });
374
375        // publish() should wait for completion
376        sender.publish(make_test_array(42)).await;
377
378        // By the time publish returns, downstream should have completed
379        assert!(completed.load(Ordering::Acquire));
380
381        recv_handle.await.unwrap();
382    }
383
384    #[tokio::test]
385    async fn test_fanout_three_receivers() {
386        let (s1, mut r1) = ndarray_channel("P1", 10);
387        let (s2, mut r2) = ndarray_channel("P2", 10);
388        let (s3, mut r3) = ndarray_channel("P3", 10);
389
390        let mut output = NDArrayOutput::new();
391        output.add(s1);
392        output.add(s2);
393        output.add(s3);
394
395        output.publish(make_test_array(42)).await;
396
397        assert_eq!(r1.recv().await.unwrap().unique_id, 42);
398        assert_eq!(r2.recv().await.unwrap().unique_id, 42);
399        assert_eq!(r3.recv().await.unwrap().unique_id, 42);
400    }
401
402    #[test]
403    fn test_blocking_recv() {
404        let rt = tokio::runtime::Builder::new_current_thread()
405            .enable_all()
406            .build()
407            .unwrap();
408        let (sender, mut receiver) = ndarray_channel("TEST", 10);
409
410        let handle = std::thread::spawn(move || {
411            let arr = receiver.blocking_recv().unwrap();
412            arr.unique_id
413        });
414
415        rt.block_on(sender.publish(make_test_array(99)));
416        let id = handle.join().unwrap();
417        assert_eq!(id, 99);
418    }
419
420    #[tokio::test]
421    async fn test_channel_closed_on_receiver_drop() {
422        let (sender, receiver) = ndarray_channel("TEST", 10);
423        drop(receiver);
424        // Sending to closed channel should not panic
425        sender.publish(make_test_array(1)).await;
426    }
427
428    #[test]
429    fn test_queued_counter_basic() {
430        let counter = QueuedArrayCounter::new();
431        assert_eq!(counter.get(), 0);
432        counter.increment();
433        assert_eq!(counter.get(), 1);
434        counter.increment();
435        assert_eq!(counter.get(), 2);
436        counter.decrement();
437        assert_eq!(counter.get(), 1);
438        counter.decrement();
439        assert_eq!(counter.get(), 0);
440    }
441
442    #[test]
443    fn test_queued_counter_wait_until_zero() {
444        let counter = Arc::new(QueuedArrayCounter::new());
445        counter.increment();
446        counter.increment();
447
448        let c = counter.clone();
449        let h = std::thread::spawn(move || {
450            std::thread::sleep(Duration::from_millis(10));
451            c.decrement();
452            std::thread::sleep(Duration::from_millis(10));
453            c.decrement();
454        });
455
456        assert!(counter.wait_until_zero(Duration::from_secs(5)));
457        h.join().unwrap();
458    }
459
460    #[test]
461    fn test_queued_counter_wait_timeout() {
462        let counter = Arc::new(QueuedArrayCounter::new());
463        counter.increment();
464        assert!(!counter.wait_until_zero(Duration::from_millis(10)));
465    }
466
467    #[tokio::test]
468    async fn test_publish_increments_counter() {
469        let counter = Arc::new(QueuedArrayCounter::new());
470        let (mut sender, mut _receiver) = ndarray_channel("TEST", 10);
471        sender.set_queued_counter(counter.clone());
472
473        sender.publish(make_test_array(1)).await;
474        assert_eq!(counter.get(), 1);
475        sender.publish(make_test_array(2)).await;
476        assert_eq!(counter.get(), 2);
477    }
478
479    #[tokio::test]
480    async fn test_message_drop_decrements() {
481        let counter = Arc::new(QueuedArrayCounter::new());
482        counter.increment();
483        let msg = ArrayMessage {
484            array: make_test_array(1),
485            counter: Some(counter.clone()),
486            done_tx: None,
487        };
488        assert_eq!(counter.get(), 1);
489        drop(msg);
490        assert_eq!(counter.get(), 0);
491    }
492}