1mod bash;
11mod cpp;
12mod csharp;
13mod css;
14mod elixir;
15mod go;
16mod haskell;
17mod hcl;
18mod html;
19mod json;
20mod kotlin;
21mod lua;
22mod nix;
23mod parsers;
24mod php;
25mod python;
26mod ruby;
27mod rust;
28mod scala;
29mod solidity;
30mod swift;
31mod yaml;
32
33use ast_grep_core::matcher::{Pattern, PatternBuilder, PatternError};
34pub use html::Html;
35
36use ast_grep_core::meta_var::MetaVariable;
37use ast_grep_core::tree_sitter::{StrDoc, TSLanguage, TSRange};
38use ast_grep_core::Node;
39use ignore::types::{Types, TypesBuilder};
40use serde::de::Visitor;
41use serde::{de, Deserialize, Deserializer, Serialize};
42use std::borrow::Cow;
43use std::collections::HashMap;
44use std::fmt;
45use std::fmt::{Display, Formatter};
46use std::iter::repeat;
47use std::path::Path;
48use std::str::FromStr;
49
50pub use ast_grep_core::language::Language;
51pub use ast_grep_core::tree_sitter::LanguageExt;
52
53macro_rules! impl_lang {
55 ($lang: ident, $func: ident) => {
56 #[derive(Clone, Copy, Debug)]
57 pub struct $lang;
58 impl Language for $lang {
59 fn kind_to_id(&self, kind: &str) -> u16 {
60 self
61 .get_ts_language()
62 .id_for_node_kind(kind, true)
63 }
64 fn field_to_id(&self, field: &str) -> Option<u16> {
65 self
66 .get_ts_language()
67 .field_id_for_name(field)
68 .map(|f| f.get())
69 }
70 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
71 builder.build(|src| StrDoc::try_new(src, self.clone()))
72 }
73 }
74 impl LanguageExt for $lang {
75 fn get_ts_language(&self) -> TSLanguage {
76 parsers::$func().into()
77 }
78 }
79 };
80}
81
82fn pre_process_pattern(expando: char, query: &str) -> std::borrow::Cow<str> {
83 let mut ret = Vec::with_capacity(query.len());
84 let mut dollar_count = 0;
85 for c in query.chars() {
86 if c == '$' {
87 dollar_count += 1;
88 continue;
89 }
90 let need_replace = matches!(c, 'A'..='Z' | '_') || dollar_count == 3; let sigil = if need_replace { expando } else { '$' };
93 ret.extend(repeat(sigil).take(dollar_count));
94 dollar_count = 0;
95 ret.push(c);
96 }
97 let sigil = if dollar_count == 3 { expando } else { '$' };
99 ret.extend(repeat(sigil).take(dollar_count));
100 std::borrow::Cow::Owned(ret.into_iter().collect())
101}
102
103macro_rules! impl_lang_expando {
106 ($lang: ident, $func: ident, $char: expr) => {
107 #[derive(Clone, Copy, Debug)]
108 pub struct $lang;
109 impl Language for $lang {
110 fn kind_to_id(&self, kind: &str) -> u16 {
111 self
112 .get_ts_language()
113 .id_for_node_kind(kind, true)
114 }
115 fn field_to_id(&self, field: &str) -> Option<u16> {
116 self
117 .get_ts_language()
118 .field_id_for_name(field)
119 .map(|f| f.get())
120 }
121 fn expando_char(&self) -> char {
122 $char
123 }
124 fn pre_process_pattern<'q>(&self, query: &'q str) -> std::borrow::Cow<'q, str> {
125 pre_process_pattern(self.expando_char(), query)
126 }
127 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
128 builder.build(|src| StrDoc::try_new(src, self.clone()))
129 }
130 }
131 impl LanguageExt for $lang {
132 fn get_ts_language(&self) -> TSLanguage {
133 $crate::parsers::$func().into()
134 }
135 }
136 };
137}
138
139pub trait Alias: Display {
140 const ALIAS: &'static [&'static str];
141}
142
143macro_rules! impl_alias {
146 ($lang:ident => $as:expr) => {
147 impl Alias for $lang {
148 const ALIAS: &'static [&'static str] = $as;
149 }
150
151 impl fmt::Display for $lang {
152 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153 write!(f, "{:?}", self)
154 }
155 }
156
157 impl<'de> Deserialize<'de> for $lang {
158 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
159 where
160 D: Deserializer<'de>,
161 {
162 let vis = AliasVisitor {
163 aliases: Self::ALIAS,
164 };
165 deserializer.deserialize_str(vis)?;
166 Ok($lang)
167 }
168 }
169
170 impl From<$lang> for SupportLang {
171 fn from(_: $lang) -> Self {
172 Self::$lang
173 }
174 }
175 };
176}
177macro_rules! impl_aliases {
180 ($($lang:ident => $as:expr),* $(,)?) => {
181 $(impl_alias!($lang => $as);)*
182 const fn alias(lang: SupportLang) -> &'static [&'static str] {
183 match lang {
184 $(SupportLang::$lang => $lang::ALIAS),*
185 }
186 }
187 };
188}
189
190impl_lang_expando!(C, language_c, '๐');
193impl_lang_expando!(Cpp, language_cpp, '๐');
194impl_lang_expando!(CSharp, language_c_sharp, 'ยต');
198impl_lang_expando!(Css, language_css, '_');
200impl_lang_expando!(Elixir, language_elixir, 'ยต');
202impl_lang_expando!(Go, language_go, 'ยต');
205impl_lang_expando!(Haskell, language_haskell, 'ยต');
209impl_lang_expando!(Hcl, language_hcl, 'ยต');
211impl_lang_expando!(Kotlin, language_kotlin, 'ยต');
213impl_lang_expando!(Nix, language_nix, '_');
215impl_lang_expando!(Php, language_php, 'ยต');
217impl_lang_expando!(Python, language_python, 'ยต');
221impl_lang_expando!(Ruby, language_ruby, 'ยต');
223impl_lang_expando!(Rust, language_rust, 'ยต');
226impl_lang_expando!(Swift, language_swift, 'ยต');
228
229impl_lang!(Bash, language_bash);
232impl_lang!(Java, language_java);
233impl_lang!(JavaScript, language_javascript);
234impl_lang!(Json, language_json);
235impl_lang!(Lua, language_lua);
236impl_lang!(Scala, language_scala);
237impl_lang!(Solidity, language_solidity);
238impl_lang!(Tsx, language_tsx);
239impl_lang!(TypeScript, language_typescript);
240impl_lang!(Yaml, language_yaml);
241#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Hash)]
246pub enum SupportLang {
247 Bash,
248 C,
249 Cpp,
250 CSharp,
251 Css,
252 Go,
253 Elixir,
254 Haskell,
255 Hcl,
256 Html,
257 Java,
258 JavaScript,
259 Json,
260 Kotlin,
261 Lua,
262 Nix,
263 Php,
264 Python,
265 Ruby,
266 Rust,
267 Scala,
268 Solidity,
269 Swift,
270 Tsx,
271 TypeScript,
272 Yaml,
273}
274
275impl SupportLang {
276 pub const fn all_langs() -> &'static [SupportLang] {
277 use SupportLang::*;
278 &[
279 Bash, C, Cpp, CSharp, Css, Elixir, Go, Haskell, Hcl, Html, Java, JavaScript, Json, Kotlin,
280 Lua, Nix, Php, Python, Ruby, Rust, Scala, Solidity, Swift, Tsx, TypeScript, Yaml,
281 ]
282 }
283
284 pub fn file_types(&self) -> Types {
285 file_types(*self)
286 }
287}
288
289impl fmt::Display for SupportLang {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 write!(f, "{self:?}")
292 }
293}
294
295#[derive(Debug)]
296pub enum SupportLangErr {
297 LanguageNotSupported(String),
298}
299
300impl Display for SupportLangErr {
301 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
302 use SupportLangErr::*;
303 match self {
304 LanguageNotSupported(lang) => write!(f, "{lang} is not supported!"),
305 }
306 }
307}
308
309impl std::error::Error for SupportLangErr {}
310
311impl<'de> Deserialize<'de> for SupportLang {
312 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
313 where
314 D: Deserializer<'de>,
315 {
316 deserializer.deserialize_str(SupportLangVisitor)
317 }
318}
319
320struct SupportLangVisitor;
321
322impl Visitor<'_> for SupportLangVisitor {
323 type Value = SupportLang;
324
325 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
326 f.write_str("SupportLang")
327 }
328
329 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
330 where
331 E: de::Error,
332 {
333 v.parse().map_err(de::Error::custom)
334 }
335}
336struct AliasVisitor {
337 aliases: &'static [&'static str],
338}
339
340impl Visitor<'_> for AliasVisitor {
341 type Value = &'static str;
342
343 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
344 write!(f, "one of {:?}", self.aliases)
345 }
346
347 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
348 where
349 E: de::Error,
350 {
351 self
352 .aliases
353 .iter()
354 .copied()
355 .find(|&a| v.eq_ignore_ascii_case(a))
356 .ok_or_else(|| de::Error::invalid_value(de::Unexpected::Str(v), &self))
357 }
358}
359
360impl_aliases! {
361 Bash => &["bash"],
362 C => &["c"],
363 Cpp => &["cc", "c++", "cpp", "cxx"],
364 CSharp => &["cs", "csharp"],
365 Css => &["css"],
366 Elixir => &["ex", "elixir"],
367 Go => &["go", "golang"],
368 Haskell => &["hs", "haskell"],
369 Hcl => &["hcl"],
370 Html => &["html"],
371 Java => &["java"],
372 JavaScript => &["javascript", "js", "jsx"],
373 Json => &["json"],
374 Kotlin => &["kotlin", "kt"],
375 Lua => &["lua"],
376 Nix => &["nix"],
377 Php => &["php"],
378 Python => &["py", "python"],
379 Ruby => &["rb", "ruby"],
380 Rust => &["rs", "rust"],
381 Scala => &["scala"],
382 Solidity => &["sol", "solidity"],
383 Swift => &["swift"],
384 TypeScript => &["ts", "typescript"],
385 Tsx => &["tsx"],
386 Yaml => &["yaml", "yml"],
387}
388
389impl FromStr for SupportLang {
391 type Err = SupportLangErr;
392 fn from_str(s: &str) -> Result<Self, Self::Err> {
393 for &lang in Self::all_langs() {
394 for moniker in alias(lang) {
395 if s.eq_ignore_ascii_case(moniker) {
396 return Ok(lang);
397 }
398 }
399 }
400 Err(SupportLangErr::LanguageNotSupported(s.to_string()))
401 }
402}
403
404macro_rules! execute_lang_method {
405 ($me: path, $method: ident, $($pname:tt),*) => {
406 use SupportLang as S;
407 match $me {
408 S::Bash => Bash.$method($($pname,)*),
409 S::C => C.$method($($pname,)*),
410 S::Cpp => Cpp.$method($($pname,)*),
411 S::CSharp => CSharp.$method($($pname,)*),
412 S::Css => Css.$method($($pname,)*),
413 S::Elixir => Elixir.$method($($pname,)*),
414 S::Go => Go.$method($($pname,)*),
415 S::Haskell => Haskell.$method($($pname,)*),
416 S::Hcl => Hcl.$method($($pname,)*),
417 S::Html => Html.$method($($pname,)*),
418 S::Java => Java.$method($($pname,)*),
419 S::JavaScript => JavaScript.$method($($pname,)*),
420 S::Json => Json.$method($($pname,)*),
421 S::Kotlin => Kotlin.$method($($pname,)*),
422 S::Lua => Lua.$method($($pname,)*),
423 S::Nix => Nix.$method($($pname,)*),
424 S::Php => Php.$method($($pname,)*),
425 S::Python => Python.$method($($pname,)*),
426 S::Ruby => Ruby.$method($($pname,)*),
427 S::Rust => Rust.$method($($pname,)*),
428 S::Scala => Scala.$method($($pname,)*),
429 S::Solidity => Solidity.$method($($pname,)*),
430 S::Swift => Swift.$method($($pname,)*),
431 S::Tsx => Tsx.$method($($pname,)*),
432 S::TypeScript => TypeScript.$method($($pname,)*),
433 S::Yaml => Yaml.$method($($pname,)*),
434 }
435 }
436}
437
438macro_rules! impl_lang_method {
439 ($method: ident, ($($pname:tt: $ptype:ty),*) => $return_type: ty) => {
440 #[inline]
441 fn $method(&self, $($pname: $ptype),*) -> $return_type {
442 execute_lang_method!{ self, $method, $($pname),* }
443 }
444 };
445}
446impl Language for SupportLang {
447 impl_lang_method!(kind_to_id, (kind: &str) => u16);
448 impl_lang_method!(field_to_id, (field: &str) => Option<u16>);
449 impl_lang_method!(meta_var_char, () => char);
450 impl_lang_method!(expando_char, () => char);
451 impl_lang_method!(extract_meta_var, (source: &str) => Option<MetaVariable>);
452 impl_lang_method!(build_pattern, (builder: &PatternBuilder) => Result<Pattern, PatternError>);
453 fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
454 execute_lang_method! { self, pre_process_pattern, query }
455 }
456 fn from_path<P: AsRef<Path>>(path: P) -> Option<Self> {
457 from_extension(path.as_ref())
458 }
459}
460
461impl LanguageExt for SupportLang {
462 impl_lang_method!(get_ts_language, () => TSLanguage);
463 impl_lang_method!(injectable_languages, () => Option<&'static [&'static str]>);
464 fn extract_injections<L: LanguageExt>(
465 &self,
466 root: Node<StrDoc<L>>,
467 ) -> HashMap<String, Vec<TSRange>> {
468 match self {
469 SupportLang::Html => Html.extract_injections(root),
470 _ => HashMap::new(),
471 }
472 }
473}
474
475fn extensions(lang: SupportLang) -> &'static [&'static str] {
476 use SupportLang::*;
477 match lang {
478 Bash => &[
479 "bash", "bats", "cgi", "command", "env", "fcgi", "ksh", "sh", "tmux", "tool", "zsh",
480 ],
481 C => &["c", "h"],
482 Cpp => &["cc", "hpp", "cpp", "c++", "hh", "cxx", "cu", "ino"],
483 CSharp => &["cs"],
484 Css => &["css", "scss"],
485 Elixir => &["ex", "exs"],
486 Go => &["go"],
487 Haskell => &["hs"],
488 Hcl => &["hcl"],
489 Html => &["html", "htm", "xhtml"],
490 Java => &["java"],
491 JavaScript => &["cjs", "js", "mjs", "jsx"],
492 Json => &["json"],
493 Kotlin => &["kt", "ktm", "kts"],
494 Lua => &["lua"],
495 Nix => &["nix"],
496 Php => &["php"],
497 Python => &["py", "py3", "pyi", "bzl"],
498 Ruby => &["rb", "rbw", "gemspec"],
499 Rust => &["rs"],
500 Scala => &["scala", "sc", "sbt"],
501 Solidity => &["sol"],
502 Swift => &["swift"],
503 TypeScript => &["ts", "cts", "mts"],
504 Tsx => &["tsx"],
505 Yaml => &["yaml", "yml"],
506 }
507}
508
509fn from_extension(path: &Path) -> Option<SupportLang> {
513 let ext = path.extension()?.to_str()?;
514 SupportLang::all_langs()
515 .iter()
516 .copied()
517 .find(|&l| extensions(l).contains(&ext))
518}
519
520fn add_custom_file_type<'b>(
521 builder: &'b mut TypesBuilder,
522 file_type: &str,
523 suffix_list: &[&str],
524) -> &'b mut TypesBuilder {
525 for suffix in suffix_list {
526 let glob = format!("*.{suffix}");
527 builder
528 .add(file_type, &glob)
529 .expect("file pattern must compile");
530 }
531 builder.select(file_type)
532}
533
534fn file_types(lang: SupportLang) -> Types {
535 let mut builder = TypesBuilder::new();
536 let exts = extensions(lang);
537 let lang_name = lang.to_string();
538 add_custom_file_type(&mut builder, &lang_name, exts);
539 builder.build().expect("file type must be valid")
540}
541
542pub fn config_file_type() -> Types {
543 let mut builder = TypesBuilder::new();
544 let builder = add_custom_file_type(&mut builder, "yml", &["yml", "yaml"]);
545 builder.build().expect("yaml type must be valid")
546}
547
548#[cfg(test)]
549mod test {
550 use super::*;
551 use ast_grep_core::{matcher::MatcherExt, Pattern};
552
553 pub fn test_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
554 let cand = lang.ast_grep(source);
555 let pattern = Pattern::new(query, lang);
556 assert!(
557 pattern.find_node(cand.root()).is_some(),
558 "goal: {pattern:?}, candidate: {}",
559 cand.root().get_inner_node().to_sexp(),
560 );
561 }
562
563 pub fn test_non_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
564 let cand = lang.ast_grep(source);
565 let pattern = Pattern::new(query, lang);
566 assert!(
567 pattern.find_node(cand.root()).is_none(),
568 "goal: {pattern:?}, candidate: {}",
569 cand.root().get_inner_node().to_sexp(),
570 );
571 }
572
573 pub fn test_replace_lang(
574 src: &str,
575 pattern: &str,
576 replacer: &str,
577 lang: impl LanguageExt,
578 ) -> String {
579 let mut source = lang.ast_grep(src);
580 assert!(source
581 .replace(pattern, replacer)
582 .expect("should parse successfully"));
583 source.generate()
584 }
585
586 #[test]
587 fn test_js_string() {
588 test_match_lang("'a'", "'a'", JavaScript);
589 test_match_lang("\"\"", "\"\"", JavaScript);
590 test_match_lang("''", "''", JavaScript);
591 }
592
593 #[test]
594 fn test_guess_by_extension() {
595 let path = Path::new("foo.rs");
596 assert_eq!(from_extension(path), Some(SupportLang::Rust));
597 }
598
599 }