use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, BTreeSet};
use itertools::Itertools;
use miette::{ensure, Diagnostic, Result};
use thiserror::Error;
use crate::data::program::{
FixedRuleArg, MagicSymbol, NormalFormAtom, NormalFormProgram, NormalFormRulesOrFixed,
StratifiedNormalFormProgram,
};
use crate::data::symb::{Symbol, PROG_ENTRY};
use crate::parse::SourceSpan;
use crate::query::graph::{
generalized_kahn, reachable_components, strongly_connected_components, Graph, StratifiedGraph,
};
impl NormalFormAtom {
fn contained_rules(&self) -> BTreeMap<&Symbol, bool> {
match self {
NormalFormAtom::Relation(_)
| NormalFormAtom::NegatedRelation(_)
| NormalFormAtom::Predicate(_)
| NormalFormAtom::Unification(_)
| NormalFormAtom::HnswSearch(_)
| NormalFormAtom::FtsSearch(_)
| NormalFormAtom::LshSearch(_) => Default::default(),
NormalFormAtom::Rule(r) => BTreeMap::from([(&r.name, false)]),
NormalFormAtom::NegatedRule(r) => BTreeMap::from([(&r.name, true)]),
}
}
}
fn convert_normal_form_program_to_graph(
nf_prog: &NormalFormProgram,
) -> StratifiedGraph<&'_ Symbol> {
let meet_rules: BTreeSet<_> = nf_prog
.prog
.iter()
.filter_map(|(k, ruleset)| match ruleset {
NormalFormRulesOrFixed::Rules { rules: ruleset } => {
let has_aggr = ruleset
.iter()
.any(|rule| rule.aggr.iter().any(|a| a.is_some()));
let is_meet = has_aggr
&& ruleset.iter().all(|rule| {
rule.aggr.iter().all(|v| match v {
None => true,
Some((v, _)) => v.is_meet,
})
});
if is_meet {
Some(k)
} else {
None
}
}
NormalFormRulesOrFixed::Fixed { fixed: _ } => None,
})
.collect();
let fixed_rules: BTreeSet<_> = nf_prog
.prog
.iter()
.filter_map(|(k, ruleset)| match ruleset {
NormalFormRulesOrFixed::Rules { rules: _ } => None,
NormalFormRulesOrFixed::Fixed { fixed: _ } => Some(k),
})
.collect();
nf_prog
.prog
.iter()
.map(|(k, ruleset)| match ruleset {
NormalFormRulesOrFixed::Rules { rules: ruleset } => {
let mut ret: BTreeMap<&Symbol, bool> = BTreeMap::default();
let has_aggr = ruleset
.iter()
.any(|rule| rule.aggr.iter().any(|a| a.is_some()));
let is_meet = has_aggr
&& ruleset.iter().all(|rule| {
rule.aggr.iter().all(|v| match v {
None => true,
Some((v, _)) => v.is_meet,
})
});
for rule in ruleset {
for atom in &rule.body {
let contained = atom.contained_rules();
for (found_key, is_negated) in contained {
let found_key_is_meet =
meet_rules.contains(found_key) && found_key != k;
let found_key_is_fixed_rule = fixed_rules.contains(found_key);
match ret.entry(found_key) {
Entry::Vacant(e) => {
if has_aggr {
if is_meet && k == found_key {
e.insert(found_key_is_fixed_rule || is_negated);
} else {
e.insert(true);
}
} else {
e.insert(
found_key_is_fixed_rule
|| found_key_is_meet
|| is_negated,
);
}
}
Entry::Occupied(mut e) => {
let old = *e.get();
let new_val = if has_aggr {
if is_meet && k == found_key {
found_key_is_fixed_rule
|| found_key_is_meet
|| is_negated
} else {
true
}
} else {
found_key_is_fixed_rule || found_key_is_meet || is_negated
};
e.insert(old || new_val);
}
}
}
}
}
(k, ret)
}
NormalFormRulesOrFixed::Fixed { fixed } => {
let mut ret: BTreeMap<&Symbol, bool> = BTreeMap::default();
for rel in &fixed.rule_args {
match rel {
FixedRuleArg::InMem { name, .. } => {
ret.insert(name, true);
}
FixedRuleArg::Stored { .. } | FixedRuleArg::NamedStored { .. } => {}
}
}
(k, ret)
}
})
.collect()
}
fn reduce_to_graph<'a>(g: &StratifiedGraph<&'a Symbol>) -> Graph<&'a Symbol> {
g.iter()
.map(|(k, s)| (*k, s.iter().map(|(sk, _)| *sk).collect_vec()))
.collect()
}
fn verify_no_cycle(g: &StratifiedGraph<&'_ Symbol>, sccs: &[BTreeSet<&Symbol>]) -> Result<()> {
for (k, vs) in g {
for scc in sccs {
if scc.contains(k) {
for (v, negated) in vs {
#[derive(Debug, Error, Diagnostic)]
#[error("Query is unstratifiable")]
#[diagnostic(code(eval::unstratifiable))]
#[diagnostic(help(
"The rule '{0}' is in the strongly connected component {1:?},\n\
and is involved in at least one forbidden dependency \n\
(negation, non-meet aggregation, or algorithm-application)."
))]
struct UnStratifiableProgram(String, Vec<String>);
ensure!(
!negated || !scc.contains(v),
UnStratifiableProgram(
v.to_string(),
scc.iter().map(|v| v.to_string()).collect_vec()
)
);
}
}
}
}
Ok(())
}
fn make_scc_reduced_graph(
sccs: &[BTreeSet<&Symbol>],
graph: &StratifiedGraph<&Symbol>,
) -> (BTreeMap<Symbol, usize>, StratifiedGraph<usize>) {
let indices = sccs
.iter()
.enumerate()
.flat_map(|(idx, scc)| scc.iter().map(move |k| ((*k).clone(), idx)))
.collect::<BTreeMap<_, _>>();
let mut ret: BTreeMap<usize, BTreeMap<usize, bool>> = Default::default();
for (from, tos) in graph {
let from_idx = *indices.get(from).unwrap();
let cur_entry = ret.entry(from_idx).or_default();
for (to, poisoned) in tos {
let to_idx = match indices.get(to) {
Some(i) => *i,
None => continue,
};
if from_idx == to_idx {
continue;
}
match cur_entry.entry(to_idx) {
Entry::Vacant(e) => {
e.insert(*poisoned);
}
Entry::Occupied(mut e) => {
let old_p = *e.get();
e.insert(old_p || *poisoned);
}
}
}
}
(indices, ret)
}
impl NormalFormProgram {
pub(crate) fn into_stratified_program(
self,
) -> Result<(StratifiedNormalFormProgram, BTreeMap<MagicSymbol, usize>)> {
let prog_entry: &Symbol = &Symbol::new(PROG_ENTRY, SourceSpan(0, 0));
let stratified_graph = convert_normal_form_program_to_graph(&self);
let graph = reduce_to_graph(&stratified_graph);
let reachable: BTreeSet<_> = reachable_components(&graph, &prog_entry)
.into_iter()
.map(|k| (*k).clone())
.collect();
let stratified_graph: StratifiedGraph<_> = stratified_graph
.into_iter()
.filter(|(k, _)| reachable.contains(k))
.collect();
let graph: Graph<_> = graph
.into_iter()
.filter(|(k, _)| reachable.contains(k))
.collect();
let sccs: Vec<BTreeSet<&Symbol>> = strongly_connected_components(&graph)?
.into_iter()
.map(|scc| scc.into_iter().cloned().collect())
.collect_vec();
verify_no_cycle(&stratified_graph, &sccs)?;
let (invert_indices, reduced_graph) = make_scc_reduced_graph(&sccs, &stratified_graph);
let sort_result = generalized_kahn(&reduced_graph, stratified_graph.len());
let n_strata = sort_result.len();
let invert_sort_result = sort_result
.into_iter()
.enumerate()
.flat_map(|(stratum, indices)| indices.into_iter().map(move |idx| (idx, stratum)))
.collect::<BTreeMap<_, _>>();
let mut ret: Vec<NormalFormProgram> = (0..n_strata)
.map(|_| NormalFormProgram {
prog: BTreeMap::new(),
disable_magic_rewrite: self.disable_magic_rewrite,
})
.collect_vec();
let mut store_lifetimes = BTreeMap::new();
for (fr, tos) in &stratified_graph {
if let Some(fr_idx) = invert_indices.get(fr) {
if let Some(fr_stratum) = invert_sort_result.get(fr_idx) {
for to in tos.keys() {
let used_in = n_strata - 1 - *fr_stratum;
let magic_to = MagicSymbol::Muggle {
inner: (*to).clone(),
};
match store_lifetimes.entry(magic_to) {
Entry::Vacant(e) => {
e.insert(used_in);
}
Entry::Occupied(mut o) => {
let existing = *o.get();
if used_in > existing {
o.insert(used_in);
}
}
}
}
}
}
}
for (name, ruleset) in self.prog {
if let Some(scc_idx) = invert_indices.get(&name) {
if let Some(rev_stratum_idx) = invert_sort_result.get(scc_idx) {
let target = ret.get_mut(*rev_stratum_idx).unwrap();
target.prog.insert(name, ruleset);
}
}
}
Ok((StratifiedNormalFormProgram(ret), store_lifetimes))
}
}
#[cfg(test)]
mod tests {
use crate::DbInstance;
#[test]
fn test_dependencies() {
let db = DbInstance::default();
let _res = db
.run_default(
r#"
x[a] <- [[1], [2]]
w[a] := a in [2]
w[a] := w[b], a = b + 1, a < 10
y[count(a)] := x[a]
y[count(a)] := w[a]
z[count(a)] := y[a]
z[count(a)] := y[b], a = b + 1
?[a] := z[a]
?[a] := w[a]
"#,
)
.unwrap()
.rows;
}
}