Skip to main content

ad_core_rs/plugin/
channel.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7/// Type-erased blocking processor for inline array processing.
8pub(crate) trait BlockingProcessFn: Send + Sync {
9    fn process_and_publish(&self, array: &NDArray);
10}
11
12/// Tracks the number of queued (in-flight) arrays across non-blocking plugins.
13/// Used by drivers to perform a bounded wait at end of acquisition.
14pub struct QueuedArrayCounter {
15    count: AtomicUsize,
16    mutex: parking_lot::Mutex<()>,
17    condvar: parking_lot::Condvar,
18}
19
20impl QueuedArrayCounter {
21    /// Create a new counter starting at zero.
22    pub fn new() -> Self {
23        Self {
24            count: AtomicUsize::new(0),
25            mutex: parking_lot::Mutex::new(()),
26            condvar: parking_lot::Condvar::new(),
27        }
28    }
29
30    /// Increment the queued count (called before try_send).
31    pub fn increment(&self) {
32        self.count.fetch_add(1, Ordering::AcqRel);
33    }
34
35    /// Decrement the queued count. Notifies waiters when reaching zero.
36    pub fn decrement(&self) {
37        let prev = self.count.fetch_sub(1, Ordering::AcqRel);
38        if prev == 1 {
39            let _guard = self.mutex.lock();
40            self.condvar.notify_all();
41        }
42    }
43
44    /// Current queued count.
45    pub fn get(&self) -> usize {
46        self.count.load(Ordering::Acquire)
47    }
48
49    /// Wait until count reaches zero, or timeout expires.
50    /// Returns `true` if count is zero, `false` on timeout.
51    pub fn wait_until_zero(&self, timeout: Duration) -> bool {
52        let mut guard = self.mutex.lock();
53        if self.count.load(Ordering::Acquire) == 0 {
54            return true;
55        }
56        !self
57            .condvar
58            .wait_while_for(
59                &mut guard,
60                |_| self.count.load(Ordering::Acquire) != 0,
61                timeout,
62            )
63            .timed_out()
64    }
65}
66
67impl Default for QueuedArrayCounter {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73/// Array message with optional queued-array counter.
74/// When dropped, decrements the counter (if present).
75pub struct ArrayMessage {
76    pub array: Arc<NDArray>,
77    pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
78}
79
80impl Drop for ArrayMessage {
81    fn drop(&mut self) {
82        if let Some(c) = self.counter.take() {
83            c.decrement();
84        }
85    }
86}
87
88/// Sender held by upstream. Supports blocking and non-blocking modes.
89#[derive(Clone)]
90pub struct NDArraySender {
91    tx: tokio::sync::mpsc::Sender<ArrayMessage>,
92    port_name: String,
93    dropped_count: Arc<AtomicU64>,
94    enabled: Arc<AtomicBool>,
95    blocking_mode: Arc<AtomicBool>,
96    blocking_processor: Option<Arc<dyn BlockingProcessFn>>,
97    queued_counter: Option<Arc<QueuedArrayCounter>>,
98}
99
100impl NDArraySender {
101    /// Send an array downstream. Behavior depends on mode:
102    /// - Disabled (`enable_callbacks=0`): silently dropped
103    /// - Blocking (`blocking_callbacks=1`): processed inline on caller's thread
104    /// - Non-blocking (default): queued for data thread (dropped if full)
105    pub fn send(&self, array: Arc<NDArray>) {
106        if !self.enabled.load(Ordering::Acquire) {
107            return;
108        }
109        if self.blocking_mode.load(Ordering::Acquire) {
110            if let Some(ref bp) = self.blocking_processor {
111                bp.process_and_publish(&array);
112                return;
113            }
114        }
115        // Non-blocking path: increment counter before try_send
116        if let Some(ref c) = self.queued_counter {
117            c.increment();
118        }
119        let msg = ArrayMessage {
120            array,
121            counter: self.queued_counter.clone(),
122        };
123        match self.tx.try_send(msg) {
124            Ok(()) => {}
125            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
126                // msg dropped here → Drop fires → counter decremented (net 0)
127                self.dropped_count.fetch_add(1, Ordering::Relaxed);
128            }
129            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
130                // msg dropped here → Drop fires → counter decremented
131            }
132        }
133    }
134
135    /// Whether this sender's plugin has callbacks enabled.
136    pub fn is_enabled(&self) -> bool {
137        self.enabled.load(Ordering::Acquire)
138    }
139
140    /// Whether this sender's plugin is in blocking mode.
141    pub fn is_blocking(&self) -> bool {
142        self.blocking_mode.load(Ordering::Acquire)
143    }
144
145    pub fn port_name(&self) -> &str {
146        &self.port_name
147    }
148
149    pub fn dropped_count(&self) -> u64 {
150        self.dropped_count.load(Ordering::Relaxed)
151    }
152
153    /// Set the queued-array counter for tracking in-flight arrays.
154    pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
155        self.queued_counter = Some(counter);
156    }
157
158    /// Configure blocking callback support. Used by plugin runtime.
159    pub(crate) fn with_blocking_support(
160        self,
161        enabled: Arc<AtomicBool>,
162        blocking_mode: Arc<AtomicBool>,
163        blocking_processor: Arc<dyn BlockingProcessFn>,
164    ) -> Self {
165        Self {
166            enabled,
167            blocking_mode,
168            blocking_processor: Some(blocking_processor),
169            ..self
170        }
171    }
172}
173
174/// Receiver held by downstream plugin.
175pub struct NDArrayReceiver {
176    rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
177}
178
179impl NDArrayReceiver {
180    /// Blocking receive (for use in std::thread data processing loops).
181    pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
182        self.rx.blocking_recv().map(|msg| msg.array.clone())
183    }
184
185    /// Async receive.
186    pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
187        self.rx.recv().await.map(|msg| msg.array.clone())
188    }
189
190    /// Receive the full ArrayMessage (crate-internal). The message's Drop
191    /// will signal completion when the caller is done with it.
192    pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
193        self.rx.recv().await
194    }
195}
196
197/// Create a matched sender/receiver pair.
198pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
199    let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
200    (
201        NDArraySender {
202            tx,
203            port_name: port_name.to_string(),
204            dropped_count: Arc::new(AtomicU64::new(0)),
205            enabled: Arc::new(AtomicBool::new(true)),
206            blocking_mode: Arc::new(AtomicBool::new(false)),
207            blocking_processor: None,
208            queued_counter: None,
209        },
210        NDArrayReceiver { rx },
211    )
212}
213
214/// Fan-out: broadcasts arrays to multiple downstream receivers.
215pub struct NDArrayOutput {
216    senders: Vec<NDArraySender>,
217}
218
219impl NDArrayOutput {
220    pub fn new() -> Self {
221        Self {
222            senders: Vec::new(),
223        }
224    }
225
226    pub fn add(&mut self, sender: NDArraySender) {
227        self.senders.push(sender);
228    }
229
230    pub fn remove(&mut self, port_name: &str) {
231        self.senders.retain(|s| s.port_name != port_name);
232    }
233
234    /// Remove a sender by port name and return it (if found).
235    pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
236        let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
237        Some(self.senders.swap_remove(idx))
238    }
239
240    /// Publish an array to all downstream receivers.
241    pub fn publish(&self, array: Arc<NDArray>) {
242        for sender in &self.senders {
243            sender.send(array.clone());
244        }
245    }
246
247    /// Publish an array to a single downstream receiver by index (for scatter/round-robin).
248    pub fn publish_to(&self, index: usize, array: Arc<NDArray>) {
249        if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
250            sender.send(array);
251        }
252    }
253
254    pub fn total_dropped(&self) -> u64 {
255        self.senders.iter().map(|s| s.dropped_count()).sum()
256    }
257
258    pub fn num_senders(&self) -> usize {
259        self.senders.len()
260    }
261}
262
263impl Default for NDArrayOutput {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use crate::ndarray::{NDArray, NDDataType, NDDimension};
273
274    fn make_test_array(id: i32) -> Arc<NDArray> {
275        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
276        arr.unique_id = id;
277        Arc::new(arr)
278    }
279
280    #[test]
281    fn test_send_receive_basic() {
282        let (sender, mut receiver) = ndarray_channel("TEST", 10);
283        sender.send(make_test_array(1));
284        sender.send(make_test_array(2));
285
286        let rt = tokio::runtime::Builder::new_current_thread()
287            .enable_all()
288            .build()
289            .unwrap();
290        rt.block_on(async {
291            let a1 = receiver.recv().await.unwrap();
292            assert_eq!(a1.unique_id, 1);
293            let a2 = receiver.recv().await.unwrap();
294            assert_eq!(a2.unique_id, 2);
295        });
296    }
297
298    #[test]
299    fn test_back_pressure_drops() {
300        let (sender, _receiver) = ndarray_channel("TEST", 2);
301        // Fill the channel
302        sender.send(make_test_array(1));
303        sender.send(make_test_array(2));
304        // This should be dropped
305        sender.send(make_test_array(3));
306        sender.send(make_test_array(4));
307
308        assert_eq!(sender.dropped_count(), 2);
309    }
310
311    #[test]
312    fn test_fanout_three_receivers() {
313        let (s1, mut r1) = ndarray_channel("P1", 10);
314        let (s2, mut r2) = ndarray_channel("P2", 10);
315        let (s3, mut r3) = ndarray_channel("P3", 10);
316
317        let mut output = NDArrayOutput::new();
318        output.add(s1);
319        output.add(s2);
320        output.add(s3);
321
322        output.publish(make_test_array(42));
323
324        let rt = tokio::runtime::Builder::new_current_thread()
325            .enable_all()
326            .build()
327            .unwrap();
328        rt.block_on(async {
329            assert_eq!(r1.recv().await.unwrap().unique_id, 42);
330            assert_eq!(r2.recv().await.unwrap().unique_id, 42);
331            assert_eq!(r3.recv().await.unwrap().unique_id, 42);
332        });
333    }
334
335    #[test]
336    fn test_fanout_total_dropped() {
337        let (s1, _r1) = ndarray_channel("P1", 1);
338        let (s2, _r2) = ndarray_channel("P2", 1);
339
340        let mut output = NDArrayOutput::new();
341        output.add(s1);
342        output.add(s2);
343
344        // Fill both channels
345        output.publish(make_test_array(1));
346        // Both full now
347        output.publish(make_test_array(2));
348
349        assert_eq!(output.total_dropped(), 2);
350    }
351
352    #[test]
353    fn test_fanout_remove() {
354        let (s1, _r1) = ndarray_channel("P1", 10);
355        let (s2, _r2) = ndarray_channel("P2", 10);
356
357        let mut output = NDArrayOutput::new();
358        output.add(s1);
359        output.add(s2);
360        assert_eq!(output.num_senders(), 2);
361
362        output.remove("P1");
363        assert_eq!(output.num_senders(), 1);
364    }
365
366    #[test]
367    fn test_blocking_recv() {
368        let (sender, mut receiver) = ndarray_channel("TEST", 10);
369
370        let handle = std::thread::spawn(move || {
371            let arr = receiver.blocking_recv().unwrap();
372            arr.unique_id
373        });
374
375        sender.send(make_test_array(99));
376        let id = handle.join().unwrap();
377        assert_eq!(id, 99);
378    }
379
380    #[test]
381    fn test_channel_closed_on_receiver_drop() {
382        let (sender, receiver) = ndarray_channel("TEST", 10);
383        drop(receiver);
384        // Sending to closed channel should not panic
385        sender.send(make_test_array(1));
386        assert_eq!(sender.dropped_count(), 0); // closed, not "dropped"
387    }
388
389    #[test]
390    fn test_queued_counter_basic() {
391        let counter = QueuedArrayCounter::new();
392        assert_eq!(counter.get(), 0);
393        counter.increment();
394        assert_eq!(counter.get(), 1);
395        counter.increment();
396        assert_eq!(counter.get(), 2);
397        counter.decrement();
398        assert_eq!(counter.get(), 1);
399        counter.decrement();
400        assert_eq!(counter.get(), 0);
401    }
402
403    #[test]
404    fn test_queued_counter_wait_until_zero() {
405        let counter = Arc::new(QueuedArrayCounter::new());
406        counter.increment();
407        counter.increment();
408
409        let c = counter.clone();
410        let h = std::thread::spawn(move || {
411            std::thread::sleep(Duration::from_millis(10));
412            c.decrement();
413            std::thread::sleep(Duration::from_millis(10));
414            c.decrement();
415        });
416
417        assert!(counter.wait_until_zero(Duration::from_secs(5)));
418        h.join().unwrap();
419    }
420
421    #[test]
422    fn test_queued_counter_wait_timeout() {
423        let counter = Arc::new(QueuedArrayCounter::new());
424        counter.increment();
425        assert!(!counter.wait_until_zero(Duration::from_millis(10)));
426    }
427
428    #[test]
429    fn test_send_increments_counter() {
430        let counter = Arc::new(QueuedArrayCounter::new());
431        let (mut sender, _receiver) = ndarray_channel("TEST", 10);
432        sender.set_queued_counter(counter.clone());
433
434        sender.send(make_test_array(1));
435        assert_eq!(counter.get(), 1);
436        sender.send(make_test_array(2));
437        assert_eq!(counter.get(), 2);
438    }
439
440    #[test]
441    fn test_send_queue_full_no_net_increment() {
442        let counter = Arc::new(QueuedArrayCounter::new());
443        let (mut sender, _receiver) = ndarray_channel("TEST", 1);
444        sender.set_queued_counter(counter.clone());
445
446        sender.send(make_test_array(1)); // fills queue
447        assert_eq!(counter.get(), 1);
448        sender.send(make_test_array(2)); // queue full → dropped → net 0 change
449        assert_eq!(counter.get(), 1);
450    }
451
452    #[test]
453    fn test_message_drop_decrements() {
454        let counter = Arc::new(QueuedArrayCounter::new());
455        counter.increment();
456        let msg = ArrayMessage {
457            array: make_test_array(1),
458            counter: Some(counter.clone()),
459        };
460        assert_eq!(counter.get(), 1);
461        drop(msg);
462        assert_eq!(counter.get(), 0);
463    }
464}