use crate::errors::{Result, RuleEngineError};
use crate::rete::facts::{FactValue, TypedFacts};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct FieldDef {
pub name: String,
pub field_type: FieldType,
pub default_value: Option<FactValue>,
pub required: bool,
}
#[derive(Debug, Clone, PartialEq)]
pub enum FieldType {
String,
Integer,
Float,
Boolean,
Array(Box<FieldType>),
Any,
}
impl FieldType {
pub fn matches(&self, value: &FactValue) -> bool {
match (self, value) {
(FieldType::String, FactValue::String(_)) => true,
(FieldType::Integer, FactValue::Integer(_)) => true,
(FieldType::Float, FactValue::Float(_)) => true,
(FieldType::Boolean, FactValue::Boolean(_)) => true,
(FieldType::Array(inner), FactValue::Array(arr)) => {
arr.iter().all(|v| inner.matches(v))
}
(FieldType::Any, _) => true,
_ => false,
}
}
pub fn default_value(&self) -> FactValue {
match self {
FieldType::String => FactValue::String(String::new()),
FieldType::Integer => FactValue::Integer(0),
FieldType::Float => FactValue::Float(0.0),
FieldType::Boolean => FactValue::Boolean(false),
FieldType::Array(_) => FactValue::Array(Vec::new()),
FieldType::Any => FactValue::Null,
}
}
}
#[derive(Debug, Clone)]
pub struct Template {
pub name: String,
pub fields: Vec<FieldDef>,
field_map: HashMap<String, usize>,
}
impl Template {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
fields: Vec::new(),
field_map: HashMap::new(),
}
}
pub fn add_field(&mut self, field: FieldDef) -> &mut Self {
let idx = self.fields.len();
self.field_map.insert(field.name.clone(), idx);
self.fields.push(field);
self
}
pub fn validate(&self, facts: &TypedFacts) -> Result<()> {
for field in &self.fields {
let value = facts.get(&field.name);
if field.required && value.is_none() {
return Err(RuleEngineError::EvaluationError {
message: format!(
"Required field '{}' missing in template '{}'",
field.name, self.name
),
});
}
if let Some(val) = value {
if !field.field_type.matches(val) {
return Err(RuleEngineError::EvaluationError {
message: format!(
"Field '{}' has wrong type. Expected {:?}, got {:?}",
field.name, field.field_type, val
),
});
}
}
}
Ok(())
}
pub fn create_instance(&self) -> TypedFacts {
let mut facts = TypedFacts::new();
for field in &self.fields {
let value = field
.default_value
.clone()
.unwrap_or_else(|| field.field_type.default_value());
facts.set(&field.name, value);
}
facts
}
pub fn get_field(&self, name: &str) -> Option<&FieldDef> {
self.field_map
.get(name)
.and_then(|idx| self.fields.get(*idx))
}
}
pub struct TemplateBuilder {
template: Template,
}
impl TemplateBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
template: Template::new(name),
}
}
pub fn string_field(mut self, name: impl Into<String>) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type: FieldType::String,
default_value: None,
required: false,
});
self
}
pub fn required_string(mut self, name: impl Into<String>) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type: FieldType::String,
default_value: None,
required: true,
});
self
}
pub fn integer_field(mut self, name: impl Into<String>) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type: FieldType::Integer,
default_value: None,
required: false,
});
self
}
pub fn float_field(mut self, name: impl Into<String>) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type: FieldType::Float,
default_value: None,
required: false,
});
self
}
pub fn boolean_field(mut self, name: impl Into<String>) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type: FieldType::Boolean,
default_value: None,
required: false,
});
self
}
pub fn field_with_default(
mut self,
name: impl Into<String>,
field_type: FieldType,
default: FactValue,
) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type,
default_value: Some(default),
required: false,
});
self
}
pub fn array_field(mut self, name: impl Into<String>, element_type: FieldType) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type: FieldType::Array(Box::new(element_type)),
default_value: None,
required: false,
});
self
}
pub fn multislot_field(self, name: impl Into<String>, element_type: FieldType) -> Self {
self.array_field(name, element_type)
}
pub fn required_array_field(
mut self,
name: impl Into<String>,
element_type: FieldType,
) -> Self {
self.template.add_field(FieldDef {
name: name.into(),
field_type: FieldType::Array(Box::new(element_type)),
default_value: None,
required: true,
});
self
}
pub fn required_multislot_field(
self,
name: impl Into<String>,
element_type: FieldType,
) -> Self {
self.required_array_field(name, element_type)
}
pub fn build(self) -> Template {
self.template
}
}
pub struct TemplateRegistry {
templates: HashMap<String, Template>,
}
impl TemplateRegistry {
pub fn new() -> Self {
Self {
templates: HashMap::new(),
}
}
pub fn register(&mut self, template: Template) {
self.templates.insert(template.name.clone(), template);
}
pub fn get(&self, name: &str) -> Option<&Template> {
self.templates.get(name)
}
pub fn create_instance(&self, template_name: &str) -> Result<TypedFacts> {
let template = self
.get(template_name)
.ok_or_else(|| RuleEngineError::EvaluationError {
message: format!("Template '{}' not found", template_name),
})?;
Ok(template.create_instance())
}
pub fn validate(&self, template_name: &str, facts: &TypedFacts) -> Result<()> {
let template = self
.get(template_name)
.ok_or_else(|| RuleEngineError::EvaluationError {
message: format!("Template '{}' not found", template_name),
})?;
template.validate(facts)
}
pub fn list_templates(&self) -> Vec<&str> {
self.templates.keys().map(|s| s.as_str()).collect()
}
}
impl Default for TemplateRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_template_builder() {
let template = TemplateBuilder::new("Person")
.required_string("name")
.integer_field("age")
.boolean_field("is_adult")
.build();
assert_eq!(template.name, "Person");
assert_eq!(template.fields.len(), 3);
assert!(template.get_field("name").unwrap().required);
}
#[test]
fn test_create_instance() {
let template = TemplateBuilder::new("Person")
.string_field("name")
.integer_field("age")
.build();
let instance = template.create_instance();
assert_eq!(
instance.get("name"),
Some(&FactValue::String(String::new()))
);
assert_eq!(instance.get("age"), Some(&FactValue::Integer(0)));
}
#[test]
fn test_validation_success() {
let template = TemplateBuilder::new("Person")
.required_string("name")
.integer_field("age")
.build();
let mut facts = TypedFacts::new();
facts.set("name", FactValue::String("Alice".to_string()));
facts.set("age", FactValue::Integer(30));
assert!(template.validate(&facts).is_ok());
}
#[test]
fn test_validation_missing_required() {
let template = TemplateBuilder::new("Person")
.required_string("name")
.integer_field("age")
.build();
let mut facts = TypedFacts::new();
facts.set("age", FactValue::Integer(30));
assert!(template.validate(&facts).is_err());
}
#[test]
fn test_validation_wrong_type() {
let template = TemplateBuilder::new("Person")
.string_field("name")
.integer_field("age")
.build();
let mut facts = TypedFacts::new();
facts.set("name", FactValue::String("Alice".to_string()));
facts.set("age", FactValue::String("thirty".to_string()));
assert!(template.validate(&facts).is_err());
}
#[test]
fn test_template_registry() {
let mut registry = TemplateRegistry::new();
let template = TemplateBuilder::new("Order")
.required_string("order_id")
.float_field("amount")
.build();
registry.register(template);
assert!(registry.get("Order").is_some());
assert!(registry.create_instance("Order").is_ok());
assert_eq!(registry.list_templates(), vec!["Order"]);
}
#[test]
fn test_array_field() {
let template = TemplateBuilder::new("ShoppingCart")
.array_field("items", FieldType::String)
.build();
let mut facts = TypedFacts::new();
facts.set(
"items",
FactValue::Array(vec![
FactValue::String("item1".to_string()),
FactValue::String("item2".to_string()),
]),
);
assert!(template.validate(&facts).is_ok());
}
#[test]
fn test_field_with_default() {
let template = TemplateBuilder::new("Config")
.field_with_default("timeout", FieldType::Integer, FactValue::Integer(30))
.build();
let instance = template.create_instance();
assert_eq!(instance.get("timeout"), Some(&FactValue::Integer(30)));
}
}