1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ScatterMethod {
10 RoundRobin = 0,
11}
12
13impl ScatterMethod {
14 pub fn from_i32(_v: i32) -> Self {
15 Self::RoundRobin
16 }
17}
18
19pub struct ScatterProcessor {
21 method: ScatterMethod,
22 current_index: usize,
23 num_outputs: usize,
24 method_idx: Option<usize>,
25}
26
27impl ScatterProcessor {
28 pub fn new() -> Self {
29 Self {
30 method: ScatterMethod::RoundRobin,
31 current_index: 0,
32 num_outputs: 1,
33 method_idx: None,
34 }
35 }
36}
37
38impl Default for ScatterProcessor {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl NDPluginProcess for ScatterProcessor {
45 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
46 let idx = if self.num_outputs > 0 {
47 self.current_index % self.num_outputs
48 } else {
49 self.current_index
50 };
51 self.current_index = self.current_index.wrapping_add(1);
52 ProcessResult::scatter(vec![Arc::new(array.clone())], idx)
53 }
54
55 fn plugin_type(&self) -> &str {
56 "NDPluginScatter"
57 }
58
59 fn register_params(
60 &mut self,
61 base: &mut asyn_rs::port::PortDriverBase,
62 ) -> asyn_rs::error::AsynResult<()> {
63 use asyn_rs::param::ParamType;
64 base.create_param("SCATTER_METHOD", ParamType::Int32)?;
65 self.method_idx = base.find_param("SCATTER_METHOD");
66 Ok(())
67 }
68
69 fn on_param_change(
70 &mut self,
71 reason: usize,
72 params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
73 ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
74 if Some(reason) == self.method_idx {
75 self.method = ScatterMethod::from_i32(params.value.as_i32());
76 }
77 ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use ad_core_rs::ndarray::{NDDataType, NDDimension};
85
86 #[test]
87 fn test_scatter_processor_round_robin() {
88 let mut proc = ScatterProcessor::new();
89 proc.num_outputs = 3;
90 let pool = NDArrayPool::new(1_000_000);
91
92 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
93 arr.unique_id = 42;
94
95 let r0 = proc.process_array(&arr, &pool);
96 assert_eq!(r0.scatter_index, Some(0));
97 assert_eq!(r0.output_arrays.len(), 1);
98
99 let r1 = proc.process_array(&arr, &pool);
100 assert_eq!(r1.scatter_index, Some(1));
101
102 let r2 = proc.process_array(&arr, &pool);
103 assert_eq!(r2.scatter_index, Some(2));
104
105 let r3 = proc.process_array(&arr, &pool);
107 assert_eq!(r3.scatter_index, Some(0));
108 }
109}