use lemma::parsing::ast::DateTimeValue;
use lemma::Engine;
use rust_decimal::Decimal;
use std::collections::HashMap;
use std::str::FromStr;
fn load_coffee_order() -> Engine {
let mut engine = Engine::new();
let examples = r#"
spec examples
data money: scale
-> decimals 2
-> unit eur 1.00
-> unit gbp 1.17
-> minimum 0 eur
data priority: text
-> option "low"
-> option "medium"
-> option "high"
"#;
let coffee_order = r#"
spec coffee_order
data coffee: text
-> option "espresso"
-> option "latte"
-> option "cappuccino"
-> option "mocha"
data size: text
-> option "small"
-> option "medium"
-> option "large"
-> option "extra large"
data price : money from examples
data priority : priority from examples
data number_of_cups : number -> maximum 10
data has_loyalty_card: boolean
rule ordered_priority: veto "Unknown priority"
unless priority is "low" then 1
unless priority is "medium" then 2
unless priority is "high" then 3
rule base_price: veto "Unknown type of coffee"
unless coffee is "espresso" then 2.50 eur
unless coffee is "latte" then 3.50 eur
unless coffee is "cappuccino" then 3.50 eur
unless coffee is "mocha" then 4.00 eur
rule size_multiplier: veto "Unknown size of coffee"
unless size is "small" then 0.80
unless size is "medium" then 1.00
unless size is "large" then 1.20
rule price_per_cup: base_price * size_multiplier
rule subtotal: price_per_cup * number_of_cups
rule loyalty_discount: 0.0
unless has_loyalty_card then 0.10
rule discount_amount: subtotal * loyalty_discount
rule total: subtotal - discount_amount
"#;
engine
.load(
examples,
lemma::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
"examples.lemma",
))),
)
.expect("Failed to parse examples");
engine
.load(
coffee_order,
lemma::SourceType::Path(std::sync::Arc::new(std::path::PathBuf::from(
"coffee_order.lemma",
))),
)
.expect("Failed to parse coffee_order");
engine
}
#[test]
fn test_coffee_order_espresso_small_no_loyalty() {
let engine = load_coffee_order();
let now = DateTimeValue::now();
let data_values = HashMap::from([
("coffee".to_string(), "espresso".to_string()),
("size".to_string(), "small".to_string()),
("number_of_cups".to_string(), "2".to_string()),
("has_loyalty_card".to_string(), "false".to_string()),
]);
let response = engine
.run(None, "coffee_order", Some(&now), data_values, false)
.expect("Evaluation failed");
let base_price = response
.results
.values()
.find(|r| r.rule.name == "base_price")
.expect("base_price rule not found");
let base_price_value = base_price
.result
.value()
.expect("base_price should have value");
match &base_price_value.value {
lemma::ValueKind::Scale(n, unit) => {
assert_eq!(
unit.as_str(),
"eur",
"base_price should have unit 'eur', got: {:?}",
unit
);
assert_eq!(
*n,
Decimal::from_str("2.50").unwrap(),
"base_price should be exactly 2.50 (2.50 eur), got: {}",
n
);
}
_ => panic!(
"base_price should be Scale type, got: {:?}",
base_price_value.value
),
}
let size_multiplier = response
.results
.values()
.find(|r| r.rule.name == "size_multiplier")
.expect("size_multiplier rule not found");
let multiplier_value = size_multiplier
.result
.value()
.expect("size_multiplier should have value");
match &multiplier_value.value {
lemma::ValueKind::Number(n) => {
assert_eq!(
*n,
Decimal::from_str("0.80").unwrap(),
"size_multiplier should be 0.80, got: {}",
n
);
}
_ => panic!(
"size_multiplier should be Number type, got: {:?}",
multiplier_value.value
),
}
let price_per_cup = response
.results
.values()
.find(|r| r.rule.name == "price_per_cup")
.expect("price_per_cup rule not found");
let cup_price = price_per_cup
.result
.value()
.expect("price_per_cup should have value");
match &cup_price.value {
lemma::ValueKind::Scale(n, unit) => {
assert_eq!(
unit.as_str(),
"eur",
"price_per_cup should have unit 'eur', got: {:?}",
unit
);
assert_eq!(
*n,
Decimal::from_str("2.00").unwrap(),
"price_per_cup should be exactly 2.00 (2.50 * 0.80), got: {}",
n
);
}
_ => panic!(
"price_per_cup should be Scale type, got: {:?}",
cup_price.value
),
}
let subtotal = response
.results
.values()
.find(|r| r.rule.name == "subtotal")
.expect("subtotal rule not found");
let subtotal_value = subtotal.result.value().expect("subtotal should have value");
let subtotal_num = match &subtotal_value.value {
lemma::ValueKind::Scale(n, unit) => {
assert_eq!(
unit.as_str(),
"eur",
"subtotal should have unit 'eur', got: {:?}",
unit
);
*n
}
_ => panic!(
"subtotal should be Scale type, got: {:?}",
subtotal_value.value
),
};
assert_eq!(
subtotal_num,
Decimal::from_str("4.00").unwrap(),
"subtotal should be exactly 4.00 (2.00 * 2), got: {}",
subtotal_num
);
let loyalty_discount = response
.results
.values()
.find(|r| r.rule.name == "loyalty_discount")
.expect("loyalty_discount rule not found");
let discount = loyalty_discount
.result
.value()
.expect("loyalty_discount should have value");
match &discount.value {
lemma::ValueKind::Number(n) => {
assert_eq!(
*n,
Decimal::from_str("0.00").unwrap(),
"loyalty_discount should be 0.00, got: {}",
n
);
}
_ => panic!(
"loyalty_discount should be Number type when 0.0, got: {:?}",
discount.value
),
}
let total = response
.results
.values()
.find(|r| r.rule.name == "total")
.expect("total rule not found");
let total_value = total.result.value().expect("total should have value");
let total_num = match &total_value.value {
lemma::ValueKind::Scale(n, unit) => {
assert_eq!(
unit.as_str(),
"eur",
"total should have unit 'eur', got: {:?}",
unit
);
*n
}
_ => panic!("total should be Scale type, got: {:?}", total_value.value),
};
assert!(
(total_num - subtotal_num).abs() < Decimal::from_str("0.01").unwrap(),
"total should equal subtotal when discount is 0, got total: {}, subtotal: {}",
total_num,
subtotal_num
);
}
#[test]
fn test_coffee_order_latte_large_with_loyalty() {
let engine = load_coffee_order();
let now = DateTimeValue::now();
let data_values = HashMap::from([
("coffee".to_string(), "latte".to_string()),
("size".to_string(), "large".to_string()),
("number_of_cups".to_string(), "3".to_string()),
("has_loyalty_card".to_string(), "true".to_string()),
]);
let response = engine
.run(None, "coffee_order", Some(&now), data_values, false)
.expect("Evaluation failed");
let base_price = response
.results
.values()
.find(|r| r.rule.name == "base_price")
.expect("base_price rule not found");
let base_price_value = base_price
.result
.value()
.expect("base_price should have value");
match &base_price_value.value {
lemma::ValueKind::Scale(n, unit) => {
assert_eq!(
unit.as_str(),
"eur",
"base_price should have unit 'eur', got: {:?}",
unit
);
assert_eq!(
*n,
Decimal::from_str("3.50").unwrap(),
"base_price should be exactly 3.50 (3.50 eur), got: {}",
n
);
}
_ => panic!(
"base_price should be Scale type, got: {:?}",
base_price_value.value
),
}
let size_multiplier = response
.results
.values()
.find(|r| r.rule.name == "size_multiplier")
.expect("size_multiplier rule not found");
let multiplier_value = size_multiplier
.result
.value()
.expect("size_multiplier should have value");
match &multiplier_value.value {
lemma::ValueKind::Number(n) => {
assert_eq!(
*n,
Decimal::from_str("1.20").unwrap(),
"size_multiplier should be 1.20, got: {}",
n
);
}
_ => panic!(
"size_multiplier should be Number type, got: {:?}",
multiplier_value.value
),
}
let loyalty_discount = response
.results
.values()
.find(|r| r.rule.name == "loyalty_discount")
.expect("loyalty_discount rule not found");
let discount = loyalty_discount
.result
.value()
.expect("loyalty_discount should have value");
match &discount.value {
lemma::ValueKind::Number(n) => {
assert_eq!(
*n,
Decimal::from_str("0.10").unwrap(),
"loyalty_discount should be exactly 0.10, got: {}",
n
);
}
_ => panic!(
"loyalty_discount should be Number type, got: {:?}",
discount.value
),
}
let subtotal = response
.results
.values()
.find(|r| r.rule.name == "subtotal")
.expect("subtotal rule not found");
let total = response
.results
.values()
.find(|r| r.rule.name == "total")
.expect("total rule not found");
let subtotal_value = subtotal.result.value().expect("subtotal should have value");
let total_value = total.result.value().expect("total should have value");
let subtotal_num = match &subtotal_value.value {
lemma::ValueKind::Scale(n, unit) => {
assert_eq!(
unit.as_str(),
"eur",
"subtotal should have unit 'eur', got: {:?}",
unit
);
*n
}
_ => panic!(
"subtotal should be Scale type, got: {:?}",
subtotal_value.value
),
};
assert_eq!(
subtotal_num,
Decimal::from_str("12.60").unwrap(),
"subtotal should be exactly 12.60 (4.20 * 3), got: {}",
subtotal_num
);
let total_num = match &total_value.value {
lemma::ValueKind::Scale(n, unit) => {
assert_eq!(
unit.as_str(),
"eur",
"total should have unit 'eur', got: {:?}",
unit
);
*n
}
_ => panic!("total should be Scale type, got: {:?}", total_value.value),
};
assert_eq!(
total_num,
Decimal::from_str("11.34").unwrap(),
"total should be exactly 11.34 (12.60 - 1.26), got: {}",
total_num
);
}
#[test]
fn test_coffee_order_ordered_priority() {
let engine = load_coffee_order();
let now = DateTimeValue::now();
let priorities = ["low", "medium", "high"];
let expected_values = ["1", "2", "3"];
for (priority, expected) in priorities.iter().zip(expected_values.iter()) {
let data_values = HashMap::from([("priority".to_string(), priority.to_string())]);
let response = engine
.run(None, "coffee_order", Some(&now), data_values, false)
.expect("Evaluation failed");
let ordered_priority = response
.results
.values()
.find(|r| r.rule.name == "ordered_priority")
.expect("ordered_priority rule not found");
let priority_value = ordered_priority
.result
.value()
.expect("ordered_priority should have value");
assert_eq!(
priority_value.to_string(),
*expected,
"priority '{}' should map to {}, got: {}",
priority,
expected,
priority_value
);
}
}
#[test]
fn test_coffee_order_invalid_size_veto() {
let engine = load_coffee_order();
let now = DateTimeValue::now();
let data_values = HashMap::from([
("coffee".to_string(), "espresso".to_string()),
("size".to_string(), "extra large".to_string()),
("number_of_cups".to_string(), "1".to_string()),
]);
let response = engine
.run(None, "coffee_order", Some(&now), data_values, false)
.expect("Evaluation should complete (even with veto)");
let size_multiplier = response
.results
.values()
.find(|r| r.rule.name == "size_multiplier")
.expect("size_multiplier rule not found");
assert!(
size_multiplier.result.vetoed(),
"size_multiplier should veto for 'extra large' size"
);
let price_per_cup = response
.results
.values()
.find(|r| r.rule.name == "price_per_cup");
if let Some(price_per_cup) = price_per_cup {
assert!(
price_per_cup.result.vetoed() || price_per_cup.result.value().is_none(),
"price_per_cup should fail when size_multiplier vetoes"
);
}
}