Skip to main content

ad_plugins/
scatter.rs

1use std::sync::Arc;
2
3use ad_core::ndarray::NDArray;
4use ad_core::ndarray_pool::NDArrayPool;
5use ad_core::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7/// Scatter processor: passes through arrays. Round-robin distribution is handled
8/// by wiring multiple NDArraySender instances downstream.
9pub struct ScatterProcessor;
10
11impl ScatterProcessor {
12    pub fn new() -> Self {
13        Self
14    }
15}
16
17impl Default for ScatterProcessor {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl NDPluginProcess for ScatterProcessor {
24    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
25        ProcessResult::arrays(vec![Arc::new(array.clone())])
26    }
27
28    fn plugin_type(&self) -> &str {
29        "NDPluginScatter"
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36    use ad_core::ndarray::{NDDataType, NDDimension};
37
38    #[test]
39    fn test_scatter_processor_passthrough() {
40        let mut proc = ScatterProcessor::new();
41        let pool = NDArrayPool::new(1_000_000);
42
43        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
44        arr.unique_id = 42;
45
46        let result = proc.process_array(&arr, &pool);
47        assert_eq!(result.output_arrays.len(), 1);
48        assert_eq!(result.output_arrays[0].unique_id, 42);
49    }
50}