1use std::sync::Arc;
2
3use ad_core::ndarray::NDArray;
4use ad_core::ndarray_pool::NDArrayPool;
5use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7pub struct GatherProcessor {
9 count: u64,
10}
11
12impl GatherProcessor {
13 pub fn new() -> Self {
14 Self { count: 0 }
15 }
16
17 pub fn total_received(&self) -> u64 {
18 self.count
19 }
20}
21
22impl Default for GatherProcessor {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl NDPluginProcess for GatherProcessor {
29 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
30 self.count += 1;
31 ProcessResult::arrays(vec![Arc::new(array.clone())])
32 }
33
34 fn plugin_type(&self) -> &str {
35 "NDPluginGather"
36 }
37}
38
39#[cfg(test)]
40mod tests {
41 use super::*;
42 use ad_core::ndarray::{NDDataType, NDDimension};
43
44 #[test]
45 fn test_gather_processor() {
46 let mut proc = GatherProcessor::new();
47 let pool = NDArrayPool::new(1_000_000);
48
49 let arr1 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
50 let arr2 = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
51
52 let result1 = proc.process_array(&arr1, &pool);
53 let result2 = proc.process_array(&arr2, &pool);
54
55 assert_eq!(result1.output_arrays.len(), 1);
56 assert_eq!(result2.output_arrays.len(), 1);
57 assert_eq!(proc.total_received(), 2);
58 }
59}