use lemma::parsing::ast::DateTimeValue;
use lemma::{Engine, SemanticDurationUnit};
use rust_decimal::Decimal;
use std::collections::HashMap;
use std::str::FromStr;
fn get_rule_value(
engine: &Engine,
spec_name: &str,
rule_name: &str,
facts: HashMap<String, String>,
) -> lemma::LiteralValue {
let now = DateTimeValue::now();
let response = engine.run(spec_name, Some(&now), facts, false).unwrap();
response
.results
.get(rule_name)
.unwrap_or_else(|| panic!("rule '{}' not found in {}", rule_name, spec_name))
.result
.value()
.unwrap_or_else(|| panic!("rule '{}' had no value", rule_name))
.clone()
}
fn load_specs_folder_examples() -> Engine {
let mut engine = Engine::new();
let examples = [
"../documentation/examples/01_coffee_order.lemma",
"../documentation/examples/02_library_fees.lemma",
"../documentation/examples/03_recipe_scaling.lemma",
"../documentation/examples/04_membership_benefits.lemma",
"../documentation/examples/05_weather_clothing.lemma",
];
for path in examples {
let content = std::fs::read_to_string(path)
.unwrap_or_else(|e| panic!("Failed to read {}: {}", path, e));
engine
.load(&content, lemma::SourceType::Labeled(path))
.unwrap_or_else(|errs| {
panic!(
"Failed to parse {}: {}",
path,
errs.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join("; ")
)
});
}
engine
}
#[test]
fn test_01_coffee_order() {
let engine = load_specs_folder_examples();
let mut facts = HashMap::new();
facts.insert("product".to_string(), "latte".to_string());
facts.insert("size".to_string(), "large".to_string());
facts.insert("number_of_cups".to_string(), "2".to_string());
facts.insert("has_loyalty_card".to_string(), "true".to_string());
facts.insert("age".to_string(), "70".to_string());
let total = get_rule_value(&engine, "coffee_order", "total", facts);
assert_eq!(
total.value,
lemma::ValueKind::Scale(Decimal::from_str("7.644").unwrap(), "eur".to_string())
);
}
#[test]
fn test_02_library_fees() {
let engine = load_specs_folder_examples();
let mut facts = HashMap::new();
facts.insert("days_overdue".to_string(), "5".to_string());
facts.insert("book_type".to_string(), "regular".to_string());
facts.insert("is_first_offense".to_string(), "false".to_string());
let final_fee = get_rule_value(&engine, "library_fees", "final_fee", facts.clone());
assert_eq!(
final_fee.value,
lemma::ValueKind::Scale(Decimal::from_str("1.25").unwrap(), "eur".to_string())
);
let can_checkout = get_rule_value(&engine, "library_fees", "can_checkout", facts);
assert_eq!(can_checkout.value, lemma::ValueKind::Boolean(true));
}
#[test]
fn test_03_recipe_scaling() {
let engine = load_specs_folder_examples();
let mut facts = HashMap::new();
facts.insert("original_servings".to_string(), "4".to_string());
facts.insert("desired_servings".to_string(), "8".to_string());
facts.insert("recipe_name".to_string(), "chocolate_cake".to_string());
let scaling_factor = get_rule_value(&engine, "recipe_scaling", "scaling_factor", facts.clone());
assert_eq!(
scaling_factor.value,
lemma::ValueKind::Number(Decimal::from_str("2").unwrap())
);
let baking_time = get_rule_value(&engine, "recipe_scaling", "baking_time", facts.clone());
assert_eq!(
baking_time.value,
lemma::ValueKind::Duration(
Decimal::from_str("40").unwrap(),
SemanticDurationUnit::Minute.clone()
)
);
let oven_temp = get_rule_value(&engine, "recipe_scaling", "oven_temperature", facts);
assert_eq!(
oven_temp.value,
lemma::ValueKind::Scale(Decimal::from_str("175").unwrap(), "celsius".to_string())
);
}
#[test]
fn test_04_membership_benefits() {
let engine = load_specs_folder_examples();
let discount_rate = get_rule_value(
&engine,
"premium_membership",
"discount_rate",
HashMap::new(),
);
assert_eq!(
discount_rate.value,
lemma::ValueKind::Ratio(
Decimal::from_str("0.10").unwrap(),
Some("percent".to_string())
)
);
let discount = get_rule_value(&engine, "membership_benefits", "discount", HashMap::new());
assert_eq!(
discount.value,
lemma::ValueKind::Number(Decimal::from_str("15").unwrap())
);
let shipping_cost = get_rule_value(
&engine,
"membership_benefits",
"shipping_cost",
HashMap::new(),
);
assert_eq!(
shipping_cost.value,
lemma::ValueKind::Number(Decimal::from_str("0").unwrap())
);
let total_points = get_rule_value(
&engine,
"membership_benefits",
"total_points",
HashMap::new(),
);
assert_eq!(
total_points.value,
lemma::ValueKind::Number(Decimal::from_str("325").unwrap())
);
}
#[test]
fn test_05_weather_clothing() {
let engine = load_specs_folder_examples();
let mut facts = HashMap::new();
facts.insert("temperature".to_string(), "15 celsius".to_string());
facts.insert("is_raining".to_string(), "false".to_string());
facts.insert("wind_speed".to_string(), "10".to_string());
let clothing_layer =
get_rule_value(&engine, "weather_clothing", "clothing_layer", facts.clone());
assert_eq!(
clothing_layer.value,
lemma::ValueKind::Text("light".to_string())
);
let needs_jacket = get_rule_value(&engine, "weather_clothing", "needs_jacket", facts);
assert_eq!(needs_jacket.value, lemma::ValueKind::Boolean(false));
}
#[test]
fn test_all_documentation_examples_parse() {
let engine = load_specs_folder_examples();
let specs = engine.list_specs();
assert!(
specs.len() >= 6,
"Expected at least 6 specs (examples + coffee_order), found {}. Available: {:?}",
specs.len(),
specs
);
let key_specs = vec![
"coffee_order", "library_fees", "recipe_scaling", "premium_membership", "membership_benefits", "weather_clothing", ];
let spec_names: Vec<&str> = specs.iter().map(|d| d.name.as_str()).collect();
for expected in key_specs {
assert!(
spec_names.contains(&expected),
"Expected spec '{}' not found. Available: {:?}",
expected,
spec_names
);
}
}