dcbor-pattern 0.11.1

Pattern matcher for dCBOR
Documentation
use std::ops::RangeBounds;

use dcbor::prelude::*;

use crate::{
    Interval,
    pattern::{Matcher, Path, Pattern, vm::Instr},
};

/// Pattern for matching CBOR map structures.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MapPattern {
    /// Matches any map.
    Any,
    /// Matches maps with multiple key-value constraints that must all be
    /// satisfied.
    Constraints(Vec<(Pattern, Pattern)>),
    /// Matches maps with number of key-value pairs in the given interval.
    Length(Interval),
}

impl MapPattern {
    /// Creates a new `MapPattern` that matches any map.
    pub fn any() -> Self { MapPattern::Any }

    /// Creates a new `MapPattern` that matches maps with multiple key-value
    /// constraints that must all be satisfied.
    pub fn with_key_value_constraints(
        constraints: Vec<(Pattern, Pattern)>,
    ) -> Self {
        MapPattern::Constraints(constraints)
    }

    /// Creates a new `MapPattern` that matches maps with a specific number of
    /// key-value pairs.
    pub fn with_length(length: usize) -> Self {
        MapPattern::Length(Interval::new(length..=length))
    }

    /// Creates a new `MapPattern` that matches maps with number of key-value
    /// pairs in the given range.
    pub fn with_length_range<R: RangeBounds<usize>>(range: R) -> Self {
        MapPattern::Length(Interval::new(range))
    }

    /// Creates a new `MapPattern` that matches maps with number of key-value
    /// pairs in the given range.
    pub fn with_length_interval(interval: Interval) -> Self {
        MapPattern::Length(interval)
    }
}

impl Matcher for MapPattern {
    fn paths(&self, haystack: &CBOR) -> Vec<Path> {
        // First check if this is a map
        match haystack.as_case() {
            CBORCase::Map(map) => {
                match self {
                    MapPattern::Any => {
                        // Match any map - return the map itself
                        vec![vec![haystack.clone()]]
                    }
                    MapPattern::Constraints(constraints) => {
                        // All constraints must be satisfied
                        for (key_pattern, value_pattern) in constraints {
                            let mut found_match = false;
                            for (key, value) in map.iter() {
                                if key_pattern.matches(key)
                                    && value_pattern.matches(value)
                                {
                                    found_match = true;
                                    break;
                                }
                            }
                            if !found_match {
                                return vec![];
                            }
                        }
                        vec![vec![haystack.clone()]]
                    }
                    MapPattern::Length(interval) => {
                        if interval.contains(map.len()) {
                            vec![vec![haystack.clone()]]
                        } else {
                            vec![]
                        }
                    }
                }
            }
            _ => {
                // Not a map, no match
                vec![]
            }
        }
    }

    fn compile(
        &self,
        code: &mut Vec<Instr>,
        literals: &mut Vec<Pattern>,
        captures: &mut Vec<String>,
    ) {
        // Collect capture names from inner patterns
        self.collect_capture_names(captures);

        let idx = literals.len();
        literals.push(Pattern::Structure(
            crate::pattern::StructurePattern::Map(self.clone()),
        ));
        code.push(Instr::MatchStructure(idx));
    }

    fn collect_capture_names(&self, names: &mut Vec<String>) {
        match self {
            MapPattern::Any => {
                // No captures in a simple any pattern
            }
            MapPattern::Constraints(constraints) => {
                // Collect captures from all key and value patterns
                for (key_pattern, value_pattern) in constraints {
                    key_pattern.collect_capture_names(names);
                    value_pattern.collect_capture_names(names);
                }
            }
            MapPattern::Length(_) => {
                // No captures in length interval patterns
            }
        }
    }

    fn paths_with_captures(
        &self,
        haystack: &CBOR,
    ) -> (Vec<Path>, std::collections::HashMap<String, Vec<Path>>) {
        // Check if this CBOR value is a map
        let CBORCase::Map(map) = haystack.as_case() else {
            return (vec![], std::collections::HashMap::new());
        };

        match self {
            MapPattern::Any => {
                // Matches any map, no captures
                (
                    vec![vec![haystack.clone()]],
                    std::collections::HashMap::new(),
                )
            }
            MapPattern::Constraints(constraints) => {
                // Match if all key-value constraints are satisfied
                let mut all_captures = std::collections::HashMap::new();
                let mut all_constraints_satisfied = true;

                for (key_pattern, value_pattern) in constraints {
                    let mut constraint_satisfied = false;

                    for (key, value) in map.iter() {
                        let (key_paths, key_captures) =
                            key_pattern.paths_with_captures(key);
                        let (value_paths, value_captures) =
                            value_pattern.paths_with_captures(value);

                        if !key_paths.is_empty() && !value_paths.is_empty() {
                            constraint_satisfied = true;

                            // Merge key captures
                            for (name, capture_paths) in key_captures {
                                let updated_paths: Vec<Path> = capture_paths
                                    .iter()
                                    .map(|_capture_path| {
                                        vec![haystack.clone(), key.clone()]
                                    })
                                    .collect();
                                all_captures
                                    .entry(name)
                                    .or_insert_with(Vec::new)
                                    .extend(updated_paths);
                            }

                            // Merge value captures
                            for (name, capture_paths) in value_captures {
                                let updated_paths: Vec<Path> = capture_paths
                                    .iter()
                                    .map(|_capture_path| {
                                        vec![haystack.clone(), value.clone()]
                                    })
                                    .collect();
                                all_captures
                                    .entry(name)
                                    .or_insert_with(Vec::new)
                                    .extend(updated_paths);
                            }
                            break; // Found a matching key-value pair for this constraint
                        }
                    }

                    if !constraint_satisfied {
                        all_constraints_satisfied = false;
                        break;
                    }
                }

                if all_constraints_satisfied {
                    (vec![vec![haystack.clone()]], all_captures)
                } else {
                    (vec![], all_captures)
                }
            }
            _ => {
                // For other variants, fall back to basic paths without captures
                (self.paths(haystack), std::collections::HashMap::new())
            }
        }
    }
}

impl std::fmt::Display for MapPattern {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            MapPattern::Any => write!(f, "map"),
            MapPattern::Constraints(constraints) => {
                write!(f, "{{")?;
                for (i, (key_pattern, value_pattern)) in
                    constraints.iter().enumerate()
                {
                    if i > 0 {
                        write!(f, ", ")?;
                    }
                    write!(f, "{}: {}", key_pattern, value_pattern)?;
                }
                write!(f, "}}")
            }
            MapPattern::Length(interval) => {
                write!(f, "{{{}}}", interval)
            }
        }
    }
}