1use std::sync::Arc;
2
3use ad_core::ndarray::NDArray;
4use ad_core::ndarray_pool::NDArrayPool;
5use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7pub struct ScatterProcessor;
10
11impl ScatterProcessor {
12 pub fn new() -> Self {
13 Self
14 }
15}
16
17impl Default for ScatterProcessor {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl NDPluginProcess for ScatterProcessor {
24 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
25 ProcessResult::arrays(vec![Arc::new(array.clone())])
26 }
27
28 fn plugin_type(&self) -> &str {
29 "NDPluginScatter"
30 }
31}
32
33#[cfg(test)]
34mod tests {
35 use super::*;
36 use ad_core::ndarray::{NDDataType, NDDimension};
37
38 #[test]
39 fn test_scatter_processor_passthrough() {
40 let mut proc = ScatterProcessor::new();
41 let pool = NDArrayPool::new(1_000_000);
42
43 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
44 arr.unique_id = 42;
45
46 let result = proc.process_array(&arr, &pool);
47 assert_eq!(result.output_arrays.len(), 1);
48 assert_eq!(result.output_arrays[0].unique_id, 42);
49 }
50}