Skip to main content

ad_core_rs/plugin/
channel.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
3use std::time::Duration;
4
5use crate::ndarray::NDArray;
6
7/// Type-erased blocking processor for inline array processing.
8pub(crate) trait BlockingProcessFn: Send + Sync {
9    fn process_and_publish(&self, array: &NDArray);
10}
11
12/// Tracks the number of queued (in-flight) arrays across non-blocking plugins.
13/// Used by drivers to perform a bounded wait at end of acquisition.
14pub struct QueuedArrayCounter {
15    count: AtomicUsize,
16    mutex: parking_lot::Mutex<()>,
17    condvar: parking_lot::Condvar,
18}
19
20impl QueuedArrayCounter {
21    /// Create a new counter starting at zero.
22    pub fn new() -> Self {
23        Self {
24            count: AtomicUsize::new(0),
25            mutex: parking_lot::Mutex::new(()),
26            condvar: parking_lot::Condvar::new(),
27        }
28    }
29
30    /// Increment the queued count (called before try_send).
31    pub fn increment(&self) {
32        self.count.fetch_add(1, Ordering::AcqRel);
33    }
34
35    /// Decrement the queued count. Notifies waiters when reaching zero.
36    pub fn decrement(&self) {
37        let prev = self.count.fetch_sub(1, Ordering::AcqRel);
38        if prev == 1 {
39            let _guard = self.mutex.lock();
40            self.condvar.notify_all();
41        }
42    }
43
44    /// Current queued count.
45    pub fn get(&self) -> usize {
46        self.count.load(Ordering::Acquire)
47    }
48
49    /// Wait until count reaches zero, or timeout expires.
50    /// Returns `true` if count is zero, `false` on timeout.
51    pub fn wait_until_zero(&self, timeout: Duration) -> bool {
52        let mut guard = self.mutex.lock();
53        if self.count.load(Ordering::Acquire) == 0 {
54            return true;
55        }
56        !self
57            .condvar
58            .wait_while_for(
59                &mut guard,
60                |_| self.count.load(Ordering::Acquire) != 0,
61                timeout,
62            )
63            .timed_out()
64    }
65}
66
67impl Default for QueuedArrayCounter {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73/// Array message with optional queued-array counter.
74/// When dropped, decrements the counter (if present).
75pub struct ArrayMessage {
76    pub array: Arc<NDArray>,
77    pub(crate) counter: Option<Arc<QueuedArrayCounter>>,
78}
79
80impl Drop for ArrayMessage {
81    fn drop(&mut self) {
82        if let Some(c) = self.counter.take() {
83            c.decrement();
84        }
85    }
86}
87
88/// Sender held by upstream. Supports blocking and non-blocking modes.
89#[derive(Clone)]
90pub struct NDArraySender {
91    tx: tokio::sync::mpsc::Sender<ArrayMessage>,
92    port_name: String,
93    dropped_count: Arc<AtomicU64>,
94    enabled: Arc<AtomicBool>,
95    blocking_mode: Arc<AtomicBool>,
96    blocking_processor: Option<Arc<dyn BlockingProcessFn>>,
97    queued_counter: Option<Arc<QueuedArrayCounter>>,
98}
99
100impl NDArraySender {
101    /// Send an array downstream. Behavior depends on mode:
102    /// - Disabled (`enable_callbacks=0`): silently dropped
103    /// - Blocking (`blocking_callbacks=1`): processed inline on caller's thread
104    /// - Non-blocking (default): queued for data thread (dropped if full)
105    pub fn send(&self, array: Arc<NDArray>) {
106        if !self.enabled.load(Ordering::Acquire) {
107            return;
108        }
109        if self.blocking_mode.load(Ordering::Acquire) {
110            if let Some(ref bp) = self.blocking_processor {
111                bp.process_and_publish(&array);
112                return;
113            }
114        }
115        // Non-blocking path: increment counter before try_send
116        if let Some(ref c) = self.queued_counter {
117            c.increment();
118        }
119        let msg = ArrayMessage {
120            array,
121            counter: self.queued_counter.clone(),
122        };
123        match self.tx.try_send(msg) {
124            Ok(()) => {}
125            Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
126                // msg dropped here → Drop fires → counter decremented (net 0)
127                self.dropped_count.fetch_add(1, Ordering::Relaxed);
128            }
129            Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
130                // msg dropped here → Drop fires → counter decremented
131            }
132        }
133    }
134
135    /// Whether this sender's plugin has callbacks enabled.
136    pub fn is_enabled(&self) -> bool {
137        self.enabled.load(Ordering::Acquire)
138    }
139
140    /// Whether this sender's plugin is in blocking mode.
141    pub fn is_blocking(&self) -> bool {
142        self.blocking_mode.load(Ordering::Acquire)
143    }
144
145    pub fn port_name(&self) -> &str {
146        &self.port_name
147    }
148
149    pub fn dropped_count(&self) -> u64 {
150        self.dropped_count.load(Ordering::Relaxed)
151    }
152
153    /// Clone the shared dropped-array counter (for monitoring from the data thread).
154    pub(crate) fn dropped_count_shared(&self) -> Arc<AtomicU64> {
155        self.dropped_count.clone()
156    }
157
158    /// Clone the underlying tokio sender (for queue capacity checks from the data thread).
159    pub(crate) fn tx_clone(&self) -> tokio::sync::mpsc::Sender<ArrayMessage> {
160        self.tx.clone()
161    }
162
163    /// Set the queued-array counter for tracking in-flight arrays.
164    pub fn set_queued_counter(&mut self, counter: Arc<QueuedArrayCounter>) {
165        self.queued_counter = Some(counter);
166    }
167
168    /// Configure blocking callback support. Used by plugin runtime.
169    pub(crate) fn with_blocking_support(
170        self,
171        enabled: Arc<AtomicBool>,
172        blocking_mode: Arc<AtomicBool>,
173        blocking_processor: Arc<dyn BlockingProcessFn>,
174    ) -> Self {
175        Self {
176            enabled,
177            blocking_mode,
178            blocking_processor: Some(blocking_processor),
179            ..self
180        }
181    }
182}
183
184/// Receiver held by downstream plugin.
185pub struct NDArrayReceiver {
186    rx: tokio::sync::mpsc::Receiver<ArrayMessage>,
187}
188
189impl NDArrayReceiver {
190    /// Blocking receive (for use in std::thread data processing loops).
191    pub fn blocking_recv(&mut self) -> Option<Arc<NDArray>> {
192        self.rx.blocking_recv().map(|msg| msg.array.clone())
193    }
194
195    /// Async receive.
196    pub async fn recv(&mut self) -> Option<Arc<NDArray>> {
197        self.rx.recv().await.map(|msg| msg.array.clone())
198    }
199
200    /// Receive the full ArrayMessage (crate-internal). The message's Drop
201    /// will signal completion when the caller is done with it.
202    pub(crate) async fn recv_msg(&mut self) -> Option<ArrayMessage> {
203        self.rx.recv().await
204    }
205}
206
207/// Create a matched sender/receiver pair.
208pub fn ndarray_channel(port_name: &str, queue_size: usize) -> (NDArraySender, NDArrayReceiver) {
209    let (tx, rx) = tokio::sync::mpsc::channel(queue_size.max(1));
210    (
211        NDArraySender {
212            tx,
213            port_name: port_name.to_string(),
214            dropped_count: Arc::new(AtomicU64::new(0)),
215            enabled: Arc::new(AtomicBool::new(true)),
216            blocking_mode: Arc::new(AtomicBool::new(false)),
217            blocking_processor: None,
218            queued_counter: None,
219        },
220        NDArrayReceiver { rx },
221    )
222}
223
224/// Fan-out: broadcasts arrays to multiple downstream receivers.
225pub struct NDArrayOutput {
226    senders: Vec<NDArraySender>,
227}
228
229impl NDArrayOutput {
230    pub fn new() -> Self {
231        Self {
232            senders: Vec::new(),
233        }
234    }
235
236    pub fn add(&mut self, sender: NDArraySender) {
237        self.senders.push(sender);
238    }
239
240    pub fn remove(&mut self, port_name: &str) {
241        self.senders.retain(|s| s.port_name != port_name);
242    }
243
244    /// Remove a sender by port name and return it (if found).
245    pub fn take(&mut self, port_name: &str) -> Option<NDArraySender> {
246        let idx = self.senders.iter().position(|s| s.port_name == port_name)?;
247        Some(self.senders.swap_remove(idx))
248    }
249
250    /// Publish an array to all downstream receivers.
251    pub fn publish(&self, array: Arc<NDArray>) {
252        for sender in &self.senders {
253            sender.send(array.clone());
254        }
255    }
256
257    /// Publish an array to a single downstream receiver by index (for scatter/round-robin).
258    pub fn publish_to(&self, index: usize, array: Arc<NDArray>) {
259        if let Some(sender) = self.senders.get(index % self.senders.len().max(1)) {
260            sender.send(array);
261        }
262    }
263
264    pub fn total_dropped(&self) -> u64 {
265        self.senders.iter().map(|s| s.dropped_count()).sum()
266    }
267
268    pub fn num_senders(&self) -> usize {
269        self.senders.len()
270    }
271}
272
273impl Default for NDArrayOutput {
274    fn default() -> Self {
275        Self::new()
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::ndarray::{NDArray, NDDataType, NDDimension};
283
284    fn make_test_array(id: i32) -> Arc<NDArray> {
285        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
286        arr.unique_id = id;
287        Arc::new(arr)
288    }
289
290    #[test]
291    fn test_send_receive_basic() {
292        let (sender, mut receiver) = ndarray_channel("TEST", 10);
293        sender.send(make_test_array(1));
294        sender.send(make_test_array(2));
295
296        let rt = tokio::runtime::Builder::new_current_thread()
297            .enable_all()
298            .build()
299            .unwrap();
300        rt.block_on(async {
301            let a1 = receiver.recv().await.unwrap();
302            assert_eq!(a1.unique_id, 1);
303            let a2 = receiver.recv().await.unwrap();
304            assert_eq!(a2.unique_id, 2);
305        });
306    }
307
308    #[test]
309    fn test_back_pressure_drops() {
310        let (sender, _receiver) = ndarray_channel("TEST", 2);
311        // Fill the channel
312        sender.send(make_test_array(1));
313        sender.send(make_test_array(2));
314        // This should be dropped
315        sender.send(make_test_array(3));
316        sender.send(make_test_array(4));
317
318        assert_eq!(sender.dropped_count(), 2);
319    }
320
321    #[test]
322    fn test_fanout_three_receivers() {
323        let (s1, mut r1) = ndarray_channel("P1", 10);
324        let (s2, mut r2) = ndarray_channel("P2", 10);
325        let (s3, mut r3) = ndarray_channel("P3", 10);
326
327        let mut output = NDArrayOutput::new();
328        output.add(s1);
329        output.add(s2);
330        output.add(s3);
331
332        output.publish(make_test_array(42));
333
334        let rt = tokio::runtime::Builder::new_current_thread()
335            .enable_all()
336            .build()
337            .unwrap();
338        rt.block_on(async {
339            assert_eq!(r1.recv().await.unwrap().unique_id, 42);
340            assert_eq!(r2.recv().await.unwrap().unique_id, 42);
341            assert_eq!(r3.recv().await.unwrap().unique_id, 42);
342        });
343    }
344
345    #[test]
346    fn test_fanout_total_dropped() {
347        let (s1, _r1) = ndarray_channel("P1", 1);
348        let (s2, _r2) = ndarray_channel("P2", 1);
349
350        let mut output = NDArrayOutput::new();
351        output.add(s1);
352        output.add(s2);
353
354        // Fill both channels
355        output.publish(make_test_array(1));
356        // Both full now
357        output.publish(make_test_array(2));
358
359        assert_eq!(output.total_dropped(), 2);
360    }
361
362    #[test]
363    fn test_fanout_remove() {
364        let (s1, _r1) = ndarray_channel("P1", 10);
365        let (s2, _r2) = ndarray_channel("P2", 10);
366
367        let mut output = NDArrayOutput::new();
368        output.add(s1);
369        output.add(s2);
370        assert_eq!(output.num_senders(), 2);
371
372        output.remove("P1");
373        assert_eq!(output.num_senders(), 1);
374    }
375
376    #[test]
377    fn test_blocking_recv() {
378        let (sender, mut receiver) = ndarray_channel("TEST", 10);
379
380        let handle = std::thread::spawn(move || {
381            let arr = receiver.blocking_recv().unwrap();
382            arr.unique_id
383        });
384
385        sender.send(make_test_array(99));
386        let id = handle.join().unwrap();
387        assert_eq!(id, 99);
388    }
389
390    #[test]
391    fn test_channel_closed_on_receiver_drop() {
392        let (sender, receiver) = ndarray_channel("TEST", 10);
393        drop(receiver);
394        // Sending to closed channel should not panic
395        sender.send(make_test_array(1));
396        assert_eq!(sender.dropped_count(), 0); // closed, not "dropped"
397    }
398
399    #[test]
400    fn test_queued_counter_basic() {
401        let counter = QueuedArrayCounter::new();
402        assert_eq!(counter.get(), 0);
403        counter.increment();
404        assert_eq!(counter.get(), 1);
405        counter.increment();
406        assert_eq!(counter.get(), 2);
407        counter.decrement();
408        assert_eq!(counter.get(), 1);
409        counter.decrement();
410        assert_eq!(counter.get(), 0);
411    }
412
413    #[test]
414    fn test_queued_counter_wait_until_zero() {
415        let counter = Arc::new(QueuedArrayCounter::new());
416        counter.increment();
417        counter.increment();
418
419        let c = counter.clone();
420        let h = std::thread::spawn(move || {
421            std::thread::sleep(Duration::from_millis(10));
422            c.decrement();
423            std::thread::sleep(Duration::from_millis(10));
424            c.decrement();
425        });
426
427        assert!(counter.wait_until_zero(Duration::from_secs(5)));
428        h.join().unwrap();
429    }
430
431    #[test]
432    fn test_queued_counter_wait_timeout() {
433        let counter = Arc::new(QueuedArrayCounter::new());
434        counter.increment();
435        assert!(!counter.wait_until_zero(Duration::from_millis(10)));
436    }
437
438    #[test]
439    fn test_send_increments_counter() {
440        let counter = Arc::new(QueuedArrayCounter::new());
441        let (mut sender, _receiver) = ndarray_channel("TEST", 10);
442        sender.set_queued_counter(counter.clone());
443
444        sender.send(make_test_array(1));
445        assert_eq!(counter.get(), 1);
446        sender.send(make_test_array(2));
447        assert_eq!(counter.get(), 2);
448    }
449
450    #[test]
451    fn test_send_queue_full_no_net_increment() {
452        let counter = Arc::new(QueuedArrayCounter::new());
453        let (mut sender, _receiver) = ndarray_channel("TEST", 1);
454        sender.set_queued_counter(counter.clone());
455
456        sender.send(make_test_array(1)); // fills queue
457        assert_eq!(counter.get(), 1);
458        sender.send(make_test_array(2)); // queue full → dropped → net 0 change
459        assert_eq!(counter.get(), 1);
460    }
461
462    #[test]
463    fn test_message_drop_decrements() {
464        let counter = Arc::new(QueuedArrayCounter::new());
465        counter.increment();
466        let msg = ArrayMessage {
467            array: make_test_array(1),
468            counter: Some(counter.clone()),
469        };
470        assert_eq!(counter.get(), 1);
471        drop(msg);
472        assert_eq!(counter.get(), 0);
473    }
474}