use super::{
ConstType, FlowLogParser, Lexeme, Rule,
declaration::{
CompDecl, ExternFn, InitDecl, InputDirective, OutputDirective, PrintSizeDirective,
RawTypeOp, Relation, split_type_alias,
},
error::{DirectiveKind, ParseError, grammar_bug},
inliner,
logic::{Arithmetic, AtomArg, Factor, FlowLogRule, FnCall, Head, LoopBlock, Predicate},
primitive::TypeRegistry,
segment::Segment,
};
use crate::common::{FileId, SourceMap, Span};
use pest::{Parser, iterators::Pair};
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::{fmt, fs};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Program {
relations: Vec<Relation>,
segments: Vec<Segment>,
udfs: Vec<ExternFn>,
facts: HashMap<String, Vec<(Span, Vec<ConstType>)>>,
type_registry: TypeRegistry,
}
impl fmt::Display for Program {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "=============================================")?;
writeln!(f, "FlowLog DATALOG PROGRAM")?;
writeln!(f, "=============================================")?;
writeln!(f)?;
if !self.relations.is_empty() {
writeln!(f, "Relations")?;
writeln!(f, "---------------------------------------------")?;
for rel in &self.relations {
writeln!(f, "{}", rel)?;
}
writeln!(f)?;
}
if !self.udfs.is_empty() {
writeln!(f, "Extern Functions")?;
writeln!(f, "---------------------------------------------")?;
for udf in &self.udfs {
writeln!(f, "{}", udf)?;
}
writeln!(f)?;
}
if !self.segments.is_empty() {
writeln!(f, "Program (source order)")?;
writeln!(f, "---------------------------------------------")?;
for (i, item) in self.segments.iter().enumerate() {
match item {
Segment::Plain(rules) => {
writeln!(f, "[Segment {}]", i)?;
for rule in rules {
writeln!(f, " {}", rule)?;
}
}
Segment::Loop(block) | Segment::Fixpoint(block) => {
writeln!(f, "[Loop {}]", i)?;
writeln!(f, " {}", block)?;
}
}
}
writeln!(f)?;
}
if !self.facts.is_empty() {
writeln!(f, "Facts")?;
writeln!(f, "---------------------------------------------")?;
for (rel_name, facts) in &self.facts {
for (_, vals) in facts {
let values = vals
.iter()
.map(|c| c.to_string())
.collect::<Vec<_>>()
.join(", ");
writeln!(f, "{}({}).", rel_name, values)?;
}
}
}
Ok(())
}
}
impl Program {
pub fn parse(path: &str, extended: bool, sm: &mut SourceMap) -> Result<Self, ParseError> {
Self::parse_with_includes(path, extended, &[], sm)
}
pub fn parse_with_includes(
path: &str,
extended: bool,
include_dirs: &[&Path],
sm: &mut SourceMap,
) -> Result<Self, ParseError> {
let file_path = PathBuf::from(path);
let root_file = sm
.load(&file_path)
.map_err(|source| ParseError::IncludeIo {
span: Span::DUMMY,
path: file_path.clone(),
source,
})?;
let base_dir: PathBuf = file_path
.parent()
.map(Path::to_path_buf)
.unwrap_or_else(|| PathBuf::from("."));
let mut in_progress = HashSet::new();
let mut completed = HashSet::new();
in_progress.insert(fs::canonicalize(&file_path).unwrap_or_else(|_| file_path.clone()));
let combined = resolve_includes(
sm.text(root_file).to_string(),
root_file,
&base_dir,
include_dirs,
&mut in_progress,
&mut completed,
sm,
)?;
let combined_file = sm.add(file_path.clone(), combined);
let mut pairs = FlowLogParser::parse(Rule::main_grammar, sm.text(combined_file))
.map_err(|e| ParseError::syntax_from_pest(&e, combined_file))?;
let root = pairs
.next()
.ok_or_else(|| grammar_bug("no parsed rule found"))?;
let mut program = Self::collect_program(root, extended, combined_file)?;
program.prune_dead_components();
debug!("\n{}", program);
info!("Successfully parsed program from '{}'.", path);
Ok(program)
}
#[must_use]
#[inline]
pub(crate) fn relations(&self) -> &[Relation] {
&self.relations
}
#[must_use]
pub fn edbs(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| self.is_edb_relation(rel))
.collect()
}
#[cfg(test)]
#[must_use]
#[inline]
pub(crate) fn file_backed_relations(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| rel.is_file_backed())
.collect()
}
#[cfg(test)]
#[must_use]
pub(crate) fn inline_fact_relations(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| self.has_inline_facts(rel.name()))
.collect()
}
#[must_use]
pub(crate) fn edb_names(&self) -> Vec<String> {
let mut names: Vec<String> = self
.edbs()
.iter()
.map(|rel| rel.name().to_string())
.collect();
names.sort_unstable();
names
}
#[must_use]
pub(crate) fn edb_fingerprints(&self) -> HashSet<u64> {
self.edbs().iter().map(|rel| rel.fingerprint()).collect()
}
#[must_use]
#[inline]
pub(crate) fn idbs(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| rel.is_output_printsize())
.collect()
}
#[must_use]
#[inline]
pub fn output_idbs(&self) -> Vec<&Relation> {
self.relations.iter().filter(|rel| rel.output()).collect()
}
#[must_use]
#[inline]
pub fn printsize_idbs(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| rel.printsize())
.collect()
}
#[must_use]
#[inline]
pub(crate) fn segments(&self) -> &[Segment] {
&self.segments
}
#[must_use]
pub fn rules(&self) -> Vec<&FlowLogRule> {
self.segments
.iter()
.flat_map(|item| item.as_rules())
.collect()
}
pub(crate) fn segments_mut(&mut self) -> &mut [Segment] {
&mut self.segments
}
pub(crate) fn facts_mut(&mut self) -> &mut HashMap<String, Vec<(Span, Vec<ConstType>)>> {
&mut self.facts
}
#[must_use]
pub(crate) fn rule(&self, rid: usize) -> &FlowLogRule {
let mut offset = 0;
for seg in &self.segments {
let rules: &[FlowLogRule] = match seg {
Segment::Plain(rules) => rules,
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
};
if rid < offset + rules.len() {
return &rules[rid - offset];
}
offset += rules.len();
}
panic!("Parser error: rule ID {rid} out of bounds");
}
#[must_use]
#[inline]
pub fn facts(&self) -> &HashMap<String, Vec<(Span, Vec<ConstType>)>> {
&self.facts
}
#[must_use]
#[inline]
pub(crate) fn has_inline_facts(&self, relation_name: &str) -> bool {
self.facts.contains_key(relation_name)
}
#[must_use]
#[inline]
pub(crate) fn udfs(&self) -> &[ExternFn] {
&self.udfs
}
#[inline]
pub(crate) fn registry_and_segments_mut(&mut self) -> (&TypeRegistry, &mut [Segment]) {
(&self.type_registry, &mut self.segments)
}
#[inline]
fn is_edb_relation(&self, rel: &Relation) -> bool {
rel.has_input() || self.has_inline_facts(rel.name())
}
}
fn resolve_includes(
source: String,
source_file: FileId,
base_dir: &Path,
include_dirs: &[&Path],
in_progress: &mut HashSet<PathBuf>,
completed: &mut HashSet<PathBuf>,
sm: &mut SourceMap,
) -> Result<String, ParseError> {
let mut pairs = FlowLogParser::parse(Rule::main_grammar, &source)
.map_err(|e| ParseError::syntax_from_pest(&e, source_file))?;
let root = pairs
.next()
.ok_or_else(|| grammar_bug("no parsed rule found"))?;
let mut out = String::with_capacity(source.len());
let mut cursor = 0usize;
for node in root.into_inner() {
if node.as_rule() != Rule::include_directive {
continue;
}
let span = node.as_span();
let directive_span = Span::new(source_file, span.start() as u32, span.end() as u32);
out.push_str(&source[cursor..span.start()]);
cursor = span.end();
let path_node = node
.into_inner()
.next()
.ok_or_else(|| grammar_bug("include directive missing path"))?;
let raw = path_node.as_str().trim_matches('"');
let full_path = resolve_one_include(raw, base_dir, include_dirs);
let canonical = fs::canonicalize(&full_path).unwrap_or_else(|_| full_path.clone());
if in_progress.contains(&canonical) {
return Err(ParseError::CircularInclude {
span: directive_span,
path: full_path.clone(),
chain: in_progress.iter().cloned().collect(),
});
}
if completed.contains(&canonical) {
warn!("Skipping duplicate include '{}'.", full_path.display());
continue;
}
debug!("Including '{}'.", full_path.display());
let included_file = sm
.load(&full_path)
.map_err(|source| ParseError::IncludeIo {
span: directive_span,
path: full_path.clone(),
source,
})?;
let included_source = sm.text(included_file).to_string();
let included_base = full_path.parent().unwrap_or(Path::new(".")).to_path_buf();
in_progress.insert(canonical.clone());
let inlined = resolve_includes(
included_source,
included_file,
&included_base,
include_dirs,
in_progress,
completed,
sm,
)?;
in_progress.remove(&canonical);
completed.insert(canonical);
if out.chars().last().is_some_and(|c| !c.is_whitespace()) {
out.push('\n');
}
out.push_str(&inlined);
if out.chars().last().is_some_and(|c| !c.is_whitespace()) {
out.push('\n');
}
}
out.push_str(&source[cursor..]);
Ok(out)
}
fn resolve_one_include(raw: &str, base_dir: &Path, include_dirs: &[&Path]) -> PathBuf {
let parent_relative = base_dir.join(raw);
if parent_relative.exists() {
return parent_relative;
}
for dir in include_dirs {
let candidate = dir.join(raw);
if candidate.exists() {
return candidate;
}
}
parent_relative
}
fn build_type_registry(parsed_rule: Pair<Rule>, file: FileId) -> Result<TypeRegistry, ParseError> {
let mut registry = TypeRegistry::new();
for node in parsed_rule.into_inner() {
if node.as_rule() != Rule::type_alias_decl {
continue;
}
let (name, op, parent, span) = split_type_alias(node, file)?;
match op {
RawTypeOp::Subtype => {
registry.register_subtype(&name, &parent, span)?;
}
RawTypeOp::Alias => {
registry.register_alias(&name, &parent, span)?;
}
}
}
Ok(registry)
}
#[inline]
fn flush_rules(pending: &mut Vec<FlowLogRule>, out: &mut Vec<Segment>) {
if !pending.is_empty() {
out.push(Segment::Plain(std::mem::take(pending)));
}
}
fn normalize_inliner_dots(
relations: &mut [Relation],
segments: &mut [Segment],
raw_facts: &mut [FlowLogRule],
) {
for rel in relations.iter_mut() {
if rel.name().contains('.') {
let renamed = rel.raw_name().replace('.', INLINER_SEP);
rel.set_name(renamed);
}
}
for_each_rule_mut(segments, normalize_rule_dots);
for fact in raw_facts.iter_mut() {
normalize_rule_dots(fact);
}
}
fn normalize_rule_dots(rule: &mut FlowLogRule) {
let head = rule.head_mut();
if head.name().contains('.') {
head.set_name(head.name().replace('.', INLINER_SEP));
}
for pred in rule.rhs_mut() {
if let Predicate::PositiveAtom(a) | Predicate::NegativeAtom(a) = pred
&& a.name().contains('.')
{
a.set_name(a.name().replace('.', INLINER_SEP));
}
}
}
const INLINER_SEP: &str = "·";
fn for_each_rule_mut<F>(segments: &mut [Segment], mut f: F)
where
F: FnMut(&mut FlowLogRule),
{
for seg in segments.iter_mut() {
let rules: &mut [FlowLogRule] = match seg {
Segment::Plain(rs) => rs.as_mut_slice(),
Segment::Loop(b) | Segment::Fixpoint(b) => b.rules_mut(),
};
for rule in rules {
f(rule);
}
}
}
fn check_duplicate_directives<T>(
dirs: &[T],
kind: DirectiveKind,
name_of: impl Fn(&T) -> &str,
span_of: impl Fn(&T) -> Span,
) -> Result<(), ParseError> {
let mut seen: HashMap<&str, Span> = HashMap::new();
for d in dirs {
let name = name_of(d);
let span = span_of(d);
if let Some(prior) = seen.get(name) {
return Err(ParseError::DuplicateDirective {
span,
prior: *prior,
kind,
name: name.to_string(),
});
}
seen.insert(name, span);
}
Ok(())
}
impl Program {
fn collect_program(
parsed_rule: Pair<Rule>,
extended: bool,
file: FileId,
) -> Result<Self, ParseError> {
let mut type_registry = build_type_registry(parsed_rule.clone(), file)?;
let mut relations: Vec<Relation> = Vec::new();
let mut decl_spans: HashMap<String, (String, Span)> = HashMap::new();
let mut input_directives: Vec<InputDirective> = Vec::new();
let mut output_directives: Vec<OutputDirective> = Vec::new();
let mut printsize_directives: Vec<PrintSizeDirective> = Vec::new();
let mut udfs: Vec<ExternFn> = Vec::new();
let mut raw_facts: Vec<FlowLogRule> = Vec::new();
let mut current_rules: Vec<FlowLogRule> = Vec::new();
let mut segments: Vec<Segment> = Vec::new();
let mut comps: HashMap<String, CompDecl> = HashMap::new();
let mut inits_at_pos: Vec<(InitDecl, usize)> = Vec::new();
for node in parsed_rule.into_inner() {
match node.as_rule() {
Rule::declaration => {
let rel = Relation::from_parsed_rule_with_registry(node, file, &type_registry)?;
if let Some((_prev_raw, prior)) = decl_spans.get(rel.name()) {
return Err(ParseError::DuplicateDecl {
span: rel.span(),
prior: *prior,
name: rel.raw_name().to_string(),
});
}
decl_spans.insert(
rel.name().to_string(),
(rel.raw_name().to_string(), rel.span()),
);
relations.push(rel);
}
Rule::extern_fn => {
udfs.push(ExternFn::from_parsed_rule(node, file, &type_registry)?)
}
Rule::type_alias_decl => {} Rule::comp_decl => {
let comp = CompDecl::from_parsed_rule(node, file)?;
comps.insert(comp.name.clone(), comp);
}
Rule::init_decl => {
flush_rules(&mut current_rules, &mut segments);
let init = InitDecl::from_parsed_rule(node, file)?;
inits_at_pos.push((init, segments.len()));
}
Rule::input_directive => {
input_directives.push(InputDirective::from_parsed_rule(node, file)?)
}
Rule::output_directive => {
output_directives.push(OutputDirective::from_parsed_rule(node, file)?)
}
Rule::printsize_directive => {
printsize_directives.push(PrintSizeDirective::from_parsed_rule(node, file)?)
}
Rule::rule => {
current_rules.extend(FlowLogRule::expand_from_parsed_rule(node, file)?);
}
Rule::loop_block => {
let block = LoopBlock::from_parsed_rule(node, file)?;
if !extended {
return Err(ParseError::LoopBlockInStandardMode { span: block.span() });
}
flush_rules(&mut current_rules, &mut segments);
segments.push(Segment::Loop(block));
}
Rule::fixpoint_block => {
let block = LoopBlock::from_parsed_rule(node, file)?;
if !extended {
return Err(ParseError::LoopBlockInStandardMode { span: block.span() });
}
flush_rules(&mut current_rules, &mut segments);
segments.push(Segment::Fixpoint(block));
}
Rule::fact => {
let head_node = node
.into_inner()
.next()
.ok_or_else(|| grammar_bug("fact missing head"))?;
raw_facts.push(FlowLogRule::new(
Head::from_parsed_rule(head_node, file)?,
vec![],
));
}
Rule::include_directive => {
return Err(grammar_bug(
"unexpected include_directive in parsed tree; includes should have been resolved before parsing",
));
}
_ => {}
}
}
flush_rules(&mut current_rules, &mut segments);
let mut shift = 0usize;
for (init, pos) in inits_at_pos {
let mut out = inliner::InlinerOutput::default();
inliner::inline_one("", init, &mut comps, &mut out, &mut type_registry)?;
for rel in out.relations {
if let Some((_prev_raw, prior)) = decl_spans.get(rel.name()) {
return Err(ParseError::DuplicateDecl {
span: rel.span(),
prior: *prior,
name: rel.raw_name().to_string(),
});
}
decl_spans.insert(
rel.name().to_string(),
(rel.raw_name().to_string(), rel.span()),
);
relations.push(rel);
}
raw_facts.extend(out.facts);
if !out.rules.is_empty() {
segments.insert(pos + shift, Segment::Plain(out.rules));
shift += 1;
}
}
Self::apply_directives(
&mut relations,
input_directives,
output_directives,
printsize_directives,
)?;
Self::reclassify_udf_predicates(&mut segments, &udfs)?;
Self::validate_loop_conditions(&segments, &relations)?;
normalize_inliner_dots(&mut relations, &mut segments, &mut raw_facts);
let mut program = Self {
relations,
segments,
udfs,
type_registry,
..Self::default()
};
for fact in raw_facts {
program.extract_fact(fact);
}
program.validate_relation_references()?;
Ok(program)
}
fn validate_relation_references(&self) -> Result<(), ParseError> {
let declared: HashSet<&str> = self.relations.iter().map(|r| r.name()).collect();
for segment in &self.segments {
let rules: &[FlowLogRule] = match segment {
Segment::Plain(rules) => rules,
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
};
for rule in rules {
let head = rule.head();
if !declared.contains(head.name()) {
return Err(ParseError::UndeclaredInRule {
span: head.span(),
name: head.name().to_string(),
});
}
for pred in rule.rhs() {
if let Predicate::PositiveAtom(atom) | Predicate::NegativeAtom(atom) = pred
&& !declared.contains(atom.name())
{
return Err(ParseError::UndeclaredInRule {
span: atom.span(),
name: atom.name().to_string(),
});
}
}
}
}
for (rel_name, tuples) in &self.facts {
if !declared.contains(rel_name.as_str()) {
let span = tuples.first().map(|(s, _)| *s).unwrap_or(Span::DUMMY);
return Err(ParseError::UndeclaredInFact {
span,
name: rel_name.clone(),
});
}
}
Ok(())
}
fn apply_directives(
relations: &mut [Relation],
input_directives: Vec<InputDirective>,
output_directives: Vec<OutputDirective>,
printsize_directives: Vec<PrintSizeDirective>,
) -> Result<(), ParseError> {
check_duplicate_directives(
&input_directives,
DirectiveKind::Input,
|d| d.relation_name(),
|d| d.span(),
)?;
check_duplicate_directives(
&output_directives,
DirectiveKind::Output,
|d| d.relation_name(),
|d| d.span(),
)?;
check_duplicate_directives(
&printsize_directives,
DirectiveKind::PrintSize,
|d| d.relation_name(),
|d| d.span(),
)?;
for d in input_directives {
match relations.iter_mut().find(|r| r.name() == d.relation_name()) {
Some(rel) => rel.set_input_params(d.parameters().clone()),
None => {
return Err(ParseError::UndeclaredInDirective {
span: d.span(),
kind: DirectiveKind::Input,
name: d.relation_name().to_string(),
});
}
}
}
for d in output_directives {
match relations.iter_mut().find(|r| r.name() == d.relation_name()) {
Some(rel) => {
rel.set_output(true);
if !d.parameters().is_empty() {
rel.set_output_params(d.parameters().clone())?;
}
}
None => {
return Err(ParseError::UndeclaredInDirective {
span: d.span(),
kind: DirectiveKind::Output,
name: d.relation_name().to_string(),
});
}
}
}
for d in printsize_directives {
match relations.iter_mut().find(|r| r.name() == d.relation_name()) {
Some(rel) => rel.set_printsize(true),
None => {
return Err(ParseError::UndeclaredInDirective {
span: d.span(),
kind: DirectiveKind::PrintSize,
name: d.relation_name().to_string(),
});
}
}
}
Ok(())
}
fn validate_loop_conditions(
items: &[Segment],
relations: &[Relation],
) -> Result<(), ParseError> {
let declared: HashSet<&str> = relations.iter().map(|r| r.name()).collect();
for item in items {
let Some(block) = item.as_loop() else {
continue;
};
for directive in block.iterative_relations() {
let name = directive.name();
if !declared.contains(name) {
return Err(ParseError::UndeclaredInIterativeList {
span: block.span(),
name: name.to_string(),
});
}
}
let Some(cond) = block.condition() else {
continue;
};
let Some(until_group) = cond.until_part() else {
continue;
};
for rel in until_group.relations() {
let name = rel.name();
if !declared.contains(name) && !declared.contains(name.to_lowercase().as_str()) {
return Err(ParseError::UndeclaredLoopCondition {
span: block.span(),
name: name.to_string(),
});
}
let decl = relations
.iter()
.find(|r| r.name() == name || r.name() == name.to_lowercase().as_str())
.ok_or_else(|| grammar_bug("already confirmed declared above"))?;
if decl.arity() != 0 {
return Err(ParseError::NonNullaryLoopCondition {
span: block.span(),
name: name.to_string(),
arity: decl.arity(),
});
}
}
}
Ok(())
}
fn reclassify_udf_predicates(
items: &mut [Segment],
udfs: &[ExternFn],
) -> Result<(), ParseError> {
let udf_names: HashSet<&str> = udfs.iter().map(ExternFn::name).collect();
if udf_names.is_empty() {
return Ok(());
}
let mut result = Ok(());
for_each_rule_mut(items, |rule| {
if result.is_ok() {
result = Self::reclassify_one_rule(rule, &udf_names);
}
});
result
}
fn reclassify_one_rule(
rule: &mut FlowLogRule,
udf_names: &HashSet<&str>,
) -> Result<(), ParseError> {
let needs_rewrite = rule.rhs().iter().any(|p| {
matches!(
p,
Predicate::PositiveAtom(a) | Predicate::NegativeAtom(a)
if udf_names.contains(a.name())
)
});
if !needs_rewrite {
return Ok(());
}
let mut new_rhs = Vec::with_capacity(rule.rhs().len());
for pred in rule.rhs() {
new_rhs.push(match pred {
Predicate::PositiveAtom(atom) | Predicate::NegativeAtom(atom)
if udf_names.contains(atom.name()) =>
{
let mut args = Vec::with_capacity(atom.arguments().len());
for a in atom.arguments() {
match a {
AtomArg::Var(v) => {
args.push(Arithmetic::new(Factor::Var(v.clone()), vec![]));
}
AtomArg::Const(c) => {
args.push(Arithmetic::new(Factor::Const(c.clone()), vec![]));
}
AtomArg::Placeholder => {
return Err(ParseError::PlaceholderInUdf {
span: atom.span(),
udf_name: atom.name().to_string(),
});
}
}
}
Predicate::FnCall(
FnCall::new(
atom.name().to_string(),
args,
matches!(pred, Predicate::NegativeAtom(_)),
)
.with_span(atom.span()),
)
}
other => other.clone(),
});
}
*rule = FlowLogRule::new(rule.head().clone(), new_rhs);
Ok(())
}
fn extract_fact(&mut self, fact_rule: FlowLogRule) {
let rel_name = fact_rule.head().name().to_string();
let span = fact_rule.head().span();
let tuple = fact_rule.extract_constants_from_head();
self.facts.entry(rel_name).or_default().push((span, tuple));
}
}
const NO_TOP_LEVEL_RULE_ID: usize = usize::MAX;
impl Program {
#[must_use]
fn identify_needed_components(&self) -> ((HashSet<usize>, HashSet<String>), HashSet<String>) {
let all_rules: Vec<&FlowLogRule> = self
.segments
.iter()
.flat_map(|item| match item {
Segment::Plain(rules) => rules.as_slice(),
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
})
.collect();
let mut needed_preds: HashSet<String> = self
.idbs()
.into_iter()
.map(|d| d.name().to_string())
.collect();
needed_preds.extend(self.facts.keys().cloned());
needed_preds.extend(
self.segments
.iter()
.filter_map(Segment::as_loop)
.flat_map(|block| {
block
.condition()
.and_then(|cond| cond.until_part())
.into_iter()
.flat_map(|stop| stop.relations().map(|rel| rel.name().to_string()))
}),
);
if needed_preds.is_empty() {
let all_indices = (0..all_rules.len()).collect();
let all_preds = self
.relations
.iter()
.map(|d| d.name().to_string())
.collect();
return ((all_indices, all_preds), HashSet::new());
}
let mut head_to_rules: HashMap<String, Vec<usize>> = HashMap::new();
for (i, r) in all_rules.iter().enumerate() {
head_to_rules
.entry(r.head().name().to_string())
.or_default()
.push(i);
}
let input_relations: HashSet<String> = self
.relations
.iter()
.filter(|r| r.has_input())
.map(|r| r.name().to_string())
.collect();
let underived: Vec<String> = needed_preds
.iter()
.filter(|p| {
!head_to_rules.contains_key(p.as_str())
&& !self.facts.contains_key(p.as_str())
&& !input_relations.contains(p.as_str())
})
.cloned()
.collect();
for name in &underived {
needed_preds.remove(name);
}
let mut needed_rules: HashSet<usize> = needed_preds
.iter()
.flat_map(|p| head_to_rules.get(p).into_iter().flatten().copied())
.collect();
let dep_map: HashMap<usize, Vec<(usize, String)>> = all_rules
.iter()
.enumerate()
.map(|(i, r)| {
let deps = r
.rhs()
.iter()
.filter_map(|pred| match pred {
Predicate::PositiveAtom(a) | Predicate::NegativeAtom(a) => Some(a.name()),
_ => None,
})
.flat_map(|atom_name| {
if let Some(ids) = head_to_rules.get(atom_name) {
ids.iter()
.map(|&id| (id, atom_name.to_string()))
.collect::<Vec<_>>()
} else {
vec![(NO_TOP_LEVEL_RULE_ID, atom_name.to_string())]
}
})
.collect();
(i, deps)
})
.collect();
let mut processed: HashSet<usize> = HashSet::new();
let mut stack: Vec<usize> = needed_rules.iter().copied().collect();
while let Some(rule_id) = stack.pop() {
if !processed.insert(rule_id) {
continue;
}
for &(dep_rule_id, ref pred_name) in dep_map.get(&rule_id).into_iter().flatten() {
needed_preds.insert(pred_name.clone());
if dep_rule_id != NO_TOP_LEVEL_RULE_ID && !processed.contains(&dep_rule_id) {
needed_rules.insert(dep_rule_id);
stack.push(dep_rule_id);
}
}
}
let underived: HashSet<String> = underived.into_iter().collect();
((needed_rules, needed_preds), underived)
}
fn prune_dead_components(&mut self) {
let ((needed_rules, needed_preds), underived) = self.identify_needed_components();
let dead_relations: Vec<_> = self
.relations
.iter()
.filter(|d| !needed_preds.contains(d.name()) && !underived.contains(d.name()))
.map(|d| d.name().to_string())
.collect();
let dead_rules: Vec<_> = self
.segments
.iter()
.flat_map(|item| match item {
Segment::Plain(rules) => rules.as_slice(),
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
})
.enumerate()
.filter(|(i, _)| !needed_rules.contains(i))
.map(|(i, r)| format!("#{}: {}", i, r))
.collect();
if !underived.is_empty() || !dead_relations.is_empty() || !dead_rules.is_empty() {
let mut parts = Vec::new();
if !underived.is_empty() {
let mut sorted: Vec<_> = underived.iter().collect();
sorted.sort();
parts.push(format!(
" underived IDBs (declared but no rules): {}",
sorted
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>()
.join(", ")
));
}
if !dead_relations.is_empty() {
parts.push(format!(
" unreachable relations: {}",
dead_relations.join(", ")
));
}
if !dead_rules.is_empty() {
parts.push(format!(" unreachable rules: {}", dead_rules.join(", ")));
}
warn!("Pruned dead components:\n{}", parts.join("\n"));
}
self.relations.retain(|d| needed_preds.contains(d.name()));
let mut global_idx = 0usize;
let new_items: Vec<Segment> = self
.segments
.drain(..)
.filter_map(|item| match item {
Segment::Plain(rules) => {
let filtered: Vec<FlowLogRule> = rules
.into_iter()
.filter(|_| {
let keep = needed_rules.contains(&global_idx);
global_idx += 1;
keep
})
.collect();
if filtered.is_empty() {
None
} else {
Some(Segment::Plain(filtered))
}
}
Segment::Loop(mut block) => {
block.rules_mut().retain(|_| {
let keep = needed_rules.contains(&global_idx);
global_idx += 1;
keep
});
if block.rules().is_empty() {
None
} else {
Some(Segment::Loop(block))
}
}
Segment::Fixpoint(mut block) => {
block.rules_mut().retain(|_| {
let keep = needed_rules.contains(&global_idx);
global_idx += 1;
keep
});
if block.rules().is_empty() {
None
} else {
Some(Segment::Fixpoint(block))
}
}
})
.collect();
self.segments = new_items;
self.facts
.retain(|rel, _| needed_preds.contains(rel.as_str()));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::DataType;
use std::io::Write;
fn loop_blocks(program: &Program) -> Vec<&LoopBlock> {
program
.segments()
.iter()
.filter_map(|s| s.as_loop())
.collect()
}
fn parse_program(src: &str) -> Program {
parse_program_result(src).expect("parse failed")
}
fn parse_program_result(src: &str) -> Result<Program, ParseError> {
let mut tmp = tempfile::NamedTempFile::new().expect("failed to create temp file");
tmp.write_all(src.as_bytes())
.expect("failed to write temp file");
let mut sm = SourceMap::new();
Program::parse(&tmp.path().to_string_lossy(), true, &mut sm)
}
#[test]
fn decl_case_collision_rejected() {
let err = parse_program_result(
"
.decl edge(x: number)
.decl Edge(y: number)
",
)
.unwrap_err();
assert!(
matches!(err, ParseError::DuplicateDecl { .. }),
"got {err:?}"
);
}
#[test]
fn attr_case_collision_rejected() {
let err = parse_program_result(".decl edge(x: number, X: number)").unwrap_err();
assert!(
matches!(err, ParseError::DuplicateAttribute { .. }),
"got {err:?}"
);
}
#[test]
fn ordering_preserved() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.output c
a(X) :- b(X).
fixpoint { b(X) :- a(X). }
c(X) :- a(X).
";
let program = parse_program(src);
let items = program.segments();
assert_eq!(items.len(), 3);
assert!(matches!(&items[0], Segment::Plain(r) if r.len() == 1));
assert!(matches!(&items[1], Segment::Fixpoint(_)));
assert!(matches!(&items[2], Segment::Plain(r) if r.len() == 1));
}
#[test]
fn multiple_rules_before_loop_form_one_segment() {
let src = "
.decl a(x: number)
.decl b(x: number)
.output a
a(1) :- b(1).
a(2) :- b(2).
fixpoint { a(X) :- b(X). }
";
let program = parse_program(src);
let items = program.segments();
assert_eq!(items.len(), 2);
assert!(matches!(&items[0], Segment::Plain(r) if r.len() == 2));
assert!(matches!(&items[1], Segment::Fixpoint(_)));
}
#[test]
fn rules_method_flattens_segments() {
let src = "
.decl a(x: number)
.decl b(x: number)
.output a
a(X) :- b(X).
fixpoint { }
a(1) :- b(1).
";
let program = parse_program(src);
assert_eq!(program.rules().len(), 2);
}
#[test]
fn loop_block_with_declared_relation() {
let src = "
.decl done()
.decl edge(x: number, y: number)
.output done
loop until { done } { done() :- edge(1, 2). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
}
#[test]
fn fixpoint_needs_no_declaration() {
let src = "
.decl edge(x: number, y: number)
.output edge
fixpoint { edge(1, 2) :- edge(1, 2). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
}
#[test]
fn loop_while_needs_no_declaration() {
let src = "
.decl edge(x: number, y: number)
.output edge
loop while { @it <= 9 } { edge(1, 2) :- edge(1, 2). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
}
#[test]
fn iterative_declared_passes() {
let src = "
.decl edge(x: number, y: number)
.decl active_edge(x: number, y: number)
.output active_edge
fixpoint { .iterative active_edge active_edge(X, Y) :- edge(X, Y). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
assert_eq!(loop_blocks(&program)[0].iterative_relations().len(), 1);
}
#[test]
fn dead_code_elimination_keeps_loop_until_relations() {
let src = "
.decl edge(x: number, y: number)
.decl keep()
.decl dead()
.output edge
edge(1, 2).
loop until { keep } {
keep() :- edge(1, 2).
}
dead() :- edge(2, 3).
";
let program = parse_program(src);
assert!(program.relations().iter().any(|rel| rel.name() == "keep"));
assert!(!program.relations().iter().any(|rel| rel.name() == "dead"));
}
#[test]
fn edb_subsets_track_file_backed_inline_and_overlap_relations() {
let src = "
.decl file_only(x: number)
.decl fact_only(x: number)
.decl both(x: number)
.decl out(x: number)
.input file_only(IO=\"file\", filename=\"file_only.csv\", delimiter=\",\")
.input both(IO=\"file\", filename=\"both.csv\", delimiter=\",\")
.output out
fact_only(1).
both(2).
out(X) :- file_only(X).
out(X) :- fact_only(X).
out(X) :- both(X).
";
let program = parse_program(src);
let mut edbs = program
.edbs()
.into_iter()
.map(|rel| rel.name().to_string())
.collect::<Vec<_>>();
edbs.sort_unstable();
let mut file_backed = program
.file_backed_relations()
.into_iter()
.map(|rel| rel.name().to_string())
.collect::<Vec<_>>();
file_backed.sort_unstable();
let mut inline_facts = program
.inline_fact_relations()
.into_iter()
.map(|rel| rel.name().to_string())
.collect::<Vec<_>>();
inline_facts.sort_unstable();
assert_eq!(edbs, vec!["both", "fact_only", "file_only"]);
assert_eq!(file_backed, vec!["both", "file_only"]);
assert_eq!(inline_facts, vec!["both", "fact_only"]);
}
#[test]
fn multi_head_rule_expands() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.output b
.output c
b(X), c(X) :- a(X).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 2);
assert_eq!(rules[0].head().name(), "b");
assert_eq!(rules[1].head().name(), "c");
assert_eq!(rules[0].rhs().len(), 1);
assert_eq!(rules[1].rhs().len(), 1);
}
#[test]
fn multi_head_rule_in_fixpoint() {
let src = "
.decl a(x: number, y: number)
.decl b(x: number, y: number)
.decl c(x: number, y: number)
.output b
.output c
fixpoint { b(X, Y), c(X, Y) :- a(X, Y). }
";
let program = parse_program(src);
let blocks = loop_blocks(&program);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].rules().len(), 2);
assert_eq!(blocks[0].rules()[0].head().name(), "b");
assert_eq!(blocks[0].rules()[1].head().name(), "c");
}
#[test]
fn multi_body_rule_expands() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.output c
c(X) :- a(X); b(X).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 2);
assert_eq!(rules[0].head().name(), "c");
assert_eq!(rules[1].head().name(), "c");
assert_eq!(rules[0].rhs()[0].name(), "a");
assert_eq!(rules[1].rhs()[0].name(), "b");
}
#[test]
fn disjunction_arm_can_be_a_conjunction() {
let src = "
.decl a(x: number) .decl b(x: number)
.decl c(x: number) .decl d(x: number)
.decl r(x: number)
.output r
r(X) :- ( a(X), b(X) ; c(X), d(X) ).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 2);
let bodies: Vec<Vec<&str>> = rules
.iter()
.map(|r| r.rhs().iter().map(|p| p.name()).collect())
.collect();
assert!(bodies.contains(&vec!["a", "b"]));
assert!(bodies.contains(&vec!["c", "d"]));
}
#[test]
fn nested_disjunctions_cross_product() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl d(x: number)
.decl r(x: number)
.output r
r(X) :- ( a(X) ; b(X) ), ( c(X) ; d(X) ).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 4);
let bodies: Vec<(&str, &str)> = rules
.iter()
.map(|r| (r.rhs()[0].name(), r.rhs()[1].name()))
.collect();
assert!(bodies.contains(&("a", "c")));
assert!(bodies.contains(&("a", "d")));
assert!(bodies.contains(&("b", "c")));
assert!(bodies.contains(&("b", "d")));
}
#[test]
fn multi_head_multi_body_expands() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl d(x: number)
.output c
.output d
c(X), d(X) :- a(X); b(X).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 4);
assert_eq!(rules[0].head().name(), "c");
assert_eq!(rules[0].rhs()[0].name(), "a");
assert_eq!(rules[1].head().name(), "c");
assert_eq!(rules[1].rhs()[0].name(), "b");
assert_eq!(rules[2].head().name(), "d");
assert_eq!(rules[2].rhs()[0].name(), "a");
assert_eq!(rules[3].head().name(), "d");
assert_eq!(rules[3].rhs()[0].name(), "b");
}
#[test]
fn diamond_include_dedups_leaf() {
let dir = tempfile::tempdir().expect("tempdir");
let write = |name: &str, body: &str| {
std::fs::write(dir.path().join(name), body).expect("write");
};
write(
"leaf.dl",
".decl leaf_rel(x: number)\n.output leaf_rel\nleaf_rel(1).\n",
);
write("left.dl", ".include \"leaf.dl\"\n");
write("right.dl", ".include \"leaf.dl\"\n");
write("root.dl", ".include \"left.dl\"\n.include \"right.dl\"\n");
let mut sm = SourceMap::new();
let program = Program::parse(&dir.path().join("root.dl").to_string_lossy(), true, &mut sm)
.expect("diamond include should succeed with dedup");
let rels: Vec<_> = program
.relations()
.iter()
.filter(|r| r.name() == "leaf_rel")
.collect();
assert_eq!(rels.len(), 1, "leaf_rel inlined twice");
}
#[test]
fn type_alias_chain_resolves_to_root() {
let src = "
.type C = number
.type B = C
.type A = B
.decl R(x: A, y: B, z: C)
.output R
R(1, 2, 3).
";
let program = parse_program(src);
let r = program
.relations()
.iter()
.find(|r| r.name() == "r")
.unwrap();
assert_eq!(
r.data_type(),
vec![DataType::Int32, DataType::Int32, DataType::Int32]
);
}
#[test]
fn negated_udf_reclassification_preserves_negation() {
let src = "
.decl edge(x: number, y: number)
.decl out(x: number, y: number)
.output out
.extern fn cost(x: number) -> number
out(X, Y) :- edge(X, Y), !cost(X).
";
let program = parse_program(src);
let rule = program.rules()[0];
let fn_call = rule
.rhs()
.iter()
.find_map(|p| match p {
Predicate::FnCall(fc) => Some(fc),
_ => None,
})
.expect("udf body atom should be reclassified to FnCall");
assert!(
fn_call.is_negated(),
"negation lost during reclassification"
);
assert_eq!(fn_call.name(), "cost");
}
fn fact_numbers(program: &Program, rel: &str) -> Vec<i64> {
program
.facts()
.get(rel)
.unwrap_or_else(|| panic!("no facts for `{rel}`"))
.iter()
.map(|(_, tuple)| match &tuple[0] {
ConstType::Int(n) => *n,
other => panic!("expected number in `{rel}`, got {other:?}"),
})
.collect()
}
#[test]
fn override_drops_parent_facts() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
Foo(2).
}
.comp Sub : Base {
.override Foo
Foo(10).
}
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
assert_eq!(fact_numbers(&program, "s·foo"), vec![10]);
}
#[test]
fn override_drops_parent_derived_rule() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
.decl Seed(x: number)
Foo(x) :- Seed(x).
}
.comp Sub : Base {
.override Foo
Foo(x) :- Seed(x), x > 5.
}
.init s = Sub
.input s.Seed(IO=\"file\", filename=\"Seed.csv\", delimiter=\",\")
.output s.Foo
";
let program = parse_program(src);
let rules: Vec<_> = program
.rules()
.into_iter()
.filter(|r| r.head().name() == "s·foo")
.collect();
assert_eq!(rules.len(), 1, "exactly one s·foo rule survives");
let has_compare = rules[0]
.rhs()
.iter()
.any(|p| matches!(p, Predicate::Compare(_)));
assert!(has_compare, "override's filtered rule should survive");
}
#[test]
fn overridable_without_override_keeps_parent_facts() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base { }
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
let tuples = program.facts().get("s·foo").expect("s·foo facts");
assert_eq!(tuples.len(), 1);
}
#[test]
fn override_of_non_overridable_errors() {
let src = "
.comp Base { .decl Foo(x: number) Foo(1). }
.comp Sub : Base { .override Foo Foo(10). }
.init s = Sub
.output s.Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverrideOfNonOverridable { .. }),
"got {err:?}"
);
}
#[test]
fn override_unknown_relation_errors() {
let src = "
.comp Base { .decl Foo(x: number) overridable Foo(1). }
.comp Sub : Base { .override Bar Foo(10). }
.init s = Sub
.output s.Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverrideUnknownRelation { .. }),
"got {err:?}"
);
}
#[test]
fn override_propagates_through_inheritance_chain() {
let src = "
.comp Top { .decl Foo(x: number) overridable Foo(1). }
.comp Mid1 : Top { .override Foo Foo(2). }
.comp Bot : Mid1 { }
.init b = Bot
.output b.Foo
";
let program = parse_program(src);
assert_eq!(fact_numbers(&program, "b·foo"), vec![2]);
}
#[test]
fn override_parametric_type_substitution() {
let src = "
.comp Base<T> {
.decl Foo(x: T) overridable
Foo(0).
}
.comp Sub<T> : Base<T> {
.override Foo
Foo(42).
}
.init s = Sub<number>
.output s.Foo
";
let program = parse_program(src);
assert_eq!(fact_numbers(&program, "s·foo"), vec![42]);
}
#[test]
fn override_to_empty_drops_parent_derivations() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base {
.override Foo
}
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
assert!(program.facts().get("s·foo").is_none());
assert!(program.rules().iter().all(|r| r.head().name() != "s·foo"));
}
#[test]
fn override_outside_comp_is_syntax_error() {
let src = "
.decl Foo(x: number)
.override Foo
.output Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(matches!(err, ParseError::Syntax { .. }), "got {err:?}");
}
#[test]
fn override_redeclaration_errors() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base {
.override Foo
.decl Foo(x: number)
Foo(10).
}
.init s = Sub
.output s.Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverrideRedeclaresRelation { .. }),
"got {err:?}"
);
}
#[test]
fn double_override_is_accepted() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base {
.override Foo
.override Foo
Foo(10).
}
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
let tuples = program.facts().get("s·foo").expect("s·foo facts");
assert_eq!(tuples.len(), 1);
}
#[test]
fn overridable_outside_comp_errors() {
let src = ".decl Foo(x: number) overridable\n.output Foo\n";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverridableOutsideComp { .. }),
"got {err:?}"
);
}
}