1mod bash;
22mod cpp;
23mod csharp;
24mod css;
25mod elixir;
26mod go;
27mod haskell;
28mod hcl;
29mod html;
30mod json;
31mod kotlin;
32mod lua;
33mod nix;
34mod parsers;
35mod php;
36mod python;
37mod ruby;
38mod rust;
39mod scala;
40mod solidity;
41mod swift;
42mod yaml;
43
44use ast_grep_core::matcher::{Pattern, PatternBuilder, PatternError};
45pub use html::Html;
46
47use ast_grep_core::meta_var::MetaVariable;
48use ast_grep_core::tree_sitter::{StrDoc, TSLanguage, TSRange};
49use ast_grep_core::Node;
50use ignore::types::{Types, TypesBuilder};
51use serde::de::Visitor;
52use serde::{de, Deserialize, Deserializer, Serialize};
53use std::borrow::Cow;
54use std::collections::HashMap;
55use std::fmt;
56use std::fmt::{Display, Formatter};
57use std::iter::repeat;
58use std::path::Path;
59use std::str::FromStr;
60
61pub use ast_grep_core::language::Language;
62pub use ast_grep_core::tree_sitter::LanguageExt;
63
64macro_rules! impl_lang {
66 ($lang: ident, $func: ident) => {
67 #[derive(Clone, Copy, Debug)]
68 pub struct $lang;
69 impl Language for $lang {
70 fn kind_to_id(&self, kind: &str) -> u16 {
71 self
72 .get_ts_language()
73 .id_for_node_kind(kind, true)
74 }
75 fn field_to_id(&self, field: &str) -> Option<u16> {
76 self
77 .get_ts_language()
78 .field_id_for_name(field)
79 .map(|f| f.get())
80 }
81 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
82 builder.build(|src| StrDoc::try_new(src, self.clone()))
83 }
84 }
85 impl LanguageExt for $lang {
86 fn get_ts_language(&self) -> TSLanguage {
87 parsers::$func().into()
88 }
89 }
90 };
91}
92
93fn pre_process_pattern(expando: char, query: &str) -> std::borrow::Cow<'_, str> {
94 let mut ret = Vec::with_capacity(query.len());
95 let mut dollar_count = 0;
96 for c in query.chars() {
97 if c == '$' {
98 dollar_count += 1;
99 continue;
100 }
101 let need_replace = matches!(c, 'A'..='Z' | '_') || dollar_count == 3; let sigil = if need_replace { expando } else { '$' };
104 ret.extend(repeat(sigil).take(dollar_count));
105 dollar_count = 0;
106 ret.push(c);
107 }
108 let sigil = if dollar_count == 3 { expando } else { '$' };
110 ret.extend(repeat(sigil).take(dollar_count));
111 std::borrow::Cow::Owned(ret.into_iter().collect())
112}
113
114macro_rules! impl_lang_expando {
117 ($lang: ident, $func: ident, $char: expr) => {
118 #[derive(Clone, Copy, Debug)]
119 pub struct $lang;
120 impl Language for $lang {
121 fn kind_to_id(&self, kind: &str) -> u16 {
122 self
123 .get_ts_language()
124 .id_for_node_kind(kind, true)
125 }
126 fn field_to_id(&self, field: &str) -> Option<u16> {
127 self
128 .get_ts_language()
129 .field_id_for_name(field)
130 .map(|f| f.get())
131 }
132 fn expando_char(&self) -> char {
133 $char
134 }
135 fn pre_process_pattern<'q>(&self, query: &'q str) -> std::borrow::Cow<'q, str> {
136 pre_process_pattern(self.expando_char(), query)
137 }
138 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
139 builder.build(|src| StrDoc::try_new(src, self.clone()))
140 }
141 }
142 impl LanguageExt for $lang {
143 fn get_ts_language(&self) -> TSLanguage {
144 $crate::parsers::$func().into()
145 }
146 }
147 };
148}
149
150pub trait Alias: Display {
151 const ALIAS: &'static [&'static str];
152}
153
154macro_rules! impl_alias {
157 ($lang:ident => $as:expr) => {
158 impl Alias for $lang {
159 const ALIAS: &'static [&'static str] = $as;
160 }
161
162 impl fmt::Display for $lang {
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 write!(f, "{:?}", self)
165 }
166 }
167
168 impl<'de> Deserialize<'de> for $lang {
169 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
170 where
171 D: Deserializer<'de>,
172 {
173 let vis = AliasVisitor {
174 aliases: Self::ALIAS,
175 };
176 deserializer.deserialize_str(vis)?;
177 Ok($lang)
178 }
179 }
180
181 impl From<$lang> for SupportLang {
182 fn from(_: $lang) -> Self {
183 Self::$lang
184 }
185 }
186 };
187}
188macro_rules! impl_aliases {
191 ($($lang:ident => $as:expr),* $(,)?) => {
192 $(impl_alias!($lang => $as);)*
193 const fn alias(lang: SupportLang) -> &'static [&'static str] {
194 match lang {
195 $(SupportLang::$lang => $lang::ALIAS),*
196 }
197 }
198 };
199}
200
201impl_lang_expando!(C, language_c, '๐');
204impl_lang_expando!(Cpp, language_cpp, '๐');
205impl_lang_expando!(CSharp, language_c_sharp, 'ยต');
209impl_lang_expando!(Css, language_css, '_');
211impl_lang_expando!(Elixir, language_elixir, 'ยต');
213impl_lang_expando!(Go, language_go, 'ยต');
216impl_lang_expando!(Haskell, language_haskell, 'ยต');
220impl_lang_expando!(Hcl, language_hcl, 'ยต');
222impl_lang_expando!(Kotlin, language_kotlin, 'ยต');
224impl_lang_expando!(Nix, language_nix, '_');
226impl_lang_expando!(Php, language_php, 'ยต');
228impl_lang_expando!(Python, language_python, 'ยต');
232impl_lang_expando!(Ruby, language_ruby, 'ยต');
234impl_lang_expando!(Rust, language_rust, 'ยต');
237impl_lang_expando!(Swift, language_swift, 'ยต');
239
240impl_lang!(Bash, language_bash);
243impl_lang!(Java, language_java);
244impl_lang!(JavaScript, language_javascript);
245impl_lang!(Json, language_json);
246impl_lang!(Lua, language_lua);
247impl_lang!(Scala, language_scala);
248impl_lang!(Solidity, language_solidity);
249impl_lang!(Tsx, language_tsx);
250impl_lang!(TypeScript, language_typescript);
251impl_lang!(Yaml, language_yaml);
252#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Hash)]
257pub enum SupportLang {
258 Bash,
259 C,
260 Cpp,
261 CSharp,
262 Css,
263 Go,
264 Elixir,
265 Haskell,
266 Hcl,
267 Html,
268 Java,
269 JavaScript,
270 Json,
271 Kotlin,
272 Lua,
273 Nix,
274 Php,
275 Python,
276 Ruby,
277 Rust,
278 Scala,
279 Solidity,
280 Swift,
281 Tsx,
282 TypeScript,
283 Yaml,
284}
285
286impl SupportLang {
287 pub const fn all_langs() -> &'static [SupportLang] {
288 use SupportLang::*;
289 &[
290 Bash, C, Cpp, CSharp, Css, Elixir, Go, Haskell, Hcl, Html, Java, JavaScript, Json, Kotlin,
291 Lua, Nix, Php, Python, Ruby, Rust, Scala, Solidity, Swift, Tsx, TypeScript, Yaml,
292 ]
293 }
294
295 pub fn file_types(&self) -> Types {
296 file_types(*self)
297 }
298}
299
300impl fmt::Display for SupportLang {
301 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302 write!(f, "{self:?}")
303 }
304}
305
306#[derive(Debug)]
307pub enum SupportLangErr {
308 LanguageNotSupported(String),
309}
310
311impl Display for SupportLangErr {
312 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
313 use SupportLangErr::*;
314 match self {
315 LanguageNotSupported(lang) => write!(f, "{lang} is not supported!"),
316 }
317 }
318}
319
320impl std::error::Error for SupportLangErr {}
321
322impl<'de> Deserialize<'de> for SupportLang {
323 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
324 where
325 D: Deserializer<'de>,
326 {
327 deserializer.deserialize_str(SupportLangVisitor)
328 }
329}
330
331struct SupportLangVisitor;
332
333impl Visitor<'_> for SupportLangVisitor {
334 type Value = SupportLang;
335
336 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
337 f.write_str("SupportLang")
338 }
339
340 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
341 where
342 E: de::Error,
343 {
344 v.parse().map_err(de::Error::custom)
345 }
346}
347struct AliasVisitor {
348 aliases: &'static [&'static str],
349}
350
351impl Visitor<'_> for AliasVisitor {
352 type Value = &'static str;
353
354 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
355 write!(f, "one of {:?}", self.aliases)
356 }
357
358 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
359 where
360 E: de::Error,
361 {
362 self
363 .aliases
364 .iter()
365 .copied()
366 .find(|&a| v.eq_ignore_ascii_case(a))
367 .ok_or_else(|| de::Error::invalid_value(de::Unexpected::Str(v), &self))
368 }
369}
370
371impl_aliases! {
372 Bash => &["bash"],
373 C => &["c"],
374 Cpp => &["cc", "c++", "cpp", "cxx"],
375 CSharp => &["cs", "csharp"],
376 Css => &["css"],
377 Elixir => &["ex", "elixir"],
378 Go => &["go", "golang"],
379 Haskell => &["hs", "haskell"],
380 Hcl => &["hcl"],
381 Html => &["html"],
382 Java => &["java"],
383 JavaScript => &["javascript", "js", "jsx"],
384 Json => &["json"],
385 Kotlin => &["kotlin", "kt"],
386 Lua => &["lua"],
387 Nix => &["nix"],
388 Php => &["php"],
389 Python => &["py", "python"],
390 Ruby => &["rb", "ruby"],
391 Rust => &["rs", "rust"],
392 Scala => &["scala"],
393 Solidity => &["sol", "solidity"],
394 Swift => &["swift"],
395 TypeScript => &["ts", "typescript"],
396 Tsx => &["tsx"],
397 Yaml => &["yaml", "yml"],
398}
399
400impl FromStr for SupportLang {
402 type Err = SupportLangErr;
403 fn from_str(s: &str) -> Result<Self, Self::Err> {
404 for &lang in Self::all_langs() {
405 for moniker in alias(lang) {
406 if s.eq_ignore_ascii_case(moniker) {
407 return Ok(lang);
408 }
409 }
410 }
411 Err(SupportLangErr::LanguageNotSupported(s.to_string()))
412 }
413}
414
415macro_rules! execute_lang_method {
416 ($me: path, $method: ident, $($pname:tt),*) => {
417 use SupportLang as S;
418 match $me {
419 S::Bash => Bash.$method($($pname,)*),
420 S::C => C.$method($($pname,)*),
421 S::Cpp => Cpp.$method($($pname,)*),
422 S::CSharp => CSharp.$method($($pname,)*),
423 S::Css => Css.$method($($pname,)*),
424 S::Elixir => Elixir.$method($($pname,)*),
425 S::Go => Go.$method($($pname,)*),
426 S::Haskell => Haskell.$method($($pname,)*),
427 S::Hcl => Hcl.$method($($pname,)*),
428 S::Html => Html.$method($($pname,)*),
429 S::Java => Java.$method($($pname,)*),
430 S::JavaScript => JavaScript.$method($($pname,)*),
431 S::Json => Json.$method($($pname,)*),
432 S::Kotlin => Kotlin.$method($($pname,)*),
433 S::Lua => Lua.$method($($pname,)*),
434 S::Nix => Nix.$method($($pname,)*),
435 S::Php => Php.$method($($pname,)*),
436 S::Python => Python.$method($($pname,)*),
437 S::Ruby => Ruby.$method($($pname,)*),
438 S::Rust => Rust.$method($($pname,)*),
439 S::Scala => Scala.$method($($pname,)*),
440 S::Solidity => Solidity.$method($($pname,)*),
441 S::Swift => Swift.$method($($pname,)*),
442 S::Tsx => Tsx.$method($($pname,)*),
443 S::TypeScript => TypeScript.$method($($pname,)*),
444 S::Yaml => Yaml.$method($($pname,)*),
445 }
446 }
447}
448
449macro_rules! impl_lang_method {
450 ($method: ident, ($($pname:tt: $ptype:ty),*) => $return_type: ty) => {
451 #[inline]
452 fn $method(&self, $($pname: $ptype),*) -> $return_type {
453 execute_lang_method!{ self, $method, $($pname),* }
454 }
455 };
456}
457impl Language for SupportLang {
458 impl_lang_method!(kind_to_id, (kind: &str) => u16);
459 impl_lang_method!(field_to_id, (field: &str) => Option<u16>);
460 impl_lang_method!(meta_var_char, () => char);
461 impl_lang_method!(expando_char, () => char);
462 impl_lang_method!(extract_meta_var, (source: &str) => Option<MetaVariable>);
463 impl_lang_method!(build_pattern, (builder: &PatternBuilder) => Result<Pattern, PatternError>);
464 fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
465 execute_lang_method! { self, pre_process_pattern, query }
466 }
467 fn from_path<P: AsRef<Path>>(path: P) -> Option<Self> {
468 from_extension(path.as_ref())
469 }
470}
471
472impl LanguageExt for SupportLang {
473 impl_lang_method!(get_ts_language, () => TSLanguage);
474 impl_lang_method!(injectable_languages, () => Option<&'static [&'static str]>);
475 fn extract_injections<L: LanguageExt>(
476 &self,
477 root: Node<StrDoc<L>>,
478 ) -> HashMap<String, Vec<TSRange>> {
479 match self {
480 SupportLang::Html => Html.extract_injections(root),
481 _ => HashMap::new(),
482 }
483 }
484}
485
486fn extensions(lang: SupportLang) -> &'static [&'static str] {
487 use SupportLang::*;
488 match lang {
489 Bash => &[
490 "bash", "bats", "cgi", "command", "env", "fcgi", "ksh", "sh", "tmux", "tool", "zsh",
491 ],
492 C => &["c", "h"],
493 Cpp => &["cc", "hpp", "cpp", "c++", "hh", "cxx", "cu", "ino"],
494 CSharp => &["cs"],
495 Css => &["css", "scss"],
496 Elixir => &["ex", "exs"],
497 Go => &["go"],
498 Haskell => &["hs"],
499 Hcl => &["hcl"],
500 Html => &["html", "htm", "xhtml"],
501 Java => &["java"],
502 JavaScript => &["cjs", "js", "mjs", "jsx"],
503 Json => &["json"],
504 Kotlin => &["kt", "ktm", "kts"],
505 Lua => &["lua"],
506 Nix => &["nix"],
507 Php => &["php"],
508 Python => &["py", "py3", "pyi", "bzl"],
509 Ruby => &["rb", "rbw", "gemspec"],
510 Rust => &["rs"],
511 Scala => &["scala", "sc", "sbt"],
512 Solidity => &["sol"],
513 Swift => &["swift"],
514 TypeScript => &["ts", "cts", "mts"],
515 Tsx => &["tsx"],
516 Yaml => &["yaml", "yml"],
517 }
518}
519
520fn from_extension(path: &Path) -> Option<SupportLang> {
524 let ext = path.extension()?.to_str()?;
525 SupportLang::all_langs()
526 .iter()
527 .copied()
528 .find(|&l| extensions(l).contains(&ext))
529}
530
531fn add_custom_file_type<'b>(
532 builder: &'b mut TypesBuilder,
533 file_type: &str,
534 suffix_list: &[&str],
535) -> &'b mut TypesBuilder {
536 for suffix in suffix_list {
537 let glob = format!("*.{suffix}");
538 builder
539 .add(file_type, &glob)
540 .expect("file pattern must compile");
541 }
542 builder.select(file_type)
543}
544
545fn file_types(lang: SupportLang) -> Types {
546 let mut builder = TypesBuilder::new();
547 let exts = extensions(lang);
548 let lang_name = lang.to_string();
549 add_custom_file_type(&mut builder, &lang_name, exts);
550 builder.build().expect("file type must be valid")
551}
552
553pub fn config_file_type() -> Types {
554 let mut builder = TypesBuilder::new();
555 let builder = add_custom_file_type(&mut builder, "yml", &["yml", "yaml"]);
556 builder.build().expect("yaml type must be valid")
557}
558
559#[cfg(test)]
560mod test {
561 use super::*;
562 use ast_grep_core::{matcher::MatcherExt, Pattern};
563
564 pub fn test_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
565 let cand = lang.ast_grep(source);
566 let pattern = Pattern::new(query, lang);
567 assert!(
568 pattern.find_node(cand.root()).is_some(),
569 "goal: {pattern:?}, candidate: {}",
570 cand.root().get_inner_node().to_sexp(),
571 );
572 }
573
574 pub fn test_non_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
575 let cand = lang.ast_grep(source);
576 let pattern = Pattern::new(query, lang);
577 assert!(
578 pattern.find_node(cand.root()).is_none(),
579 "goal: {pattern:?}, candidate: {}",
580 cand.root().get_inner_node().to_sexp(),
581 );
582 }
583
584 pub fn test_replace_lang(
585 src: &str,
586 pattern: &str,
587 replacer: &str,
588 lang: impl LanguageExt,
589 ) -> String {
590 let mut source = lang.ast_grep(src);
591 assert!(source
592 .replace(pattern, replacer)
593 .expect("should parse successfully"));
594 source.generate()
595 }
596
597 #[test]
598 fn test_js_string() {
599 test_match_lang("'a'", "'a'", JavaScript);
600 test_match_lang("\"\"", "\"\"", JavaScript);
601 test_match_lang("''", "''", JavaScript);
602 }
603
604 #[test]
605 fn test_guess_by_extension() {
606 let path = Path::new("foo.rs");
607 assert_eq!(from_extension(path), Some(SupportLang::Rust));
608 }
609
610 }