lemma/
engine.rs

1use crate::evaluation::Evaluator;
2use crate::parsing::ast::Span;
3use crate::planning::plan;
4use crate::{parse, LemmaDoc, LemmaError, LemmaResult, ResourceLimits, Response};
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8/// Engine for evaluating Lemma rules
9///
10/// Pure Rust implementation that evaluates Lemma docs directly from the AST.
11/// Uses pre-built execution plans that are self-contained and ready for evaluation.
12pub struct Engine {
13    execution_plans: HashMap<String, crate::planning::ExecutionPlan>,
14    documents: HashMap<String, LemmaDoc>,
15    sources: HashMap<String, String>,
16    evaluator: Evaluator,
17    limits: ResourceLimits,
18}
19
20impl Default for Engine {
21    fn default() -> Self {
22        Self {
23            execution_plans: HashMap::new(),
24            documents: HashMap::new(),
25            sources: HashMap::new(),
26            evaluator: Evaluator,
27            limits: ResourceLimits::default(),
28        }
29    }
30}
31
32impl Engine {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Create an engine with custom resource limits
38    pub fn with_limits(limits: ResourceLimits) -> Self {
39        Self {
40            execution_plans: HashMap::new(),
41            documents: HashMap::new(),
42            sources: HashMap::new(),
43            evaluator: Evaluator,
44            limits,
45        }
46    }
47
48    pub fn add_lemma_code(&mut self, lemma_code: &str, source: &str) -> LemmaResult<()> {
49        let new_docs = parse(lemma_code, source, &self.limits)?;
50
51        for doc in &new_docs {
52            let attribute = doc.attribute.clone().unwrap_or_else(|| doc.name.clone());
53            self.sources.insert(attribute, lemma_code.to_owned());
54            self.documents.insert(doc.name.clone(), doc.clone());
55        }
56
57        // Collect all documents (existing + new)
58        let all_docs: Vec<LemmaDoc> = self.documents.values().cloned().collect();
59
60        // Build execution plans for all new documents
61        for doc in &new_docs {
62            let execution_plan = plan(doc, &all_docs, self.sources.clone()).map_err(|errs| {
63                if errs.is_empty() {
64                    use crate::parsing::ast::Span;
65                    let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
66                    let source_text = self
67                        .sources
68                        .get(attribute)
69                        .map(|s| s.as_str())
70                        .unwrap_or("");
71                    LemmaError::engine(
72                        format!("Failed to build execution plan for document: {}", doc.name),
73                        Span {
74                            start: 0,
75                            end: 0,
76                            line: doc.start_line,
77                            col: 0,
78                        },
79                        attribute,
80                        std::sync::Arc::from(source_text),
81                        doc.name.clone(),
82                        doc.start_line,
83                        None::<String>,
84                    )
85                } else {
86                    errs.into_iter().next().unwrap_or_else(|| {
87                        use crate::parsing::ast::Span;
88                        let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
89                        let source_text = self
90                            .sources
91                            .get(attribute)
92                            .map(|s| s.as_str())
93                            .unwrap_or("");
94                        LemmaError::engine(
95                            format!("Failed to build execution plan for document: {}", doc.name),
96                            Span {
97                                start: 0,
98                                end: 0,
99                                line: doc.start_line,
100                                col: 0,
101                            },
102                            attribute,
103                            std::sync::Arc::from(source_text),
104                            doc.name.clone(),
105                            doc.start_line,
106                            None::<String>,
107                        )
108                    })
109                }
110            })?;
111
112            self.execution_plans
113                .insert(doc.name.clone(), execution_plan);
114        }
115
116        Ok(())
117    }
118
119    pub fn remove_document(&mut self, doc_name: &str) {
120        self.execution_plans.remove(doc_name);
121        self.documents.remove(doc_name);
122    }
123
124    pub fn list_documents(&self) -> Vec<String> {
125        self.documents.keys().cloned().collect()
126    }
127
128    pub fn get_document(&self, doc_name: &str) -> Option<&LemmaDoc> {
129        self.documents.get(doc_name)
130    }
131
132    pub fn get_document_facts(&self, doc_name: &str) -> Vec<&crate::LemmaFact> {
133        if let Some(doc) = self.documents.get(doc_name) {
134            doc.facts.iter().collect()
135        } else {
136            Vec::new()
137        }
138    }
139
140    pub fn get_document_rules(&self, doc_name: &str) -> Vec<&crate::LemmaRule> {
141        if let Some(doc) = self.documents.get(doc_name) {
142            doc.rules.iter().collect()
143        } else {
144            Vec::new()
145        }
146    }
147
148    /// Get the resolved schema type for a fact in a document (including imported/custom types).
149    ///
150    /// This uses the document's compiled execution plan, so it reflects the authoritative schema
151    /// after type resolution and overrides.
152    pub fn get_fact_schema_type(
153        &self,
154        doc_name: &str,
155        fact_name: &str,
156    ) -> Option<crate::LemmaType> {
157        let plan = self.execution_plans.get(doc_name)?;
158        let fact_path = plan.get_fact_path_by_str(fact_name)?;
159        plan.fact_schema.get(fact_path).cloned()
160    }
161
162    /// Get the set of fact names required to evaluate a document's rules.
163    ///
164    /// - If `rule_names` is empty, returns required facts for **all** rules in the document.
165    /// - Otherwise, returns required facts for the specified local rules (by name).
166    ///
167    /// Returned names match `FactReference::to_string()` / `FactPath::to_string()` (e.g. "age",
168    /// "order.price", etc.), so they can be used by UIs to decide what to prompt for.
169    pub fn get_required_fact_names(
170        &self,
171        doc_name: &str,
172        rule_names: &[String],
173    ) -> Option<HashSet<String>> {
174        let plan = self.execution_plans.get(doc_name)?;
175        let mut required: HashSet<String> = HashSet::new();
176
177        if rule_names.is_empty() {
178            for rule in &plan.rules {
179                for fact in &rule.needs_facts {
180                    required.insert(fact.to_string());
181                }
182            }
183            return Some(required);
184        }
185
186        for rule_name in rule_names {
187            let rule = plan.get_rule(rule_name)?;
188            for fact in &rule.needs_facts {
189                required.insert(fact.to_string());
190            }
191        }
192
193        Some(required)
194    }
195
196    /// Evaluate rules in a document with JSON values for facts.
197    ///
198    /// This is a convenience method that accepts JSON directly and converts it
199    /// to typed values using the document's fact type declarations.
200    ///
201    /// If `rule_names` is empty, evaluates all rules.
202    /// Otherwise, only returns results for the specified rules (dependencies still computed).
203    ///
204    /// Values are provided as JSON bytes (e.g., `b"{\"quantity\": 5, \"is_member\": true}"`).
205    /// They are automatically parsed to the expected type based on the document schema.
206    pub fn evaluate_json(
207        &self,
208        doc_name: &str,
209        rule_names: Vec<String>,
210        json: &[u8],
211    ) -> LemmaResult<Response> {
212        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
213            LemmaError::engine(
214                format!("Document '{}' not found", doc_name),
215                Span {
216                    start: 0,
217                    end: 0,
218                    line: 1,
219                    col: 0,
220                },
221                "<unknown>",
222                Arc::from(""),
223                "<unknown>",
224                1,
225                None::<String>,
226            )
227        })?;
228
229        let values = crate::serialization::from_json(json, base_plan)?;
230
231        self.evaluate_strict(doc_name, rule_names, values)
232    }
233
234    /// Evaluate rules in a document with string values for facts.
235    ///
236    /// This is the user-friendly API that accepts raw string values and parses them
237    /// to the appropriate types based on the document's fact type declarations.
238    /// Use this for CLI, HTTP APIs, and other user-facing interfaces.
239    ///
240    /// If `rule_names` is empty, evaluates all rules.
241    /// Otherwise, only returns results for the specified rules (dependencies still computed).
242    ///
243    /// Values are provided as name -> value string pairs (e.g., "type" -> "latte").
244    /// They are automatically parsed to the expected type based on the document schema.
245    pub fn evaluate(
246        &self,
247        doc_name: &str,
248        rule_names: Vec<String>,
249        values: HashMap<String, String>,
250    ) -> LemmaResult<Response> {
251        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
252            LemmaError::engine(
253                format!("Document '{}' not found", doc_name),
254                Span {
255                    start: 0,
256                    end: 0,
257                    line: 1,
258                    col: 0,
259                },
260                "<unknown>",
261                Arc::from(""),
262                "<unknown>",
263                1,
264                None::<String>,
265            )
266        })?;
267
268        let plan = base_plan.clone().with_values(values, &self.limits)?;
269
270        self.evaluate_plan(plan, rule_names)
271    }
272
273    /// Evaluate rules in a document with typed values for facts.
274    ///
275    /// This is the strict API that accepts pre-typed LiteralValue values.
276    /// Use this for programmatic APIs, protobuf, msgpack, FFI, and other
277    /// strongly-typed interfaces where values are already parsed.
278    ///
279    /// If `rule_names` is empty, evaluates all rules.
280    /// Otherwise, only returns results for the specified rules (dependencies still computed).
281    ///
282    /// Values are provided as name -> LiteralValue pairs (e.g., "age" -> Number(25)).
283    pub fn evaluate_strict(
284        &self,
285        doc_name: &str,
286        rule_names: Vec<String>,
287        values: HashMap<String, crate::LiteralValue>,
288    ) -> LemmaResult<Response> {
289        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
290            LemmaError::engine(
291                format!("Document '{}' not found", doc_name),
292                Span {
293                    start: 0,
294                    end: 0,
295                    line: 1,
296                    col: 0,
297                },
298                "<unknown>",
299                Arc::from(""),
300                "<unknown>",
301                1,
302                None::<String>,
303            )
304        })?;
305
306        let plan = base_plan.clone().with_typed_values(values, &self.limits)?;
307
308        self.evaluate_plan(plan, rule_names)
309    }
310
311    /// Invert a rule to find input domains that produce a desired outcome with JSON values.
312    ///
313    /// This is a convenience method that accepts JSON directly and converts it
314    /// to typed values using the document's fact type declarations.
315    ///
316    /// Returns an InversionResponse containing:
317    /// - `solutions`: Concrete domain constraints for each free variable
318    /// - `undetermined_facts`: Facts that are not fully determined
319    /// - `is_determined`: Whether all facts have concrete values
320    ///
321    /// Values are provided as JSON bytes (e.g., `b"{\"quantity\": 5, \"is_member\": true}"`).
322    /// They are automatically parsed to the expected type based on the document schema.
323    pub fn invert_json(
324        &self,
325        doc_name: &str,
326        rule_name: &str,
327        target: crate::inversion::Target,
328        json: &[u8],
329    ) -> LemmaResult<crate::InversionResponse> {
330        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
331            LemmaError::engine(
332                format!("Document '{}' not found", doc_name),
333                Span {
334                    start: 0,
335                    end: 0,
336                    line: 1,
337                    col: 0,
338                },
339                "<unknown>",
340                Arc::from(""),
341                "<unknown>",
342                1,
343                None::<String>,
344            )
345        })?;
346
347        let values = crate::serialization::from_json(json, base_plan)?;
348
349        self.invert_strict(doc_name, rule_name, target, values)
350    }
351
352    /// Invert a rule to find input domains that produce a desired outcome.
353    ///
354    /// This is the user-friendly API that accepts raw string values and parses them
355    /// to the appropriate types based on the document's fact type declarations.
356    ///
357    /// Returns an InversionResponse containing:
358    /// - `solutions`: Concrete domain constraints for each free variable
359    /// - `undetermined_facts`: Facts that are not fully determined
360    /// - `is_determined`: Whether all facts have concrete values
361    ///
362    /// Values are provided as name -> value string pairs (e.g., "quantity" -> "5").
363    /// They are automatically parsed to the expected type based on the document schema.
364    pub fn invert(
365        &self,
366        doc_name: &str,
367        rule_name: &str,
368        target: crate::inversion::Target,
369        values: HashMap<String, String>,
370    ) -> LemmaResult<crate::InversionResponse> {
371        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
372            LemmaError::engine(
373                format!("Document '{}' not found", doc_name),
374                Span {
375                    start: 0,
376                    end: 0,
377                    line: 1,
378                    col: 0,
379                },
380                "<unknown>",
381                Arc::from(""),
382                "<unknown>",
383                1,
384                None::<String>,
385            )
386        })?;
387
388        let plan = base_plan.clone().with_values(values, &self.limits)?;
389
390        // Collect provided fact paths
391        let provided_facts = plan.fact_values.keys().cloned().collect();
392
393        crate::inversion::invert(rule_name, target, &plan, &provided_facts)
394    }
395
396    /// Invert a rule to find input domains that produce a desired outcome.
397    ///
398    /// This is the strict API that accepts pre-typed LiteralValue values.
399    /// Use this for programmatic APIs, protobuf, msgpack, FFI, and other
400    /// strongly-typed interfaces where values are already parsed.
401    ///
402    /// Returns an InversionResponse containing:
403    /// - `solutions`: Concrete domain constraints for each free variable
404    /// - `undetermined_facts`: Facts that are not fully determined
405    /// - `is_determined`: Whether all facts have concrete values
406    ///
407    /// Values are provided as name -> LiteralValue pairs (e.g., "quantity" -> Number(5)).
408    pub fn invert_strict(
409        &self,
410        doc_name: &str,
411        rule_name: &str,
412        target: crate::inversion::Target,
413        values: HashMap<String, crate::LiteralValue>,
414    ) -> LemmaResult<crate::InversionResponse> {
415        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
416            LemmaError::engine(
417                format!("Document '{}' not found", doc_name),
418                Span {
419                    start: 0,
420                    end: 0,
421                    line: 1,
422                    col: 0,
423                },
424                "<unknown>",
425                Arc::from(""),
426                "<unknown>",
427                1,
428                None::<String>,
429            )
430        })?;
431
432        let plan = base_plan.clone().with_typed_values(values, &self.limits)?;
433
434        // Collect provided fact paths
435        let provided_facts = plan.fact_values.keys().cloned().collect();
436
437        crate::inversion::invert(rule_name, target, &plan, &provided_facts)
438    }
439
440    fn evaluate_plan(
441        &self,
442        plan: crate::planning::ExecutionPlan,
443        rule_names: Vec<String>,
444    ) -> LemmaResult<Response> {
445        let mut response = self.evaluator.evaluate(&plan)?;
446
447        if !rule_names.is_empty() {
448            response.filter_rules(&rule_names);
449        }
450
451        Ok(response)
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458    use rust_decimal::Decimal;
459    use std::str::FromStr;
460
461    #[test]
462    fn test_evaluate_document_all_rules() {
463        let mut engine = Engine::new();
464        engine
465            .add_lemma_code(
466                r#"
467        doc test
468        fact x = 10
469        fact y = 5
470        rule sum = x + y
471        rule product = x * y
472    "#,
473                "test.lemma",
474            )
475            .unwrap();
476
477        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
478        assert_eq!(response.results.len(), 2);
479
480        let sum_result = response
481            .results
482            .values()
483            .find(|r| r.rule.name == "sum")
484            .unwrap();
485        assert_eq!(
486            sum_result.result,
487            crate::OperationResult::Value(crate::LiteralValue::number(
488                Decimal::from_str("15").unwrap()
489            ))
490        );
491
492        let product_result = response
493            .results
494            .values()
495            .find(|r| r.rule.name == "product")
496            .unwrap();
497        assert_eq!(
498            product_result.result,
499            crate::OperationResult::Value(crate::LiteralValue::number(
500                Decimal::from_str("50").unwrap()
501            ))
502        );
503    }
504
505    #[test]
506    fn test_evaluate_empty_facts() {
507        let mut engine = Engine::new();
508        engine
509            .add_lemma_code(
510                r#"
511        doc test
512        fact price = 100
513        rule total = price * 2
514    "#,
515                "test.lemma",
516            )
517            .unwrap();
518
519        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
520        assert_eq!(response.results.len(), 1);
521        assert_eq!(
522            response.results.values().next().unwrap().result,
523            crate::OperationResult::Value(crate::LiteralValue::number(
524                Decimal::from_str("200").unwrap()
525            ))
526        );
527    }
528
529    #[test]
530    fn test_evaluate_boolean_rule() {
531        let mut engine = Engine::new();
532        engine
533            .add_lemma_code(
534                r#"
535        doc test
536        fact age = 25
537        rule is_adult = age >= 18
538    "#,
539                "test.lemma",
540            )
541            .unwrap();
542
543        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
544        assert_eq!(
545            response.results.values().next().unwrap().result,
546            crate::OperationResult::Value(crate::LiteralValue::boolean(crate::BooleanValue::True))
547        );
548    }
549
550    #[test]
551    fn test_evaluate_with_unless_clause() {
552        let mut engine = Engine::new();
553        engine
554            .add_lemma_code(
555                r#"
556        doc test
557        fact quantity = 15
558        rule discount = 0
559          unless quantity >= 10 then 10
560    "#,
561                "test.lemma",
562            )
563            .unwrap();
564
565        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
566        assert_eq!(
567            response.results.values().next().unwrap().result,
568            crate::OperationResult::Value(crate::LiteralValue::number(
569                Decimal::from_str("10").unwrap()
570            ))
571        );
572    }
573
574    #[test]
575    fn test_document_not_found() {
576        let engine = Engine::new();
577        let result = engine.evaluate("nonexistent", vec![], HashMap::new());
578        assert!(result.is_err());
579        assert!(result.unwrap_err().to_string().contains("not found"));
580    }
581
582    #[test]
583    fn test_multiple_documents() {
584        let mut engine = Engine::new();
585        engine
586            .add_lemma_code(
587                r#"
588        doc doc1
589        fact x = 10
590        rule result = x * 2
591    "#,
592                "doc1.lemma",
593            )
594            .unwrap();
595
596        engine
597            .add_lemma_code(
598                r#"
599        doc doc2
600        fact y = 5
601        rule result = y * 3
602    "#,
603                "doc2.lemma",
604            )
605            .unwrap();
606
607        let response1 = engine.evaluate("doc1", vec![], HashMap::new()).unwrap();
608        assert_eq!(
609            response1.results[0].result,
610            crate::OperationResult::Value(crate::LiteralValue::number(
611                Decimal::from_str("20").unwrap()
612            ))
613        );
614
615        let response2 = engine.evaluate("doc2", vec![], HashMap::new()).unwrap();
616        assert_eq!(
617            response2.results[0].result,
618            crate::OperationResult::Value(crate::LiteralValue::number(
619                Decimal::from_str("15").unwrap()
620            ))
621        );
622    }
623
624    #[test]
625    fn test_runtime_error_mapping() {
626        let mut engine = Engine::new();
627        engine
628            .add_lemma_code(
629                r#"
630        doc test
631        fact numerator = 10
632        fact denominator = 0
633        rule division = numerator / denominator
634    "#,
635                "test.lemma",
636            )
637            .unwrap();
638
639        let result = engine.evaluate("test", vec![], HashMap::new());
640        // Division by zero returns a Veto (not an error) in the new evaluation design
641        assert!(result.is_ok(), "Evaluation should succeed");
642        let response = result.unwrap();
643        let division_result = response
644            .results
645            .values()
646            .find(|r| r.rule.name == "division");
647        assert!(
648            division_result.is_some(),
649            "Should have division rule result"
650        );
651        match &division_result.unwrap().result {
652            crate::OperationResult::Veto(message) => {
653                assert!(
654                    message
655                        .as_ref()
656                        .map(|m| m.contains("Division by zero"))
657                        .unwrap_or(false),
658                    "Veto message should mention division by zero: {:?}",
659                    message
660                );
661            }
662            other => panic!("Expected Veto for division by zero, got {:?}", other),
663        }
664    }
665
666    #[test]
667    fn test_rules_sorted_by_source_order() {
668        let mut engine = Engine::new();
669        engine
670            .add_lemma_code(
671                r#"
672        doc test
673        fact a = 1
674        fact b = 2
675        rule z = a + b
676        rule y = a * b
677        rule x = a - b
678    "#,
679                "test.lemma",
680            )
681            .unwrap();
682
683        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
684        assert_eq!(response.results.len(), 3);
685
686        // Check they all have span information for ordering
687        for result in response.results.values() {
688            assert!(
689                result.rule.source_location.is_some(),
690                "Rule {} missing source_location",
691                result.rule.name
692            );
693        }
694
695        // Verify source positions increase (z < y < x)
696        let z_pos = response
697            .results
698            .values()
699            .find(|r| r.rule.name == "z")
700            .unwrap()
701            .rule
702            .source_location
703            .as_ref()
704            .unwrap()
705            .span
706            .start;
707        let y_pos = response
708            .results
709            .values()
710            .find(|r| r.rule.name == "y")
711            .unwrap()
712            .rule
713            .source_location
714            .as_ref()
715            .unwrap()
716            .span
717            .start;
718        let x_pos = response
719            .results
720            .values()
721            .find(|r| r.rule.name == "x")
722            .unwrap()
723            .rule
724            .source_location
725            .as_ref()
726            .unwrap()
727            .span
728            .start;
729
730        assert!(z_pos < y_pos);
731        assert!(y_pos < x_pos);
732    }
733
734    #[test]
735    fn test_rule_filtering_evaluates_dependencies() {
736        let mut engine = Engine::new();
737        engine
738            .add_lemma_code(
739                r#"
740        doc test
741        fact base = 100
742        rule subtotal = base * 2
743        rule tax = subtotal? * 10%
744        rule total = subtotal? + tax?
745    "#,
746                "test.lemma",
747            )
748            .unwrap();
749
750        // Request only 'total', but it depends on 'subtotal' and 'tax'
751        let response = engine
752            .evaluate("test", vec!["total".to_string()], HashMap::new())
753            .unwrap();
754
755        // Only 'total' should be in results
756        assert_eq!(response.results.len(), 1);
757        assert_eq!(response.results.keys().next().unwrap(), "total");
758
759        // But the value should be correct (dependencies were computed)
760        let total = response.results.values().next().unwrap();
761        assert_eq!(
762            total.result,
763            crate::OperationResult::Value(crate::LiteralValue::number(
764                Decimal::from_str("220").unwrap()
765            ))
766        );
767    }
768}