use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use async_trait::async_trait;
use crate::{Result, OdinError};
use crate::message::{OdinMessage, MessagePriority};
use crate::protocol::OdinProtocol;
#[derive(Debug)]
pub struct HELRuleEngine {
rules: Arc<RwLock<HashMap<String, Rule>>>,
stats: Arc<RwLock<RuleStats>>,
protocol: Option<Arc<OdinProtocol>>,
}
impl HELRuleEngine {
pub fn new() -> Self {
Self {
rules: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(RuleStats::default())),
protocol: None,
}
}
pub fn set_protocol(&mut self, protocol: Arc<OdinProtocol>) {
self.protocol = Some(protocol);
}
pub async fn add_rule(&self, rule: Rule) -> Result<()> {
rule.validate()?;
let mut rules = self.rules.write().await;
rules.insert(rule.id.clone(), rule);
let mut stats = self.stats.write().await;
stats.rules_added += 1;
Ok(())
}
pub async fn remove_rule(&self, rule_id: &str) -> Result<bool> {
let mut rules = self.rules.write().await;
let removed = rules.remove(rule_id).is_some();
if removed {
let mut stats = self.stats.write().await;
stats.rules_removed += 1;
}
Ok(removed)
}
pub async fn get_rule(&self, rule_id: &str) -> Option<Rule> {
let rules = self.rules.read().await;
rules.get(rule_id).cloned()
}
pub async fn list_rules(&self) -> Vec<Rule> {
let rules = self.rules.read().await;
rules.values().cloned().collect()
}
pub async fn execute_rules(&self, message: &OdinMessage) -> Result<Vec<RuleExecutionResult>> {
let rules = self.rules.read().await;
let mut results = Vec::new();
for rule in rules.values() {
if rule.matches(message).await? {
let start_time = std::time::Instant::now();
let result = rule.execute(message, self.protocol.as_ref()).await;
let execution_time = start_time.elapsed();
let mut stats = self.stats.write().await;
stats.rules_executed += 1;
stats.total_execution_time += execution_time;
if result.is_err() {
stats.rules_failed += 1;
}
results.push(RuleExecutionResult {
rule_id: rule.id.clone(),
success: result.is_ok(),
execution_time,
error: result.err().map(|e| e.to_string()),
});
}
}
Ok(results)
}
pub async fn get_stats(&self) -> RuleStats {
let stats = self.stats.read().await;
stats.clone()
}
pub async fn clear_rules(&self) -> Result<()> {
let mut rules = self.rules.write().await;
rules.clear();
let mut stats = self.stats.write().await;
stats.rules_cleared += 1;
Ok(())
}
}
impl Default for HELRuleEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Rule {
pub id: String,
pub name: String,
pub description: String,
pub priority: i32,
pub conditions: Vec<Condition>,
pub actions: Vec<Action>,
pub enabled: bool,
pub metadata: HashMap<String, String>,
}
impl Rule {
pub fn new(id: String, name: String, description: String) -> Self {
Self {
id,
name,
description,
priority: 0,
conditions: Vec::new(),
actions: Vec::new(),
enabled: true,
metadata: HashMap::new(),
}
}
pub fn add_condition(mut self, condition: Condition) -> Self {
self.conditions.push(condition);
self
}
pub fn add_action(mut self, action: Action) -> Self {
self.actions.push(action);
self
}
pub fn priority(mut self, priority: i32) -> Self {
self.priority = priority;
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn validate(&self) -> Result<()> {
if self.id.is_empty() {
return Err(OdinError::Rule("Rule ID cannot be empty".to_string()));
}
if self.name.is_empty() {
return Err(OdinError::Rule("Rule name cannot be empty".to_string()));
}
if self.conditions.is_empty() {
return Err(OdinError::Rule("Rule must have at least one condition".to_string()));
}
if self.actions.is_empty() {
return Err(OdinError::Rule("Rule must have at least one action".to_string()));
}
Ok(())
}
pub async fn matches(&self, message: &OdinMessage) -> Result<bool> {
if !self.enabled {
return Ok(false);
}
for condition in &self.conditions {
if !condition.evaluate(message).await? {
return Ok(false);
}
}
Ok(true)
}
pub async fn execute(&self, message: &OdinMessage, protocol: Option<&Arc<OdinProtocol>>) -> Result<()> {
for action in &self.actions {
action.execute(message, protocol).await?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Condition {
ContentContains(String),
SourceMatches(String),
TargetMatches(String),
PriorityEquals(MessagePriority),
Custom(String),
}
impl Condition {
pub async fn evaluate(&self, message: &OdinMessage) -> Result<bool> {
match self {
Condition::ContentContains(text) => {
Ok(message.content.contains(text))
}
Condition::SourceMatches(pattern) => {
Ok(message.source_node.contains(pattern))
}
Condition::TargetMatches(pattern) => {
Ok(message.target_node.contains(pattern))
}
Condition::PriorityEquals(priority) => {
Ok(message.priority == *priority)
}
Condition::Custom(expression) => {
Ok(true)
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Action {
SendMessage {
target: String,
content: String,
priority: MessagePriority,
},
Log {
level: LogLevel,
message: String,
},
Forward {
target: String,
},
ModifyContent {
new_content: String,
},
Custom(String),
}
impl Action {
pub async fn execute(&self, message: &OdinMessage, protocol: Option<&Arc<OdinProtocol>>) -> Result<()> {
match self {
Action::SendMessage { target, content, priority } => {
if let Some(protocol) = protocol {
protocol.send_message(target, content, *priority).await?;
}
}
Action::Log { level, message: log_msg } => {
match level {
LogLevel::Info => println!("[INFO] {}", log_msg),
LogLevel::Warning => println!("[WARN] {}", log_msg),
LogLevel::Error => eprintln!("[ERROR] {}", log_msg),
LogLevel::Debug => println!("[DEBUG] {}", log_msg),
}
}
Action::Forward { target } => {
if let Some(protocol) = protocol {
protocol.send_message(target, &message.content, message.priority).await?;
}
}
Action::ModifyContent { new_content: _ } => {
}
Action::Custom(_code) => {
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LogLevel {
Info,
Warning,
Error,
Debug,
}
#[derive(Debug, Clone)]
pub struct RuleExecutionResult {
pub rule_id: String,
pub success: bool,
pub execution_time: std::time::Duration,
pub error: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct RuleStats {
pub rules_added: u64,
pub rules_removed: u64,
pub rules_executed: u64,
pub rules_failed: u64,
pub rules_cleared: u64,
pub total_execution_time: std::time::Duration,
}
impl RuleStats {
pub fn average_execution_time(&self) -> std::time::Duration {
if self.rules_executed > 0 {
self.total_execution_time / self.rules_executed as u32
} else {
std::time::Duration::ZERO
}
}
pub fn success_rate(&self) -> f64 {
if self.rules_executed > 0 {
(self.rules_executed - self.rules_failed) as f64 / self.rules_executed as f64
} else {
0.0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::{MessageType, OdinMessage};
#[tokio::test]
async fn test_rule_creation() {
let rule = Rule::new(
"test-rule".to_string(),
"Test Rule".to_string(),
"A test rule".to_string(),
)
.add_condition(Condition::ContentContains("hello".to_string()))
.add_action(Action::Log {
level: LogLevel::Info,
message: "Rule triggered".to_string(),
});
assert!(rule.validate().is_ok());
assert_eq!(rule.id, "test-rule");
assert_eq!(rule.conditions.len(), 1);
assert_eq!(rule.actions.len(), 1);
}
#[tokio::test]
async fn test_rule_engine() {
let engine = HELRuleEngine::new();
let rule = Rule::new(
"test-rule".to_string(),
"Test Rule".to_string(),
"A test rule".to_string(),
)
.add_condition(Condition::ContentContains("hello".to_string()))
.add_action(Action::Log {
level: LogLevel::Info,
message: "Rule triggered".to_string(),
});
engine.add_rule(rule).await.unwrap();
let rules = engine.list_rules().await;
assert_eq!(rules.len(), 1);
assert_eq!(rules[0].id, "test-rule");
}
#[tokio::test]
async fn test_rule_matching() {
let rule = Rule::new(
"test-rule".to_string(),
"Test Rule".to_string(),
"A test rule".to_string(),
)
.add_condition(Condition::ContentContains("hello".to_string()));
let message = OdinMessage::new(
MessageType::Standard,
"source",
"target",
"hello world",
MessagePriority::Normal,
);
assert!(rule.matches(&message).await.unwrap());
let message2 = OdinMessage::new(
MessageType::Standard,
"source",
"target",
"goodbye world",
MessagePriority::Normal,
);
assert!(!rule.matches(&message2).await.unwrap());
}
#[tokio::test]
async fn test_condition_evaluation() {
let condition = Condition::ContentContains("test".to_string());
let message = OdinMessage::new(
MessageType::Standard,
"source",
"target",
"this is a test message",
MessagePriority::Normal,
);
assert!(condition.evaluate(&message).await.unwrap());
}
}