1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7const MAX_GATHER_PORTS: usize = 8;
9
10pub struct GatherProcessor {
12 count: u64,
13}
14
15impl GatherProcessor {
16 pub fn new() -> Self {
17 Self { count: 0 }
18 }
19
20 pub fn total_received(&self) -> u64 {
21 self.count
22 }
23}
24
25impl Default for GatherProcessor {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl NDPluginProcess for GatherProcessor {
32 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
33 self.count += 1;
34 ProcessResult::arrays(vec![Arc::new(array.clone())])
35 }
36
37 fn plugin_type(&self) -> &str {
38 "NDPluginGather"
39 }
40
41 fn register_params(
42 &mut self,
43 base: &mut asyn_rs::port::PortDriverBase,
44 ) -> asyn_rs::error::AsynResult<()> {
45 use asyn_rs::param::ParamType;
46 for i in 1..=MAX_GATHER_PORTS {
47 base.create_param(&format!("GATHER_NDARRAY_PORT_{}", i), ParamType::Octet)?;
48 base.create_param(&format!("GATHER_NDARRAY_ADDR_{}", i), ParamType::Int32)?;
49 }
50 Ok(())
51 }
52}
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57 use ad_core_rs::ndarray::{NDDataType, NDDimension};
58
59 #[test]
60 fn test_gather_processor() {
61 let mut proc = GatherProcessor::new();
62 let pool = NDArrayPool::new(1_000_000);
63
64 let arr1 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
65 let arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
66
67 let result1 = proc.process_array(&arr1, &pool);
68 let result2 = proc.process_array(&arr2, &pool);
69
70 assert_eq!(result1.output_arrays.len(), 1);
71 assert_eq!(result2.output_arrays.len(), 1);
72 assert_eq!(proc.total_received(), 2);
73 }
74}