1mod bash;
22mod cpp;
23mod csharp;
24mod css;
25mod dart;
26mod elixir;
27mod go;
28mod haskell;
29mod hcl;
30mod html;
31mod json;
32mod kotlin;
33mod lua;
34mod markdown;
35mod nix;
36mod parsers;
37mod php;
38mod python;
39mod ruby;
40mod rust;
41mod scala;
42mod solidity;
43mod swift;
44mod yaml;
45
46use ast_grep_core::matcher::{Pattern, PatternBuilder, PatternError};
47pub use html::Html;
48
49use ast_grep_core::meta_var::MetaVariable;
50use ast_grep_core::tree_sitter::{StrDoc, TSLanguage, TSRange};
51use ast_grep_core::Node;
52use ignore::types::{Types, TypesBuilder};
53use serde::de::Visitor;
54use serde::{de, Deserialize, Deserializer, Serialize};
55use std::borrow::Cow;
56use std::fmt;
57use std::fmt::{Display, Formatter};
58use std::iter::repeat;
59use std::path::Path;
60use std::str::FromStr;
61
62pub use ast_grep_core::language::Language;
63pub use ast_grep_core::tree_sitter::LanguageExt;
64
65macro_rules! impl_lang {
67 ($lang: ident, $func: ident) => {
68 #[derive(Clone, Copy, Debug)]
69 pub struct $lang;
70 impl Language for $lang {
71 fn kind_to_id(&self, kind: &str) -> u16 {
72 self
73 .get_ts_language()
74 .id_for_node_kind(kind, true)
75 }
76 fn field_to_id(&self, field: &str) -> Option<u16> {
77 self
78 .get_ts_language()
79 .field_id_for_name(field)
80 .map(|f| f.get())
81 }
82 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
83 builder.build(|src| StrDoc::try_new(src, self.clone()))
84 }
85 }
86 impl LanguageExt for $lang {
87 fn get_ts_language(&self) -> TSLanguage {
88 parsers::$func().into()
89 }
90 }
91 };
92}
93
94fn pre_process_pattern(expando: char, query: &str) -> std::borrow::Cow<'_, str> {
95 let mut ret = Vec::with_capacity(query.len());
96 let mut dollar_count = 0;
97 for c in query.chars() {
98 if c == '$' {
99 dollar_count += 1;
100 continue;
101 }
102 let need_replace = matches!(c, 'A'..='Z' | '_') || dollar_count == 3; let sigil = if need_replace { expando } else { '$' };
105 ret.extend(repeat(sigil).take(dollar_count));
106 dollar_count = 0;
107 ret.push(c);
108 }
109 let sigil = if dollar_count == 3 { expando } else { '$' };
111 ret.extend(repeat(sigil).take(dollar_count));
112 std::borrow::Cow::Owned(ret.into_iter().collect())
113}
114
115macro_rules! impl_lang_expando {
118 ($lang: ident, $func: ident, $char: expr) => {
119 #[derive(Clone, Copy, Debug)]
120 pub struct $lang;
121 impl Language for $lang {
122 fn kind_to_id(&self, kind: &str) -> u16 {
123 self
124 .get_ts_language()
125 .id_for_node_kind(kind, true)
126 }
127 fn field_to_id(&self, field: &str) -> Option<u16> {
128 self
129 .get_ts_language()
130 .field_id_for_name(field)
131 .map(|f| f.get())
132 }
133 fn expando_char(&self) -> char {
134 $char
135 }
136 fn pre_process_pattern<'q>(&self, query: &'q str) -> std::borrow::Cow<'q, str> {
137 pre_process_pattern(self.expando_char(), query)
138 }
139 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
140 builder.build(|src| StrDoc::try_new(src, self.clone()))
141 }
142 }
143 impl LanguageExt for $lang {
144 fn get_ts_language(&self) -> TSLanguage {
145 $crate::parsers::$func().into()
146 }
147 }
148 };
149}
150
151pub trait Alias: Display {
152 const ALIAS: &'static [&'static str];
153}
154
155macro_rules! impl_alias {
158 ($lang:ident => $as:expr) => {
159 impl Alias for $lang {
160 const ALIAS: &'static [&'static str] = $as;
161 }
162
163 impl fmt::Display for $lang {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 write!(f, "{:?}", self)
166 }
167 }
168
169 impl<'de> Deserialize<'de> for $lang {
170 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
171 where
172 D: Deserializer<'de>,
173 {
174 let vis = AliasVisitor {
175 aliases: Self::ALIAS,
176 };
177 deserializer.deserialize_str(vis)?;
178 Ok($lang)
179 }
180 }
181
182 impl From<$lang> for SupportLang {
183 fn from(_: $lang) -> Self {
184 Self::$lang
185 }
186 }
187 };
188}
189macro_rules! impl_aliases {
192 ($($lang:ident => $as:expr),* $(,)?) => {
193 $(impl_alias!($lang => $as);)*
194 const fn alias(lang: SupportLang) -> &'static [&'static str] {
195 match lang {
196 $(SupportLang::$lang => $lang::ALIAS),*
197 }
198 }
199 };
200}
201
202impl_lang_expando!(C, language_c, '๐');
205impl_lang_expando!(Cpp, language_cpp, '๐');
206impl_lang_expando!(CSharp, language_c_sharp, 'ยต');
210impl_lang_expando!(Css, language_css, '_');
212impl_lang_expando!(Elixir, language_elixir, 'ยต');
214impl_lang_expando!(Go, language_go, 'ยต');
217impl_lang_expando!(Haskell, language_haskell, 'ยต');
221impl_lang_expando!(Hcl, language_hcl, 'ยต');
223impl_lang_expando!(Kotlin, language_kotlin, 'ยต');
225impl_lang_expando!(Nix, language_nix, '_');
227impl_lang_expando!(Php, language_php, 'ยต');
229impl_lang_expando!(Python, language_python, 'ยต');
233impl_lang_expando!(Ruby, language_ruby, 'ยต');
235impl_lang_expando!(Rust, language_rust, 'ยต');
238impl_lang_expando!(Swift, language_swift, 'ยต');
240
241impl_lang!(Bash, language_bash);
244impl_lang!(Java, language_java);
245impl_lang!(JavaScript, language_javascript);
246impl_lang!(Json, language_json);
247impl_lang!(Lua, language_lua);
248impl_lang!(Markdown, language_markdown);
249impl_lang!(Scala, language_scala);
250impl_lang!(Solidity, language_solidity);
251impl_lang!(Tsx, language_tsx);
252impl_lang!(TypeScript, language_typescript);
253impl_lang!(Dart, language_dart);
254impl_lang!(Yaml, language_yaml);
255#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Hash)]
260pub enum SupportLang {
261 Bash,
262 C,
263 Cpp,
264 CSharp,
265 Css,
266 Dart,
267 Go,
268 Elixir,
269 Haskell,
270 Hcl,
271 Html,
272 Java,
273 JavaScript,
274 Json,
275 Kotlin,
276 Lua,
277 Markdown,
278 Nix,
279 Php,
280 Python,
281 Ruby,
282 Rust,
283 Scala,
284 Solidity,
285 Swift,
286 Tsx,
287 TypeScript,
288 Yaml,
289}
290
291impl SupportLang {
292 pub const fn all_langs() -> &'static [SupportLang] {
293 use SupportLang::*;
294 &[
295 Bash, C, Cpp, CSharp, Css, Dart, Elixir, Go, Haskell, Hcl, Html, Java, JavaScript, Json,
296 Kotlin, Lua, Markdown, Nix, Php, Python, Ruby, Rust, Scala, Solidity, Swift, Tsx, TypeScript,
297 Yaml,
298 ]
299 }
300
301 pub fn file_types(&self) -> Types {
302 file_types(*self)
303 }
304}
305
306impl fmt::Display for SupportLang {
307 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308 write!(f, "{self:?}")
309 }
310}
311
312#[derive(Debug)]
313pub enum SupportLangErr {
314 LanguageNotSupported(String),
315}
316
317impl Display for SupportLangErr {
318 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
319 use SupportLangErr::*;
320 match self {
321 LanguageNotSupported(lang) => write!(f, "{lang} is not supported!"),
322 }
323 }
324}
325
326impl std::error::Error for SupportLangErr {}
327
328impl<'de> Deserialize<'de> for SupportLang {
329 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
330 where
331 D: Deserializer<'de>,
332 {
333 deserializer.deserialize_str(SupportLangVisitor)
334 }
335}
336
337struct SupportLangVisitor;
338
339impl Visitor<'_> for SupportLangVisitor {
340 type Value = SupportLang;
341
342 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
343 f.write_str("SupportLang")
344 }
345
346 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
347 where
348 E: de::Error,
349 {
350 v.parse().map_err(de::Error::custom)
351 }
352}
353struct AliasVisitor {
354 aliases: &'static [&'static str],
355}
356
357impl Visitor<'_> for AliasVisitor {
358 type Value = &'static str;
359
360 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
361 write!(f, "one of {:?}", self.aliases)
362 }
363
364 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
365 where
366 E: de::Error,
367 {
368 self
369 .aliases
370 .iter()
371 .copied()
372 .find(|&a| v.eq_ignore_ascii_case(a))
373 .ok_or_else(|| de::Error::invalid_value(de::Unexpected::Str(v), &self))
374 }
375}
376
377impl_aliases! {
378 Bash => &["bash"],
379 C => &["c"],
380 Cpp => &["cc", "c++", "cpp", "cxx"],
381 CSharp => &["cs", "csharp"],
382 Css => &["css"],
383 Dart => &["dart"],
384 Elixir => &["ex", "elixir"],
385 Go => &["go", "golang"],
386 Haskell => &["hs", "haskell"],
387 Hcl => &["hcl"],
388 Html => &["html"],
389 Java => &["java"],
390 JavaScript => &["javascript", "js", "jsx"],
391 Json => &["json"],
392 Kotlin => &["kotlin", "kt"],
393 Lua => &["lua"],
394 Markdown => &["markdown", "md"],
395 Nix => &["nix"],
396 Php => &["php"],
397 Python => &["py", "python"],
398 Ruby => &["rb", "ruby"],
399 Rust => &["rs", "rust"],
400 Scala => &["scala"],
401 Solidity => &["sol", "solidity"],
402 Swift => &["swift"],
403 TypeScript => &["ts", "typescript"],
404 Tsx => &["tsx"],
405 Yaml => &["yaml", "yml"],
406}
407
408impl FromStr for SupportLang {
410 type Err = SupportLangErr;
411 fn from_str(s: &str) -> Result<Self, Self::Err> {
412 for &lang in Self::all_langs() {
413 for moniker in alias(lang) {
414 if s.eq_ignore_ascii_case(moniker) {
415 return Ok(lang);
416 }
417 }
418 }
419 Err(SupportLangErr::LanguageNotSupported(s.to_string()))
420 }
421}
422
423macro_rules! execute_lang_method {
424 ($me: path, $method: ident, $($pname:tt),*) => {
425 use SupportLang as S;
426 match $me {
427 S::Bash => Bash.$method($($pname,)*),
428 S::C => C.$method($($pname,)*),
429 S::Cpp => Cpp.$method($($pname,)*),
430 S::CSharp => CSharp.$method($($pname,)*),
431 S::Css => Css.$method($($pname,)*),
432 S::Dart => Dart.$method($($pname,)*),
433 S::Elixir => Elixir.$method($($pname,)*),
434 S::Go => Go.$method($($pname,)*),
435 S::Haskell => Haskell.$method($($pname,)*),
436 S::Hcl => Hcl.$method($($pname,)*),
437 S::Html => Html.$method($($pname,)*),
438 S::Java => Java.$method($($pname,)*),
439 S::JavaScript => JavaScript.$method($($pname,)*),
440 S::Json => Json.$method($($pname,)*),
441 S::Kotlin => Kotlin.$method($($pname,)*),
442 S::Lua => Lua.$method($($pname,)*),
443 S::Markdown => Markdown.$method($($pname,)*),
444 S::Nix => Nix.$method($($pname,)*),
445 S::Php => Php.$method($($pname,)*),
446 S::Python => Python.$method($($pname,)*),
447 S::Ruby => Ruby.$method($($pname,)*),
448 S::Rust => Rust.$method($($pname,)*),
449 S::Scala => Scala.$method($($pname,)*),
450 S::Solidity => Solidity.$method($($pname,)*),
451 S::Swift => Swift.$method($($pname,)*),
452 S::Tsx => Tsx.$method($($pname,)*),
453 S::TypeScript => TypeScript.$method($($pname,)*),
454 S::Yaml => Yaml.$method($($pname,)*),
455 }
456 }
457}
458
459macro_rules! impl_lang_method {
460 ($method: ident, ($($pname:tt: $ptype:ty),*) => $return_type: ty) => {
461 #[inline]
462 fn $method(&self, $($pname: $ptype),*) -> $return_type {
463 execute_lang_method!{ self, $method, $($pname),* }
464 }
465 };
466}
467impl Language for SupportLang {
468 impl_lang_method!(kind_to_id, (kind: &str) => u16);
469 impl_lang_method!(field_to_id, (field: &str) => Option<u16>);
470 impl_lang_method!(meta_var_char, () => char);
471 impl_lang_method!(expando_char, () => char);
472 impl_lang_method!(extract_meta_var, (source: &str) => Option<MetaVariable>);
473 impl_lang_method!(build_pattern, (builder: &PatternBuilder) => Result<Pattern, PatternError>);
474 fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
475 execute_lang_method! { self, pre_process_pattern, query }
476 }
477 fn from_path<P: AsRef<Path>>(path: P) -> Option<Self> {
478 from_extension(path.as_ref())
479 }
480}
481
482impl LanguageExt for SupportLang {
483 impl_lang_method!(get_ts_language, () => TSLanguage);
484 impl_lang_method!(injectable_languages, () => Option<&'static [&'static str]>);
485 fn extract_injections<L: LanguageExt>(
486 &self,
487 root: Node<StrDoc<L>>,
488 ) -> Vec<(String, Vec<TSRange>)> {
489 match self {
490 SupportLang::Html => Html.extract_injections(root),
491 _ => Vec::new(),
492 }
493 }
494}
495
496fn extensions(lang: SupportLang) -> &'static [&'static str] {
497 use SupportLang::*;
498 match lang {
499 Bash => &[
500 "bash", "bats", "cgi", "command", "env", "fcgi", "ksh", "sh", "tmux", "tool", "zsh",
501 ],
502 C => &["c", "h"],
503 Cpp => &["cc", "hpp", "cpp", "c++", "hh", "cxx", "cu", "ino"],
504 CSharp => &["cs"],
505 Css => &["css", "scss"],
506 Dart => &["dart"],
507 Elixir => &["ex", "exs"],
508 Go => &["go"],
509 Haskell => &["hs"],
510 Hcl => &["hcl", "nomad", "tf", "tfvars", "workflow"],
511 Html => &["html", "htm", "xhtml"],
512 Java => &["java"],
513 JavaScript => &["cjs", "js", "mjs", "jsx"],
514 Json => &["json"],
515 Kotlin => &["kt", "ktm", "kts"],
516 Lua => &["lua"],
517 Markdown => &["markdown", "md"],
518 Nix => &["nix"],
519 Php => &["php"],
520 Python => &["py", "py3", "pyi", "bzl"],
521 Ruby => &["rb", "rbw", "gemspec"],
522 Rust => &["rs"],
523 Scala => &["scala", "sc", "sbt"],
524 Solidity => &["sol"],
525 Swift => &["swift"],
526 TypeScript => &["ts", "cts", "mts"],
527 Tsx => &["tsx"],
528 Yaml => &["yaml", "yml"],
529 }
530}
531
532fn from_extension(path: &Path) -> Option<SupportLang> {
536 let ext = path.extension()?.to_str()?;
537 SupportLang::all_langs()
538 .iter()
539 .copied()
540 .find(|&l| extensions(l).contains(&ext))
541}
542
543fn add_custom_file_type<'b>(
544 builder: &'b mut TypesBuilder,
545 file_type: &str,
546 suffix_list: &[&str],
547) -> &'b mut TypesBuilder {
548 for suffix in suffix_list {
549 let glob = format!("*.{suffix}");
550 builder
551 .add(file_type, &glob)
552 .expect("file pattern must compile");
553 }
554 builder.select(file_type)
555}
556
557fn file_types(lang: SupportLang) -> Types {
558 let mut builder = TypesBuilder::new();
559 let exts = extensions(lang);
560 let lang_name = lang.to_string();
561 add_custom_file_type(&mut builder, &lang_name, exts);
562 builder.build().expect("file type must be valid")
563}
564
565pub fn config_file_type() -> Types {
566 let mut builder = TypesBuilder::new();
567 let builder = add_custom_file_type(&mut builder, "yml", &["yml", "yaml"]);
568 builder.build().expect("yaml type must be valid")
569}
570
571#[cfg(test)]
572mod test {
573 use super::*;
574 use ast_grep_core::{matcher::MatcherExt, Pattern};
575
576 pub fn test_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
577 let cand = lang.ast_grep(source);
578 let pattern = Pattern::new(query, lang);
579 assert!(
580 pattern.find_node(cand.root()).is_some(),
581 "goal: {pattern:?}, candidate: {}",
582 cand.root().get_inner_node().to_sexp(),
583 );
584 }
585
586 pub fn test_non_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
587 let cand = lang.ast_grep(source);
588 let pattern = Pattern::new(query, lang);
589 assert!(
590 pattern.find_node(cand.root()).is_none(),
591 "goal: {pattern:?}, candidate: {}",
592 cand.root().get_inner_node().to_sexp(),
593 );
594 }
595
596 pub fn test_replace_lang(
597 src: &str,
598 pattern: &str,
599 replacer: &str,
600 lang: impl LanguageExt,
601 ) -> String {
602 let mut source = lang.ast_grep(src);
603 assert!(source
604 .replace(pattern, replacer)
605 .expect("should parse successfully"));
606 source.generate()
607 }
608
609 #[test]
610 fn test_js_string() {
611 test_match_lang("'a'", "'a'", JavaScript);
612 test_match_lang("\"\"", "\"\"", JavaScript);
613 test_match_lang("''", "''", JavaScript);
614 }
615
616 #[test]
617 fn test_guess_by_extension() {
618 let path = Path::new("foo.rs");
619 assert_eq!(from_extension(path), Some(SupportLang::Rust));
620 let path = Path::new("README.md");
621 assert_eq!(from_extension(path), Some(SupportLang::Markdown));
622 let path = Path::new("README.markdown");
623 assert_eq!(from_extension(path), Some(SupportLang::Markdown));
624 }
625
626 }