1mod bash;
11mod cpp;
12mod csharp;
13mod css;
14mod elixir;
15mod go;
16mod haskell;
17mod html;
18mod json;
19mod kotlin;
20mod lua;
21mod parsers;
22mod php;
23mod python;
24mod ruby;
25mod rust;
26mod scala;
27mod swift;
28mod yaml;
29
30use ast_grep_core::matcher::{Pattern, PatternBuilder, PatternError};
31pub use html::Html;
32
33use ast_grep_core::meta_var::MetaVariable;
34use ast_grep_core::tree_sitter::{StrDoc, TSLanguage, TSRange};
35use ast_grep_core::Node;
36use ignore::types::{Types, TypesBuilder};
37use serde::de::Visitor;
38use serde::{de, Deserialize, Deserializer, Serialize};
39use std::borrow::Cow;
40use std::collections::HashMap;
41use std::fmt;
42use std::fmt::{Display, Formatter};
43use std::iter::repeat;
44use std::path::Path;
45use std::str::FromStr;
46
47pub use ast_grep_core::language::Language;
48pub use ast_grep_core::tree_sitter::LanguageExt;
49
50macro_rules! impl_lang {
52 ($lang: ident, $func: ident) => {
53 #[derive(Clone, Copy, Debug)]
54 pub struct $lang;
55 impl Language for $lang {
56 fn kind_to_id(&self, kind: &str) -> u16 {
57 self
58 .get_ts_language()
59 .id_for_node_kind(kind, true)
60 }
61 fn field_to_id(&self, field: &str) -> Option<u16> {
62 self
63 .get_ts_language()
64 .field_id_for_name(field)
65 .map(|f| f.get())
66 }
67 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
68 builder.build(|src| StrDoc::try_new(src, self.clone()))
69 }
70 }
71 impl LanguageExt for $lang {
72 fn get_ts_language(&self) -> TSLanguage {
73 parsers::$func().into()
74 }
75 }
76 };
77}
78
79fn pre_process_pattern(expando: char, query: &str) -> std::borrow::Cow<str> {
80 let mut ret = Vec::with_capacity(query.len());
81 let mut dollar_count = 0;
82 for c in query.chars() {
83 if c == '$' {
84 dollar_count += 1;
85 continue;
86 }
87 let need_replace = matches!(c, 'A'..='Z' | '_') || dollar_count == 3; let sigil = if need_replace { expando } else { '$' };
90 ret.extend(repeat(sigil).take(dollar_count));
91 dollar_count = 0;
92 ret.push(c);
93 }
94 let sigil = if dollar_count == 3 { expando } else { '$' };
96 ret.extend(repeat(sigil).take(dollar_count));
97 std::borrow::Cow::Owned(ret.into_iter().collect())
98}
99
100macro_rules! impl_lang_expando {
103 ($lang: ident, $func: ident, $char: expr) => {
104 #[derive(Clone, Copy, Debug)]
105 pub struct $lang;
106 impl Language for $lang {
107 fn kind_to_id(&self, kind: &str) -> u16 {
108 self
109 .get_ts_language()
110 .id_for_node_kind(kind, true)
111 }
112 fn field_to_id(&self, field: &str) -> Option<u16> {
113 self
114 .get_ts_language()
115 .field_id_for_name(field)
116 .map(|f| f.get())
117 }
118 fn expando_char(&self) -> char {
119 $char
120 }
121 fn pre_process_pattern<'q>(&self, query: &'q str) -> std::borrow::Cow<'q, str> {
122 pre_process_pattern(self.expando_char(), query)
123 }
124 fn build_pattern(&self, builder: &PatternBuilder) -> Result<Pattern, PatternError> {
125 builder.build(|src| StrDoc::try_new(src, self.clone()))
126 }
127 }
128 impl LanguageExt for $lang {
129 fn get_ts_language(&self) -> TSLanguage {
130 $crate::parsers::$func().into()
131 }
132 }
133 };
134}
135
136pub trait Alias: Display {
137 const ALIAS: &'static [&'static str];
138}
139
140macro_rules! impl_alias {
143 ($lang:ident => $as:expr) => {
144 impl Alias for $lang {
145 const ALIAS: &'static [&'static str] = $as;
146 }
147
148 impl fmt::Display for $lang {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 write!(f, "{:?}", self)
151 }
152 }
153
154 impl<'de> Deserialize<'de> for $lang {
155 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
156 where
157 D: Deserializer<'de>,
158 {
159 let vis = AliasVisitor {
160 aliases: Self::ALIAS,
161 };
162 deserializer.deserialize_str(vis)?;
163 Ok($lang)
164 }
165 }
166
167 impl From<$lang> for SupportLang {
168 fn from(_: $lang) -> Self {
169 Self::$lang
170 }
171 }
172 };
173}
174macro_rules! impl_aliases {
177 ($($lang:ident => $as:expr),* $(,)?) => {
178 $(impl_alias!($lang => $as);)*
179 const fn alias(lang: SupportLang) -> &'static [&'static str] {
180 match lang {
181 $(SupportLang::$lang => $lang::ALIAS),*
182 }
183 }
184 };
185}
186
187impl_lang_expando!(C, language_c, '_');
192impl_lang_expando!(Cpp, language_cpp, '_');
193impl_lang_expando!(CSharp, language_c_sharp, 'µ');
197impl_lang_expando!(Css, language_css, '_');
199impl_lang_expando!(Elixir, language_elixir, 'µ');
201impl_lang_expando!(Go, language_go, 'µ');
204impl_lang_expando!(Haskell, language_haskell, 'µ');
208impl_lang_expando!(Kotlin, language_kotlin, 'µ');
210impl_lang_expando!(Php, language_php, 'µ');
212impl_lang_expando!(Python, language_python, 'µ');
216impl_lang_expando!(Ruby, language_ruby, 'µ');
218impl_lang_expando!(Rust, language_rust, 'µ');
221impl_lang_expando!(Swift, language_swift, 'µ');
223
224impl_lang!(Bash, language_bash);
227impl_lang!(Java, language_java);
228impl_lang!(JavaScript, language_javascript);
229impl_lang!(Json, language_json);
230impl_lang!(Lua, language_lua);
231impl_lang!(Scala, language_scala);
232impl_lang!(Tsx, language_tsx);
233impl_lang!(TypeScript, language_typescript);
234impl_lang!(Yaml, language_yaml);
235#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Hash)]
240pub enum SupportLang {
241 Bash,
242 C,
243 Cpp,
244 CSharp,
245 Css,
246 Go,
247 Elixir,
248 Haskell,
249 Html,
250 Java,
251 JavaScript,
252 Json,
253 Kotlin,
254 Lua,
255 Php,
256 Python,
257 Ruby,
258 Rust,
259 Scala,
260 Swift,
261 Tsx,
262 TypeScript,
263 Yaml,
264}
265
266impl SupportLang {
267 pub const fn all_langs() -> &'static [SupportLang] {
268 use SupportLang::*;
269 &[
270 Bash, C, Cpp, CSharp, Css, Elixir, Go, Haskell, Html, Java, JavaScript, Json, Kotlin, Lua,
271 Php, Python, Ruby, Rust, Scala, Swift, Tsx, TypeScript, Yaml,
272 ]
273 }
274
275 pub fn file_types(&self) -> Types {
276 file_types(*self)
277 }
278}
279
280impl fmt::Display for SupportLang {
281 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282 write!(f, "{:?}", self)
283 }
284}
285
286#[derive(Debug)]
287pub enum SupportLangErr {
288 LanguageNotSupported(String),
289}
290
291impl Display for SupportLangErr {
292 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
293 use SupportLangErr::*;
294 match self {
295 LanguageNotSupported(lang) => write!(f, "{} is not supported!", lang),
296 }
297 }
298}
299
300impl std::error::Error for SupportLangErr {}
301
302impl<'de> Deserialize<'de> for SupportLang {
303 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
304 where
305 D: Deserializer<'de>,
306 {
307 deserializer.deserialize_str(SupportLangVisitor)
308 }
309}
310
311struct SupportLangVisitor;
312
313impl Visitor<'_> for SupportLangVisitor {
314 type Value = SupportLang;
315
316 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
317 f.write_str("SupportLang")
318 }
319
320 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
321 where
322 E: de::Error,
323 {
324 v.parse().map_err(de::Error::custom)
325 }
326}
327struct AliasVisitor {
328 aliases: &'static [&'static str],
329}
330
331impl Visitor<'_> for AliasVisitor {
332 type Value = &'static str;
333
334 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
335 write!(f, "one of {:?}", self.aliases)
336 }
337
338 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
339 where
340 E: de::Error,
341 {
342 self
343 .aliases
344 .iter()
345 .copied()
346 .find(|&a| v.eq_ignore_ascii_case(a))
347 .ok_or_else(|| de::Error::invalid_value(de::Unexpected::Str(v), &self))
348 }
349}
350
351impl_aliases! {
352 Bash => &["bash"],
353 C => &["c"],
354 Cpp => &["cc", "c++", "cpp", "cxx"],
355 CSharp => &["cs", "csharp"],
356 Css => &["css"],
357 Elixir => &["ex", "elixir"],
358 Go => &["go", "golang"],
359 Haskell => &["hs", "haskell"],
360 Html => &["html"],
361 Java => &["java"],
362 JavaScript => &["javascript", "js", "jsx"],
363 Json => &["json"],
364 Kotlin => &["kotlin", "kt"],
365 Lua => &["lua"],
366 Php => &["php"],
367 Python => &["py", "python"],
368 Ruby => &["rb", "ruby"],
369 Rust => &["rs", "rust"],
370 Scala => &["scala"],
371 Swift => &["swift"],
372 TypeScript => &["ts", "typescript"],
373 Tsx => &["tsx"],
374 Yaml => &["yaml", "yml"],
375}
376
377impl FromStr for SupportLang {
379 type Err = SupportLangErr;
380 fn from_str(s: &str) -> Result<Self, Self::Err> {
381 for &lang in Self::all_langs() {
382 for moniker in alias(lang) {
383 if s.eq_ignore_ascii_case(moniker) {
384 return Ok(lang);
385 }
386 }
387 }
388 Err(SupportLangErr::LanguageNotSupported(s.to_string()))
389 }
390}
391
392macro_rules! execute_lang_method {
393 ($me: path, $method: ident, $($pname:tt),*) => {
394 use SupportLang as S;
395 match $me {
396 S::Bash => Bash.$method($($pname,)*),
397 S::C => C.$method($($pname,)*),
398 S::Cpp => Cpp.$method($($pname,)*),
399 S::CSharp => CSharp.$method($($pname,)*),
400 S::Css => Css.$method($($pname,)*),
401 S::Elixir => Elixir.$method($($pname,)*),
402 S::Go => Go.$method($($pname,)*),
403 S::Haskell => Haskell.$method($($pname,)*),
404 S::Html => Html.$method($($pname,)*),
405 S::Java => Java.$method($($pname,)*),
406 S::JavaScript => JavaScript.$method($($pname,)*),
407 S::Json => Json.$method($($pname,)*),
408 S::Kotlin => Kotlin.$method($($pname,)*),
409 S::Lua => Lua.$method($($pname,)*),
410 S::Php => Php.$method($($pname,)*),
411 S::Python => Python.$method($($pname,)*),
412 S::Ruby => Ruby.$method($($pname,)*),
413 S::Rust => Rust.$method($($pname,)*),
414 S::Scala => Scala.$method($($pname,)*),
415 S::Swift => Swift.$method($($pname,)*),
416 S::Tsx => Tsx.$method($($pname,)*),
417 S::TypeScript => TypeScript.$method($($pname,)*),
418 S::Yaml => Yaml.$method($($pname,)*),
419 }
420 }
421}
422
423macro_rules! impl_lang_method {
424 ($method: ident, ($($pname:tt: $ptype:ty),*) => $return_type: ty) => {
425 #[inline]
426 fn $method(&self, $($pname: $ptype),*) -> $return_type {
427 execute_lang_method!{ self, $method, $($pname),* }
428 }
429 };
430}
431impl Language for SupportLang {
432 impl_lang_method!(kind_to_id, (kind: &str) => u16);
433 impl_lang_method!(field_to_id, (field: &str) => Option<u16>);
434 impl_lang_method!(meta_var_char, () => char);
435 impl_lang_method!(expando_char, () => char);
436 impl_lang_method!(extract_meta_var, (source: &str) => Option<MetaVariable>);
437 impl_lang_method!(build_pattern, (builder: &PatternBuilder) => Result<Pattern, PatternError>);
438 fn pre_process_pattern<'q>(&self, query: &'q str) -> Cow<'q, str> {
439 execute_lang_method! { self, pre_process_pattern, query }
440 }
441 fn from_path<P: AsRef<Path>>(path: P) -> Option<Self> {
442 from_extension(path.as_ref())
443 }
444}
445
446impl LanguageExt for SupportLang {
447 impl_lang_method!(get_ts_language, () => TSLanguage);
448 impl_lang_method!(injectable_languages, () => Option<&'static [&'static str]>);
449 fn extract_injections<L: LanguageExt>(
450 &self,
451 root: Node<StrDoc<L>>,
452 ) -> HashMap<String, Vec<TSRange>> {
453 match self {
454 SupportLang::Html => Html.extract_injections(root),
455 _ => HashMap::new(),
456 }
457 }
458}
459
460fn extensions(lang: SupportLang) -> &'static [&'static str] {
461 use SupportLang::*;
462 match lang {
463 Bash => &[
464 "bash", "bats", "cgi", "command", "env", "fcgi", "ksh", "sh", "tmux", "tool", "zsh",
465 ],
466 C => &["c", "h"],
467 Cpp => &["cc", "hpp", "cpp", "c++", "hh", "cxx", "cu", "ino"],
468 CSharp => &["cs"],
469 Css => &["css", "scss"],
470 Elixir => &["ex", "exs"],
471 Go => &["go"],
472 Haskell => &["hs"],
473 Html => &["html", "htm", "xhtml"],
474 Java => &["java"],
475 JavaScript => &["cjs", "js", "mjs", "jsx"],
476 Json => &["json"],
477 Kotlin => &["kt", "ktm", "kts"],
478 Lua => &["lua"],
479 Php => &["php"],
480 Python => &["py", "py3", "pyi", "bzl"],
481 Ruby => &["rb", "rbw", "gemspec"],
482 Rust => &["rs"],
483 Scala => &["scala", "sc", "sbt"],
484 Swift => &["swift"],
485 TypeScript => &["ts", "cts", "mts"],
486 Tsx => &["tsx"],
487 Yaml => &["yaml", "yml"],
488 }
489}
490
491fn from_extension(path: &Path) -> Option<SupportLang> {
495 let ext = path.extension()?.to_str()?;
496 SupportLang::all_langs()
497 .iter()
498 .copied()
499 .find(|&l| extensions(l).contains(&ext))
500}
501
502fn add_custom_file_type<'b>(
503 builder: &'b mut TypesBuilder,
504 file_type: &str,
505 suffix_list: &[&str],
506) -> &'b mut TypesBuilder {
507 for suffix in suffix_list {
508 let glob = format!("*.{suffix}");
509 builder
510 .add(file_type, &glob)
511 .expect("file pattern must compile");
512 }
513 builder.select(file_type)
514}
515
516fn file_types(lang: SupportLang) -> Types {
517 let mut builder = TypesBuilder::new();
518 let exts = extensions(lang);
519 let lang_name = lang.to_string();
520 add_custom_file_type(&mut builder, &lang_name, exts);
521 builder.build().expect("file type must be valid")
522}
523
524pub fn config_file_type() -> Types {
525 let mut builder = TypesBuilder::new();
526 let builder = add_custom_file_type(&mut builder, "yml", &["yml", "yaml"]);
527 builder.build().expect("yaml type must be valid")
528}
529
530#[cfg(test)]
531mod test {
532 use super::*;
533 use ast_grep_core::{matcher::MatcherExt, Pattern};
534
535 pub fn test_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
536 let cand = lang.ast_grep(source);
537 let pattern = Pattern::new(query, lang);
538 assert!(
539 pattern.find_node(cand.root()).is_some(),
540 "goal: {pattern:?}, candidate: {}",
541 cand.root().get_inner_node().to_sexp(),
542 );
543 }
544
545 pub fn test_non_match_lang(query: &str, source: &str, lang: impl LanguageExt) {
546 let cand = lang.ast_grep(source);
547 let pattern = Pattern::new(query, lang);
548 assert!(
549 pattern.find_node(cand.root()).is_none(),
550 "goal: {pattern:?}, candidate: {}",
551 cand.root().get_inner_node().to_sexp(),
552 );
553 }
554
555 pub fn test_replace_lang(
556 src: &str,
557 pattern: &str,
558 replacer: &str,
559 lang: impl LanguageExt,
560 ) -> String {
561 let mut source = lang.ast_grep(src);
562 assert!(source
563 .replace(pattern, replacer)
564 .expect("should parse successfully"));
565 source.generate()
566 }
567
568 #[test]
569 fn test_js_string() {
570 test_match_lang("'a'", "'a'", JavaScript);
571 test_match_lang("\"\"", "\"\"", JavaScript);
572 test_match_lang("''", "''", JavaScript);
573 }
574
575 #[test]
576 fn test_guess_by_extension() {
577 let path = Path::new("foo.rs");
578 assert_eq!(from_extension(path), Some(SupportLang::Rust));
579 }
580
581 }