use std::borrow::Cow;
use ast_grep_core::language::Language;
use ast_grep_core::matcher::PatternError;
use ast_grep_core::tree_sitter::{LanguageExt, StrDoc, TSLanguage};
use ast_grep_core::Pattern;
use crate::parser::LangId;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AstGrepLang {
TypeScript,
Tsx,
JavaScript,
Python,
Rust,
Go,
C,
Cpp,
Zig,
CSharp,
Solidity,
Vue,
Json,
}
impl AstGrepLang {
pub fn from_lang_id(lang_id: &LangId) -> Option<Self> {
match lang_id {
LangId::TypeScript => Some(Self::TypeScript),
LangId::Tsx => Some(Self::Tsx),
LangId::JavaScript => Some(Self::JavaScript),
LangId::Python => Some(Self::Python),
LangId::Rust => Some(Self::Rust),
LangId::Go => Some(Self::Go),
LangId::C => Some(Self::C),
LangId::Cpp => Some(Self::Cpp),
LangId::Zig => Some(Self::Zig),
LangId::CSharp => Some(Self::CSharp),
LangId::Solidity => Some(Self::Solidity),
LangId::Vue => Some(Self::Vue),
LangId::Json => Some(Self::Json),
LangId::Scala => None,
LangId::Bash => None, _ => None,
}
}
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"typescript" | "ts" => Some(Self::TypeScript),
"tsx" => Some(Self::Tsx),
"javascript" | "js" => Some(Self::JavaScript),
"python" | "py" => Some(Self::Python),
"rust" | "rs" => Some(Self::Rust),
"go" | "golang" => Some(Self::Go),
"c" => Some(Self::C),
"cpp" | "c++" | "cplusplus" => Some(Self::Cpp),
"zig" => Some(Self::Zig),
"csharp" | "c#" | "cs" => Some(Self::CSharp),
"solidity" | "sol" => Some(Self::Solidity),
"vue" => Some(Self::Vue),
"json" | "jsonc" => Some(Self::Json),
_ => None,
}
}
pub fn extensions(&self) -> &'static [&'static str] {
match self {
Self::TypeScript => &["ts", "mts", "cts"],
Self::Tsx => &["tsx"],
Self::JavaScript => &["js", "mjs", "cjs", "jsx"],
Self::Python => &["py", "pyi"],
Self::Rust => &["rs"],
Self::Go => &["go"],
Self::C => &["c", "h"],
Self::Cpp => &["cc", "cpp", "cxx", "hpp", "hh"],
Self::Zig => &["zig"],
Self::CSharp => &["cs"],
Self::Solidity => &["sol"],
Self::Vue => &["vue"],
Self::Json => &["json", "jsonc"],
}
}
pub fn matches_extension(&self, ext: &str) -> bool {
self.extensions().contains(&ext)
}
pub fn matches_path(&self, path: &std::path::Path) -> bool {
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
self.matches_extension(ext)
}
}
impl Language for AstGrepLang {
fn kind_to_id(&self, kind: &str) -> u16 {
let ts_lang: TSLanguage = self.get_ts_language();
ts_lang.id_for_node_kind(kind, true)
}
fn field_to_id(&self, field: &str) -> Option<u16> {
self.get_ts_language()
.field_id_for_name(field)
.map(|f| f.get())
}
fn build_pattern(
&self,
builder: &ast_grep_core::matcher::PatternBuilder,
) -> Result<Pattern, PatternError> {
builder.build(|src| StrDoc::try_new(src, self.clone()))
}
fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
let expando = self.expando_char();
if expando == '$' {
return Cow::Borrowed(query);
}
let mut ret = Vec::with_capacity(query.len());
let mut dollar_count = 0;
for c in query.chars() {
if c == '$' {
dollar_count += 1;
continue;
}
let need_replace = matches!(c, 'A'..='Z' | '_') || dollar_count == 3;
let sigil = if need_replace { expando } else { '$' };
ret.extend(std::iter::repeat(sigil).take(dollar_count));
dollar_count = 0;
ret.push(c);
}
let sigil = if dollar_count == 3 { expando } else { '$' };
ret.extend(std::iter::repeat(sigil).take(dollar_count));
Cow::Owned(ret.into_iter().collect())
}
fn expando_char(&self) -> char {
match self {
Self::Python | Self::Rust | Self::C | Self::Cpp | Self::Zig | Self::CSharp => {
'\u{00B5}' }
_ => '$',
}
}
}
impl LanguageExt for AstGrepLang {
fn get_ts_language(&self) -> TSLanguage {
match self {
Self::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
Self::Tsx => tree_sitter_typescript::LANGUAGE_TSX.into(),
Self::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
Self::Python => tree_sitter_python::LANGUAGE.into(),
Self::Rust => tree_sitter_rust::LANGUAGE.into(),
Self::Go => tree_sitter_go::LANGUAGE.into(),
Self::C => tree_sitter_c::LANGUAGE.into(),
Self::Cpp => tree_sitter_cpp::LANGUAGE.into(),
Self::Zig => tree_sitter_zig::LANGUAGE.into(),
Self::CSharp => tree_sitter_c_sharp::LANGUAGE.into(),
Self::Solidity => tree_sitter_solidity::LANGUAGE.into(),
Self::Vue => tree_sitter_vue::LANGUAGE.into(),
Self::Json => tree_sitter_json::LANGUAGE.into(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ast_grep_core::tree_sitter::LanguageExt;
#[test]
fn test_from_str() {
assert_eq!(
AstGrepLang::from_str("typescript"),
Some(AstGrepLang::TypeScript)
);
assert_eq!(AstGrepLang::from_str("tsx"), Some(AstGrepLang::Tsx));
assert_eq!(
AstGrepLang::from_str("javascript"),
Some(AstGrepLang::JavaScript)
);
assert_eq!(AstGrepLang::from_str("python"), Some(AstGrepLang::Python));
assert_eq!(AstGrepLang::from_str("rust"), Some(AstGrepLang::Rust));
assert_eq!(AstGrepLang::from_str("go"), Some(AstGrepLang::Go));
assert_eq!(AstGrepLang::from_str("c"), Some(AstGrepLang::C));
assert_eq!(AstGrepLang::from_str("cpp"), Some(AstGrepLang::Cpp));
assert_eq!(AstGrepLang::from_str("zig"), Some(AstGrepLang::Zig));
assert_eq!(AstGrepLang::from_str("c#"), Some(AstGrepLang::CSharp));
assert_eq!(
AstGrepLang::from_str("solidity"),
Some(AstGrepLang::Solidity)
);
assert_eq!(AstGrepLang::from_str("sol"), Some(AstGrepLang::Solidity));
assert_eq!(AstGrepLang::from_str("vue"), Some(AstGrepLang::Vue));
assert_eq!(AstGrepLang::from_str("json"), Some(AstGrepLang::Json));
assert_eq!(AstGrepLang::from_str("markdown"), None);
}
#[test]
fn test_ast_grep_basic() {
let lang = AstGrepLang::TypeScript;
let grep = lang.ast_grep("const x = 1;");
let root = grep.root();
assert!(root.find("const $X = $Y").is_some());
}
#[test]
fn test_python_function_pattern() {
let lang = AstGrepLang::Python;
let source = "def add(a, b):\n return a + b\n";
let grep = lang.ast_grep(source);
let root = grep.root();
let found = root.find("def $FUNC($$$):\n return $X");
assert!(found.is_some(), "Python function pattern should match");
let node = found.unwrap();
assert_eq!(node.text(), "def add(a, b):\n return a + b");
}
#[test]
fn test_python_expression_pattern() {
let lang = AstGrepLang::Python;
let source = "x = self.value + 1\n";
let grep = lang.ast_grep(source);
let root = grep.root();
let found = root.find("self.$ATTR + $X");
assert!(found.is_some(), "Python expression pattern should match");
}
#[test]
fn test_rust_function_pattern() {
let lang = AstGrepLang::Rust;
let source = "fn add(a: i32, b: i32) -> i32 { a + b }";
let grep = lang.ast_grep(source);
let root = grep.root();
let found = root.find("fn $NAME($$$) -> $RET { $$$BODY }");
assert!(found.is_some(), "Rust function pattern should match");
}
#[test]
fn test_expando_char() {
assert_eq!(AstGrepLang::Python.expando_char(), '\u{00B5}');
assert_eq!(AstGrepLang::Rust.expando_char(), '\u{00B5}');
assert_eq!(AstGrepLang::C.expando_char(), '\u{00B5}');
assert_eq!(AstGrepLang::Cpp.expando_char(), '\u{00B5}');
assert_eq!(AstGrepLang::Zig.expando_char(), '\u{00B5}');
assert_eq!(AstGrepLang::CSharp.expando_char(), '\u{00B5}');
assert_eq!(AstGrepLang::TypeScript.expando_char(), '$');
assert_eq!(AstGrepLang::JavaScript.expando_char(), '$');
assert_eq!(AstGrepLang::Go.expando_char(), '$');
}
#[test]
fn test_solidity_function_pattern_probe() {
let lang = AstGrepLang::Solidity;
let source = "contract C {\n function add(uint256 a) public pure returns (uint256) { return a; }\n}\n";
let grep = lang.ast_grep(source);
let root = grep.root();
let patterns = [
"return $X;",
"uint256 $X",
"function $NAME",
"function add",
"contract C { $$$ }",
];
let mut any_matched = false;
for pat in &patterns {
if root.find(*pat).is_some() {
any_matched = true;
break;
}
}
assert!(
any_matched,
"no Solidity pattern matched — grammar wiring broken"
);
}
#[test]
fn solidity_expando_char_stays_dollar() {
assert_eq!(AstGrepLang::Solidity.expando_char(), '$');
}
#[test]
fn vue_expando_char_stays_dollar() {
assert_eq!(AstGrepLang::Vue.expando_char(), '$');
}
#[test]
fn solidity_meta_var_pattern_binds_capture() {
let lang = AstGrepLang::Solidity;
let source =
"contract C {\n function add(uint256 a) public pure returns (uint256) { return a; }\n}\n";
let grep = lang.ast_grep(source);
let root = grep.root();
let found = root.find("function $NAME($$$) public pure returns ($$$) { $$$ }");
assert!(
found.is_some(),
"Solidity meta-var pattern must match — bug recurrence"
);
}
#[test]
fn test_pre_process_pattern_python() {
let lang = AstGrepLang::Python;
let result = lang.pre_process_pattern("def $FUNC($$$):");
assert!(result.contains('\u{00B5}'), "Should contain µ expando char");
assert!(
!result.contains('$'),
"Should not contain $ after preprocessing"
);
}
}