use std::path::{Path, PathBuf};
use xlog_core::Result;
use xlog_ir::ExecutionPlan;
use xlog_stats::{StatsManager, StatsSnapshot};
use crate::lower::Lowerer;
use crate::module::ModuleError;
use crate::optimizer::Optimizer;
use crate::parser::parse_program;
use crate::resolver::ModuleResolver;
use crate::stratify::stratify;
use crate::{BodyLiteral, Program, Query, Rule as AstRule, Term};
pub struct Compiler {
lowerer: Lowerer,
}
use std::collections::HashMap;
use std::sync::Arc;
use xlog_core::{RelId, Schema};
impl Default for Compiler {
fn default() -> Self {
Self::new()
}
}
impl Compiler {
pub fn new() -> Self {
Self {
lowerer: Lowerer::new(),
}
}
pub fn set_max_active_rules(&mut self, max: usize) {
self.lowerer.set_max_active_rules(max);
}
pub fn compile(&mut self, source: &str) -> Result<ExecutionPlan> {
self.compile_with_stats_snapshot(source, None)
}
pub fn compile_with_stats_snapshot(
&mut self,
source: &str,
stats_snapshot: Option<&StatsSnapshot>,
) -> Result<ExecutionPlan> {
let program = parse_program(source)?;
self.compile_program_with_stats_snapshot(&program, stats_snapshot)
}
pub fn compile_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
self.compile_program_with_stats_snapshot(program, None)
}
pub fn compile_program_with_stats_snapshot(
&mut self,
program: &Program,
stats_snapshot: Option<&StatsSnapshot>,
) -> Result<ExecutionPlan> {
let program = desugar_queries_and_constraints(program);
let strata = stratify(&program)?;
let strata_preds: Vec<Vec<String>> = strata.into_iter().map(|s| s.predicates).collect();
self.lowerer.set_strata(strata_preds);
let mut cardinality_hints: HashMap<String, u64> = HashMap::new();
if let Some(snapshot) = stats_snapshot {
if !snapshot.rel_names.is_empty() {
let rel_name_by_id: HashMap<RelId, &str> = snapshot
.rel_names
.iter()
.map(|(id, name)| (*id, name.as_str()))
.collect();
for rel in &snapshot.relations {
if let Some(name) = rel_name_by_id.get(&rel.rel_id) {
cardinality_hints.insert((*name).to_string(), rel.cardinality);
}
}
}
}
self.lowerer.set_cardinality_hints(cardinality_hints);
let mut plan = self.lowerer.lower_program(&program)?;
let mut mgr = StatsManager::new();
let mut fact_counts: HashMap<String, u64> = HashMap::new();
for fact in program.facts() {
*fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
}
for (pred, rel_id) in self.lowerer.rel_ids() {
mgr.register_relation(*rel_id);
let rows = fact_counts.get(pred).copied().unwrap_or(0);
if rows > 0 {
mgr.update_cardinality(*rel_id, rows);
if let Some(schema) = self.lowerer.schemas().get(pred) {
mgr.update_byte_size(*rel_id, rows * schema.row_size_bytes() as u64);
}
}
}
if let Some(snapshot) = stats_snapshot {
if snapshot.rel_names.is_empty() {
mgr.merge_snapshot(snapshot);
} else {
let rel_name_by_id: HashMap<RelId, &str> = snapshot
.rel_names
.iter()
.map(|(id, name)| (*id, name.as_str()))
.collect();
for rel in &snapshot.relations {
let Some(pred) = rel_name_by_id.get(&rel.rel_id) else {
continue;
};
let Some(rel_id) = self.lowerer.rel_ids().get(*pred) else {
continue;
};
let mut remapped = rel.clone();
remapped.rel_id = *rel_id;
if let Some(schema) = self.lowerer.schemas().get(*pred) {
remapped.column_stats.retain(|col| {
col.col_idx < schema.arity()
&& schema.column_type(col.col_idx) == Some(col.dtype)
});
} else {
remapped.column_stats.clear();
}
mgr.register_relation(*rel_id);
if let Some(stats) = mgr.get_relation_stats_mut(*rel_id) {
*stats = remapped;
}
}
for js in &snapshot.join_selectivities {
if js.left_keys.len() != js.right_keys.len() {
continue;
}
let Some(left_pred) = rel_name_by_id.get(&js.left_rel) else {
continue;
};
let Some(right_pred) = rel_name_by_id.get(&js.right_rel) else {
continue;
};
let Some(&left_id) = self.lowerer.rel_ids().get(*left_pred) else {
continue;
};
let Some(&right_id) = self.lowerer.rel_ids().get(*right_pred) else {
continue;
};
let Some(left_schema) = self.lowerer.schemas().get(*left_pred) else {
continue;
};
let Some(right_schema) = self.lowerer.schemas().get(*right_pred) else {
continue;
};
if js.left_keys.iter().any(|&k| k >= left_schema.arity())
|| js.right_keys.iter().any(|&k| k >= right_schema.arity())
{
continue;
}
mgr.set_join_selectivity(
left_id,
right_id,
js.left_keys.clone(),
js.right_keys.clone(),
js.selectivity,
);
}
}
}
let schemas_by_rel_id: HashMap<RelId, Schema> = self
.lowerer
.rel_ids()
.iter()
.filter_map(|(pred, rel_id)| {
self.lowerer
.schemas()
.get(pred)
.map(|schema| (*rel_id, schema.clone()))
})
.collect();
let mut optimizer = Optimizer::new(Arc::new(mgr));
optimizer.set_schemas(schemas_by_rel_id);
for rules in &mut plan.rules_by_scc {
for rule in rules {
rule.body = optimizer.optimize(rule.body.clone());
}
}
Ok(plan)
}
pub fn reset(&mut self) {
self.lowerer = Lowerer::new();
}
pub fn rel_ids(&self) -> &HashMap<String, RelId> {
self.lowerer.rel_ids()
}
pub fn schemas(&self) -> &HashMap<String, Schema> {
self.lowerer.schemas()
}
}
fn desugar_queries_and_constraints(program: &Program) -> Program {
let mut out = program.clone();
for (i, constraint) in program.constraints.iter().enumerate() {
let pred = format!("__xlog_constraint_{}", i);
out.rules.push(AstRule {
head: crate::ast::Atom {
predicate: pred,
terms: vec![Term::Integer(1)],
},
body: constraint.body.clone(),
});
}
for (i, Query { atom }) in program.queries.iter().enumerate() {
let pred = format!("__xlog_query_{}", i);
let mut head_terms: Vec<Term> = Vec::new();
let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
for term in &atom.terms {
if let Term::Variable(name) = term {
if seen.insert(name.as_str()) {
head_terms.push(Term::Variable(name.clone()));
}
}
}
if head_terms.is_empty() {
head_terms.push(Term::Integer(1));
}
out.rules.push(AstRule {
head: crate::ast::Atom {
predicate: pred,
terms: head_terms,
},
body: vec![BodyLiteral::Positive(atom.clone())],
});
}
out
}
pub fn compile(source: &str) -> Result<ExecutionPlan> {
let mut compiler = Compiler::new();
compiler.compile(source)
}
pub fn load_modules(
entry_file: &Path,
search_paths: Vec<PathBuf>,
) -> std::result::Result<ModuleResolver, ModuleError> {
let mut resolver = ModuleResolver::new(search_paths);
let base_dir = entry_file.parent().unwrap_or(Path::new("."));
let module_name = entry_file
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("main");
resolver.load_module(base_dir, &[module_name.to_string()])?;
Ok(resolver)
}
#[cfg(test)]
mod tests {
use super::*;
use xlog_core::ScalarType;
use xlog_ir::RirNode;
use xlog_stats::RelationStats;
use xlog_stats::StatsManager;
#[test]
fn test_compiler_new() {
let compiler = Compiler::new();
drop(compiler);
}
#[test]
fn test_compile_fact() {
let mut compiler = Compiler::new();
let result = compiler.compile("edge(1, 2).");
assert!(result.is_ok(), "Failed to compile fact: {:?}", result.err());
}
#[test]
fn test_compile_simple_rule() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
edge(1, 2).
reach(X, Y) :- edge(X, Y).
"#,
);
assert!(
result.is_ok(),
"Failed to compile simple rule: {:?}",
result.err()
);
let plan = result.unwrap();
assert!(!plan.sccs.is_empty(), "Expected at least one SCC");
}
#[test]
fn test_compile_transitive_closure() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
edge(1, 2).
edge(2, 3).
edge(3, 4).
reach(X, Y) :- edge(X, Y).
reach(X, Z) :- reach(X, Y), edge(Y, Z).
"#,
);
assert!(result.is_ok(), "Failed to compile TC: {:?}", result.err());
let plan = result.unwrap();
assert!(!plan.sccs.is_empty());
}
#[test]
fn test_compile_with_negation() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
node(1).
node(2).
node(3).
edge(1, 2).
isolated(X) :- node(X), not edge(X, Y).
"#,
);
assert!(
result.is_ok(),
"Failed to compile with negation: {:?}",
result.err()
);
}
#[test]
fn test_compile_with_comparison() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
value(1).
value(5).
value(10).
value(15).
small(X) :- value(X), X < 10.
"#,
);
assert!(
result.is_ok(),
"Failed to compile with comparison: {:?}",
result.err()
);
}
#[test]
fn test_schema_infers_from_rule_body_types() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
edge(1, 2).
edge(2, 3).
reach(X, Y) :- edge(X, Y).
"#,
);
assert!(
result.is_ok(),
"Failed to compile rule for schema inference: {:?}",
result.err()
);
let schema = compiler
.schemas()
.get("reach")
.expect("missing reach schema");
assert_eq!(
schema.column_type(0),
Some(ScalarType::U32),
"reach column 0 should match edge column type"
);
assert_eq!(
schema.column_type(1),
Some(ScalarType::U32),
"reach column 1 should match edge column type"
);
}
#[test]
fn test_compile_unstratifiable_fails() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
p :- not q.
q :- not p.
"#,
);
assert!(result.is_err(), "Should fail with stratification cycle");
}
#[test]
fn test_compile_syntax_error_fails() {
let mut compiler = Compiler::new();
let result = compiler.compile("edge(1, 2"); assert!(result.is_err(), "Should fail with syntax error");
}
#[test]
fn test_compile_convenience_function() {
let result = compile("edge(1, 2).");
assert!(
result.is_ok(),
"Convenience compile failed: {:?}",
result.err()
);
}
#[test]
fn test_compiler_reset() {
let mut compiler = Compiler::new();
let result1 = compiler.compile("edge(1, 2).");
assert!(result1.is_ok());
compiler.reset();
let result2 = compiler.compile("node(1). node(2).");
assert!(result2.is_ok());
}
#[test]
fn test_compile_with_pred_decl() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
pred edge(u32, u32).
edge(1, 2).
edge(2, 3).
reach(X, Y) :- edge(X, Y).
"#,
);
assert!(
result.is_ok(),
"Failed to compile with pred decl: {:?}",
result.err()
);
}
#[test]
fn test_compile_multi_stratum() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
// Base facts
edge(1, 2).
edge(2, 3).
edge(3, 1).
// Stratum 0: edge (base)
// Stratum 1: reach (depends on edge, recursive)
reach(X, Y) :- edge(X, Y).
reach(X, Z) :- reach(X, Y), edge(Y, Z).
// Stratum 2: non_reach (negates reach)
all_pairs(X, Y) :- edge(X, Z), edge(Y, W).
non_reach(X, Y) :- all_pairs(X, Y), not reach(X, Y).
"#,
);
assert!(
result.is_ok(),
"Failed to compile multi-stratum: {:?}",
result.err()
);
let plan = result.unwrap();
assert!(!plan.strata.is_empty(), "Expected multiple strata");
}
#[test]
fn test_compile_aggregation() {
let mut compiler = Compiler::new();
let result = compiler.compile(
r#"
edge(1, 2).
edge(1, 3).
edge(2, 3).
out_degree(X, count(Y)) :- edge(X, Y).
"#,
);
assert!(
result.is_ok(),
"Failed to compile with aggregation: {:?}",
result.err()
);
let plan = result.unwrap();
let out_degree_rules: Vec<_> = plan
.rules_by_scc
.iter()
.flatten()
.filter(|r| r.head == "out_degree")
.collect();
assert_eq!(out_degree_rules.len(), 1, "Expected one out_degree rule");
let body = &out_degree_rules[0].body;
match body {
RirNode::Project { input, .. } => {
assert!(
matches!(input.as_ref(), RirNode::GroupBy { .. }),
"Expected Project(GroupBy(..)), got {:?}",
input
);
}
other => panic!("Expected Project(GroupBy(..)), got {:?}", other),
}
}
#[test]
fn test_compile_with_stats_snapshot() {
let mut compiler = Compiler::new();
let source = r#"
edge(1, 2).
edge(2, 3).
reach(X, Y) :- edge(X, Y).
"#;
let _ = compiler.compile(source).expect("Initial compile failed");
let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
let mut mgr = StatsManager::new();
mgr.register_relation(edge_id);
mgr.update_cardinality(edge_id, 42);
let snapshot = mgr.snapshot();
let plan = compiler
.compile_with_stats_snapshot(source, Some(&snapshot))
.expect("Compile with snapshot failed");
assert!(!plan.sccs.is_empty());
}
#[test]
fn test_compile_with_named_stats_snapshot_reorders_joins() {
let mut compiler = Compiler::new();
let source = r#"
foo(1).
edge(1).
out(X) :- edge(X), foo(X).
"#;
let mut edge_stats = RelationStats::new(RelId(0));
edge_stats.update_cardinality(10);
let mut foo_stats = RelationStats::new(RelId(1));
foo_stats.update_cardinality(10_000);
let snapshot = StatsSnapshot {
relations: vec![edge_stats, foo_stats],
join_selectivities: Vec::new(),
rel_names: vec![
(RelId(0), "edge".to_string()),
(RelId(1), "foo".to_string()),
],
};
let plan = compiler
.compile_with_stats_snapshot(source, Some(&snapshot))
.expect("Compile with named snapshot failed");
let foo_id = *compiler.rel_ids().get("foo").expect("foo rel_id missing");
let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
let out_rule = plan
.rules_by_scc
.iter()
.flatten()
.find(|r| r.head == "out")
.expect("out rule missing");
let mut node = &out_rule.body;
while let RirNode::Project { input, .. } = node {
node = input;
}
match node {
RirNode::Join { left, right, .. } => {
assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
}
other => panic!("Expected Join node, got {:?}", other),
}
}
}