Skip to main content

fabula_dsl/
compiler.rs

1//! Compiler: AST → fabula Pattern and MemGraph.
2//!
3//! The compiler is generic over a [`TypeMapper`] trait that bridges DSL
4//! literal types (strings, numbers, booleans, node references) to the
5//! target pattern type system. The default [`MemMapper`] produces
6//! `Pattern<String, MemValue>` for testing and in-memory evaluation.
7
8use crate::ast::*;
9use crate::error::ParseError;
10use fabula::builder::{NegationBuilder, PatternBuilder, StageBuilder};
11use fabula::compose;
12use fabula::datasource::ValueConstraint;
13use fabula::interval::AllenRelation;
14use fabula::pattern::Pattern;
15use fabula_memory::MemValue;
16use std::collections::{HashMap, HashSet};
17use std::fmt::Debug;
18
19// ---------------------------------------------------------------------------
20// TypeMapper trait
21// ---------------------------------------------------------------------------
22
23/// Maps DSL literal types to concrete pattern types.
24///
25/// The DSL parses labels as strings and values as strings/numbers/bools/node-refs.
26/// This trait bridges from those parsed representations to the target type system.
27///
28/// All methods return `Result` to support fallible mappings (e.g., looking up
29/// a string label in a predicate registry that may not contain the label).
30///
31/// # Example
32///
33/// ```rust,ignore
34/// struct WkMapper { labels: HashMap<String, u32> }
35///
36/// impl TypeMapper for WkMapper {
37///     type L = u32;
38///     type V = paracausality::Value;
39///
40///     fn label(&self, s: &str) -> Result<u32, String> {
41///         self.labels.get(s).copied()
42///             .ok_or_else(|| format!("unknown predicate '{}'", s))
43///     }
44///     // ...
45/// }
46/// ```
47pub trait TypeMapper {
48    /// The label type for patterns (e.g., `String`, `u32`).
49    type L: Clone + Debug;
50    /// The value type for patterns (e.g., `MemValue`, `Value`).
51    type V: Clone + Debug;
52
53    /// Convert a label string to the target label type.
54    fn label(&self, s: &str) -> Result<Self::L, String>;
55    /// Convert a string literal to a value.
56    fn string_value(&self, s: &str) -> Result<Self::V, String>;
57    /// Convert a numeric literal to a value.
58    fn num_value(&self, n: f64) -> Result<Self::V, String>;
59    /// Convert a boolean literal to a value.
60    fn bool_value(&self, b: bool) -> Result<Self::V, String>;
61    /// Convert a node reference to a value.
62    fn node_ref(&self, name: &str) -> Result<Self::V, String>;
63}
64
65/// Default mapper that produces `Pattern<String, MemValue>`.
66pub struct MemMapper;
67
68impl TypeMapper for MemMapper {
69    type L = String;
70    type V = MemValue;
71
72    fn label(&self, s: &str) -> Result<String, String> {
73        Ok(s.to_string())
74    }
75    fn string_value(&self, s: &str) -> Result<MemValue, String> {
76        Ok(MemValue::Str(s.to_string()))
77    }
78    fn num_value(&self, n: f64) -> Result<MemValue, String> {
79        Ok(MemValue::Num(n))
80    }
81    fn bool_value(&self, b: bool) -> Result<MemValue, String> {
82        Ok(MemValue::Bool(b))
83    }
84    fn node_ref(&self, name: &str) -> Result<MemValue, String> {
85        Ok(MemValue::Node(name.to_string()))
86    }
87}
88
89// ---------------------------------------------------------------------------
90// Pattern compilation
91// ---------------------------------------------------------------------------
92
93/// Compile a pattern AST into a `Pattern<String, MemValue>` using the default mapper.
94pub fn compile_pattern(ast: &PatternAst) -> Result<Pattern<String, MemValue>, ParseError> {
95    compile_pattern_with(ast, &MemMapper)
96}
97
98/// Compile a [`PatternBody`] (from [`crate::parser::Parser::parse_pattern_body()`])
99/// into a `Pattern<String, MemValue>` using the default mapper.
100///
101/// The `name` is assigned to the resulting pattern since the body doesn't
102/// include the `pattern name { }` header.
103pub fn compile_pattern_body(
104    name: &str,
105    body: &PatternBody,
106) -> Result<Pattern<String, MemValue>, ParseError> {
107    compile_pattern_body_with(name, body, &MemMapper)
108}
109
110/// Compile a [`PatternBody`] using a custom [`TypeMapper`].
111pub fn compile_pattern_body_with<M: TypeMapper>(
112    name: &str,
113    body: &PatternBody,
114    mapper: &M,
115) -> Result<Pattern<M::L, M::V>, ParseError> {
116    let ast = PatternAst {
117        name: name.to_string(),
118        stages: body.stages.clone(),
119        negations: body.negations.clone(),
120        temporals: body.temporals.clone(),
121        metadata: body.metadata.clone(),
122        deadline: body.deadline,
123        unordered_groups: body.unordered_groups.clone(),
124        private: body.private,
125    };
126    compile_pattern_with(&ast, mapper)
127}
128
129/// Compile a pattern AST using a custom [`TypeMapper`].
130///
131/// Validates variable scoping: `?var` sources must reference a variable
132/// that was bound by `-> ?var` in a prior clause or is the current stage anchor.
133pub fn compile_pattern_with<M: TypeMapper>(
134    ast: &PatternAst,
135    mapper: &M,
136) -> Result<Pattern<M::L, M::V>, ParseError> {
137    let mut builder = PatternBuilder::<M::L, M::V>::new(&ast.name);
138
139    // Track variables bound across stages
140    let mut bound_vars: HashSet<String> = HashSet::new();
141
142    // Pre-collect bindings from concurrent group siblings so scoping is
143    // symmetric within a group (order-independent).
144    let mut group_prebound: HashMap<usize, HashSet<String>> = HashMap::new();
145    for group in &ast.unordered_groups {
146        let mut group_vars = HashSet::new();
147        for &si in group {
148            if let Some(stage) = ast.stages.get(si) {
149                group_vars.insert(stage.anchor.clone());
150                for clause in &stage.clauses {
151                    if let ClauseTarget::Bind(ref var) = clause.target {
152                        group_vars.insert(var.clone());
153                    }
154                }
155            }
156        }
157        for &si in group {
158            group_prebound.insert(si, group_vars.clone());
159        }
160    }
161
162    for (stage_idx, stage) in ast.stages.iter().enumerate() {
163        let anchor = stage.anchor.clone();
164
165        // Stage anchor is implicitly in scope within its clauses.
166        // For concurrent group members, sibling bindings are also in scope.
167        let mut stage_scope = bound_vars.clone();
168        stage_scope.insert(anchor.clone());
169        if let Some(sibling_vars) = group_prebound.get(&stage_idx) {
170            stage_scope.extend(sibling_vars.iter().cloned());
171        }
172
173        // Validate ?var sources are bound
174        validate_clause_sources(&stage.clauses, &stage_scope)?;
175
176        // Collect new bindings from this stage's clause targets
177        for clause in &stage.clauses {
178            if let ClauseTarget::Bind(ref var) = clause.target {
179                if var == &anchor {
180                    return Err(ParseError {
181                        line: 0,
182                        column: 0,
183                        span: (0, 0),
184                        message: format!(
185                            "binding '-> ?{}' collides with stage anchor '{}'. \
186                             This silently constrains ?{} to self-loops only. \
187                             Use a different variable name.",
188                            var, anchor, var
189                        ),
190                    });
191                }
192                bound_vars.insert(var.clone());
193            }
194        }
195        // Stage anchor is bound for subsequent stages
196        bound_vars.insert(anchor.clone());
197
198        let clauses = stage.clauses.clone();
199        builder = builder.stage(&anchor, |s| build_stage(s, &clauses, mapper));
200    }
201
202    for neg in &ast.negations {
203        // Negation clauses can reference bound vars from completed stages
204        validate_clause_sources(&neg.clauses, &bound_vars)?;
205
206        // Reject unless_between where both anchors are in the same concurrent group
207        if let NegationKind::Between(start, end) = &neg.kind {
208            for group in &ast.unordered_groups {
209                let start_in = group
210                    .iter()
211                    .any(|&i| ast.stages.get(i).is_some_and(|s| s.anchor == *start));
212                let end_in = group
213                    .iter()
214                    .any(|&i| ast.stages.get(i).is_some_and(|s| s.anchor == *end));
215                if start_in && end_in {
216                    return Err(ParseError {
217                        line: 0,
218                        column: 0,
219                        span: (0, 0),
220                        message: format!(
221                            "unless_between anchors '{}' and '{}' are in the same concurrent group. \
222                             Temporal ordering between concurrent stages is undefined.",
223                            start, end
224                        ),
225                    });
226                }
227            }
228        }
229
230        let clauses = neg.clauses.clone();
231        builder = match &neg.kind {
232            NegationKind::Between(start, end) => {
233                builder.unless_between(start, end, |n| build_negation(n, &clauses, mapper))
234            }
235            NegationKind::After(start) => {
236                builder.unless_after(start, |n| build_negation(n, &clauses, mapper))
237            }
238            NegationKind::Global => builder.unless_global(|n| build_negation(n, &clauses, mapper)),
239        };
240    }
241
242    for temp in &ast.temporals {
243        let relation = parse_allen_relation(&temp.relation).map_err(|msg| ParseError {
244            line: 0,
245            column: 0,
246            span: (0, 0),
247            message: msg,
248        })?;
249        if temp.gap_min.is_some() || temp.gap_max.is_some() {
250            builder = builder.temporal_with_gap(
251                &temp.left,
252                relation,
253                &temp.right,
254                fabula::pattern::MetricGap {
255                    min: temp.gap_min,
256                    max: temp.gap_max,
257                },
258            );
259        } else {
260            builder = builder.temporal(&temp.left, relation, &temp.right);
261        }
262    }
263
264    // Convert ordered metadata pairs to HashMap (last-write-wins for duplicates)
265    for (key, value) in &ast.metadata {
266        builder = builder.metadata(key, value);
267    }
268
269    if let Some(deadline) = ast.deadline {
270        if deadline < 1.0 {
271            return Err(ParseError {
272                line: 0,
273                column: 0,
274                span: (0, 0),
275                message: format!("deadline must be a positive integer, got {}", deadline),
276            });
277        }
278        builder = builder.deadline(deadline as u64);
279    }
280
281    let mut pattern = builder.build();
282    pattern.unordered_groups = ast.unordered_groups.clone();
283    pattern.private = ast.private;
284    Ok(pattern)
285}
286
287/// Validate that all `?var` sources in clauses reference bound variables.
288/// Accumulates bindings from `-> ?var` targets clause-by-clause within the list.
289fn validate_clause_sources(
290    clauses: &[ClauseAst],
291    initial_scope: &HashSet<String>,
292) -> Result<(), ParseError> {
293    let mut scope = initial_scope.clone();
294    for clause in clauses {
295        if clause.source_kind == SourceKind::Var && !scope.contains(&clause.source) {
296            return Err(ParseError {
297                line: 0,
298                column: 0,
299                span: (0, 0),
300                message: format!(
301                    "variable '?{}' used as source but not yet bound. \
302                     Bind it with '-> ?{}' in a prior clause, or use '{}' \
303                     (without ?) for a literal node name.",
304                    clause.source, clause.source, clause.source
305                ),
306            });
307        }
308        // Validate ConstraintVar references are in scope
309        if let ClauseTarget::ConstraintVar(_, ref var) = clause.target {
310            if !scope.contains(var) {
311                return Err(ParseError {
312                    line: 0,
313                    column: 0,
314                    span: (0, 0),
315                    message: format!(
316                        "variable '?{}' used in constraint but not yet bound. \
317                         Bind it with '-> ?{}' in a prior clause or stage.",
318                        var, var
319                    ),
320                });
321            }
322        }
323        // Negation (!) is only valid with literal values and node references.
324        // Constraints and bindings cannot be negated.
325        if clause.negated {
326            match &clause.target {
327                ClauseTarget::Constraint(..) | ClauseTarget::ConstraintVar(..) => {
328                    return Err(ParseError {
329                        line: 0,
330                        column: 0,
331                        span: (0, 0),
332                        message: format!(
333                            "negated constraints ('! {}.{} < value') are not supported. \
334                             Rewrite as the inverse constraint \
335                             (e.g., '! x.v < 0.5' becomes 'x.v >= 0.5').",
336                            clause.source, clause.label
337                        ),
338                    });
339                }
340                ClauseTarget::Bind(var) => {
341                    return Err(ParseError {
342                        line: 0,
343                        column: 0,
344                        span: (0, 0),
345                        message: format!(
346                            "negated bindings ('! {}.{} -> ?{}') are not supported.",
347                            clause.source, clause.label, var
348                        ),
349                    });
350                }
351                _ => {} // Literals and NodeRefs can be negated
352            }
353        }
354        // Bind target for subsequent clauses
355        if let ClauseTarget::Bind(ref var) = clause.target {
356            scope.insert(var.clone());
357        }
358    }
359    Ok(())
360}
361
362// ---------------------------------------------------------------------------
363// Stage and clause compilation (generic over TypeMapper)
364// ---------------------------------------------------------------------------
365
366fn build_stage<M: TypeMapper>(
367    mut s: StageBuilder<M::L, M::V>,
368    clauses: &[ClauseAst],
369    mapper: &M,
370) -> StageBuilder<M::L, M::V> {
371    for clause in clauses {
372        s = add_clause_to_stage(s, clause, mapper);
373    }
374    s
375}
376
377fn add_clause_to_stage<M: TypeMapper>(
378    s: StageBuilder<M::L, M::V>,
379    clause: &ClauseAst,
380    mapper: &M,
381) -> StageBuilder<M::L, M::V> {
382    let source = &clause.source;
383    // Unwrap mapper results — validation errors in mapper are propagated
384    // at a higher level; within the builder callback we cannot return Result.
385    let label = mapper
386        .label(&clause.label)
387        .expect("label mapping failed in stage builder");
388
389    match &clause.target {
390        ClauseTarget::LiteralStr(val) => {
391            let v = mapper
392                .string_value(val)
393                .expect("string_value mapping failed");
394            if clause.negated {
395                s.not_edge(source, label, v)
396            } else {
397                s.edge(source, label, v)
398            }
399        }
400        ClauseTarget::LiteralNum(val) => {
401            let v = mapper.num_value(*val).expect("num_value mapping failed");
402            if clause.negated {
403                s.not_edge(source, label, v)
404            } else {
405                s.edge(source, label, v)
406            }
407        }
408        ClauseTarget::LiteralBool(val) => {
409            let v = mapper.bool_value(*val).expect("bool_value mapping failed");
410            if clause.negated {
411                s.not_edge(source, label, v)
412            } else {
413                s.edge(source, label, v)
414            }
415        }
416        ClauseTarget::Bind(var) => s.edge_bind(source, label, var),
417        ClauseTarget::NodeRef(node) => {
418            let v = mapper.node_ref(node).expect("node_ref mapping failed");
419            if clause.negated {
420                s.not_edge(source, label, v)
421            } else {
422                s.edge(source, label, v)
423            }
424        }
425        ClauseTarget::Constraint(op, val) => {
426            let constraint = make_constraint_with(mapper, *op, val);
427            s.edge_constrained(source, label, constraint)
428        }
429        ClauseTarget::ConstraintVar(op, var) => {
430            let constraint = make_var_constraint(*op, var);
431            s.edge_constrained(source, label, constraint)
432        }
433    }
434}
435
436fn build_negation<M: TypeMapper>(
437    mut n: NegationBuilder<M::L, M::V>,
438    clauses: &[ClauseAst],
439    mapper: &M,
440) -> NegationBuilder<M::L, M::V> {
441    for clause in clauses {
442        n = add_clause_to_negation(n, clause, mapper);
443    }
444    n
445}
446
447fn add_clause_to_negation<M: TypeMapper>(
448    n: NegationBuilder<M::L, M::V>,
449    clause: &ClauseAst,
450    mapper: &M,
451) -> NegationBuilder<M::L, M::V> {
452    let source = &clause.source;
453    let label = mapper
454        .label(&clause.label)
455        .expect("label mapping failed in negation builder");
456
457    match &clause.target {
458        ClauseTarget::LiteralStr(val) => {
459            let v = mapper
460                .string_value(val)
461                .expect("string_value mapping failed");
462            n.edge(source, label, v)
463        }
464        ClauseTarget::LiteralNum(val) => {
465            let v = mapper.num_value(*val).expect("num_value mapping failed");
466            n.edge(source, label, v)
467        }
468        ClauseTarget::LiteralBool(val) => {
469            let v = mapper.bool_value(*val).expect("bool_value mapping failed");
470            n.edge(source, label, v)
471        }
472        ClauseTarget::Bind(var) => n.edge_bind(source, label, var),
473        ClauseTarget::NodeRef(node) => {
474            let v = mapper.node_ref(node).expect("node_ref mapping failed");
475            n.edge(source, label, v)
476        }
477        ClauseTarget::Constraint(op, val) => {
478            let constraint = make_constraint_with(mapper, *op, val);
479            n.edge_constrained(source, label, constraint)
480        }
481        ClauseTarget::ConstraintVar(op, var) => {
482            let constraint = make_var_constraint(*op, var);
483            n.edge_constrained(source, label, constraint)
484        }
485    }
486}
487
488fn make_constraint_with<M: TypeMapper>(
489    mapper: &M,
490    op: ConstraintOp,
491    val: &ConstraintValue,
492) -> ValueConstraint<M::V> {
493    let v = match val {
494        ConstraintValue::Num(n) => mapper
495            .num_value(*n)
496            .expect("num_value mapping failed in constraint"),
497        ConstraintValue::Str(s) => mapper
498            .string_value(s)
499            .expect("string_value mapping failed in constraint"),
500    };
501    match op {
502        ConstraintOp::Eq => ValueConstraint::Eq(v),
503        ConstraintOp::Lt => ValueConstraint::Lt(v),
504        ConstraintOp::Gt => ValueConstraint::Gt(v),
505        ConstraintOp::Lte => ValueConstraint::Lte(v),
506        ConstraintOp::Gte => ValueConstraint::Gte(v),
507    }
508}
509
510fn make_var_constraint<V>(op: ConstraintOp, var: &str) -> ValueConstraint<V> {
511    match op {
512        ConstraintOp::Eq => ValueConstraint::EqVar(var.to_string()),
513        ConstraintOp::Lt => ValueConstraint::LtVar(var.to_string()),
514        ConstraintOp::Gt => ValueConstraint::GtVar(var.to_string()),
515        ConstraintOp::Lte => ValueConstraint::LteVar(var.to_string()),
516        ConstraintOp::Gte => ValueConstraint::GteVar(var.to_string()),
517    }
518}
519
520// ---------------------------------------------------------------------------
521// Compose compilation
522// ---------------------------------------------------------------------------
523
524/// Compile a compose directive using the default mapper.
525pub fn compile_compose(
526    ast: &ComposeAst,
527    known: &HashMap<String, Pattern<String, MemValue>>,
528) -> Result<Vec<Pattern<String, MemValue>>, ParseError> {
529    compile_compose_with(ast, known, &MemMapper)
530}
531
532/// Compile a compose directive using a custom [`TypeMapper`].
533///
534/// Resolves pattern names against already-compiled patterns in `known`.
535/// Returns one or more patterns (choice returns multiple).
536#[allow(clippy::type_complexity)]
537pub fn compile_compose_with<M: TypeMapper>(
538    ast: &ComposeAst,
539    known: &HashMap<String, Pattern<M::L, M::V>>,
540    _mapper: &M,
541) -> Result<Vec<Pattern<M::L, M::V>>, ParseError> {
542    let resolve = |name: &str| -> Result<&Pattern<M::L, M::V>, ParseError> {
543        known.get(name).ok_or_else(|| ParseError {
544            line: 0,
545            column: 0,
546            span: (0, 0),
547            message: format!(
548                "compose '{}' references pattern '{}' which has not been defined yet. \
549                 Define it before the compose directive.",
550                ast.name, name
551            ),
552        })
553    };
554
555    match &ast.body {
556        ComposeBody::Sequence {
557            left,
558            right,
559            shared,
560        } => {
561            let a = resolve(left)?;
562            let b = resolve(right)?;
563            let shared_refs: Vec<&str> = shared.iter().map(|s| s.as_str()).collect();
564            Ok(vec![compose::sequence(&ast.name, a, b, &shared_refs)])
565        }
566        ComposeBody::Choice {
567            alternatives,
568            exclusive,
569        } => {
570            let pats = alternatives
571                .iter()
572                .map(|name| resolve(name))
573                .collect::<Result<Vec<_>, _>>()?;
574            Ok(compose::choice(&ast.name, &pats, *exclusive))
575        }
576        ComposeBody::Repeat {
577            pattern,
578            min,
579            max,
580            shared,
581        } => {
582            let p = resolve(pattern)?;
583            let shared_refs: Vec<&str> = shared.iter().map(|s| s.as_str()).collect();
584            if *min < 1 {
585                return Err(ParseError {
586                    line: 0,
587                    column: 0,
588                    span: (0, 0),
589                    message: "repeat count must be at least 1".to_string(),
590                });
591            }
592            if let Some(max_val) = max {
593                if *max_val < *min {
594                    return Err(ParseError {
595                        line: 0,
596                        column: 0,
597                        span: (0, 0),
598                        message: format!("repeat max ({}) must be >= min ({})", max_val, min),
599                    });
600                }
601            }
602            // Exact repeat (min == max): use original unrolled repeat for backward compat
603            if *max == Some(*min) {
604                Ok(vec![compose::repeat(&ast.name, p, *min, &shared_refs)])
605            } else {
606                Ok(vec![compose::repeat_range(
607                    &ast.name,
608                    p,
609                    *min,
610                    *max,
611                    &shared_refs,
612                )])
613            }
614        }
615    }
616}
617
618// ---------------------------------------------------------------------------
619// Allen relation parsing
620// ---------------------------------------------------------------------------
621
622fn parse_allen_relation(s: &str) -> Result<AllenRelation, String> {
623    match s {
624        "before" => Ok(AllenRelation::Before),
625        "after" => Ok(AllenRelation::After),
626        "meets" => Ok(AllenRelation::Meets),
627        "met_by" => Ok(AllenRelation::MetBy),
628        "overlaps" => Ok(AllenRelation::Overlaps),
629        "overlapped_by" => Ok(AllenRelation::OverlappedBy),
630        "during" => Ok(AllenRelation::During),
631        "contains" => Ok(AllenRelation::Contains),
632        "starts" => Ok(AllenRelation::Starts),
633        "started_by" => Ok(AllenRelation::StartedBy),
634        "finishes" => Ok(AllenRelation::Finishes),
635        "finished_by" => Ok(AllenRelation::FinishedBy),
636        "equals" => Ok(AllenRelation::Equals),
637        _ => Err(format!("unknown Allen relation '{}'. Expected one of: before, after, meets, met_by, overlaps, overlapped_by, during, contains, starts, started_by, finishes, finished_by, equals", s)),
638    }
639}
640
641// ---------------------------------------------------------------------------
642// Graph compilation (always MemGraph — test-only)
643// ---------------------------------------------------------------------------
644
645/// Compile a graph AST into a `MemGraph`.
646pub fn compile_graph(ast: &GraphAst) -> fabula_memory::MemGraph {
647    let mut graph = fabula_memory::MemGraph::new();
648
649    for edge in &ast.edges {
650        match &edge.target {
651            EdgeTarget::Str(val) => {
652                if let Some(end) = edge.time_end {
653                    graph.add_edge_bounded(
654                        &edge.source,
655                        &edge.label,
656                        MemValue::Str(val.clone()),
657                        edge.time_start,
658                        end,
659                    );
660                } else {
661                    graph.add_str(&edge.source, &edge.label, val, edge.time_start);
662                }
663            }
664            EdgeTarget::Num(val) => {
665                if let Some(end) = edge.time_end {
666                    graph.add_edge_bounded(
667                        &edge.source,
668                        &edge.label,
669                        MemValue::Num(*val),
670                        edge.time_start,
671                        end,
672                    );
673                } else {
674                    graph.add_num(&edge.source, &edge.label, *val, edge.time_start);
675                }
676            }
677            EdgeTarget::Bool(val) => {
678                if let Some(end) = edge.time_end {
679                    graph.add_edge_bounded(
680                        &edge.source,
681                        &edge.label,
682                        MemValue::Bool(*val),
683                        edge.time_start,
684                        end,
685                    );
686                } else {
687                    graph.add_edge(
688                        &edge.source,
689                        &edge.label,
690                        MemValue::Bool(*val),
691                        edge.time_start,
692                    );
693                }
694            }
695            EdgeTarget::NodeRef(node) => {
696                if let Some(end) = edge.time_end {
697                    graph.add_edge_bounded(
698                        &edge.source,
699                        &edge.label,
700                        MemValue::Node(node.clone()),
701                        edge.time_start,
702                        end,
703                    );
704                } else {
705                    graph.add_ref(&edge.source, &edge.label, node, edge.time_start);
706                }
707            }
708        }
709    }
710
711    if let Some(t) = ast.now {
712        graph.set_time(t);
713    }
714
715    graph
716}
717
718// ---------------------------------------------------------------------------
719// Tests
720// ---------------------------------------------------------------------------
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725    use crate::lexer::Lexer;
726    use crate::parser::Parser;
727
728    fn parse_ast(input: &str) -> PatternAst {
729        let tokens = Lexer::new(input).tokenize().unwrap();
730        let mut parser = Parser::new(tokens);
731        parser.parse_pattern_only().unwrap()
732    }
733
734    #[test]
735    fn mem_mapper_matches_existing_behavior() {
736        let input = r#"pattern test {
737            stage e1 {
738                e1.eventType = "betray"
739                e1.actor -> ?char
740            }
741        }"#;
742        let ast = parse_ast(input);
743        let direct = compile_pattern(&ast).unwrap();
744        let via_mapper = compile_pattern_with(&ast, &MemMapper).unwrap();
745        assert_eq!(direct, via_mapper);
746    }
747
748    /// A custom mapper that uppercases labels and wraps values.
749    #[derive(Debug, Clone)]
750    enum UpperValue {
751        Text(String),
752        Number(f64),
753        Flag(bool),
754        Ref(String),
755    }
756
757    struct UpperMapper;
758
759    impl TypeMapper for UpperMapper {
760        type L = String;
761        type V = UpperValue;
762
763        fn label(&self, s: &str) -> Result<String, String> {
764            Ok(s.to_uppercase())
765        }
766        fn string_value(&self, s: &str) -> Result<UpperValue, String> {
767            Ok(UpperValue::Text(s.to_uppercase()))
768        }
769        fn num_value(&self, n: f64) -> Result<UpperValue, String> {
770            Ok(UpperValue::Number(n))
771        }
772        fn bool_value(&self, b: bool) -> Result<UpperValue, String> {
773            Ok(UpperValue::Flag(b))
774        }
775        fn node_ref(&self, name: &str) -> Result<UpperValue, String> {
776            Ok(UpperValue::Ref(name.to_uppercase()))
777        }
778    }
779
780    #[test]
781    fn custom_mapper_transforms_labels() {
782        let input = r#"pattern test {
783            stage e1 {
784                e1.eventType = "betray"
785                e1.actor -> ?char
786            }
787        }"#;
788        let ast = parse_ast(input);
789        let pattern = compile_pattern_with(&ast, &UpperMapper).unwrap();
790        assert_eq!(pattern.stages[0].clauses[0].label, "EVENTTYPE");
791        assert_eq!(pattern.stages[0].clauses[1].label, "ACTOR");
792    }
793
794    #[test]
795    fn custom_mapper_transforms_values() {
796        let input = r#"pattern test {
797            stage e1 {
798                e1.eventType = "betray"
799                e1.score > 5
800            }
801        }"#;
802        let ast = parse_ast(input);
803        let pattern = compile_pattern_with(&ast, &UpperMapper).unwrap();
804        // String literal uppercased
805        match &pattern.stages[0].clauses[0].target {
806            fabula::pattern::Target::Literal(UpperValue::Text(s)) => assert_eq!(s, "BETRAY"),
807            other => panic!("expected Text, got {:?}", other),
808        }
809        // Constraint value mapped through num_value
810        match &pattern.stages[0].clauses[1].target {
811            fabula::pattern::Target::Constraint(ValueConstraint::Gt(UpperValue::Number(n))) => {
812                assert_eq!(*n, 5.0);
813            }
814            other => panic!("expected Gt(Number), got {:?}", other),
815        }
816    }
817
818    /// A mapper that rejects unknown labels.
819    struct StrictMapper;
820
821    impl TypeMapper for StrictMapper {
822        type L = u32;
823        type V = String;
824
825        fn label(&self, s: &str) -> Result<u32, String> {
826            match s {
827                "eventType" => Ok(1),
828                "actor" => Ok(2),
829                _ => Err(format!("unknown predicate '{}'", s)),
830            }
831        }
832        fn string_value(&self, s: &str) -> Result<String, String> {
833            Ok(s.to_string())
834        }
835        fn num_value(&self, n: f64) -> Result<String, String> {
836            Ok(n.to_string())
837        }
838        fn bool_value(&self, b: bool) -> Result<String, String> {
839            Ok(b.to_string())
840        }
841        fn node_ref(&self, name: &str) -> Result<String, String> {
842            Ok(name.to_string())
843        }
844    }
845
846    #[test]
847    fn strict_mapper_succeeds_with_known_labels() {
848        let input = r#"pattern test {
849            stage e1 {
850                e1.eventType = "betray"
851                e1.actor -> ?char
852            }
853        }"#;
854        let ast = parse_ast(input);
855        let pattern = compile_pattern_with(&ast, &StrictMapper).unwrap();
856        assert_eq!(pattern.stages[0].clauses[0].label, 1u32);
857        assert_eq!(pattern.stages[0].clauses[1].label, 2u32);
858    }
859
860    #[test]
861    #[should_panic(expected = "unknown predicate 'badLabel'")]
862    fn strict_mapper_panics_on_unknown_label() {
863        let input = r#"pattern test {
864            stage e1 {
865                e1.badLabel = "value"
866            }
867        }"#;
868        let ast = parse_ast(input);
869        // The mapper error propagates as a panic from within the builder callback
870        let _ = compile_pattern_with(&ast, &StrictMapper);
871    }
872
873    #[test]
874    fn metadata_parsed_and_compiled() {
875        let input = r#"pattern my_rule {
876            meta("severity", "high")
877            meta("mitre", "T1078")
878            stage e1 {
879                e1.eventType = "betray"
880            }
881        }"#;
882        let ast = parse_ast(input);
883        assert_eq!(ast.metadata.len(), 2);
884        assert_eq!(
885            ast.metadata[0],
886            ("severity".to_string(), "high".to_string())
887        );
888        assert_eq!(ast.metadata[1], ("mitre".to_string(), "T1078".to_string()));
889
890        let pattern = compile_pattern(&ast).unwrap();
891        assert_eq!(pattern.metadata.get("severity").unwrap(), "high");
892        assert_eq!(pattern.metadata.get("mitre").unwrap(), "T1078");
893    }
894
895    #[test]
896    fn metadata_after_stages() {
897        let input = r#"pattern test {
898            stage e1 { e1.type = "x" }
899            meta("key", "val")
900        }"#;
901        let ast = parse_ast(input);
902        let pattern = compile_pattern(&ast).unwrap();
903        assert_eq!(pattern.metadata.get("key").unwrap(), "val");
904        assert_eq!(pattern.stages.len(), 1);
905    }
906
907    #[test]
908    fn metadata_duplicate_key_last_wins() {
909        let input = r#"pattern test {
910            meta("key", "first")
911            meta("key", "second")
912            stage e1 { e1.type = "x" }
913        }"#;
914        let ast = parse_ast(input);
915        assert_eq!(ast.metadata.len(), 2); // AST preserves both
916
917        let pattern = compile_pattern(&ast).unwrap();
918        assert_eq!(pattern.metadata.get("key").unwrap(), "second"); // last wins
919        assert_eq!(pattern.metadata.len(), 1);
920    }
921
922    #[test]
923    fn compile_pattern_body_with_metadata() {
924        let input = r#"pattern wrapper {
925            meta("source", "test")
926            stage e1 { e1.type = "x" }
927        }"#;
928        let tokens = Lexer::new(input).tokenize().unwrap();
929        let mut parser = Parser::new(tokens);
930        parser.expect(crate::lexer::TokenKind::Pattern).unwrap();
931        let _name = parser.expect_ident().unwrap();
932        parser.expect(crate::lexer::TokenKind::LBrace).unwrap();
933        let body = parser.parse_pattern_body().unwrap();
934
935        assert_eq!(body.metadata.len(), 1);
936
937        let pattern = compile_pattern_body("renamed", &body).unwrap();
938        assert_eq!(pattern.name, "renamed");
939        assert_eq!(pattern.metadata.get("source").unwrap(), "test");
940    }
941
942    #[test]
943    fn deadline_parsed_and_compiled() {
944        let input = r#"pattern sla {
945            deadline 2880
946            stage e1 { e1.type = "submit" }
947        }"#;
948        let ast = parse_ast(input);
949        assert_eq!(ast.deadline, Some(2880.0));
950
951        let pattern = compile_pattern(&ast).unwrap();
952        assert_eq!(pattern.deadline_ticks, Some(2880));
953    }
954
955    #[test]
956    fn no_deadline_is_none() {
957        let input = r#"pattern test {
958            stage e1 { e1.type = "x" }
959        }"#;
960        let ast = parse_ast(input);
961        assert_eq!(ast.deadline, None);
962
963        let pattern = compile_pattern(&ast).unwrap();
964        assert_eq!(pattern.deadline_ticks, None);
965    }
966
967    #[test]
968    fn deadline_with_metadata() {
969        let input = r#"pattern sla {
970            meta("severity", "high")
971            deadline 100
972            stage e1 { e1.type = "x" }
973        }"#;
974        let pattern = compile_pattern(&parse_ast(input)).unwrap();
975        assert_eq!(pattern.deadline_ticks, Some(100));
976        assert_eq!(pattern.metadata.get("severity").unwrap(), "high");
977    }
978
979    #[test]
980    fn deadline_zero_rejected() {
981        let input = r#"pattern bad {
982            deadline 0
983            stage e1 { e1.type = "x" }
984        }"#;
985        let result = compile_pattern(&parse_ast(input));
986        assert!(result.is_err());
987        assert!(result.unwrap_err().message.contains("positive integer"));
988    }
989
990    // -----------------------------------------------------------------------
991    // Cross-stage value comparison (ConstraintVar)
992    // -----------------------------------------------------------------------
993
994    #[test]
995    fn constraint_var_gt_parsed_and_compiled() {
996        let input = r#"pattern escalation {
997            stage e1 {
998                e1.type = "order"
999                e1.price -> ?base_price
1000            }
1001            stage e2 {
1002                e2.type = "order"
1003                e2.price > ?base_price
1004            }
1005        }"#;
1006        let ast = parse_ast(input);
1007        assert!(matches!(
1008            &ast.stages[1].clauses[1].target,
1009            ClauseTarget::ConstraintVar(ConstraintOp::Gt, var) if var == "base_price"
1010        ));
1011
1012        let pattern = compile_pattern(&ast).unwrap();
1013        match &pattern.stages[1].clauses[1].target {
1014            fabula::pattern::Target::Constraint(ValueConstraint::GtVar(v)) => {
1015                assert_eq!(v, "base_price");
1016            }
1017            other => panic!("expected GtVar, got {:?}", other),
1018        }
1019    }
1020
1021    #[test]
1022    fn constraint_var_all_operators() {
1023        for (op_str, expected_op) in [
1024            ("<", ConstraintOp::Lt),
1025            (">", ConstraintOp::Gt),
1026            ("<=", ConstraintOp::Lte),
1027            (">=", ConstraintOp::Gte),
1028            ("=", ConstraintOp::Eq),
1029        ] {
1030            let input = format!(
1031                r#"pattern test {{
1032                    stage e1 {{ e1.val -> ?v }}
1033                    stage e2 {{ e2.val {} ?v }}
1034                }}"#,
1035                op_str
1036            );
1037            let ast = parse_ast(&input);
1038            assert!(
1039                matches!(
1040                    &ast.stages[1].clauses[0].target,
1041                    ClauseTarget::ConstraintVar(op, var) if *op == expected_op && var == "v"
1042                ),
1043                "failed for operator {}",
1044                op_str
1045            );
1046        }
1047    }
1048
1049    #[test]
1050    fn constraint_var_unbound_rejected() {
1051        let input = r#"pattern bad {
1052            stage e1 {
1053                e1.type = "x"
1054                e1.score > ?unbound
1055            }
1056        }"#;
1057        let result = compile_pattern(&parse_ast(input));
1058        assert!(result.is_err());
1059        assert!(result.unwrap_err().message.contains("not yet bound"));
1060    }
1061
1062    #[test]
1063    fn constraint_var_negated_rejected() {
1064        let input = r#"pattern bad {
1065            stage e1 { e1.val -> ?v }
1066            stage e2 { ! e2.val > ?v }
1067        }"#;
1068        let result = compile_pattern(&parse_ast(input));
1069        assert!(result.is_err());
1070        assert!(result.unwrap_err().message.contains("negated constraints"));
1071    }
1072
1073    // -----------------------------------------------------------------------
1074    // Concurrent (unordered) groups
1075    // -----------------------------------------------------------------------
1076
1077    #[test]
1078    fn concurrent_parsed_and_compiled() {
1079        let input = r#"pattern test {
1080            stage setup { setup.type = "start" }
1081            concurrent {
1082                stage a { a.type = "alpha" }
1083                stage b { b.type = "beta" }
1084            }
1085            stage end { end.type = "finish" }
1086        }"#;
1087        let ast = parse_ast(input);
1088        assert_eq!(ast.stages.len(), 4); // setup, a, b, end
1089        assert_eq!(ast.unordered_groups.len(), 1);
1090        assert_eq!(ast.unordered_groups[0], vec![1, 2]); // a and b
1091
1092        let pattern = compile_pattern(&ast).unwrap();
1093        assert_eq!(pattern.stages.len(), 4);
1094        assert_eq!(pattern.unordered_groups.len(), 1);
1095        assert_eq!(pattern.unordered_groups[0], vec![1, 2]);
1096    }
1097
1098    #[test]
1099    fn concurrent_only_group() {
1100        let input = r#"pattern test {
1101            concurrent {
1102                stage a { a.type = "alpha" }
1103                stage b { b.type = "beta" }
1104            }
1105        }"#;
1106        let ast = parse_ast(input);
1107        assert_eq!(ast.stages.len(), 2);
1108        assert_eq!(ast.unordered_groups, vec![vec![0, 1]]);
1109    }
1110
1111    #[test]
1112    fn concurrent_multiple_groups() {
1113        let input = r#"pattern test {
1114            concurrent {
1115                stage a { a.type = "alpha" }
1116                stage b { b.type = "beta" }
1117            }
1118            stage mid { mid.type = "mid" }
1119            concurrent {
1120                stage c { c.type = "gamma" }
1121                stage d { d.type = "delta" }
1122            }
1123        }"#;
1124        let ast = parse_ast(input);
1125        assert_eq!(ast.stages.len(), 5);
1126        assert_eq!(ast.unordered_groups.len(), 2);
1127        assert_eq!(ast.unordered_groups[0], vec![0, 1]);
1128        assert_eq!(ast.unordered_groups[1], vec![3, 4]);
1129    }
1130
1131    #[test]
1132    fn concurrent_unless_between_same_group_rejected() {
1133        let input = r#"pattern bad {
1134            concurrent {
1135                stage a { a.type = "alpha" }
1136                stage b { b.type = "beta" }
1137            }
1138            unless between a b {
1139                mid.type = "block"
1140            }
1141        }"#;
1142        let result = compile_pattern(&parse_ast(input));
1143        assert!(result.is_err());
1144        assert!(result
1145            .unwrap_err()
1146            .message
1147            .contains("same concurrent group"));
1148    }
1149
1150    #[test]
1151    fn concurrent_unless_between_different_groups_ok() {
1152        let input = r#"pattern ok {
1153            stage setup { setup.type = "start" }
1154            concurrent {
1155                stage a { a.type = "alpha" }
1156                stage b { b.type = "beta" }
1157            }
1158            unless between setup a {
1159                mid.type = "block"
1160            }
1161        }"#;
1162        let result = compile_pattern(&parse_ast(input));
1163        assert!(result.is_ok());
1164    }
1165
1166    #[test]
1167    fn concurrent_dsl_evaluate() {
1168        let doc = crate::parse_document(
1169            r#"
1170            pattern test {
1171                concurrent {
1172                    stage a { a.type = "alpha" }
1173                    stage b { b.type = "beta" }
1174                }
1175            }
1176
1177            graph {
1178                @1 ev1.type = "beta"
1179                @2 ev2.type = "alpha"
1180                now = 10
1181            }
1182            "#,
1183        )
1184        .unwrap();
1185
1186        assert_eq!(doc.patterns[0].unordered_groups, vec![vec![0, 1]]);
1187
1188        let mut engine = fabula::engine::SiftEngine::<String, String, MemValue, i64>::new();
1189        engine.register(doc.patterns[0].clone());
1190        let matches = engine.evaluate(&doc.graphs[0]);
1191        assert_eq!(matches.len(), 1);
1192    }
1193}