Skip to main content

xlog_logic/
lower.rs

1//! Lowering from AST to IR
2//!
3//! This module transforms Datalog programs (AST) into the Relational IR (RIR)
4//! representation for execution. The lowering process:
5//!
6//! 1. Infers schemas from facts and predicate declarations
7//! 2. Tracks variable positions across atoms for join key computation
8//! 3. Builds left-deep join trees for multi-atom rule bodies
9//! 4. Handles negation via set difference (Diff) nodes
10//! 5. Wraps recursive predicates in Fixpoint nodes
11//! 6. Projects to match head variables
12
13use std::collections::{HashMap, HashSet};
14
15use xlog_core::{symbol, AggOp as CoreAggOp, RelId, Result, ScalarType, Schema, XlogError};
16use xlog_ir::{
17    CompareOp, CompiledRule, ConstValue, ExecutionPlan, Expr, JoinType, PlanBuilder, ProjectExpr,
18    RirMeta, RirNode, Scc, Stratum as IrStratum,
19};
20
21use crate::ast::{
22    AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Comparison, IsExpr, LearnableRule, PredColumn,
23    Program, Rule, Term, TypeRef,
24};
25use crate::stratify::{build_dependency_graph, find_sccs_for_lowering, DepType};
26
27struct JoinPlan<'a> {
28    node: RirNode,
29    leaf_order: Vec<&'a Atom>,
30    leaf_order_idx: Vec<usize>,
31    var_pos: HashMap<String, usize>,
32    width: usize,
33    est_rows: f64,
34    total_cost: f64,
35}
36
37fn pred_columns_for_decl(pred_decl: &crate::ast::PredDecl) -> Vec<PredColumn> {
38    if pred_decl.columns.is_empty() {
39        pred_decl
40            .types
41            .iter()
42            .cloned()
43            .map(|typ| PredColumn { name: None, typ })
44            .collect()
45    } else {
46        pred_decl.columns.clone()
47    }
48}
49
50fn resolve_pred_column_type(
51    predicate: &str,
52    index: usize,
53    typ: &TypeRef,
54    domains: &HashMap<String, ScalarType>,
55) -> Result<ScalarType> {
56    match typ {
57        TypeRef::Scalar(ty) => Ok(*ty),
58        TypeRef::Domain(name) => domains.get(name).copied().ok_or_else(|| {
59            XlogError::Compilation(format!(
60                "v0.8.5 unknown domain alias '{}' in predicate '{}' column {}",
61                name, predicate, index
62            ))
63        }),
64        TypeRef::List(_) | TypeRef::Term | TypeRef::Compound | TypeRef::PredRef => {
65            Ok(ScalarType::U64)
66        }
67    }
68}
69
70fn validate_lowerable_terms(program: &Program) -> Result<()> {
71    for rule in &program.rules {
72        validate_atom_terms(&rule.head, "rule head")?;
73        for lit in &rule.body {
74            match lit {
75                BodyLiteral::Positive(atom) => validate_atom_terms(atom, "positive body atom")?,
76                BodyLiteral::Negated(atom) => validate_atom_terms(atom, "negated body atom")?,
77                BodyLiteral::Epistemic(_) => {}
78                BodyLiteral::Comparison(cmp) => {
79                    validate_term_lowerable(&cmp.left, "comparison left operand")?;
80                    validate_term_lowerable(&cmp.right, "comparison right operand")?;
81                }
82                BodyLiteral::IsExpr(_) => {}
83                BodyLiteral::Univ(_) => {
84                    return Err(XlogError::Compilation(
85                        "v0.8.5 meta error: univ literal was not normalized before lowering"
86                            .to_string(),
87                    ));
88                }
89            }
90        }
91    }
92    for constraint in &program.constraints {
93        for lit in &constraint.body {
94            match lit {
95                BodyLiteral::Positive(atom) => validate_atom_terms(atom, "constraint body atom")?,
96                BodyLiteral::Negated(atom) => {
97                    validate_atom_terms(atom, "constraint negated body atom")?
98                }
99                BodyLiteral::Epistemic(_) => {}
100                BodyLiteral::Comparison(cmp) => {
101                    validate_term_lowerable(&cmp.left, "constraint comparison left operand")?;
102                    validate_term_lowerable(&cmp.right, "constraint comparison right operand")?;
103                }
104                BodyLiteral::IsExpr(_) => {}
105                BodyLiteral::Univ(_) => {
106                    return Err(XlogError::Compilation(
107                        "v0.8.5 meta error: univ literal was not normalized before lowering"
108                            .to_string(),
109                    ));
110                }
111            }
112        }
113    }
114    for query in &program.queries {
115        validate_atom_terms(&query.atom, "query atom")?;
116    }
117    for pf in &program.prob_facts {
118        validate_atom_terms(&pf.atom, "probabilistic fact")?;
119    }
120    for ad in &program.annotated_disjunctions {
121        for choice in &ad.choices {
122            validate_atom_terms(&choice.atom, "annotated disjunction choice")?;
123        }
124    }
125    for evidence in &program.evidence {
126        validate_atom_terms(&evidence.atom, "evidence atom")?;
127    }
128    for query in &program.prob_queries {
129        validate_atom_terms(&query.atom, "probabilistic query")?;
130    }
131    for neural in &program.neural_predicates {
132        validate_atom_terms(&neural.predicate, "neural predicate")?;
133    }
134    for learnable in &program.learnable_rules {
135        validate_atom_terms(&learnable.head, "learnable rule head")?;
136        for lit in &learnable.body {
137            if let BodyLiteral::Positive(atom) = lit {
138                validate_atom_terms(atom, "learnable rule body")?;
139            }
140        }
141    }
142    Ok(())
143}
144
145fn validate_atom_terms(atom: &Atom, context: &str) -> Result<()> {
146    for term in &atom.terms {
147        validate_term_lowerable(term, context)?;
148    }
149    Ok(())
150}
151
152fn validate_term_lowerable(term: &Term, context: &str) -> Result<()> {
153    match term {
154        Term::List(_) => Err(v085_term_not_lowerable(context, "list")),
155        Term::Cons { .. } => Err(v085_term_not_lowerable(context, "cons")),
156        Term::Compound { .. } => Err(v085_term_not_lowerable(context, "compound")),
157        Term::PredRef(_) => Err(v085_term_not_lowerable(context, "predref")),
158        Term::Variable(_)
159        | Term::Anonymous
160        | Term::Integer(_)
161        | Term::Float(_)
162        | Term::String(_)
163        | Term::Symbol(_)
164        | Term::Aggregate(_) => Ok(()),
165    }
166}
167
168fn v085_term_not_lowerable(context: &str, kind: &str) -> XlogError {
169    XlogError::Compilation(format!(
170        "v0.8.5 term form '{}' in {} is parsed but not lowerable before its G085 implementation node",
171        kind, context
172    ))
173}
174
175fn v085_term_kind(term: &Term) -> &'static str {
176    match term {
177        Term::List(_) => "list",
178        Term::Cons { .. } => "cons",
179        Term::Compound { .. } => "compound",
180        Term::PredRef(_) => "predref",
181        Term::Variable(_)
182        | Term::Anonymous
183        | Term::Integer(_)
184        | Term::Float(_)
185        | Term::String(_)
186        | Term::Symbol(_)
187        | Term::Aggregate(_) => "term",
188    }
189}
190
191/// Lowerer transforms AST programs into RIR execution plans.
192pub struct Lowerer {
193    /// Inferred or declared schemas for each predicate
194    schemas: HashMap<String, Schema>,
195    /// Stratification result (predicates grouped by strata)
196    strata: Vec<Vec<String>>,
197    /// Estimated cardinality per predicate (for join ordering)
198    est_cardinality: HashMap<String, u64>,
199    /// Optional cardinality hints per predicate (e.g., from runtime statistics).
200    cardinality_hints: HashMap<String, u64>,
201    /// Next available relation ID
202    next_rel_id: u32,
203    /// Mapping from predicate names to relation IDs
204    rel_ids: HashMap<String, RelId>,
205    /// SCCs for the program (from stratification)
206    sccs: Vec<Scc>,
207    /// Maximum active rules for TensorMaskedJoin (default 32)
208    max_active_rules: usize,
209}
210
211impl Default for Lowerer {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl Lowerer {
218    /// Create a new lowerer instance
219    pub fn new() -> Self {
220        Self {
221            schemas: HashMap::new(),
222            strata: Vec::new(),
223            est_cardinality: HashMap::new(),
224            cardinality_hints: HashMap::new(),
225            next_rel_id: 0,
226            rel_ids: HashMap::new(),
227            sccs: Vec::new(),
228            max_active_rules: 32,
229        }
230    }
231
232    /// Set the maximum active rules for TensorMaskedJoin.
233    pub fn set_max_active_rules(&mut self, max: usize) {
234        self.max_active_rules = max;
235    }
236
237    /// Set the stratification result for ordering
238    pub(crate) fn set_strata(&mut self, strata: Vec<Vec<String>>) {
239        self.strata = strata;
240    }
241
242    /// Set cardinality hints (typically sourced from runtime statistics snapshots).
243    ///
244    /// These hints are used by lowering-time join ordering when available.
245    pub(crate) fn set_cardinality_hints(&mut self, hints: HashMap<String, u64>) {
246        self.cardinality_hints = hints;
247    }
248
249    /// Get the mapping from predicate names to relation IDs
250    pub fn rel_ids(&self) -> &HashMap<String, RelId> {
251        &self.rel_ids
252    }
253
254    /// Get the inferred schemas for predicates
255    pub fn schemas(&self) -> &HashMap<String, Schema> {
256        &self.schemas
257    }
258
259    pub(crate) fn create_helper_relation(&mut self, schema: Schema) -> (String, RelId) {
260        let name = format!("__w37_helper_{}", self.next_rel_id);
261        let rel_id = self.get_or_create_rel_id(&name);
262        self.schemas.insert(name.clone(), schema);
263        (name, rel_id)
264    }
265
266    /// Get or allocate a relation ID for a predicate
267    fn get_or_create_rel_id(&mut self, name: &str) -> RelId {
268        if let Some(&id) = self.rel_ids.get(name) {
269            id
270        } else {
271            let id = RelId(self.next_rel_id);
272            self.next_rel_id += 1;
273            self.rel_ids.insert(name.to_string(), id);
274            id
275        }
276    }
277
278    /// Infer schemas from facts and predicate declarations
279    fn infer_schemas(&mut self, program: &Program) -> Result<()> {
280        let domains: HashMap<String, ScalarType> = program
281            .domains
282            .iter()
283            .map(|domain| (domain.name.clone(), domain.typ))
284            .collect();
285
286        // First, use explicit predicate declarations
287        for pred_decl in &program.predicates {
288            let declared_columns = pred_columns_for_decl(pred_decl);
289            let columns: Vec<(String, ScalarType)> = declared_columns
290                .iter()
291                .enumerate()
292                .map(|(i, col)| {
293                    let name = col.name.clone().unwrap_or_else(|| format!("c{}", i));
294                    resolve_pred_column_type(&pred_decl.name, i, &col.typ, &domains)
295                        .map(|ty| (name, ty))
296                })
297                .collect::<Result<Vec<_>>>()?;
298            self.schemas
299                .insert(pred_decl.name.clone(), Schema::new(columns));
300        }
301
302        // Then, infer from facts (if no declaration exists)
303        for rule in program.facts() {
304            let pred = &rule.head.predicate;
305            if !self.schemas.contains_key(pred) {
306                let columns: Vec<(String, ScalarType)> = rule
307                    .head
308                    .terms
309                    .iter()
310                    .enumerate()
311                    .map(|(i, term)| {
312                        let ty = infer_term_type(term);
313                        (format!("c{}", i), ty)
314                    })
315                    .collect();
316                self.schemas.insert(pred.clone(), Schema::new(columns));
317            }
318        }
319
320        // Finally, infer from rule heads if we still don't have a schema
321        for rule in &program.rules {
322            let pred = &rule.head.predicate;
323            if !self.schemas.contains_key(pred) {
324                // Use default U64 type for variables
325                let columns: Vec<(String, ScalarType)> = rule
326                    .head
327                    .terms
328                    .iter()
329                    .enumerate()
330                    .map(|(i, term)| {
331                        let ty = match term {
332                            Term::Variable(name) => self
333                                .infer_head_term_type_from_body(rule, name)
334                                .unwrap_or_else(|| infer_term_type(term)),
335                            _ => infer_term_type(term),
336                        };
337                        (format!("c{}", i), ty)
338                    })
339                    .collect();
340                let schema = Schema::new(columns)
341                    .with_sort_labels(sort_labels_from_terms(&rule.head.terms))
342                    .expect("rule head sort labels match inferred schema arity");
343                self.schemas.insert(pred.clone(), schema);
344            }
345        }
346
347        // Also infer from rule bodies for EDB predicates that only appear in bodies
348        for rule in &program.rules {
349            for lit in &rule.body {
350                let atom = match lit {
351                    BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
352                    BodyLiteral::Epistemic(_)
353                    | BodyLiteral::Comparison(_)
354                    | BodyLiteral::IsExpr(_)
355                    | BodyLiteral::Univ(_) => continue,
356                };
357                let pred = &atom.predicate;
358                if self.schemas.contains_key(pred) {
359                    continue;
360                }
361                let columns: Vec<(String, ScalarType)> = atom
362                    .terms
363                    .iter()
364                    .enumerate()
365                    .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
366                    .collect();
367                let schema = Schema::new(columns)
368                    .with_sort_labels(sort_labels_from_terms(&atom.terms))
369                    .expect("body sort labels match inferred schema arity");
370                self.schemas.insert(pred.clone(), schema);
371            }
372        }
373
374        // Ensure schemas exist for probabilistic facts and annotated disjunctions
375        for pf in &program.prob_facts {
376            let pred = &pf.atom.predicate;
377            if self.schemas.contains_key(pred) {
378                continue;
379            }
380            let columns: Vec<(String, ScalarType)> = pf
381                .atom
382                .terms
383                .iter()
384                .enumerate()
385                .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
386                .collect();
387            self.schemas.insert(pred.clone(), Schema::new(columns));
388        }
389
390        for ad in &program.annotated_disjunctions {
391            for choice in &ad.choices {
392                let pred = &choice.atom.predicate;
393                if self.schemas.contains_key(pred) {
394                    continue;
395                }
396                let columns: Vec<(String, ScalarType)> = choice
397                    .atom
398                    .terms
399                    .iter()
400                    .enumerate()
401                    .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
402                    .collect();
403                self.schemas.insert(pred.clone(), Schema::new(columns));
404            }
405        }
406
407        Ok(())
408    }
409
410    fn infer_head_term_type_from_body(&self, rule: &Rule, var_name: &str) -> Option<ScalarType> {
411        for lit in &rule.body {
412            let atom = match lit {
413                BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
414                BodyLiteral::Epistemic(_)
415                | BodyLiteral::Comparison(_)
416                | BodyLiteral::IsExpr(_)
417                | BodyLiteral::Univ(_) => continue,
418            };
419            let schema = self.schemas.get(&atom.predicate)?;
420            for (idx, term) in atom.terms.iter().enumerate() {
421                if let Term::Variable(name) = term {
422                    if name == var_name {
423                        if let Some(ty) = schema.column_type(idx) {
424                            return Some(ty);
425                        }
426                    }
427                }
428            }
429        }
430        None
431    }
432
433    fn infer_cardinalities(&mut self, program: &Program) {
434        self.est_cardinality.clear();
435
436        let mut fact_counts: HashMap<String, u64> = HashMap::new();
437        for fact in program.facts() {
438            *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
439        }
440
441        for pred in self.schemas.keys() {
442            let est = self
443                .cardinality_hints
444                .get(pred)
445                .copied()
446                .or_else(|| fact_counts.get(pred).copied())
447                .unwrap_or(1000)
448                .max(1);
449            self.est_cardinality.insert(pred.clone(), est);
450        }
451    }
452
453    /// Build SCCs from the dependency graph
454    fn build_sccs(&mut self, program: &Program) {
455        let graph = build_dependency_graph(program);
456        let scc_groups = find_sccs_for_lowering(&graph);
457
458        self.sccs.clear();
459        for (id, predicates) in scc_groups.iter().enumerate() {
460            // An SCC is recursive if it has more than one predicate
461            // or if a single predicate depends on itself positively
462            let is_recursive = if predicates.len() > 1 {
463                true
464            } else {
465                let pred = &predicates[0];
466                graph
467                    .outgoing(pred)
468                    .iter()
469                    .any(|e| e.to == *pred && e.dep_type == DepType::Positive)
470            };
471
472            self.sccs.push(Scc {
473                id: id as u32,
474                predicates: predicates.clone(),
475                is_recursive,
476            });
477        }
478    }
479
480    /// Lower an entire program to an execution plan
481    pub fn lower_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
482        validate_lowerable_terms(program)?;
483        // Infer schemas
484        self.infer_schemas(program)?;
485        self.infer_cardinalities(program);
486
487        // Pre-allocate RelIds for declared predicates so schema-only programs
488        // can populate relation stores before any facts or executable rules
489        // mention those relations. This keeps ILP candidate generation and
490        // runtime relation upload aligned with declared schemas.
491        for pred_decl in &program.predicates {
492            self.get_or_create_rel_id(&pred_decl.name);
493        }
494
495        // Build SCCs
496        self.build_sccs(program);
497
498        // Build execution plan
499        let mut builder = PlanBuilder::new();
500
501        // Add SCCs to the builder
502        for scc in &self.sccs {
503            builder.add_scc(scc.clone());
504        }
505
506        // Build strata from our strata field
507        for (id, preds) in self.strata.iter().enumerate() {
508            // Find which SCCs belong to this stratum
509            let scc_ids: Vec<u32> = self
510                .sccs
511                .iter()
512                .filter(|scc| scc.predicates.iter().any(|p| preds.contains(p)))
513                .map(|scc| scc.id)
514                .collect();
515
516            if !scc_ids.is_empty() {
517                builder.add_stratum(IrStratum {
518                    id: id as u32,
519                    sccs: scc_ids,
520                });
521            }
522        }
523
524        // Lower each rule
525        let mut rules_by_pred: HashMap<String, Vec<&Rule>> = HashMap::new();
526        for rule in program.proper_rules() {
527            rules_by_pred
528                .entry(rule.head.predicate.clone())
529                .or_default()
530                .push(rule);
531        }
532
533        // Add facts as scan-only rules
534        for fact in program.facts() {
535            let pred = &fact.head.predicate;
536            let scc_id = self.find_scc_for_predicate(pred);
537            let rel_id = self.get_or_create_rel_id(pred);
538
539            let body = RirNode::Scan { rel: rel_id };
540            let meta = self.create_meta_for_predicate(pred);
541
542            builder.add_rule(
543                scc_id,
544                CompiledRule {
545                    head: pred.clone(),
546                    body,
547                    meta,
548                },
549            );
550        }
551
552        // Lower proper rules
553        for (pred, rules) in &rules_by_pred {
554            let scc_id = self.find_scc_for_predicate(pred);
555
556            for rule in rules {
557                let body = self.lower_rule(rule)?;
558                let meta = self.create_meta_for_predicate(pred);
559
560                builder.add_rule(
561                    scc_id,
562                    CompiledRule {
563                        head: pred.clone(),
564                        body,
565                        meta,
566                    },
567                );
568            }
569        }
570
571        // Lower learnable rules (RD-32)
572        // Pre-allocate RelIds for ALL learnable predicates (heads + bodies)
573        // so every lower_learnable_rule snapshot is complete.
574        for learnable in &program.learnable_rules {
575            self.get_or_create_rel_id(&learnable.head.predicate);
576            for lit in &learnable.body {
577                if let BodyLiteral::Positive(atom) = lit {
578                    self.get_or_create_rel_id(&atom.predicate);
579                }
580            }
581        }
582        for learnable in &program.learnable_rules {
583            let head_pred = &learnable.head.predicate;
584            let scc_id = self.find_scc_for_predicate(head_pred);
585            let body = self.lower_learnable_rule(learnable)?;
586            let meta = self.create_meta_for_predicate(head_pred);
587            builder.add_rule(
588                scc_id,
589                CompiledRule {
590                    head: head_pred.clone(),
591                    body,
592                    meta,
593                },
594            );
595        }
596
597        Ok(builder.build())
598    }
599
600    /// Find the SCC ID for a predicate
601    fn find_scc_for_predicate(&self, pred: &str) -> u32 {
602        self.sccs
603            .iter()
604            .find(|scc| scc.predicates.contains(&pred.to_string()))
605            .map(|scc| scc.id)
606            .unwrap_or(0)
607    }
608
609    /// Create metadata for a predicate
610    fn create_meta_for_predicate(&self, pred: &str) -> RirMeta {
611        let schema = self
612            .schemas
613            .get(pred)
614            .cloned()
615            .unwrap_or_else(|| Schema::new(vec![]));
616        RirMeta::with_schema(schema)
617    }
618
619    /// Lower a learnable rule template into a TensorMaskedJoin node.
620    /// RD-34: Validates body has exactly 2 positive atoms.
621    /// RD-36: Sorts rel_index by RelId for deterministic tensor dimension mapping.
622    /// RD-30: Uses get_or_create_rel_id for head (handles head-only predicates).
623    fn lower_learnable_rule(&mut self, rule: &LearnableRule) -> Result<RirNode> {
624        // RD-34: Validate body shape
625        if rule.body.len() != 2 {
626            return Err(XlogError::Compilation(format!(
627                "learnable rule '{}' requires exactly 2 body literals, got {}",
628                rule.mask_name,
629                rule.body.len()
630            )));
631        }
632        for (idx, lit) in rule.body.iter().enumerate() {
633            match lit {
634                BodyLiteral::Positive(_) => {}
635                _ => {
636                    return Err(XlogError::Compilation(format!(
637                        "learnable rule '{}' body[{}]: only positive atoms allowed",
638                        rule.mask_name, idx
639                    )));
640                }
641            }
642        }
643
644        // RD-36: Sort by RelId for deterministic mapping
645        let mut rel_index: Vec<(RelId, String)> = self
646            .rel_ids()
647            .iter()
648            .map(|(name, id)| (*id, name.clone()))
649            .collect();
650        rel_index.sort_by_key(|(id, _)| id.0);
651        let schema_size = rel_index.len();
652
653        let (left_keys, right_keys) =
654            self.extract_template_join_keys(&rule.body[0], &rule.body[1])?;
655
656        let head_rel_name = rule.head.predicate.clone();
657        // RD-30: Allocate lazily — head-only predicates may not have a RelId yet
658        let head_rel_id = self.get_or_create_rel_id(&head_rel_name);
659
660        // Compute head projection: map head variables to join result columns.
661        // Join result layout: [left_col_0..left_col_n, right_col_0..right_col_m].
662        let left_atom = rule.body[0].atom().unwrap();
663        let right_atom = rule.body[1].atom().unwrap();
664        let left_arity = left_atom.terms.len();
665
666        // Build variable -> first-occurrence column mapping over joined result
667        let mut var_to_col: HashMap<String, usize> = HashMap::new();
668        for (i, term) in left_atom.terms.iter().enumerate() {
669            if let Some(name) = term.variable_name() {
670                var_to_col.entry(name.to_string()).or_insert(i);
671            }
672        }
673        for (i, term) in right_atom.terms.iter().enumerate() {
674            if let Some(name) = term.variable_name() {
675                var_to_col.entry(name.to_string()).or_insert(left_arity + i);
676            }
677        }
678
679        let mut head_projection: Vec<usize> = Vec::new();
680        for term in &rule.head.terms {
681            if let Some(name) = term.variable_name() {
682                let col = var_to_col.get(name).ok_or_else(|| {
683                    XlogError::Compilation(format!(
684                        "Learnable rule head variable '{}' not found in body atoms \
685                         ({}, {}). All head variables must appear in the body.",
686                        name, left_atom.predicate, right_atom.predicate,
687                    ))
688                })?;
689                head_projection.push(*col);
690            } else {
691                return Err(XlogError::Compilation(format!(
692                    "Learnable rule head must contain only variables, \
693                     found constant {:?} in head of '{}'",
694                    term, head_rel_name,
695                )));
696            }
697        }
698
699        // Infer schema for head predicate from the learnable rule if not already set.
700        // The head's column types come from the projected join columns.
701        if !self.schemas.contains_key(&head_rel_name) {
702            let columns: Vec<(String, ScalarType)> = head_projection
703                .iter()
704                .enumerate()
705                .map(|(i, &col)| {
706                    // Determine the type from left or right atom's schema
707                    let ty = if col < left_arity {
708                        self.schemas
709                            .get(&left_atom.predicate)
710                            .and_then(|s| s.column_type(col))
711                            .unwrap_or(ScalarType::U32)
712                    } else {
713                        self.schemas
714                            .get(&right_atom.predicate)
715                            .and_then(|s| s.column_type(col - left_arity))
716                            .unwrap_or(ScalarType::U32)
717                    };
718                    (format!("c{}", i), ty)
719                })
720                .collect();
721            self.schemas
722                .insert(head_rel_name.clone(), Schema::new(columns));
723        }
724
725        Ok(RirNode::TensorMaskedJoin {
726            mask_name: rule.mask_name.clone(),
727            schema_size,
728            left_keys,
729            right_keys,
730            rel_index,
731            head_rel_name,
732            head_rel_id,
733            max_active_rules: self.max_active_rules,
734            head_projection,
735        })
736    }
737
738    /// Extract join keys from two body literals' shared variables.
739    /// For `b1(X, Z), b2(Z, Y)`, the shared variable Z gives left_keys=[1], right_keys=[0].
740    fn extract_template_join_keys(
741        &self,
742        left: &BodyLiteral,
743        right: &BodyLiteral,
744    ) -> Result<(Vec<usize>, Vec<usize>)> {
745        let left_atom = left
746            .atom()
747            .ok_or_else(|| XlogError::Compilation("Learnable body[0] is not an atom".into()))?;
748        let right_atom = right
749            .atom()
750            .ok_or_else(|| XlogError::Compilation("Learnable body[1] is not an atom".into()))?;
751
752        let mut left_keys = Vec::new();
753        let mut right_keys = Vec::new();
754
755        for (li, lt) in left_atom.terms.iter().enumerate() {
756            if let Some(lname) = lt.variable_name() {
757                for (ri, rt) in right_atom.terms.iter().enumerate() {
758                    if let Some(rname) = rt.variable_name() {
759                        if lname == rname {
760                            left_keys.push(li);
761                            right_keys.push(ri);
762                        }
763                    }
764                }
765            }
766        }
767
768        Ok((left_keys, right_keys))
769    }
770
771    /// Lower a single rule to an RIR node
772    fn lower_rule(&mut self, rule: &Rule) -> Result<RirNode> {
773        if let Some(lit) = rule.body.iter().find_map(|lit| match lit {
774            BodyLiteral::Epistemic(lit) => Some(lit),
775            _ => None,
776        }) {
777            return Err(XlogError::UnsupportedEpistemicConstruct {
778                construct: "RIR lowering boundary".to_string(),
779                context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
780            });
781        }
782
783        // Split body literals.
784        let (positive_atoms, negated_atoms, comparisons, is_exprs) =
785            Self::split_body_literals(&rule.body);
786
787        // Allocate RelIds for all body predicates in source order so join planning
788        // does not influence identifier assignment.
789        for lit in &rule.body {
790            match lit {
791                BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => {
792                    self.get_or_create_rel_id(&atom.predicate);
793                }
794                BodyLiteral::Epistemic(_)
795                | BodyLiteral::Comparison(_)
796                | BodyLiteral::IsExpr(_)
797                | BodyLiteral::Univ(_) => {}
798            }
799        }
800
801        // Plan positive atoms (join tree shape + leaf order).
802        //
803        // Rules with no positive atoms are legal for nullary/ground heads in our
804        // probabilistic profiles (e.g. `q() :- not p().`). Lower them by seeding
805        // the body with a unit relation ({()}) and applying filters/negations.
806        let (positive_root, leaf_order) = if positive_atoms.is_empty() {
807            (RirNode::Unit, Vec::new())
808        } else {
809            self.plan_positive_atoms(&positive_atoms)?
810        };
811
812        // Build variable environment from the planned leaf order (matches join output layout:
813        // left subtree columns then right subtree columns).
814        let mut var_env = VariableEnv::new();
815        let mut current_col = 0;
816        for atom in &leaf_order {
817            let schema = self.schemas.get(&atom.predicate);
818            for (i, term) in atom.terms.iter().enumerate() {
819                if let Term::Variable(name) = term {
820                    if name == "_" {
821                        continue;
822                    }
823                    var_env.add_occurrence(name, atom.predicate.clone(), i, current_col + i);
824                    // Also record the type for this variable (first occurrence wins)
825                    if !var_env.types.contains_key(name) {
826                        let typ = schema
827                            .and_then(|s| s.column_type(i))
828                            .unwrap_or(ScalarType::I64); // Default to I64 for arithmetic
829                        var_env.types.insert(name.to_string(), typ);
830                    }
831                }
832            }
833            current_col += atom.terms.len();
834        }
835        var_env.total_cols = current_col;
836
837        // Lower the body starting from the planned positive join root.
838        let body_node = self.lower_body_parts(
839            positive_root,
840            &negated_atoms,
841            &comparisons,
842            &is_exprs,
843            &mut var_env,
844        )?;
845
846        if rule.has_aggregation() {
847            return self.lower_aggregate_rule(&rule.head, body_node, &var_env);
848        }
849
850        // Project to head terms (variables and constants).
851        let projection_exprs = self.compute_head_projection(&rule.head, &var_env)?;
852
853        if Self::is_identity_projection(&projection_exprs, var_env.column_count()) {
854            Ok(body_node)
855        } else {
856            Ok(RirNode::Project {
857                input: Box::new(body_node),
858                columns: projection_exprs,
859            })
860        }
861    }
862
863    fn split_body_literals(
864        body: &[BodyLiteral],
865    ) -> (Vec<&Atom>, Vec<&Atom>, Vec<&Comparison>, Vec<&IsExpr>) {
866        let mut positive_atoms: Vec<&Atom> = Vec::new();
867        let mut negated_atoms: Vec<&Atom> = Vec::new();
868        let mut comparisons: Vec<&Comparison> = Vec::new();
869        let mut is_exprs: Vec<&IsExpr> = Vec::new();
870
871        for lit in body {
872            match lit {
873                BodyLiteral::Positive(atom) => positive_atoms.push(atom),
874                BodyLiteral::Negated(atom) => negated_atoms.push(atom),
875                BodyLiteral::Epistemic(_) => {}
876                BodyLiteral::Comparison(cmp) => comparisons.push(cmp),
877                BodyLiteral::IsExpr(is_expr) => is_exprs.push(is_expr),
878                BodyLiteral::Univ(_) => {}
879            }
880        }
881
882        (positive_atoms, negated_atoms, comparisons, is_exprs)
883    }
884
885    fn atom_vars(atom: &Atom) -> std::collections::HashSet<String> {
886        atom.terms
887            .iter()
888            .flat_map(|t| t.variables().into_iter())
889            .filter(|name| *name != "_")
890            .map(ToOwned::to_owned)
891            .collect()
892    }
893
894    fn estimate_atom_rows(&self, atom: &Atom) -> f64 {
895        let base = self
896            .est_cardinality
897            .get(&atom.predicate)
898            .copied()
899            .unwrap_or(1000)
900            .max(1) as f64;
901
902        let const_count = atom
903            .terms
904            .iter()
905            .filter(|t| term_to_const_value(t).is_some())
906            .count();
907
908        // Equality constants are usually selective; use a conservative default.
909        let selectivity = 0.1_f64.powi(const_count as i32);
910        (base * selectivity).max(1.0)
911    }
912
913    fn build_cartesian_join(
914        &self,
915        left: RirNode,
916        right: RirNode,
917        left_width: usize,
918        right_width: usize,
919    ) -> RirNode {
920        // Implement cross join by appending a constant key column to both inputs and joining on it,
921        // then projecting away the constant columns.
922        let left_const_col =
923            ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
924        let right_const_col =
925            ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
926
927        let mut left_cols: Vec<ProjectExpr> = (0..left_width).map(ProjectExpr::Column).collect();
928        left_cols.push(left_const_col);
929        let left_aug = RirNode::Project {
930            input: Box::new(left),
931            columns: left_cols,
932        };
933
934        let mut right_cols: Vec<ProjectExpr> = (0..right_width).map(ProjectExpr::Column).collect();
935        right_cols.push(right_const_col);
936        let right_aug = RirNode::Project {
937            input: Box::new(right),
938            columns: right_cols,
939        };
940
941        let joined = RirNode::Join {
942            left: Box::new(left_aug),
943            right: Box::new(right_aug),
944            left_keys: vec![left_width],
945            right_keys: vec![right_width],
946            join_type: JoinType::Inner,
947        };
948
949        let mut keep: Vec<ProjectExpr> = Vec::with_capacity(left_width + right_width);
950        keep.extend((0..left_width).map(ProjectExpr::Column));
951        let right_start = left_width + 1;
952        keep.extend((right_start..right_start + right_width).map(ProjectExpr::Column));
953
954        RirNode::Project {
955            input: Box::new(joined),
956            columns: keep,
957        }
958    }
959
960    fn make_leaf_plan<'a>(&mut self, atom: &'a Atom, orig_idx: usize) -> Result<JoinPlan<'a>> {
961        let rel_id = self.get_or_create_rel_id(&atom.predicate);
962        let scan = RirNode::Scan { rel: rel_id };
963        let node = self.apply_constant_filters(scan, atom, 0)?;
964
965        let mut var_pos: HashMap<String, usize> = HashMap::new();
966        for (i, term) in atom.terms.iter().enumerate() {
967            if let Term::Variable(name) = term {
968                if name != "_" {
969                    var_pos.entry(name.clone()).or_insert(i);
970                }
971            }
972        }
973
974        let est_rows = self.estimate_atom_rows(atom);
975        Ok(JoinPlan {
976            node,
977            leaf_order: vec![atom],
978            leaf_order_idx: vec![orig_idx],
979            var_pos,
980            width: atom.terms.len(),
981            est_rows,
982            total_cost: est_rows,
983        })
984    }
985
986    fn join_plans<'a>(&self, left: &JoinPlan<'a>, right: &JoinPlan<'a>) -> JoinPlan<'a> {
987        let shared_vars: Vec<&String> = left
988            .var_pos
989            .keys()
990            .filter(|v| right.var_pos.contains_key(*v))
991            .collect();
992
993        let node = if shared_vars.is_empty() {
994            self.build_cartesian_join(
995                left.node.clone(),
996                right.node.clone(),
997                left.width,
998                right.width,
999            )
1000        } else {
1001            let mut key_pairs: Vec<(usize, usize)> = shared_vars
1002                .iter()
1003                .filter_map(|v| {
1004                    Some((
1005                        left.var_pos.get(*v).copied()?,
1006                        right.var_pos.get(*v).copied()?,
1007                    ))
1008                })
1009                .collect();
1010            key_pairs.sort_unstable();
1011
1012            let (left_keys, right_keys): (Vec<usize>, Vec<usize>) = key_pairs.into_iter().unzip();
1013
1014            RirNode::Join {
1015                left: Box::new(left.node.clone()),
1016                right: Box::new(right.node.clone()),
1017                left_keys,
1018                right_keys,
1019                join_type: JoinType::Inner,
1020            }
1021        };
1022
1023        let mut leaf_order = left.leaf_order.clone();
1024        leaf_order.extend(right.leaf_order.iter().copied());
1025
1026        let mut leaf_order_idx = left.leaf_order_idx.clone();
1027        leaf_order_idx.extend_from_slice(&right.leaf_order_idx);
1028
1029        let mut var_pos = left.var_pos.clone();
1030        for (var, pos) in &right.var_pos {
1031            var_pos.entry(var.clone()).or_insert(left.width + *pos);
1032        }
1033
1034        let shared = shared_vars.len();
1035        let mut selectivity = if shared == 0 {
1036            1.0
1037        } else {
1038            0.1_f64.powi(shared as i32)
1039        };
1040        if shared == 0 {
1041            // Penalize cartesian joins strongly.
1042            selectivity *= 1.0e6;
1043        }
1044
1045        let output_rows = (left.est_rows * right.est_rows * selectivity).max(1.0);
1046
1047        // Hash join cost is sensitive to which side is build (right) and probe (left).
1048        let build_cost = right.est_rows;
1049        let probe_cost = left.est_rows * 0.5;
1050        let total_cost = left.total_cost + right.total_cost + build_cost + probe_cost + output_rows;
1051
1052        JoinPlan {
1053            node,
1054            leaf_order,
1055            leaf_order_idx,
1056            var_pos,
1057            width: left.width + right.width,
1058            est_rows: output_rows,
1059            total_cost,
1060        }
1061    }
1062
1063    fn plan_positive_atoms_bushy<'a>(
1064        &mut self,
1065        atoms: &[&'a Atom],
1066    ) -> Result<(RirNode, Vec<&'a Atom>)> {
1067        let n = atoms.len();
1068        if n == 0 {
1069            return Err(XlogError::Compilation("Empty rule body".to_string()));
1070        }
1071        if n == 1 {
1072            let plan = self.make_leaf_plan(atoms[0], 0)?;
1073            return Ok((plan.node, plan.leaf_order));
1074        }
1075
1076        let size = 1usize << n;
1077        let mut best: Vec<Option<JoinPlan<'a>>> = (0..size).map(|_| None).collect();
1078
1079        for (i, atom) in atoms.iter().enumerate() {
1080            best[1usize << i] = Some(self.make_leaf_plan(atom, i)?);
1081        }
1082
1083        fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1084            for (ai, bi) in a.iter().zip(b.iter()) {
1085                if ai != bi {
1086                    return ai < bi;
1087                }
1088            }
1089            a.len() < b.len()
1090        }
1091
1092        for mask in 1..size {
1093            if mask.count_ones() <= 1 {
1094                continue;
1095            }
1096
1097            let mut best_for_mask: Option<JoinPlan<'a>> = None;
1098
1099            let mut sub = (mask - 1) & mask;
1100            while sub > 0 {
1101                let a = sub;
1102                let b = mask ^ a;
1103                if b == 0 {
1104                    sub = (sub - 1) & mask;
1105                    continue;
1106                }
1107
1108                let (Some(plan_a), Some(plan_b)) = (&best[a], &best[b]) else {
1109                    sub = (sub - 1) & mask;
1110                    continue;
1111                };
1112
1113                // Consider both orientations: A ⋈ B and B ⋈ A.
1114                for (left, right) in [(plan_a, plan_b), (plan_b, plan_a)] {
1115                    let cand = self.join_plans(left, right);
1116                    let replace = match &best_for_mask {
1117                        None => true,
1118                        Some(current) => {
1119                            if cand.total_cost < current.total_cost {
1120                                true
1121                            } else if (cand.total_cost - current.total_cost).abs() < 1e-9 {
1122                                lex_lt(&cand.leaf_order_idx, &current.leaf_order_idx)
1123                            } else {
1124                                false
1125                            }
1126                        }
1127                    };
1128
1129                    if replace {
1130                        best_for_mask = Some(cand);
1131                    }
1132                }
1133
1134                sub = (sub - 1) & mask;
1135            }
1136
1137            best[mask] = best_for_mask;
1138        }
1139
1140        let full_mask = size - 1;
1141        if let Some(plan) = best[full_mask].take() {
1142            return Ok((plan.node, plan.leaf_order));
1143        }
1144
1145        // Should be unreachable, but fall back to greedy ordering.
1146        let ordered = self.order_positive_atoms_greedy(atoms);
1147        let mut dummy_env = VariableEnv::new();
1148        let node = self.build_join_tree(&ordered, &mut dummy_env)?;
1149        Ok((node, ordered))
1150    }
1151
1152    fn plan_positive_atoms<'a>(&mut self, atoms: &[&'a Atom]) -> Result<(RirNode, Vec<&'a Atom>)> {
1153        if atoms.len() <= 1 {
1154            if atoms.is_empty() {
1155                return Err(XlogError::Compilation("Empty rule body".to_string()));
1156            }
1157            let plan = self.make_leaf_plan(atoms[0], 0)?;
1158            return Ok((plan.node, plan.leaf_order));
1159        }
1160
1161        const MAX_BUSHY_DP_ATOMS: usize = 10;
1162        if atoms.len() <= MAX_BUSHY_DP_ATOMS {
1163            return self.plan_positive_atoms_bushy(atoms);
1164        }
1165
1166        // Greedy bushy join planning for large rules (scales beyond exponential DP).
1167        self.plan_positive_atoms_bushy_greedy(atoms)
1168    }
1169
1170    fn plan_positive_atoms_bushy_greedy<'a>(
1171        &mut self,
1172        atoms: &[&'a Atom],
1173    ) -> Result<(RirNode, Vec<&'a Atom>)> {
1174        if atoms.is_empty() {
1175            return Err(XlogError::Compilation("Empty rule body".to_string()));
1176        }
1177
1178        fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1179            for (ai, bi) in a.iter().zip(b.iter()) {
1180                if ai != bi {
1181                    return ai < bi;
1182                }
1183            }
1184            a.len() < b.len()
1185        }
1186
1187        let mut plans: Vec<JoinPlan<'a>> = Vec::with_capacity(atoms.len());
1188        for (idx, atom) in atoms.iter().enumerate() {
1189            plans.push(self.make_leaf_plan(atom, idx)?);
1190        }
1191
1192        while plans.len() > 1 {
1193            let mut best_pair: Option<(usize, usize, JoinPlan<'a>)> = None;
1194
1195            for i in 0..plans.len() {
1196                for j in (i + 1)..plans.len() {
1197                    let a = &plans[i];
1198                    let b = &plans[j];
1199
1200                    let cand_ab = self.join_plans(a, b);
1201                    let cand_ba = self.join_plans(b, a);
1202
1203                    let cand = if cand_ab.total_cost < cand_ba.total_cost
1204                        || (cand_ab.total_cost - cand_ba.total_cost).abs() < 1e-9
1205                            && lex_lt(&cand_ab.leaf_order_idx, &cand_ba.leaf_order_idx)
1206                    {
1207                        cand_ab
1208                    } else {
1209                        cand_ba
1210                    };
1211
1212                    let replace = match &best_pair {
1213                        None => true,
1214                        Some((_bi, _bj, best)) => {
1215                            if cand.total_cost < best.total_cost {
1216                                true
1217                            } else if (cand.total_cost - best.total_cost).abs() < 1e-9 {
1218                                lex_lt(&cand.leaf_order_idx, &best.leaf_order_idx)
1219                            } else {
1220                                false
1221                            }
1222                        }
1223                    };
1224
1225                    if replace {
1226                        best_pair = Some((i, j, cand));
1227                    }
1228                }
1229            }
1230
1231            let Some((i, j, joined)) = best_pair else {
1232                break;
1233            };
1234
1235            // Remove joined inputs from the plan list and replace with the join.
1236            let (a, b) = if i < j { (i, j) } else { (j, i) };
1237            plans.remove(b);
1238            plans.remove(a);
1239            plans.push(joined);
1240        }
1241
1242        let plan = plans
1243            .pop()
1244            .ok_or_else(|| XlogError::Compilation("Join planning failed".to_string()))?;
1245        Ok((plan.node, plan.leaf_order))
1246    }
1247
1248    fn order_positive_atoms_greedy<'a>(&self, atoms: &[&'a Atom]) -> Vec<&'a Atom> {
1249        let mut remaining: Vec<(usize, &Atom)> = atoms.iter().copied().enumerate().collect();
1250        let mut ordered: Vec<&Atom> = Vec::with_capacity(atoms.len());
1251        let mut bound_vars: HashSet<String> = HashSet::new();
1252
1253        while !remaining.is_empty() {
1254            let pick_idx = if ordered.is_empty() {
1255                remaining
1256                    .iter()
1257                    .enumerate()
1258                    .min_by(|(_, a), (_, b)| {
1259                        let (ai, aa) = **a;
1260                        let (bi, bb) = **b;
1261                        self.estimate_atom_rows(aa)
1262                            .partial_cmp(&self.estimate_atom_rows(bb))
1263                            .unwrap_or(std::cmp::Ordering::Equal)
1264                            .then(ai.cmp(&bi))
1265                    })
1266                    .map(|(idx, _)| idx)
1267                    .unwrap()
1268            } else {
1269                remaining
1270                    .iter()
1271                    .enumerate()
1272                    .min_by(|(_, a), (_, b)| {
1273                        let (ai, aa) = **a;
1274                        let (bi, bb) = **b;
1275
1276                        let a_vars = Self::atom_vars(aa);
1277                        let b_vars = Self::atom_vars(bb);
1278
1279                        let a_shared = a_vars.intersection(&bound_vars).count();
1280                        let b_shared = b_vars.intersection(&bound_vars).count();
1281
1282                        let a_score = if a_shared == 0 {
1283                            self.estimate_atom_rows(aa) * 1.0e12
1284                        } else {
1285                            self.estimate_atom_rows(aa) / a_shared as f64
1286                        };
1287                        let b_score = if b_shared == 0 {
1288                            self.estimate_atom_rows(bb) * 1.0e12
1289                        } else {
1290                            self.estimate_atom_rows(bb) / b_shared as f64
1291                        };
1292
1293                        a_score
1294                            .partial_cmp(&b_score)
1295                            .unwrap_or(std::cmp::Ordering::Equal)
1296                            .then(ai.cmp(&bi))
1297                    })
1298                    .map(|(idx, _)| idx)
1299                    .unwrap()
1300            };
1301
1302            let (_orig_idx, atom) = remaining.remove(pick_idx);
1303            ordered.push(atom);
1304            bound_vars.extend(Self::atom_vars(atom));
1305        }
1306
1307        ordered
1308    }
1309
1310    fn lower_body_parts(
1311        &mut self,
1312        positive_root: RirNode,
1313        negated_atoms: &[&Atom],
1314        comparisons: &[&Comparison],
1315        is_exprs: &[&IsExpr],
1316        var_env: &mut VariableEnv,
1317    ) -> Result<RirNode> {
1318        let mut result = positive_root;
1319
1320        // Apply comparisons as filters.
1321        for cmp in comparisons {
1322            result = self.apply_comparison(result, cmp, var_env)?;
1323        }
1324
1325        // Apply is-expressions (must be after atoms that bind the input variables).
1326        for is_expr in is_exprs {
1327            result = self.lower_is_expr(is_expr, result, var_env)?;
1328        }
1329
1330        // Handle negated atoms via Diff / semi-join.
1331        for neg_atom in negated_atoms {
1332            result = self.apply_negation(result, neg_atom, var_env)?;
1333        }
1334
1335        Ok(result)
1336    }
1337
1338    /// Build a left-deep join tree from positive atoms
1339    fn build_join_tree(&mut self, atoms: &[&Atom], var_env: &mut VariableEnv) -> Result<RirNode> {
1340        if atoms.is_empty() {
1341            return Err(XlogError::Compilation("Empty rule body".to_string()));
1342        }
1343
1344        // Start with the first atom as a scan
1345        let first_atom = atoms[0];
1346        let rel_id = self.get_or_create_rel_id(&first_atom.predicate);
1347        let mut result = RirNode::Scan { rel: rel_id };
1348        let mut result_vars = self.collect_atom_vars(first_atom);
1349        let mut result_width = first_atom.terms.len();
1350
1351        // Apply constant filters if any
1352        result = self.apply_constant_filters(result, first_atom, 0)?;
1353
1354        // Join with remaining atoms (left-deep)
1355        for atom in atoms.iter().skip(1) {
1356            let right_rel_id = self.get_or_create_rel_id(&atom.predicate);
1357            let right_scan = RirNode::Scan { rel: right_rel_id };
1358
1359            // Apply constant filters to the right side
1360            let right_filtered = self.apply_constant_filters(right_scan, atom, 0)?;
1361
1362            // Compute join keys based on shared variables
1363            let (left_keys, right_keys) = self.compute_join_keys(&result_vars, atom, result_width);
1364
1365            if left_keys.is_empty() {
1366                // Cartesian product (no shared variables)
1367                result = RirNode::Join {
1368                    left: Box::new(result),
1369                    right: Box::new(right_filtered),
1370                    left_keys: vec![],
1371                    right_keys: vec![],
1372                    join_type: JoinType::Inner,
1373                };
1374            } else {
1375                result = RirNode::Join {
1376                    left: Box::new(result),
1377                    right: Box::new(right_filtered),
1378                    left_keys,
1379                    right_keys,
1380                    join_type: JoinType::Inner,
1381                };
1382            }
1383
1384            // Update result vars for the next iteration
1385            for (i, term) in atom.terms.iter().enumerate() {
1386                if let Term::Variable(name) = term {
1387                    result_vars.push((name.clone(), result_width + i));
1388                }
1389            }
1390            result_width += atom.terms.len();
1391        }
1392
1393        // Update var_env with final positions
1394        var_env.total_cols = result_width;
1395
1396        Ok(result)
1397    }
1398
1399    /// Collect variable names and their positions within an atom
1400    fn collect_atom_vars(&self, atom: &Atom) -> Vec<(String, usize)> {
1401        atom.terms
1402            .iter()
1403            .enumerate()
1404            .filter_map(|(i, term)| {
1405                if let Term::Variable(name) = term {
1406                    Some((name.clone(), i))
1407                } else {
1408                    None
1409                }
1410            })
1411            .collect()
1412    }
1413
1414    /// Compute join keys between the current result and a new atom
1415    fn compute_join_keys(
1416        &self,
1417        left_vars: &[(String, usize)],
1418        right_atom: &Atom,
1419        _left_width: usize,
1420    ) -> (Vec<usize>, Vec<usize>) {
1421        let mut left_keys = Vec::new();
1422        let mut right_keys = Vec::new();
1423
1424        for (right_idx, term) in right_atom.terms.iter().enumerate() {
1425            if let Term::Variable(name) = term {
1426                // Find if this variable exists in the left side
1427                for (left_name, left_idx) in left_vars {
1428                    if left_name == name {
1429                        left_keys.push(*left_idx);
1430                        right_keys.push(right_idx);
1431                        break; // Only use first occurrence for join key
1432                    }
1433                }
1434            }
1435        }
1436
1437        (left_keys, right_keys)
1438    }
1439
1440    /// Apply constant filters for an atom
1441    fn apply_constant_filters(
1442        &self,
1443        input: RirNode,
1444        atom: &Atom,
1445        _base_col: usize,
1446    ) -> Result<RirNode> {
1447        let mut filters = Vec::new();
1448        let mut first_var_col: HashMap<&str, usize> = HashMap::new();
1449        let schema = self.schemas.get(&atom.predicate).ok_or_else(|| {
1450            XlogError::Compilation(format!("Missing schema for predicate {}", atom.predicate))
1451        })?;
1452
1453        for (i, term) in atom.terms.iter().enumerate() {
1454            if let Term::Variable(name) = term {
1455                if name != "_" {
1456                    if let Some(&first) = first_var_col.get(name.as_str()) {
1457                        filters.push(Expr::Compare {
1458                            left: Box::new(Expr::Column(first)),
1459                            op: CompareOp::Eq,
1460                            right: Box::new(Expr::Column(i)),
1461                        });
1462                    } else {
1463                        first_var_col.insert(name.as_str(), i);
1464                    }
1465                }
1466            }
1467
1468            let col_type = schema.column_type(i).ok_or_else(|| {
1469                XlogError::Compilation(format!(
1470                    "Missing column type for {} column {}",
1471                    atom.predicate, i
1472                ))
1473            })?;
1474            if let Some(const_val) = term_to_typed_const_value(term, col_type)? {
1475                filters.push(Expr::Compare {
1476                    left: Box::new(Expr::Column(i)),
1477                    op: CompareOp::Eq,
1478                    right: Box::new(Expr::Const(const_val)),
1479                });
1480            }
1481        }
1482
1483        if filters.is_empty() {
1484            Ok(input)
1485        } else {
1486            let predicate = if filters.len() == 1 {
1487                filters.pop().unwrap()
1488            } else {
1489                Expr::And(filters)
1490            };
1491
1492            Ok(RirNode::Filter {
1493                input: Box::new(input),
1494                predicate,
1495            })
1496        }
1497    }
1498
1499    /// Apply a comparison as a filter
1500    fn apply_comparison(
1501        &self,
1502        input: RirNode,
1503        cmp: &Comparison,
1504        var_env: &VariableEnv,
1505    ) -> Result<RirNode> {
1506        let (left_expr, right_expr) = match (&cmp.left, &cmp.right) {
1507            (Term::Variable(name), term) => {
1508                let col = var_env.get_column(name).ok_or_else(|| {
1509                    XlogError::Compilation(format!("Variable {} not found in environment", name))
1510                })?;
1511                let typ = var_env.get_type(name).ok_or_else(|| {
1512                    XlogError::Compilation(format!("Missing type for variable {}", name))
1513                })?;
1514                if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1515                    (Expr::Column(col), Expr::Const(const_val))
1516                } else {
1517                    (
1518                        self.term_to_expr(&cmp.left, var_env)?,
1519                        self.term_to_expr(&cmp.right, var_env)?,
1520                    )
1521                }
1522            }
1523            (term, Term::Variable(name)) => {
1524                let col = var_env.get_column(name).ok_or_else(|| {
1525                    XlogError::Compilation(format!("Variable {} not found in environment", name))
1526                })?;
1527                let typ = var_env.get_type(name).ok_or_else(|| {
1528                    XlogError::Compilation(format!("Missing type for variable {}", name))
1529                })?;
1530                if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1531                    (Expr::Const(const_val), Expr::Column(col))
1532                } else {
1533                    (
1534                        self.term_to_expr(&cmp.left, var_env)?,
1535                        self.term_to_expr(&cmp.right, var_env)?,
1536                    )
1537                }
1538            }
1539            _ => (
1540                self.term_to_expr(&cmp.left, var_env)?,
1541                self.term_to_expr(&cmp.right, var_env)?,
1542            ),
1543        };
1544
1545        let op = match cmp.op {
1546            CompOp::Eq => CompareOp::Eq,
1547            CompOp::Ne => CompareOp::Ne,
1548            CompOp::Lt => CompareOp::Lt,
1549            CompOp::Le => CompareOp::Le,
1550            CompOp::Gt => CompareOp::Gt,
1551            CompOp::Ge => CompareOp::Ge,
1552        };
1553
1554        Ok(RirNode::Filter {
1555            input: Box::new(input),
1556            predicate: Expr::Compare {
1557                left: Box::new(left_expr),
1558                op,
1559                right: Box::new(right_expr),
1560            },
1561        })
1562    }
1563
1564    /// Convert a term to an expression
1565    fn term_to_expr(&self, term: &Term, var_env: &VariableEnv) -> Result<Expr> {
1566        match term {
1567            Term::Variable(name) => {
1568                if let Some(col) = var_env.get_column(name) {
1569                    Ok(Expr::Column(col))
1570                } else {
1571                    Err(XlogError::Compilation(format!(
1572                        "Variable {} not found in environment",
1573                        name
1574                    )))
1575                }
1576            }
1577            Term::Anonymous => Err(XlogError::Compilation(
1578                "Anonymous wildcard '_' not allowed in comparisons".to_string(),
1579            )),
1580            Term::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
1581            Term::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
1582            Term::String(s) => Ok(Expr::Const(ConstValue::Symbol(s.clone()))),
1583            Term::Symbol(id) => Ok(Expr::Const(ConstValue::Symbol(symbol::resolve(*id)))),
1584            Term::Aggregate(_) => Err(XlogError::Compilation(
1585                "Aggregates not allowed in comparisons".to_string(),
1586            )),
1587            Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1588                Err(v085_term_not_lowerable("comparison", v085_term_kind(term)))
1589            }
1590        }
1591    }
1592
1593    /// Apply negation via set difference
1594    fn apply_negation(
1595        &mut self,
1596        input: RirNode,
1597        neg_atom: &Atom,
1598        var_env: &VariableEnv,
1599    ) -> Result<RirNode> {
1600        let rel_id = self.get_or_create_rel_id(&neg_atom.predicate);
1601        let neg_scan = RirNode::Scan { rel: rel_id };
1602
1603        // Apply constant filters to the negated atom
1604        let neg_filtered = self.apply_constant_filters(neg_scan, neg_atom, 0)?;
1605
1606        // Find which columns from the input correspond to variables in the negated atom
1607        let mut input_cols = Vec::new();
1608        let mut neg_cols = Vec::new();
1609
1610        for (neg_idx, term) in neg_atom.terms.iter().enumerate() {
1611            if let Term::Variable(name) = term {
1612                if let Some(col) = var_env.get_column(name) {
1613                    input_cols.push(col);
1614                    neg_cols.push(neg_idx);
1615                }
1616            }
1617        }
1618
1619        if input_cols.is_empty() {
1620            // No shared variables - this is an existence check
1621            // If the negated relation is non-empty, result is empty
1622            // This is a special case we handle with anti-join
1623            Ok(RirNode::Diff {
1624                left: Box::new(input),
1625                right: Box::new(neg_filtered),
1626            })
1627        } else {
1628            // Project the negated atom to only the shared variable columns
1629            let neg_projected = if neg_cols.len() < neg_atom.terms.len() {
1630                let neg_proj_exprs: Vec<ProjectExpr> =
1631                    neg_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1632                RirNode::Project {
1633                    input: Box::new(neg_filtered),
1634                    columns: neg_proj_exprs,
1635                }
1636            } else {
1637                neg_filtered
1638            };
1639
1640            // Project input to matching columns for the diff, then diff
1641            // Actually, for proper anti-join semantics we need to be careful.
1642            // The Diff operation subtracts matching tuples.
1643            // We need to project input to the shared columns, diff, then rejoin.
1644
1645            // Simpler approach: project input to shared columns, diff with negated,
1646            // then rejoin with original
1647            let input_proj_exprs: Vec<ProjectExpr> =
1648                input_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1649            let input_projected = RirNode::Project {
1650                input: Box::new(input.clone()),
1651                columns: input_proj_exprs,
1652            };
1653
1654            // The Diff gives us the keys that should be kept
1655            let kept_keys = RirNode::Diff {
1656                left: Box::new(input_projected),
1657                right: Box::new(neg_projected),
1658            };
1659
1660            // Join back with original input to get full tuples
1661            // This effectively filters the input to only rows where the key
1662            // is not in the negated relation
1663            Ok(RirNode::Join {
1664                left: Box::new(input),
1665                right: Box::new(kept_keys),
1666                left_keys: input_cols.clone(),
1667                right_keys: (0..input_cols.len()).collect(),
1668                join_type: JoinType::Semi,
1669            })
1670        }
1671    }
1672
1673    fn is_identity_projection(proj: &[ProjectExpr], input_cols: usize) -> bool {
1674        if proj.len() != input_cols {
1675            return false;
1676        }
1677        proj.iter()
1678            .enumerate()
1679            .all(|(i, e)| matches!(e, ProjectExpr::Column(c) if *c == i))
1680    }
1681
1682    /// Build a projection list that matches the rule head term order.
1683    ///
1684    /// For non-aggregate rules this supports:
1685    /// - Variables (column passthrough)
1686    /// - Constants (computed constant columns)
1687    fn compute_head_projection(
1688        &self,
1689        head: &Atom,
1690        var_env: &VariableEnv,
1691    ) -> Result<Vec<ProjectExpr>> {
1692        let mut cols = Vec::with_capacity(head.terms.len());
1693
1694        for term in &head.terms {
1695            match term {
1696                Term::Variable(name) => {
1697                    let col = var_env
1698                        .get_column(name)
1699                        .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1700                    cols.push(ProjectExpr::Column(col));
1701                }
1702                Term::Anonymous => {
1703                    return Err(XlogError::Compilation(
1704                        "Anonymous wildcard '_' not allowed in rule head".to_string(),
1705                    ));
1706                }
1707                Term::Aggregate(_) => {
1708                    return Err(XlogError::Compilation(
1709                        "Aggregate term in non-aggregate rule head".to_string(),
1710                    ));
1711                }
1712                Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1713                    let (expr, typ) = term_to_project_const_expr(term)?;
1714                    cols.push(ProjectExpr::Computed(expr, typ));
1715                }
1716                Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1717                    return Err(v085_term_not_lowerable(
1718                        "rule head projection",
1719                        v085_term_kind(term),
1720                    ));
1721                }
1722            }
1723        }
1724
1725        Ok(cols)
1726    }
1727
1728    /// Lower an aggregate rule head into `GroupBy` + final projection.
1729    fn lower_aggregate_rule(
1730        &mut self,
1731        head: &Atom,
1732        body: RirNode,
1733        var_env: &VariableEnv,
1734    ) -> Result<RirNode> {
1735        // Collect unique group keys in head order.
1736        let mut key_vars: Vec<String> = Vec::new();
1737        let mut key_var_to_pos: HashMap<String, usize> = HashMap::new();
1738        let mut key_src_cols: Vec<usize> = Vec::new();
1739
1740        // Collect unique aggregate specs (op, var) in head order.
1741        let mut agg_specs: Vec<(AggOp, String)> = Vec::new();
1742        let mut agg_to_pos: HashMap<(AggOp, String), usize> = HashMap::new();
1743        let mut value_vars: Vec<String> = Vec::new();
1744        let mut value_var_to_pos: HashMap<String, usize> = HashMap::new();
1745        let mut value_src_cols: Vec<usize> = Vec::new();
1746
1747        for term in &head.terms {
1748            match term {
1749                Term::Variable(name) => {
1750                    if !key_var_to_pos.contains_key(name) {
1751                        let col = var_env
1752                            .get_column(name)
1753                            .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1754                        let pos = key_vars.len();
1755                        key_vars.push(name.clone());
1756                        key_var_to_pos.insert(name.clone(), pos);
1757                        key_src_cols.push(col);
1758                    }
1759                }
1760                Term::Aggregate(agg) => {
1761                    let key = (agg.op, agg.variable.clone());
1762                    if let std::collections::hash_map::Entry::Vacant(entry) = agg_to_pos.entry(key)
1763                    {
1764                        // Ensure the aggregated variable is bound.
1765                        let col = var_env
1766                            .get_column(&agg.variable)
1767                            .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1768
1769                        // Ensure the value variable exists in the groupby input.
1770                        let value_pos = *value_var_to_pos
1771                            .entry(agg.variable.clone())
1772                            .or_insert_with(|| {
1773                                let p = value_vars.len();
1774                                value_vars.push(agg.variable.clone());
1775                                value_src_cols.push(col);
1776                                p
1777                            });
1778
1779                        let agg_pos = agg_specs.len();
1780                        agg_specs.push((agg.op, agg.variable.clone()));
1781                        entry.insert(agg_pos);
1782
1783                        // Keep clippy happy about unused value_pos in insert_with closure.
1784                        let _ = value_pos;
1785                    }
1786                }
1787                Term::Anonymous => {
1788                    return Err(XlogError::Compilation(
1789                        "Anonymous wildcard '_' not allowed in rule head".to_string(),
1790                    ));
1791                }
1792                Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1793                    // Constants are allowed in the head; they are projected after aggregation.
1794                }
1795                Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1796                    return Err(v085_term_not_lowerable(
1797                        "aggregate rule head",
1798                        v085_term_kind(term),
1799                    ));
1800                }
1801            }
1802        }
1803
1804        if agg_specs.is_empty() {
1805            return Err(XlogError::Compilation(
1806                "Rule marked as aggregate but no aggregate terms found".to_string(),
1807            ));
1808        }
1809
1810        // Build groupby input: [keys..., values...]. For global aggregates (no keys),
1811        // synthesize a constant key column so GroupBy is well-defined.
1812        let mut group_input_cols: Vec<ProjectExpr> = Vec::new();
1813        let mut key_cols: Vec<usize> = Vec::new();
1814
1815        if key_src_cols.is_empty() {
1816            group_input_cols.push(ProjectExpr::Computed(
1817                Expr::Const(ConstValue::U32(0)),
1818                ScalarType::U32,
1819            ));
1820            key_cols.push(0);
1821        } else {
1822            for (i, &col) in key_src_cols.iter().enumerate() {
1823                group_input_cols.push(ProjectExpr::Column(col));
1824                key_cols.push(i);
1825            }
1826        }
1827
1828        let value_offset = group_input_cols.len();
1829        for &col in &value_src_cols {
1830            group_input_cols.push(ProjectExpr::Column(col));
1831        }
1832
1833        let group_input = RirNode::Project {
1834            input: Box::new(body),
1835            columns: group_input_cols,
1836        };
1837
1838        // Build multi-aggregation spec list (value_col indices are in the group_input schema).
1839        let mut aggs: Vec<(usize, CoreAggOp)> = Vec::with_capacity(agg_specs.len());
1840        for (op, var) in &agg_specs {
1841            let value_pos = *value_var_to_pos
1842                .get(var)
1843                .ok_or_else(|| XlogError::UnsafeVariable(var.clone()))?;
1844            let value_col = value_offset + value_pos;
1845            aggs.push((value_col, convert_agg_op(op)));
1846        }
1847
1848        let groupby = RirNode::GroupBy {
1849            input: Box::new(group_input),
1850            key_cols,
1851            aggs,
1852        };
1853
1854        // Final projection to match head term order:
1855        // - variables map to group key columns
1856        // - aggregates map to groupby output agg columns (after keys)
1857        // - constants are computed columns
1858        let key_count = if key_src_cols.is_empty() {
1859            1
1860        } else {
1861            key_vars.len()
1862        };
1863
1864        let mut final_proj: Vec<ProjectExpr> = Vec::with_capacity(head.terms.len());
1865        for term in &head.terms {
1866            match term {
1867                Term::Variable(name) => {
1868                    let idx = if key_src_cols.is_empty() {
1869                        // Global aggregates have no key vars in the output; binding a variable in the head
1870                        // is a semantic error because it would be unbound.
1871                        return Err(XlogError::UnsafeVariable(name.clone()));
1872                    } else {
1873                        *key_var_to_pos
1874                            .get(name)
1875                            .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?
1876                    };
1877                    final_proj.push(ProjectExpr::Column(idx));
1878                }
1879                Term::Aggregate(agg) => {
1880                    let pos = *agg_to_pos
1881                        .get(&(agg.op, agg.variable.clone()))
1882                        .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1883                    final_proj.push(ProjectExpr::Column(key_count + pos));
1884                }
1885                Term::Anonymous => {
1886                    return Err(XlogError::Compilation(
1887                        "Anonymous wildcard '_' not allowed in rule head".to_string(),
1888                    ));
1889                }
1890                Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1891                    let (expr, typ) = term_to_project_const_expr(term)?;
1892                    final_proj.push(ProjectExpr::Computed(expr, typ));
1893                }
1894                Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1895                    return Err(v085_term_not_lowerable(
1896                        "aggregate rule projection",
1897                        v085_term_kind(term),
1898                    ));
1899                }
1900            }
1901        }
1902
1903        if final_proj.is_empty() {
1904            return Err(XlogError::Compilation(
1905                "Aggregate rule produced empty head projection".to_string(),
1906            ));
1907        }
1908
1909        Ok(RirNode::Project {
1910            input: Box::new(groupby),
1911            columns: final_proj,
1912        })
1913    }
1914
1915    /// Infer the result type of an arithmetic expression (strict same-type)
1916    pub(crate) fn infer_arith_type(
1917        &self,
1918        expr: &ArithExpr,
1919        var_env: &VariableEnv,
1920    ) -> Result<ScalarType> {
1921        match expr {
1922            ArithExpr::Variable(name) => var_env.get_type(name).ok_or_else(|| {
1923                XlogError::Compilation(format!("Unknown variable {} in arithmetic", name))
1924            }),
1925            ArithExpr::Integer(_) => Ok(ScalarType::I64),
1926            ArithExpr::Float(_) => Ok(ScalarType::F64),
1927
1928            ArithExpr::Add(l, r)
1929            | ArithExpr::Sub(l, r)
1930            | ArithExpr::Mul(l, r)
1931            | ArithExpr::Div(l, r) => {
1932                let lt = self.infer_arith_type(l, var_env)?;
1933                let rt = self.infer_arith_type(r, var_env)?;
1934
1935                if lt != rt {
1936                    return Err(XlogError::Compilation(format!(
1937                        "Type mismatch in arithmetic: {:?} vs {:?}. Use cast() for conversion.",
1938                        lt, rt
1939                    )));
1940                }
1941
1942                if !Self::is_numeric_type(&lt) {
1943                    return Err(XlogError::Compilation(format!(
1944                        "Arithmetic requires numeric type, got {:?}",
1945                        lt
1946                    )));
1947                }
1948
1949                Ok(lt)
1950            }
1951
1952            ArithExpr::Mod(l, r) => {
1953                let lt = self.infer_arith_type(l, var_env)?;
1954                let rt = self.infer_arith_type(r, var_env)?;
1955
1956                if lt != rt {
1957                    return Err(XlogError::Compilation(format!(
1958                        "Type mismatch in mod: {:?} vs {:?}",
1959                        lt, rt
1960                    )));
1961                }
1962
1963                if matches!(lt, ScalarType::F32 | ScalarType::F64) {
1964                    return Err(XlogError::Compilation(
1965                        "Modulo (%) not supported for floating point".into(),
1966                    ));
1967                }
1968
1969                Ok(lt)
1970            }
1971
1972            ArithExpr::Abs(inner) => {
1973                let t = self.infer_arith_type(inner, var_env)?;
1974                if !Self::is_numeric_type(&t) {
1975                    return Err(XlogError::Compilation(format!(
1976                        "abs requires numeric type, got {:?}",
1977                        t
1978                    )));
1979                }
1980                Ok(t)
1981            }
1982
1983            ArithExpr::Min(l, r) | ArithExpr::Max(l, r) => {
1984                let lt = self.infer_arith_type(l, var_env)?;
1985                let rt = self.infer_arith_type(r, var_env)?;
1986
1987                if lt != rt {
1988                    return Err(XlogError::Compilation(format!(
1989                        "Type mismatch in min/max: {:?} vs {:?}",
1990                        lt, rt
1991                    )));
1992                }
1993
1994                if !Self::is_numeric_type(&lt) {
1995                    return Err(XlogError::Compilation(format!(
1996                        "min/max requires numeric type, got {:?}",
1997                        lt
1998                    )));
1999                }
2000
2001                Ok(lt)
2002            }
2003
2004            ArithExpr::Pow(base, exp) => {
2005                let base_t = self.infer_arith_type(base, var_env)?;
2006                let exp_t = self.infer_arith_type(exp, var_env)?;
2007
2008                if !Self::is_numeric_type(&base_t) || !Self::is_numeric_type(&exp_t) {
2009                    return Err(XlogError::Compilation(format!(
2010                        "pow requires numeric operands, got {:?} and {:?}",
2011                        base_t, exp_t
2012                    )));
2013                }
2014
2015                // pow always returns f64 (standard math behavior)
2016                Ok(ScalarType::F64)
2017            }
2018
2019            ArithExpr::Cast(_, target) => Ok(*target),
2020
2021            ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2022                "User-defined function '{}' must be inlined before lowering",
2023                name
2024            ))),
2025
2026            ArithExpr::Conditional {
2027                then_expr,
2028                else_expr,
2029                ..
2030            } => {
2031                // Both branches must have the same type
2032                let then_type = self.infer_arith_type(then_expr, var_env)?;
2033                let else_type = self.infer_arith_type(else_expr, var_env)?;
2034                if then_type != else_type {
2035                    return Err(XlogError::Compilation(format!(
2036                        "Conditional branches have different types: {:?} vs {:?}",
2037                        then_type, else_type
2038                    )));
2039                }
2040                Ok(then_type)
2041            }
2042        }
2043    }
2044
2045    fn is_numeric_type(t: &ScalarType) -> bool {
2046        matches!(
2047            t,
2048            ScalarType::I32
2049                | ScalarType::I64
2050                | ScalarType::U32
2051                | ScalarType::U64
2052                | ScalarType::F32
2053                | ScalarType::F64
2054        )
2055    }
2056
2057    /// Convert ArithExpr to IR Expr
2058    fn arith_to_expr(&self, arith: &ArithExpr, var_env: &VariableEnv) -> Result<Expr> {
2059        match arith {
2060            ArithExpr::Variable(name) => {
2061                let col = var_env.get_column(name).ok_or_else(|| {
2062                    XlogError::Compilation(format!(
2063                        "Variable {} not bound before use in arithmetic",
2064                        name
2065                    ))
2066                })?;
2067                Ok(Expr::Column(col))
2068            }
2069            ArithExpr::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
2070            ArithExpr::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
2071
2072            ArithExpr::Add(l, r) => Ok(Expr::Add(
2073                Box::new(self.arith_to_expr(l, var_env)?),
2074                Box::new(self.arith_to_expr(r, var_env)?),
2075            )),
2076            ArithExpr::Sub(l, r) => Ok(Expr::Sub(
2077                Box::new(self.arith_to_expr(l, var_env)?),
2078                Box::new(self.arith_to_expr(r, var_env)?),
2079            )),
2080            ArithExpr::Mul(l, r) => Ok(Expr::Mul(
2081                Box::new(self.arith_to_expr(l, var_env)?),
2082                Box::new(self.arith_to_expr(r, var_env)?),
2083            )),
2084            ArithExpr::Div(l, r) => Ok(Expr::Div(
2085                Box::new(self.arith_to_expr(l, var_env)?),
2086                Box::new(self.arith_to_expr(r, var_env)?),
2087            )),
2088            ArithExpr::Mod(l, r) => Ok(Expr::Mod(
2089                Box::new(self.arith_to_expr(l, var_env)?),
2090                Box::new(self.arith_to_expr(r, var_env)?),
2091            )),
2092
2093            ArithExpr::Abs(e) => Ok(Expr::Abs(Box::new(self.arith_to_expr(e, var_env)?))),
2094            ArithExpr::Min(l, r) => Ok(Expr::Min(
2095                Box::new(self.arith_to_expr(l, var_env)?),
2096                Box::new(self.arith_to_expr(r, var_env)?),
2097            )),
2098            ArithExpr::Max(l, r) => Ok(Expr::Max(
2099                Box::new(self.arith_to_expr(l, var_env)?),
2100                Box::new(self.arith_to_expr(r, var_env)?),
2101            )),
2102            ArithExpr::Pow(l, r) => Ok(Expr::Pow(
2103                Box::new(self.arith_to_expr(l, var_env)?),
2104                Box::new(self.arith_to_expr(r, var_env)?),
2105            )),
2106            ArithExpr::Cast(e, t) => Ok(Expr::Cast(Box::new(self.arith_to_expr(e, var_env)?), *t)),
2107
2108            ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2109                "User-defined function '{}' must be inlined before lowering",
2110                name
2111            ))),
2112
2113            ArithExpr::Conditional {
2114                cond_left,
2115                cond_op,
2116                cond_right,
2117                then_expr,
2118                else_expr,
2119            } => {
2120                // Convert AST comparison operator to IR comparison operator
2121                let ir_cond_op = match cond_op {
2122                    CompOp::Eq => CompareOp::Eq,
2123                    CompOp::Ne => CompareOp::Ne,
2124                    CompOp::Lt => CompareOp::Lt,
2125                    CompOp::Le => CompareOp::Le,
2126                    CompOp::Gt => CompareOp::Gt,
2127                    CompOp::Ge => CompareOp::Ge,
2128                };
2129
2130                // Build the condition as a Compare expression
2131                let condition = Expr::Compare {
2132                    left: Box::new(self.arith_to_expr(cond_left, var_env)?),
2133                    op: ir_cond_op,
2134                    right: Box::new(self.arith_to_expr(cond_right, var_env)?),
2135                };
2136
2137                // Build then and else expressions (recursive for nested conditionals)
2138                let then_ir = self.arith_to_expr(then_expr, var_env)?;
2139                let else_ir = self.arith_to_expr(else_expr, var_env)?;
2140
2141                Ok(Expr::Conditional {
2142                    condition: Box::new(condition),
2143                    then_expr: Box::new(then_ir),
2144                    else_expr: Box::new(else_ir),
2145                })
2146            }
2147        }
2148    }
2149
2150    /// Lower an is-expression to a Project node with computed column
2151    fn lower_is_expr(
2152        &mut self,
2153        is_expr: &IsExpr,
2154        input: RirNode,
2155        var_env: &mut VariableEnv,
2156    ) -> Result<RirNode> {
2157        // 1. Verify target is NOT already bound
2158        if var_env.contains(&is_expr.target) {
2159            return Err(XlogError::Compilation(format!(
2160                "Variable {} already bound; 'is' requires fresh variable",
2161                is_expr.target
2162            )));
2163        }
2164
2165        // 2. Verify all variables in expression are bound
2166        for var in is_expr.expr.variables() {
2167            if !var_env.contains(var) {
2168                return Err(XlogError::Compilation(format!(
2169                    "Variable {} used in arithmetic but not bound",
2170                    var
2171                )));
2172            }
2173        }
2174
2175        // 3. Infer result type
2176        let result_type = self.infer_arith_type(&is_expr.expr, var_env)?;
2177
2178        // 4. Convert expression to IR
2179        let ir_expr = self.arith_to_expr(&is_expr.expr, var_env)?;
2180
2181        // 5. Build projection: pass through all existing columns + add computed column
2182        let num_cols = var_env.column_count();
2183        let mut proj_exprs: Vec<ProjectExpr> = (0..num_cols).map(ProjectExpr::Column).collect();
2184        proj_exprs.push(ProjectExpr::Computed(ir_expr, result_type));
2185
2186        // 6. Bind the new variable
2187        var_env.bind(&is_expr.target, num_cols, result_type);
2188
2189        Ok(RirNode::Project {
2190            input: Box::new(input),
2191            columns: proj_exprs,
2192        })
2193    }
2194}
2195
2196/// Track variable occurrences and column positions
2197pub(crate) struct VariableEnv {
2198    /// Maps variable name to list of (predicate, position in atom, global column)
2199    occurrences: HashMap<String, Vec<(String, usize, usize)>>,
2200    /// Total columns in current result
2201    total_cols: usize,
2202    /// Maps variable name to its type (for type inference)
2203    types: HashMap<String, ScalarType>,
2204}
2205
2206impl VariableEnv {
2207    fn new() -> Self {
2208        Self {
2209            occurrences: HashMap::new(),
2210            total_cols: 0,
2211            types: HashMap::new(),
2212        }
2213    }
2214
2215    fn add_occurrence(&mut self, var: &str, pred: String, atom_pos: usize, global_col: usize) {
2216        self.occurrences
2217            .entry(var.to_string())
2218            .or_default()
2219            .push((pred, atom_pos, global_col));
2220    }
2221
2222    fn get_column(&self, var: &str) -> Option<usize> {
2223        self.occurrences
2224            .get(var)
2225            .and_then(|occs| occs.first())
2226            .map(|(_, _, col)| *col)
2227    }
2228
2229    /// Bind a variable to a column with a specific type (for type inference)
2230    fn bind(&mut self, name: &str, column: usize, typ: ScalarType) {
2231        self.types.insert(name.to_string(), typ);
2232        // Also add occurrence for column lookup
2233        self.occurrences
2234            .entry(name.to_string())
2235            .or_default()
2236            .push(("".to_string(), 0, column));
2237        // Update total_cols to account for the new computed column
2238        // This is critical for chained is-expressions where each adds a column
2239        if column >= self.total_cols {
2240            self.total_cols = column + 1;
2241        }
2242    }
2243
2244    /// Get the type of a bound variable
2245    fn get_type(&self, name: &str) -> Option<ScalarType> {
2246        self.types.get(name).copied()
2247    }
2248
2249    /// Check if a variable is bound
2250    fn contains(&self, name: &str) -> bool {
2251        self.occurrences.contains_key(name)
2252    }
2253
2254    /// Get the current column count (for adding new computed columns)
2255    fn column_count(&self) -> usize {
2256        self.total_cols
2257    }
2258}
2259
2260/// Infer the type of a term
2261fn infer_term_type(term: &Term) -> ScalarType {
2262    match term {
2263        Term::Variable(_) | Term::Anonymous => ScalarType::U64, // Default for variables
2264        Term::Integer(i) => {
2265            if *i >= 0 && *i <= u32::MAX as i64 {
2266                ScalarType::U32
2267            } else {
2268                ScalarType::I64
2269            }
2270        }
2271        Term::Float(_) => ScalarType::F64,
2272        Term::String(_) | Term::Symbol(_) => ScalarType::Symbol,
2273        Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
2274            ScalarType::U64
2275        }
2276        Term::Aggregate(agg) => match agg.op {
2277            AggOp::Count => ScalarType::U32,
2278            AggOp::Sum => ScalarType::U64,
2279            AggOp::Min | AggOp::Max => ScalarType::U32,
2280            AggOp::LogSumExp => ScalarType::F64,
2281        },
2282    }
2283}
2284
2285fn sort_labels_from_terms(terms: &[Term]) -> Vec<String> {
2286    terms
2287        .iter()
2288        .enumerate()
2289        .map(|(idx, term)| match term {
2290            Term::Variable(name) if !name.trim().is_empty() => name.clone(),
2291            Term::Aggregate(agg) => format!("{:?}_{}", agg.op, agg.variable),
2292            Term::List(_) => format!("list{}", idx),
2293            Term::Cons { .. } => format!("cons{}", idx),
2294            Term::Compound { functor, .. } => functor.clone(),
2295            Term::PredRef(name) => name.clone(),
2296            _ => format!("c{}", idx),
2297        })
2298        .collect()
2299}
2300
2301/// Convert a term to a constant value (if it is a constant)
2302fn term_to_const_value(term: &Term) -> Option<ConstValue> {
2303    match term {
2304        Term::Integer(i) => Some(ConstValue::I64(*i)),
2305        Term::Float(f) => Some(ConstValue::F64(*f)),
2306        Term::String(s) => Some(ConstValue::Symbol(s.clone())),
2307        Term::Symbol(id) => Some(ConstValue::Symbol(symbol::resolve(*id))),
2308        Term::Variable(_)
2309        | Term::Anonymous
2310        | Term::Aggregate(_)
2311        | Term::List(_)
2312        | Term::Cons { .. }
2313        | Term::Compound { .. }
2314        | Term::PredRef(_) => None,
2315    }
2316}
2317
2318fn term_to_typed_const_value(term: &Term, expected: ScalarType) -> Result<Option<ConstValue>> {
2319    let const_val = match term {
2320        Term::Integer(i) => match expected {
2321            ScalarType::U32 => {
2322                if *i >= 0 && *i <= u32::MAX as i64 {
2323                    ConstValue::U32(*i as u32)
2324                } else {
2325                    return Err(XlogError::Compilation(format!(
2326                        "Integer literal {} out of range for {:?}",
2327                        i, expected
2328                    )));
2329                }
2330            }
2331            ScalarType::U64 => {
2332                if *i >= 0 {
2333                    ConstValue::U64(*i as u64)
2334                } else {
2335                    return Err(XlogError::Compilation(format!(
2336                        "Integer literal {} out of range for {:?}",
2337                        i, expected
2338                    )));
2339                }
2340            }
2341            ScalarType::I32 => {
2342                if *i >= i32::MIN as i64 && *i <= i32::MAX as i64 {
2343                    ConstValue::I32(*i as i32)
2344                } else {
2345                    return Err(XlogError::Compilation(format!(
2346                        "Integer literal {} out of range for {:?}",
2347                        i, expected
2348                    )));
2349                }
2350            }
2351            ScalarType::I64 => ConstValue::I64(*i),
2352            ScalarType::F32 => {
2353                let value = *i as f64;
2354                if value < f32::MIN as f64 || value > f32::MAX as f64 {
2355                    return Err(XlogError::Compilation(format!(
2356                        "Integer literal {} out of range for {:?}",
2357                        i, expected
2358                    )));
2359                }
2360                ConstValue::F32(value as f32)
2361            }
2362            ScalarType::F64 => ConstValue::F64(*i as f64),
2363            ScalarType::Bool => {
2364                if *i == 0 || *i == 1 {
2365                    ConstValue::Bool(*i == 1)
2366                } else {
2367                    return Err(XlogError::Compilation(format!(
2368                        "Integer literal {} not valid for {:?}",
2369                        i, expected
2370                    )));
2371                }
2372            }
2373            ScalarType::Symbol => {
2374                return Err(XlogError::Compilation(format!(
2375                    "Integer literal {} not valid for {:?}",
2376                    i, expected
2377                )));
2378            }
2379        },
2380        Term::Float(f) => match expected {
2381            ScalarType::F32 => {
2382                if !f.is_finite() {
2383                    return Err(XlogError::Compilation(format!(
2384                        "Float literal {} not valid for {:?}",
2385                        f, expected
2386                    )));
2387                }
2388                if *f < f32::MIN as f64 || *f > f32::MAX as f64 {
2389                    return Err(XlogError::Compilation(format!(
2390                        "Float literal {} out of range for {:?}",
2391                        f, expected
2392                    )));
2393                }
2394                ConstValue::F32(*f as f32)
2395            }
2396            ScalarType::F64 => ConstValue::F64(*f),
2397            ScalarType::U32
2398            | ScalarType::U64
2399            | ScalarType::I32
2400            | ScalarType::I64
2401            | ScalarType::Bool
2402            | ScalarType::Symbol => {
2403                return Err(XlogError::Compilation(format!(
2404                    "Float literal {} not valid for {:?}",
2405                    f, expected
2406                )));
2407            }
2408        },
2409        Term::String(s) => {
2410            if expected == ScalarType::Symbol {
2411                ConstValue::Symbol(s.clone())
2412            } else {
2413                return Err(XlogError::Compilation(format!(
2414                    "String literal {} not valid for {:?}",
2415                    s, expected
2416                )));
2417            }
2418        }
2419        Term::Symbol(id) => {
2420            if expected == ScalarType::Symbol {
2421                ConstValue::Symbol(symbol::resolve(*id))
2422            } else {
2423                return Err(XlogError::Compilation(format!(
2424                    "Symbol literal {} not valid for {:?}",
2425                    symbol::resolve(*id),
2426                    expected
2427                )));
2428            }
2429        }
2430        Term::Variable(_)
2431        | Term::Anonymous
2432        | Term::Aggregate(_)
2433        | Term::List(_)
2434        | Term::Cons { .. }
2435        | Term::Compound { .. }
2436        | Term::PredRef(_) => return Ok(None),
2437    };
2438
2439    Ok(Some(const_val))
2440}
2441
2442fn term_to_project_const_expr(term: &Term) -> Result<(Expr, ScalarType)> {
2443    match term {
2444        Term::Integer(i) => {
2445            if *i >= 0 && *i <= u32::MAX as i64 {
2446                Ok((Expr::Const(ConstValue::U32(*i as u32)), ScalarType::U32))
2447            } else {
2448                Ok((Expr::Const(ConstValue::I64(*i)), ScalarType::I64))
2449            }
2450        }
2451        Term::Float(f) => Ok((Expr::Const(ConstValue::F64(*f)), ScalarType::F64)),
2452        Term::String(s) => Ok((
2453            Expr::Const(ConstValue::Symbol(s.clone())),
2454            ScalarType::Symbol,
2455        )),
2456        Term::Symbol(id) => Ok((
2457            Expr::Const(ConstValue::Symbol(symbol::resolve(*id))),
2458            ScalarType::Symbol,
2459        )),
2460        Term::Variable(_)
2461        | Term::Anonymous
2462        | Term::Aggregate(_)
2463        | Term::List(_)
2464        | Term::Cons { .. }
2465        | Term::Compound { .. }
2466        | Term::PredRef(_) => Err(XlogError::Compilation("Expected constant term".to_string())),
2467    }
2468}
2469
2470/// Convert AST AggOp to core AggOp
2471fn convert_agg_op(op: &AggOp) -> CoreAggOp {
2472    match op {
2473        AggOp::Count => CoreAggOp::Count,
2474        AggOp::Sum => CoreAggOp::Sum,
2475        AggOp::Min => CoreAggOp::Min,
2476        AggOp::Max => CoreAggOp::Max,
2477        AggOp::LogSumExp => CoreAggOp::LogSumExp,
2478    }
2479}
2480
2481// Export the find_sccs_for_lowering function from stratify
2482// We need to add this to the stratify module
2483
2484#[cfg(test)]
2485mod arith_type_tests {
2486    use super::*;
2487    use crate::ast::ArithExpr;
2488
2489    #[test]
2490    fn test_arith_type_inference_same_type() {
2491        // X + Y where both are i64 should succeed and return i64
2492        let lowerer = Lowerer::new();
2493        let mut var_env = VariableEnv::new();
2494        var_env.bind("X", 0, ScalarType::I64);
2495        var_env.bind("Y", 1, ScalarType::I64);
2496
2497        let expr = ArithExpr::Add(
2498            Box::new(ArithExpr::Variable("X".to_string())),
2499            Box::new(ArithExpr::Variable("Y".to_string())),
2500        );
2501        let result = lowerer.infer_arith_type(&expr, &var_env);
2502        assert!(result.is_ok());
2503        assert_eq!(result.unwrap(), ScalarType::I64);
2504    }
2505
2506    #[test]
2507    fn test_arith_type_inference_mismatch() {
2508        // X + Y where X is i64 and Y is f64 should fail
2509        let lowerer = Lowerer::new();
2510        let mut var_env = VariableEnv::new();
2511        var_env.bind("X", 0, ScalarType::I64);
2512        var_env.bind("Y", 1, ScalarType::F64);
2513
2514        let expr = ArithExpr::Add(
2515            Box::new(ArithExpr::Variable("X".to_string())),
2516            Box::new(ArithExpr::Variable("Y".to_string())),
2517        );
2518        let result = lowerer.infer_arith_type(&expr, &var_env);
2519        assert!(result.is_err());
2520    }
2521}
2522
2523#[cfg(test)]
2524mod tests {
2525    use super::*;
2526    use crate::ast::*;
2527
2528    fn pred_decl(name: &str, types: Vec<ScalarType>) -> PredDecl {
2529        let type_refs: Vec<TypeRef> = types.into_iter().map(TypeRef::Scalar).collect();
2530        let columns = type_refs
2531            .iter()
2532            .cloned()
2533            .map(|typ| PredColumn { name: None, typ })
2534            .collect();
2535        PredDecl {
2536            name: name.to_string(),
2537            types: type_refs,
2538            columns,
2539            is_private: false,
2540        }
2541    }
2542
2543    /// Helper to create a simple edge atom
2544    fn edge_atom(x: &str, y: &str) -> Atom {
2545        Atom {
2546            predicate: "edge".to_string(),
2547            terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2548        }
2549    }
2550
2551    /// Helper to create a reach atom
2552    fn reach_atom(x: &str, y: &str) -> Atom {
2553        Atom {
2554            predicate: "reach".to_string(),
2555            terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2556        }
2557    }
2558
2559    /// Helper to create a node atom
2560    fn node_atom(x: &str) -> Atom {
2561        Atom {
2562            predicate: "node".to_string(),
2563            terms: vec![Term::Variable(x.to_string())],
2564        }
2565    }
2566
2567    #[test]
2568    fn test_lowerer_new() {
2569        let lowerer = Lowerer::new();
2570        assert!(lowerer.schemas.is_empty());
2571        assert!(lowerer.strata.is_empty());
2572        assert_eq!(lowerer.next_rel_id, 0);
2573    }
2574
2575    #[test]
2576    fn test_get_or_create_rel_id() {
2577        let mut lowerer = Lowerer::new();
2578        let id1 = lowerer.get_or_create_rel_id("edge");
2579        let id2 = lowerer.get_or_create_rel_id("reach");
2580        let id3 = lowerer.get_or_create_rel_id("edge");
2581
2582        assert_eq!(id1, RelId(0));
2583        assert_eq!(id2, RelId(1));
2584        assert_eq!(id3, RelId(0)); // Same as id1
2585    }
2586
2587    #[test]
2588    fn test_infer_schemas_from_facts() {
2589        let mut program = Program::new();
2590        program.rules.push(Rule {
2591            head: Atom {
2592                predicate: "edge".to_string(),
2593                terms: vec![Term::Integer(1), Term::Integer(2)],
2594            },
2595            body: vec![],
2596        });
2597
2598        let mut lowerer = Lowerer::new();
2599        lowerer.infer_schemas(&program).unwrap();
2600
2601        assert!(lowerer.schemas.contains_key("edge"));
2602        let schema = lowerer.schemas.get("edge").unwrap();
2603        assert_eq!(schema.arity(), 2);
2604    }
2605
2606    #[test]
2607    fn test_lower_simple_rule() {
2608        // reach(X, Y) :- edge(X, Y).
2609        let rule = Rule {
2610            head: reach_atom("X", "Y"),
2611            body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
2612        };
2613
2614        let mut lowerer = Lowerer::new();
2615        lowerer.schemas.insert(
2616            "edge".to_string(),
2617            Schema::new(vec![
2618                ("c0".to_string(), ScalarType::U32),
2619                ("c1".to_string(), ScalarType::U32),
2620            ]),
2621        );
2622
2623        let result = lowerer.lower_rule(&rule);
2624        assert!(result.is_ok());
2625
2626        let node = result.unwrap();
2627        // Should be just a scan (no projection needed since columns match)
2628        assert!(matches!(node, RirNode::Scan { .. }));
2629    }
2630
2631    #[test]
2632    fn test_lower_join_rule() {
2633        // reach(X, Z) :- reach(X, Y), edge(Y, Z).
2634        let rule = Rule {
2635            head: Atom {
2636                predicate: "reach".to_string(),
2637                terms: vec![
2638                    Term::Variable("X".to_string()),
2639                    Term::Variable("Z".to_string()),
2640                ],
2641            },
2642            body: vec![
2643                BodyLiteral::Positive(reach_atom("X", "Y")),
2644                BodyLiteral::Positive(edge_atom("Y", "Z")),
2645            ],
2646        };
2647
2648        let mut lowerer = Lowerer::new();
2649        lowerer.schemas.insert(
2650            "reach".to_string(),
2651            Schema::new(vec![
2652                ("c0".to_string(), ScalarType::U32),
2653                ("c1".to_string(), ScalarType::U32),
2654            ]),
2655        );
2656        lowerer.schemas.insert(
2657            "edge".to_string(),
2658            Schema::new(vec![
2659                ("c0".to_string(), ScalarType::U32),
2660                ("c1".to_string(), ScalarType::U32),
2661            ]),
2662        );
2663
2664        let result = lowerer.lower_rule(&rule);
2665        assert!(result.is_ok());
2666
2667        let node = result.unwrap();
2668        // Should be Project(Join(Scan, Scan))
2669        if let RirNode::Project { input, columns } = node {
2670            // X from reach (col 0), Z from edge (col 3)
2671            assert_eq!(
2672                columns,
2673                vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
2674            );
2675            assert!(matches!(*input, RirNode::Join { .. }));
2676            if let RirNode::Join {
2677                left_keys,
2678                right_keys,
2679                ..
2680            } = *input
2681            {
2682                assert_eq!(left_keys, vec![1]); // Y in reach (position 1)
2683                assert_eq!(right_keys, vec![0]); // Y in edge (position 0)
2684            }
2685        } else {
2686            panic!("Expected Project node");
2687        }
2688    }
2689
2690    #[test]
2691    fn test_join_order_prefers_smaller_relation() {
2692        // out(X) :- big(X), small(X).
2693        let rule = Rule {
2694            head: Atom {
2695                predicate: "out".to_string(),
2696                terms: vec![Term::Variable("X".to_string())],
2697            },
2698            body: vec![
2699                BodyLiteral::Positive(Atom {
2700                    predicate: "big".to_string(),
2701                    terms: vec![Term::Variable("X".to_string())],
2702                }),
2703                BodyLiteral::Positive(Atom {
2704                    predicate: "small".to_string(),
2705                    terms: vec![Term::Variable("X".to_string())],
2706                }),
2707            ],
2708        };
2709
2710        let mut lowerer = Lowerer::new();
2711        lowerer.schemas.insert(
2712            "big".to_string(),
2713            Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2714        );
2715        lowerer.schemas.insert(
2716            "small".to_string(),
2717            Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2718        );
2719
2720        // Ensure stable RelIds independent of join order.
2721        let big_id = lowerer.get_or_create_rel_id("big");
2722        let small_id = lowerer.get_or_create_rel_id("small");
2723        assert_eq!(big_id, RelId(0));
2724        assert_eq!(small_id, RelId(1));
2725
2726        // Prefer scanning the smaller relation first.
2727        lowerer.est_cardinality.insert("big".to_string(), 10_000);
2728        lowerer.est_cardinality.insert("small".to_string(), 10);
2729
2730        let node = lowerer.lower_rule(&rule).unwrap();
2731        let join = match node {
2732            RirNode::Project { input, .. } => *input,
2733            other => other,
2734        };
2735
2736        match join {
2737            RirNode::Join { left, right, .. } => {
2738                // Prefer building the hash table on the smaller relation (right/build side).
2739                assert!(matches!(*left, RirNode::Scan { rel } if rel == big_id));
2740                assert!(matches!(*right, RirNode::Scan { rel } if rel == small_id));
2741            }
2742            other => panic!("Expected Join node, got {:?}", other),
2743        }
2744    }
2745
2746    #[test]
2747    fn test_lower_negation() {
2748        // isolated(X) :- node(X), not edge(X, _).
2749        let rule = Rule {
2750            head: Atom {
2751                predicate: "isolated".to_string(),
2752                terms: vec![Term::Variable("X".to_string())],
2753            },
2754            body: vec![
2755                BodyLiteral::Positive(node_atom("X")),
2756                BodyLiteral::Negated(Atom {
2757                    predicate: "edge".to_string(),
2758                    terms: vec![
2759                        Term::Variable("X".to_string()),
2760                        Term::Variable("_".to_string()),
2761                    ],
2762                }),
2763            ],
2764        };
2765
2766        let mut lowerer = Lowerer::new();
2767        lowerer.schemas.insert(
2768            "node".to_string(),
2769            Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2770        );
2771        lowerer.schemas.insert(
2772            "edge".to_string(),
2773            Schema::new(vec![
2774                ("c0".to_string(), ScalarType::U32),
2775                ("c1".to_string(), ScalarType::U32),
2776            ]),
2777        );
2778
2779        let result = lowerer.lower_rule(&rule);
2780        assert!(result.is_ok());
2781
2782        // The result should involve a Diff or semi-join for negation
2783        let node = result.unwrap();
2784        // Verify the structure contains the negation handling
2785        fn contains_diff_or_semi(node: &RirNode) -> bool {
2786            match node {
2787                RirNode::Diff { .. } => true,
2788                RirNode::Join {
2789                    join_type: JoinType::Semi,
2790                    ..
2791                } => true,
2792                RirNode::Join { left, right, .. } => {
2793                    contains_diff_or_semi(left) || contains_diff_or_semi(right)
2794                }
2795                RirNode::Project { input, .. } => contains_diff_or_semi(input),
2796                RirNode::Filter { input, .. } => contains_diff_or_semi(input),
2797                _ => false,
2798            }
2799        }
2800        assert!(contains_diff_or_semi(&node));
2801    }
2802
2803    #[test]
2804    fn test_lower_comparison() {
2805        // greater(X, Y) :- pair(X, Y), X > Y.
2806        let rule = Rule {
2807            head: Atom {
2808                predicate: "greater".to_string(),
2809                terms: vec![
2810                    Term::Variable("X".to_string()),
2811                    Term::Variable("Y".to_string()),
2812                ],
2813            },
2814            body: vec![
2815                BodyLiteral::Positive(Atom {
2816                    predicate: "pair".to_string(),
2817                    terms: vec![
2818                        Term::Variable("X".to_string()),
2819                        Term::Variable("Y".to_string()),
2820                    ],
2821                }),
2822                BodyLiteral::Comparison(Comparison {
2823                    left: Term::Variable("X".to_string()),
2824                    op: CompOp::Gt,
2825                    right: Term::Variable("Y".to_string()),
2826                }),
2827            ],
2828        };
2829
2830        let mut lowerer = Lowerer::new();
2831        lowerer.schemas.insert(
2832            "pair".to_string(),
2833            Schema::new(vec![
2834                ("c0".to_string(), ScalarType::U32),
2835                ("c1".to_string(), ScalarType::U32),
2836            ]),
2837        );
2838
2839        let result = lowerer.lower_rule(&rule);
2840        assert!(result.is_ok());
2841
2842        let node = result.unwrap();
2843        // Should contain a Filter node
2844        fn contains_filter(node: &RirNode) -> bool {
2845            match node {
2846                RirNode::Filter { .. } => true,
2847                RirNode::Project { input, .. } => contains_filter(input),
2848                RirNode::Join { left, right, .. } => {
2849                    contains_filter(left) || contains_filter(right)
2850                }
2851                _ => false,
2852            }
2853        }
2854        assert!(contains_filter(&node));
2855    }
2856
2857    #[test]
2858    fn test_lower_constant_filter() {
2859        // specific_edge(Y) :- edge(1, Y).
2860        let rule = Rule {
2861            head: Atom {
2862                predicate: "specific_edge".to_string(),
2863                terms: vec![Term::Variable("Y".to_string())],
2864            },
2865            body: vec![BodyLiteral::Positive(Atom {
2866                predicate: "edge".to_string(),
2867                terms: vec![Term::Integer(1), Term::Variable("Y".to_string())],
2868            })],
2869        };
2870
2871        let mut lowerer = Lowerer::new();
2872        lowerer.schemas.insert(
2873            "edge".to_string(),
2874            Schema::new(vec![
2875                ("c0".to_string(), ScalarType::U32),
2876                ("c1".to_string(), ScalarType::U32),
2877            ]),
2878        );
2879
2880        let result = lowerer.lower_rule(&rule);
2881        assert!(result.is_ok());
2882
2883        let node = result.unwrap();
2884        // Should contain a Filter for the constant 1
2885        fn has_const_filter(node: &RirNode) -> bool {
2886            match node {
2887                RirNode::Filter {
2888                    predicate: Expr::Compare { right, .. },
2889                    ..
2890                } => matches!(**right, Expr::Const(_)),
2891                RirNode::Project { input, .. } => has_const_filter(input),
2892                _ => false,
2893            }
2894        }
2895        assert!(has_const_filter(&node));
2896    }
2897
2898    #[test]
2899    fn test_lower_repeated_variable_filter() {
2900        // self_loop(X) :- edge(X, X).
2901        let rule = Rule {
2902            head: Atom {
2903                predicate: "self_loop".to_string(),
2904                terms: vec![Term::Variable("X".to_string())],
2905            },
2906            body: vec![BodyLiteral::Positive(Atom {
2907                predicate: "edge".to_string(),
2908                terms: vec![
2909                    Term::Variable("X".to_string()),
2910                    Term::Variable("X".to_string()),
2911                ],
2912            })],
2913        };
2914
2915        let mut lowerer = Lowerer::new();
2916        lowerer.schemas.insert(
2917            "edge".to_string(),
2918            Schema::new(vec![
2919                ("c0".to_string(), ScalarType::U32),
2920                ("c1".to_string(), ScalarType::U32),
2921            ]),
2922        );
2923
2924        let node = lowerer.lower_rule(&rule).expect("lower_rule failed");
2925
2926        fn has_col_eq_filter(node: &RirNode) -> bool {
2927            match node {
2928                RirNode::Filter { predicate, .. } => match predicate {
2929                    Expr::Compare {
2930                        left,
2931                        op: CompareOp::Eq,
2932                        right,
2933                    } => {
2934                        matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2935                            || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2936                    }
2937                    Expr::And(exprs) => exprs.iter().any(|e| match e {
2938                        Expr::Compare {
2939                            left,
2940                            op: CompareOp::Eq,
2941                            right,
2942                        } => {
2943                            matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2944                                || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2945                        }
2946                        _ => false,
2947                    }),
2948                    _ => false,
2949                },
2950                RirNode::Project { input, .. } => has_col_eq_filter(input),
2951                _ => false,
2952            }
2953        }
2954
2955        assert!(has_col_eq_filter(&node));
2956    }
2957
2958    #[test]
2959    fn test_lower_program_simple() {
2960        let mut program = Program::new();
2961
2962        // edge(1, 2).
2963        program.rules.push(Rule {
2964            head: Atom {
2965                predicate: "edge".to_string(),
2966                terms: vec![Term::Integer(1), Term::Integer(2)],
2967            },
2968            body: vec![],
2969        });
2970
2971        // reach(X, Y) :- edge(X, Y).
2972        program.rules.push(Rule {
2973            head: reach_atom("X", "Y"),
2974            body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
2975        });
2976
2977        let mut lowerer = Lowerer::new();
2978        lowerer.set_strata(vec![vec!["edge".to_string()], vec!["reach".to_string()]]);
2979
2980        let result = lowerer.lower_program(&program);
2981        assert!(result.is_ok());
2982
2983        let plan = result.unwrap();
2984        assert!(!plan.sccs.is_empty());
2985    }
2986
2987    #[test]
2988    fn test_variable_env() {
2989        let mut env = VariableEnv::new();
2990        env.add_occurrence("X", "edge".to_string(), 0, 0);
2991        env.add_occurrence("Y", "edge".to_string(), 1, 1);
2992        env.add_occurrence("Y", "node".to_string(), 0, 2);
2993
2994        assert_eq!(env.get_column("X"), Some(0));
2995        assert_eq!(env.get_column("Y"), Some(1)); // First occurrence
2996        assert_eq!(env.get_column("Z"), None);
2997    }
2998
2999    #[test]
3000    fn test_infer_term_type() {
3001        assert_eq!(
3002            infer_term_type(&Term::Variable("X".to_string())),
3003            ScalarType::U64
3004        );
3005        assert_eq!(infer_term_type(&Term::Integer(42)), ScalarType::U32);
3006        assert_eq!(infer_term_type(&Term::Integer(i64::MAX)), ScalarType::I64);
3007        assert_eq!(infer_term_type(&Term::Float(3.25)), ScalarType::F64);
3008        assert_eq!(
3009            infer_term_type(&Term::Symbol(symbol::intern("foo"))),
3010            ScalarType::Symbol
3011        );
3012    }
3013
3014    #[test]
3015    fn test_convert_agg_op() {
3016        assert_eq!(convert_agg_op(&AggOp::Count), CoreAggOp::Count);
3017        assert_eq!(convert_agg_op(&AggOp::Sum), CoreAggOp::Sum);
3018        assert_eq!(convert_agg_op(&AggOp::Min), CoreAggOp::Min);
3019        assert_eq!(convert_agg_op(&AggOp::Max), CoreAggOp::Max);
3020        assert_eq!(convert_agg_op(&AggOp::LogSumExp), CoreAggOp::LogSumExp);
3021    }
3022
3023    #[test]
3024    fn test_variable_env_bind_updates_total_cols() {
3025        // Test that bind() properly updates total_cols for chained is-expressions
3026        let mut env = VariableEnv::new();
3027        env.total_cols = 2; // Simulate 2 columns from atoms
3028
3029        // Bind first computed variable at column 2
3030        env.bind("A", 2, ScalarType::I64);
3031        assert_eq!(
3032            env.column_count(),
3033            3,
3034            "total_cols should be 3 after first bind"
3035        );
3036        assert_eq!(env.get_column("A"), Some(2));
3037
3038        // Bind second computed variable at column 3
3039        env.bind("B", 3, ScalarType::I64);
3040        assert_eq!(
3041            env.column_count(),
3042            4,
3043            "total_cols should be 4 after second bind"
3044        );
3045        assert_eq!(env.get_column("B"), Some(3));
3046    }
3047
3048    #[test]
3049    fn test_lower_chained_is_expressions() {
3050        // result(A, B) :- input(X, Y), A is X + Y, B is A * 2.
3051        // This tests that chained is-expressions correctly update column indices
3052        let rule = Rule {
3053            head: Atom {
3054                predicate: "result".to_string(),
3055                terms: vec![
3056                    Term::Variable("A".to_string()),
3057                    Term::Variable("B".to_string()),
3058                ],
3059            },
3060            body: vec![
3061                BodyLiteral::Positive(Atom {
3062                    predicate: "input".to_string(),
3063                    terms: vec![
3064                        Term::Variable("X".to_string()),
3065                        Term::Variable("Y".to_string()),
3066                    ],
3067                }),
3068                BodyLiteral::IsExpr(IsExpr {
3069                    target: "A".to_string(),
3070                    expr: ArithExpr::Add(
3071                        Box::new(ArithExpr::Variable("X".to_string())),
3072                        Box::new(ArithExpr::Variable("Y".to_string())),
3073                    ),
3074                }),
3075                BodyLiteral::IsExpr(IsExpr {
3076                    target: "B".to_string(),
3077                    expr: ArithExpr::Mul(
3078                        Box::new(ArithExpr::Variable("A".to_string())),
3079                        Box::new(ArithExpr::Integer(2)),
3080                    ),
3081                }),
3082            ],
3083        };
3084
3085        let mut lowerer = Lowerer::new();
3086        lowerer.schemas.insert(
3087            "input".to_string(),
3088            Schema::new(vec![
3089                ("c0".to_string(), ScalarType::I64),
3090                ("c1".to_string(), ScalarType::I64),
3091            ]),
3092        );
3093
3094        let result = lowerer.lower_rule(&rule);
3095        assert!(
3096            result.is_ok(),
3097            "Lowering chained is-expressions should succeed: {:?}",
3098            result.err()
3099        );
3100
3101        let node = result.unwrap();
3102
3103        // The structure should be:
3104        // Project([col 2, col 3]) <-- final projection for A, B
3105        //   Project([col 0, col 1, col 2, A*2]) <-- second is-expr adds B at col 3
3106        //     Project([col 0, col 1, X+Y]) <-- first is-expr adds A at col 2
3107        //       Scan(input)
3108
3109        // Verify we have nested Project nodes
3110        fn count_projects(node: &RirNode) -> usize {
3111            match node {
3112                RirNode::Project { input, .. } => 1 + count_projects(input),
3113                _ => 0,
3114            }
3115        }
3116
3117        // We expect 3 Project nodes: 2 for is-expressions + 1 for final head projection
3118        let project_count = count_projects(&node);
3119        assert!(
3120            project_count >= 2,
3121            "Expected at least 2 Project nodes for chained is-exprs, got {}",
3122            project_count
3123        );
3124
3125        // Verify the final projection references columns 2 and 3 (A and B)
3126        if let RirNode::Project { columns, .. } = &node {
3127            assert_eq!(columns.len(), 2, "Head has 2 variables");
3128            // A should be at column 2, B at column 3
3129            assert_eq!(columns[0], ProjectExpr::Column(2), "A should be column 2");
3130            assert_eq!(columns[1], ProjectExpr::Column(3), "B should be column 3");
3131        } else {
3132            panic!("Expected top-level Project node");
3133        }
3134    }
3135
3136    #[test]
3137    fn test_u64_comparison_type_from_pred_decl() {
3138        // Test that u64 type from pred decl is preserved in comparison lowering
3139        let mut program = Program::new();
3140
3141        // pred count_data(symbol, u64).
3142        program.predicates.push(pred_decl(
3143            "count_data",
3144            vec![ScalarType::Symbol, ScalarType::U64],
3145        ));
3146
3147        // count_data(alice, 5).
3148        program.rules.push(Rule {
3149            head: Atom {
3150                predicate: "count_data".to_string(),
3151                terms: vec![
3152                    Term::Symbol(xlog_core::symbol::intern("alice")),
3153                    Term::Integer(5),
3154                ],
3155            },
3156            body: vec![],
3157        });
3158
3159        // pred big_count(symbol, u64).
3160        program.predicates.push(pred_decl(
3161            "big_count",
3162            vec![ScalarType::Symbol, ScalarType::U64],
3163        ));
3164
3165        // big_count(Name, Count) :- count_data(Name, Count), Count >= 3.
3166        program.rules.push(Rule {
3167            head: Atom {
3168                predicate: "big_count".to_string(),
3169                terms: vec![
3170                    Term::Variable("Name".to_string()),
3171                    Term::Variable("Count".to_string()),
3172                ],
3173            },
3174            body: vec![
3175                BodyLiteral::Positive(Atom {
3176                    predicate: "count_data".to_string(),
3177                    terms: vec![
3178                        Term::Variable("Name".to_string()),
3179                        Term::Variable("Count".to_string()),
3180                    ],
3181                }),
3182                BodyLiteral::Comparison(Comparison {
3183                    left: Term::Variable("Count".to_string()),
3184                    op: CompOp::Ge,
3185                    right: Term::Integer(3),
3186                }),
3187            ],
3188        });
3189
3190        let mut lowerer = Lowerer::new();
3191        lowerer.infer_schemas(&program).unwrap();
3192
3193        // Verify schema has correct types
3194        let schema = lowerer
3195            .schemas
3196            .get("count_data")
3197            .expect("schema for count_data");
3198        assert_eq!(
3199            schema.column_type(0),
3200            Some(ScalarType::Symbol),
3201            "First column should be Symbol"
3202        );
3203        assert_eq!(
3204            schema.column_type(1),
3205            Some(ScalarType::U64),
3206            "Second column should be U64"
3207        );
3208
3209        // Now test lowering the rule with comparison
3210        lowerer.set_strata(vec![
3211            vec!["count_data".to_string()],
3212            vec!["big_count".to_string()],
3213        ]);
3214        lowerer.build_sccs(&program);
3215
3216        let rule = &program.rules[1]; // big_count rule
3217        let result = lowerer.lower_rule(rule);
3218        assert!(
3219            result.is_ok(),
3220            "Lowering should succeed: {:?}",
3221            result.err()
3222        );
3223
3224        // Check that the filter has the correct constant type
3225        fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3226            match node {
3227                RirNode::Filter { predicate, input } => {
3228                    if let Expr::Compare { right, .. } = predicate {
3229                        if let Expr::Const(val) = right.as_ref() {
3230                            return Some(val);
3231                        }
3232                    }
3233                    find_compare_const(input)
3234                }
3235                RirNode::Project { input, .. } => find_compare_const(input),
3236                RirNode::Join { left, right, .. } => {
3237                    find_compare_const(left).or_else(|| find_compare_const(right))
3238                }
3239                _ => None,
3240            }
3241        }
3242
3243        let node = result.unwrap();
3244        let const_val = find_compare_const(&node);
3245        assert!(const_val.is_some(), "Should find a constant in comparison");
3246
3247        // The constant should be U64(3), not I64(3)
3248        match const_val.unwrap() {
3249            ConstValue::U64(v) => assert_eq!(*v, 3, "Value should be 3"),
3250            other => panic!("Expected U64(3), got {:?}", other),
3251        }
3252    }
3253
3254    #[test]
3255    fn test_u64_comparison_with_aggregation() {
3256        use crate::ast::AggExpr;
3257
3258        // Test aggregation + comparison case
3259        let mut program = Program::new();
3260
3261        // pred reports_to(symbol, symbol).
3262        program.predicates.push(pred_decl(
3263            "reports_to",
3264            vec![ScalarType::Symbol, ScalarType::Symbol],
3265        ));
3266
3267        // reports_to facts
3268        program.rules.push(Rule {
3269            head: Atom {
3270                predicate: "reports_to".to_string(),
3271                terms: vec![
3272                    Term::Symbol(xlog_core::symbol::intern("alice")),
3273                    Term::Symbol(xlog_core::symbol::intern("bob")),
3274                ],
3275            },
3276            body: vec![],
3277        });
3278        program.rules.push(Rule {
3279            head: Atom {
3280                predicate: "reports_to".to_string(),
3281                terms: vec![
3282                    Term::Symbol(xlog_core::symbol::intern("carol")),
3283                    Term::Symbol(xlog_core::symbol::intern("bob")),
3284                ],
3285            },
3286            body: vec![],
3287        });
3288
3289        // pred direct_count(symbol, u64).
3290        program.predicates.push(pred_decl(
3291            "direct_count",
3292            vec![ScalarType::Symbol, ScalarType::U64],
3293        ));
3294
3295        // direct_count(Mgr, count(Emp)) :- reports_to(Emp, Mgr).
3296        program.rules.push(Rule {
3297            head: Atom {
3298                predicate: "direct_count".to_string(),
3299                terms: vec![
3300                    Term::Variable("Mgr".to_string()),
3301                    Term::Aggregate(AggExpr {
3302                        op: AggOp::Count,
3303                        variable: "Emp".to_string(),
3304                    }),
3305                ],
3306            },
3307            body: vec![BodyLiteral::Positive(Atom {
3308                predicate: "reports_to".to_string(),
3309                terms: vec![
3310                    Term::Variable("Emp".to_string()),
3311                    Term::Variable("Mgr".to_string()),
3312                ],
3313            })],
3314        });
3315
3316        // pred big_manager(symbol, u64).
3317        program.predicates.push(pred_decl(
3318            "big_manager",
3319            vec![ScalarType::Symbol, ScalarType::U64],
3320        ));
3321
3322        // big_manager(Mgr, Count) :- direct_count(Mgr, Count), Count >= 2.
3323        program.rules.push(Rule {
3324            head: Atom {
3325                predicate: "big_manager".to_string(),
3326                terms: vec![
3327                    Term::Variable("Mgr".to_string()),
3328                    Term::Variable("Count".to_string()),
3329                ],
3330            },
3331            body: vec![
3332                BodyLiteral::Positive(Atom {
3333                    predicate: "direct_count".to_string(),
3334                    terms: vec![
3335                        Term::Variable("Mgr".to_string()),
3336                        Term::Variable("Count".to_string()),
3337                    ],
3338                }),
3339                BodyLiteral::Comparison(Comparison {
3340                    left: Term::Variable("Count".to_string()),
3341                    op: CompOp::Ge,
3342                    right: Term::Integer(2),
3343                }),
3344            ],
3345        });
3346
3347        let mut lowerer = Lowerer::new();
3348        lowerer.infer_schemas(&program).unwrap();
3349
3350        // Verify schema has correct types
3351        let schema = lowerer
3352            .schemas
3353            .get("direct_count")
3354            .expect("schema for direct_count");
3355        assert_eq!(
3356            schema.column_type(0),
3357            Some(ScalarType::Symbol),
3358            "First column should be Symbol"
3359        );
3360        assert_eq!(
3361            schema.column_type(1),
3362            Some(ScalarType::U64),
3363            "Second column should be U64"
3364        );
3365
3366        lowerer.set_strata(vec![
3367            vec!["reports_to".to_string()],
3368            vec!["direct_count".to_string()],
3369            vec!["big_manager".to_string()],
3370        ]);
3371        lowerer.build_sccs(&program);
3372
3373        // Lower the big_manager rule (index 3: after 2 facts + aggregation rule)
3374        let big_manager_rule = &program.rules[3];
3375        let result = lowerer.lower_rule(big_manager_rule);
3376        assert!(
3377            result.is_ok(),
3378            "Lowering should succeed: {:?}",
3379            result.err()
3380        );
3381
3382        // Check that the filter has the correct constant type
3383        fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3384            match node {
3385                RirNode::Filter { predicate, input } => {
3386                    if let Expr::Compare { right, .. } = predicate {
3387                        if let Expr::Const(val) = right.as_ref() {
3388                            return Some(val);
3389                        }
3390                    }
3391                    find_compare_const(input)
3392                }
3393                RirNode::Project { input, .. } => find_compare_const(input),
3394                RirNode::Join { left, right, .. } => {
3395                    find_compare_const(left).or_else(|| find_compare_const(right))
3396                }
3397                _ => None,
3398            }
3399        }
3400
3401        let node = result.unwrap();
3402        let const_val = find_compare_const(&node);
3403        assert!(const_val.is_some(), "Should find a constant in comparison");
3404
3405        // The constant should be U64(2), not I64(2)
3406        match const_val.unwrap() {
3407            ConstValue::U64(v) => assert_eq!(*v, 2, "Value should be 2"),
3408            other => panic!("Expected U64(2), got {:?}", other),
3409        }
3410    }
3411}