1mod bash;
22mod cpp;
23mod csharp;
24mod css;
25mod dart;
26mod elixir;
27mod go;
28mod haskell;
29mod hcl;
30mod html;
31mod json;
32mod kotlin;
33mod lua;
34mod nix;
35mod parsers;
36mod php;
37mod python;
38mod ruby;
39mod rust;
40mod scala;
41mod solidity;
42mod swift;
43mod yaml;
44
45use ast_grep_core::matcher::{Pattern, PatternBuilder, PatternError};
46pub use html::Html;
47
48use ast_grep_core::meta_var::MetaVariable;
49use ast_grep_core::tree_sitter::{StrDoc, TSLanguage, TSRange};
50use ast_grep_core::Node;
51use ignore::types::{Types, TypesBuilder};
52use serde::de::Visitor;
53use serde::{de, Deserialize, Deserializer, Serialize};
54use std::borrow::Cow;
55use std::fmt;
56use std::fmt::{Display, Formatter};
57use std::iter::repeat;
58use std::path::Path;
59use std::str::FromStr;
60
61pub use ast_grep_core::language::Language;
62pub use ast_grep_core::tree_sitter::LanguageExt;
63
64macro_rules! impl_lang {
66 ($lang: ident, $func: ident) => {
67 #[derive(Clone, Copy, Debug)]
68 pub struct $lang;
69 impl Language for $lang {
70 fn kind_to_id(&self, kind: &str) -> u16 {
71 self
72 .get_ts_language()
73 .id_for_node_kind(kind, true)
74 }
75 fn field_to_id(&self, field: &str) -> Option<u16> {
76 self
77 .get_ts_language()
78 .field_id_for_name(field)
79 .map(|f| f.get())
80 }
81 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
82 builder.build(|src| StrDoc::try_new(src, self.clone()))
83 }
84 }
85 impl LanguageExt for $lang {
86 fn get_ts_language(&self) -> TSLanguage {
87 parsers::$func().into()
88 }
89 }
90 };
91}
92
93fn pre_process_pattern(expando: char, query: &str) -> std::borrow::Cow<'_, str> {
94 let mut ret = Vec::with_capacity(query.len());
95 let mut dollar_count = 0;
96 for c in query.chars() {
97 if c == '$' {
98 dollar_count += 1;
99 continue;
100 }
101 let need_replace = matches!(c, 'A'..='Z' | '_') || dollar_count == 3; let sigil = if need_replace { expando } else { '$' };
104 ret.extend(repeat(sigil).take(dollar_count));
105 dollar_count = 0;
106 ret.push(c);
107 }
108 let sigil = if dollar_count == 3 { expando } else { '$' };
110 ret.extend(repeat(sigil).take(dollar_count));
111 std::borrow::Cow::Owned(ret.into_iter().collect())
112}
113
114macro_rules! impl_lang_expando {
117 ($lang: ident, $func: ident, $char: expr) => {
118 #[derive(Clone, Copy, Debug)]
119 pub struct $lang;
120 impl Language for $lang {
121 fn kind_to_id(&self, kind: &str) -> u16 {
122 self
123 .get_ts_language()
124 .id_for_node_kind(kind, true)
125 }
126 fn field_to_id(&self, field: &str) -> Option<u16> {
127 self
128 .get_ts_language()
129 .field_id_for_name(field)
130 .map(|f| f.get())
131 }
132 fn expando_char(&self) -> char {
133 $char
134 }
135 fn pre_process_pattern<'q>(&self, query: &'q str) -> std::borrow::Cow<'q, str> {
136 pre_process_pattern(self.expando_char(), query)
137 }
138 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
139 builder.build(|src| StrDoc::try_new(src, self.clone()))
140 }
141 }
142 impl LanguageExt for $lang {
143 fn get_ts_language(&self) -> TSLanguage {
144 $crate::parsers::$func().into()
145 }
146 }
147 };
148}
149
150pub trait Alias: Display {
151 const ALIAS: &'static [&'static str];
152}
153
154macro_rules! impl_alias {
157 ($lang:ident => $as:expr) => {
158 impl Alias for $lang {
159 const ALIAS: &'static [&'static str] = $as;
160 }
161
162 impl fmt::Display for $lang {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 write!(f, "{:?}", self)
165 }
166 }
167
168 impl<'de> Deserialize<'de> for $lang {
169 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170 where
171 D: Deserializer<'de>,
172 {
173 let vis = AliasVisitor {
174 aliases: Self::ALIAS,
175 };
176 deserializer.deserialize_str(vis)?;
177 Ok($lang)
178 }
179 }
180
181 impl From<$lang> for SupportLang {
182 fn from(_: $lang) -> Self {
183 Self::$lang
184 }
185 }
186 };
187}
188macro_rules! impl_aliases {
191 ($($lang:ident => $as:expr),* $(,)?) => {
192 $(impl_alias!($lang => $as);)*
193 const fn alias(lang: SupportLang) -> &'static [&'static str] {
194 match lang {
195 $(SupportLang::$lang => $lang::ALIAS),*
196 }
197 }
198 };
199}
200
201impl_lang_expando!(C, language_c, '๐');
204impl_lang_expando!(Cpp, language_cpp, '๐');
205impl_lang_expando!(CSharp, language_c_sharp, 'ยต');
209impl_lang_expando!(Css, language_css, '_');
211impl_lang_expando!(Elixir, language_elixir, 'ยต');
213impl_lang_expando!(Go, language_go, 'ยต');
216impl_lang_expando!(Haskell, language_haskell, 'ยต');
220impl_lang_expando!(Hcl, language_hcl, 'ยต');
222impl_lang_expando!(Kotlin, language_kotlin, 'ยต');
224impl_lang_expando!(Nix, language_nix, '_');
226impl_lang_expando!(Php, language_php, 'ยต');
228impl_lang_expando!(Python, language_python, 'ยต');
232impl_lang_expando!(Ruby, language_ruby, 'ยต');
234impl_lang_expando!(Rust, language_rust, 'ยต');
237impl_lang_expando!(Swift, language_swift, 'ยต');
239
240impl_lang!(Bash, language_bash);
243impl_lang!(Java, language_java);
244impl_lang!(JavaScript, language_javascript);
245impl_lang!(Json, language_json);
246impl_lang!(Lua, language_lua);
247impl_lang!(Scala, language_scala);
248impl_lang!(Solidity, language_solidity);
249impl_lang!(Tsx, language_tsx);
250impl_lang!(TypeScript, language_typescript);
251impl_lang!(Dart, language_dart);
252impl_lang!(Yaml, language_yaml);
253#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Hash)]
258pub enum SupportLang {
259 Bash,
260 C,
261 Cpp,
262 CSharp,
263 Css,
264 Dart,
265 Go,
266 Elixir,
267 Haskell,
268 Hcl,
269 Html,
270 Java,
271 JavaScript,
272 Json,
273 Kotlin,
274 Lua,
275 Nix,
276 Php,
277 Python,
278 Ruby,
279 Rust,
280 Scala,
281 Solidity,
282 Swift,
283 Tsx,
284 TypeScript,
285 Yaml,
286}
287
288impl SupportLang {
289 pub const fn all_langs() -> &'static [SupportLang] {
290 use SupportLang::*;
291 &[
292 Bash, C, Cpp, CSharp, Css, Dart, Elixir, Go, Haskell, Hcl, Html, Java, JavaScript, Json,
293 Kotlin, Lua, Nix, Php, Python, Ruby, Rust, Scala, Solidity, Swift, Tsx, TypeScript, Yaml,
294 ]
295 }
296
297 pub fn file_types(&self) -> Types {
298 file_types(*self)
299 }
300}
301
302impl fmt::Display for SupportLang {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 write!(f, "{self:?}")
305 }
306}
307
308#[derive(Debug)]
309pub enum SupportLangErr {
310 LanguageNotSupported(String),
311}
312
313impl Display for SupportLangErr {
314 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
315 use SupportLangErr::*;
316 match self {
317 LanguageNotSupported(lang) => write!(f, "{lang} is not supported!"),
318 }
319 }
320}
321
322impl std::error::Error for SupportLangErr {}
323
324impl<'de> Deserialize<'de> for SupportLang {
325 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
326 where
327 D: Deserializer<'de>,
328 {
329 deserializer.deserialize_str(SupportLangVisitor)
330 }
331}
332
333struct SupportLangVisitor;
334
335impl Visitor<'_> for SupportLangVisitor {
336 type Value = SupportLang;
337
338 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
339 f.write_str("SupportLang")
340 }
341
342 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
343 where
344 E: de::Error,
345 {
346 v.parse().map_err(de::Error::custom)
347 }
348}
349struct AliasVisitor {
350 aliases: &'static [&'static str],
351}
352
353impl Visitor<'_> for AliasVisitor {
354 type Value = &'static str;
355
356 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
357 write!(f, "one of {:?}", self.aliases)
358 }
359
360 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
361 where
362 E: de::Error,
363 {
364 self
365 .aliases
366 .iter()
367 .copied()
368 .find(|&a| v.eq_ignore_ascii_case(a))
369 .ok_or_else(|| de::Error::invalid_value(de::Unexpected::Str(v), &self))
370 }
371}
372
373impl_aliases! {
374 Bash => &["bash"],
375 C => &["c"],
376 Cpp => &["cc", "c++", "cpp", "cxx"],
377 CSharp => &["cs", "csharp"],
378 Css => &["css"],
379 Dart => &["dart"],
380 Elixir => &["ex", "elixir"],
381 Go => &["go", "golang"],
382 Haskell => &["hs", "haskell"],
383 Hcl => &["hcl"],
384 Html => &["html"],
385 Java => &["java"],
386 JavaScript => &["javascript", "js", "jsx"],
387 Json => &["json"],
388 Kotlin => &["kotlin", "kt"],
389 Lua => &["lua"],
390 Nix => &["nix"],
391 Php => &["php"],
392 Python => &["py", "python"],
393 Ruby => &["rb", "ruby"],
394 Rust => &["rs", "rust"],
395 Scala => &["scala"],
396 Solidity => &["sol", "solidity"],
397 Swift => &["swift"],
398 TypeScript => &["ts", "typescript"],
399 Tsx => &["tsx"],
400 Yaml => &["yaml", "yml"],
401}
402
403impl FromStr for SupportLang {
405 type Err = SupportLangErr;
406 fn from_str(s: &str) -> Result<Self, Self::Err> {
407 for &lang in Self::all_langs() {
408 for moniker in alias(lang) {
409 if s.eq_ignore_ascii_case(moniker) {
410 return Ok(lang);
411 }
412 }
413 }
414 Err(SupportLangErr::LanguageNotSupported(s.to_string()))
415 }
416}
417
418macro_rules! execute_lang_method {
419 ($me: path, $method: ident, $($pname:tt),*) => {
420 use SupportLang as S;
421 match $me {
422 S::Bash => Bash.$method($($pname,)*),
423 S::C => C.$method($($pname,)*),
424 S::Cpp => Cpp.$method($($pname,)*),
425 S::CSharp => CSharp.$method($($pname,)*),
426 S::Css => Css.$method($($pname,)*),
427 S::Dart => Dart.$method($($pname,)*),
428 S::Elixir => Elixir.$method($($pname,)*),
429 S::Go => Go.$method($($pname,)*),
430 S::Haskell => Haskell.$method($($pname,)*),
431 S::Hcl => Hcl.$method($($pname,)*),
432 S::Html => Html.$method($($pname,)*),
433 S::Java => Java.$method($($pname,)*),
434 S::JavaScript => JavaScript.$method($($pname,)*),
435 S::Json => Json.$method($($pname,)*),
436 S::Kotlin => Kotlin.$method($($pname,)*),
437 S::Lua => Lua.$method($($pname,)*),
438 S::Nix => Nix.$method($($pname,)*),
439 S::Php => Php.$method($($pname,)*),
440 S::Python => Python.$method($($pname,)*),
441 S::Ruby => Ruby.$method($($pname,)*),
442 S::Rust => Rust.$method($($pname,)*),
443 S::Scala => Scala.$method($($pname,)*),
444 S::Solidity => Solidity.$method($($pname,)*),
445 S::Swift => Swift.$method($($pname,)*),
446 S::Tsx => Tsx.$method($($pname,)*),
447 S::TypeScript => TypeScript.$method($($pname,)*),
448 S::Yaml => Yaml.$method($($pname,)*),
449 }
450 }
451}
452
453macro_rules! impl_lang_method {
454 ($method: ident, ($($pname:tt: $ptype:ty),*) => $return_type: ty) => {
455 #[inline]
456 fn $method(&self, $($pname: $ptype),*) -> $return_type {
457 execute_lang_method!{ self, $method, $($pname),* }
458 }
459 };
460}
461impl Language for SupportLang {
462 impl_lang_method!(kind_to_id, (kind: &str) => u16);
463 impl_lang_method!(field_to_id, (field: &str) => Option<u16>);
464 impl_lang_method!(meta_var_char, () => char);
465 impl_lang_method!(expando_char, () => char);
466 impl_lang_method!(extract_meta_var, (source: &str) => Option<MetaVariable>);
467 impl_lang_method!(build_pattern, (builder: &PatternBuilder) => Result<Pattern, PatternError>);
468 fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
469 execute_lang_method! { self, pre_process_pattern, query }
470 }
471 fn from_path<P: AsRef<Path>>(path: P) -> Option<Self> {
472 from_extension(path.as_ref())
473 }
474}
475
476impl LanguageExt for SupportLang {
477 impl_lang_method!(get_ts_language, () => TSLanguage);
478 impl_lang_method!(injectable_languages, () => Option<&'static [&'static str]>);
479 fn extract_injections<L: LanguageExt>(
480 &self,
481 root: Node<StrDoc<L>>,
482 ) -> Vec<(String, Vec<TSRange>)> {
483 match self {
484 SupportLang::Html => Html.extract_injections(root),
485 _ => Vec::new(),
486 }
487 }
488}
489
490fn extensions(lang: SupportLang) -> &'static [&'static str] {
491 use SupportLang::*;
492 match lang {
493 Bash => &[
494 "bash", "bats", "cgi", "command", "env", "fcgi", "ksh", "sh", "tmux", "tool", "zsh",
495 ],
496 C => &["c", "h"],
497 Cpp => &["cc", "hpp", "cpp", "c++", "hh", "cxx", "cu", "ino"],
498 CSharp => &["cs"],
499 Css => &["css", "scss"],
500 Dart => &["dart"],
501 Elixir => &["ex", "exs"],
502 Go => &["go"],
503 Haskell => &["hs"],
504 Hcl => &["hcl", "nomad", "tf", "tfvars", "workflow"],
505 Html => &["html", "htm", "xhtml"],
506 Java => &["java"],
507 JavaScript => &["cjs", "js", "mjs", "jsx"],
508 Json => &["json"],
509 Kotlin => &["kt", "ktm", "kts"],
510 Lua => &["lua"],
511 Nix => &["nix"],
512 Php => &["php"],
513 Python => &["py", "py3", "pyi", "bzl"],
514 Ruby => &["rb", "rbw", "gemspec"],
515 Rust => &["rs"],
516 Scala => &["scala", "sc", "sbt"],
517 Solidity => &["sol"],
518 Swift => &["swift"],
519 TypeScript => &["ts", "cts", "mts"],
520 Tsx => &["tsx"],
521 Yaml => &["yaml", "yml"],
522 }
523}
524
525fn from_extension(path: &Path) -> Option<SupportLang> {
529 let ext = path.extension()?.to_str()?;
530 SupportLang::all_langs()
531 .iter()
532 .copied()
533 .find(|&l| extensions(l).contains(&ext))
534}
535
536fn add_custom_file_type<'b>(
537 builder: &'b mut TypesBuilder,
538 file_type: &str,
539 suffix_list: &[&str],
540) -> &'b mut TypesBuilder {
541 for suffix in suffix_list {
542 let glob = format!("*.{suffix}");
543 builder
544 .add(file_type, &glob)
545 .expect("file pattern must compile");
546 }
547 builder.select(file_type)
548}
549
550fn file_types(lang: SupportLang) -> Types {
551 let mut builder = TypesBuilder::new();
552 let exts = extensions(lang);
553 let lang_name = lang.to_string();
554 add_custom_file_type(&mut builder, &lang_name, exts);
555 builder.build().expect("file type must be valid")
556}
557
558pub fn config_file_type() -> Types {
559 let mut builder = TypesBuilder::new();
560 let builder = add_custom_file_type(&mut builder, "yml", &["yml", "yaml"]);
561 builder.build().expect("yaml type must be valid")
562}
563
564#[cfg(test)]
565mod test {
566 use super::*;
567 use ast_grep_core::{matcher::MatcherExt, Pattern};
568
569 pub fn test_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
570 let cand = lang.ast_grep(source);
571 let pattern = Pattern::new(query, lang);
572 assert!(
573 pattern.find_node(cand.root()).is_some(),
574 "goal: {pattern:?}, candidate: {}",
575 cand.root().get_inner_node().to_sexp(),
576 );
577 }
578
579 pub fn test_non_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
580 let cand = lang.ast_grep(source);
581 let pattern = Pattern::new(query, lang);
582 assert!(
583 pattern.find_node(cand.root()).is_none(),
584 "goal: {pattern:?}, candidate: {}",
585 cand.root().get_inner_node().to_sexp(),
586 );
587 }
588
589 pub fn test_replace_lang(
590 src: &str,
591 pattern: &str,
592 replacer: &str,
593 lang: impl LanguageExt,
594 ) -> String {
595 let mut source = lang.ast_grep(src);
596 assert!(source
597 .replace(pattern, replacer)
598 .expect("should parse successfully"));
599 source.generate()
600 }
601
602 #[test]
603 fn test_js_string() {
604 test_match_lang("'a'", "'a'", JavaScript);
605 test_match_lang("\"\"", "\"\"", JavaScript);
606 test_match_lang("''", "''", JavaScript);
607 }
608
609 #[test]
610 fn test_guess_by_extension() {
611 let path = Path::new("foo.rs");
612 assert_eq!(from_extension(path), Some(SupportLang::Rust));
613 }
614
615 }