1use crate::evaluation::Evaluator;
2use crate::parsing::ast::Span;
3use crate::planning::plan;
4use crate::{parse, LemmaDoc, LemmaError, LemmaResult, LemmaType, ResourceLimits, Response};
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8pub struct Engine {
13 execution_plans: HashMap<String, crate::planning::ExecutionPlan>,
14 documents: HashMap<String, LemmaDoc>,
15 sources: HashMap<String, String>,
16 evaluator: Evaluator,
17 limits: ResourceLimits,
18}
19
20impl Default for Engine {
21 fn default() -> Self {
22 Self {
23 execution_plans: HashMap::new(),
24 documents: HashMap::new(),
25 sources: HashMap::new(),
26 evaluator: Evaluator,
27 limits: ResourceLimits::default(),
28 }
29 }
30}
31
32impl Engine {
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn with_limits(limits: ResourceLimits) -> Self {
39 Self {
40 execution_plans: HashMap::new(),
41 documents: HashMap::new(),
42 sources: HashMap::new(),
43 evaluator: Evaluator,
44 limits,
45 }
46 }
47
48 pub fn add_lemma_code(&mut self, lemma_code: &str, source: &str) -> LemmaResult<()> {
49 let new_docs = parse(lemma_code, source, &self.limits)?;
50
51 for doc in &new_docs {
52 let attribute = doc.attribute.clone().unwrap_or_else(|| doc.name.clone());
53 self.sources.insert(attribute, lemma_code.to_owned());
54 self.documents.insert(doc.name.clone(), doc.clone());
55 }
56
57 let all_docs: Vec<LemmaDoc> = self.documents.values().cloned().collect();
59
60 for doc in &new_docs {
62 let execution_plan = plan(doc, &all_docs, self.sources.clone()).map_err(|errs| {
63 if errs.is_empty() {
64 use crate::parsing::ast::Span;
65 let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
66 let source_text = self
67 .sources
68 .get(attribute)
69 .map(|s| s.as_str())
70 .unwrap_or("");
71 LemmaError::engine(
72 format!("Failed to build execution plan for document: {}", doc.name),
73 Span {
74 start: 0,
75 end: 0,
76 line: doc.start_line,
77 col: 0,
78 },
79 attribute,
80 std::sync::Arc::from(source_text),
81 doc.name.clone(),
82 doc.start_line,
83 None::<String>,
84 )
85 } else {
86 errs.into_iter().next().unwrap_or_else(|| {
87 use crate::parsing::ast::Span;
88 let attribute = doc.attribute.as_deref().unwrap_or(&doc.name);
89 let source_text = self
90 .sources
91 .get(attribute)
92 .map(|s| s.as_str())
93 .unwrap_or("");
94 LemmaError::engine(
95 format!("Failed to build execution plan for document: {}", doc.name),
96 Span {
97 start: 0,
98 end: 0,
99 line: doc.start_line,
100 col: 0,
101 },
102 attribute,
103 std::sync::Arc::from(source_text),
104 doc.name.clone(),
105 doc.start_line,
106 None::<String>,
107 )
108 })
109 }
110 })?;
111
112 self.execution_plans
113 .insert(doc.name.clone(), execution_plan);
114 }
115
116 Ok(())
117 }
118
119 pub fn remove_document(&mut self, doc_name: &str) {
120 self.execution_plans.remove(doc_name);
121 self.documents.remove(doc_name);
122 }
123
124 pub fn list_documents(&self) -> Vec<String> {
125 self.documents.keys().cloned().collect()
126 }
127
128 pub fn get_document(&self, doc_name: &str) -> Option<&LemmaDoc> {
129 self.documents.get(doc_name)
130 }
131
132 pub fn get_execution_plan(&self, doc_name: &str) -> Option<&crate::planning::ExecutionPlan> {
137 self.execution_plans.get(doc_name)
138 }
139
140 pub fn get_document_rules(&self, doc_name: &str) -> Vec<&crate::LemmaRule> {
141 if let Some(doc) = self.documents.get(doc_name) {
142 doc.rules.iter().collect()
143 } else {
144 Vec::new()
145 }
146 }
147
148 pub fn get_facts(
156 &self,
157 doc_name: &str,
158 rule_names: &[String],
159 ) -> LemmaResult<HashMap<crate::FactPath, LemmaType>> {
160 let plan = self.execution_plans.get(doc_name).ok_or_else(|| {
161 LemmaError::engine(
162 format!("Document '{}' not found", doc_name),
163 Span {
164 start: 0,
165 end: 0,
166 line: 1,
167 col: 0,
168 },
169 "<engine>",
170 Arc::from(""),
171 doc_name,
172 1,
173 None::<String>,
174 )
175 })?;
176
177 let mut fact_paths = HashSet::new();
178
179 if rule_names.is_empty() {
180 for rule in plan.rules.iter().filter(|r| r.path.segments.is_empty()) {
182 fact_paths.extend(rule.needs_facts.iter().cloned());
183 }
184 } else {
185 for rule_name in rule_names {
186 let rule = plan.get_rule(rule_name).ok_or_else(|| {
187 LemmaError::engine(
188 format!("Rule '{}' not found in document '{}'", rule_name, doc_name),
189 Span {
190 start: 0,
191 end: 0,
192 line: 1,
193 col: 0,
194 },
195 "<engine>",
196 Arc::from(""),
197 doc_name,
198 1,
199 None::<String>,
200 )
201 })?;
202 fact_paths.extend(rule.needs_facts.iter().cloned());
203 }
204 }
205
206 let mut result = HashMap::new();
208 for fact_path in fact_paths {
209 if let Some(lemma_type) = plan.fact_schema.get(&fact_path) {
210 result.insert(fact_path, lemma_type.clone());
211 }
212 }
213
214 Ok(result)
215 }
216
217 pub fn evaluate_json(
228 &self,
229 doc_name: &str,
230 rule_names: Vec<String>,
231 json: &[u8],
232 ) -> LemmaResult<Response> {
233 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
234 LemmaError::engine(
235 format!("Document '{}' not found", doc_name),
236 Span {
237 start: 0,
238 end: 0,
239 line: 1,
240 col: 0,
241 },
242 "<engine>",
243 Arc::from(""),
244 doc_name,
245 1,
246 None::<String>,
247 )
248 })?;
249
250 let values = crate::serialization::from_json(json, base_plan)?;
251 let plan = base_plan.clone().with_values(values, &self.limits)?;
252
253 self.evaluate_plan(plan, rule_names)
254 }
255
256 pub fn evaluate(
268 &self,
269 doc_name: &str,
270 rule_names: Vec<String>,
271 values: HashMap<String, String>,
272 ) -> LemmaResult<Response> {
273 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
274 LemmaError::engine(
275 format!("Document '{}' not found", doc_name),
276 Span {
277 start: 0,
278 end: 0,
279 line: 1,
280 col: 0,
281 },
282 "<engine>",
283 Arc::from(""),
284 doc_name,
285 1,
286 None::<String>,
287 )
288 })?;
289
290 let plan = base_plan.clone().with_values(values, &self.limits)?;
291
292 self.evaluate_plan(plan, rule_names)
293 }
294
295 pub fn invert_json(
300 &self,
301 doc_name: &str,
302 rule_name: &str,
303 target: crate::inversion::Target,
304 json: &[u8],
305 ) -> LemmaResult<crate::InversionResponse> {
306 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
307 LemmaError::engine(
308 format!("Document '{}' not found", doc_name),
309 Span {
310 start: 0,
311 end: 0,
312 line: 1,
313 col: 0,
314 },
315 "<engine>",
316 Arc::from(""),
317 doc_name,
318 1,
319 None::<String>,
320 )
321 })?;
322
323 let values = crate::serialization::from_json(json, base_plan)?;
324 self.invert(doc_name, rule_name, target, values)
325 }
326
327 pub fn invert(
332 &self,
333 doc_name: &str,
334 rule_name: &str,
335 target: crate::inversion::Target,
336 values: HashMap<String, String>,
337 ) -> LemmaResult<crate::InversionResponse> {
338 let base_plan = self.execution_plans.get(doc_name).ok_or_else(|| {
339 LemmaError::engine(
340 format!("Document '{}' not found", doc_name),
341 Span {
342 start: 0,
343 end: 0,
344 line: 1,
345 col: 0,
346 },
347 "<engine>",
348 Arc::from(""),
349 doc_name,
350 1,
351 None::<String>,
352 )
353 })?;
354
355 let plan = base_plan.clone().with_values(values, &self.limits)?;
356 let provided_facts = plan.fact_values.keys().cloned().collect();
357
358 crate::inversion::invert(rule_name, target, &plan, &provided_facts)
359 }
360
361 fn evaluate_plan(
362 &self,
363 plan: crate::planning::ExecutionPlan,
364 rule_names: Vec<String>,
365 ) -> LemmaResult<Response> {
366 let mut response = self.evaluator.evaluate(&plan)?;
367
368 if !rule_names.is_empty() {
369 response.filter_rules(&rule_names);
370 }
371
372 Ok(response)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use rust_decimal::Decimal;
380 use std::str::FromStr;
381
382 #[test]
383 fn test_evaluate_document_all_rules() {
384 let mut engine = Engine::new();
385 engine
386 .add_lemma_code(
387 r#"
388 doc test
389 fact x = 10
390 fact y = 5
391 rule sum = x + y
392 rule product = x * y
393 "#,
394 "test.lemma",
395 )
396 .unwrap();
397
398 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
399 assert_eq!(response.results.len(), 2);
400
401 let sum_result = response
402 .results
403 .values()
404 .find(|r| r.rule.name == "sum")
405 .unwrap();
406 assert_eq!(
407 sum_result.result,
408 crate::OperationResult::Value(crate::LiteralValue::number(
409 Decimal::from_str("15").unwrap()
410 ))
411 );
412
413 let product_result = response
414 .results
415 .values()
416 .find(|r| r.rule.name == "product")
417 .unwrap();
418 assert_eq!(
419 product_result.result,
420 crate::OperationResult::Value(crate::LiteralValue::number(
421 Decimal::from_str("50").unwrap()
422 ))
423 );
424 }
425
426 #[test]
427 fn test_evaluate_empty_facts() {
428 let mut engine = Engine::new();
429 engine
430 .add_lemma_code(
431 r#"
432 doc test
433 fact price = 100
434 rule total = price * 2
435 "#,
436 "test.lemma",
437 )
438 .unwrap();
439
440 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
441 assert_eq!(response.results.len(), 1);
442 assert_eq!(
443 response.results.values().next().unwrap().result,
444 crate::OperationResult::Value(crate::LiteralValue::number(
445 Decimal::from_str("200").unwrap()
446 ))
447 );
448 }
449
450 #[test]
451 fn test_evaluate_boolean_rule() {
452 let mut engine = Engine::new();
453 engine
454 .add_lemma_code(
455 r#"
456 doc test
457 fact age = 25
458 rule is_adult = age >= 18
459 "#,
460 "test.lemma",
461 )
462 .unwrap();
463
464 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
465 assert_eq!(
466 response.results.values().next().unwrap().result,
467 crate::OperationResult::Value(crate::LiteralValue::boolean(crate::BooleanValue::True))
468 );
469 }
470
471 #[test]
472 fn test_evaluate_with_unless_clause() {
473 let mut engine = Engine::new();
474 engine
475 .add_lemma_code(
476 r#"
477 doc test
478 fact quantity = 15
479 rule discount = 0
480 unless quantity >= 10 then 10
481 "#,
482 "test.lemma",
483 )
484 .unwrap();
485
486 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
487 assert_eq!(
488 response.results.values().next().unwrap().result,
489 crate::OperationResult::Value(crate::LiteralValue::number(
490 Decimal::from_str("10").unwrap()
491 ))
492 );
493 }
494
495 #[test]
496 fn test_document_not_found() {
497 let engine = Engine::new();
498 let result = engine.evaluate("nonexistent", vec![], HashMap::new());
499 assert!(result.is_err());
500 assert!(result.unwrap_err().to_string().contains("not found"));
501 }
502
503 #[test]
504 fn test_multiple_documents() {
505 let mut engine = Engine::new();
506 engine
507 .add_lemma_code(
508 r#"
509 doc doc1
510 fact x = 10
511 rule result = x * 2
512 "#,
513 "doc1.lemma",
514 )
515 .unwrap();
516
517 engine
518 .add_lemma_code(
519 r#"
520 doc doc2
521 fact y = 5
522 rule result = y * 3
523 "#,
524 "doc2.lemma",
525 )
526 .unwrap();
527
528 let response1 = engine.evaluate("doc1", vec![], HashMap::new()).unwrap();
529 assert_eq!(
530 response1.results[0].result,
531 crate::OperationResult::Value(crate::LiteralValue::number(
532 Decimal::from_str("20").unwrap()
533 ))
534 );
535
536 let response2 = engine.evaluate("doc2", vec![], HashMap::new()).unwrap();
537 assert_eq!(
538 response2.results[0].result,
539 crate::OperationResult::Value(crate::LiteralValue::number(
540 Decimal::from_str("15").unwrap()
541 ))
542 );
543 }
544
545 #[test]
546 fn test_runtime_error_mapping() {
547 let mut engine = Engine::new();
548 engine
549 .add_lemma_code(
550 r#"
551 doc test
552 fact numerator = 10
553 fact denominator = 0
554 rule division = numerator / denominator
555 "#,
556 "test.lemma",
557 )
558 .unwrap();
559
560 let result = engine.evaluate("test", vec![], HashMap::new());
561 assert!(result.is_ok(), "Evaluation should succeed");
563 let response = result.unwrap();
564 let division_result = response
565 .results
566 .values()
567 .find(|r| r.rule.name == "division");
568 assert!(
569 division_result.is_some(),
570 "Should have division rule result"
571 );
572 match &division_result.unwrap().result {
573 crate::OperationResult::Veto(message) => {
574 assert!(
575 message
576 .as_ref()
577 .map(|m| m.contains("Division by zero"))
578 .unwrap_or(false),
579 "Veto message should mention division by zero: {:?}",
580 message
581 );
582 }
583 other => panic!("Expected Veto for division by zero, got {:?}", other),
584 }
585 }
586
587 #[test]
588 fn test_rules_sorted_by_source_order() {
589 let mut engine = Engine::new();
590 engine
591 .add_lemma_code(
592 r#"
593 doc test
594 fact a = 1
595 fact b = 2
596 rule z = a + b
597 rule y = a * b
598 rule x = a - b
599 "#,
600 "test.lemma",
601 )
602 .unwrap();
603
604 let response = engine.evaluate("test", vec![], HashMap::new()).unwrap();
605 assert_eq!(response.results.len(), 3);
606
607 for result in response.results.values() {
609 assert!(
610 result.rule.source_location.is_some(),
611 "Rule {} missing source_location",
612 result.rule.name
613 );
614 }
615
616 let z_pos = response
618 .results
619 .values()
620 .find(|r| r.rule.name == "z")
621 .unwrap()
622 .rule
623 .source_location
624 .as_ref()
625 .unwrap()
626 .span
627 .start;
628 let y_pos = response
629 .results
630 .values()
631 .find(|r| r.rule.name == "y")
632 .unwrap()
633 .rule
634 .source_location
635 .as_ref()
636 .unwrap()
637 .span
638 .start;
639 let x_pos = response
640 .results
641 .values()
642 .find(|r| r.rule.name == "x")
643 .unwrap()
644 .rule
645 .source_location
646 .as_ref()
647 .unwrap()
648 .span
649 .start;
650
651 assert!(z_pos < y_pos);
652 assert!(y_pos < x_pos);
653 }
654
655 #[test]
656 fn test_rule_filtering_evaluates_dependencies() {
657 let mut engine = Engine::new();
658 engine
659 .add_lemma_code(
660 r#"
661 doc test
662 fact base = 100
663 rule subtotal = base * 2
664 rule tax = subtotal? * 10%
665 rule total = subtotal? + tax?
666 "#,
667 "test.lemma",
668 )
669 .unwrap();
670
671 let response = engine
673 .evaluate("test", vec!["total".to_string()], HashMap::new())
674 .unwrap();
675
676 assert_eq!(response.results.len(), 1);
678 assert_eq!(response.results.keys().next().unwrap(), "total");
679
680 let total = response.results.values().next().unwrap();
682 assert_eq!(
683 total.result,
684 crate::OperationResult::Value(crate::LiteralValue::number(
685 Decimal::from_str("220").unwrap()
686 ))
687 );
688 }
689}