1use std::sync::Arc;
2
3use ad_core_rs::ndarray::NDArray;
4use ad_core_rs::ndarray_pool::NDArrayPool;
5use ad_core_rs::plugin::runtime::{NDPluginProcess, ProcessResult};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum ScatterMethod {
10 RoundRobin = 0,
11}
12
13impl ScatterMethod {
14 pub fn from_i32(_v: i32) -> Self {
15 Self::RoundRobin
16 }
17}
18
19pub struct ScatterProcessor {
30 method: ScatterMethod,
31 current_index: usize,
32 num_outputs: usize,
33 method_idx: Option<usize>,
34}
35
36impl ScatterProcessor {
37 pub fn new() -> Self {
38 Self {
39 method: ScatterMethod::RoundRobin,
40 current_index: 0,
41 num_outputs: 0,
42 method_idx: None,
43 }
44 }
45
46 pub fn set_num_outputs(&mut self, n: usize) {
49 self.num_outputs = n;
50 }
51}
52
53impl Default for ScatterProcessor {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl NDPluginProcess for ScatterProcessor {
60 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
61 let idx = if self.num_outputs > 0 {
62 self.current_index % self.num_outputs
63 } else {
64 self.current_index
65 };
66 self.current_index = self.current_index.wrapping_add(1);
67 ProcessResult::scatter(vec![Arc::new(array.clone())], idx)
68 }
69
70 fn plugin_type(&self) -> &str {
71 "NDPluginScatter"
72 }
73
74 fn register_params(
75 &mut self,
76 base: &mut asyn_rs::port::PortDriverBase,
77 ) -> asyn_rs::error::AsynResult<()> {
78 use asyn_rs::param::ParamType;
79 base.create_param("SCATTER_METHOD", ParamType::Int32)?;
80 self.method_idx = base.find_param("SCATTER_METHOD");
81 Ok(())
82 }
83
84 fn on_param_change(
85 &mut self,
86 reason: usize,
87 params: &ad_core_rs::plugin::runtime::PluginParamSnapshot,
88 ) -> ad_core_rs::plugin::runtime::ParamChangeResult {
89 if Some(reason) == self.method_idx {
90 self.method = ScatterMethod::from_i32(params.value.as_i32());
91 }
92 ad_core_rs::plugin::runtime::ParamChangeResult::updates(vec![])
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use ad_core_rs::ndarray::{NDDataType, NDDimension};
100
101 #[test]
102 fn test_scatter_processor_round_robin() {
103 let mut proc = ScatterProcessor::new();
104 proc.set_num_outputs(3);
105 let pool = NDArrayPool::new(1_000_000);
106
107 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
108 arr.unique_id = 42;
109
110 let r0 = proc.process_array(&arr, &pool);
111 assert_eq!(r0.scatter_index, Some(0));
112 assert_eq!(r0.output_arrays.len(), 1);
113
114 let r1 = proc.process_array(&arr, &pool);
115 assert_eq!(r1.scatter_index, Some(1));
116
117 let r2 = proc.process_array(&arr, &pool);
118 assert_eq!(r2.scatter_index, Some(2));
119
120 let r3 = proc.process_array(&arr, &pool);
122 assert_eq!(r3.scatter_index, Some(0));
123 }
124
125 #[test]
126 fn test_scatter_default_emits_raw_index() {
127 let mut proc = ScatterProcessor::new();
132 let pool = NDArrayPool::new(1_000_000);
133 let arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
134
135 let indices: Vec<_> = (0..4)
136 .map(|_| proc.process_array(&arr, &pool).scatter_index.unwrap())
137 .collect();
138 assert_eq!(indices, vec![0, 1, 2, 3]);
140 }
141}