Skip to main content

ad_plugins/
pos_plugin.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3
4use ad_core::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
5use ad_core::ndarray::NDArray;
6use ad_core::ndarray_pool::NDArrayPool;
7use ad_core::plugin::runtime::{NDPluginProcess, ParamUpdate, ProcessResult};
8use serde::Deserialize;
9
10/// Position mode: Discard consumes positions, Keep cycles through them.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PosMode {
13    Discard,
14    Keep,
15}
16
17/// JSON-deserializable position list.
18#[derive(Debug, Deserialize)]
19pub struct PositionList {
20    pub positions: Vec<HashMap<String, f64>>,
21}
22
23/// NDPosPlugin processor: attaches position metadata to arrays from a position list.
24pub struct PosPluginProcessor {
25    positions: VecDeque<HashMap<String, f64>>,
26    all_positions: Vec<HashMap<String, f64>>,
27    mode: PosMode,
28    index: usize,
29    running: bool,
30    expected_id: i32,
31    missing_frames: usize,
32    duplicate_frames: usize,
33}
34
35impl PosPluginProcessor {
36    pub fn new(mode: PosMode) -> Self {
37        Self {
38            positions: VecDeque::new(),
39            all_positions: Vec::new(),
40            mode,
41            index: 0,
42            running: false,
43            expected_id: 0,
44            missing_frames: 0,
45            duplicate_frames: 0,
46        }
47    }
48
49    /// Load positions from a JSON string.
50    pub fn load_positions_json(&mut self, json_str: &str) -> Result<usize, serde_json::Error> {
51        let list: PositionList = serde_json::from_str(json_str)?;
52        let count = list.positions.len();
53        self.all_positions = list.positions.clone();
54        self.positions = list.positions.into();
55        self.index = 0;
56        Ok(count)
57    }
58
59    /// Load positions directly.
60    pub fn load_positions(&mut self, positions: Vec<HashMap<String, f64>>) {
61        self.all_positions = positions.clone();
62        self.positions = positions.into();
63        self.index = 0;
64    }
65
66    /// Start processing.
67    pub fn start(&mut self) {
68        self.running = true;
69        self.expected_id = 0;
70        self.missing_frames = 0;
71        self.duplicate_frames = 0;
72    }
73
74    /// Stop processing.
75    pub fn stop(&mut self) {
76        self.running = false;
77    }
78
79    /// Clear all positions.
80    pub fn clear(&mut self) {
81        self.positions.clear();
82        self.all_positions.clear();
83        self.index = 0;
84    }
85
86    pub fn missing_frames(&self) -> usize {
87        self.missing_frames
88    }
89
90    pub fn duplicate_frames(&self) -> usize {
91        self.duplicate_frames
92    }
93
94    pub fn remaining_positions(&self) -> usize {
95        match self.mode {
96            PosMode::Discard => self.positions.len(),
97            PosMode::Keep => self.all_positions.len(),
98        }
99    }
100
101    fn current_position(&self) -> Option<&HashMap<String, f64>> {
102        match self.mode {
103            PosMode::Discard => self.positions.front(),
104            PosMode::Keep => {
105                if self.all_positions.is_empty() {
106                    None
107                } else {
108                    Some(&self.all_positions[self.index % self.all_positions.len()])
109                }
110            }
111        }
112    }
113
114    fn advance(&mut self) {
115        match self.mode {
116            PosMode::Discard => { self.positions.pop_front(); }
117            PosMode::Keep => { self.index += 1; }
118        }
119    }
120}
121
122impl NDPluginProcess for PosPluginProcessor {
123    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
124        if !self.running {
125            return ProcessResult::arrays(vec![Arc::new(array.clone())]);
126        }
127
128        let has_positions = match self.mode {
129            PosMode::Discard => !self.positions.is_empty(),
130            PosMode::Keep => !self.all_positions.is_empty(),
131        };
132
133        if !has_positions {
134            return ProcessResult::arrays(vec![Arc::new(array.clone())]);
135        }
136
137        // Frame ID tracking
138        if self.expected_id > 0 {
139            let uid = array.unique_id;
140            if uid > self.expected_id {
141                let diff = (uid - self.expected_id) as usize;
142                self.missing_frames += diff;
143                for _ in 0..diff {
144                    self.advance();
145                    let has = match self.mode {
146                        PosMode::Discard => !self.positions.is_empty(),
147                        PosMode::Keep => !self.all_positions.is_empty(),
148                    };
149                    if !has {
150                        return ProcessResult::arrays(vec![Arc::new(array.clone())]);
151                    }
152                }
153            } else if uid < self.expected_id {
154                self.duplicate_frames += 1;
155                return ProcessResult::empty();
156            }
157        }
158
159        let position = match self.current_position() {
160            Some(pos) => pos.clone(),
161            None => return ProcessResult::arrays(vec![Arc::new(array.clone())]),
162        };
163
164        let mut out = array.clone();
165        for (key, value) in &position {
166            out.attributes.add(NDAttribute {
167                name: key.clone(),
168                description: String::new(),
169                source: NDAttrSource::Driver,
170                value: NDAttrValue::Float64(*value),
171            });
172        }
173
174        self.advance();
175        self.expected_id = array.unique_id + 1;
176
177        let updates = vec![
178            ParamUpdate::int32(0, self.missing_frames as i32),
179            ParamUpdate::int32(1, self.duplicate_frames as i32),
180        ];
181
182        ProcessResult {
183            output_arrays: vec![Arc::new(out)],
184            param_updates: updates,
185        }
186    }
187
188    fn plugin_type(&self) -> &str {
189        "NDPosPlugin"
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use ad_core::ndarray::{NDDataType, NDDimension};
197
198    fn make_array(id: i32) -> NDArray {
199        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
200        arr.unique_id = id;
201        arr
202    }
203
204    #[test]
205    fn test_discard_mode() {
206        let mut proc = PosPluginProcessor::new(PosMode::Discard);
207        let mut pos1 = HashMap::new();
208        pos1.insert("X".into(), 1.5);
209        pos1.insert("Y".into(), 2.3);
210        let mut pos2 = HashMap::new();
211        pos2.insert("X".into(), 3.1);
212        pos2.insert("Y".into(), 4.2);
213
214        proc.load_positions(vec![pos1, pos2]);
215        proc.start();
216
217        let pool = NDArrayPool::new(1_000_000);
218
219        let result = proc.process_array(&make_array(1), &pool);
220        assert_eq!(result.output_arrays.len(), 1);
221        let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
222        assert!((x - 1.5).abs() < 1e-10);
223
224        let result = proc.process_array(&make_array(2), &pool);
225        let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
226        assert!((x - 3.1).abs() < 1e-10);
227
228        assert_eq!(proc.remaining_positions(), 0);
229    }
230
231    #[test]
232    fn test_keep_mode() {
233        let mut proc = PosPluginProcessor::new(PosMode::Keep);
234        let mut pos1 = HashMap::new();
235        pos1.insert("X".into(), 10.0);
236        let mut pos2 = HashMap::new();
237        pos2.insert("X".into(), 20.0);
238
239        proc.load_positions(vec![pos1, pos2]);
240        proc.start();
241
242        let pool = NDArrayPool::new(1_000_000);
243
244        let result = proc.process_array(&make_array(1), &pool);
245        let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
246        assert!((x - 10.0).abs() < 1e-10);
247
248        let result = proc.process_array(&make_array(2), &pool);
249        let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
250        assert!((x - 20.0).abs() < 1e-10);
251
252        // Wraps around
253        let result = proc.process_array(&make_array(3), &pool);
254        let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
255        assert!((x - 10.0).abs() < 1e-10);
256    }
257
258    #[test]
259    fn test_missing_frames() {
260        let mut proc = PosPluginProcessor::new(PosMode::Discard);
261        let mut pos1 = HashMap::new();
262        pos1.insert("X".into(), 1.0);
263        let mut pos2 = HashMap::new();
264        pos2.insert("X".into(), 2.0);
265        let mut pos3 = HashMap::new();
266        pos3.insert("X".into(), 3.0);
267
268        proc.load_positions(vec![pos1, pos2, pos3]);
269        proc.start();
270
271        let pool = NDArrayPool::new(1_000_000);
272
273        proc.process_array(&make_array(1), &pool);
274
275        // Frame 3 (skip frame 2)
276        let result = proc.process_array(&make_array(3), &pool);
277        assert_eq!(proc.missing_frames(), 1);
278        let x = result.output_arrays[0].attributes.get("X").unwrap().value.as_f64().unwrap();
279        assert!((x - 3.0).abs() < 1e-10);
280    }
281
282    #[test]
283    fn test_duplicate_frames() {
284        let mut proc = PosPluginProcessor::new(PosMode::Discard);
285        let mut pos1 = HashMap::new();
286        pos1.insert("X".into(), 1.0);
287        let mut pos2 = HashMap::new();
288        pos2.insert("X".into(), 2.0);
289
290        proc.load_positions(vec![pos1, pos2]);
291        proc.start();
292
293        let pool = NDArrayPool::new(1_000_000);
294
295        proc.process_array(&make_array(1), &pool);
296
297        let result = proc.process_array(&make_array(1), &pool);
298        assert_eq!(proc.duplicate_frames(), 1);
299        assert!(result.output_arrays.is_empty());
300    }
301
302    #[test]
303    fn test_load_json() {
304        let mut proc = PosPluginProcessor::new(PosMode::Discard);
305        let json = r#"{"positions": [{"X": 1.5, "Y": 2.3}, {"X": 3.1, "Y": 4.2}]}"#;
306        let count = proc.load_positions_json(json).unwrap();
307        assert_eq!(count, 2);
308        assert_eq!(proc.remaining_positions(), 2);
309    }
310
311    #[test]
312    fn test_not_running_passthrough() {
313        let mut proc = PosPluginProcessor::new(PosMode::Discard);
314        let pool = NDArrayPool::new(1_000_000);
315        let result = proc.process_array(&make_array(1), &pool);
316        assert_eq!(result.output_arrays.len(), 1);
317        assert!(result.output_arrays[0].attributes.get("X").is_none());
318    }
319}