use crate::{
ArmPattern, CapturedNode, CodePattern, MatchResult, NameMatcher, NodeKind, PatternExpr, Span,
};
use ryo_source::pure::{PureBlock, PureExpr, PureFn, PureMatchArm, PurePattern, PureStmt};
use std::collections::HashMap;
#[derive(Debug, Default)]
pub struct MatchContext {
pub captures: HashMap<String, CapturedNode>,
}
impl MatchContext {
pub fn new() -> Self {
Self::default()
}
pub fn capture(&mut self, name: impl Into<String>, text: impl Into<String>) {
let name = name.into();
self.captures
.insert(name, CapturedNode::new(Span::point(0, 0), text.into()));
}
pub fn merge(&mut self, other: MatchContext) {
self.captures.extend(other.captures);
}
pub fn into_match_result(self) -> MatchResult {
let mut result = MatchResult::matched();
result.captures = self.captures;
result
}
}
pub struct ExprMatcher<'p> {
pattern: &'p CodePattern,
}
impl<'p> ExprMatcher<'p> {
pub fn new(pattern: &'p CodePattern) -> Self {
Self { pattern }
}
pub fn matches(&self, expr: &PureExpr) -> Option<MatchContext> {
let mut ctx = MatchContext::new();
if self.match_expr(expr, &mut ctx) {
if let Some(ref capture_name) = self.pattern.capture {
ctx.capture(capture_name.clone(), expr_to_string(expr));
}
Some(ctx)
} else {
None
}
}
fn match_expr(&self, expr: &PureExpr, ctx: &mut MatchContext) -> bool {
match (&self.pattern.node, expr) {
(
NodeKind::MethodCall,
PureExpr::MethodCall {
receiver,
method,
args,
..
},
) => {
if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("method") {
if !match_name(name_matcher, method) {
return false;
}
}
if let Some(receiver_pattern) = self.pattern.children.get("receiver") {
if !self.match_pattern_expr(receiver_pattern, receiver, ctx) {
return false;
}
}
if let Some(PatternExpr::Pattern(args_pattern)) = self.pattern.children.get("args")
{
let _ = args_pattern;
let _ = args;
}
true
}
(NodeKind::FunctionCall, PureExpr::Call { func, args }) => {
if let Some(func_pattern) = self.pattern.children.get("func") {
if !self.match_pattern_expr(func_pattern, func, ctx) {
return false;
}
}
let _ = args;
true
}
(NodeKind::MacroCall, PureExpr::Macro { name, .. }) => {
if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("macro") {
if !match_name(name_matcher, name) {
return false;
}
}
true
}
(NodeKind::Try, PureExpr::Try(expr)) => {
if let Some(expr_pattern) = self.pattern.children.get("expr") {
if !self.match_pattern_expr(expr_pattern, expr, ctx) {
return false;
}
}
true
}
(NodeKind::Await, PureExpr::Await(expr)) => {
if let Some(expr_pattern) = self.pattern.children.get("expr") {
if !self.match_pattern_expr(expr_pattern, expr, ctx) {
return false;
}
}
true
}
(NodeKind::BinaryOp, PureExpr::Binary { op, left, right }) => {
if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("op") {
if !match_name(name_matcher, op) {
return false;
}
}
if let Some(left_pattern) = self.pattern.children.get("left") {
if !self.match_pattern_expr(left_pattern, left, ctx) {
return false;
}
}
if let Some(right_pattern) = self.pattern.children.get("right") {
if !self.match_pattern_expr(right_pattern, right, ctx) {
return false;
}
}
true
}
(NodeKind::Path, PureExpr::Path(path)) => {
if let Some(PatternExpr::Name(name_matcher)) = self.pattern.children.get("path") {
if !match_name(name_matcher, path) {
return false;
}
}
true
}
(NodeKind::Literal, PureExpr::Lit(lit)) => {
if let Some(value_pattern) = self.pattern.children.get("value") {
match value_pattern {
PatternExpr::Literal(expected) => {
if let Some(expected_str) = expected.as_str() {
if lit != expected_str {
return false;
}
}
}
PatternExpr::Name(NameMatcher::Exact(expected_str))
if lit != expected_str =>
{
return false;
}
_ => {}
}
}
true
}
(NodeKind::Expr, _) => true,
(NodeKind::Block, PureExpr::Block { .. }) => true,
(
NodeKind::If,
PureExpr::If {
cond,
then_branch,
else_branch,
},
) => {
if let Some(cond_pattern) = self.pattern.children.get("cond") {
if !self.match_pattern_expr(cond_pattern, cond, ctx) {
return false;
}
}
let _ = (then_branch, else_branch);
true
}
(NodeKind::Match, PureExpr::Match { expr, arms }) => {
if let Some(expr_pattern) = self.pattern.children.get("expr") {
if !self.match_pattern_expr(expr_pattern, expr, ctx) {
return false;
}
}
if let Some(expected) = self.pattern.arm_count {
if arms.len() != expected {
return false;
}
}
if let Some(arm_patterns) = &self.pattern.arms {
for ap in arm_patterns {
if !arms.iter().any(|arm| match_arm_pattern(ap, arm, ctx)) {
return false;
}
}
}
true
}
(NodeKind::Return, PureExpr::Return(maybe_expr)) => {
if let Some(expr_pattern) = self.pattern.children.get("expr") {
if let Some(expr) = maybe_expr {
if !self.match_pattern_expr(expr_pattern, expr, ctx) {
return false;
}
} else {
return false;
}
}
true
}
(NodeKind::Loop, PureExpr::Loop { .. }) => true,
(NodeKind::Loop, PureExpr::While { .. }) => true,
(NodeKind::Loop, PureExpr::For { .. }) => true,
(NodeKind::Closure, PureExpr::Closure { .. }) => true,
(NodeKind::Index, PureExpr::Index { expr, index }) => {
if let Some(expr_pattern) = self.pattern.children.get("expr") {
if !self.match_pattern_expr(expr_pattern, expr, ctx) {
return false;
}
}
if let Some(index_pattern) = self.pattern.children.get("index") {
if !self.match_pattern_expr(index_pattern, index, ctx) {
return false;
}
}
true
}
_ => false,
}
}
fn match_pattern_expr(
&self,
pattern: &PatternExpr,
expr: &PureExpr,
ctx: &mut MatchContext,
) -> bool {
match pattern {
PatternExpr::Pattern(nested) => {
let matcher = ExprMatcher::new(nested);
if let Some(nested_ctx) = matcher.matches(expr) {
ctx.merge(nested_ctx);
true
} else {
false
}
}
PatternExpr::Capture(var_name) => {
ctx.capture(var_name.clone(), expr_to_string(expr));
true
}
PatternExpr::Wildcard => true,
PatternExpr::Name(name_matcher) => {
if let PureExpr::Path(path) = expr {
match_name(name_matcher, path)
} else {
false
}
}
PatternExpr::Literal(expected) => {
if let PureExpr::Lit(lit) = expr {
if let Some(expected_str) = expected.as_str() {
lit == expected_str
} else {
false
}
} else {
false
}
}
}
}
}
fn match_name(matcher: &NameMatcher, name: &str) -> bool {
match matcher {
NameMatcher::Exact(expected) => name == expected,
NameMatcher::Pattern(pattern) => {
if let Some(ref prefix) = pattern.starts_with {
if !name.starts_with(prefix) {
return false;
}
}
if let Some(ref suffix) = pattern.ends_with {
if !name.ends_with(suffix) {
return false;
}
}
if let Some(ref substr) = pattern.contains {
if !name.contains(substr) {
return false;
}
}
if let Some(ref glob) = pattern.glob {
if !match_glob(glob, name) {
return false;
}
}
true
}
}
}
fn match_glob(pattern: &str, name: &str) -> bool {
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
return name.starts_with(prefix);
}
if let Some(suffix) = pattern.strip_prefix('*') {
return name.ends_with(suffix);
}
pattern == name
}
fn match_arm_pattern(ap: &ArmPattern, arm: &PureMatchArm, _ctx: &mut MatchContext) -> bool {
if let Some(ref expected_path) = ap.pattern_path {
if !pattern_contains_path(&arm.pattern, expected_path) {
return false;
}
}
if let Some(ref body_pattern) = ap.body {
let matcher = ExprMatcher::new(body_pattern);
if matcher.matches(&arm.body).is_none() {
return false;
}
}
true
}
fn pattern_contains_path(pat: &PurePattern, expected: &str) -> bool {
match pat {
PurePattern::Path(p) => p == expected || p.ends_with(&format!("::{}", expected)),
PurePattern::Struct { path, .. } => {
path == expected || path.ends_with(&format!("::{}", expected))
}
PurePattern::Ident { name, .. } => name == expected,
PurePattern::Or(patterns) => patterns.iter().any(|p| pattern_contains_path(p, expected)),
PurePattern::Ref { pattern, .. } => pattern_contains_path(pattern, expected),
PurePattern::Tuple(patterns) => patterns.iter().any(|p| pattern_contains_path(p, expected)),
_ => false,
}
}
pub fn expr_to_string(expr: &PureExpr) -> String {
match expr {
PureExpr::Lit(s) => s.clone(),
PureExpr::Path(s) => s.clone(),
PureExpr::MethodCall {
receiver,
method,
args,
turbofish,
} => {
let receiver_str = expr_to_string(receiver);
let turbofish_str = turbofish
.as_ref()
.map(|t| format!("::{}", t))
.unwrap_or_default();
let args_str = args
.iter()
.map(expr_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("{}.{}{}({})", receiver_str, method, turbofish_str, args_str)
}
PureExpr::Call { func, args } => {
let func_str = expr_to_string(func);
let args_str = args
.iter()
.map(expr_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("{}({})", func_str, args_str)
}
PureExpr::Binary { op, left, right } => {
format!("{} {} {}", expr_to_string(left), op, expr_to_string(right))
}
PureExpr::Unary { op, expr } => {
format!("{}{}", op, expr_to_string(expr))
}
PureExpr::Try(expr) => {
format!("{}?", expr_to_string(expr))
}
PureExpr::Await(expr) => {
format!("{}.await", expr_to_string(expr))
}
PureExpr::Field { expr, field } => {
format!("{}.{}", expr_to_string(expr), field)
}
PureExpr::Return(Some(e)) => format!("return {}", expr_to_string(e)),
PureExpr::Return(None) => "return".to_string(),
PureExpr::Block { .. } => "{ ... }".to_string(),
PureExpr::If { .. } => "if ...".to_string(),
PureExpr::Match { .. } => "match ...".to_string(),
PureExpr::Closure { .. } => "|...| ...".to_string(),
PureExpr::Tuple(items) => {
let items_str = items
.iter()
.map(expr_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("({})", items_str)
}
PureExpr::Array(items) => {
let items_str = items
.iter()
.map(expr_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("[{}]", items_str)
}
PureExpr::Macro { name, .. } => format!("{}!(...)", name),
_ => "<expr>".to_string(),
}
}
pub struct BodyScanner<'p> {
pattern: &'p CodePattern,
}
impl<'p> BodyScanner<'p> {
pub fn new(pattern: &'p CodePattern) -> Self {
Self { pattern }
}
pub fn scan_fn(&self, func: &PureFn) -> Vec<MatchResult> {
let mut results = Vec::new();
self.scan_block(&func.body, &mut results);
results
}
fn scan_block(&self, block: &PureBlock, results: &mut Vec<MatchResult>) {
for stmt in &block.stmts {
self.scan_stmt(stmt, results);
}
}
fn scan_stmt(&self, stmt: &PureStmt, results: &mut Vec<MatchResult>) {
match stmt {
PureStmt::Local {
init: Some(expr), ..
} => {
self.scan_expr(expr, results);
}
PureStmt::Semi(expr) | PureStmt::Expr(expr) => {
self.scan_expr(expr, results);
}
_ => {}
}
}
fn scan_expr(&self, expr: &PureExpr, results: &mut Vec<MatchResult>) {
let matcher = ExprMatcher::new(self.pattern);
if let Some(ctx) = matcher.matches(expr) {
results.push(ctx.into_match_result());
}
match expr {
PureExpr::MethodCall { receiver, args, .. } => {
self.scan_expr(receiver, results);
for arg in args {
self.scan_expr(arg, results);
}
}
PureExpr::Call { func, args } => {
self.scan_expr(func, results);
for arg in args {
self.scan_expr(arg, results);
}
}
PureExpr::Binary { left, right, .. } => {
self.scan_expr(left, results);
self.scan_expr(right, results);
}
PureExpr::Unary { expr, .. } => {
self.scan_expr(expr, results);
}
PureExpr::Try(expr) => {
self.scan_expr(expr, results);
}
PureExpr::Await(expr) => {
self.scan_expr(expr, results);
}
PureExpr::Field { expr, .. } => {
self.scan_expr(expr, results);
}
PureExpr::Index { expr, index } => {
self.scan_expr(expr, results);
self.scan_expr(index, results);
}
PureExpr::Block { block, .. } => {
self.scan_block(block, results);
}
PureExpr::If {
cond,
then_branch,
else_branch,
} => {
self.scan_expr(cond, results);
self.scan_block(then_branch, results);
if let Some(else_expr) = else_branch {
self.scan_expr(else_expr, results);
}
}
PureExpr::Match { expr, arms } => {
self.scan_expr(expr, results);
for arm in arms {
self.scan_expr(&arm.body, results);
}
}
PureExpr::Loop { body, .. } => {
self.scan_block(body, results);
}
PureExpr::While { cond, body, .. } => {
self.scan_expr(cond, results);
self.scan_block(body, results);
}
PureExpr::For { expr, body, .. } => {
self.scan_expr(expr, results);
self.scan_block(body, results);
}
PureExpr::Return(Some(e)) => {
self.scan_expr(e, results);
}
PureExpr::Break { expr: Some(e), .. } => {
self.scan_expr(e, results);
}
PureExpr::Closure { body, .. } => {
self.scan_expr(body, results);
}
PureExpr::Struct { fields, .. } => {
for (_, field_expr) in fields {
self.scan_expr(field_expr, results);
}
}
PureExpr::Tuple(items) | PureExpr::Array(items) => {
for item in items {
self.scan_expr(item, results);
}
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ryo_source::pure::MacroDelimiter;
#[test]
fn test_match_method_call() {
let pattern = CodePattern::new(NodeKind::MethodCall).with_child(
"method",
PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
);
let expr = PureExpr::MethodCall {
receiver: Box::new(PureExpr::Path("result".into())),
method: "unwrap".into(),
turbofish: None,
args: vec![],
};
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr).is_some());
}
#[test]
fn test_no_match_different_method() {
let pattern = CodePattern::new(NodeKind::MethodCall).with_child(
"method",
PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
);
let expr = PureExpr::MethodCall {
receiver: Box::new(PureExpr::Path("result".into())),
method: "expect".into(),
turbofish: None,
args: vec![],
};
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr).is_none());
}
#[test]
fn test_capture_receiver() {
let pattern = CodePattern::new(NodeKind::MethodCall)
.with_child(
"method",
PatternExpr::Name(NameMatcher::Exact("unwrap".into())),
)
.with_child("receiver", PatternExpr::Capture("$x".into()))
.with_capture("$call");
let expr = PureExpr::MethodCall {
receiver: Box::new(PureExpr::Path("my_result".into())),
method: "unwrap".into(),
turbofish: None,
args: vec![],
};
let matcher = ExprMatcher::new(&pattern);
let ctx = matcher.matches(&expr).unwrap();
assert!(ctx.captures.contains_key("$x"));
assert!(ctx.captures.contains_key("$call"));
assert_eq!(ctx.captures["$x"].text, "my_result");
}
#[test]
fn test_glob_matching() {
assert!(match_glob("get_*", "get_name"));
assert!(match_glob("*_id", "user_id"));
assert!(match_glob("*", "anything"));
assert!(!match_glob("get_*", "set_name"));
}
#[test]
fn test_literal_match_with_name_pattern() {
let pattern = CodePattern::new(NodeKind::Literal).with_child(
"value",
PatternExpr::Name(NameMatcher::Exact("true".into())),
);
let expr_true = PureExpr::Lit("true".into());
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr_true).is_some());
let expr_false = PureExpr::Lit("false".into());
assert!(matcher.matches(&expr_false).is_none());
let expr_num = PureExpr::Lit("42".into());
assert!(matcher.matches(&expr_num).is_none());
}
#[test]
fn test_literal_match_with_literal_pattern() {
let pattern = CodePattern::new(NodeKind::Literal)
.with_child("value", PatternExpr::Literal(serde_json::json!("true")));
let expr_true = PureExpr::Lit("true".into());
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr_true).is_some());
let expr_false = PureExpr::Lit("false".into());
assert!(matcher.matches(&expr_false).is_none());
}
#[test]
fn test_macro_call_match() {
let pattern = CodePattern::new(NodeKind::MacroCall).with_child(
"macro",
PatternExpr::Name(NameMatcher::Exact("todo".into())),
);
let expr_todo = PureExpr::Macro {
name: "todo".into(),
delimiter: MacroDelimiter::Paren,
tokens: "".into(),
};
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr_todo).is_some());
let expr_println = PureExpr::Macro {
name: "println".into(),
delimiter: MacroDelimiter::Paren,
tokens: "".into(),
};
assert!(matcher.matches(&expr_println).is_none());
let expr_vec = PureExpr::Macro {
name: "vec".into(),
delimiter: MacroDelimiter::Bracket,
tokens: "".into(),
};
assert!(matcher.matches(&expr_vec).is_none());
}
#[test]
fn test_macro_call_no_filter_matches_all() {
let pattern = CodePattern::new(NodeKind::MacroCall);
let expr_todo = PureExpr::Macro {
name: "todo".into(),
delimiter: MacroDelimiter::Paren,
tokens: "".into(),
};
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr_todo).is_some());
let expr_vec = PureExpr::Macro {
name: "vec".into(),
delimiter: MacroDelimiter::Bracket,
tokens: "".into(),
};
assert!(matcher.matches(&expr_vec).is_some());
}
#[test]
fn test_path_match_exact() {
let pattern = CodePattern::new(NodeKind::Path).with_child(
"path",
PatternExpr::Name(NameMatcher::Exact("Filter::Recurse".into())),
);
let expr_match = PureExpr::Path("Filter::Recurse".into());
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr_match).is_some());
let expr_no_match = PureExpr::Path("Filter::Include".into());
assert!(matcher.matches(&expr_no_match).is_none());
let expr_unrelated = PureExpr::Path("something_else".into());
assert!(matcher.matches(&expr_unrelated).is_none());
}
#[test]
fn test_path_no_filter_matches_all() {
let pattern = CodePattern::new(NodeKind::Path);
let expr = PureExpr::Path("anything".into());
let matcher = ExprMatcher::new(&pattern);
assert!(matcher.matches(&expr).is_some());
}
}