Skip to main content

ad_core/plugin/
channel.rs

1use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7/// Type-erased blocking processor for inline array processing.
8pub(crate) trait BlockingProcessFn: Send + Sync {
9    fn process_and_publish(&self, array: &NDArray);
10}
11
12/// Tracks the number of queued (in-flight) arrays across non-blocking plugins.
13/// Used by drivers to perform a bounded wait at end of acquisition.
14pub struct QueuedArrayCounter {
15    count: AtomicUsize,
16    mutex: parking_lot::Mutex<()>,
17    condvar: parking_lot::Condvar,
18}
19
20impl QueuedArrayCounter {
21    /// Create a new counter starting at zero.
22    pub fn new() -> Self {
23        Self {
24            count: AtomicUsize::new(0),
25            mutex: parking_lot::Mutex::new(()),
26            condvar: parking_lot::Condvar::new(),
27        }
28    }
29
30    /// Increment the queued count (called before try_send).
31    pub fn increment(&self) {
32        self.count.fetch_add(1, Ordering::AcqRel);
33    }
34
35    /// Decrement the queued count. Notifies waiters when reaching zero.
36    pub fn decrement(&self) {
37        let prev = self.count.fetch_sub(1, Ordering::AcqRel);
38        if prev == 1 {
39            let _guard = self.mutex.lock();
40            self.condvar.notify_all();
41        }
42    }
43
44    /// Current queued count.
45    pub fn get(&self) -> usize {
46        self.count.load(Ordering::Acquire)
47    }
48
49    /// Wait until count reaches zero, or timeout expires.
50    /// Returns `true` if count is zero, `false` on timeout.
51    pub fn wait_until_zero(&self, timeout: Duration) -> bool {
52        let mut guard = self.mutex.lock();
53        if self.count.load(Ordering::Acquire) == 0 {
54            return true;
55        }
56        !self
57            .condvar
58            .wait_while_for(&mut guard, |_| {
59                self.count.load(Ordering::Acquire) != 0
60            }, timeout)
61            .timed_out()
62    }
63}
64
65impl Default for QueuedArrayCounter {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71/// Array message with optional queued-array counter.
72/// When dropped, decrements the counter (if present).
73pub struct ArrayMessage {
74    pub array: Arc<NDArray>,
75    pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
76}
77
78impl Drop for ArrayMessage {
79    fn drop(&mut self) {
80        if let Some(c) = self.counter.take() {
81            c.decrement();
82        }
83    }
84}
85
86/// Sender held by upstream. Supports blocking and non-blocking modes.
87#[derive(Clone)]
88pub struct NDArraySender {
89    tx: tokio::sync::mpsc::Sender<ArrayMessage>,
90    port_name: String,
91    dropped_count: Arc<AtomicU64>,
92    enabled: Arc<AtomicBool>,
93    blocking_mode: Arc<AtomicBool>,
94    blocking_processor: Option<Arc<dyn BlockingProcessFn>>,
95    queued_counter: Option<Arc<QueuedArrayCounter>>,
96}
97
98impl NDArraySender {
99    /// Send an array downstream. Behavior depends on mode:
100    /// - Disabled (`enable_callbacks=0`): silently dropped
101    /// - Blocking (`blocking_callbacks=1`): processed inline on caller's thread
102    /// - Non-blocking (default): queued for data thread (dropped if full)
103    pub fn send(&self, array: Arc<NDArray>) {
104        if !self.enabled.load(Ordering::Acquire) {
105            return;
106        }
107        if self.blocking_mode.load(Ordering::Acquire) {
108            if let Some(ref bp) = self.blocking_processor {
109                bp.process_and_publish(&array);
110                return;
111            }
112        }
113        // Non-blocking path: increment counter before try_send
114        if let Some(ref c) = self.queued_counter {
115            c.increment();
116        }
117        let msg = ArrayMessage {
118            array,
119            counter: self.queued_counter.clone(),
120        };
121        match self.tx.try_send(msg) {
122            Ok(()) => {}
123            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
124                // msg dropped here → Drop fires → counter decremented (net 0)
125                self.dropped_count.fetch_add(1, Ordering::Relaxed);
126            }
127            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
128                // msg dropped here → Drop fires → counter decremented
129            }
130        }
131    }
132
133    /// Whether this sender's plugin has callbacks enabled.
134    pub fn is_enabled(&self) -> bool {
135        self.enabled.load(Ordering::Acquire)
136    }
137
138    /// Whether this sender's plugin is in blocking mode.
139    pub fn is_blocking(&self) -> bool {
140        self.blocking_mode.load(Ordering::Acquire)
141    }
142
143    pub fn port_name(&self) -> &str {
144        &self.port_name
145    }
146
147    pub fn dropped_count(&self) -> u64 {
148        self.dropped_count.load(Ordering::Relaxed)
149    }
150
151    /// Set the queued-array counter for tracking in-flight arrays.
152    pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
153        self.queued_counter = Some(counter);
154    }
155
156    /// Configure blocking callback support. Used by plugin runtime.
157    pub(crate) fn with_blocking_support(
158        self,
159        enabled: Arc<AtomicBool>,
160        blocking_mode: Arc<AtomicBool>,
161        blocking_processor: Arc<dyn BlockingProcessFn>,
162    ) -> Self {
163        Self {
164            enabled,
165            blocking_mode,
166            blocking_processor: Some(blocking_processor),
167            ..self
168        }
169    }
170}
171
172/// Receiver held by downstream plugin.
173pub struct NDArrayReceiver {
174    rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
175}
176
177impl NDArrayReceiver {
178    /// Blocking receive (for use in std::thread data processing loops).
179    pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
180        self.rx.blocking_recv().map(|msg| msg.array.clone())
181    }
182
183    /// Async receive.
184    pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
185        self.rx.recv().await.map(|msg| msg.array.clone())
186    }
187
188    /// Receive the full ArrayMessage (crate-internal). The message's Drop
189    /// will signal completion when the caller is done with it.
190    pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
191        self.rx.recv().await
192    }
193}
194
195/// Create a matched sender/receiver pair.
196pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
197    let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
198    (
199        NDArraySender {
200            tx,
201            port_name: port_name.to_string(),
202            dropped_count: Arc::new(AtomicU64::new(0)),
203            enabled: Arc::new(AtomicBool::new(true)),
204            blocking_mode: Arc::new(AtomicBool::new(false)),
205            blocking_processor: None,
206            queued_counter: None,
207        },
208        NDArrayReceiver { rx },
209    )
210}
211
212/// Fan-out: broadcasts arrays to multiple downstream receivers.
213pub struct NDArrayOutput {
214    senders: Vec<NDArraySender>,
215}
216
217impl NDArrayOutput {
218    pub fn new() -> Self {
219        Self {
220            senders: Vec::new(),
221        }
222    }
223
224    pub fn add(&mut self, sender: NDArraySender) {
225        self.senders.push(sender);
226    }
227
228    pub fn remove(&mut self, port_name: &str) {
229        self.senders.retain(|s| s.port_name != port_name);
230    }
231
232    /// Remove a sender by port name and return it (if found).
233    pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
234        let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
235        Some(self.senders.swap_remove(idx))
236    }
237
238    /// Publish an array to all downstream receivers.
239    pub fn publish(&self, array: Arc<NDArray>) {
240        for sender in &self.senders {
241            sender.send(array.clone());
242        }
243    }
244
245    pub fn total_dropped(&self) -> u64 {
246        self.senders.iter().map(|s| s.dropped_count()).sum()
247    }
248
249    pub fn num_senders(&self) -> usize {
250        self.senders.len()
251    }
252}
253
254impl Default for NDArrayOutput {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::ndarray::{NDArray, NDDataType, NDDimension};
264
265    fn make_test_array(id: i32) -> Arc<NDArray> {
266        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
267        arr.unique_id = id;
268        Arc::new(arr)
269    }
270
271    #[test]
272    fn test_send_receive_basic() {
273        let (sender, mut receiver) = ndarray_channel("TEST", 10);
274        sender.send(make_test_array(1));
275        sender.send(make_test_array(2));
276
277        let rt = tokio::runtime::Builder::new_current_thread()
278            .enable_all()
279            .build()
280            .unwrap();
281        rt.block_on(async {
282            let a1 = receiver.recv().await.unwrap();
283            assert_eq!(a1.unique_id, 1);
284            let a2 = receiver.recv().await.unwrap();
285            assert_eq!(a2.unique_id, 2);
286        });
287    }
288
289    #[test]
290    fn test_back_pressure_drops() {
291        let (sender, _receiver) = ndarray_channel("TEST", 2);
292        // Fill the channel
293        sender.send(make_test_array(1));
294        sender.send(make_test_array(2));
295        // This should be dropped
296        sender.send(make_test_array(3));
297        sender.send(make_test_array(4));
298
299        assert_eq!(sender.dropped_count(), 2);
300    }
301
302    #[test]
303    fn test_fanout_three_receivers() {
304        let (s1, mut r1) = ndarray_channel("P1", 10);
305        let (s2, mut r2) = ndarray_channel("P2", 10);
306        let (s3, mut r3) = ndarray_channel("P3", 10);
307
308        let mut output = NDArrayOutput::new();
309        output.add(s1);
310        output.add(s2);
311        output.add(s3);
312
313        output.publish(make_test_array(42));
314
315        let rt = tokio::runtime::Builder::new_current_thread()
316            .enable_all()
317            .build()
318            .unwrap();
319        rt.block_on(async {
320            assert_eq!(r1.recv().await.unwrap().unique_id, 42);
321            assert_eq!(r2.recv().await.unwrap().unique_id, 42);
322            assert_eq!(r3.recv().await.unwrap().unique_id, 42);
323        });
324    }
325
326    #[test]
327    fn test_fanout_total_dropped() {
328        let (s1, _r1) = ndarray_channel("P1", 1);
329        let (s2, _r2) = ndarray_channel("P2", 1);
330
331        let mut output = NDArrayOutput::new();
332        output.add(s1);
333        output.add(s2);
334
335        // Fill both channels
336        output.publish(make_test_array(1));
337        // Both full now
338        output.publish(make_test_array(2));
339
340        assert_eq!(output.total_dropped(), 2);
341    }
342
343    #[test]
344    fn test_fanout_remove() {
345        let (s1, _r1) = ndarray_channel("P1", 10);
346        let (s2, _r2) = ndarray_channel("P2", 10);
347
348        let mut output = NDArrayOutput::new();
349        output.add(s1);
350        output.add(s2);
351        assert_eq!(output.num_senders(), 2);
352
353        output.remove("P1");
354        assert_eq!(output.num_senders(), 1);
355    }
356
357    #[test]
358    fn test_blocking_recv() {
359        let (sender, mut receiver) = ndarray_channel("TEST", 10);
360
361        let handle = std::thread::spawn(move || {
362            let arr = receiver.blocking_recv().unwrap();
363            arr.unique_id
364        });
365
366        sender.send(make_test_array(99));
367        let id = handle.join().unwrap();
368        assert_eq!(id, 99);
369    }
370
371    #[test]
372    fn test_channel_closed_on_receiver_drop() {
373        let (sender, receiver) = ndarray_channel("TEST", 10);
374        drop(receiver);
375        // Sending to closed channel should not panic
376        sender.send(make_test_array(1));
377        assert_eq!(sender.dropped_count(), 0); // closed, not "dropped"
378    }
379
380    #[test]
381    fn test_queued_counter_basic() {
382        let counter = QueuedArrayCounter::new();
383        assert_eq!(counter.get(), 0);
384        counter.increment();
385        assert_eq!(counter.get(), 1);
386        counter.increment();
387        assert_eq!(counter.get(), 2);
388        counter.decrement();
389        assert_eq!(counter.get(), 1);
390        counter.decrement();
391        assert_eq!(counter.get(), 0);
392    }
393
394    #[test]
395    fn test_queued_counter_wait_until_zero() {
396        let counter = Arc::new(QueuedArrayCounter::new());
397        counter.increment();
398        counter.increment();
399
400        let c = counter.clone();
401        let h = std::thread::spawn(move || {
402            std::thread::sleep(Duration::from_millis(10));
403            c.decrement();
404            std::thread::sleep(Duration::from_millis(10));
405            c.decrement();
406        });
407
408        assert!(counter.wait_until_zero(Duration::from_secs(5)));
409        h.join().unwrap();
410    }
411
412    #[test]
413    fn test_queued_counter_wait_timeout() {
414        let counter = Arc::new(QueuedArrayCounter::new());
415        counter.increment();
416        assert!(!counter.wait_until_zero(Duration::from_millis(10)));
417    }
418
419    #[test]
420    fn test_send_increments_counter() {
421        let counter = Arc::new(QueuedArrayCounter::new());
422        let (mut sender, _receiver) = ndarray_channel("TEST", 10);
423        sender.set_queued_counter(counter.clone());
424
425        sender.send(make_test_array(1));
426        assert_eq!(counter.get(), 1);
427        sender.send(make_test_array(2));
428        assert_eq!(counter.get(), 2);
429    }
430
431    #[test]
432    fn test_send_queue_full_no_net_increment() {
433        let counter = Arc::new(QueuedArrayCounter::new());
434        let (mut sender, _receiver) = ndarray_channel("TEST", 1);
435        sender.set_queued_counter(counter.clone());
436
437        sender.send(make_test_array(1)); // fills queue
438        assert_eq!(counter.get(), 1);
439        sender.send(make_test_array(2)); // queue full → dropped → net 0 change
440        assert_eq!(counter.get(), 1);
441    }
442
443    #[test]
444    fn test_message_drop_decrements() {
445        let counter = Arc::new(QueuedArrayCounter::new());
446        counter.increment();
447        let msg = ArrayMessage {
448            array: make_test_array(1),
449            counter: Some(counter.clone()),
450        };
451        assert_eq!(counter.get(), 1);
452        drop(msg);
453        assert_eq!(counter.get(), 0);
454    }
455}