1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{
6 NDPluginProcess, ParamChangeResult, ParamUpdate, PluginParamSnapshot, ProcessResult,
7};
8
9pub const MAX_GATHER_PORTS: usize = 8;
11
12#[derive(Debug, Clone, Copy, Default)]
14struct GatherPortParams {
15 port_idx: Option<usize>,
17 addr_idx: Option<usize>,
19}
20
21pub struct GatherProcessor {
33 count: u64,
35 num_ports: usize,
37 source_ports: [String; MAX_GATHER_PORTS],
39 source_addrs: [i32; MAX_GATHER_PORTS],
41 port_params: [GatherPortParams; MAX_GATHER_PORTS],
43 num_ports_idx: Option<usize>,
45}
46
47impl GatherProcessor {
48 pub fn new() -> Self {
49 Self {
50 count: 0,
51 num_ports: 0,
52 source_ports: Default::default(),
53 source_addrs: [0; MAX_GATHER_PORTS],
54 port_params: [GatherPortParams::default(); MAX_GATHER_PORTS],
55 num_ports_idx: None,
56 }
57 }
58
59 pub fn with_ports(ports: &[&str]) -> Self {
61 let mut proc = Self::new();
62 let n = ports.len().min(MAX_GATHER_PORTS);
63 proc.num_ports = n;
64 for (i, &name) in ports.iter().take(n).enumerate() {
65 proc.source_ports[i] = name.to_string();
66 }
67 proc
68 }
69
70 pub fn total_received(&self) -> u64 {
71 self.count
72 }
73
74 pub fn num_ports(&self) -> usize {
76 self.num_ports
77 }
78
79 pub fn source_port(&self, index: usize) -> &str {
81 if index < MAX_GATHER_PORTS {
82 &self.source_ports[index]
83 } else {
84 ""
85 }
86 }
87}
88
89impl Default for GatherProcessor {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl NDPluginProcess for GatherProcessor {
96 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
97 self.count += 1;
98 ProcessResult::arrays(vec![Arc::new(array.clone())])
99 }
100
101 fn plugin_type(&self) -> &str {
102 "NDPluginGather"
103 }
104
105 fn register_params(
106 &mut self,
107 base: &mut asyn_rs::port::PortDriverBase,
108 ) -> asyn_rs::error::AsynResult<()> {
109 use asyn_rs::param::ParamType;
110
111 for i in 0..MAX_GATHER_PORTS {
113 let port_name = format!("GATHER_NDARRAY_PORT_{}", i + 1);
114 let addr_name = format!("GATHER_NDARRAY_ADDR_{}", i + 1);
115 base.create_param(&port_name, ParamType::Octet)?;
116 base.create_param(&addr_name, ParamType::Int32)?;
117 self.port_params[i].port_idx = base.find_param(&port_name);
118 self.port_params[i].addr_idx = base.find_param(&addr_name);
119 }
120
121 base.create_param("GATHER_NUM_PORTS", ParamType::Int32)?;
123 self.num_ports_idx = base.find_param("GATHER_NUM_PORTS");
124
125 Ok(())
126 }
127
128 fn on_param_change(
129 &mut self,
130 reason: usize,
131 params: &PluginParamSnapshot,
132 ) -> ParamChangeResult {
133 for i in 0..MAX_GATHER_PORTS {
135 if Some(reason) == self.port_params[i].port_idx {
136 if let Some(new_port) = params.value.as_string() {
137 self.source_ports[i] = new_port.to_string();
138 self.num_ports = self.source_ports.iter().filter(|s| !s.is_empty()).count();
140 if let Some(idx) = self.num_ports_idx {
141 return ParamChangeResult::updates(vec![ParamUpdate::int32(
142 idx,
143 self.num_ports as i32,
144 )]);
145 }
146 }
147 return ParamChangeResult::empty();
148 }
149 if Some(reason) == self.port_params[i].addr_idx {
150 self.source_addrs[i] = params.value.as_i32();
151 return ParamChangeResult::empty();
152 }
153 }
154
155 ParamChangeResult::empty()
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use ad_core_rs::ndarray::{NDDataType, NDDimension};
163
164 #[test]
165 fn test_gather_processor_passthrough() {
166 let mut proc = GatherProcessor::new();
167 let pool = NDArrayPool::new(1_000_000);
168
169 let arr1 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
170 let arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
171
172 let result1 = proc.process_array(&arr1, &pool);
173 let result2 = proc.process_array(&arr2, &pool);
174
175 assert_eq!(result1.output_arrays.len(), 1);
176 assert_eq!(result2.output_arrays.len(), 1);
177 assert_eq!(proc.total_received(), 2);
178 }
179
180 #[test]
181 fn test_gather_with_ports() {
182 let proc = GatherProcessor::with_ports(&["SIM1", "SIM2", "SIM3"]);
183 assert_eq!(proc.num_ports(), 3);
184 assert_eq!(proc.source_port(0), "SIM1");
185 assert_eq!(proc.source_port(1), "SIM2");
186 assert_eq!(proc.source_port(2), "SIM3");
187 assert_eq!(proc.source_port(3), "");
188 }
189
190 #[test]
191 fn test_gather_multi_source_counting() {
192 let mut proc = GatherProcessor::with_ports(&["DRV1", "DRV2"]);
193 let pool = NDArrayPool::new(1_000_000);
194
195 for _ in 0..5 {
197 let arr = NDArray::new(vec![NDDimension::new(10)], NDDataType::UInt16);
198 proc.process_array(&arr, &pool);
199 }
200
201 assert_eq!(proc.total_received(), 5);
202 }
203
204 #[test]
205 fn test_gather_default() {
206 let proc = GatherProcessor::default();
207 assert_eq!(proc.total_received(), 0);
208 assert_eq!(proc.num_ports(), 0);
209 }
210
211 #[test]
212 fn test_gather_max_ports_clamped() {
213 let names: Vec<&str> = (0..12).map(|_| "PORT").collect();
215 let proc = GatherProcessor::with_ports(&names);
216 assert_eq!(proc.num_ports(), MAX_GATHER_PORTS);
217 }
218}