Skip to main content

ad_plugins_rs/
gather.rs

1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7/// Maximum number of gather input ports.
8const MAX_GATHER_PORTS: usize = 8;
9
10/// Pure gather processing logic (gathers from multiple senders into one stream).
11pub struct GatherProcessor {
12    count: u64,
13}
14
15impl GatherProcessor {
16    pub fn new() -> Self {
17        Self { count: 0 }
18    }
19
20    pub fn total_received(&self) -> u64 {
21        self.count
22    }
23}
24
25impl Default for GatherProcessor {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl NDPluginProcess for GatherProcessor {
32    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
33        self.count += 1;
34        ProcessResult::arrays(vec![Arc::new(array.clone())])
35    }
36
37    fn plugin_type(&self) -> &str {
38        "NDPluginGather"
39    }
40
41    fn register_params(
42        &mut self,
43        base: &mut asyn_rs::port::PortDriverBase,
44    ) -> asyn_rs::error::AsynResult<()> {
45        use asyn_rs::param::ParamType;
46        for i in 1..=MAX_GATHER_PORTS {
47            base.create_param(&format!("GATHER_NDARRAY_PORT_{}", i), ParamType::Octet)?;
48            base.create_param(&format!("GATHER_NDARRAY_ADDR_{}", i), ParamType::Int32)?;
49        }
50        Ok(())
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use ad_core_rs::ndarray::{NDDataType, NDDimension};
58
59    #[test]
60    fn test_gather_processor() {
61        let mut proc = GatherProcessor::new();
62        let pool = NDArrayPool::new(1_000_000);
63
64        let arr1 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
65        let arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
66
67        let result1 = proc.process_array(&arr1, &pool);
68        let result2 = proc.process_array(&arr2, &pool);
69
70        assert_eq!(result1.output_arrays.len(), 1);
71        assert_eq!(result2.output_arrays.len(), 1);
72        assert_eq!(proc.total_received(), 2);
73    }
74}