1use crate::evaluation::Evaluator;
2use crate::parsing::ast::Span;
3use crate::planning::plan;
4use crate::{parse, LemmaDoc, LemmaError, LemmaResult, ResourceLimits, Response};
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8pub struct Engine {
13 execution_plans: HashMap<String, crate::planning::ExecutionPlan>,
14 documents: HashMap<String, LemmaDoc>,
15 sources: HashMap<String, String>,
16 evaluator: Evaluator,
17 limits: ResourceLimits,
18}
19
20impl Default for Engine {
21 fn default() -> Self {
22 Self {
23 execution_plans: HashMap::new(),
24 documents: HashMap::new(),
25 sources: HashMap::new(),
26 evaluator: Evaluator,
27 limits: ResourceLimits::default(),
28 }
29 }
30}
31
32impl Engine {
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn with_limits(limits: ResourceLimits) -> Self {
39 Self {
40 execution_plans: HashMap::new(),
41 documents: HashMap::new(),
42 sources: HashMap::new(),
43 evaluator: Evaluator,
44 limits,
45 }
46 }
47
48 pub fn add_lemma_code(&mut self, lemma_code: &str, source: &str) -> LemmaResult<()> {
49 let new_docs = parse(lemma_code, source, &self.limits)?;
50
51 for doc in &new_docs {
52 let attribute = doc.attribute.clone().unwrap_or_else(|| doc.name.clone());
53 self.sources.insert(attribute, lemma_code.to_owned());
54 self.documents.insert(doc.name.clone(), doc.clone());
55 }
56
57 let all_docs: Vec<LemmaDoc> = self.documents.values().cloned().collect();
59
60 for doc in &new_docs {
62 let execution_plan = plan(doc, &all_docs, self.sources.clone()).map_err(|errs| {
63 if errs.is_empty() {
64 use crate::parsing::ast::Span;
65 let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
66 let source_text = self
67 .sources
68 .get(attribute)
69 .map(|s| s.as_str())
70 .unwrap_or("");
71 LemmaError::engine(
72 format!("Failed to build execution plan for document: {}", doc.name),
73 Span {
74 start: 0,
75 end: 0,
76 line: doc.start_line,
77 col: 0,
78 },
79 attribute,
80 std::sync::Arc::from(source_text),
81 doc.name.clone(),
82 doc.start_line,
83 None::<String>,
84 )
85 } else {
86 errs.into_iter().next().unwrap_or_else(|| {
87 use crate::parsing::ast::Span;
88 let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
89 let source_text = self
90 .sources
91 .get(attribute)
92 .map(|s| s.as_str())
93 .unwrap_or("");
94 LemmaError::engine(
95 format!("Failed to build execution plan for document: {}", doc.name),
96 Span {
97 start: 0,
98 end: 0,
99 line: doc.start_line,
100 col: 0,
101 },
102 attribute,
103 std::sync::Arc::from(source_text),
104 doc.name.clone(),
105 doc.start_line,
106 None::<String>,
107 )
108 })
109 }
110 })?;
111
112 self.execution_plans
113 .insert(doc.name.clone(), execution_plan);
114 }
115
116 Ok(())
117 }
118
119 pub fn remove_document(&mut self, doc_name: &str) {
120 self.execution_plans.remove(doc_name);
121 self.documents.remove(doc_name);
122 }
123
124 pub fn list_documents(&self) -> Vec<String> {
125 self.documents.keys().cloned().collect()
126 }
127
128 pub fn get_document(&self, doc_name: &str) -> Option<&LemmaDoc> {
129 self.documents.get(doc_name)
130 }
131
132 pub fn get_document_facts(&self, doc_name: &str) -> Vec<&crate::LemmaFact> {
133 if let Some(doc) = self.documents.get(doc_name) {
134 doc.facts.iter().collect()
135 } else {
136 Vec::new()
137 }
138 }
139
140 pub fn get_document_rules(&self, doc_name: &str) -> Vec<&crate::LemmaRule> {
141 if let Some(doc) = self.documents.get(doc_name) {
142 doc.rules.iter().collect()
143 } else {
144 Vec::new()
145 }
146 }
147
148 pub fn get_fact_schema_type(
153 &self,
154 doc_name: &str,
155 fact_name: &str,
156 ) -> Option<crate::LemmaType> {
157 let plan = self.execution_plans.get(doc_name)?;
158 let fact_path = plan.get_fact_path_by_str(fact_name)?;
159 plan.fact_schema.get(fact_path).cloned()
160 }
161
162 pub fn get_required_fact_names(
170 &self,
171 doc_name: &str,
172 rule_names: &[String],
173 ) -> Option<HashSet<String>> {
174 let plan = self.execution_plans.get(doc_name)?;
175 let mut required: HashSet<String> = HashSet::new();
176
177 if rule_names.is_empty() {
178 for rule in &plan.rules {
179 for fact in &rule.needs_facts {
180 required.insert(fact.to_string());
181 }
182 }
183 return Some(required);
184 }
185
186 for rule_name in rule_names {
187 let rule = plan.get_rule(rule_name)?;
188 for fact in &rule.needs_facts {
189 required.insert(fact.to_string());
190 }
191 }
192
193 Some(required)
194 }
195
196 pub fn evaluate_json(
207 &self,
208 doc_name: &str,
209 rule_names: Vec<String>,
210 json: &[u8],
211 ) -> LemmaResult<Response> {
212 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
213 LemmaError::engine(
214 format!("Document '{}' not found", doc_name),
215 Span {
216 start: 0,
217 end: 0,
218 line: 1,
219 col: 0,
220 },
221 "<unknown>",
222 Arc::from(""),
223 "<unknown>",
224 1,
225 None::<String>,
226 )
227 })?;
228
229 let values = crate::serialization::from_json(json, base_plan)?;
230
231 self.evaluate_strict(doc_name, rule_names, values)
232 }
233
234 pub fn evaluate(
246 &self,
247 doc_name: &str,
248 rule_names: Vec<String>,
249 values: HashMap<String, String>,
250 ) -> LemmaResult<Response> {
251 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
252 LemmaError::engine(
253 format!("Document '{}' not found", doc_name),
254 Span {
255 start: 0,
256 end: 0,
257 line: 1,
258 col: 0,
259 },
260 "<unknown>",
261 Arc::from(""),
262 "<unknown>",
263 1,
264 None::<String>,
265 )
266 })?;
267
268 let plan = base_plan.clone().with_values(values, &self.limits)?;
269
270 self.evaluate_plan(plan, rule_names)
271 }
272
273 pub fn evaluate_strict(
284 &self,
285 doc_name: &str,
286 rule_names: Vec<String>,
287 values: HashMap<String, crate::LiteralValue>,
288 ) -> LemmaResult<Response> {
289 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
290 LemmaError::engine(
291 format!("Document '{}' not found", doc_name),
292 Span {
293 start: 0,
294 end: 0,
295 line: 1,
296 col: 0,
297 },
298 "<unknown>",
299 Arc::from(""),
300 "<unknown>",
301 1,
302 None::<String>,
303 )
304 })?;
305
306 let plan = base_plan.clone().with_typed_values(values, &self.limits)?;
307
308 self.evaluate_plan(plan, rule_names)
309 }
310
311 pub fn invert_json(
324 &self,
325 doc_name: &str,
326 rule_name: &str,
327 target: crate::inversion::Target,
328 json: &[u8],
329 ) -> LemmaResult<crate::InversionResponse> {
330 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
331 LemmaError::engine(
332 format!("Document '{}' not found", doc_name),
333 Span {
334 start: 0,
335 end: 0,
336 line: 1,
337 col: 0,
338 },
339 "<unknown>",
340 Arc::from(""),
341 "<unknown>",
342 1,
343 None::<String>,
344 )
345 })?;
346
347 let values = crate::serialization::from_json(json, base_plan)?;
348
349 self.invert_strict(doc_name, rule_name, target, values)
350 }
351
352 pub fn invert(
365 &self,
366 doc_name: &str,
367 rule_name: &str,
368 target: crate::inversion::Target,
369 values: HashMap<String, String>,
370 ) -> LemmaResult<crate::InversionResponse> {
371 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
372 LemmaError::engine(
373 format!("Document '{}' not found", doc_name),
374 Span {
375 start: 0,
376 end: 0,
377 line: 1,
378 col: 0,
379 },
380 "<unknown>",
381 Arc::from(""),
382 "<unknown>",
383 1,
384 None::<String>,
385 )
386 })?;
387
388 let plan = base_plan.clone().with_values(values, &self.limits)?;
389
390 let provided_facts = plan.fact_values.keys().cloned().collect();
392
393 crate::inversion::invert(rule_name, target, &plan, &provided_facts)
394 }
395
396 pub fn invert_strict(
409 &self,
410 doc_name: &str,
411 rule_name: &str,
412 target: crate::inversion::Target,
413 values: HashMap<String, crate::LiteralValue>,
414 ) -> LemmaResult<crate::InversionResponse> {
415 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
416 LemmaError::engine(
417 format!("Document '{}' not found", doc_name),
418 Span {
419 start: 0,
420 end: 0,
421 line: 1,
422 col: 0,
423 },
424 "<unknown>",
425 Arc::from(""),
426 "<unknown>",
427 1,
428 None::<String>,
429 )
430 })?;
431
432 let plan = base_plan.clone().with_typed_values(values, &self.limits)?;
433
434 let provided_facts = plan.fact_values.keys().cloned().collect();
436
437 crate::inversion::invert(rule_name, target, &plan, &provided_facts)
438 }
439
440 fn evaluate_plan(
441 &self,
442 plan: crate::planning::ExecutionPlan,
443 rule_names: Vec<String>,
444 ) -> LemmaResult<Response> {
445 let mut response = self.evaluator.evaluate(&plan)?;
446
447 if !rule_names.is_empty() {
448 response.filter_rules(&rule_names);
449 }
450
451 Ok(response)
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use rust_decimal::Decimal;
459 use std::str::FromStr;
460
461 #[test]
462 fn test_evaluate_document_all_rules() {
463 let mut engine = Engine::new();
464 engine
465 .add_lemma_code(
466 r#"
467 doc test
468 fact x = 10
469 fact y = 5
470 rule sum = x + y
471 rule product = x * y
472 "#,
473 "test.lemma",
474 )
475 .unwrap();
476
477 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
478 assert_eq!(response.results.len(), 2);
479
480 let sum_result = response
481 .results
482 .values()
483 .find(|r| r.rule.name == "sum")
484 .unwrap();
485 assert_eq!(
486 sum_result.result,
487 crate::OperationResult::Value(crate::LiteralValue::number(
488 Decimal::from_str("15").unwrap()
489 ))
490 );
491
492 let product_result = response
493 .results
494 .values()
495 .find(|r| r.rule.name == "product")
496 .unwrap();
497 assert_eq!(
498 product_result.result,
499 crate::OperationResult::Value(crate::LiteralValue::number(
500 Decimal::from_str("50").unwrap()
501 ))
502 );
503 }
504
505 #[test]
506 fn test_evaluate_empty_facts() {
507 let mut engine = Engine::new();
508 engine
509 .add_lemma_code(
510 r#"
511 doc test
512 fact price = 100
513 rule total = price * 2
514 "#,
515 "test.lemma",
516 )
517 .unwrap();
518
519 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
520 assert_eq!(response.results.len(), 1);
521 assert_eq!(
522 response.results.values().next().unwrap().result,
523 crate::OperationResult::Value(crate::LiteralValue::number(
524 Decimal::from_str("200").unwrap()
525 ))
526 );
527 }
528
529 #[test]
530 fn test_evaluate_boolean_rule() {
531 let mut engine = Engine::new();
532 engine
533 .add_lemma_code(
534 r#"
535 doc test
536 fact age = 25
537 rule is_adult = age >= 18
538 "#,
539 "test.lemma",
540 )
541 .unwrap();
542
543 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
544 assert_eq!(
545 response.results.values().next().unwrap().result,
546 crate::OperationResult::Value(crate::LiteralValue::boolean(crate::BooleanValue::True))
547 );
548 }
549
550 #[test]
551 fn test_evaluate_with_unless_clause() {
552 let mut engine = Engine::new();
553 engine
554 .add_lemma_code(
555 r#"
556 doc test
557 fact quantity = 15
558 rule discount = 0
559 unless quantity >= 10 then 10
560 "#,
561 "test.lemma",
562 )
563 .unwrap();
564
565 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
566 assert_eq!(
567 response.results.values().next().unwrap().result,
568 crate::OperationResult::Value(crate::LiteralValue::number(
569 Decimal::from_str("10").unwrap()
570 ))
571 );
572 }
573
574 #[test]
575 fn test_document_not_found() {
576 let engine = Engine::new();
577 let result = engine.evaluate("nonexistent", vec![], HashMap::new());
578 assert!(result.is_err());
579 assert!(result.unwrap_err().to_string().contains("not found"));
580 }
581
582 #[test]
583 fn test_multiple_documents() {
584 let mut engine = Engine::new();
585 engine
586 .add_lemma_code(
587 r#"
588 doc doc1
589 fact x = 10
590 rule result = x * 2
591 "#,
592 "doc1.lemma",
593 )
594 .unwrap();
595
596 engine
597 .add_lemma_code(
598 r#"
599 doc doc2
600 fact y = 5
601 rule result = y * 3
602 "#,
603 "doc2.lemma",
604 )
605 .unwrap();
606
607 let response1 = engine.evaluate("doc1", vec![], HashMap::new()).unwrap();
608 assert_eq!(
609 response1.results[0].result,
610 crate::OperationResult::Value(crate::LiteralValue::number(
611 Decimal::from_str("20").unwrap()
612 ))
613 );
614
615 let response2 = engine.evaluate("doc2", vec![], HashMap::new()).unwrap();
616 assert_eq!(
617 response2.results[0].result,
618 crate::OperationResult::Value(crate::LiteralValue::number(
619 Decimal::from_str("15").unwrap()
620 ))
621 );
622 }
623
624 #[test]
625 fn test_runtime_error_mapping() {
626 let mut engine = Engine::new();
627 engine
628 .add_lemma_code(
629 r#"
630 doc test
631 fact numerator = 10
632 fact denominator = 0
633 rule division = numerator / denominator
634 "#,
635 "test.lemma",
636 )
637 .unwrap();
638
639 let result = engine.evaluate("test", vec![], HashMap::new());
640 assert!(result.is_ok(), "Evaluation should succeed");
642 let response = result.unwrap();
643 let division_result = response
644 .results
645 .values()
646 .find(|r| r.rule.name == "division");
647 assert!(
648 division_result.is_some(),
649 "Should have division rule result"
650 );
651 match &division_result.unwrap().result {
652 crate::OperationResult::Veto(message) => {
653 assert!(
654 message
655 .as_ref()
656 .map(|m| m.contains("Division by zero"))
657 .unwrap_or(false),
658 "Veto message should mention division by zero: {:?}",
659 message
660 );
661 }
662 other => panic!("Expected Veto for division by zero, got {:?}", other),
663 }
664 }
665
666 #[test]
667 fn test_rules_sorted_by_source_order() {
668 let mut engine = Engine::new();
669 engine
670 .add_lemma_code(
671 r#"
672 doc test
673 fact a = 1
674 fact b = 2
675 rule z = a + b
676 rule y = a * b
677 rule x = a - b
678 "#,
679 "test.lemma",
680 )
681 .unwrap();
682
683 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
684 assert_eq!(response.results.len(), 3);
685
686 for result in response.results.values() {
688 assert!(
689 result.rule.source_location.is_some(),
690 "Rule {} missing source_location",
691 result.rule.name
692 );
693 }
694
695 let z_pos = response
697 .results
698 .values()
699 .find(|r| r.rule.name == "z")
700 .unwrap()
701 .rule
702 .source_location
703 .as_ref()
704 .unwrap()
705 .span
706 .start;
707 let y_pos = response
708 .results
709 .values()
710 .find(|r| r.rule.name == "y")
711 .unwrap()
712 .rule
713 .source_location
714 .as_ref()
715 .unwrap()
716 .span
717 .start;
718 let x_pos = response
719 .results
720 .values()
721 .find(|r| r.rule.name == "x")
722 .unwrap()
723 .rule
724 .source_location
725 .as_ref()
726 .unwrap()
727 .span
728 .start;
729
730 assert!(z_pos < y_pos);
731 assert!(y_pos < x_pos);
732 }
733
734 #[test]
735 fn test_rule_filtering_evaluates_dependencies() {
736 let mut engine = Engine::new();
737 engine
738 .add_lemma_code(
739 r#"
740 doc test
741 fact base = 100
742 rule subtotal = base * 2
743 rule tax = subtotal? * 10%
744 rule total = subtotal? + tax?
745 "#,
746 "test.lemma",
747 )
748 .unwrap();
749
750 let response = engine
752 .evaluate("test", vec!["total".to_string()], HashMap::new())
753 .unwrap();
754
755 assert_eq!(response.results.len(), 1);
757 assert_eq!(response.results.keys().next().unwrap(), "total");
758
759 let total = response.results.values().next().unwrap();
761 assert_eq!(
762 total.result,
763 crate::OperationResult::Value(crate::LiteralValue::number(
764 Decimal::from_str("220").unwrap()
765 ))
766 );
767 }
768}