use super::SgLang;
use crate::utils::ErrorContext as EC;
use ast_grep_config::{DeserializeEnv, RuleCore, SerializableRuleCore};
use ast_grep_core::{
tree_sitter::{LanguageExt, StrDoc, TSRange},
Doc, Node,
};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use std::sync::OnceLock;
#[derive(Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum Injected {
Static(String),
Dynamic(Vec<String>),
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct SerializableInjection {
#[serde(flatten)]
core: SerializableRuleCore,
host_language: String,
injected: Injected,
}
struct Injection {
host: SgLang,
rules: Vec<(RuleCore, Option<String>)>,
injectable: HashSet<String>,
}
impl Injection {
fn new(lang: SgLang) -> Self {
Self {
host: lang,
rules: vec![],
injectable: Default::default(),
}
}
}
pub fn register_injetables(injections: Vec<SerializableInjection>) -> Result<()> {
let mut injectable = HashMap::new();
for injection in injections {
register_injetable(injection, &mut injectable)?;
}
merge_default_injecatable(&mut injectable);
LANG_INJECTIONS.set(injectable.into_values().collect()).ok();
let injects = LANG_INJECTIONS.get().expect("just initialized");
INJECTABLE_LANGS
.set(
injects
.iter()
.map(|inj| {
(
inj.host,
inj.injectable.iter().map(|s| s.as_str()).collect(),
)
})
.collect(),
)
.ok();
Ok(())
}
fn merge_default_injecatable(ret: &mut HashMap<SgLang, Injection>) {
for (lang, injection) in ret {
let langs = match lang {
SgLang::Builtin(b) => b.injectable_languages(),
SgLang::Custom(c) => c.injectable_languages(),
};
let Some(langs) = langs else {
continue;
};
injection
.injectable
.extend(langs.iter().map(|s| s.to_string()));
}
}
fn register_injetable(
injection: SerializableInjection,
injectable: &mut HashMap<SgLang, Injection>,
) -> Result<()> {
let lang = SgLang::from_str(&injection.host_language)?;
let env = DeserializeEnv::new(lang);
let rule = injection.core.get_matcher(env).context(EC::LangInjection)?;
let default_lang = match &injection.injected {
Injected::Static(s) => Some(s.clone()),
Injected::Dynamic(_) => None,
};
let entry = injectable
.entry(lang)
.or_insert_with(|| Injection::new(lang));
match injection.injected {
Injected::Static(s) => {
entry.injectable.insert(s);
}
Injected::Dynamic(v) => entry.injectable.extend(v),
}
entry.rules.push((rule, default_lang));
Ok(())
}
static LANG_INJECTIONS: OnceLock<Vec<Injection>> = OnceLock::new();
static INJECTABLE_LANGS: OnceLock<Vec<(SgLang, Vec<&'static str>)>> = OnceLock::new();
pub fn injectable_languages(lang: SgLang) -> Option<&'static [&'static str]> {
let injections = INJECTABLE_LANGS.get().map(|v| v.as_slice()).unwrap_or(&[]);
let Some(injection) = injections.iter().find(|i| i.0 == lang) else {
return match lang {
SgLang::Builtin(b) => b.injectable_languages(),
SgLang::Custom(c) => c.injectable_languages(),
};
};
Some(&injection.1)
}
pub fn extract_injections<L: LanguageExt>(
lang: &SgLang,
root: Node<StrDoc<L>>,
) -> Vec<(String, Vec<TSRange>)> {
let mut ret = match lang {
SgLang::Custom(c) => c.extract_injections(root.clone()),
SgLang::Builtin(b) => b.extract_injections(root.clone()),
};
let injections = LANG_INJECTIONS.get().map(|v| v.as_slice()).unwrap_or(&[]);
extract_custom_inject(lang, injections, root, &mut ret);
ret
}
fn extract_custom_inject<L: LanguageExt>(
lang: &SgLang,
injections: &[Injection],
root: Node<StrDoc<L>>,
ret: &mut Vec<(String, Vec<TSRange>)>,
) {
let Some(rules) = injections.iter().find(|n| n.host == *lang) else {
return;
};
for (rule, default_lang) in &rules.rules {
for m in root.find_all(rule) {
let env = m.get_env();
let Some(region) = env.get_match("CONTENT") else {
continue;
};
let Some(lang) = env
.get_match("LANG")
.map(|n| n.text().to_string())
.or_else(|| default_lang.clone())
else {
continue;
};
let range = node_to_range(region);
ret.push((lang, vec![range]));
}
}
}
fn node_to_range<D: Doc>(node: &Node<D>) -> TSRange {
let r = node.range();
let start = node.start_pos();
let sp = start.byte_point();
let sp = tree_sitter::Point::new(sp.0, sp.1);
let end = node.end_pos();
let ep = end.byte_point();
let ep = tree_sitter::Point::new(ep.0, ep.1);
TSRange {
start_byte: r.start,
end_byte: r.end,
start_point: sp,
end_point: ep,
}
}
#[cfg(test)]
mod test {
use super::*;
use ast_grep_config::from_str;
use ast_grep_language::SupportLang;
const DYNAMIC: &str = "
hostLanguage: js
rule:
pattern: styled.$LANG`$CONTENT`
injected: [css]";
const STATIC: &str = "
hostLanguage: js
rule:
pattern: styled`$CONTENT`
injected: css";
#[test]
fn test_deserialize() {
let inj: SerializableInjection = from_str(STATIC).expect("should ok");
assert!(matches!(inj.injected, Injected::Static(_)));
let inj: SerializableInjection = from_str(DYNAMIC).expect("should ok");
assert!(matches!(inj.injected, Injected::Dynamic(_)));
}
const BAD: &str = "
hostLanguage: HTML
rule:
kind: not_exist
injected: [js, ts, tsx]";
#[test]
fn test_bad_inject() {
let mut map = HashMap::new();
let inj: SerializableInjection = from_str(BAD).expect("should ok");
let ret = register_injetable(inj, &mut map);
assert!(ret.is_err());
let ec = ret.unwrap_err().downcast::<EC>().expect("should ok");
assert!(matches!(ec, EC::LangInjection));
}
#[test]
fn test_good_injection() {
let mut map = HashMap::new();
let inj: SerializableInjection = from_str(STATIC).expect("should ok");
let ret = register_injetable(inj, &mut map);
assert!(ret.is_ok());
let inj: SerializableInjection = from_str(DYNAMIC).expect("should ok");
let ret = register_injetable(inj, &mut map);
assert!(ret.is_ok());
assert_eq!(map.len(), 1);
let injections: Vec<_> = map.into_values().collect();
let mut ret = Vec::new();
let lang = SgLang::from(SupportLang::JavaScript);
let sg = lang.ast_grep("const a = styled`.btn { margin: 0; }`");
let root = sg.root();
extract_custom_inject(&lang, &injections, root, &mut ret);
assert_eq!(ret.len(), 1);
assert_eq!(ret[0].0, "css");
assert_eq!(ret[0].1.len(), 1);
ret.clear();
let sg = lang.ast_grep("const a = styled.css`.btn { margin: 0; }`");
let root = sg.root();
extract_custom_inject(&lang, &injections, root, &mut ret);
assert_eq!(ret.len(), 1);
assert_eq!(ret[0].0, "css");
assert_eq!(ret[0].1.len(), 1);
}
#[test]
fn test_multiple_matches_produce_separate_entries() {
let mut map = HashMap::new();
let inj: SerializableInjection = from_str(STATIC).expect("should ok");
register_injetable(inj, &mut map).unwrap();
let injections: Vec<_> = map.into_values().collect();
let mut ret = Vec::new();
let lang = SgLang::from(SupportLang::JavaScript);
let sg = lang
.ast_grep("const a = styled`.btn { margin: 0; }`; const b = styled`.card { padding: 1em; }`");
let root = sg.root();
extract_custom_inject(&lang, &injections, root, &mut ret);
assert_eq!(ret.len(), 2);
assert_eq!(ret[0].0, "css");
assert_eq!(ret[0].1.len(), 1);
assert_eq!(ret[1].0, "css");
assert_eq!(ret[1].1.len(), 1);
assert_ne!(ret[0].1[0].start_byte, ret[1].1[0].start_byte);
}
}