use std::sync::Arc;
use ad_core_rs::ndarray::NDArray;
use ad_core_rs::ndarray_pool::NDArrayPool;
use ad_core_rs::plugin::runtime::{
NDPluginProcess, ParamChangeResult, ParamUpdate, PluginParamSnapshot, ProcessResult,
};
pub const MAX_GATHER_PORTS: usize = 8;
#[derive(Debug, Clone, Copy, Default)]
struct GatherPortParams {
port_idx: Option<usize>,
addr_idx: Option<usize>,
}
pub struct GatherProcessor {
count: u64,
num_ports: usize,
source_ports: [String; MAX_GATHER_PORTS],
source_addrs: [i32; MAX_GATHER_PORTS],
port_params: [GatherPortParams; MAX_GATHER_PORTS],
num_ports_idx: Option<usize>,
}
impl GatherProcessor {
pub fn new() -> Self {
Self {
count: 0,
num_ports: 0,
source_ports: Default::default(),
source_addrs: [0; MAX_GATHER_PORTS],
port_params: [GatherPortParams::default(); MAX_GATHER_PORTS],
num_ports_idx: None,
}
}
pub fn with_ports(ports: &[&str]) -> Self {
let mut proc = Self::new();
let n = ports.len().min(MAX_GATHER_PORTS);
proc.num_ports = n;
for (i, &name) in ports.iter().take(n).enumerate() {
proc.source_ports[i] = name.to_string();
}
proc
}
pub fn total_received(&self) -> u64 {
self.count
}
pub fn num_ports(&self) -> usize {
self.num_ports
}
pub fn source_port(&self, index: usize) -> &str {
if index < MAX_GATHER_PORTS {
&self.source_ports[index]
} else {
""
}
}
}
impl Default for GatherProcessor {
fn default() -> Self {
Self::new()
}
}
impl NDPluginProcess for GatherProcessor {
fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
self.count += 1;
ProcessResult::arrays(vec![Arc::new(array.clone())])
}
fn plugin_type(&self) -> &str {
"NDPluginGather"
}
fn register_params(
&mut self,
base: &mut asyn_rs::port::PortDriverBase,
) -> asyn_rs::error::AsynResult<()> {
use asyn_rs::param::ParamType;
for i in 0..MAX_GATHER_PORTS {
let port_name = format!("GATHER_NDARRAY_PORT_{}", i + 1);
let addr_name = format!("GATHER_NDARRAY_ADDR_{}", i + 1);
base.create_param(&port_name, ParamType::Octet)?;
base.create_param(&addr_name, ParamType::Int32)?;
self.port_params[i].port_idx = base.find_param(&port_name);
self.port_params[i].addr_idx = base.find_param(&addr_name);
}
base.create_param("GATHER_NUM_PORTS", ParamType::Int32)?;
self.num_ports_idx = base.find_param("GATHER_NUM_PORTS");
Ok(())
}
fn on_param_change(
&mut self,
reason: usize,
params: &PluginParamSnapshot,
) -> ParamChangeResult {
for i in 0..MAX_GATHER_PORTS {
if Some(reason) == self.port_params[i].port_idx {
if let Some(new_port) = params.value.as_string() {
self.source_ports[i] = new_port.to_string();
self.num_ports = self.source_ports.iter().filter(|s| !s.is_empty()).count();
if let Some(idx) = self.num_ports_idx {
return ParamChangeResult::updates(vec![ParamUpdate::int32(
idx,
self.num_ports as i32,
)]);
}
}
return ParamChangeResult::empty();
}
if Some(reason) == self.port_params[i].addr_idx {
self.source_addrs[i] = params.value.as_i32();
return ParamChangeResult::empty();
}
}
ParamChangeResult::empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use ad_core_rs::ndarray::{NDDataType, NDDimension};
#[test]
fn test_gather_processor_passthrough() {
let mut proc = GatherProcessor::new();
let pool = NDArrayPool::new(1_000_000);
let arr1 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
let arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
let result1 = proc.process_array(&arr1, &pool);
let result2 = proc.process_array(&arr2, &pool);
assert_eq!(result1.output_arrays.len(), 1);
assert_eq!(result2.output_arrays.len(), 1);
assert_eq!(proc.total_received(), 2);
}
#[test]
fn test_gather_with_ports() {
let proc = GatherProcessor::with_ports(&["SIM1", "SIM2", "SIM3"]);
assert_eq!(proc.num_ports(), 3);
assert_eq!(proc.source_port(0), "SIM1");
assert_eq!(proc.source_port(1), "SIM2");
assert_eq!(proc.source_port(2), "SIM3");
assert_eq!(proc.source_port(3), "");
}
#[test]
fn test_gather_multi_source_counting() {
let mut proc = GatherProcessor::with_ports(&["DRV1", "DRV2"]);
let pool = NDArrayPool::new(1_000_000);
for _ in 0..5 {
let arr = NDArray::new(vec![NDDimension::new(10)], NDDataType::UInt16);
proc.process_array(&arr, &pool);
}
assert_eq!(proc.total_received(), 5);
}
#[test]
fn test_gather_default() {
let proc = GatherProcessor::default();
assert_eq!(proc.total_received(), 0);
assert_eq!(proc.num_ports(), 0);
}
#[test]
fn test_gather_max_ports_clamped() {
let names: Vec<&str> = (0..12).map(|_| "PORT").collect();
let proc = GatherProcessor::with_ports(&names);
assert_eq!(proc.num_ports(), MAX_GATHER_PORTS);
}
}