Skip to main content

ad_plugins_rs/
gather.rs

1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{
6    NDPluginProcess, ParamChangeResult, ParamUpdate, PluginParamSnapshot, ProcessResult,
7};
8
9/// Maximum number of gather input ports.
10pub const MAX_GATHER_PORTS: usize = 8;
11
12/// Per-port param indices for one gather source.
13#[derive(Debug, Clone, Copy, Default)]
14struct GatherPortParams {
15    /// Param index for GATHER_NDARRAY_PORT_N (Octet).
16    port_idx: Option<usize>,
17    /// Param index for GATHER_NDARRAY_ADDR_N (Int32).
18    addr_idx: Option<usize>,
19}
20
21/// Pure gather processing logic: merges arrays from multiple upstream ports
22/// into a single output stream.
23///
24/// Multi-source subscription is achieved at the IOC wiring level:
25/// `NDGatherConfigure` registers the same `NDArraySender` with multiple
26/// upstream `NDArrayOutput`s, so arrays from any configured source arrive
27/// on the plugin's single input channel.
28///
29/// The processor stores the configured source port names and addresses as
30/// params (GATHER_NDARRAY_PORT_1..8, GATHER_NDARRAY_ADDR_1..8) for
31/// introspection and runtime reconfiguration via PVs.
32pub struct GatherProcessor {
33    /// Total arrays received across all sources.
34    count: u64,
35    /// Number of configured source ports (set during construction or param change).
36    num_ports: usize,
37    /// Configured source port names (indexed 0..MAX_GATHER_PORTS-1).
38    source_ports: [String; MAX_GATHER_PORTS],
39    /// Configured source addresses (indexed 0..MAX_GATHER_PORTS-1).
40    source_addrs: [i32; MAX_GATHER_PORTS],
41    /// Param indices for per-port params.
42    port_params: [GatherPortParams; MAX_GATHER_PORTS],
43    /// Param index for GATHER_NUM_PORTS.
44    num_ports_idx: Option<usize>,
45}
46
47impl GatherProcessor {
48    pub fn new() -> Self {
49        Self {
50            count: 0,
51            num_ports: 0,
52            source_ports: Default::default(),
53            source_addrs: [0; MAX_GATHER_PORTS],
54            port_params: [GatherPortParams::default(); MAX_GATHER_PORTS],
55            num_ports_idx: None,
56        }
57    }
58
59    /// Create a GatherProcessor pre-configured with the given source port names.
60    pub fn with_ports(ports: &[&str]) -> Self {
61        let mut proc = Self::new();
62        let n = ports.len().min(MAX_GATHER_PORTS);
63        proc.num_ports = n;
64        for (i, &name) in ports.iter().take(n).enumerate() {
65            proc.source_ports[i] = name.to_string();
66        }
67        proc
68    }
69
70    pub fn total_received(&self) -> u64 {
71        self.count
72    }
73
74    /// Number of configured source ports.
75    pub fn num_ports(&self) -> usize {
76        self.num_ports
77    }
78
79    /// Get the configured source port name for the given index (0-based).
80    pub fn source_port(&self, index: usize) -> &str {
81        if index < MAX_GATHER_PORTS {
82            &self.source_ports[index]
83        } else {
84            ""
85        }
86    }
87}
88
89impl Default for GatherProcessor {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl NDPluginProcess for GatherProcessor {
96    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
97        self.count += 1;
98        ProcessResult::arrays(vec![Arc::new(array.clone())])
99    }
100
101    fn plugin_type(&self) -> &str {
102        "NDPluginGather"
103    }
104
105    fn register_params(
106        &mut self,
107        base: &mut asyn_rs::port::PortDriverBase,
108    ) -> asyn_rs::error::AsynResult<()> {
109        use asyn_rs::param::ParamType;
110
111        // Register per-port params and store their indices
112        for i in 0..MAX_GATHER_PORTS {
113            let port_name = format!("GATHER_NDARRAY_PORT_{}", i + 1);
114            let addr_name = format!("GATHER_NDARRAY_ADDR_{}", i + 1);
115            base.create_param(&port_name, ParamType::Octet)?;
116            base.create_param(&addr_name, ParamType::Int32)?;
117            self.port_params[i].port_idx = base.find_param(&port_name);
118            self.port_params[i].addr_idx = base.find_param(&addr_name);
119        }
120
121        // Register aggregate param for number of configured ports
122        base.create_param("GATHER_NUM_PORTS", ParamType::Int32)?;
123        self.num_ports_idx = base.find_param("GATHER_NUM_PORTS");
124
125        Ok(())
126    }
127
128    fn on_param_change(
129        &mut self,
130        reason: usize,
131        params: &PluginParamSnapshot,
132    ) -> ParamChangeResult {
133        // Check if this is a GATHER_NDARRAY_PORT_N change
134        for i in 0..MAX_GATHER_PORTS {
135            if Some(reason) == self.port_params[i].port_idx {
136                if let Some(new_port) = params.value.as_string() {
137                    self.source_ports[i] = new_port.to_string();
138                    // Recount active ports
139                    self.num_ports = self.source_ports.iter().filter(|s| !s.is_empty()).count();
140                    if let Some(idx) = self.num_ports_idx {
141                        return ParamChangeResult::updates(vec![ParamUpdate::int32(
142                            idx,
143                            self.num_ports as i32,
144                        )]);
145                    }
146                }
147                return ParamChangeResult::empty();
148            }
149            if Some(reason) == self.port_params[i].addr_idx {
150                self.source_addrs[i] = params.value.as_i32();
151                return ParamChangeResult::empty();
152            }
153        }
154
155        ParamChangeResult::empty()
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use ad_core_rs::ndarray::{NDDataType, NDDimension};
163
164    #[test]
165    fn test_gather_processor_passthrough() {
166        let mut proc = GatherProcessor::new();
167        let pool = NDArrayPool::new(1_000_000);
168
169        let arr1 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
170        let arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
171
172        let result1 = proc.process_array(&arr1, &pool);
173        let result2 = proc.process_array(&arr2, &pool);
174
175        assert_eq!(result1.output_arrays.len(), 1);
176        assert_eq!(result2.output_arrays.len(), 1);
177        assert_eq!(proc.total_received(), 2);
178    }
179
180    #[test]
181    fn test_gather_with_ports() {
182        let proc = GatherProcessor::with_ports(&["SIM1", "SIM2", "SIM3"]);
183        assert_eq!(proc.num_ports(), 3);
184        assert_eq!(proc.source_port(0), "SIM1");
185        assert_eq!(proc.source_port(1), "SIM2");
186        assert_eq!(proc.source_port(2), "SIM3");
187        assert_eq!(proc.source_port(3), "");
188    }
189
190    #[test]
191    fn test_gather_multi_source_counting() {
192        let mut proc = GatherProcessor::with_ports(&["DRV1", "DRV2"]);
193        let pool = NDArrayPool::new(1_000_000);
194
195        // Simulate arrays arriving from different sources (all arrive on same channel)
196        for _ in 0..5 {
197            let arr = NDArray::new(vec![NDDimension::new(10)], NDDataType::UInt16);
198            proc.process_array(&arr, &pool);
199        }
200
201        assert_eq!(proc.total_received(), 5);
202    }
203
204    #[test]
205    fn test_gather_default() {
206        let proc = GatherProcessor::default();
207        assert_eq!(proc.total_received(), 0);
208        assert_eq!(proc.num_ports(), 0);
209    }
210
211    #[test]
212    fn test_gather_max_ports_clamped() {
213        // More ports than MAX should be clamped
214        let names: Vec<&str> = (0..12).map(|_| "PORT").collect();
215        let proc = GatherProcessor::with_ports(&names);
216        assert_eq!(proc.num_ports(), MAX_GATHER_PORTS);
217    }
218}