1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3
4use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
5use ad_core_rs::ndarray::NDArray;
6use ad_core_rs::ndarray_pool::NDArrayPool;
7use ad_core_rs::plugin::runtime::{NDPluginProcess, ParamUpdate, ProcessResult};
8use serde::Deserialize;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PosMode {
13 Discard,
14 Keep,
15}
16
17#[derive(Debug, Deserialize)]
19pub struct PositionList {
20 pub positions: Vec<HashMap<String, f64>>,
21}
22
23pub struct PosPluginProcessor {
25 positions: VecDeque<HashMap<String, f64>>,
26 all_positions: Vec<HashMap<String, f64>>,
27 mode: PosMode,
28 index: usize,
29 running: bool,
30 expected_id: i32,
31 missing_frames: usize,
32 duplicate_frames: usize,
33}
34
35impl PosPluginProcessor {
36 pub fn new(mode: PosMode) -> Self {
37 Self {
38 positions: VecDeque::new(),
39 all_positions: Vec::new(),
40 mode,
41 index: 0,
42 running: false,
43 expected_id: 0,
44 missing_frames: 0,
45 duplicate_frames: 0,
46 }
47 }
48
49 pub fn load_positions_json(&mut self, json_str: &str) -> Result<usize, serde_json::Error> {
51 let list: PositionList = serde_json::from_str(json_str)?;
52 let count = list.positions.len();
53 self.all_positions = list.positions.clone();
54 self.positions = list.positions.into();
55 self.index = 0;
56 Ok(count)
57 }
58
59 pub fn load_positions(&mut self, positions: Vec<HashMap<String, f64>>) {
61 self.all_positions = positions.clone();
62 self.positions = positions.into();
63 self.index = 0;
64 }
65
66 pub fn start(&mut self) {
68 self.running = true;
69 self.expected_id = 0;
70 self.missing_frames = 0;
71 self.duplicate_frames = 0;
72 }
73
74 pub fn stop(&mut self) {
76 self.running = false;
77 }
78
79 pub fn clear(&mut self) {
81 self.positions.clear();
82 self.all_positions.clear();
83 self.index = 0;
84 }
85
86 pub fn missing_frames(&self) -> usize {
87 self.missing_frames
88 }
89
90 pub fn duplicate_frames(&self) -> usize {
91 self.duplicate_frames
92 }
93
94 pub fn remaining_positions(&self) -> usize {
95 match self.mode {
96 PosMode::Discard => self.positions.len(),
97 PosMode::Keep => self.all_positions.len(),
98 }
99 }
100
101 fn current_position(&self) -> Option<&HashMap<String, f64>> {
102 match self.mode {
103 PosMode::Discard => self.positions.front(),
104 PosMode::Keep => {
105 if self.all_positions.is_empty() {
106 None
107 } else {
108 Some(&self.all_positions[self.index % self.all_positions.len()])
109 }
110 }
111 }
112 }
113
114 fn advance(&mut self) {
115 match self.mode {
116 PosMode::Discard => {
117 self.positions.pop_front();
118 }
119 PosMode::Keep => {
120 self.index += 1;
121 }
122 }
123 }
124}
125
126impl NDPluginProcess for PosPluginProcessor {
127 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
128 if !self.running {
129 return ProcessResult::arrays(vec![Arc::new(array.clone())]);
130 }
131
132 let has_positions = match self.mode {
133 PosMode::Discard => !self.positions.is_empty(),
134 PosMode::Keep => !self.all_positions.is_empty(),
135 };
136
137 if !has_positions {
138 return ProcessResult::arrays(vec![Arc::new(array.clone())]);
139 }
140
141 if self.expected_id > 0 {
143 let uid = array.unique_id;
144 if uid > self.expected_id {
145 let diff = (uid - self.expected_id) as usize;
146 self.missing_frames += diff;
147 for _ in 0..diff {
148 self.advance();
149 let has = match self.mode {
150 PosMode::Discard => !self.positions.is_empty(),
151 PosMode::Keep => !self.all_positions.is_empty(),
152 };
153 if !has {
154 return ProcessResult::arrays(vec![Arc::new(array.clone())]);
155 }
156 }
157 } else if uid < self.expected_id {
158 self.duplicate_frames += 1;
159 return ProcessResult::empty();
160 }
161 }
162
163 let position = match self.current_position() {
164 Some(pos) => pos.clone(),
165 None => return ProcessResult::arrays(vec![Arc::new(array.clone())]),
166 };
167
168 let mut out = array.clone();
169 for (key, value) in &position {
170 out.attributes.add(NDAttribute {
171 name: key.clone(),
172 description: String::new(),
173 source: NDAttrSource::Driver,
174 value: NDAttrValue::Float64(*value),
175 });
176 }
177
178 self.advance();
179 self.expected_id = array.unique_id + 1;
180
181 let updates = vec![
182 ParamUpdate::int32(0, self.missing_frames as i32),
183 ParamUpdate::int32(1, self.duplicate_frames as i32),
184 ];
185
186 ProcessResult {
187 output_arrays: vec![Arc::new(out)],
188 param_updates: updates,
189 scatter_index: None,
190 }
191 }
192
193 fn plugin_type(&self) -> &str {
194 "NDPosPlugin"
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use ad_core_rs::ndarray::{NDDataType, NDDimension};
202
203 fn make_array(id: i32) -> NDArray {
204 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
205 arr.unique_id = id;
206 arr
207 }
208
209 #[test]
210 fn test_discard_mode() {
211 let mut proc = PosPluginProcessor::new(PosMode::Discard);
212 let mut pos1 = HashMap::new();
213 pos1.insert("X".into(), 1.5);
214 pos1.insert("Y".into(), 2.3);
215 let mut pos2 = HashMap::new();
216 pos2.insert("X".into(), 3.1);
217 pos2.insert("Y".into(), 4.2);
218
219 proc.load_positions(vec![pos1, pos2]);
220 proc.start();
221
222 let pool = NDArrayPool::new(1_000_000);
223
224 let result = proc.process_array(&make_array(1), &pool);
225 assert_eq!(result.output_arrays.len(), 1);
226 let x = result.output_arrays[0]
227 .attributes
228 .get("X")
229 .unwrap()
230 .value
231 .as_f64()
232 .unwrap();
233 assert!((x - 1.5).abs() < 1e-10);
234
235 let result = proc.process_array(&make_array(2), &pool);
236 let x = result.output_arrays[0]
237 .attributes
238 .get("X")
239 .unwrap()
240 .value
241 .as_f64()
242 .unwrap();
243 assert!((x - 3.1).abs() < 1e-10);
244
245 assert_eq!(proc.remaining_positions(), 0);
246 }
247
248 #[test]
249 fn test_keep_mode() {
250 let mut proc = PosPluginProcessor::new(PosMode::Keep);
251 let mut pos1 = HashMap::new();
252 pos1.insert("X".into(), 10.0);
253 let mut pos2 = HashMap::new();
254 pos2.insert("X".into(), 20.0);
255
256 proc.load_positions(vec![pos1, pos2]);
257 proc.start();
258
259 let pool = NDArrayPool::new(1_000_000);
260
261 let result = proc.process_array(&make_array(1), &pool);
262 let x = result.output_arrays[0]
263 .attributes
264 .get("X")
265 .unwrap()
266 .value
267 .as_f64()
268 .unwrap();
269 assert!((x - 10.0).abs() < 1e-10);
270
271 let result = proc.process_array(&make_array(2), &pool);
272 let x = result.output_arrays[0]
273 .attributes
274 .get("X")
275 .unwrap()
276 .value
277 .as_f64()
278 .unwrap();
279 assert!((x - 20.0).abs() < 1e-10);
280
281 let result = proc.process_array(&make_array(3), &pool);
283 let x = result.output_arrays[0]
284 .attributes
285 .get("X")
286 .unwrap()
287 .value
288 .as_f64()
289 .unwrap();
290 assert!((x - 10.0).abs() < 1e-10);
291 }
292
293 #[test]
294 fn test_missing_frames() {
295 let mut proc = PosPluginProcessor::new(PosMode::Discard);
296 let mut pos1 = HashMap::new();
297 pos1.insert("X".into(), 1.0);
298 let mut pos2 = HashMap::new();
299 pos2.insert("X".into(), 2.0);
300 let mut pos3 = HashMap::new();
301 pos3.insert("X".into(), 3.0);
302
303 proc.load_positions(vec![pos1, pos2, pos3]);
304 proc.start();
305
306 let pool = NDArrayPool::new(1_000_000);
307
308 proc.process_array(&make_array(1), &pool);
309
310 let result = proc.process_array(&make_array(3), &pool);
312 assert_eq!(proc.missing_frames(), 1);
313 let x = result.output_arrays[0]
314 .attributes
315 .get("X")
316 .unwrap()
317 .value
318 .as_f64()
319 .unwrap();
320 assert!((x - 3.0).abs() < 1e-10);
321 }
322
323 #[test]
324 fn test_duplicate_frames() {
325 let mut proc = PosPluginProcessor::new(PosMode::Discard);
326 let mut pos1 = HashMap::new();
327 pos1.insert("X".into(), 1.0);
328 let mut pos2 = HashMap::new();
329 pos2.insert("X".into(), 2.0);
330
331 proc.load_positions(vec![pos1, pos2]);
332 proc.start();
333
334 let pool = NDArrayPool::new(1_000_000);
335
336 proc.process_array(&make_array(1), &pool);
337
338 let result = proc.process_array(&make_array(1), &pool);
339 assert_eq!(proc.duplicate_frames(), 1);
340 assert!(result.output_arrays.is_empty());
341 }
342
343 #[test]
344 fn test_load_json() {
345 let mut proc = PosPluginProcessor::new(PosMode::Discard);
346 let json = r#"{"positions": [{"X": 1.5, "Y": 2.3}, {"X": 3.1, "Y": 4.2}]}"#;
347 let count = proc.load_positions_json(json).unwrap();
348 assert_eq!(count, 2);
349 assert_eq!(proc.remaining_positions(), 2);
350 }
351
352 #[test]
353 fn test_not_running_passthrough() {
354 let mut proc = PosPluginProcessor::new(PosMode::Discard);
355 let pool = NDArrayPool::new(1_000_000);
356 let result = proc.process_array(&make_array(1), &pool);
357 assert_eq!(result.output_arrays.len(), 1);
358 assert!(result.output_arrays[0].attributes.get("X").is_none());
359 }
360}