Skip to main content

lemma/
engine.rs

1use crate::evaluation::Evaluator;
2use crate::parsing::ast::Span;
3use crate::planning::plan;
4use crate::{parse, LemmaDoc, LemmaError, LemmaResult, LemmaType, ResourceLimits, Response};
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8/// Engine for evaluating Lemma rules
9///
10/// Pure Rust implementation that evaluates Lemma docs directly from the AST.
11/// Uses pre-built execution plans that are self-contained and ready for evaluation.
12pub struct Engine {
13    execution_plans: HashMap<String, crate::planning::ExecutionPlan>,
14    documents: HashMap<String, LemmaDoc>,
15    sources: HashMap<String, String>,
16    evaluator: Evaluator,
17    limits: ResourceLimits,
18}
19
20impl Default for Engine {
21    fn default() -> Self {
22        Self {
23            execution_plans: HashMap::new(),
24            documents: HashMap::new(),
25            sources: HashMap::new(),
26            evaluator: Evaluator,
27            limits: ResourceLimits::default(),
28        }
29    }
30}
31
32impl Engine {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    /// Create an engine with custom resource limits
38    pub fn with_limits(limits: ResourceLimits) -> Self {
39        Self {
40            execution_plans: HashMap::new(),
41            documents: HashMap::new(),
42            sources: HashMap::new(),
43            evaluator: Evaluator,
44            limits,
45        }
46    }
47
48    pub fn add_lemma_code(&mut self, lemma_code: &str, source: &str) -> LemmaResult<()> {
49        let new_docs = parse(lemma_code, source, &self.limits)?;
50
51        for doc in &new_docs {
52            let attribute = doc.attribute.clone().unwrap_or_else(|| doc.name.clone());
53            self.sources.insert(attribute, lemma_code.to_owned());
54            self.documents.insert(doc.name.clone(), doc.clone());
55        }
56
57        // Collect all documents (existing + new)
58        let all_docs: Vec<LemmaDoc> = self.documents.values().cloned().collect();
59
60        // Build execution plans for all new documents
61        for doc in &new_docs {
62            let execution_plan = plan(doc, &all_docs, self.sources.clone()).map_err(|errs| {
63                if errs.is_empty() {
64                    use crate::parsing::ast::Span;
65                    let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
66                    let source_text = self
67                        .sources
68                        .get(attribute)
69                        .map(|s| s.as_str())
70                        .unwrap_or("");
71                    LemmaError::engine(
72                        format!("Failed to build execution plan for document: {}", doc.name),
73                        Span {
74                            start: 0,
75                            end: 0,
76                            line: doc.start_line,
77                            col: 0,
78                        },
79                        attribute,
80                        std::sync::Arc::from(source_text),
81                        doc.name.clone(),
82                        doc.start_line,
83                        None::<String>,
84                    )
85                } else {
86                    errs.into_iter().next().unwrap_or_else(|| {
87                        use crate::parsing::ast::Span;
88                        let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
89                        let source_text = self
90                            .sources
91                            .get(attribute)
92                            .map(|s| s.as_str())
93                            .unwrap_or("");
94                        LemmaError::engine(
95                            format!("Failed to build execution plan for document: {}", doc.name),
96                            Span {
97                                start: 0,
98                                end: 0,
99                                line: doc.start_line,
100                                col: 0,
101                            },
102                            attribute,
103                            std::sync::Arc::from(source_text),
104                            doc.name.clone(),
105                            doc.start_line,
106                            None::<String>,
107                        )
108                    })
109                }
110            })?;
111
112            self.execution_plans
113                .insert(doc.name.clone(), execution_plan);
114        }
115
116        Ok(())
117    }
118
119    pub fn remove_document(&mut self, doc_name: &str) {
120        self.execution_plans.remove(doc_name);
121        self.documents.remove(doc_name);
122    }
123
124    pub fn list_documents(&self) -> Vec<String> {
125        self.documents.keys().cloned().collect()
126    }
127
128    pub fn get_document(&self, doc_name: &str) -> Option<&LemmaDoc> {
129        self.documents.get(doc_name)
130    }
131
132    /// Get the execution plan for a document.
133    ///
134    /// The execution plan contains the resolved fact schema, default values,
135    /// and topologically sorted rules ready for evaluation.
136    pub fn get_execution_plan(&self, doc_name: &str) -> Option<&crate::planning::ExecutionPlan> {
137        self.execution_plans.get(doc_name)
138    }
139
140    pub fn get_document_rules(&self, doc_name: &str) -> Vec<&crate::LemmaRule> {
141        if let Some(doc) = self.documents.get(doc_name) {
142            doc.rules.iter().collect()
143        } else {
144            Vec::new()
145        }
146    }
147
148    /// Get the facts (with types) required to evaluate a document's rules.
149    ///
150    /// - If `rule_names` is empty, returns facts for **all local** rules in the document.
151    /// - Otherwise, returns facts for the specified rules (by name).
152    ///
153    /// Returns a map from FactPath to resolved LemmaType.
154    /// This is the authoritative API for determining what inputs a rule needs.
155    pub fn get_facts(
156        &self,
157        doc_name: &str,
158        rule_names: &[String],
159    ) -> LemmaResult<HashMap<crate::FactPath, LemmaType>> {
160        let plan = self.execution_plans.get(doc_name).ok_or_else(|| {
161            LemmaError::engine(
162                format!("Document '{}' not found", doc_name),
163                Span {
164                    start: 0,
165                    end: 0,
166                    line: 1,
167                    col: 0,
168                },
169                "<engine>",
170                Arc::from(""),
171                doc_name,
172                1,
173                None::<String>,
174            )
175        })?;
176
177        let mut fact_paths = HashSet::new();
178
179        if rule_names.is_empty() {
180            // Default behavior: facts for all local rules.
181            for rule in plan.rules.iter().filter(|r| r.path.segments.is_empty()) {
182                fact_paths.extend(rule.needs_facts.iter().cloned());
183            }
184        } else {
185            for rule_name in rule_names {
186                let rule = plan.get_rule(rule_name).ok_or_else(|| {
187                    LemmaError::engine(
188                        format!("Rule '{}' not found in document '{}'", rule_name, doc_name),
189                        Span {
190                            start: 0,
191                            end: 0,
192                            line: 1,
193                            col: 0,
194                        },
195                        "<engine>",
196                        Arc::from(""),
197                        doc_name,
198                        1,
199                        None::<String>,
200                    )
201                })?;
202                fact_paths.extend(rule.needs_facts.iter().cloned());
203            }
204        }
205
206        // Build result map with types from fact_schema
207        let mut result = HashMap::new();
208        for fact_path in fact_paths {
209            if let Some(lemma_type) = plan.fact_schema.get(&fact_path) {
210                result.insert(fact_path, lemma_type.clone());
211            }
212        }
213
214        Ok(result)
215    }
216
217    /// Evaluate rules in a document with JSON values for facts.
218    ///
219    /// This is a convenience method that accepts JSON directly and converts it
220    /// to typed values using the document's fact type declarations.
221    ///
222    /// If `rule_names` is empty, evaluates all rules.
223    /// Otherwise, only returns results for the specified rules (dependencies still computed).
224    ///
225    /// Values are provided as JSON bytes (e.g., `b"{\"quantity\": 5, \"is_member\": true}"`).
226    /// They are automatically parsed to the expected type based on the document schema.
227    pub fn evaluate_json(
228        &self,
229        doc_name: &str,
230        rule_names: Vec<String>,
231        json: &[u8],
232    ) -> LemmaResult<Response> {
233        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
234            LemmaError::engine(
235                format!("Document '{}' not found", doc_name),
236                Span {
237                    start: 0,
238                    end: 0,
239                    line: 1,
240                    col: 0,
241                },
242                "<engine>",
243                Arc::from(""),
244                doc_name,
245                1,
246                None::<String>,
247            )
248        })?;
249
250        let values = crate::serialization::from_json(json, base_plan)?;
251        let plan = base_plan.clone().with_values(values, &self.limits)?;
252
253        self.evaluate_plan(plan, rule_names)
254    }
255
256    /// Evaluate rules in a document with string values for facts.
257    ///
258    /// This is the user-friendly API that accepts raw string values and parses them
259    /// to the appropriate types based on the document's fact type declarations.
260    /// Use this for CLI, HTTP APIs, and other user-facing interfaces.
261    ///
262    /// If `rule_names` is empty, evaluates all rules.
263    /// Otherwise, only returns results for the specified rules (dependencies still computed).
264    ///
265    /// Values are provided as name -> value string pairs (e.g., "type" -> "latte").
266    /// They are automatically parsed to the expected type based on the document schema.
267    pub fn evaluate(
268        &self,
269        doc_name: &str,
270        rule_names: Vec<String>,
271        values: HashMap<String, String>,
272    ) -> LemmaResult<Response> {
273        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
274            LemmaError::engine(
275                format!("Document '{}' not found", doc_name),
276                Span {
277                    start: 0,
278                    end: 0,
279                    line: 1,
280                    col: 0,
281                },
282                "<engine>",
283                Arc::from(""),
284                doc_name,
285                1,
286                None::<String>,
287            )
288        })?;
289
290        let plan = base_plan.clone().with_values(values, &self.limits)?;
291
292        self.evaluate_plan(plan, rule_names)
293    }
294
295    /// Invert a rule to find input domains that produce a desired outcome with JSON values.
296    ///
297    /// Values are provided as JSON bytes (e.g., `b"{\"quantity\": 5, \"is_member\": true}"`).
298    /// They are automatically parsed to the expected type based on the document schema.
299    pub fn invert_json(
300        &self,
301        doc_name: &str,
302        rule_name: &str,
303        target: crate::inversion::Target,
304        json: &[u8],
305    ) -> LemmaResult<crate::InversionResponse> {
306        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
307            LemmaError::engine(
308                format!("Document '{}' not found", doc_name),
309                Span {
310                    start: 0,
311                    end: 0,
312                    line: 1,
313                    col: 0,
314                },
315                "<engine>",
316                Arc::from(""),
317                doc_name,
318                1,
319                None::<String>,
320            )
321        })?;
322
323        let values = crate::serialization::from_json(json, base_plan)?;
324        self.invert(doc_name, rule_name, target, values)
325    }
326
327    /// Invert a rule to find input domains that produce a desired outcome.
328    ///
329    /// Values are provided as name -> value string pairs (e.g., "quantity" -> "5").
330    /// They are automatically parsed to the expected type based on the document schema.
331    pub fn invert(
332        &self,
333        doc_name: &str,
334        rule_name: &str,
335        target: crate::inversion::Target,
336        values: HashMap<String, String>,
337    ) -> LemmaResult<crate::InversionResponse> {
338        let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
339            LemmaError::engine(
340                format!("Document '{}' not found", doc_name),
341                Span {
342                    start: 0,
343                    end: 0,
344                    line: 1,
345                    col: 0,
346                },
347                "<engine>",
348                Arc::from(""),
349                doc_name,
350                1,
351                None::<String>,
352            )
353        })?;
354
355        let plan = base_plan.clone().with_values(values, &self.limits)?;
356        let provided_facts = plan.fact_values.keys().cloned().collect();
357
358        crate::inversion::invert(rule_name, target, &plan, &provided_facts)
359    }
360
361    fn evaluate_plan(
362        &self,
363        plan: crate::planning::ExecutionPlan,
364        rule_names: Vec<String>,
365    ) -> LemmaResult<Response> {
366        let mut response = self.evaluator.evaluate(&plan)?;
367
368        if !rule_names.is_empty() {
369            response.filter_rules(&rule_names);
370        }
371
372        Ok(response)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use rust_decimal::Decimal;
380    use std::str::FromStr;
381
382    #[test]
383    fn test_evaluate_document_all_rules() {
384        let mut engine = Engine::new();
385        engine
386            .add_lemma_code(
387                r#"
388        doc test
389        fact x = 10
390        fact y = 5
391        rule sum = x + y
392        rule product = x * y
393    "#,
394                "test.lemma",
395            )
396            .unwrap();
397
398        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
399        assert_eq!(response.results.len(), 2);
400
401        let sum_result = response
402            .results
403            .values()
404            .find(|r| r.rule.name == "sum")
405            .unwrap();
406        assert_eq!(
407            sum_result.result,
408            crate::OperationResult::Value(crate::LiteralValue::number(
409                Decimal::from_str("15").unwrap()
410            ))
411        );
412
413        let product_result = response
414            .results
415            .values()
416            .find(|r| r.rule.name == "product")
417            .unwrap();
418        assert_eq!(
419            product_result.result,
420            crate::OperationResult::Value(crate::LiteralValue::number(
421                Decimal::from_str("50").unwrap()
422            ))
423        );
424    }
425
426    #[test]
427    fn test_evaluate_empty_facts() {
428        let mut engine = Engine::new();
429        engine
430            .add_lemma_code(
431                r#"
432        doc test
433        fact price = 100
434        rule total = price * 2
435    "#,
436                "test.lemma",
437            )
438            .unwrap();
439
440        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
441        assert_eq!(response.results.len(), 1);
442        assert_eq!(
443            response.results.values().next().unwrap().result,
444            crate::OperationResult::Value(crate::LiteralValue::number(
445                Decimal::from_str("200").unwrap()
446            ))
447        );
448    }
449
450    #[test]
451    fn test_evaluate_boolean_rule() {
452        let mut engine = Engine::new();
453        engine
454            .add_lemma_code(
455                r#"
456        doc test
457        fact age = 25
458        rule is_adult = age >= 18
459    "#,
460                "test.lemma",
461            )
462            .unwrap();
463
464        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
465        assert_eq!(
466            response.results.values().next().unwrap().result,
467            crate::OperationResult::Value(crate::LiteralValue::boolean(crate::BooleanValue::True))
468        );
469    }
470
471    #[test]
472    fn test_evaluate_with_unless_clause() {
473        let mut engine = Engine::new();
474        engine
475            .add_lemma_code(
476                r#"
477        doc test
478        fact quantity = 15
479        rule discount = 0
480          unless quantity >= 10 then 10
481    "#,
482                "test.lemma",
483            )
484            .unwrap();
485
486        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
487        assert_eq!(
488            response.results.values().next().unwrap().result,
489            crate::OperationResult::Value(crate::LiteralValue::number(
490                Decimal::from_str("10").unwrap()
491            ))
492        );
493    }
494
495    #[test]
496    fn test_document_not_found() {
497        let engine = Engine::new();
498        let result = engine.evaluate("nonexistent", vec![], HashMap::new());
499        assert!(result.is_err());
500        assert!(result.unwrap_err().to_string().contains("not found"));
501    }
502
503    #[test]
504    fn test_multiple_documents() {
505        let mut engine = Engine::new();
506        engine
507            .add_lemma_code(
508                r#"
509        doc doc1
510        fact x = 10
511        rule result = x * 2
512    "#,
513                "doc1.lemma",
514            )
515            .unwrap();
516
517        engine
518            .add_lemma_code(
519                r#"
520        doc doc2
521        fact y = 5
522        rule result = y * 3
523    "#,
524                "doc2.lemma",
525            )
526            .unwrap();
527
528        let response1 = engine.evaluate("doc1", vec![], HashMap::new()).unwrap();
529        assert_eq!(
530            response1.results[0].result,
531            crate::OperationResult::Value(crate::LiteralValue::number(
532                Decimal::from_str("20").unwrap()
533            ))
534        );
535
536        let response2 = engine.evaluate("doc2", vec![], HashMap::new()).unwrap();
537        assert_eq!(
538            response2.results[0].result,
539            crate::OperationResult::Value(crate::LiteralValue::number(
540                Decimal::from_str("15").unwrap()
541            ))
542        );
543    }
544
545    #[test]
546    fn test_runtime_error_mapping() {
547        let mut engine = Engine::new();
548        engine
549            .add_lemma_code(
550                r#"
551        doc test
552        fact numerator = 10
553        fact denominator = 0
554        rule division = numerator / denominator
555    "#,
556                "test.lemma",
557            )
558            .unwrap();
559
560        let result = engine.evaluate("test", vec![], HashMap::new());
561        // Division by zero returns a Veto (not an error) in the new evaluation design
562        assert!(result.is_ok(), "Evaluation should succeed");
563        let response = result.unwrap();
564        let division_result = response
565            .results
566            .values()
567            .find(|r| r.rule.name == "division");
568        assert!(
569            division_result.is_some(),
570            "Should have division rule result"
571        );
572        match &division_result.unwrap().result {
573            crate::OperationResult::Veto(message) => {
574                assert!(
575                    message
576                        .as_ref()
577                        .map(|m| m.contains("Division by zero"))
578                        .unwrap_or(false),
579                    "Veto message should mention division by zero: {:?}",
580                    message
581                );
582            }
583            other => panic!("Expected Veto for division by zero, got {:?}", other),
584        }
585    }
586
587    #[test]
588    fn test_rules_sorted_by_source_order() {
589        let mut engine = Engine::new();
590        engine
591            .add_lemma_code(
592                r#"
593        doc test
594        fact a = 1
595        fact b = 2
596        rule z = a + b
597        rule y = a * b
598        rule x = a - b
599    "#,
600                "test.lemma",
601            )
602            .unwrap();
603
604        let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
605        assert_eq!(response.results.len(), 3);
606
607        // Check they all have span information for ordering
608        for result in response.results.values() {
609            assert!(
610                result.rule.source_location.is_some(),
611                "Rule {} missing source_location",
612                result.rule.name
613            );
614        }
615
616        // Verify source positions increase (z < y < x)
617        let z_pos = response
618            .results
619            .values()
620            .find(|r| r.rule.name == "z")
621            .unwrap()
622            .rule
623            .source_location
624            .as_ref()
625            .unwrap()
626            .span
627            .start;
628        let y_pos = response
629            .results
630            .values()
631            .find(|r| r.rule.name == "y")
632            .unwrap()
633            .rule
634            .source_location
635            .as_ref()
636            .unwrap()
637            .span
638            .start;
639        let x_pos = response
640            .results
641            .values()
642            .find(|r| r.rule.name == "x")
643            .unwrap()
644            .rule
645            .source_location
646            .as_ref()
647            .unwrap()
648            .span
649            .start;
650
651        assert!(z_pos < y_pos);
652        assert!(y_pos < x_pos);
653    }
654
655    #[test]
656    fn test_rule_filtering_evaluates_dependencies() {
657        let mut engine = Engine::new();
658        engine
659            .add_lemma_code(
660                r#"
661        doc test
662        fact base = 100
663        rule subtotal = base * 2
664        rule tax = subtotal? * 10%
665        rule total = subtotal? + tax?
666    "#,
667                "test.lemma",
668            )
669            .unwrap();
670
671        // Request only 'total', but it depends on 'subtotal' and 'tax'
672        let response = engine
673            .evaluate("test", vec!["total".to_string()], HashMap::new())
674            .unwrap();
675
676        // Only 'total' should be in results
677        assert_eq!(response.results.len(), 1);
678        assert_eq!(response.results.keys().next().unwrap(), "total");
679
680        // But the value should be correct (dependencies were computed)
681        let total = response.results.values().next().unwrap();
682        assert_eq!(
683            total.result,
684            crate::OperationResult::Value(crate::LiteralValue::number(
685                Decimal::from_str("220").unwrap()
686            ))
687        );
688    }
689}