ast-grep 0.41.1

Search and Rewrite code at large scale using precise AST pattern
Documentation
use super::SgLang;
use crate::utils::ErrorContext as EC;
use ast_grep_config::{DeserializeEnv, RuleCore, SerializableRuleCore};
use ast_grep_core::{
  tree_sitter::{LanguageExt, StrDoc, TSRange},
  Doc, Node,
};

use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};

use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::OnceLock;

// NB, you should not use SgLang in the (de_serialize interface
// since Injected is used before lang registration in sgconfig.yml
#[derive(Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum Injected {
  Static(String),
  Dynamic(Vec<String>),
}

#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct SerializableInjection {
  #[serde(flatten)]
  core: SerializableRuleCore,
  /// The host language, e.g. html, contains other languages
  host_language: String,
  /// Injected language according to the rule
  /// It accepts either a string like js for single static language.
  /// or an array of string like [js, ts] for dynamic language detection.
  injected: Injected,
}

struct Injection {
  host: SgLang,
  rules: Vec<(RuleCore, Option<String>)>,
  injectable: HashSet<String>,
}

impl Injection {
  fn new(lang: SgLang) -> Self {
    Self {
      host: lang,
      rules: vec![],
      injectable: Default::default(),
    }
  }
}

pub fn register_injetables(injections: Vec<SerializableInjection>) -> Result<()> {
  let mut injectable = HashMap::new();
  for injection in injections {
    register_injetable(injection, &mut injectable)?;
  }
  merge_default_injecatable(&mut injectable);
  LANG_INJECTIONS.set(injectable.into_values().collect()).ok();
  let injects = LANG_INJECTIONS.get().expect("just initialized");
  INJECTABLE_LANGS
    .set(
      injects
        .iter()
        .map(|inj| {
          (
            inj.host,
            inj.injectable.iter().map(|s| s.as_str()).collect(),
          )
        })
        .collect(),
    )
    .ok();
  Ok(())
}

fn merge_default_injecatable(ret: &mut HashMap<SgLang, Injection>) {
  for (lang, injection) in ret {
    let langs = match lang {
      SgLang::Builtin(b) => b.injectable_languages(),
      SgLang::Custom(c) => c.injectable_languages(),
    };
    let Some(langs) = langs else {
      continue;
    };
    injection
      .injectable
      .extend(langs.iter().map(|s| s.to_string()));
  }
}

fn register_injetable(
  injection: SerializableInjection,
  injectable: &mut HashMap<SgLang, Injection>,
) -> Result<()> {
  let lang = SgLang::from_str(&injection.host_language)?;
  let env = DeserializeEnv::new(lang);
  let rule = injection.core.get_matcher(env).context(EC::LangInjection)?;
  let default_lang = match &injection.injected {
    Injected::Static(s) => Some(s.clone()),
    Injected::Dynamic(_) => None,
  };
  let entry = injectable
    .entry(lang)
    .or_insert_with(|| Injection::new(lang));
  match injection.injected {
    Injected::Static(s) => {
      entry.injectable.insert(s);
    }
    Injected::Dynamic(v) => entry.injectable.extend(v),
  }
  entry.rules.push((rule, default_lang));
  Ok(())
}

static LANG_INJECTIONS: OnceLock<Vec<Injection>> = OnceLock::new();
static INJECTABLE_LANGS: OnceLock<Vec<(SgLang, Vec<&'static str>)>> = OnceLock::new();

pub fn injectable_languages(lang: SgLang) -> Option<&'static [&'static str]> {
  // NB: custom injection and builtin injections are resolved in INJECTABLE_LANGS
  let injections = INJECTABLE_LANGS.get().map(|v| v.as_slice()).unwrap_or(&[]);
  let Some(injection) = injections.iter().find(|i| i.0 == lang) else {
    return match lang {
      SgLang::Builtin(b) => b.injectable_languages(),
      SgLang::Custom(c) => c.injectable_languages(),
    };
  };
  Some(&injection.1)
}

pub fn extract_injections<L: LanguageExt>(
  lang: &SgLang,
  root: Node<StrDoc<L>>,
) -> Vec<(String, Vec<TSRange>)> {
  let mut ret = match lang {
    SgLang::Custom(c) => c.extract_injections(root.clone()),
    SgLang::Builtin(b) => b.extract_injections(root.clone()),
  };
  let injections = LANG_INJECTIONS.get().map(|v| v.as_slice()).unwrap_or(&[]);
  extract_custom_inject(lang, injections, root, &mut ret);
  ret
}

