Skip to main content

ad_plugins_rs/
pos_plugin.rs

1use std::collections::{HashMap, VecDeque};
2use std::sync::Arc;
3
4use ad_core_rs::attributes::{NDAttrSource, NDAttrValue, NDAttribute};
5use ad_core_rs::ndarray::NDArray;
6use ad_core_rs::ndarray_pool::NDArrayPool;
7use ad_core_rs::plugin::runtime::{NDPluginProcess, ParamUpdate, ProcessResult};
8use serde::Deserialize;
9
10/// Position mode: Discard consumes positions, Keep cycles through them.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum PosMode {
13    Discard,
14    Keep,
15}
16
17/// JSON-deserializable position list.
18#[derive(Debug, Deserialize)]
19pub struct PositionList {
20    pub positions: Vec<HashMap<String, f64>>,
21}
22
23/// NDPosPlugin processor: attaches position metadata to arrays from a position list.
24pub struct PosPluginProcessor {
25    positions: VecDeque<HashMap<String, f64>>,
26    all_positions: Vec<HashMap<String, f64>>,
27    mode: PosMode,
28    index: usize,
29    running: bool,
30    expected_id: i32,
31    missing_frames: usize,
32    duplicate_frames: usize,
33}
34
35impl PosPluginProcessor {
36    pub fn new(mode: PosMode) -> Self {
37        Self {
38            positions: VecDeque::new(),
39            all_positions: Vec::new(),
40            mode,
41            index: 0,
42            running: false,
43            expected_id: 0,
44            missing_frames: 0,
45            duplicate_frames: 0,
46        }
47    }
48
49    /// Load positions from a JSON string.
50    pub fn load_positions_json(&mut self, json_str: &str) -> Result<usize, serde_json::Error> {
51        let list: PositionList = serde_json::from_str(json_str)?;
52        let count = list.positions.len();
53        self.all_positions = list.positions.clone();
54        self.positions = list.positions.into();
55        self.index = 0;
56        Ok(count)
57    }
58
59    /// Load positions from an XML string (C++ NDPosPlugin format).
60    ///
61    /// Expected XML format:
62    /// ```xml
63    /// <positions>
64    ///   <position index="0">value1</position>
65    ///   <position index="1">value2</position>
66    /// </positions>
67    /// ```
68    ///
69    /// Each `<position>` element becomes a single-entry HashMap with key "position"
70    /// mapped to the parsed f64 value. If the value cannot be parsed as f64,
71    /// the position is skipped.
72    pub fn load_positions_xml(&mut self, xml_str: &str) -> Result<usize, String> {
73        let positions = parse_positions_xml(xml_str)?;
74        let count = positions.len();
75        self.all_positions = positions.clone();
76        self.positions = positions.into();
77        self.index = 0;
78        Ok(count)
79    }
80
81    /// Load positions from a string, auto-detecting format.
82    ///
83    /// If the content starts with '<' (after trimming whitespace), it is treated as XML.
84    /// Otherwise, it is treated as JSON.
85    pub fn load_positions_auto(&mut self, content: &str) -> Result<usize, String> {
86        if content.trim_start().starts_with('<') {
87            self.load_positions_xml(content)
88        } else {
89            self.load_positions_json(content)
90                .map_err(|e| format!("JSON parse error: {}", e))
91        }
92    }
93
94    /// Load positions directly.
95    pub fn load_positions(&mut self, positions: Vec<HashMap<String, f64>>) {
96        self.all_positions = positions.clone();
97        self.positions = positions.into();
98        self.index = 0;
99    }
100
101    /// Start processing.
102    pub fn start(&mut self) {
103        self.running = true;
104        self.expected_id = 0;
105        self.missing_frames = 0;
106        self.duplicate_frames = 0;
107    }
108
109    /// Stop processing.
110    pub fn stop(&mut self) {
111        self.running = false;
112    }
113
114    /// Clear all positions.
115    pub fn clear(&mut self) {
116        self.positions.clear();
117        self.all_positions.clear();
118        self.index = 0;
119    }
120
121    pub fn missing_frames(&self) -> usize {
122        self.missing_frames
123    }
124
125    pub fn duplicate_frames(&self) -> usize {
126        self.duplicate_frames
127    }
128
129    pub fn remaining_positions(&self) -> usize {
130        match self.mode {
131            PosMode::Discard => self.positions.len(),
132            PosMode::Keep => self.all_positions.len(),
133        }
134    }
135
136    fn current_position(&self) -> Option<&HashMap<String, f64>> {
137        match self.mode {
138            PosMode::Discard => self.positions.front(),
139            PosMode::Keep => {
140                if self.index < self.all_positions.len() {
141                    Some(&self.all_positions[self.index])
142                } else {
143                    None
144                }
145            }
146        }
147    }
148
149    fn advance(&mut self) {
150        match self.mode {
151            PosMode::Discard => {
152                self.positions.pop_front();
153            }
154            PosMode::Keep => {
155                self.index += 1;
156            }
157        }
158    }
159}
160
161/// Parse positions from the C++ NDPosPlugin XML format.
162///
163/// Handles the simple format:
164/// ```xml
165/// <positions>
166///   <position index="0">123.45</position>
167///   <position index="1">678.90</position>
168/// </positions>
169/// ```
170///
171/// This is a minimal hand-written parser for this trivial XML format,
172/// avoiding the need for an external XML crate dependency.
173/// Check if a character can follow `<position` in a valid opening tag.
174/// Valid: whitespace (attributes), '>' (end of tag), '/' (self-closing).
175/// Invalid: 's' (which would mean `<positions`).
176fn is_position_tag_boundary(c: char) -> bool {
177    c.is_ascii_whitespace() || c == '>' || c == '/'
178}
179
180fn parse_positions_xml(xml: &str) -> Result<Vec<HashMap<String, f64>>, String> {
181    let mut positions: Vec<(usize, f64)> = Vec::new();
182    let tag_prefix = "<position";
183
184    // Find all <position ...>value</position> elements
185    let mut search_from = 0;
186    while let Some(rel_start) = xml[search_from..].find(tag_prefix) {
187        let open_start = search_from + rel_start;
188        let after_prefix = open_start + tag_prefix.len();
189
190        // Check that this is actually <position ...> and not <positions> or </positions>
191        if after_prefix >= xml.len() {
192            break;
193        }
194        let next_char = xml[after_prefix..].chars().next().unwrap_or(' ');
195        if !is_position_tag_boundary(next_char) {
196            // This is <positions> or similar, skip past it
197            search_from = after_prefix;
198            continue;
199        }
200
201        let tag_end = xml[open_start..]
202            .find('>')
203            .ok_or_else(|| "Malformed XML: unclosed <position tag".to_string())?;
204        let tag_end = open_start + tag_end;
205
206        // Parse index attribute from the opening tag
207        let tag_content = &xml[open_start..tag_end];
208        let index = if let Some(idx_start) = tag_content.find("index=") {
209            let after_eq = &tag_content[idx_start + 6..];
210            // Handle both index="0" and index='0'
211            let quote_char = after_eq.chars().next().unwrap_or('"');
212            if quote_char == '"' || quote_char == '\'' {
213                let inner = &after_eq[1..];
214                let end = inner.find(quote_char).ok_or_else(|| {
215                    "Malformed XML: unclosed quote in index attribute".to_string()
216                })?;
217                inner[..end]
218                    .parse::<usize>()
219                    .map_err(|e| format!("Invalid index value: {}", e))?
220            } else {
221                // No quotes, read digits
222                let end = after_eq
223                    .find(|c: char| !c.is_ascii_digit())
224                    .unwrap_or(after_eq.len());
225                after_eq[..end]
226                    .parse::<usize>()
227                    .map_err(|e| format!("Invalid index value: {}", e))?
228            }
229        } else {
230            // No index attribute, use sequential ordering
231            positions.len()
232        };
233
234        // Extract value between > and </position>
235        let value_start = tag_end + 1;
236        let close_tag = xml[value_start..]
237            .find("</position>")
238            .ok_or_else(|| "Malformed XML: missing </position> closing tag".to_string())?;
239        let value_str = xml[value_start..value_start + close_tag].trim();
240
241        if let Ok(value) = value_str.parse::<f64>() {
242            positions.push((index, value));
243        }
244        // Skip non-numeric values silently
245
246        search_from = value_start + close_tag + "</position>".len();
247    }
248
249    // Sort by index and build the result
250    positions.sort_by_key(|(idx, _)| *idx);
251
252    let result: Vec<HashMap<String, f64>> = positions
253        .into_iter()
254        .map(|(_, value)| {
255            let mut map = HashMap::new();
256            map.insert("position".into(), value);
257            map
258        })
259        .collect();
260
261    Ok(result)
262}
263
264impl NDPluginProcess for PosPluginProcessor {
265    fn process_array(&mut self, array: &NDArray, _pool: &NDArrayPool) -> ProcessResult {
266        if !self.running {
267            return ProcessResult::arrays(vec![Arc::new(array.clone())]);
268        }
269
270        let has_positions = match self.mode {
271            PosMode::Discard => !self.positions.is_empty(),
272            PosMode::Keep => !self.all_positions.is_empty(),
273        };
274
275        if !has_positions {
276            return ProcessResult::arrays(vec![Arc::new(array.clone())]);
277        }
278
279        // Frame ID tracking
280        if self.expected_id > 0 {
281            let uid = array.unique_id;
282            if uid > self.expected_id {
283                let diff = (uid - self.expected_id) as usize;
284                self.missing_frames += diff;
285                for _ in 0..diff {
286                    self.advance();
287                    let has = match self.mode {
288                        PosMode::Discard => !self.positions.is_empty(),
289                        PosMode::Keep => !self.all_positions.is_empty(),
290                    };
291                    if !has {
292                        return ProcessResult::arrays(vec![Arc::new(array.clone())]);
293                    }
294                }
295            } else if uid < self.expected_id {
296                self.duplicate_frames += 1;
297                return ProcessResult::empty();
298            }
299        }
300
301        let position = match self.current_position() {
302            Some(pos) => pos.clone(),
303            None => return ProcessResult::arrays(vec![Arc::new(array.clone())]),
304        };
305
306        let mut out = array.clone();
307        for (key, value) in &position {
308            out.attributes.add(NDAttribute {
309                name: key.clone(),
310                description: String::new(),
311                source: NDAttrSource::Driver,
312                value: NDAttrValue::Float64(*value),
313            });
314        }
315
316        self.advance();
317        self.expected_id = array.unique_id + 1;
318
319        let updates = vec![
320            ParamUpdate::int32(0, self.missing_frames as i32),
321            ParamUpdate::int32(1, self.duplicate_frames as i32),
322        ];
323
324        ProcessResult {
325            output_arrays: vec![Arc::new(out)],
326            param_updates: updates,
327            scatter_index: None,
328        }
329    }
330
331    fn plugin_type(&self) -> &str {
332        "NDPosPlugin"
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use ad_core_rs::ndarray::{NDDataType, NDDimension};
340
341    fn make_array(id: i32) -> NDArray {
342        let mut arr = NDArray::new(vec![NDDimension::new(4)], NDDataType::UInt8);
343        arr.unique_id = id;
344        arr
345    }
346
347    #[test]
348    fn test_discard_mode() {
349        let mut proc = PosPluginProcessor::new(PosMode::Discard);
350        let mut pos1 = HashMap::new();
351        pos1.insert("X".into(), 1.5);
352        pos1.insert("Y".into(), 2.3);
353        let mut pos2 = HashMap::new();
354        pos2.insert("X".into(), 3.1);
355        pos2.insert("Y".into(), 4.2);
356
357        proc.load_positions(vec![pos1, pos2]);
358        proc.start();
359
360        let pool = NDArrayPool::new(1_000_000);
361
362        let result = proc.process_array(&make_array(1), &pool);
363        assert_eq!(result.output_arrays.len(), 1);
364        let x = result.output_arrays[0]
365            .attributes
366            .get("X")
367            .unwrap()
368            .value
369            .as_f64()
370            .unwrap();
371        assert!((x - 1.5).abs() < 1e-10);
372
373        let result = proc.process_array(&make_array(2), &pool);
374        let x = result.output_arrays[0]
375            .attributes
376            .get("X")
377            .unwrap()
378            .value
379            .as_f64()
380            .unwrap();
381        assert!((x - 3.1).abs() < 1e-10);
382
383        assert_eq!(proc.remaining_positions(), 0);
384    }
385
386    #[test]
387    fn test_keep_mode() {
388        let mut proc = PosPluginProcessor::new(PosMode::Keep);
389        let mut pos1 = HashMap::new();
390        pos1.insert("X".into(), 10.0);
391        let mut pos2 = HashMap::new();
392        pos2.insert("X".into(), 20.0);
393
394        proc.load_positions(vec![pos1, pos2]);
395        proc.start();
396
397        let pool = NDArrayPool::new(1_000_000);
398
399        let result = proc.process_array(&make_array(1), &pool);
400        let x = result.output_arrays[0]
401            .attributes
402            .get("X")
403            .unwrap()
404            .value
405            .as_f64()
406            .unwrap();
407        assert!((x - 10.0).abs() < 1e-10);
408
409        let result = proc.process_array(&make_array(2), &pool);
410        let x = result.output_arrays[0]
411            .attributes
412            .get("X")
413            .unwrap()
414            .value
415            .as_f64()
416            .unwrap();
417        assert!((x - 20.0).abs() < 1e-10);
418
419        // Stops at end of list (no wrapping)
420        let result = proc.process_array(&make_array(3), &pool);
421        assert_eq!(result.output_arrays.len(), 1);
422        assert!(result.output_arrays[0].attributes.get("X").is_none());
423    }
424
425    #[test]
426    fn test_missing_frames() {
427        let mut proc = PosPluginProcessor::new(PosMode::Discard);
428        let mut pos1 = HashMap::new();
429        pos1.insert("X".into(), 1.0);
430        let mut pos2 = HashMap::new();
431        pos2.insert("X".into(), 2.0);
432        let mut pos3 = HashMap::new();
433        pos3.insert("X".into(), 3.0);
434
435        proc.load_positions(vec![pos1, pos2, pos3]);
436        proc.start();
437
438        let pool = NDArrayPool::new(1_000_000);
439
440        proc.process_array(&make_array(1), &pool);
441
442        // Frame 3 (skip frame 2)
443        let result = proc.process_array(&make_array(3), &pool);
444        assert_eq!(proc.missing_frames(), 1);
445        let x = result.output_arrays[0]
446            .attributes
447            .get("X")
448            .unwrap()
449            .value
450            .as_f64()
451            .unwrap();
452        assert!((x - 3.0).abs() < 1e-10);
453    }
454
455    #[test]
456    fn test_duplicate_frames() {
457        let mut proc = PosPluginProcessor::new(PosMode::Discard);
458        let mut pos1 = HashMap::new();
459        pos1.insert("X".into(), 1.0);
460        let mut pos2 = HashMap::new();
461        pos2.insert("X".into(), 2.0);
462
463        proc.load_positions(vec![pos1, pos2]);
464        proc.start();
465
466        let pool = NDArrayPool::new(1_000_000);
467
468        proc.process_array(&make_array(1), &pool);
469
470        let result = proc.process_array(&make_array(1), &pool);
471        assert_eq!(proc.duplicate_frames(), 1);
472        assert!(result.output_arrays.is_empty());
473    }
474
475    #[test]
476    fn test_load_json() {
477        let mut proc = PosPluginProcessor::new(PosMode::Discard);
478        let json = r#"{"positions": [{"X": 1.5, "Y": 2.3}, {"X": 3.1, "Y": 4.2}]}"#;
479        let count = proc.load_positions_json(json).unwrap();
480        assert_eq!(count, 2);
481        assert_eq!(proc.remaining_positions(), 2);
482    }
483
484    #[test]
485    fn test_not_running_passthrough() {
486        let mut proc = PosPluginProcessor::new(PosMode::Discard);
487        let pool = NDArrayPool::new(1_000_000);
488        let result = proc.process_array(&make_array(1), &pool);
489        assert_eq!(result.output_arrays.len(), 1);
490        assert!(result.output_arrays[0].attributes.get("X").is_none());
491    }
492
493    #[test]
494    fn test_load_xml() {
495        let mut proc = PosPluginProcessor::new(PosMode::Discard);
496        let xml = r#"<positions>
497  <position index="0">1.5</position>
498  <position index="1">2.3</position>
499  <position index="2">3.7</position>
500</positions>"#;
501        let count = proc.load_positions_xml(xml).unwrap();
502        assert_eq!(count, 3);
503        assert_eq!(proc.remaining_positions(), 3);
504    }
505
506    #[test]
507    fn test_load_xml_out_of_order() {
508        let mut proc = PosPluginProcessor::new(PosMode::Discard);
509        let xml = r#"<positions>
510  <position index="2">30.0</position>
511  <position index="0">10.0</position>
512  <position index="1">20.0</position>
513</positions>"#;
514        let count = proc.load_positions_xml(xml).unwrap();
515        assert_eq!(count, 3);
516
517        proc.start();
518        let pool = NDArrayPool::new(1_000_000);
519
520        // Should be sorted by index: 10.0, 20.0, 30.0
521        let result = proc.process_array(&make_array(1), &pool);
522        let pos = result.output_arrays[0]
523            .attributes
524            .get("position")
525            .unwrap()
526            .value
527            .as_f64()
528            .unwrap();
529        assert!((pos - 10.0).abs() < 1e-10);
530
531        let result = proc.process_array(&make_array(2), &pool);
532        let pos = result.output_arrays[0]
533            .attributes
534            .get("position")
535            .unwrap()
536            .value
537            .as_f64()
538            .unwrap();
539        assert!((pos - 20.0).abs() < 1e-10);
540    }
541
542    #[test]
543    fn test_load_xml_no_index() {
544        let mut proc = PosPluginProcessor::new(PosMode::Discard);
545        let xml = r#"<positions>
546  <position>5.5</position>
547  <position>6.6</position>
548</positions>"#;
549        let count = proc.load_positions_xml(xml).unwrap();
550        assert_eq!(count, 2);
551    }
552
553    #[test]
554    fn test_load_auto_json() {
555        let mut proc = PosPluginProcessor::new(PosMode::Discard);
556        let json = r#"{"positions": [{"X": 1.5}]}"#;
557        let count = proc.load_positions_auto(json).unwrap();
558        assert_eq!(count, 1);
559    }
560
561    #[test]
562    fn test_load_auto_xml() {
563        let mut proc = PosPluginProcessor::new(PosMode::Discard);
564        let xml = r#"<positions><position index="0">99.9</position></positions>"#;
565        let count = proc.load_positions_auto(xml).unwrap();
566        assert_eq!(count, 1);
567    }
568
569    #[test]
570    fn test_load_xml_empty() {
571        let mut proc = PosPluginProcessor::new(PosMode::Discard);
572        let xml = r#"<positions></positions>"#;
573        let count = proc.load_positions_xml(xml).unwrap();
574        assert_eq!(count, 0);
575    }
576}