ryo-pattern 0.1.0

RyoPattern - AST pattern matching and lint rules for Ryo
Documentation
//! CodePattern - AST structural matching
//!
//! Describes structural patterns over AST nodes for pattern matching.

use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// AST node type to match
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "PascalCase")]
pub enum NodeKind {
    // Expressions
    /// Any expression.
    Expr,
    /// Literal (number / string / bool / etc.).
    Literal,
    /// Path expression (`foo::bar`).
    Path,
    /// Method call (`x.foo()`).
    MethodCall,
    /// Free function call (`foo()`).
    FunctionCall,
    /// Macro invocation (`println!()`).
    MacroCall,
    /// Binary operator expression.
    BinaryOp,
    /// Unary operator expression.
    UnaryOp,
    /// `if` / `if let` expression.
    If,
    /// `match` expression.
    Match,
    /// `loop` / `while` / `for` expression.
    Loop,
    /// Block expression `{ ... }`.
    Block,
    /// Closure expression (`|x| ...`).
    Closure,
    /// `await` expression.
    Await,
    /// `?` (try) expression.
    Try,
    /// `return` expression.
    Return,
    /// Index expression (`a[i]`).
    Index,

    // Items
    /// `fn` item.
    Function,
    /// `struct` item.
    Struct,
    /// `enum` item.
    Enum,
    /// `trait` item.
    Trait,
    /// `impl` block.
    Impl,
    /// `mod` item.
    Mod,
    /// `use` declaration.
    Use,
    /// `const` item.
    Const,
    /// `static` item.
    Static,
    /// `type` alias.
    TypeAlias,

    // Parts
    /// Struct / enum field.
    Field,
    /// Enum variant.
    Variant,
    /// Function parameter.
    Param,
    /// Call site argument.
    Arg,
    /// Generic type / lifetime argument.
    GenericArg,
    /// Lifetime token.
    Lifetime,
    /// Attribute (`#[...]`).
    Attribute,

    // Special
    /// `let` expression (`if let` / `while let`).
    LetExpr,
    /// Wildcard (matches any node).
    Wildcard,
    /// Unit `()`.
    Unit,
}

/// Name matcher for symbol/method names
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum NameMatcher {
    /// Exact match
    Exact(String),
    /// Pattern-based match
    Pattern(NamePattern),
}

/// Pattern-based name matching
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]
pub struct NamePattern {
    /// Glob pattern (e.g., "get_*")
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub glob: Option<String>,
    /// Regex pattern (e.g., "^is_")
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub regex: Option<String>,
    /// Starts with prefix
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub starts_with: Option<String>,
    /// Ends with suffix
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub ends_with: Option<String>,
    /// Contains substring
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub contains: Option<String>,
}

/// Pattern expression (recursive)
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
pub enum PatternExpr {
    /// Nested code pattern
    Pattern(Box<CodePattern>),
    /// Name matcher
    Name(NameMatcher),
    /// Literal value
    Literal(serde_json::Value),
    /// Wildcard (_) - matches anything
    Wildcard,
    /// Capture variable reference (e.g., "$VAR")
    Capture(String),
}

/// Match arm pattern for Match expressions
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct ArmPattern {
    /// Match on the pattern path (e.g., "Some", "None", "Ok", "Err")
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub pattern_path: Option<String>,

    /// Match on the arm body expression
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub body: Option<Box<CodePattern>>,
}

/// AST Pattern for structural matching
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct CodePattern {
    /// AST node type to match
    pub node: NodeKind,

    /// Required arm count (for Match nodes)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub arm_count: Option<usize>,

    /// Arm patterns (for Match nodes, order-independent matching)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub arms: Option<Vec<ArmPattern>>,

    /// Child patterns (field name -> pattern)
    #[serde(flatten)]
    pub children: HashMap<String, PatternExpr>,

    /// Capture variable (e.g., "$RECEIVER")
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub capture: Option<String>,

    /// Match zero or more items (ellipsis)
    #[serde(default)]
    pub ellipsis: bool,
}

impl CodePattern {
    /// Create a new CodePattern for the given node kind
    pub fn new(node: NodeKind) -> Self {
        Self {
            node,
            arm_count: None,
            arms: None,
            children: HashMap::new(),
            capture: None,
            ellipsis: false,
        }
    }

    /// Add a child pattern
    pub fn with_child(mut self, name: impl Into<String>, pattern: PatternExpr) -> Self {
        self.children.insert(name.into(), pattern);
        self
    }

    /// Set capture variable
    pub fn with_capture(mut self, var: impl Into<String>) -> Self {
        self.capture = Some(var.into());
        self
    }

    /// Set ellipsis mode
    pub fn with_ellipsis(mut self) -> Self {
        self.ellipsis = true;
        self
    }
}

/// Body match conditions for symbol bodies
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct BodyMatch {
    /// At least one node matches each pattern (existential)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub contains: Option<Vec<CodePattern>>,

    /// No node matches these patterns (negation)
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub not_contains: Option<Vec<CodePattern>>,

    /// All listed patterns must have at least one match
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub all_of: Option<Vec<CodePattern>>,
}

impl BodyMatch {
    /// Construct an empty `BodyMatch`.
    pub fn new() -> Self {
        Self::default()
    }

    /// Add a contains pattern
    pub fn contains(mut self, pattern: CodePattern) -> Self {
        self.contains.get_or_insert_with(Vec::new).push(pattern);
        self
    }

    /// Add a not_contains pattern
    pub fn not_contains(mut self, pattern: CodePattern) -> Self {
        self.not_contains.get_or_insert_with(Vec::new).push(pattern);
        self
    }

    /// Add an all_of pattern
    pub fn all_of(mut self, pattern: CodePattern) -> Self {
        self.all_of.get_or_insert_with(Vec::new).push(pattern);
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_code_pattern_builder() {
        let pattern = CodePattern::new(NodeKind::MethodCall)
            .with_child(
                "method",
                PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
            )
            .with_capture("$UNWRAP");

        assert_eq!(pattern.node, NodeKind::MethodCall);
        assert_eq!(pattern.capture, Some("$UNWRAP".to_string()));
        assert!(pattern.children.contains_key("method"));
    }

    #[test]
    fn test_body_match_builder() {
        let body = BodyMatch::new()
            .contains(CodePattern::new(NodeKind::MethodCall))
            .not_contains(CodePattern::new(NodeKind::MacroCall));

        assert!(body.contains.is_some());
        assert!(body.not_contains.is_some());
        assert!(body.all_of.is_none());
    }

    #[test]
    fn test_serialize_deserialize() {
        let pattern = CodePattern::new(NodeKind::MethodCall).with_child(
            "method",
            PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
        );

        let json = serde_json::to_string(&pattern).unwrap();
        let deserialized: CodePattern = serde_json::from_str(&json).unwrap();

        assert_eq!(pattern.node, deserialized.node);
    }
}