fn extract_custom_inject<L: LanguageExt>(
  lang: &SgLang,
  injections: &[Injection],
  root: Node<StrDoc<L>>,
  ret: &mut Vec<(String, Vec<TSRange>)>,
) {
  let Some(rules) = injections.iter().find(|n| n.host == *lang) else {
    return;
  };
  for (rule, default_lang) in &rules.rules {
    for m in root.find_all(rule) {
      let env = m.get_env();
      let Some(region) = env.get_match("CONTENT") else {
        continue;
      };
      let Some(lang) = env
        .get_match("LANG")
        .map(|n| n.text().to_string())
        .or_else(|| default_lang.clone())
      else {
        continue;
      };
      let range = node_to_range(region);
      ret.push((lang, vec![range]));
    }
  }
}

fn node_to_range<D: Doc>(node: &Node<D>) -> TSRange {
  let r = node.range();
  let start = node.start_pos();
  let sp = start.byte_point();
  let sp = tree_sitter::Point::new(sp.0, sp.1);
  let end = node.end_pos();
  let ep = end.byte_point();
  let ep = tree_sitter::Point::new(ep.0, ep.1);
  TSRange {
    start_byte: r.start,
    end_byte: r.end,
    start_point: sp,
    end_point: ep,
  }
}

#[cfg(test)]
mod test {
  use super::*;
  use ast_grep_config::from_str;
  use ast_grep_language::SupportLang;
  const DYNAMIC: &str = "
hostLanguage: js
rule:
  pattern: styled.$LANG`$CONTENT`
injected: [css]";
  const STATIC: &str = "
hostLanguage: js
rule:
  pattern: styled`$CONTENT`
injected: css";
  #[test]
  fn test_deserialize() {
    let inj: SerializableInjection = from_str(STATIC).expect("should ok");
    assert!(matches!(inj.injected, Injected::Static(_)));
    let inj: SerializableInjection = from_str(DYNAMIC).expect("should ok");
    assert!(matches!(inj.injected, Injected::Dynamic(_)));
  }

  const BAD: &str = "
hostLanguage: HTML
rule:
  kind: not_exist
injected: [js, ts, tsx]";

  #[test]
  fn test_bad_inject() {
    let mut map = HashMap::new();
    let inj: SerializableInjection = from_str(BAD).expect("should ok");
    let ret = register_injetable(inj, &mut map);
    assert!(ret.is_err());
    let ec = ret.unwrap_err().downcast::<EC>().expect("should ok");
    assert!(matches!(ec, EC::LangInjection));
  }

  #[test]
  fn test_good_injection() {
    let mut map = HashMap::new();
    let inj: SerializableInjection = from_str(STATIC).expect("should ok");
    let ret = register_injetable(inj, &mut map);
    assert!(ret.is_ok());
    let inj: SerializableInjection = from_str(DYNAMIC).expect("should ok");
    let ret = register_injetable(inj, &mut map);
    assert!(ret.is_ok());
    assert_eq!(map.len(), 1);
    let injections: Vec<_> = map.into_values().collect();
    let mut ret = Vec::new();
    let lang = SgLang::from(SupportLang::JavaScript);
    let sg = lang.ast_grep("const a = styled`.btn { margin: 0; }`");
    let root = sg.root();
    extract_custom_inject(&lang, &injections, root, &mut ret);
    assert_eq!(ret.len(), 1);
    assert_eq!(ret[0].0, "css");
    assert_eq!(ret[0].1.len(), 1);
    ret.clear();
    let sg = lang.ast_grep("const a = styled.css`.btn { margin: 0; }`");
    let root = sg.root();
    extract_custom_inject(&lang, &injections, root, &mut ret);
    assert_eq!(ret.len(), 1);
    assert_eq!(ret[0].0, "css");
    assert_eq!(ret[0].1.len(), 1);
  }

  #[test]
  fn test_multiple_matches_produce_separate_entries() {
    let mut map = HashMap::new();
    let inj: SerializableInjection = from_str(STATIC).expect("should ok");
    register_injetable(inj, &mut map).unwrap();
    let injections: Vec<_> = map.into_values().collect();
    let mut ret = Vec::new();
    let lang = SgLang::from(SupportLang::JavaScript);
    let sg = lang
      .ast_grep("const a = styled`.btn { margin: 0; }`; const b = styled`.card { padding: 1em; }`");
    let root = sg.root();
    extract_custom_inject(&lang, &injections, root, &mut ret);
    assert_eq!(ret.len(), 2);
    assert_eq!(ret[0].0, "css");
    assert_eq!(ret[0].1.len(), 1);
    assert_eq!(ret[1].0, "css");
    assert_eq!(ret[1].1.len(), 1);
    assert_ne!(ret[0].1[0].start_byte, ret[1].1[0].start_byte);
  }
}