Skip to main content

ad_plugins_rs/
pos_plugin.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3
4use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
5use ad_core_rs::ndarray::NDArray;
6use ad_core_rs::ndarray_pool::NDArrayPool;
7use ad_core_rs::plugin::runtime::{NDPluginProcess, ParamUpdate, ProcessResult};
8use serde::Deserialize;
9
10/// Position mode: Discard consumes positions, Keep cycles through them.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PosMode {
13    Discard,
14    Keep,
15}
16
17/// JSON-deserializable position list.
18#[derive(Debug, Deserialize)]
19pub struct PositionList {
20    pub positions: Vec<HashMap<String, f64>>,
21}
22
23/// NDPosPlugin processor: attaches position metadata to arrays from a position list.
24pub struct PosPluginProcessor {
25    positions: VecDeque<HashMap<String, f64>>,
26    all_positions: Vec<HashMap<String, f64>>,
27    mode: PosMode,
28    index: usize,
29    running: bool,
30    expected_id: i32,
31    missing_frames: usize,
32    duplicate_frames: usize,
33}
34
35impl PosPluginProcessor {
36    pub fn new(mode: PosMode) -> Self {
37        Self {
38            positions: VecDeque::new(),
39            all_positions: Vec::new(),
40            mode,
41            index: 0,
42            running: false,
43            expected_id: 0,
44            missing_frames: 0,
45            duplicate_frames: 0,
46        }
47    }
48
49    /// Load positions from a JSON string.
50    pub fn load_positions_json(&mut self, json_str: &str) -> Result<usize, serde_json::Error> {
51        let list: PositionList = serde_json::from_str(json_str)?;
52        let count = list.positions.len();
53        self.all_positions = list.positions.clone();
54        self.positions = list.positions.into();
55        self.index = 0;
56        Ok(count)
57    }
58
59    /// Load positions directly.
60    pub fn load_positions(&mut self, positions: Vec<HashMap<String, f64>>) {
61        self.all_positions = positions.clone();
62        self.positions = positions.into();
63        self.index = 0;
64    }
65
66    /// Start processing.
67    pub fn start(&mut self) {
68        self.running = true;
69        self.expected_id = 0;
70        self.missing_frames = 0;
71        self.duplicate_frames = 0;
72    }
73
74    /// Stop processing.
75    pub fn stop(&mut self) {
76        self.running = false;
77    }
78
79    /// Clear all positions.
80    pub fn clear(&mut self) {
81        self.positions.clear();
82        self.all_positions.clear();
83        self.index = 0;
84    }
85
86    pub fn missing_frames(&self) -> usize {
87        self.missing_frames
88    }
89
90    pub fn duplicate_frames(&self) -> usize {
91        self.duplicate_frames
92    }
93
94    pub fn remaining_positions(&self) -> usize {
95        match self.mode {
96            PosMode::Discard => self.positions.len(),
97            PosMode::Keep => self.all_positions.len(),
98        }
99    }
100
101    fn current_position(&self) -> Option<&HashMap<String, f64>> {
102        match self.mode {
103            PosMode::Discard => self.positions.front(),
104            PosMode::Keep => {
105                if self.all_positions.is_empty() {
106                    None
107                } else {
108                    Some(&self.all_positions[self.index % self.all_positions.len()])
109                }
110            }
111        }
112    }
113
114    fn advance(&mut self) {
115        match self.mode {
116            PosMode::Discard => {
117                self.positions.pop_front();
118            }
119            PosMode::Keep => {
120                self.index += 1;
121            }
122        }
123    }
124}
125
126impl NDPluginProcess for PosPluginProcessor {
127    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
128        if !self.running {
129            return ProcessResult::arrays(vec![Arc::new(array.clone())]);
130        }
131
132        let has_positions = match self.mode {
133            PosMode::Discard => !self.positions.is_empty(),
134            PosMode::Keep => !self.all_positions.is_empty(),
135        };
136
137        if !has_positions {
138            return ProcessResult::arrays(vec![Arc::new(array.clone())]);
139        }
140
141        // Frame ID tracking
142        if self.expected_id > 0 {
143            let uid = array.unique_id;
144            if uid > self.expected_id {
145                let diff = (uid - self.expected_id) as usize;
146                self.missing_frames += diff;
147                for _ in 0..diff {
148                    self.advance();
149                    let has = match self.mode {
150                        PosMode::Discard => !self.positions.is_empty(),
151                        PosMode::Keep => !self.all_positions.is_empty(),
152                    };
153                    if !has {
154                        return ProcessResult::arrays(vec![Arc::new(array.clone())]);
155                    }
156                }
157            } else if uid < self.expected_id {
158                self.duplicate_frames += 1;
159                return ProcessResult::empty();
160            }
161        }
162
163        let position = match self.current_position() {
164            Some(pos) => pos.clone(),
165            None => return ProcessResult::arrays(vec![Arc::new(array.clone())]),
166        };
167
168        let mut out = array.clone();
169        for (key, value) in &position {
170            out.attributes.add(NDAttribute {
171                name: key.clone(),
172                description: String::new(),
173                source: NDAttrSource::Driver,
174                value: NDAttrValue::Float64(*value),
175            });
176        }
177
178        self.advance();
179        self.expected_id = array.unique_id + 1;
180
181        let updates = vec![
182            ParamUpdate::int32(0, self.missing_frames as i32),
183            ParamUpdate::int32(1, self.duplicate_frames as i32),
184        ];
185
186        ProcessResult {
187            output_arrays: vec![Arc::new(out)],
188            param_updates: updates,
189            scatter_index: None,
190        }
191    }
192
193    fn plugin_type(&self) -> &str {
194        "NDPosPlugin"
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use ad_core_rs::ndarray::{NDDataType, NDDimension};
202
203    fn make_array(id: i32) -> NDArray {
204        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
205        arr.unique_id = id;
206        arr
207    }
208
209    #[test]
210    fn test_discard_mode() {
211        let mut proc = PosPluginProcessor::new(PosMode::Discard);
212        let mut pos1 = HashMap::new();
213        pos1.insert("X".into(), 1.5);
214        pos1.insert("Y".into(), 2.3);
215        let mut pos2 = HashMap::new();
216        pos2.insert("X".into(), 3.1);
217        pos2.insert("Y".into(), 4.2);
218
219        proc.load_positions(vec![pos1, pos2]);
220        proc.start();
221
222        let pool = NDArrayPool::new(1_000_000);
223
224        let result = proc.process_array(&make_array(1), &pool);
225        assert_eq!(result.output_arrays.len(), 1);
226        let x = result.output_arrays[0]
227            .attributes
228            .get("X")
229            .unwrap()
230            .value
231            .as_f64()
232            .unwrap();
233        assert!((x - 1.5).abs() < 1e-10);
234
235        let result = proc.process_array(&make_array(2), &pool);
236        let x = result.output_arrays[0]
237            .attributes
238            .get("X")
239            .unwrap()
240            .value
241            .as_f64()
242            .unwrap();
243        assert!((x - 3.1).abs() < 1e-10);
244
245        assert_eq!(proc.remaining_positions(), 0);
246    }
247
248    #[test]
249    fn test_keep_mode() {
250        let mut proc = PosPluginProcessor::new(PosMode::Keep);
251        let mut pos1 = HashMap::new();
252        pos1.insert("X".into(), 10.0);
253        let mut pos2 = HashMap::new();
254        pos2.insert("X".into(), 20.0);
255
256        proc.load_positions(vec![pos1, pos2]);
257        proc.start();
258
259        let pool = NDArrayPool::new(1_000_000);
260
261        let result = proc.process_array(&make_array(1), &pool);
262        let x = result.output_arrays[0]
263            .attributes
264            .get("X")
265            .unwrap()
266            .value
267            .as_f64()
268            .unwrap();
269        assert!((x - 10.0).abs() < 1e-10);
270
271        let result = proc.process_array(&make_array(2), &pool);
272        let x = result.output_arrays[0]
273            .attributes
274            .get("X")
275            .unwrap()
276            .value
277            .as_f64()
278            .unwrap();
279        assert!((x - 20.0).abs() < 1e-10);
280
281        // Wraps around
282        let result = proc.process_array(&make_array(3), &pool);
283        let x = result.output_arrays[0]
284            .attributes
285            .get("X")
286            .unwrap()
287            .value
288            .as_f64()
289            .unwrap();
290        assert!((x - 10.0).abs() < 1e-10);
291    }
292
293    #[test]
294    fn test_missing_frames() {
295        let mut proc = PosPluginProcessor::new(PosMode::Discard);
296        let mut pos1 = HashMap::new();
297        pos1.insert("X".into(), 1.0);
298        let mut pos2 = HashMap::new();
299        pos2.insert("X".into(), 2.0);
300        let mut pos3 = HashMap::new();
301        pos3.insert("X".into(), 3.0);
302
303        proc.load_positions(vec![pos1, pos2, pos3]);
304        proc.start();
305
306        let pool = NDArrayPool::new(1_000_000);
307
308        proc.process_array(&make_array(1), &pool);
309
310        // Frame 3 (skip frame 2)
311        let result = proc.process_array(&make_array(3), &pool);
312        assert_eq!(proc.missing_frames(), 1);
313        let x = result.output_arrays[0]
314            .attributes
315            .get("X")
316            .unwrap()
317            .value
318            .as_f64()
319            .unwrap();
320        assert!((x - 3.0).abs() < 1e-10);
321    }
322
323    #[test]
324    fn test_duplicate_frames() {
325        let mut proc = PosPluginProcessor::new(PosMode::Discard);
326        let mut pos1 = HashMap::new();
327        pos1.insert("X".into(), 1.0);
328        let mut pos2 = HashMap::new();
329        pos2.insert("X".into(), 2.0);
330
331        proc.load_positions(vec![pos1, pos2]);
332        proc.start();
333
334        let pool = NDArrayPool::new(1_000_000);
335
336        proc.process_array(&make_array(1), &pool);
337
338        let result = proc.process_array(&make_array(1), &pool);
339        assert_eq!(proc.duplicate_frames(), 1);
340        assert!(result.output_arrays.is_empty());
341    }
342
343    #[test]
344    fn test_load_json() {
345        let mut proc = PosPluginProcessor::new(PosMode::Discard);
346        let json = r#"{"positions": [{"X": 1.5, "Y": 2.3}, {"X": 3.1, "Y": 4.2}]}"#;
347        let count = proc.load_positions_json(json).unwrap();
348        assert_eq!(count, 2);
349        assert_eq!(proc.remaining_positions(), 2);
350    }
351
352    #[test]
353    fn test_not_running_passthrough() {
354        let mut proc = PosPluginProcessor::new(PosMode::Discard);
355        let pool = NDArrayPool::new(1_000_000);
356        let result = proc.process_array(&make_array(1), &pool);
357        assert_eq!(result.output_arrays.len(), 1);
358        assert!(result.output_arrays[0].attributes.get("X").is_none());
359    }
360}