use super::{
ConstType, FlowLogParser, Lexeme, Rule,
declaration::{
CompDecl, ExternFn, InitDecl, InputDirective, OutputDirective, PrintSizeDirective,
RawTypeOp, Relation, split_type_alias,
},
error::{DirectiveKind, ParseError, grammar_bug},
inliner,
logic::{
Arithmetic, AtomArg, Factor, FlowLogRule, FnCall, Head, LoopBlock, Predicate,
consume_plan_directive,
},
primitive::TypeRegistry,
segment::Segment,
};
use crate::common::{FileId, SourceMap, Span};
use pest::{Parser, iterators::Pair};
use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::{fmt, fs};
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct Program {
relations: Vec<Relation>,
segments: Vec<Segment>,
udfs: Vec<ExternFn>,
facts: HashMap<String, Vec<(Span, Vec<ConstType>)>>,
empty_output_files: Vec<String>,
type_registry: TypeRegistry,
}
impl fmt::Display for Program {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "=============================================")?;
writeln!(f, "FlowLog DATALOG PROGRAM")?;
writeln!(f, "=============================================")?;
writeln!(f)?;
if !self.relations.is_empty() {
writeln!(f, "Relations")?;
writeln!(f, "---------------------------------------------")?;
for rel in &self.relations {
writeln!(f, "{}", rel)?;
}
writeln!(f)?;
}
if !self.udfs.is_empty() {
writeln!(f, "Extern Functions")?;
writeln!(f, "---------------------------------------------")?;
for udf in &self.udfs {
writeln!(f, "{}", udf)?;
}
writeln!(f)?;
}
if !self.segments.is_empty() {
writeln!(f, "Program (source order)")?;
writeln!(f, "---------------------------------------------")?;
for (i, item) in self.segments.iter().enumerate() {
match item {
Segment::Plain(rules) => {
writeln!(f, "[Segment {}]", i)?;
for rule in rules {
writeln!(f, " {}", rule)?;
}
}
Segment::Loop(block) | Segment::Fixpoint(block) => {
writeln!(f, "[Loop {}]", i)?;
writeln!(f, " {}", block)?;
}
}
}
writeln!(f)?;
}
if !self.facts.is_empty() {
writeln!(f, "Facts")?;
writeln!(f, "---------------------------------------------")?;
for (rel_name, facts) in &self.facts {
for (_, vals) in facts {
let values = vals
.iter()
.map(|c| c.to_string())
.collect::<Vec<_>>()
.join(", ");
writeln!(f, "{}({}).", rel_name, values)?;
}
}
}
Ok(())
}
}
impl Program {
pub fn parse(path: &str, extended: bool, sm: &mut SourceMap) -> Result<Self, ParseError> {
Self::parse_with_includes(path, extended, &[], sm)
}
pub fn parse_with_includes(
path: &str,
extended: bool,
include_dirs: &[&Path],
sm: &mut SourceMap,
) -> Result<Self, ParseError> {
let file_path = PathBuf::from(path);
let root_file = sm
.load(&file_path)
.map_err(|source| ParseError::IncludeIo {
span: Span::DUMMY,
path: file_path.clone(),
source,
})?;
let base_dir: PathBuf = file_path
.parent()
.map(Path::to_path_buf)
.unwrap_or_else(|| PathBuf::from("."));
let mut in_progress = HashSet::new();
let mut completed = HashSet::new();
in_progress.insert(fs::canonicalize(&file_path).unwrap_or_else(|_| file_path.clone()));
let combined = resolve_includes(
sm.text(root_file).to_string(),
root_file,
&base_dir,
include_dirs,
&mut in_progress,
&mut completed,
sm,
)?;
let combined_file = sm.add(file_path.clone(), combined);
let mut pairs = FlowLogParser::parse(Rule::main_grammar, sm.text(combined_file))
.map_err(|e| ParseError::syntax_from_pest(&e, combined_file))?;
let root = pairs
.next()
.ok_or_else(|| grammar_bug("no parsed rule found"))?;
let mut program = Self::collect_program(root, extended, combined_file)?;
program.prune_dead_components();
program.materialize_orphan_relations();
debug!("\n{}", program);
info!("Successfully parsed program from '{}'.", path);
Ok(program)
}
#[must_use]
#[inline]
pub(crate) fn relations(&self) -> &[Relation] {
&self.relations
}
#[must_use]
pub(crate) fn relation_by_fingerprint(&self, fp: u64) -> Option<&Relation> {
self.relations.iter().find(|rel| rel.fingerprint() == fp)
}
#[must_use]
pub fn edbs(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| self.is_edb_relation(rel))
.collect()
}
#[cfg(test)]
#[must_use]
#[inline]
pub(crate) fn file_backed_relations(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| rel.is_file_backed())
.collect()
}
#[cfg(test)]
#[must_use]
pub(crate) fn inline_fact_relations(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| self.has_inline_facts(rel.name()))
.collect()
}
#[must_use]
pub(crate) fn edb_names(&self) -> Vec<String> {
let mut names: Vec<String> = self
.edbs()
.iter()
.map(|rel| rel.name().to_string())
.collect();
names.sort_unstable();
names
}
#[must_use]
pub(crate) fn edb_fingerprints(&self) -> HashSet<u64> {
self.edbs().iter().map(|rel| rel.fingerprint()).collect()
}
#[must_use]
#[inline]
pub(crate) fn idbs(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| rel.is_output_printsize())
.collect()
}
#[must_use]
#[inline]
pub fn output_idbs(&self) -> Vec<&Relation> {
self.relations.iter().filter(|rel| rel.output()).collect()
}
#[must_use]
#[inline]
pub fn empty_output_files(&self) -> &[String] {
&self.empty_output_files
}
#[must_use]
#[inline]
pub fn printsize_idbs(&self) -> Vec<&Relation> {
self.relations
.iter()
.filter(|rel| rel.printsize())
.collect()
}
#[must_use]
#[inline]
pub(crate) fn segments(&self) -> &[Segment] {
&self.segments
}
#[must_use]
pub fn rules(&self) -> Vec<&FlowLogRule> {
self.segments
.iter()
.flat_map(|item| item.as_rules())
.collect()
}
pub(crate) fn segments_mut(&mut self) -> &mut [Segment] {
&mut self.segments
}
pub(crate) fn facts_mut(&mut self) -> &mut HashMap<String, Vec<(Span, Vec<ConstType>)>> {
&mut self.facts
}
#[must_use]
pub(crate) fn rule(&self, rid: usize) -> &FlowLogRule {
let mut offset = 0;
for seg in &self.segments {
let rules: &[FlowLogRule] = match seg {
Segment::Plain(rules) => rules,
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
};
if rid < offset + rules.len() {
return &rules[rid - offset];
}
offset += rules.len();
}
panic!("Parser error: rule ID {rid} out of bounds");
}
#[must_use]
#[inline]
pub fn facts(&self) -> &HashMap<String, Vec<(Span, Vec<ConstType>)>> {
&self.facts
}
#[must_use]
#[inline]
pub(crate) fn has_inline_facts(&self, relation_name: &str) -> bool {
self.facts.contains_key(relation_name)
}
#[must_use]
#[inline]
pub(crate) fn udfs(&self) -> &[ExternFn] {
&self.udfs
}
#[inline]
pub(crate) fn registry_and_segments_mut(&mut self) -> (&TypeRegistry, &mut [Segment]) {
(&self.type_registry, &mut self.segments)
}
#[inline]
fn is_edb_relation(&self, rel: &Relation) -> bool {
rel.has_input() || self.has_inline_facts(rel.name())
}
}
fn resolve_includes(
source: String,
source_file: FileId,
base_dir: &Path,
include_dirs: &[&Path],
in_progress: &mut HashSet<PathBuf>,
completed: &mut HashSet<PathBuf>,
sm: &mut SourceMap,
) -> Result<String, ParseError> {
let mut pairs = FlowLogParser::parse(Rule::main_grammar, &source)
.map_err(|e| ParseError::syntax_from_pest(&e, source_file))?;
let root = pairs
.next()
.ok_or_else(|| grammar_bug("no parsed rule found"))?;
let mut out = String::with_capacity(source.len());
let mut cursor = 0usize;
for node in root.into_inner() {
if node.as_rule() != Rule::include_directive {
continue;
}
let span = node.as_span();
let directive_span = Span::new(source_file, span.start() as u32, span.end() as u32);
out.push_str(&source[cursor..span.start()]);
cursor = span.end();
let path_node = node
.into_inner()
.next()
.ok_or_else(|| grammar_bug("include directive missing path"))?;
let raw = path_node.as_str().trim_matches('"');
let full_path = resolve_one_include(raw, base_dir, include_dirs);
let canonical = fs::canonicalize(&full_path).unwrap_or_else(|_| full_path.clone());
if in_progress.contains(&canonical) {
return Err(ParseError::CircularInclude {
span: directive_span,
path: full_path.clone(),
chain: in_progress.iter().cloned().collect(),
});
}
if completed.contains(&canonical) {
warn!("Skipping duplicate include '{}'.", full_path.display());
continue;
}
debug!("Including '{}'.", full_path.display());
let included_file = sm
.load(&full_path)
.map_err(|source| ParseError::IncludeIo {
span: directive_span,
path: full_path.clone(),
source,
})?;
let included_source = sm.text(included_file).to_string();
let included_base = full_path.parent().unwrap_or(Path::new(".")).to_path_buf();
in_progress.insert(canonical.clone());
let inlined = resolve_includes(
included_source,
included_file,
&included_base,
include_dirs,
in_progress,
completed,
sm,
)?;
in_progress.remove(&canonical);
completed.insert(canonical);
if out.chars().last().is_some_and(|c| !c.is_whitespace()) {
out.push('\n');
}
out.push_str(&inlined);
if out.chars().last().is_some_and(|c| !c.is_whitespace()) {
out.push('\n');
}
}
out.push_str(&source[cursor..]);
Ok(out)
}
fn resolve_one_include(raw: &str, base_dir: &Path, include_dirs: &[&Path]) -> PathBuf {
let parent_relative = base_dir.join(raw);
if parent_relative.exists() {
return parent_relative;
}
for dir in include_dirs {
let candidate = dir.join(raw);
if candidate.exists() {
return candidate;
}
}
parent_relative
}
fn build_type_registry(parsed_rule: Pair<Rule>, file: FileId) -> Result<TypeRegistry, ParseError> {
let mut registry = TypeRegistry::new();
for node in parsed_rule.into_inner() {
if node.as_rule() != Rule::type_alias_decl {
continue;
}
let (name, op, parent, span) = split_type_alias(node, file)?;
match op {
RawTypeOp::Subtype => {
registry.register_subtype(&name, &parent, span)?;
}
RawTypeOp::Alias => {
registry.register_alias(&name, &parent, span)?;
}
}
}
Ok(registry)
}
#[inline]
fn flush_rules(pending: &mut Vec<FlowLogRule>, out: &mut Vec<Segment>) {
if !pending.is_empty() {
out.push(Segment::Plain(std::mem::take(pending)));
}
}
fn normalize_inliner_dots(
relations: &mut [Relation],
segments: &mut [Segment],
raw_facts: &mut [FlowLogRule],
) {
for rel in relations.iter_mut() {
if rel.name().contains('.') {
let renamed = rel.raw_name().replace('.', INLINER_SEP);
rel.set_name(renamed);
}
}
for_each_rule_mut(segments, normalize_rule_dots);
for fact in raw_facts.iter_mut() {
normalize_rule_dots(fact);
}
}
fn normalize_rule_dots(rule: &mut FlowLogRule) {
let head = rule.head_mut();
if head.name().contains('.') {
head.set_name(head.name().replace('.', INLINER_SEP));
}
for pred in rule.rhs_mut() {
if let Predicate::PositiveAtom(a) | Predicate::NegativeAtom(a) = pred
&& a.name().contains('.')
{
a.set_name(a.name().replace('.', INLINER_SEP));
}
}
}
const INLINER_SEP: &str = "·";
fn for_each_rule_mut<F>(segments: &mut [Segment], mut f: F)
where
F: FnMut(&mut FlowLogRule),
{
for seg in segments.iter_mut() {
let rules: &mut [FlowLogRule] = match seg {
Segment::Plain(rs) => rs.as_mut_slice(),
Segment::Loop(b) | Segment::Fixpoint(b) => b.rules_mut(),
};
for rule in rules {
f(rule);
}
}
}
fn check_duplicate_directives<T>(
dirs: &[T],
kind: DirectiveKind,
name_of: impl Fn(&T) -> &str,
span_of: impl Fn(&T) -> Span,
) -> Result<(), ParseError> {
let mut seen: HashMap<&str, Span> = HashMap::new();
for d in dirs {
let name = name_of(d);
let span = span_of(d);
if let Some(prior) = seen.get(name) {
return Err(ParseError::DuplicateDirective {
span,
prior: *prior,
kind,
name: name.to_string(),
});
}
seen.insert(name, span);
}
Ok(())
}
impl Program {
fn collect_program(
parsed_rule: Pair<Rule>,
extended: bool,
file: FileId,
) -> Result<Self, ParseError> {
let mut type_registry = build_type_registry(parsed_rule.clone(), file)?;
let mut relations: Vec<Relation> = Vec::new();
let mut decl_spans: HashMap<String, (String, Span)> = HashMap::new();
let mut input_directives: Vec<InputDirective> = Vec::new();
let mut output_directives: Vec<OutputDirective> = Vec::new();
let mut printsize_directives: Vec<PrintSizeDirective> = Vec::new();
let mut udfs: Vec<ExternFn> = Vec::new();
let mut raw_facts: Vec<FlowLogRule> = Vec::new();
let mut current_rules: Vec<FlowLogRule> = Vec::new();
let mut segments: Vec<Segment> = Vec::new();
let mut comps: HashMap<String, CompDecl> = HashMap::new();
let mut inits_at_pos: Vec<(InitDecl, usize)> = Vec::new();
let mut plan_target_start: Option<usize> = None;
for node in parsed_rule.into_inner() {
let node_rule = node.as_rule();
match node_rule {
Rule::declaration => {
let rel = Relation::from_parsed_rule_with_registry(node, file, &type_registry)?;
if let Some((_prev_raw, prior)) = decl_spans.get(rel.name()) {
return Err(ParseError::DuplicateDecl {
span: rel.span(),
prior: *prior,
name: rel.raw_name().to_string(),
});
}
decl_spans.insert(
rel.name().to_string(),
(rel.raw_name().to_string(), rel.span()),
);
relations.push(rel);
}
Rule::extern_fn => {
udfs.push(ExternFn::from_parsed_rule(node, file, &type_registry)?)
}
Rule::type_alias_decl => {} Rule::comp_decl => {
let comp = CompDecl::from_parsed_rule(node, file)?;
comps.insert(comp.name.clone(), comp);
}
Rule::init_decl => {
flush_rules(&mut current_rules, &mut segments);
let init = InitDecl::from_parsed_rule(node, file)?;
inits_at_pos.push((init, segments.len()));
}
Rule::input_directive => {
input_directives.push(InputDirective::from_parsed_rule(node, file)?)
}
Rule::output_directive => {
output_directives.push(OutputDirective::from_parsed_rule(node, file)?)
}
Rule::printsize_directive => {
printsize_directives.push(PrintSizeDirective::from_parsed_rule(node, file)?)
}
Rule::rule => {
let start = current_rules.len();
current_rules.extend(FlowLogRule::expand_from_parsed_rule(node, file)?);
plan_target_start = Some(start);
}
Rule::plan_directive => {
consume_plan_directive(node, file, &mut current_rules, &mut plan_target_start)?;
}
Rule::loop_block => {
let block = LoopBlock::from_parsed_rule(node, file)?;
if !extended {
return Err(ParseError::LoopBlockInStandardMode { span: block.span() });
}
flush_rules(&mut current_rules, &mut segments);
segments.push(Segment::Loop(block));
}
Rule::fixpoint_block => {
let block = LoopBlock::from_parsed_rule(node, file)?;
if !extended {
return Err(ParseError::LoopBlockInStandardMode { span: block.span() });
}
flush_rules(&mut current_rules, &mut segments);
segments.push(Segment::Fixpoint(block));
}
Rule::fact => {
let head_node = node
.into_inner()
.next()
.ok_or_else(|| grammar_bug("fact missing head"))?;
raw_facts.push(FlowLogRule::new(
Head::from_parsed_rule(head_node, file)?,
vec![],
));
}
Rule::include_directive => {
return Err(grammar_bug(
"unexpected include_directive in parsed tree; includes should have been resolved before parsing",
));
}
_ => {}
}
if !matches!(node_rule, Rule::rule | Rule::plan_directive) {
plan_target_start = None;
}
}
flush_rules(&mut current_rules, &mut segments);
let global_instances: HashMap<String, String> = inits_at_pos
.iter()
.map(|(init, _)| (init.instance.to_lowercase(), init.instance.clone()))
.collect();
let global_decls: HashMap<String, String> = HashMap::new();
let mut shift = 0usize;
for (init, pos) in inits_at_pos {
let mut out = inliner::InlinerOutput::default();
inliner::inline_one(
"",
&global_instances,
&global_decls,
init,
&mut comps,
&mut out,
&mut type_registry,
)?;
for rel in out.relations {
if let Some((_prev_raw, prior)) = decl_spans.get(rel.name()) {
return Err(ParseError::DuplicateDecl {
span: rel.span(),
prior: *prior,
name: rel.raw_name().to_string(),
});
}
decl_spans.insert(
rel.name().to_string(),
(rel.raw_name().to_string(), rel.span()),
);
relations.push(rel);
}
raw_facts.extend(out.facts);
input_directives.extend(out.input_directives);
output_directives.extend(out.output_directives);
printsize_directives.extend(out.printsize_directives);
if !out.rules.is_empty() {
segments.insert(pos + shift, Segment::Plain(out.rules));
shift += 1;
}
}
Self::apply_directives(
&mut relations,
input_directives,
output_directives,
printsize_directives,
)?;
Self::validate_output_printsize_exclusion(&relations)?;
Self::reclassify_udf_predicates(&mut segments, &udfs)?;
Self::validate_loop_conditions(&segments, &relations)?;
normalize_inliner_dots(&mut relations, &mut segments, &mut raw_facts);
super::desugar::desugar_equality_assignments(&mut segments, &mut raw_facts)?;
let mut program = Self {
relations,
segments,
udfs,
type_registry,
..Self::default()
};
for fact in raw_facts {
program.extract_fact(fact)?;
}
program.validate_relation_references()?;
Ok(program)
}
fn materialize_orphan_relations(&mut self) {
let mut produced: HashSet<String> = HashSet::new();
let mut referenced: HashSet<String> = HashSet::new();
for segment in &self.segments {
let rules: &[FlowLogRule] = match segment {
Segment::Plain(rules) => rules,
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
};
for rule in rules {
produced.insert(rule.head().name().to_string());
for pred in rule.rhs() {
if let Predicate::PositiveAtom(atom) | Predicate::NegativeAtom(atom) = pred {
referenced.insert(atom.name().to_string());
}
}
}
}
let orphans: Vec<String> = self
.relations
.iter()
.filter(|rel| {
let name = rel.name();
referenced.contains(name)
&& !produced.contains(name)
&& !self.facts.contains_key(name)
&& !rel.has_input()
})
.map(|rel| rel.name().to_string())
.collect();
for name in orphans {
self.facts.entry(name).or_default();
}
}
fn validate_relation_references(&self) -> Result<(), ParseError> {
let declared: HashSet<&str> = self.relations.iter().map(|r| r.name()).collect();
for segment in &self.segments {
let rules: &[FlowLogRule] = match segment {
Segment::Plain(rules) => rules,
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
};
for rule in rules {
let head = rule.head();
if !declared.contains(head.name()) {
return Err(ParseError::UndeclaredInRule {
span: head.span(),
name: head.name().to_string(),
});
}
for pred in rule.rhs() {
if let Predicate::PositiveAtom(atom) | Predicate::NegativeAtom(atom) = pred
&& !declared.contains(atom.name())
{
return Err(ParseError::UndeclaredInRule {
span: atom.span(),
name: atom.name().to_string(),
});
}
}
}
}
for (rel_name, tuples) in &self.facts {
if !declared.contains(rel_name.as_str()) {
let span = tuples.first().map(|(s, _)| *s).unwrap_or(Span::DUMMY);
return Err(ParseError::UndeclaredInFact {
span,
name: rel_name.clone(),
});
}
}
Ok(())
}
fn apply_directives(
relations: &mut [Relation],
input_directives: Vec<InputDirective>,
output_directives: Vec<OutputDirective>,
printsize_directives: Vec<PrintSizeDirective>,
) -> Result<(), ParseError> {
check_duplicate_directives(
&input_directives,
DirectiveKind::Input,
|d| d.relation_name(),
|d| d.span(),
)?;
check_duplicate_directives(
&output_directives,
DirectiveKind::Output,
|d| d.relation_name(),
|d| d.span(),
)?;
check_duplicate_directives(
&printsize_directives,
DirectiveKind::PrintSize,
|d| d.relation_name(),
|d| d.span(),
)?;
for d in input_directives {
match relations.iter_mut().find(|r| r.name() == d.relation_name()) {
Some(rel) => rel.set_input_params(d.parameters().clone()),
None => {
return Err(ParseError::UndeclaredInDirective {
span: d.span(),
kind: DirectiveKind::Input,
name: d.relation_name().to_string(),
});
}
}
}
for d in output_directives {
match relations.iter_mut().find(|r| r.name() == d.relation_name()) {
Some(rel) => {
rel.set_output(true);
if !d.parameters().is_empty() {
rel.set_output_params(d.parameters().clone())?;
}
}
None => {
return Err(ParseError::UndeclaredInDirective {
span: d.span(),
kind: DirectiveKind::Output,
name: d.relation_name().to_string(),
});
}
}
}
for d in printsize_directives {
match relations.iter_mut().find(|r| r.name() == d.relation_name()) {
Some(rel) => rel.set_printsize(true),
None => {
return Err(ParseError::UndeclaredInDirective {
span: d.span(),
kind: DirectiveKind::PrintSize,
name: d.relation_name().to_string(),
});
}
}
}
Ok(())
}
fn validate_output_printsize_exclusion(relations: &[Relation]) -> Result<(), ParseError> {
for rel in relations {
if rel.output() && rel.printsize() {
return Err(ParseError::OutputAndPrintsizeConflict {
span: rel.span(),
name: rel.raw_name().to_string(),
});
}
}
Ok(())
}
fn validate_loop_conditions(
items: &[Segment],
relations: &[Relation],
) -> Result<(), ParseError> {
let declared: HashSet<&str> = relations.iter().map(|r| r.name()).collect();
for item in items {
let Some(block) = item.as_loop() else {
continue;
};
for directive in block.iterative_relations() {
let name = directive.name();
if !declared.contains(name) {
return Err(ParseError::UndeclaredInIterativeList {
span: block.span(),
name: name.to_string(),
});
}
}
let Some(cond) = block.condition() else {
continue;
};
let Some(until_group) = cond.until_part() else {
continue;
};
for rel in until_group.relations() {
let name = rel.name();
if !declared.contains(name) && !declared.contains(name.to_lowercase().as_str()) {
return Err(ParseError::UndeclaredLoopCondition {
span: block.span(),
name: name.to_string(),
});
}
let decl = relations
.iter()
.find(|r| r.name() == name || r.name() == name.to_lowercase().as_str())
.ok_or_else(|| grammar_bug("already confirmed declared above"))?;
if decl.arity() != 0 {
return Err(ParseError::NonNullaryLoopCondition {
span: block.span(),
name: decl.raw_name().to_string(),
arity: decl.arity(),
});
}
}
}
Ok(())
}
fn reclassify_udf_predicates(
items: &mut [Segment],
udfs: &[ExternFn],
) -> Result<(), ParseError> {
let udf_names: HashSet<&str> = udfs.iter().map(ExternFn::name).collect();
if udf_names.is_empty() {
return Ok(());
}
let mut result = Ok(());
for_each_rule_mut(items, |rule| {
if result.is_ok() {
result = Self::reclassify_one_rule(rule, &udf_names);
}
});
result
}
fn reclassify_one_rule(
rule: &mut FlowLogRule,
udf_names: &HashSet<&str>,
) -> Result<(), ParseError> {
let needs_rewrite = rule.rhs().iter().any(|p| {
matches!(
p,
Predicate::PositiveAtom(a) | Predicate::NegativeAtom(a)
if udf_names.contains(a.name())
)
});
if !needs_rewrite {
return Ok(());
}
let mut new_rhs = Vec::with_capacity(rule.rhs().len());
for pred in rule.rhs() {
new_rhs.push(match pred {
Predicate::PositiveAtom(atom) | Predicate::NegativeAtom(atom)
if udf_names.contains(atom.name()) =>
{
let mut args = Vec::with_capacity(atom.arguments().len());
for a in atom.arguments() {
match a {
AtomArg::Var(v) => {
args.push(Arithmetic::new(Factor::Var(v.clone()), vec![]));
}
AtomArg::Const(c) => {
args.push(Arithmetic::new(Factor::Const(c.clone()), vec![]));
}
AtomArg::Placeholder => {
return Err(ParseError::PlaceholderInUdf {
span: atom.span(),
udf_name: atom.name().to_string(),
});
}
}
}
Predicate::FnCall(
FnCall::new(
atom.name().to_string(),
args,
matches!(pred, Predicate::NegativeAtom(_)),
)
.with_span(atom.span()),
)
}
other => other.clone(),
});
}
*rule = FlowLogRule::new(rule.head().clone(), new_rhs);
Ok(())
}
fn extract_fact(&mut self, fact_rule: FlowLogRule) -> Result<(), ParseError> {
let rel_name = fact_rule.head().name().to_string();
let span = fact_rule.head().span();
let tuple = fact_rule.extract_constants_from_head()?;
self.facts.entry(rel_name).or_default().push((span, tuple));
Ok(())
}
}
const NO_TOP_LEVEL_RULE_ID: usize = usize::MAX;
impl Program {
#[must_use]
fn identify_needed_components(&self) -> ((HashSet<usize>, HashSet<String>), HashSet<String>) {
let all_rules: Vec<&FlowLogRule> = self
.segments
.iter()
.flat_map(|item| match item {
Segment::Plain(rules) => rules.as_slice(),
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
})
.collect();
let mut needed_preds: HashSet<String> = self
.idbs()
.into_iter()
.map(|d| d.name().to_string())
.collect();
needed_preds.extend(self.facts.keys().cloned());
needed_preds.extend(
self.segments
.iter()
.filter_map(Segment::as_loop)
.flat_map(|block| {
block
.condition()
.and_then(|cond| cond.until_part())
.into_iter()
.flat_map(|stop| stop.relations().map(|rel| rel.name().to_string()))
}),
);
if needed_preds.is_empty() {
let all_indices = (0..all_rules.len()).collect();
let all_preds = self
.relations
.iter()
.map(|d| d.name().to_string())
.collect();
return ((all_indices, all_preds), HashSet::new());
}
let mut head_to_rules: HashMap<String, Vec<usize>> = HashMap::new();
for (i, r) in all_rules.iter().enumerate() {
head_to_rules
.entry(r.head().name().to_string())
.or_default()
.push(i);
}
let input_relations: HashSet<String> = self
.relations
.iter()
.filter(|r| r.has_input())
.map(|r| r.name().to_string())
.collect();
let underived: Vec<String> = needed_preds
.iter()
.filter(|p| {
!head_to_rules.contains_key(p.as_str())
&& !self.facts.contains_key(p.as_str())
&& !input_relations.contains(p.as_str())
})
.cloned()
.collect();
for name in &underived {
needed_preds.remove(name);
}
let mut needed_rules: HashSet<usize> = needed_preds
.iter()
.flat_map(|p| head_to_rules.get(p).into_iter().flatten().copied())
.collect();
let dep_map: HashMap<usize, Vec<(usize, String)>> = all_rules
.iter()
.enumerate()
.map(|(i, r)| {
let deps = r
.rhs()
.iter()
.filter_map(|pred| match pred {
Predicate::PositiveAtom(a) | Predicate::NegativeAtom(a) => Some(a.name()),
_ => None,
})
.flat_map(|atom_name| {
if let Some(ids) = head_to_rules.get(atom_name) {
ids.iter()
.map(|&id| (id, atom_name.to_string()))
.collect::<Vec<_>>()
} else {
vec![(NO_TOP_LEVEL_RULE_ID, atom_name.to_string())]
}
})
.collect();
(i, deps)
})
.collect();
let mut processed: HashSet<usize> = HashSet::new();
let mut stack: Vec<usize> = needed_rules.iter().copied().collect();
while let Some(rule_id) = stack.pop() {
if !processed.insert(rule_id) {
continue;
}
for &(dep_rule_id, ref pred_name) in dep_map.get(&rule_id).into_iter().flatten() {
needed_preds.insert(pred_name.clone());
if dep_rule_id != NO_TOP_LEVEL_RULE_ID && !processed.contains(&dep_rule_id) {
needed_rules.insert(dep_rule_id);
stack.push(dep_rule_id);
}
}
}
let underived: HashSet<String> = underived.into_iter().collect();
((needed_rules, needed_preds), underived)
}
fn prune_dead_components(&mut self) {
let ((needed_rules, needed_preds), underived) = self.identify_needed_components();
let dead_relations: Vec<_> = self
.relations
.iter()
.filter(|d| !needed_preds.contains(d.name()) && !underived.contains(d.name()))
.map(|d| d.raw_name().to_string())
.collect();
let dead_rules: Vec<_> = self
.segments
.iter()
.flat_map(|item| match item {
Segment::Plain(rules) => rules.as_slice(),
Segment::Loop(block) | Segment::Fixpoint(block) => block.rules(),
})
.enumerate()
.filter(|(i, _)| !needed_rules.contains(i))
.map(|(i, r)| format!("#{}: {}", i, r))
.collect();
if !underived.is_empty() || !dead_relations.is_empty() || !dead_rules.is_empty() {
let mut parts = Vec::new();
if !underived.is_empty() {
let mut sorted: Vec<&str> = self
.relations
.iter()
.filter(|r| underived.contains(r.name()))
.map(Relation::raw_name)
.collect();
sorted.sort_unstable();
parts.push(format!(
" underived IDBs (declared but no rules): {}",
sorted.join(", ")
));
}
if !dead_relations.is_empty() {
parts.push(format!(
" unreachable relations: {}",
dead_relations.join(", ")
));
}
if !dead_rules.is_empty() {
parts.push(format!(" unreachable rules: {}", dead_rules.join(", ")));
}
warn!("Pruned dead components:\n{}", parts.join("\n"));
}
self.empty_output_files = self
.relations
.iter()
.filter(|d| d.output() && !needed_preds.contains(d.name()))
.map(Relation::output_file_name)
.collect();
self.relations.retain(|d| needed_preds.contains(d.name()));
let mut global_idx = 0usize;
let new_items: Vec<Segment> = self
.segments
.drain(..)
.filter_map(|item| match item {
Segment::Plain(rules) => {
let filtered: Vec<FlowLogRule> = rules
.into_iter()
.filter(|_| {
let keep = needed_rules.contains(&global_idx);
global_idx += 1;
keep
})
.collect();
if filtered.is_empty() {
None
} else {
Some(Segment::Plain(filtered))
}
}
Segment::Loop(mut block) => {
block.rules_mut().retain(|_| {
let keep = needed_rules.contains(&global_idx);
global_idx += 1;
keep
});
if block.rules().is_empty() {
None
} else {
Some(Segment::Loop(block))
}
}
Segment::Fixpoint(mut block) => {
block.rules_mut().retain(|_| {
let keep = needed_rules.contains(&global_idx);
global_idx += 1;
keep
});
if block.rules().is_empty() {
None
} else {
Some(Segment::Fixpoint(block))
}
}
})
.collect();
self.segments = new_items;
self.facts
.retain(|rel, _| needed_preds.contains(rel.as_str()));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::ComparisonOperator;
use crate::parser::DataType;
use crate::parser::HeadArg;
use std::io::Write;
fn loop_blocks(program: &Program) -> Vec<&LoopBlock> {
program
.segments()
.iter()
.filter_map(|s| s.as_loop())
.collect()
}
fn parse_program(src: &str) -> Program {
parse_program_result(src).expect("parse failed")
}
fn parse_program_result(src: &str) -> Result<Program, ParseError> {
let mut tmp = tempfile::NamedTempFile::new().expect("failed to create temp file");
tmp.write_all(src.as_bytes())
.expect("failed to write temp file");
let mut sm = SourceMap::new();
Program::parse(&tmp.path().to_string_lossy(), true, &mut sm)
}
fn find_relation<'a>(program: &'a Program, name: &str) -> &'a Relation {
program
.relations()
.iter()
.find(|r| r.name() == name)
.unwrap_or_else(|| panic!("relation `{name}` not found"))
}
#[test]
fn decl_case_collision_rejected() {
let err = parse_program_result(
"
.decl edge(x: number)
.decl Edge(y: number)
",
)
.unwrap_err();
assert!(
matches!(err, ParseError::DuplicateDecl { .. }),
"got {err:?}"
);
}
#[test]
fn attr_case_collision_rejected() {
let err = parse_program_result(".decl edge(x: number, X: number)").unwrap_err();
assert!(
matches!(err, ParseError::DuplicateAttribute { .. }),
"got {err:?}"
);
}
#[test]
fn ordering_preserved() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.output c
a(X) :- b(X).
fixpoint { b(X) :- a(X). }
c(X) :- a(X).
";
let program = parse_program(src);
let items = program.segments();
assert_eq!(items.len(), 3);
assert!(matches!(&items[0], Segment::Plain(r) if r.len() == 1));
assert!(matches!(&items[1], Segment::Fixpoint(_)));
assert!(matches!(&items[2], Segment::Plain(r) if r.len() == 1));
}
#[test]
fn multiple_rules_before_loop_form_one_segment() {
let src = "
.decl a(x: number)
.decl b(x: number)
.output a
a(1) :- b(1).
a(2) :- b(2).
fixpoint { a(X) :- b(X). }
";
let program = parse_program(src);
let items = program.segments();
assert_eq!(items.len(), 2);
assert!(matches!(&items[0], Segment::Plain(r) if r.len() == 2));
assert!(matches!(&items[1], Segment::Fixpoint(_)));
}
#[test]
fn rules_method_flattens_segments() {
let src = "
.decl a(x: number)
.decl b(x: number)
.output a
a(X) :- b(X).
fixpoint { }
a(1) :- b(1).
";
let program = parse_program(src);
assert_eq!(program.rules().len(), 2);
}
#[test]
fn loop_block_with_declared_relation() {
let src = "
.decl done()
.decl edge(x: number, y: number)
.output done
loop until { done } { done() :- edge(1, 2). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
}
#[test]
fn fixpoint_needs_no_declaration() {
let src = "
.decl edge(x: number, y: number)
.output edge
fixpoint { edge(1, 2) :- edge(1, 2). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
}
#[test]
fn loop_while_needs_no_declaration() {
let src = "
.decl edge(x: number, y: number)
.output edge
loop while { @it <= 9 } { edge(1, 2) :- edge(1, 2). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
}
#[test]
fn iterative_declared_passes() {
let src = "
.decl edge(x: number, y: number)
.decl active_edge(x: number, y: number)
.output active_edge
fixpoint { .iterative active_edge active_edge(X, Y) :- edge(X, Y). }
";
let program = parse_program(src);
assert_eq!(loop_blocks(&program).len(), 1);
assert_eq!(loop_blocks(&program)[0].iterative_relations().len(), 1);
}
#[test]
fn dead_code_elimination_keeps_loop_until_relations() {
let src = "
.decl edge(x: number, y: number)
.decl keep()
.decl dead()
.output edge
edge(1, 2).
loop until { keep } {
keep() :- edge(1, 2).
}
dead() :- edge(2, 3).
";
let program = parse_program(src);
assert!(program.relations().iter().any(|rel| rel.name() == "keep"));
assert!(!program.relations().iter().any(|rel| rel.name() == "dead"));
}
#[test]
fn edb_subsets_track_file_backed_inline_and_overlap_relations() {
let src = "
.decl file_only(x: number)
.decl fact_only(x: number)
.decl both(x: number)
.decl out(x: number)
.input file_only(IO=\"file\", filename=\"file_only.csv\", delimiter=\",\")
.input both(IO=\"file\", filename=\"both.csv\", delimiter=\",\")
.output out
fact_only(1).
both(2).
out(X) :- file_only(X).
out(X) :- fact_only(X).
out(X) :- both(X).
";
let program = parse_program(src);
let mut edbs = program
.edbs()
.into_iter()
.map(|rel| rel.name().to_string())
.collect::<Vec<_>>();
edbs.sort_unstable();
let mut file_backed = program
.file_backed_relations()
.into_iter()
.map(|rel| rel.name().to_string())
.collect::<Vec<_>>();
file_backed.sort_unstable();
let mut inline_facts = program
.inline_fact_relations()
.into_iter()
.map(|rel| rel.name().to_string())
.collect::<Vec<_>>();
inline_facts.sort_unstable();
assert_eq!(edbs, vec!["both", "fact_only", "file_only"]);
assert_eq!(file_backed, vec!["both", "file_only"]);
assert_eq!(inline_facts, vec!["both", "fact_only"]);
}
#[test]
fn bare_input_directive_uses_souffle_defaults() {
let src = "
.decl Edge(a: symbol, b: symbol)
.input Edge
";
let program = parse_program(src);
let edge = find_relation(&program, "edge");
assert!(edge.has_input(), "bare .input attaches params");
assert!(edge.is_file_backed(), "absent IO= defaults to file");
assert_eq!(edge.input_file_name(), "Edge.facts");
assert_eq!(edge.input_delimiter(), "\t");
}
#[test]
fn multi_head_rule_expands() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.output b
.output c
b(X), c(X) :- a(X).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 2);
assert_eq!(rules[0].head().name(), "b");
assert_eq!(rules[1].head().name(), "c");
assert_eq!(rules[0].rhs().len(), 1);
assert_eq!(rules[1].rhs().len(), 1);
}
#[test]
fn multi_head_rule_in_fixpoint() {
let src = "
.decl a(x: number, y: number)
.decl b(x: number, y: number)
.decl c(x: number, y: number)
.output b
.output c
fixpoint { b(X, Y), c(X, Y) :- a(X, Y). }
";
let program = parse_program(src);
let blocks = loop_blocks(&program);
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].rules().len(), 2);
assert_eq!(blocks[0].rules()[0].head().name(), "b");
assert_eq!(blocks[0].rules()[1].head().name(), "c");
}
#[test]
fn multi_body_rule_expands() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.output c
c(X) :- a(X); b(X).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 2);
assert_eq!(rules[0].head().name(), "c");
assert_eq!(rules[1].head().name(), "c");
assert_eq!(rules[0].rhs()[0].name(), "a");
assert_eq!(rules[1].rhs()[0].name(), "b");
}
#[test]
fn disjunction_arm_can_be_a_conjunction() {
let src = "
.decl a(x: number) .decl b(x: number)
.decl c(x: number) .decl d(x: number)
.decl r(x: number)
.output r
r(X) :- ( a(X), b(X) ; c(X), d(X) ).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 2);
let bodies: Vec<Vec<&str>> = rules
.iter()
.map(|r| r.rhs().iter().map(|p| p.name()).collect())
.collect();
assert!(bodies.contains(&vec!["a", "b"]));
assert!(bodies.contains(&vec!["c", "d"]));
}
#[test]
fn nested_disjunctions_cross_product() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl d(x: number)
.decl r(x: number)
.output r
r(X) :- ( a(X) ; b(X) ), ( c(X) ; d(X) ).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 4);
let bodies: Vec<(&str, &str)> = rules
.iter()
.map(|r| (r.rhs()[0].name(), r.rhs()[1].name()))
.collect();
assert!(bodies.contains(&("a", "c")));
assert!(bodies.contains(&("a", "d")));
assert!(bodies.contains(&("b", "c")));
assert!(bodies.contains(&("b", "d")));
}
#[test]
fn multi_head_multi_body_expands() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl d(x: number)
.output c
.output d
c(X), d(X) :- a(X); b(X).
";
let program = parse_program(src);
let rules = program.rules();
assert_eq!(rules.len(), 4);
assert_eq!(rules[0].head().name(), "c");
assert_eq!(rules[0].rhs()[0].name(), "a");
assert_eq!(rules[1].head().name(), "c");
assert_eq!(rules[1].rhs()[0].name(), "b");
assert_eq!(rules[2].head().name(), "d");
assert_eq!(rules[2].rhs()[0].name(), "a");
assert_eq!(rules[3].head().name(), "d");
assert_eq!(rules[3].rhs()[0].name(), "b");
}
#[test]
fn diamond_include_dedups_leaf() {
let dir = tempfile::tempdir().expect("tempdir");
let write = |name: &str, body: &str| {
std::fs::write(dir.path().join(name), body).expect("write");
};
write(
"leaf.dl",
".decl leaf_rel(x: number)\n.output leaf_rel\nleaf_rel(1).\n",
);
write("left.dl", ".include \"leaf.dl\"\n");
write("right.dl", ".include \"leaf.dl\"\n");
write("root.dl", ".include \"left.dl\"\n.include \"right.dl\"\n");
let mut sm = SourceMap::new();
let program = Program::parse(&dir.path().join("root.dl").to_string_lossy(), true, &mut sm)
.expect("diamond include should succeed with dedup");
let rels: Vec<_> = program
.relations()
.iter()
.filter(|r| r.name() == "leaf_rel")
.collect();
assert_eq!(rels.len(), 1, "leaf_rel inlined twice");
}
#[test]
fn type_alias_chain_resolves_to_root() {
let src = "
.type C = number
.type B = C
.type A = B
.decl R(x: A, y: B, z: C)
.output R
R(1, 2, 3).
";
let program = parse_program(src);
let r = find_relation(&program, "r");
assert_eq!(
r.data_type(),
vec![DataType::Int32, DataType::Int32, DataType::Int32]
);
}
#[test]
fn negated_udf_reclassification_preserves_negation() {
let src = "
.decl edge(x: number, y: number)
.decl out(x: number, y: number)
.output out
.extern fn cost(x: number) -> number
out(X, Y) :- edge(X, Y), !cost(X).
";
let program = parse_program(src);
let rule = program.rules()[0];
let fn_call = rule
.rhs()
.iter()
.find_map(|p| match p {
Predicate::FnCall(fc) => Some(fc),
_ => None,
})
.expect("udf body atom should be reclassified to FnCall");
assert!(
fn_call.is_negated(),
"negation lost during reclassification"
);
assert_eq!(fn_call.name(), "cost");
}
fn fact_numbers(program: &Program, rel: &str) -> Vec<i64> {
program
.facts()
.get(rel)
.unwrap_or_else(|| panic!("no facts for `{rel}`"))
.iter()
.map(|(_, tuple)| match &tuple[0] {
ConstType::Int(n) => *n,
other => panic!("expected number in `{rel}`, got {other:?}"),
})
.collect()
}
#[test]
fn output_and_printsize_on_same_relation_rejected() {
let err = parse_program_result(
"
.decl R(x: number)
R(1).
.output R
.printsize R
",
)
.unwrap_err();
assert!(
matches!(err, ParseError::OutputAndPrintsizeConflict { ref name, .. } if name == "R"),
"expected OutputAndPrintsizeConflict, got {err:?}"
);
}
#[test]
fn output_and_printsize_rejected_regardless_of_order() {
let err = parse_program_result(
"
.decl R(x: number)
R(1).
.printsize R
.output R
",
)
.unwrap_err();
assert!(
matches!(err, ParseError::OutputAndPrintsizeConflict { .. }),
"expected OutputAndPrintsizeConflict, got {err:?}"
);
}
#[test]
fn output_and_printsize_inside_comp_rejected() {
let err = parse_program_result(
"
.comp C {
.decl Src(x: number)
.decl R(x: number)
Src(1).
R(x) :- Src(x).
.output R
.printsize R
}
.init c = C
",
)
.unwrap_err();
assert!(
matches!(err, ParseError::OutputAndPrintsizeConflict { .. }),
"expected OutputAndPrintsizeConflict for comp-internal pair, got {err:?}"
);
}
#[test]
fn empty_output_recorded_and_pruned_from_dataflow() {
let program = parse_program(
"
.decl Nothing(x: symbol)
.decl Src(x: symbol)
.decl Out(x: symbol)
Src(\"v\").
Out(x) :- Src(x).
.output Nothing
.output Out
",
);
assert!(
program.output_idbs().iter().all(|r| r.name() != "nothing"),
"empty `.output` should be pruned from output_idbs, got: {:?}",
program
.output_idbs()
.iter()
.map(|r| r.name())
.collect::<Vec<_>>()
);
assert_eq!(program.empty_output_files(), &["Nothing.csv"]);
assert!(program.output_idbs().iter().any(|r| r.name() == "out"));
}
#[test]
fn empty_output_filename_param_honored() {
let program = parse_program(
"
.decl Nothing(x: symbol)
.output Nothing(filename=\"custom.tsv\")
// companion derived rel to keep the dataflow non-empty so codegen works
.decl Filled(x: symbol)
Filled(\"v\").
.decl Out(x: symbol)
Out(x) :- Filled(x).
.output Out
",
);
assert_eq!(program.empty_output_files(), &["custom.tsv"]);
}
#[test]
fn infix_cat_no_longer_parses() {
let res = parse_program_result(
"
.decl A(x: symbol)
.decl B(x: symbol)
.decl C(x: symbol)
A(\"hi\").
B(\"there\").
C(x cat y) :- A(x), B(y).
",
);
assert!(
res.is_err(),
"infix `cat` should be a parse error, but parsed: {res:?}"
);
}
#[test]
fn inlined_relation_raw_name_keeps_literal_dot() {
let src = "
.comp C {
.decl R(x: symbol)
.decl S(x: symbol)
R(x) :- S(x).
.output R
}
.init c = C
";
let program = parse_program(src);
let r = find_relation(&program, "c·r");
assert_eq!(r.name(), "c·r");
assert_eq!(r.raw_name(), "c.R");
}
#[test]
fn member_type_resolves_when_nested_init_follows_decl() {
let src = "
.type Value = symbol
.comp Cfg { .type Context = symbol }
.comp Analysis<Configuration> {
.decl RunningThread(ctx:configuration.Context, v:Value)
.init configuration = Configuration
}
.init mainAnalysis = Analysis<Cfg>
";
let program = parse_program(src);
let r = find_relation(&program, "mainanalysis·runningthread");
assert_eq!(r.data_type(), vec![DataType::String, DataType::String]);
}
#[test]
fn self_referential_member_type_from_concrete_subtype() {
let src = "
.type Value = symbol
.type Invo = symbol
.comp AbstractConfiguration {
.decl ContextRequest(ctx:configuration.Context, invo:Invo)
}
.comp Analysis<Configuration> {
.init configuration = Configuration
.decl RunningThread(ctx:configuration.Context, v:Value)
}
.comp ConcreteConfiguration : AbstractConfiguration {
.type Context = symbol
}
.init mainAnalysis = Analysis<ConcreteConfiguration>
";
let program = parse_program(src);
let req = find_relation(&program, "mainanalysis·configuration·contextrequest");
assert_eq!(req.data_type(), vec![DataType::String, DataType::String]);
let thread = find_relation(&program, "mainanalysis·runningthread");
assert_eq!(thread.data_type(), vec![DataType::String, DataType::String]);
}
#[test]
fn comp_local_type_alias_resolves_as_attr_type() {
let src = "
.comp C {
.type MethodType = symbol
.decl R(mt:MethodType, i:number)
}
.init c = C
";
let program = parse_program(src);
let r = find_relation(&program, "c·r");
assert_eq!(r.data_type(), vec![DataType::String, DataType::Int32]);
}
#[test]
fn bare_member_type_from_concrete_subtype_resolves() {
let src = "
.type Invo = symbol
.comp AbstractConfiguration {
.decl ContextRequest(ctx:Context, invo:Invo)
}
.comp ConcreteConfiguration : AbstractConfiguration {
.type Context = symbol
}
.init c = ConcreteConfiguration
";
let program = parse_program(src);
let r = find_relation(&program, "c·contextrequest");
assert_eq!(r.data_type(), vec![DataType::String, DataType::String]);
}
#[test]
fn sibling_instance_relation_ref_resolves() {
let src = "
.comp Lib { .decl SubtypeOf(a:symbol, b:symbol) }
.init basic = Lib
.comp Analysis {
.decl R(x:symbol)
R(x) :- basic.SubtypeOf(x, _).
}
.init main = Analysis
";
let program = parse_program(src);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "main·r")
.expect("main·r rule");
let body: Vec<&str> = rule.rhs().iter().map(|p| p.name()).collect();
assert!(
body.contains(&"basic·subtypeof"),
"sibling ref should resolve to basic·subtypeof, got {body:?}"
);
}
#[test]
fn equality_assignment_arith_substituted_into_head() {
let program = parse_program(
"
.decl A(x:number)
.decl R(t:number)
.input A(IO=\"file\",filename=\"A.csv\")
R(t) :- A(x), t = x + 1.
.output R
",
);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "r")
.expect("r rule");
assert_eq!(rule.rhs().len(), 1, "equality literal should be dropped");
assert!(matches!(rule.rhs()[0], Predicate::PositiveAtom(_)));
assert!(
matches!(rule.head().head_arguments()[0], HeadArg::Arith(_)),
"head arg should carry the substituted arithmetic"
);
}
#[test]
fn equality_assignment_alias_substituted_into_head() {
let program = parse_program(
"
.decl A(x:symbol)
.decl R(y:symbol)
.input A(IO=\"file\",filename=\"A.csv\")
R(y) :- A(x), y = x.
.output R
",
);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "r")
.expect("r rule");
assert_eq!(rule.rhs().len(), 1);
match &rule.head().head_arguments()[0] {
HeadArg::Var(v) => assert_eq!(v, "x"),
other => panic!("expected aliased Var(x), got {other:?}"),
}
}
#[test]
fn equality_assignment_const_only_becomes_fact() {
let program = parse_program(
"
.decl P(t:symbol)
P(t) :- t = \"boolean\".
.output P
",
);
assert!(
program.facts().contains_key("p"),
"const-only assignment rule should become a fact"
);
assert_eq!(program.facts()["p"].len(), 1);
assert!(
program.rules().iter().all(|r| r.head().name() != "p"),
"no derivation rule should remain for the fact relation"
);
}
#[test]
fn equality_between_bound_columns_is_kept_as_filter() {
let program = parse_program(
"
.decl A(x:number)
.decl B(t:number)
.decl R(x:number, t:number)
.input A(IO=\"file\",filename=\"A.csv\")
.input B(IO=\"file\",filename=\"B.csv\")
R(x, t) :- A(x), B(t), t = x.
.output R
",
);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "r")
.expect("r rule");
let compares = rule
.rhs()
.iter()
.filter(|p| matches!(p, Predicate::Compare(_)))
.count();
assert_eq!(compares, 1, "filter equality must be preserved");
}
#[test]
fn orphan_relation_referenced_by_live_rule_is_materialized_empty() {
let program = parse_program(
"
.decl O(x:symbol)
.decl I(x:symbol)
.input I(IO=\"file\",filename=\"I.csv\")
.decl R(x:symbol)
R(x) :- O(x), I(x).
.output R
",
);
assert!(
program.facts().contains_key("o"),
"orphan relation should be materialized"
);
assert!(
program.facts()["o"].is_empty(),
"materialized orphan must be empty"
);
assert!(
!program.facts().contains_key("i"),
".input relation must not be materialized as an orphan"
);
}
#[test]
fn equality_assignment_into_negation_with_arith_errors() {
let err = parse_program_result(
"
.decl A(x:number)
.decl B(t:number)
.decl R(x:number)
.input A(IO=\"file\",filename=\"A.csv\")
.input B(IO=\"file\",filename=\"B.csv\")
R(x) :- A(x), !B(t), t = x + 1.
.output R
",
)
.expect_err("computed value into negated atom should error");
assert!(
matches!(err, ParseError::AssignmentVarInNegation { .. }),
"expected AssignmentVarInNegation, got {err:?}"
);
}
#[test]
fn assignment_only_rule_folds_or_rejects() {
let program = parse_program(
"
.decl P(x:number)
P(x) :- x = 1 + 2.
.output P
",
);
assert!(program.facts().contains_key("p"));
assert!(program.rules().iter().all(|r| r.head().name() != "p"));
let err = parse_program_result(
"
.decl P(s:symbol)
P(s) :- s = cat(\"a\", \"b\").
.output P
",
)
.expect_err("builtin in ground head should be rejected");
assert!(
matches!(err, ParseError::GroundRuleNotConst { .. }),
"expected GroundRuleNotConst, got {err:?}"
);
let err = parse_program_result(
"
.decl P(x:number)
P(x) :- y = 1.
.output P
",
)
.expect_err("unbound head var in ground rule should be rejected");
assert!(
matches!(err, ParseError::GroundRuleNotConst { .. }),
"expected GroundRuleNotConst, got {err:?}"
);
let err = parse_program_result(
"
.decl P(x:number)
P(x) :- x = 1 / 0.
.output P
",
)
.expect_err("division by zero should be rejected");
assert!(matches!(err, ParseError::GroundRuleNotConst { .. }));
let program = parse_program(
"
.decl P(x:number)
P(x) :- x = (1 + 2) * 3.
.output P
",
);
assert!(program.facts().contains_key("p"));
}
#[test]
fn chained_assignments_resolve_and_substitute_into_filters() {
let program = parse_program(
"
.decl A(x:number)
.decl R(a:number, b:number)
.input A(IO=\"file\",filename=\"A.csv\")
R(a, b) :- A(x), b = a + 2, a = x + 1, b < 10.
.output R
",
);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "r")
.expect("r rule");
let mut compare_vars = Vec::new();
let mut compares = 0;
for pred in rule.rhs() {
if let Predicate::Compare(e) = pred {
compares += 1;
compare_vars.extend(e.left().vars().into_iter().cloned());
compare_vars.extend(e.right().vars().into_iter().cloned());
}
}
assert_eq!(compares, 1, "only the filter comparison remains");
assert!(
compare_vars.iter().all(|v| v == "x"),
"assignment vars must be fully substituted away: {rule}"
);
}
#[test]
fn multi_term_substitution_wraps_in_group() {
use crate::parser::Factor;
let program = parse_program(
"
.decl A(x:number, a:number, b:number)
.decl R(z:number)
.input A(IO=\"file\",filename=\"A.csv\")
R(z) :- A(x, a, b), y = a - b, z = x * y.
.output R
",
);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "r")
.expect("r rule");
let [HeadArg::Arith(arith)] = rule.head().head_arguments() else {
panic!("expected one arithmetic head arg");
};
let (_, factor) = &arith.rest()[0];
assert!(
matches!(factor, Factor::Group(inner) if !inner.rest().is_empty()),
"substituted multi-term value must be group-wrapped: {arith}"
);
let program = parse_program(
"
.decl A(g:number, x:number)
.decl S(g:number, s:number)
.input A(IO=\"file\",filename=\"A.csv\")
S(g, sum(t)) :- A(g, x), t = x + 1.
.output S
",
);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "s")
.expect("s rule");
let agg = rule
.head()
.head_arguments()
.iter()
.find_map(|a| match a {
HeadArg::Aggregation(agg) => Some(agg),
_ => None,
})
.expect("aggregation head arg");
assert!(
!agg.arithmetic().vars().iter().any(|v| *v == "t"),
"assignment var must be substituted inside the aggregation"
);
}
#[test]
fn assignment_inside_fixpoint_desugared() {
let program = parse_program(
"
.decl A(x:number)
.decl R(t:number)
.input A(IO=\"file\",filename=\"A.csv\")
fixpoint {
R(t) :- A(x), t = x + 1.
R(t) :- R(x), t = x + 1, x < 5.
}
.output R
",
);
for rule in program.rules() {
let assignments = rule
.rhs()
.iter()
.filter(|p| matches!(p, Predicate::Compare(e) if *e.operator() == ComparisonOperator::Equal))
.count();
assert_eq!(assignments, 0, "assignments inside blocks must desugar");
}
}
#[test]
fn single_factor_groups_are_transparent_to_desugar() {
let program = parse_program(
"
.decl P(t:symbol)
P(t) :- t = (\"boolean\").
.output P
",
);
assert!(
program.facts().contains_key("p"),
"grouped const-only assignment rule should become a fact"
);
let program = parse_program(
"
.decl A(x:number)
.decl R(t:number)
.input A(IO=\"file\",filename=\"A.csv\")
R(t) :- A(x), (t) = x.
.output R
",
);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "r")
.expect("r rule");
assert!(
!rule
.rhs()
.iter()
.any(|p| matches!(p, Predicate::Compare(_))),
"grouped assignment must be eliminated, not left as a filter"
);
parse_program(
"
.decl A(x:number)
.decl B(t:number)
.decl R(x:number)
.input A(IO=\"file\",filename=\"A.csv\")
.input B(IO=\"file\",filename=\"B.csv\")
R(x) :- A(x), !B(t), t = (x).
.output R
",
);
}
#[test]
fn comp_directive_targets_global_relation() {
let src = "
.decl G(x:symbol)
G(\"a\").
.comp C {
.decl L(x:symbol)
L(x) :- G(x).
.output G(IO=\"file\",filename=\"G.csv\",delimiter=\"\\t\")
}
.init c = C
";
let program = parse_program(src);
let g = find_relation(&program, "g");
assert!(
g.output(),
".output of a global relation from inside a comp should apply"
);
}
#[test]
fn override_drops_parent_facts() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
Foo(2).
}
.comp Sub : Base {
.override Foo
Foo(10).
}
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
assert_eq!(fact_numbers(&program, "s·foo"), vec![10]);
}
#[test]
fn override_drops_parent_derived_rule() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
.decl Seed(x: number)
Foo(x) :- Seed(x).
}
.comp Sub : Base {
.override Foo
Foo(x) :- Seed(x), x > 5.
}
.init s = Sub
.input s.Seed(IO=\"file\", filename=\"Seed.csv\", delimiter=\",\")
.output s.Foo
";
let program = parse_program(src);
let rules: Vec<_> = program
.rules()
.into_iter()
.filter(|r| r.head().name() == "s·foo")
.collect();
assert_eq!(rules.len(), 1, "exactly one s·foo rule survives");
let has_compare = rules[0]
.rhs()
.iter()
.any(|p| matches!(p, Predicate::Compare(_)));
assert!(has_compare, "override's filtered rule should survive");
}
#[test]
fn overridable_without_override_keeps_parent_facts() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base { }
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
let tuples = program.facts().get("s·foo").expect("s·foo facts");
assert_eq!(tuples.len(), 1);
}
#[test]
fn override_of_non_overridable_errors() {
let src = "
.comp Base { .decl Foo(x: number) Foo(1). }
.comp Sub : Base { .override Foo Foo(10). }
.init s = Sub
.output s.Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverrideOfNonOverridable { .. }),
"got {err:?}"
);
}
#[test]
fn override_unknown_relation_errors() {
let src = "
.comp Base { .decl Foo(x: number) overridable Foo(1). }
.comp Sub : Base { .override Bar Foo(10). }
.init s = Sub
.output s.Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverrideUnknownRelation { .. }),
"got {err:?}"
);
}
#[test]
fn override_propagates_through_inheritance_chain() {
let src = "
.comp Top { .decl Foo(x: number) overridable Foo(1). }
.comp Mid1 : Top { .override Foo Foo(2). }
.comp Bot : Mid1 { }
.init b = Bot
.output b.Foo
";
let program = parse_program(src);
assert_eq!(fact_numbers(&program, "b·foo"), vec![2]);
}
#[test]
fn override_parametric_type_substitution() {
let src = "
.comp Base<T> {
.decl Foo(x: T) overridable
Foo(0).
}
.comp Sub<T> : Base<T> {
.override Foo
Foo(42).
}
.init s = Sub<number>
.output s.Foo
";
let program = parse_program(src);
assert_eq!(fact_numbers(&program, "s·foo"), vec![42]);
}
#[test]
fn override_to_empty_drops_parent_derivations() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base {
.override Foo
}
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
assert!(program.facts().get("s·foo").is_none());
assert!(program.rules().iter().all(|r| r.head().name() != "s·foo"));
}
#[test]
fn override_outside_comp_is_syntax_error() {
let src = "
.decl Foo(x: number)
.override Foo
.output Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(matches!(err, ParseError::Syntax { .. }), "got {err:?}");
}
#[test]
fn override_redeclaration_errors() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base {
.override Foo
.decl Foo(x: number)
Foo(10).
}
.init s = Sub
.output s.Foo
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverrideRedeclaresRelation { .. }),
"got {err:?}"
);
}
#[test]
fn double_override_is_accepted() {
let src = "
.comp Base {
.decl Foo(x: number) overridable
Foo(1).
}
.comp Sub : Base {
.override Foo
.override Foo
Foo(10).
}
.init s = Sub
.output s.Foo
";
let program = parse_program(src);
let tuples = program.facts().get("s·foo").expect("s·foo facts");
assert_eq!(tuples.len(), 1);
}
#[test]
fn overridable_outside_comp_errors() {
let src = ".decl Foo(x: number) overridable\n.output Foo\n";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::OverridableOutsideComp { .. }),
"got {err:?}"
);
}
#[test]
fn plan_reorders_positive_atoms() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl h(x: number)
.output h
h(X) :- a(X), b(X), c(X).
.plan (3, 1, 2)
";
let program = parse_program(src);
let rule = program.rules()[0];
let names: Vec<&str> = rule.rhs().iter().map(|p| p.name()).collect();
assert_eq!(names, vec!["c", "a", "b"]);
}
#[test]
fn plan_souffle_form_applies_permutation() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl h(x: number)
.output h
h(X) :- a(X), b(X), c(X).
.plan 1:(3, 1, 2)
";
let program = parse_program(src);
let rule = program.rules()[0];
let names: Vec<&str> = rule.rhs().iter().map(|p| p.name()).collect();
assert_eq!(names, vec!["c", "a", "b"]);
}
#[test]
fn plan_skips_negations_and_pins_only_positives() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl d(x: number)
.decl h(x: number)
.output h
h(X) :- a(X), !d(X), b(X), c(X).
.plan (3, 1, 2)
";
let program = parse_program(src);
let rule = program.rules()[0];
let labelled: Vec<String> = rule
.rhs()
.iter()
.map(|p| match p {
Predicate::PositiveAtom(a) => a.name().to_string(),
Predicate::NegativeAtom(a) => format!("!{}", a.name()),
other => format!("{other}"),
})
.collect();
assert_eq!(labelled, vec!["c", "!d", "a", "b"]);
}
#[test]
fn plan_without_preceding_rule_errors() {
let src = "
.decl a(x: number)
.output a
.plan (1)
";
let err = parse_program_result(src).unwrap_err();
assert!(matches!(err, ParseError::PlanOrphan { .. }), "got {err:?}");
}
#[test]
fn plan_arity_mismatch_errors() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl h(x: number)
.output h
h(X) :- a(X), b(X).
.plan (1, 2, 3)
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(
err,
ParseError::PlanArityMismatch {
expected: 2,
found: 3,
..
}
),
"got {err:?}"
);
}
#[test]
fn plan_index_out_of_range_errors() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl h(x: number)
.output h
h(X) :- a(X), b(X).
.plan (1, 3)
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(
err,
ParseError::PlanIndexOutOfRange {
index: 3,
max: 2,
..
}
),
"got {err:?}"
);
}
#[test]
fn plan_duplicate_index_errors() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl h(x: number)
.output h
h(X) :- a(X), b(X).
.plan (1, 1)
";
let err = parse_program_result(src).unwrap_err();
assert!(
matches!(err, ParseError::PlanDuplicateIndex { index: 1, .. }),
"got {err:?}"
);
}
#[test]
fn plan_inside_fixpoint_block() {
let src = "
.decl a(x: number)
.decl b(x: number)
.decl c(x: number)
.decl h(x: number)
.output h
fixpoint {
h(X) :- a(X), b(X), c(X).
.plan (3, 1, 2)
}
";
let program = parse_program(src);
let block = loop_blocks(&program)[0];
let names: Vec<&str> = block.rules()[0].rhs().iter().map(|p| p.name()).collect();
assert_eq!(names, vec!["c", "a", "b"]);
}
#[test]
fn plan_inside_comp_body() {
let src = "
.comp C {
.decl A(x: number)
.decl B(x: number)
.decl D(x: number)
.decl H(x: number)
H(X) :- A(X), B(X), D(X).
.plan (3, 1, 2)
}
.init c = C
.output c.H
";
let program = parse_program(src);
let rule = program
.rules()
.into_iter()
.find(|r| r.head().name() == "c·h")
.expect("instantiated H rule");
let names: Vec<&str> = rule.rhs().iter().map(|p| p.name()).collect();
assert_eq!(names, vec!["c·d", "c·a", "c·b"]);
assert!(rule.plan_pinned(), "plan_pinned should survive inlining");
}
#[test]
fn plan_orphan_inside_comp_errors() {
let src = "
.comp C {
.decl A(x: number)
.plan (1)
}
.init c = C
.output c.A
";
let err = parse_program_result(src).unwrap_err();
assert!(matches!(err, ParseError::PlanOrphan { .. }), "got {err:?}");
}
}