1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3
4use ad_core::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
5use ad_core::ndarray::NDArray;
6use ad_core::ndarray_pool::NDArrayPool;
7use ad_core::plugin::runtime::{NDPluginProcess, ParamUpdate, ProcessResult};
8use serde::Deserialize;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PosMode {
13 Discard,
14 Keep,
15}
16
17#[derive(Debug, Deserialize)]
19pub struct PositionList {
20 pub positions: Vec<HashMap<String, f64>>,
21}
22
23pub struct PosPluginProcessor {
25 positions: VecDeque<HashMap<String, f64>>,
26 all_positions: Vec<HashMap<String, f64>>,
27 mode: PosMode,
28 index: usize,
29 running: bool,
30 expected_id: i32,
31 missing_frames: usize,
32 duplicate_frames: usize,
33}
34
35impl PosPluginProcessor {
36 pub fn new(mode: PosMode) -> Self {
37 Self {
38 positions: VecDeque::new(),
39 all_positions: Vec::new(),
40 mode,
41 index: 0,
42 running: false,
43 expected_id: 0,
44 missing_frames: 0,
45 duplicate_frames: 0,
46 }
47 }
48
49 pub fn load_positions_json(&mut self, json_str: &str) -> Result<usize, serde_json::Error> {
51 let list: PositionList = serde_json::from_str(json_str)?;
52 let count = list.positions.len();
53 self.all_positions = list.positions.clone();
54 self.positions = list.positions.into();
55 self.index = 0;
56 Ok(count)
57 }
58
59 pub fn load_positions(&mut self, positions: Vec<HashMap<String, f64>>) {
61 self.all_positions = positions.clone();
62 self.positions = positions.into();
63 self.index = 0;
64 }
65
66 pub fn start(&mut self) {
68 self.running = true;
69 self.expected_id = 0;
70 self.missing_frames = 0;
71 self.duplicate_frames = 0;
72 }
73
74 pub fn stop(&mut self) {
76 self.running = false;
77 }
78
79 pub fn clear(&mut self) {
81 self.positions.clear();
82 self.all_positions.clear();
83 self.index = 0;
84 }
85
86 pub fn missing_frames(&self) -> usize {
87 self.missing_frames
88 }
89
90 pub fn duplicate_frames(&self) -> usize {
91 self.duplicate_frames
92 }
93
94 pub fn remaining_positions(&self) -> usize {
95 match self.mode {
96 PosMode::Discard => self.positions.len(),
97 PosMode::Keep => self.all_positions.len(),
98 }
99 }
100
101 fn current_position(&self) -> Option<&HashMap<String, f64>> {
102 match self.mode {
103 PosMode::Discard => self.positions.front(),
104 PosMode::Keep => {
105 if self.all_positions.is_empty() {
106 None
107 } else {
108 Some(&self.all_positions[self.index % self.all_positions.len()])
109 }
110 }
111 }
112 }
113
114 fn advance(&mut self) {
115 match self.mode {
116 PosMode::Discard => { self.positions.pop_front(); }
117 PosMode::Keep => { self.index += 1; }
118 }
119 }
120}
121
122impl NDPluginProcess for PosPluginProcessor {
123 fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
124 if !self.running {
125 return ProcessResult::arrays(vec![Arc::new(array.clone())]);
126 }
127
128 let has_positions = match self.mode {
129 PosMode::Discard => !self.positions.is_empty(),
130 PosMode::Keep => !self.all_positions.is_empty(),
131 };
132
133 if !has_positions {
134 return ProcessResult::arrays(vec![Arc::new(array.clone())]);
135 }
136
137 if self.expected_id > 0 {
139 let uid = array.unique_id;
140 if uid > self.expected_id {
141 let diff = (uid - self.expected_id) as usize;
142 self.missing_frames += diff;
143 for _ in 0..diff {
144 self.advance();
145 let has = match self.mode {
146 PosMode::Discard => !self.positions.is_empty(),
147 PosMode::Keep => !self.all_positions.is_empty(),
148 };
149 if !has {
150 return ProcessResult::arrays(vec![Arc::new(array.clone())]);
151 }
152 }
153 } else if uid < self.expected_id {
154 self.duplicate_frames += 1;
155 return ProcessResult::empty();
156 }
157 }
158
159 let position = match self.current_position() {
160 Some(pos) => pos.clone(),
161 None => return ProcessResult::arrays(vec![Arc::new(array.clone())]),
162 };
163
164 let mut out = array.clone();
165 for (key, value) in &position {
166 out.attributes.add(NDAttribute {
167 name: key.clone(),
168 description: String::new(),
169 source: NDAttrSource::Driver,
170 value: NDAttrValue::Float64(*value),
171 });
172 }
173
174 self.advance();
175 self.expected_id = array.unique_id + 1;
176
177 let updates = vec![
178 ParamUpdate::int32(0, self.missing_frames as i32),
179 ParamUpdate::int32(1, self.duplicate_frames as i32),
180 ];
181
182 ProcessResult {
183 output_arrays: vec![Arc::new(out)],
184 param_updates: updates,
185 }
186 }
187
188 fn plugin_type(&self) -> &str {
189 "NDPosPlugin"
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use ad_core::ndarray::{NDDataType, NDDimension};
197
198 fn make_array(id: i32) -> NDArray {
199 let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
200 arr.unique_id = id;
201 arr
202 }
203
204 #[test]
205 fn test_discard_mode() {
206 let mut proc = PosPluginProcessor::new(PosMode::Discard);
207 let mut pos1 = HashMap::new();
208 pos1.insert("X".into(), 1.5);
209 pos1.insert("Y".into(), 2.3);
210 let mut pos2 = HashMap::new();
211 pos2.insert("X".into(), 3.1);
212 pos2.insert("Y".into(), 4.2);
213
214 proc.load_positions(vec![pos1, pos2]);
215 proc.start();
216
217 let pool = NDArrayPool::new(1_000_000);
218
219 let result = proc.process_array(&make_array(1), &pool);
220 assert_eq!(result.output_arrays.len(), 1);
221 let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
222 assert!((x - 1.5).abs() < 1e-10);
223
224 let result = proc.process_array(&make_array(2), &pool);
225 let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
226 assert!((x - 3.1).abs() < 1e-10);
227
228 assert_eq!(proc.remaining_positions(), 0);
229 }
230
231 #[test]
232 fn test_keep_mode() {
233 let mut proc = PosPluginProcessor::new(PosMode::Keep);
234 let mut pos1 = HashMap::new();
235 pos1.insert("X".into(), 10.0);
236 let mut pos2 = HashMap::new();
237 pos2.insert("X".into(), 20.0);
238
239 proc.load_positions(vec![pos1, pos2]);
240 proc.start();
241
242 let pool = NDArrayPool::new(1_000_000);
243
244 let result = proc.process_array(&make_array(1), &pool);
245 let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
246 assert!((x - 10.0).abs() < 1e-10);
247
248 let result = proc.process_array(&make_array(2), &pool);
249 let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
250 assert!((x - 20.0).abs() < 1e-10);
251
252 let result = proc.process_array(&make_array(3), &pool);
254 let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
255 assert!((x - 10.0).abs() < 1e-10);
256 }
257
258 #[test]
259 fn test_missing_frames() {
260 let mut proc = PosPluginProcessor::new(PosMode::Discard);
261 let mut pos1 = HashMap::new();
262 pos1.insert("X".into(), 1.0);
263 let mut pos2 = HashMap::new();
264 pos2.insert("X".into(), 2.0);
265 let mut pos3 = HashMap::new();
266 pos3.insert("X".into(), 3.0);
267
268 proc.load_positions(vec![pos1, pos2, pos3]);
269 proc.start();
270
271 let pool = NDArrayPool::new(1_000_000);
272
273 proc.process_array(&make_array(1), &pool);
274
275 let result = proc.process_array(&make_array(3), &pool);
277 assert_eq!(proc.missing_frames(), 1);
278 let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
279 assert!((x - 3.0).abs() < 1e-10);
280 }
281
282 #[test]
283 fn test_duplicate_frames() {
284 let mut proc = PosPluginProcessor::new(PosMode::Discard);
285 let mut pos1 = HashMap::new();
286 pos1.insert("X".into(), 1.0);
287 let mut pos2 = HashMap::new();
288 pos2.insert("X".into(), 2.0);
289
290 proc.load_positions(vec![pos1, pos2]);
291 proc.start();
292
293 let pool = NDArrayPool::new(1_000_000);
294
295 proc.process_array(&make_array(1), &pool);
296
297 let result = proc.process_array(&make_array(1), &pool);
298 assert_eq!(proc.duplicate_frames(), 1);
299 assert!(result.output_arrays.is_empty());
300 }
301
302 #[test]
303 fn test_load_json() {
304 let mut proc = PosPluginProcessor::new(PosMode::Discard);
305 let json = r#"{"positions": [{"X": 1.5, "Y": 2.3}, {"X": 3.1, "Y": 4.2}]}"#;
306 let count = proc.load_positions_json(json).unwrap();
307 assert_eq!(count, 2);
308 assert_eq!(proc.remaining_positions(), 2);
309 }
310
311 #[test]
312 fn test_not_running_passthrough() {
313 let mut proc = PosPluginProcessor::new(PosMode::Discard);
314 let pool = NDArrayPool::new(1_000_000);
315 let result = proc.process_array(&make_array(1), &pool);
316 assert_eq!(result.output_arrays.len(), 1);
317 assert!(result.output_arrays[0].attributes.get("X").is_none());
318 }
319}