use super::referent_rule::{GlobalRules, ReferentRuleError, RuleRegistration};
use crate::check_var::CheckHint;
use crate::maybe::Maybe;
use crate::rule::{self, Rule, RuleSerializeError, SerializableRule};
use crate::rule_core::{RuleCoreError, SerializableRuleCore};
use crate::transform::Transformation;
use ast_grep_core::language::Language;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct SerializableGlobalRule<L: Language> {
#[serde(flatten)]
pub core: SerializableRuleCore,
pub id: String,
pub language: L,
}
fn into_map<L: Language>(
rules: Vec<SerializableGlobalRule<L>>,
) -> HashMap<String, (L, SerializableRuleCore)> {
rules
.into_iter()
.map(|r| (r.id, (r.language, r.core)))
.collect()
}
type OrderResult<T> = Result<T, String>;
pub struct DeserializeEnv<L: Language> {
pub(crate) registration: RuleRegistration<L>,
pub(crate) lang: L,
}
trait DependentRule: Sized {
fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()>;
}
impl DependentRule for SerializableRule {
fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
visit_dependent_rule_ids(self, sorter)
}
}
impl<L: Language> DependentRule for (L, SerializableRuleCore) {
fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
visit_dependent_rule_ids(&self.1.rule, sorter)
}
}
impl DependentRule for Transformation {
fn visit_dependency<'a>(&'a self, sorter: &mut TopologicalSort<'a, Self>) -> OrderResult<()> {
let used_var = self.used_vars();
sorter.visit(used_var)
}
}
struct TopologicalSort<'a, T: DependentRule> {
maps: &'a HashMap<String, T>,
order: Vec<&'a str>,
seen: HashMap<&'a str, bool>,
}
impl<'a, T: DependentRule> TopologicalSort<'a, T> {
fn get_order(maps: &HashMap<String, T>) -> OrderResult<Vec<&str>> {
let mut top_sort = TopologicalSort::new(maps);
for key in maps.keys() {
top_sort.visit(key)?;
}
Ok(top_sort.order)
}
fn new(maps: &'a HashMap<String, T>) -> Self {
Self {
maps,
order: vec![],
seen: HashMap::new(),
}
}
fn visit(&mut self, key: &'a str) -> OrderResult<()> {
if let Some(&completed) = self.seen.get(key) {
return if completed {
Ok(())
} else {
Err(key.to_string())
};
}
let Some(item) = self.maps.get(key) else {
return Ok(());
};
self.seen.insert(key, false);
item.visit_dependency(self)?;
self.seen.insert(key, true);
self.order.push(key);
Ok(())
}
}
fn visit_dependent_rule_ids<'a, T: DependentRule>(
rule: &'a SerializableRule,
sort: &mut TopologicalSort<'a, T>,
) -> OrderResult<()> {
if let Maybe::Present(matches) = &rule.matches {
sort.visit(matches)?;
}
if let Maybe::Present(all) = &rule.all {
for sub in all {
visit_dependent_rule_ids(sub, sort)?;
}
}
if let Maybe::Present(any) = &rule.any {
for sub in any {
visit_dependent_rule_ids(sub, sort)?;
}
}
if let Maybe::Present(not) = &rule.not {
visit_dependent_rule_ids(not, sort)?;
}
Ok(())
}
impl<L: Language> DeserializeEnv<L> {
pub fn new(lang: L) -> Self {
Self {
registration: Default::default(),
lang,
}
}
pub fn register_local_utils(
self,
utils: &HashMap<String, SerializableRule>,
) -> Result<Self, RuleSerializeError> {
let order = TopologicalSort::get_order(utils)
.map_err(ReferentRuleError::CyclicRule)
.map_err(RuleSerializeError::MatchesReference)?;
for id in order {
let rule = utils.get(id).expect("must exist");
let rule = self.deserialize_rule(rule.clone())?;
self.registration.insert_local(id, rule)?;
}
Ok(self)
}
pub fn parse_global_utils(
utils: Vec<SerializableGlobalRule<L>>,
) -> Result<GlobalRules<L>, RuleCoreError> {
let registration = GlobalRules::default();
let utils = into_map(utils);
let order = TopologicalSort::get_order(&utils)
.map_err(ReferentRuleError::CyclicRule)
.map_err(RuleSerializeError::from)?;
for id in order {
let (lang, core) = utils.get(id).expect("must exist");
let env = DeserializeEnv::new(lang.clone()).with_globals(®istration);
let matcher = core.get_matcher_with_hint(env, CheckHint::Global)?;
registration
.insert(id, matcher)
.map_err(RuleSerializeError::MatchesReference)?;
}
Ok(registration)
}
pub fn deserialize_rule(
&self,
serialized: SerializableRule,
) -> Result<Rule<L>, RuleSerializeError> {
rule::deserialize_rule(serialized, self)
}
pub(crate) fn get_transform_order<'a>(
&self,
trans: &'a HashMap<String, Transformation>,
) -> Result<Vec<&'a str>, String> {
TopologicalSort::get_order(trans)
}
pub fn with_globals(self, globals: &GlobalRules<L>) -> Self {
Self {
registration: RuleRegistration::from_globals(globals),
lang: self.lang,
}
}
pub fn with_rewriters(self, globals: &GlobalRules<L>) -> Self {
Self {
registration: self.registration.with_rewriters(globals),
lang: self.lang,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::TypeScript;
use crate::{from_str, Rule};
use anyhow::Result;
use ast_grep_core::Matcher;
fn get_dependent_utils() -> Result<(Rule<TypeScript>, DeserializeEnv<TypeScript>)> {
let utils = from_str(
"
accessor-name:
matches: member-name
regex: whatever
member-name:
kind: identifier
",
)?;
let env = DeserializeEnv::new(TypeScript::Tsx).register_local_utils(&utils)?;
assert_eq!(utils.keys().count(), 2);
let rule = from_str("matches: accessor-name").unwrap();
Ok((
env.deserialize_rule(rule).unwrap(),
env, ))
}
#[test]
fn test_local_util_matches() -> Result<()> {
let (rule, _env) = get_dependent_utils()?;
let grep = TypeScript::Tsx.ast_grep("whatever");
assert!(grep.root().find(rule).is_some());
Ok(())
}
#[test]
fn test_local_util_kinds() -> Result<()> {
for _ in 0..10 {
let (rule, _env) = get_dependent_utils()?;
assert!(rule.potential_kinds().is_some());
}
Ok(())
}
#[test]
fn test_using_global_rule_in_local() -> Result<()> {
let utils = from_str(
"
local-rule:
matches: global-rule
",
)?;
DeserializeEnv::new(TypeScript::Tsx).register_local_utils(&utils)?;
Ok(())
}
#[test]
fn test_using_cyclic_local() -> Result<()> {
let utils = from_str(
"
local-rule:
matches: local-rule
",
)?;
let ret = DeserializeEnv::new(TypeScript::Tsx).register_local_utils(&utils);
assert!(ret.is_err());
Ok(())
}
#[test]
fn test_using_transitive_cycle() -> Result<()> {
let utils = from_str(
"
local-rule-a:
matches: local-rule-b
local-rule-b:
all:
- matches: local-rule-c
local-rule-c:
any:
- matches: local-rule-a
",
)?;
let ret = DeserializeEnv::new(TypeScript::Tsx).register_local_utils(&utils);
assert!(ret.is_err());
Ok(())
}
#[test]
fn test_cyclic_not() -> Result<()> {
let utils = from_str(
"
local-rule-a:
not: {matches: local-rule-b}
local-rule-b:
matches: local-rule-a",
)?;
let ret = DeserializeEnv::new(TypeScript::Tsx).register_local_utils(&utils);
assert!(matches!(
ret,
Err(RuleSerializeError::MatchesReference(
ReferentRuleError::CyclicRule(_)
))
));
Ok(())
}
}