use crate::language::Language;
macro_rules! define_patterns {
( $( $variant:ident => $name:expr $(, aliases = [ $($alias:expr),* ])? ; )* ) => {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Pattern {
$( $variant, )*
}
impl std::fmt::Display for Pattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
$( Pattern::$variant => write!(f, $name), )*
}
}
}
impl std::str::FromStr for Pattern {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
$( $name => Ok(Pattern::$variant), )*
$( $( $( $alias => Ok(Pattern::$variant), )* )? )*
_ => anyhow::bail!(
"Unknown pattern '{}'. Valid: {}",
s,
Self::all_names().join(", ")
),
}
}
}
impl Pattern {
pub fn all_names() -> &'static [&'static str] {
&[ $( $name, )* ]
}
}
};
}
define_patterns! {
Builder => "builder";
ErrorSwallow => "error_swallow", aliases = ["error-swallow"];
Async => "async";
Mutex => "mutex";
Unsafe => "unsafe";
Recursion => "recursion";
}
impl Pattern {
pub fn matches(&self, content: &str, name: &str, language: Option<Language>) -> bool {
if let Some(lang) = language {
if let Some(matchers) = lang.def().structural_matchers {
let pattern_name = self.to_string();
for (matcher_name, matcher_fn) in matchers {
if *matcher_name == pattern_name {
return matcher_fn(content, name);
}
}
}
}
match self {
Self::Builder => matches_builder(content, name),
Self::ErrorSwallow => matches_error_swallow(content, language),
Self::Async => matches_async(content, language),
Self::Mutex => matches_mutex(content, language),
Self::Unsafe => matches_unsafe(content, language),
Self::Recursion => matches_recursion(content, name),
}
}
}
fn matches_builder(content: &str, _name: &str) -> bool {
content.contains("-> Self")
|| content.contains("-> &Self")
|| content.contains("-> &mut Self")
|| content.contains("return self")
|| content.contains("return this")
|| (content.contains(".set") && content.contains("return"))
}
fn matches_error_swallow(content: &str, language: Option<Language>) -> bool {
match language {
Some(Language::Rust) => {
content.contains("unwrap_or_default()")
|| content.contains("unwrap_or(())")
|| content.contains(".ok();")
|| content.contains("_ => {}")
|| content.contains("_ => ()")
}
Some(Language::Python) => {
content.contains("except:") && content.contains("pass")
|| content.contains("except Exception:")
&& (content.contains("pass") || content.contains("..."))
}
Some(Language::TypeScript | Language::JavaScript) => {
content.contains("catch") && content.contains("{}")
|| content.contains("catch (") && content.contains("// ignore")
}
Some(Language::Go) => {
content.contains("_ = err") || content.contains("_ = ")
}
_ => {
content.contains("catch") && content.contains("{}")
|| content.contains("except") && content.contains("pass")
}
}
}
fn matches_async(content: &str, language: Option<Language>) -> bool {
match language {
Some(Language::Rust) => content.contains("async fn") || content.contains(".await"),
Some(Language::Python) => content.contains("async def") || content.contains("await "),
Some(Language::TypeScript | Language::JavaScript) => {
content.contains("async ") || content.contains("await ")
}
Some(Language::Go) => {
content.contains("go func") || content.contains("go ") || content.contains("<-")
}
_ => content.contains("async") || content.contains("await"),
}
}
fn matches_mutex(content: &str, language: Option<Language>) -> bool {
match language {
Some(Language::Rust) => {
content.contains("Mutex") || content.contains("RwLock") || content.contains(".lock()")
}
Some(Language::Python) => content.contains("Lock()") || content.contains("threading.Lock"),
Some(Language::Go) => content.contains("sync.Mutex") || content.contains("sync.RWMutex"),
_ => {
content.contains("mutex")
|| content.contains("Mutex")
|| content.contains("lock()")
|| content.contains("Lock()")
}
}
}
fn matches_unsafe(content: &str, language: Option<Language>) -> bool {
match language {
Some(Language::Rust) => content.contains("unsafe "),
Some(Language::C) => {
content.contains("memcpy")
|| content.contains("strcpy")
|| content.contains("sprintf")
|| content.contains("gets(")
}
Some(Language::Go) => content.contains("unsafe.Pointer"),
_ => content.contains("unsafe"),
}
}
fn matches_recursion(content: &str, name: &str) -> bool {
if name.is_empty() {
return false;
}
let lines: Vec<&str> = content.lines().collect();
if lines.len() <= 1 {
return false;
}
let call_paren = format!("{}(", name);
let call_space = format!("{} (", name);
lines[1..]
.iter()
.any(|line| line.contains(&call_paren) || line.contains(&call_space))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_parse_all_variants() {
assert!(matches!(
"builder".parse::<Pattern>().unwrap(),
Pattern::Builder
));
assert!(matches!(
"error_swallow".parse::<Pattern>().unwrap(),
Pattern::ErrorSwallow
));
assert!(matches!(
"error-swallow".parse::<Pattern>().unwrap(),
Pattern::ErrorSwallow
));
assert!(matches!(
"async".parse::<Pattern>().unwrap(),
Pattern::Async
));
assert!(matches!(
"mutex".parse::<Pattern>().unwrap(),
Pattern::Mutex
));
assert!(matches!(
"unsafe".parse::<Pattern>().unwrap(),
Pattern::Unsafe
));
assert!(matches!(
"recursion".parse::<Pattern>().unwrap(),
Pattern::Recursion
));
assert!("unknown".parse::<Pattern>().is_err());
}
#[test]
fn test_pattern_display_roundtrip() {
for name in Pattern::all_names() {
let p: Pattern = name.parse().unwrap();
assert_eq!(p.to_string(), *name);
}
}
#[test]
fn test_all_names_covers_all_variants() {
assert_eq!(Pattern::all_names().len(), 6);
for name in Pattern::all_names() {
assert!(
name.parse::<Pattern>().is_ok(),
"all_names entry '{}' failed to parse",
name
);
}
}
#[test]
fn test_builder_pattern() {
let pat = Pattern::Builder;
assert!(pat.matches(
"fn with_name(self, name: &str) -> Self { ... }",
"with_name",
None
));
assert!(pat.matches("fn build(self) -> &Self { ... }", "build", None));
assert!(!pat.matches("fn foo() -> i32 { 42 }", "foo", None));
}
#[test]
fn test_error_swallow_rust() {
let pat = Pattern::ErrorSwallow;
let lang = Some(Language::Rust);
assert!(pat.matches("let _ = result.unwrap_or_default();", "", lang));
assert!(pat.matches("result.ok();", "", lang));
assert!(pat.matches("match x { Ok(v) => v, _ => {} }", "", lang));
assert!(!pat.matches("let v = result?;", "", lang));
}
#[test]
fn test_error_swallow_python() {
let pat = Pattern::ErrorSwallow;
let lang = Some(Language::Python);
assert!(pat.matches("try:\n foo()\nexcept:\n pass", "", lang));
assert!(pat.matches("try:\n foo()\nexcept Exception:\n pass", "", lang));
assert!(!pat.matches(
"try:\n foo()\nexcept ValueError as e:\n log(e)",
"",
lang
));
}
#[test]
fn test_error_swallow_js() {
let pat = Pattern::ErrorSwallow;
let lang = Some(Language::JavaScript);
assert!(pat.matches("try { foo(); } catch (e) {}", "", lang));
assert!(pat.matches("try { foo(); } catch (e) { // ignore }", "", lang));
assert!(!pat.matches("try { foo(); } catch (e) { console.log(e); }", "", lang));
}
#[test]
fn test_async_rust() {
let pat = Pattern::Async;
assert!(pat.matches("async fn fetch() { ... }", "", Some(Language::Rust)));
assert!(pat.matches("let r = client.get(url).await?;", "", Some(Language::Rust)));
assert!(!pat.matches("fn sync_fetch() { ... }", "", Some(Language::Rust)));
}
#[test]
fn test_async_python() {
let pat = Pattern::Async;
assert!(pat.matches("async def fetch():", "", Some(Language::Python)));
assert!(pat.matches("result = await client.get(url)", "", Some(Language::Python)));
assert!(!pat.matches("def sync_fetch():", "", Some(Language::Python)));
}
#[test]
fn test_async_go() {
let pat = Pattern::Async;
let lang = Some(Language::Go);
assert!(pat.matches("go func() { ... }()", "", lang));
assert!(pat.matches("ch <- value", "", lang));
assert!(!pat.matches("func sync() { ... }", "", lang));
}
#[test]
fn test_mutex_rust() {
let pat = Pattern::Mutex;
let lang = Some(Language::Rust);
assert!(pat.matches("let guard = data.lock().unwrap();", "", lang));
assert!(pat.matches("let m = Mutex::new(0);", "", lang));
assert!(pat.matches("let rw = RwLock::new(vec![]);", "", lang));
assert!(!pat.matches("fn pure_function(x: i32) -> i32 { x + 1 }", "", lang));
}
#[test]
fn test_unsafe_rust() {
let pat = Pattern::Unsafe;
assert!(pat.matches("unsafe { ptr::read(src) }", "", Some(Language::Rust)));
assert!(!pat.matches("fn safe_function() { ... }", "", Some(Language::Rust)));
}
#[test]
fn test_unsafe_c() {
let pat = Pattern::Unsafe;
let lang = Some(Language::C);
assert!(pat.matches("memcpy(dst, src, n);", "", lang));
assert!(pat.matches("strcpy(buf, input);", "", lang));
assert!(pat.matches("sprintf(buf, fmt, arg);", "", lang));
assert!(!pat.matches("int add(int a, int b) { return a + b; }", "", lang));
}
#[test]
fn test_recursion_self_call() {
let pat = Pattern::Recursion;
let code =
"fn factorial(n: u32) -> u32 {\n if n <= 1 { 1 } else { n * factorial(n - 1) }\n}";
assert!(pat.matches(code, "factorial", None));
}
#[test]
fn test_recursion_no_self_call() {
let pat = Pattern::Recursion;
let code = "fn add(a: i32, b: i32) -> i32 {\n a + b\n}";
assert!(!pat.matches(code, "add", None));
}
#[test]
fn test_recursion_empty_name() {
let pat = Pattern::Recursion;
assert!(!pat.matches("fn foo() { foo() }", "", None));
}
#[test]
fn test_recursion_single_line() {
let pat = Pattern::Recursion;
assert!(!pat.matches("fn foo() { foo() }", "foo", None));
}
#[test]
fn test_structural_matchers_fallback() {
let pat = Pattern::Unsafe;
assert!(pat.matches("unsafe { ptr::read(p) }", "read_ptr", Some(Language::Rust)));
assert!(!pat.matches("fn safe() -> i32 { 42 }", "safe", Some(Language::Rust)));
}
#[test]
fn test_pattern_matches_no_language() {
let pat = Pattern::Async;
assert!(pat.matches("async function fetch() {}", "fetch", None));
assert!(!pat.matches("function sync() {}", "sync", None));
}
}