1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use async_trait::async_trait;
8use crate::{Result, OdinError};
9use crate::message::{OdinMessage, MessagePriority};
10use crate::protocol::OdinProtocol;
11
12#[derive(Debug)]
14pub struct HELRuleEngine {
15 rules: Arc<RwLock<HashMap<String, Rule>>>,
17 stats: Arc<RwLock<RuleStats>>,
19 protocol: Option<Arc<OdinProtocol>>,
21}
22
23impl HELRuleEngine {
24 pub fn new() -> Self {
26 Self {
27 rules: Arc::new(RwLock::new(HashMap::new())),
28 stats: Arc::new(RwLock::new(RuleStats::default())),
29 protocol: None,
30 }
31 }
32
33 pub fn set_protocol(&mut self, protocol: Arc<OdinProtocol>) {
35 self.protocol = Some(protocol);
36 }
37
38 pub async fn add_rule(&self, rule: Rule) -> Result<()> {
40 rule.validate()?;
41
42 let mut rules = self.rules.write().await;
43 rules.insert(rule.id.clone(), rule);
44
45 let mut stats = self.stats.write().await;
46 stats.rules_added += 1;
47
48 Ok(())
49 }
50
51 pub async fn remove_rule(&self, rule_id: &str) -> Result<bool> {
53 let mut rules = self.rules.write().await;
54 let removed = rules.remove(rule_id).is_some();
55
56 if removed {
57 let mut stats = self.stats.write().await;
58 stats.rules_removed += 1;
59 }
60
61 Ok(removed)
62 }
63
64 pub async fn get_rule(&self, rule_id: &str) -> Option<Rule> {
66 let rules = self.rules.read().await;
67 rules.get(rule_id).cloned()
68 }
69
70 pub async fn list_rules(&self) -> Vec<Rule> {
72 let rules = self.rules.read().await;
73 rules.values().cloned().collect()
74 }
75
76 pub async fn execute_rules(&self, message: &OdinMessage) -> Result<Vec<RuleExecutionResult>> {
78 let rules = self.rules.read().await;
79 let mut results = Vec::new();
80
81 for rule in rules.values() {
82 if rule.matches(message).await? {
83 let start_time = std::time::Instant::now();
84 let result = rule.execute(message, self.protocol.as_ref()).await;
85 let execution_time = start_time.elapsed();
86
87 let mut stats = self.stats.write().await;
89 stats.rules_executed += 1;
90 stats.total_execution_time += execution_time;
91
92 if result.is_err() {
93 stats.rules_failed += 1;
94 }
95
96 results.push(RuleExecutionResult {
97 rule_id: rule.id.clone(),
98 success: result.is_ok(),
99 execution_time,
100 error: result.err().map(|e| e.to_string()),
101 });
102 }
103 }
104
105 Ok(results)
106 }
107
108 pub async fn get_stats(&self) -> RuleStats {
110 let stats = self.stats.read().await;
111 stats.clone()
112 }
113
114 pub async fn clear_rules(&self) -> Result<()> {
116 let mut rules = self.rules.write().await;
117 rules.clear();
118
119 let mut stats = self.stats.write().await;
120 stats.rules_cleared += 1;
121
122 Ok(())
123 }
124}
125
126impl Default for HELRuleEngine {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct Rule {
135 pub id: String,
137 pub name: String,
139 pub description: String,
141 pub priority: i32,
143 pub conditions: Vec<Condition>,
145 pub actions: Vec<Action>,
147 pub enabled: bool,
149 pub metadata: HashMap<String, String>,
151}
152
153impl Rule {
154 pub fn new(id: String, name: String, description: String) -> Self {
156 Self {
157 id,
158 name,
159 description,
160 priority: 0,
161 conditions: Vec::new(),
162 actions: Vec::new(),
163 enabled: true,
164 metadata: HashMap::new(),
165 }
166 }
167
168 pub fn add_condition(mut self, condition: Condition) -> Self {
170 self.conditions.push(condition);
171 self
172 }
173
174 pub fn add_action(mut self, action: Action) -> Self {
176 self.actions.push(action);
177 self
178 }
179
180 pub fn priority(mut self, priority: i32) -> Self {
182 self.priority = priority;
183 self
184 }
185
186 pub fn enabled(mut self, enabled: bool) -> Self {
188 self.enabled = enabled;
189 self
190 }
191
192 pub fn validate(&self) -> Result<()> {
194 if self.id.is_empty() {
195 return Err(OdinError::Rule("Rule ID cannot be empty".to_string()));
196 }
197
198 if self.name.is_empty() {
199 return Err(OdinError::Rule("Rule name cannot be empty".to_string()));
200 }
201
202 if self.conditions.is_empty() {
203 return Err(OdinError::Rule("Rule must have at least one condition".to_string()));
204 }
205
206 if self.actions.is_empty() {
207 return Err(OdinError::Rule("Rule must have at least one action".to_string()));
208 }
209
210 Ok(())
211 }
212
213 pub async fn matches(&self, message: &OdinMessage) -> Result<bool> {
215 if !self.enabled {
216 return Ok(false);
217 }
218
219 for condition in &self.conditions {
220 if !condition.evaluate(message).await? {
221 return Ok(false);
222 }
223 }
224
225 Ok(true)
226 }
227
228 pub async fn execute(&self, message: &OdinMessage, protocol: Option<&Arc<OdinProtocol>>) -> Result<()> {
230 for action in &self.actions {
231 action.execute(message, protocol).await?;
232 }
233 Ok(())
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
239pub enum Condition {
240 ContentContains(String),
242 SourceMatches(String),
244 TargetMatches(String),
246 PriorityEquals(MessagePriority),
248 Custom(String),
250}
251
252impl Condition {
253 pub async fn evaluate(&self, message: &OdinMessage) -> Result<bool> {
255 match self {
256 Condition::ContentContains(text) => {
257 Ok(message.content.contains(text))
258 }
259 Condition::SourceMatches(pattern) => {
260 Ok(message.source_node.contains(pattern))
261 }
262 Condition::TargetMatches(pattern) => {
263 Ok(message.target_node.contains(pattern))
264 }
265 Condition::PriorityEquals(priority) => {
266 Ok(message.priority == *priority)
267 }
268 Condition::Custom(expression) => {
269 Ok(true)
272 }
273 }
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub enum Action {
280 SendMessage {
282 target: String,
283 content: String,
284 priority: MessagePriority,
285 },
286 Log {
288 level: LogLevel,
289 message: String,
290 },
291 Forward {
293 target: String,
294 },
295 ModifyContent {
297 new_content: String,
298 },
299 Custom(String),
301}
302
303impl Action {
304 pub async fn execute(&self, message: &OdinMessage, protocol: Option<&Arc<OdinProtocol>>) -> Result<()> {
306 match self {
307 Action::SendMessage { target, content, priority } => {
308 if let Some(protocol) = protocol {
309 protocol.send_message(target, content, *priority).await?;
310 }
311 }
312 Action::Log { level, message: log_msg } => {
313 match level {
314 LogLevel::Info => println!("[INFO] {}", log_msg),
315 LogLevel::Warning => println!("[WARN] {}", log_msg),
316 LogLevel::Error => eprintln!("[ERROR] {}", log_msg),
317 LogLevel::Debug => println!("[DEBUG] {}", log_msg),
318 }
319 }
320 Action::Forward { target } => {
321 if let Some(protocol) = protocol {
322 protocol.send_message(target, &message.content, message.priority).await?;
323 }
324 }
325 Action::ModifyContent { new_content: _ } => {
326 }
329 Action::Custom(_code) => {
330 }
333 }
334 Ok(())
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub enum LogLevel {
341 Info,
342 Warning,
343 Error,
344 Debug,
345}
346
347#[derive(Debug, Clone)]
349pub struct RuleExecutionResult {
350 pub rule_id: String,
351 pub success: bool,
352 pub execution_time: std::time::Duration,
353 pub error: Option<String>,
354}
355
356#[derive(Debug, Clone, Default)]
358pub struct RuleStats {
359 pub rules_added: u64,
360 pub rules_removed: u64,
361 pub rules_executed: u64,
362 pub rules_failed: u64,
363 pub rules_cleared: u64,
364 pub total_execution_time: std::time::Duration,
365}
366
367impl RuleStats {
368 pub fn average_execution_time(&self) -> std::time::Duration {
370 if self.rules_executed > 0 {
371 self.total_execution_time / self.rules_executed as u32
372 } else {
373 std::time::Duration::ZERO
374 }
375 }
376
377 pub fn success_rate(&self) -> f64 {
379 if self.rules_executed > 0 {
380 (self.rules_executed - self.rules_failed) as f64 / self.rules_executed as f64
381 } else {
382 0.0
383 }
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::message::{MessageType, OdinMessage};
391
392 #[tokio::test]
393 async fn test_rule_creation() {
394 let rule = Rule::new(
395 "test-rule".to_string(),
396 "Test Rule".to_string(),
397 "A test rule".to_string(),
398 )
399 .add_condition(Condition::ContentContains("hello".to_string()))
400 .add_action(Action::Log {
401 level: LogLevel::Info,
402 message: "Rule triggered".to_string(),
403 });
404
405 assert!(rule.validate().is_ok());
406 assert_eq!(rule.id, "test-rule");
407 assert_eq!(rule.conditions.len(), 1);
408 assert_eq!(rule.actions.len(), 1);
409 }
410
411 #[tokio::test]
412 async fn test_rule_engine() {
413 let engine = HELRuleEngine::new();
414
415 let rule = Rule::new(
416 "test-rule".to_string(),
417 "Test Rule".to_string(),
418 "A test rule".to_string(),
419 )
420 .add_condition(Condition::ContentContains("hello".to_string()))
421 .add_action(Action::Log {
422 level: LogLevel::Info,
423 message: "Rule triggered".to_string(),
424 });
425
426 engine.add_rule(rule).await.unwrap();
427
428 let rules = engine.list_rules().await;
429 assert_eq!(rules.len(), 1);
430 assert_eq!(rules[0].id, "test-rule");
431 }
432
433 #[tokio::test]
434 async fn test_rule_matching() {
435 let rule = Rule::new(
436 "test-rule".to_string(),
437 "Test Rule".to_string(),
438 "A test rule".to_string(),
439 )
440 .add_condition(Condition::ContentContains("hello".to_string()));
441
442 let message = OdinMessage::new(
443 MessageType::Standard,
444 "source",
445 "target",
446 "hello world",
447 MessagePriority::Normal,
448 );
449
450 assert!(rule.matches(&message).await.unwrap());
451
452 let message2 = OdinMessage::new(
453 MessageType::Standard,
454 "source",
455 "target",
456 "goodbye world",
457 MessagePriority::Normal,
458 );
459
460 assert!(!rule.matches(&message2).await.unwrap());
461 }
462
463 #[tokio::test]
464 async fn test_condition_evaluation() {
465 let condition = Condition::ContentContains("test".to_string());
466
467 let message = OdinMessage::new(
468 MessageType::Standard,
469 "source",
470 "target",
471 "this is a test message",
472 MessagePriority::Normal,
473 );
474
475 assert!(condition.evaluate(&message).await.unwrap());
476 }
477}