1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ScatterMethod {
10 RoundRobin = 0,
11}
12
13impl ScatterMethod {
14 pub fn from_i32(_v: i32) -> Self {
15 Self::RoundRobin
16 }
17}
18
19pub struct ScatterProcessor {
21 method: ScatterMethod,
22 current_index: usize,
23 method_idx: Option<usize>,
24}
25
26impl ScatterProcessor {
27 pub fn new() -> Self {
28 Self {
29 method: ScatterMethod::RoundRobin,
30 current_index: 0,
31 method_idx: None,
32 }
33 }
34}
35
36impl Default for ScatterProcessor {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl NDPluginProcess for ScatterProcessor {
43 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
44 let idx = self.current_index;
45 self.current_index += 1;
46 ProcessResult::scatter(vec![Arc::new(array.clone())], idx)
47 }
48
49 fn plugin_type(&self) -> &str {
50 "NDPluginScatter"
51 }
52
53 fn register_params(
54 &mut self,
55 base: &mut asyn_rs::port::PortDriverBase,
56 ) -> asyn_rs::error::AsynResult<()> {
57 use asyn_rs::param::ParamType;
58 base.create_param("SCATTER_METHOD", ParamType::Int32)?;
59 self.method_idx = base.find_param("SCATTER_METHOD");
60 Ok(())
61 }
62
63 fn on_param_change(
64 &mut self,
65 reason: usize,
66 params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
67 ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
68 if Some(reason) == self.method_idx {
69 self.method = ScatterMethod::from_i32(params.value.as_i32());
70 }
71 ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use ad_core_rs::ndarray::{NDDataType, NDDimension};
79
80 #[test]
81 fn test_scatter_processor_round_robin() {
82 let mut proc = ScatterProcessor::new();
83 let pool = NDArrayPool::new(1_000_000);
84
85 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
86 arr.unique_id = 42;
87
88 let r0 = proc.process_array(&arr, &pool);
89 assert_eq!(r0.scatter_index, Some(0));
90 assert_eq!(r0.output_arrays.len(), 1);
91
92 let r1 = proc.process_array(&arr, &pool);
93 assert_eq!(r1.scatter_index, Some(1));
94
95 let r2 = proc.process_array(&arr, &pool);
96 assert_eq!(r2.scatter_index, Some(2));
97 }
98}