Skip to main content

ad_plugins_rs/
scatter.rs

1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7/// Scatter method.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ScatterMethod {
10    RoundRobin = 0,
11}
12
13impl ScatterMethod {
14    pub fn from_i32(_v: i32) -> Self {
15        Self::RoundRobin
16    }
17}
18
19/// Scatter processor: distributes arrays to downstream plugins in round-robin
20/// order.
21///
22/// The actual number of downstream consumers is discovered by the plugin
23/// runtime, which publishes a scatter result to `scatter_index % senders.len()`.
24/// `num_outputs == 0` (the default) means "use the live downstream count": the
25/// processor emits a raw incrementing index and the runtime maps it onto the
26/// connected consumers — equivalent to C++ NDPluginScatter advancing
27/// `nextScatter` over its registered clients. A non-zero `num_outputs` caps the
28/// round-robin span for tests / fixed fan-out.
29pub struct ScatterProcessor {
30    method: ScatterMethod,
31    current_index: usize,
32    num_outputs: usize,
33    method_idx: Option<usize>,
34}
35
36impl ScatterProcessor {
37    pub fn new() -> Self {
38        Self {
39            method: ScatterMethod::RoundRobin,
40            current_index: 0,
41            num_outputs: 0,
42            method_idx: None,
43        }
44    }
45
46    /// Override the round-robin span. `0` means "use the live downstream
47    /// consumer count" (the default).
48    pub fn set_num_outputs(&mut self, n: usize) {
49        self.num_outputs = n;
50    }
51}
52
53impl Default for ScatterProcessor {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl NDPluginProcess for ScatterProcessor {
60    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
61        let idx = if self.num_outputs > 0 {
62            self.current_index % self.num_outputs
63        } else {
64            self.current_index
65        };
66        self.current_index = self.current_index.wrapping_add(1);
67        ProcessResult::scatter(vec![Arc::new(array.clone())], idx)
68    }
69
70    fn plugin_type(&self) -> &str {
71        "NDPluginScatter"
72    }
73
74    fn register_params(
75        &mut self,
76        base: &mut asyn_rs::port::PortDriverBase,
77    ) -> asyn_rs::error::AsynResult<()> {
78        use asyn_rs::param::ParamType;
79        base.create_param("SCATTER_METHOD", ParamType::Int32)?;
80        self.method_idx = base.find_param("SCATTER_METHOD");
81        Ok(())
82    }
83
84    fn on_param_change(
85        &mut self,
86        reason: usize,
87        params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
88    ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
89        if Some(reason) == self.method_idx {
90            self.method = ScatterMethod::from_i32(params.value.as_i32());
91        }
92        ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use ad_core_rs::ndarray::{NDDataType, NDDimension};
100
101    #[test]
102    fn test_scatter_processor_round_robin() {
103        let mut proc = ScatterProcessor::new();
104        proc.set_num_outputs(3);
105        let pool = NDArrayPool::new(1_000_000);
106
107        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
108        arr.unique_id = 42;
109
110        let r0 = proc.process_array(&arr, &pool);
111        assert_eq!(r0.scatter_index, Some(0));
112        assert_eq!(r0.output_arrays.len(), 1);
113
114        let r1 = proc.process_array(&arr, &pool);
115        assert_eq!(r1.scatter_index, Some(1));
116
117        let r2 = proc.process_array(&arr, &pool);
118        assert_eq!(r2.scatter_index, Some(2));
119
120        // Should wrap around
121        let r3 = proc.process_array(&arr, &pool);
122        assert_eq!(r3.scatter_index, Some(0));
123    }
124
125    #[test]
126    fn test_scatter_default_emits_raw_index() {
127        // With the default num_outputs == 0 the processor emits a raw
128        // incrementing index; the runtime maps it onto the actual downstream
129        // consumer count (idx % senders.len()), so scatter does NOT degenerate
130        // to a passthrough on consumer 0.
131        let mut proc = ScatterProcessor::new();
132        let pool = NDArrayPool::new(1_000_000);
133        let arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
134
135        let indices: Vec<_> = (0..4)
136            .map(|_| proc.process_array(&arr, &pool).scatter_index.unwrap())
137            .collect();
138        // Raw monotonically increasing indices — not all 0.
139        assert_eq!(indices, vec![0, 1, 2, 3]);
140    }
141}