1mod bash;
11mod cpp;
12mod csharp;
13mod css;
14mod elixir;
15mod go;
16mod haskell;
17mod html;
18mod json;
19mod kotlin;
20mod lua;
21mod nix;
22mod parsers;
23mod php;
24mod python;
25mod ruby;
26mod rust;
27mod scala;
28mod solidity;
29mod swift;
30mod yaml;
31
32use ast_grep_core::matcher::{Pattern, PatternBuilder, PatternError};
33pub use html::Html;
34
35use ast_grep_core::meta_var::MetaVariable;
36use ast_grep_core::tree_sitter::{StrDoc, TSLanguage, TSRange};
37use ast_grep_core::Node;
38use ignore::types::{Types, TypesBuilder};
39use serde::de::Visitor;
40use serde::{de, Deserialize, Deserializer, Serialize};
41use std::borrow::Cow;
42use std::collections::HashMap;
43use std::fmt;
44use std::fmt::{Display, Formatter};
45use std::iter::repeat;
46use std::path::Path;
47use std::str::FromStr;
48
49pub use ast_grep_core::language::Language;
50pub use ast_grep_core::tree_sitter::LanguageExt;
51
52macro_rules! impl_lang {
54 ($lang: ident, $func: ident) => {
55 #[derive(Clone, Copy, Debug)]
56 pub struct $lang;
57 impl Language for $lang {
58 fn kind_to_id(&self, kind: &str) -> u16 {
59 self
60 .get_ts_language()
61 .id_for_node_kind(kind, true)
62 }
63 fn field_to_id(&self, field: &str) -> Option<u16> {
64 self
65 .get_ts_language()
66 .field_id_for_name(field)
67 .map(|f| f.get())
68 }
69 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
70 builder.build(|src| StrDoc::try_new(src, self.clone()))
71 }
72 }
73 impl LanguageExt for $lang {
74 fn get_ts_language(&self) -> TSLanguage {
75 parsers::$func().into()
76 }
77 }
78 };
79}
80
81fn pre_process_pattern(expando: char, query: &str) -> std::borrow::Cow<str> {
82 let mut ret = Vec::with_capacity(query.len());
83 let mut dollar_count = 0;
84 for c in query.chars() {
85 if c == '$' {
86 dollar_count += 1;
87 continue;
88 }
89 let need_replace = matches!(c, 'A'..='Z' | '_') || dollar_count == 3; let sigil = if need_replace { expando } else { '$' };
92 ret.extend(repeat(sigil).take(dollar_count));
93 dollar_count = 0;
94 ret.push(c);
95 }
96 let sigil = if dollar_count == 3 { expando } else { '$' };
98 ret.extend(repeat(sigil).take(dollar_count));
99 std::borrow::Cow::Owned(ret.into_iter().collect())
100}
101
102macro_rules! impl_lang_expando {
105 ($lang: ident, $func: ident, $char: expr) => {
106 #[derive(Clone, Copy, Debug)]
107 pub struct $lang;
108 impl Language for $lang {
109 fn kind_to_id(&self, kind: &str) -> u16 {
110 self
111 .get_ts_language()
112 .id_for_node_kind(kind, true)
113 }
114 fn field_to_id(&self, field: &str) -> Option<u16> {
115 self
116 .get_ts_language()
117 .field_id_for_name(field)
118 .map(|f| f.get())
119 }
120 fn expando_char(&self) -> char {
121 $char
122 }
123 fn pre_process_pattern<'q>(&self, query: &'q str) -> std::borrow::Cow<'q, str> {
124 pre_process_pattern(self.expando_char(), query)
125 }
126 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
127 builder.build(|src| StrDoc::try_new(src, self.clone()))
128 }
129 }
130 impl LanguageExt for $lang {
131 fn get_ts_language(&self) -> TSLanguage {
132 $crate::parsers::$func().into()
133 }
134 }
135 };
136}
137
138pub trait Alias: Display {
139 const ALIAS: &'static [&'static str];
140}
141
142macro_rules! impl_alias {
145 ($lang:ident => $as:expr) => {
146 impl Alias for $lang {
147 const ALIAS: &'static [&'static str] = $as;
148 }
149
150 impl fmt::Display for $lang {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 write!(f, "{:?}", self)
153 }
154 }
155
156 impl<'de> Deserialize<'de> for $lang {
157 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
158 where
159 D: Deserializer<'de>,
160 {
161 let vis = AliasVisitor {
162 aliases: Self::ALIAS,
163 };
164 deserializer.deserialize_str(vis)?;
165 Ok($lang)
166 }
167 }
168
169 impl From<$lang> for SupportLang {
170 fn from(_: $lang) -> Self {
171 Self::$lang
172 }
173 }
174 };
175}
176macro_rules! impl_aliases {
179 ($($lang:ident => $as:expr),* $(,)?) => {
180 $(impl_alias!($lang => $as);)*
181 const fn alias(lang: SupportLang) -> &'static [&'static str] {
182 match lang {
183 $(SupportLang::$lang => $lang::ALIAS),*
184 }
185 }
186 };
187}
188
189impl_lang_expando!(C, language_c, '๐');
192impl_lang_expando!(Cpp, language_cpp, '๐');
193impl_lang_expando!(CSharp, language_c_sharp, 'ยต');
197impl_lang_expando!(Css, language_css, '_');
199impl_lang_expando!(Elixir, language_elixir, 'ยต');
201impl_lang_expando!(Go, language_go, 'ยต');
204impl_lang_expando!(Haskell, language_haskell, 'ยต');
208impl_lang_expando!(Kotlin, language_kotlin, 'ยต');
210impl_lang_expando!(Nix, language_nix, '_');
212impl_lang_expando!(Php, language_php, 'ยต');
214impl_lang_expando!(Python, language_python, 'ยต');
218impl_lang_expando!(Ruby, language_ruby, 'ยต');
220impl_lang_expando!(Rust, language_rust, 'ยต');
223impl_lang_expando!(Swift, language_swift, 'ยต');
225
226impl_lang!(Bash, language_bash);
229impl_lang!(Java, language_java);
230impl_lang!(JavaScript, language_javascript);
231impl_lang!(Json, language_json);
232impl_lang!(Lua, language_lua);
233impl_lang!(Scala, language_scala);
234impl_lang!(Solidity, language_solidity);
235impl_lang!(Tsx, language_tsx);
236impl_lang!(TypeScript, language_typescript);
237impl_lang!(Yaml, language_yaml);
238#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Hash)]
243pub enum SupportLang {
244 Bash,
245 C,
246 Cpp,
247 CSharp,
248 Css,
249 Go,
250 Elixir,
251 Haskell,
252 Html,
253 Java,
254 JavaScript,
255 Json,
256 Kotlin,
257 Lua,
258 Nix,
259 Php,
260 Python,
261 Ruby,
262 Rust,
263 Scala,
264 Solidity,
265 Swift,
266 Tsx,
267 TypeScript,
268 Yaml,
269}
270
271impl SupportLang {
272 pub const fn all_langs() -> &'static [SupportLang] {
273 use SupportLang::*;
274 &[
275 Bash, C, Cpp, CSharp, Css, Elixir, Go, Haskell, Html, Java, JavaScript, Json, Kotlin, Lua,
276 Nix, Php, Python, Ruby, Rust, Scala, Solidity, Swift, Tsx, TypeScript, Yaml,
277 ]
278 }
279
280 pub fn file_types(&self) -> Types {
281 file_types(*self)
282 }
283}
284
285impl fmt::Display for SupportLang {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 write!(f, "{self:?}")
288 }
289}
290
291#[derive(Debug)]
292pub enum SupportLangErr {
293 LanguageNotSupported(String),
294}
295
296impl Display for SupportLangErr {
297 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
298 use SupportLangErr::*;
299 match self {
300 LanguageNotSupported(lang) => write!(f, "{lang} is not supported!"),
301 }
302 }
303}
304
305impl std::error::Error for SupportLangErr {}
306
307impl<'de> Deserialize<'de> for SupportLang {
308 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
309 where
310 D: Deserializer<'de>,
311 {
312 deserializer.deserialize_str(SupportLangVisitor)
313 }
314}
315
316struct SupportLangVisitor;
317
318impl Visitor<'_> for SupportLangVisitor {
319 type Value = SupportLang;
320
321 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
322 f.write_str("SupportLang")
323 }
324
325 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
326 where
327 E: de::Error,
328 {
329 v.parse().map_err(de::Error::custom)
330 }
331}
332struct AliasVisitor {
333 aliases: &'static [&'static str],
334}
335
336impl Visitor<'_> for AliasVisitor {
337 type Value = &'static str;
338
339 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
340 write!(f, "one of {:?}", self.aliases)
341 }
342
343 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
344 where
345 E: de::Error,
346 {
347 self
348 .aliases
349 .iter()
350 .copied()
351 .find(|&a| v.eq_ignore_ascii_case(a))
352 .ok_or_else(|| de::Error::invalid_value(de::Unexpected::Str(v), &self))
353 }
354}
355
356impl_aliases! {
357 Bash => &["bash"],
358 C => &["c"],
359 Cpp => &["cc", "c++", "cpp", "cxx"],
360 CSharp => &["cs", "csharp"],
361 Css => &["css"],
362 Elixir => &["ex", "elixir"],
363 Go => &["go", "golang"],
364 Haskell => &["hs", "haskell"],
365 Html => &["html"],
366 Java => &["java"],
367 JavaScript => &["javascript", "js", "jsx"],
368 Json => &["json"],
369 Kotlin => &["kotlin", "kt"],
370 Lua => &["lua"],
371 Nix => &["nix"],
372 Php => &["php"],
373 Python => &["py", "python"],
374 Ruby => &["rb", "ruby"],
375 Rust => &["rs", "rust"],
376 Scala => &["scala"],
377 Solidity => &["sol", "solidity"],
378 Swift => &["swift"],
379 TypeScript => &["ts", "typescript"],
380 Tsx => &["tsx"],
381 Yaml => &["yaml", "yml"],
382}
383
384impl FromStr for SupportLang {
386 type Err = SupportLangErr;
387 fn from_str(s: &str) -> Result<Self, Self::Err> {
388 for &lang in Self::all_langs() {
389 for moniker in alias(lang) {
390 if s.eq_ignore_ascii_case(moniker) {
391 return Ok(lang);
392 }
393 }
394 }
395 Err(SupportLangErr::LanguageNotSupported(s.to_string()))
396 }
397}
398
399macro_rules! execute_lang_method {
400 ($me: path, $method: ident, $($pname:tt),*) => {
401 use SupportLang as S;
402 match $me {
403 S::Bash => Bash.$method($($pname,)*),
404 S::C => C.$method($($pname,)*),
405 S::Cpp => Cpp.$method($($pname,)*),
406 S::CSharp => CSharp.$method($($pname,)*),
407 S::Css => Css.$method($($pname,)*),
408 S::Elixir => Elixir.$method($($pname,)*),
409 S::Go => Go.$method($($pname,)*),
410 S::Haskell => Haskell.$method($($pname,)*),
411 S::Html => Html.$method($($pname,)*),
412 S::Java => Java.$method($($pname,)*),
413 S::JavaScript => JavaScript.$method($($pname,)*),
414 S::Json => Json.$method($($pname,)*),
415 S::Kotlin => Kotlin.$method($($pname,)*),
416 S::Lua => Lua.$method($($pname,)*),
417 S::Nix => Nix.$method($($pname,)*),
418 S::Php => Php.$method($($pname,)*),
419 S::Python => Python.$method($($pname,)*),
420 S::Ruby => Ruby.$method($($pname,)*),
421 S::Rust => Rust.$method($($pname,)*),
422 S::Scala => Scala.$method($($pname,)*),
423 S::Solidity => Solidity.$method($($pname,)*),
424 S::Swift => Swift.$method($($pname,)*),
425 S::Tsx => Tsx.$method($($pname,)*),
426 S::TypeScript => TypeScript.$method($($pname,)*),
427 S::Yaml => Yaml.$method($($pname,)*),
428 }
429 }
430}
431
432macro_rules! impl_lang_method {
433 ($method: ident, ($($pname:tt: $ptype:ty),*) => $return_type: ty) => {
434 #[inline]
435 fn $method(&self, $($pname: $ptype),*) -> $return_type {
436 execute_lang_method!{ self, $method, $($pname),* }
437 }
438 };
439}
440impl Language for SupportLang {
441 impl_lang_method!(kind_to_id, (kind: &str) => u16);
442 impl_lang_method!(field_to_id, (field: &str) => Option<u16>);
443 impl_lang_method!(meta_var_char, () => char);
444 impl_lang_method!(expando_char, () => char);
445 impl_lang_method!(extract_meta_var, (source: &str) => Option<MetaVariable>);
446 impl_lang_method!(build_pattern, (builder: &PatternBuilder) => Result<Pattern, PatternError>);
447 fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
448 execute_lang_method! { self, pre_process_pattern, query }
449 }
450 fn from_path<P: AsRef<Path>>(path: P) -> Option<Self> {
451 from_extension(path.as_ref())
452 }
453}
454
455impl LanguageExt for SupportLang {
456 impl_lang_method!(get_ts_language, () => TSLanguage);
457 impl_lang_method!(injectable_languages, () => Option<&'static [&'static str]>);
458 fn extract_injections<L: LanguageExt>(
459 &self,
460 root: Node<StrDoc<L>>,
461 ) -> HashMap<String, Vec<TSRange>> {
462 match self {
463 SupportLang::Html => Html.extract_injections(root),
464 _ => HashMap::new(),
465 }
466 }
467}
468
469fn extensions(lang: SupportLang) -> &'static [&'static str] {
470 use SupportLang::*;
471 match lang {
472 Bash => &[
473 "bash", "bats", "cgi", "command", "env", "fcgi", "ksh", "sh", "tmux", "tool", "zsh",
474 ],
475 C => &["c", "h"],
476 Cpp => &["cc", "hpp", "cpp", "c++", "hh", "cxx", "cu", "ino"],
477 CSharp => &["cs"],
478 Css => &["css", "scss"],
479 Elixir => &["ex", "exs"],
480 Go => &["go"],
481 Haskell => &["hs"],
482 Html => &["html", "htm", "xhtml"],
483 Java => &["java"],
484 JavaScript => &["cjs", "js", "mjs", "jsx"],
485 Json => &["json"],
486 Kotlin => &["kt", "ktm", "kts"],
487 Lua => &["lua"],
488 Nix => &["nix"],
489 Php => &["php"],
490 Python => &["py", "py3", "pyi", "bzl"],
491 Ruby => &["rb", "rbw", "gemspec"],
492 Rust => &["rs"],
493 Scala => &["scala", "sc", "sbt"],
494 Solidity => &["sol"],
495 Swift => &["swift"],
496 TypeScript => &["ts", "cts", "mts"],
497 Tsx => &["tsx"],
498 Yaml => &["yaml", "yml"],
499 }
500}
501
502fn from_extension(path: &Path) -> Option<SupportLang> {
506 let ext = path.extension()?.to_str()?;
507 SupportLang::all_langs()
508 .iter()
509 .copied()
510 .find(|&l| extensions(l).contains(&ext))
511}
512
513fn add_custom_file_type<'b>(
514 builder: &'b mut TypesBuilder,
515 file_type: &str,
516 suffix_list: &[&str],
517) -> &'b mut TypesBuilder {
518 for suffix in suffix_list {
519 let glob = format!("*.{suffix}");
520 builder
521 .add(file_type, &glob)
522 .expect("file pattern must compile");
523 }
524 builder.select(file_type)
525}
526
527fn file_types(lang: SupportLang) -> Types {
528 let mut builder = TypesBuilder::new();
529 let exts = extensions(lang);
530 let lang_name = lang.to_string();
531 add_custom_file_type(&mut builder, &lang_name, exts);
532 builder.build().expect("file type must be valid")
533}
534
535pub fn config_file_type() -> Types {
536 let mut builder = TypesBuilder::new();
537 let builder = add_custom_file_type(&mut builder, "yml", &["yml", "yaml"]);
538 builder.build().expect("yaml type must be valid")
539}
540
541#[cfg(test)]
542mod test {
543 use super::*;
544 use ast_grep_core::{matcher::MatcherExt, Pattern};
545
546 pub fn test_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
547 let cand = lang.ast_grep(source);
548 let pattern = Pattern::new(query, lang);
549 assert!(
550 pattern.find_node(cand.root()).is_some(),
551 "goal: {pattern:?}, candidate: {}",
552 cand.root().get_inner_node().to_sexp(),
553 );
554 }
555
556 pub fn test_non_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
557 let cand = lang.ast_grep(source);
558 let pattern = Pattern::new(query, lang);
559 assert!(
560 pattern.find_node(cand.root()).is_none(),
561 "goal: {pattern:?}, candidate: {}",
562 cand.root().get_inner_node().to_sexp(),
563 );
564 }
565
566 pub fn test_replace_lang(
567 src: &str,
568 pattern: &str,
569 replacer: &str,
570 lang: impl LanguageExt,
571 ) -> String {
572 let mut source = lang.ast_grep(src);
573 assert!(source
574 .replace(pattern, replacer)
575 .expect("should parse successfully"));
576 source.generate()
577 }
578
579 #[test]
580 fn test_js_string() {
581 test_match_lang("'a'", "'a'", JavaScript);
582 test_match_lang("\"\"", "\"\"", JavaScript);
583 test_match_lang("''", "''", JavaScript);
584 }
585
586 #[test]
587 fn test_guess_by_extension() {
588 let path = Path::new("foo.rs");
589 assert_eq!(from_extension(path), Some(SupportLang::Rust));
590 }
591
592 }