use std::sync::Arc;
use ad_core_rs::ndarray::NDArray;
use ad_core_rs::ndarray_pool::NDArrayPool;
use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
const MAX_GATHER_PORTS: usize = 8;
pub struct GatherProcessor {
count: u64,
}
impl GatherProcessor {
pub fn new() -> Self {
Self { count: 0 }
}
pub fn total_received(&self) -> u64 {
self.count
}
}
impl Default for GatherProcessor {
fn default() -> Self {
Self::new()
}
}
impl NDPluginProcess for GatherProcessor {
fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
self.count += 1;
ProcessResult::arrays(vec![Arc::new(array.clone())])
}
fn plugin_type(&self) -> &str {
"NDPluginGather"
}
fn register_params(
&mut self,
base: &mut asyn_rs::port::PortDriverBase,
) -> asyn_rs::error::AsynResult<()> {
use asyn_rs::param::ParamType;
for i in 1..=MAX_GATHER_PORTS {
base.create_param(&format!("GATHER_NDARRAY_PORT_{}", i), ParamType::Octet)?;
base.create_param(&format!("GATHER_NDARRAY_ADDR_{}", i), ParamType::Int32)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ad_core_rs::ndarray::{NDDataType, NDDimension};
#[test]
fn test_gather_processor() {
let mut proc = GatherProcessor::new();
let pool = NDArrayPool::new(1_000_000);
let arr1 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
let arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
let result1 = proc.process_array(&arr1, &pool);
let result2 = proc.process_array(&arr2, &pool);
assert_eq!(result1.output_arrays.len(), 1);
assert_eq!(result2.output_arrays.len(), 1);
assert_eq!(proc.total_received(), 2);
}
}