1use crate::error::{DataForgeError, Result};
6use crate::memory::StringPool;
7use rand::distributions::Distribution;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14
15#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum RuleType {
18 Regex { pattern: String, flags: Option<String> },
20 Range { min: Value, max: Value },
22 Enum { values: Vec<Value> },
24 Length { min: Option<usize>, max: Option<usize> },
26 Format { format: String },
28 Custom { name: String, params: HashMap<String, Value> },
30 Composite { operator: LogicalOperator, rules: Vec<Rule> },
32}
33
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub enum LogicalOperator {
37 And,
38 Or,
39 Not,
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44pub struct Rule {
45 pub id: String,
47 pub name: String,
49 pub rule_type: RuleType,
51 pub priority: u32,
53 pub enabled: bool,
55 pub description: Option<String>,
57 pub tags: Vec<String>,
59 pub parent_id: Option<String>,
61}
62
63#[derive(Debug, Clone)]
65pub struct RuleContext {
66 pub field_name: String,
68 pub current_value: Option<Value>,
70 pub params: HashMap<String, Value>,
72 pub generation_history: Vec<Value>,
74}
75
76#[derive(Debug, Clone)]
78pub struct RuleResult {
79 pub matched: bool,
81 pub value: Option<Value>,
83 pub error: Option<String>,
85 pub execution_time: std::time::Duration,
87}
88
89pub trait CustomRuleHandler: Send + Sync {
91 fn handle(&self, rule: &Rule, context: &RuleContext) -> Result<RuleResult>;
93
94 fn name(&self) -> &str;
96}
97
98pub struct RuleEngine {
100 rules: RwLock<HashMap<String, Rule>>,
102 regex_cache: RwLock<HashMap<String, Regex>>,
104 custom_handlers: RwLock<HashMap<String, Arc<dyn CustomRuleHandler>>>,
106 #[allow(dead_code)]
108 string_pool: Arc<StringPool>,
109 inheritance_tree: RwLock<HashMap<String, Vec<String>>>,
111}
112
113impl RuleEngine {
114 pub fn new(string_pool: Arc<StringPool>) -> Self {
116 Self {
117 rules: RwLock::new(HashMap::new()),
118 regex_cache: RwLock::new(HashMap::new()),
119 custom_handlers: RwLock::new(HashMap::new()),
120 string_pool,
121 inheritance_tree: RwLock::new(HashMap::new()),
122 }
123 }
124
125 pub fn add_rule(&self, rule: Rule) -> Result<()> {
127 let rule_id = rule.id.clone();
128
129 self.validate_rule(&rule)?;
131
132 if let Some(parent_id) = &rule.parent_id {
134 let mut tree = self.inheritance_tree.write().unwrap();
135 tree.entry(parent_id.clone()).or_insert_with(Vec::new).push(rule_id.clone());
136 }
137
138 let mut rules = self.rules.write().unwrap();
140 rules.insert(rule_id, rule);
141
142 Ok(())
143 }
144
145 pub fn remove_rule(&self, rule_id: &str) -> Result<()> {
147 let mut rules = self.rules.write().unwrap();
148
149 if let Some(rule) = rules.remove(rule_id) {
150 if let Some(parent_id) = &rule.parent_id {
152 let mut tree = self.inheritance_tree.write().unwrap();
153 if let Some(children) = tree.get_mut(parent_id) {
154 children.retain(|id| id != rule_id);
155 }
156 }
157
158 if let RuleType::Regex { pattern, .. } = &rule.rule_type {
160 let mut cache = self.regex_cache.write().unwrap();
161 cache.remove(pattern);
162 }
163
164 Ok(())
165 } else {
166 Err(DataForgeError::validation(&format!("Rule not found: {}", rule_id)))
167 }
168 }
169
170 pub fn execute_rule(&self, rule_id: &str, context: &RuleContext) -> Result<RuleResult> {
172 let start_time = std::time::Instant::now();
173
174 let rule = {
175 let rules = self.rules.read().unwrap();
176 rules.get(rule_id)
177 .ok_or_else(|| DataForgeError::validation(&format!("Rule not found: {}", rule_id)))?
178 .clone()
179 };
180
181 if !rule.enabled {
182 return Ok(RuleResult {
183 matched: false,
184 value: None,
185 error: Some("Rule is disabled".to_string()),
186 execution_time: start_time.elapsed(),
187 });
188 }
189
190 let result = self.execute_rule_internal(&rule, context);
191
192 Ok(RuleResult {
193 matched: result.is_ok(),
194 value: result.as_ref().ok().cloned(),
195 error: result.as_ref().err().map(|e| e.to_string()),
196 execution_time: start_time.elapsed(),
197 })
198 }
199
200 fn execute_rule_internal(&self, rule: &Rule, context: &RuleContext) -> Result<Value> {
202 match &rule.rule_type {
203 RuleType::Regex { pattern, flags } => {
204 self.execute_regex_rule(pattern, flags.as_deref(), context)
205 }
206 RuleType::Range { min, max } => {
207 self.execute_range_rule(min, max, context)
208 }
209 RuleType::Enum { values } => {
210 self.execute_enum_rule(values, context)
211 }
212 RuleType::Length { min, max } => {
213 self.execute_length_rule(*min, *max, context)
214 }
215 RuleType::Format { format } => {
216 self.execute_format_rule(format, context)
217 }
218 RuleType::Custom { name, params } => {
219 self.execute_custom_rule(name, params, rule, context)
220 }
221 RuleType::Composite { operator, rules } => {
222 self.execute_composite_rule(operator, rules, context)
223 }
224 }
225 }
226
227 fn execute_regex_rule(&self, pattern: &str, _flags: Option<&str>, _context: &RuleContext) -> Result<Value> {
229 let _regex = self.get_or_compile_regex(pattern)?;
230
231 use rand_regex::Regex as RandRegex;
233 let rand_regex = RandRegex::compile(pattern, 100)
234 .map_err(|e| DataForgeError::generator(&format!("Failed to compile regex for generation: {}", e)))?;
235
236 let mut rng = rand::thread_rng();
237 let generated = rand_regex.sample(&mut rng);
238
239 Ok(Value::String(generated))
240 }
241
242 fn execute_range_rule(&self, min: &Value, max: &Value, _context: &RuleContext) -> Result<Value> {
244 use rand::Rng;
245 let mut rng = rand::thread_rng();
246
247 match (min, max) {
248 (Value::Number(min_num), Value::Number(max_num)) => {
249 if let (Some(min_f), Some(max_f)) = (min_num.as_f64(), max_num.as_f64()) {
250 let value = rng.gen_range(min_f..=max_f);
251 Ok(Value::Number(serde_json::Number::from_f64(value).unwrap()))
252 } else {
253 Err(DataForgeError::validation("Invalid number range"))
254 }
255 }
256 _ => Err(DataForgeError::validation("Range rule requires numeric min and max values")),
257 }
258 }
259
260 fn execute_enum_rule(&self, values: &[Value], _context: &RuleContext) -> Result<Value> {
262 use rand::seq::SliceRandom;
263 let mut rng = rand::thread_rng();
264
265 values.choose(&mut rng)
266 .cloned()
267 .ok_or_else(|| DataForgeError::validation("Empty enum values"))
268 }
269
270 fn execute_length_rule(&self, min: Option<usize>, max: Option<usize>, _context: &RuleContext) -> Result<Value> {
272 use rand::Rng;
273 let mut rng = rand::thread_rng();
274
275 let min_len = min.unwrap_or(1);
276 let max_len = max.unwrap_or(20);
277 let length = rng.gen_range(min_len..=max_len);
278
279 let chars: String = (0..length)
280 .map(|_| rng.gen_range(b'a'..=b'z') as char)
281 .collect();
282
283 Ok(Value::String(chars))
284 }
285
286 fn execute_format_rule(&self, format: &str, context: &RuleContext) -> Result<Value> {
288 let mut result = format.to_string();
290
291 result = result.replace("{field_name}", &context.field_name);
293 result = result.replace("{random_number}", &rand::random::<u32>().to_string());
294 result = result.replace("{timestamp}", &chrono::Utc::now().timestamp().to_string());
295
296 Ok(Value::String(result))
297 }
298
299 fn execute_custom_rule(&self, name: &str, _params: &HashMap<String, Value>, rule: &Rule, context: &RuleContext) -> Result<Value> {
301 let handlers = self.custom_handlers.read().unwrap();
302
303 if let Some(handler) = handlers.get(name) {
304 let result = handler.handle(rule, context)?;
305 result.value.ok_or_else(|| DataForgeError::generator("Custom rule handler returned no value"))
306 } else {
307 Err(DataForgeError::validation(&format!("Custom rule handler not found: {}", name)))
308 }
309 }
310
311 fn execute_composite_rule(&self, operator: &LogicalOperator, rules: &[Rule], context: &RuleContext) -> Result<Value> {
313 match operator {
314 LogicalOperator::And => {
315 let mut last_value = Value::Null;
317 for rule in rules {
318 last_value = self.execute_rule_internal(rule, context)?;
319 }
320 Ok(last_value)
321 }
322 LogicalOperator::Or => {
323 for rule in rules {
325 if let Ok(value) = self.execute_rule_internal(rule, context) {
326 return Ok(value);
327 }
328 }
329 Err(DataForgeError::generator("No rule in OR composite succeeded"))
330 }
331 LogicalOperator::Not => {
332 Ok(Value::Null)
334 }
335 }
336 }
337
338 fn get_or_compile_regex(&self, pattern: &str) -> Result<Regex> {
340 {
342 let cache = self.regex_cache.read().unwrap();
343 if let Some(regex) = cache.get(pattern) {
344 return Ok(regex.clone());
345 }
346 }
347
348 let regex = Regex::new(pattern)
350 .map_err(|e| DataForgeError::validation(&format!("Invalid regex pattern: {}", e)))?;
351
352 {
354 let mut cache = self.regex_cache.write().unwrap();
355 cache.insert(pattern.to_string(), regex.clone());
356 }
357
358 Ok(regex)
359 }
360
361 fn validate_rule(&self, rule: &Rule) -> Result<()> {
363 {
365 let rules = self.rules.read().unwrap();
366 if rules.contains_key(&rule.id) {
367 return Err(DataForgeError::validation(&format!("Rule ID already exists: {}", rule.id)));
368 }
369 }
370
371 if let Some(parent_id) = &rule.parent_id {
373 let rules = self.rules.read().unwrap();
374 if !rules.contains_key(parent_id) {
375 return Err(DataForgeError::validation(&format!("Parent rule not found: {}", parent_id)));
376 }
377 }
378
379 match &rule.rule_type {
381 RuleType::Regex { pattern, .. } => {
382 Regex::new(pattern)
383 .map_err(|e| DataForgeError::validation(&format!("Invalid regex pattern: {}", e)))?;
384 }
385 RuleType::Range { min, max } => {
386 if !min.is_number() || !max.is_number() {
387 return Err(DataForgeError::validation("Range rule requires numeric min and max values"));
388 }
389 }
390 RuleType::Enum { values } => {
391 if values.is_empty() {
392 return Err(DataForgeError::validation("Enum rule requires at least one value"));
393 }
394 }
395 _ => {} }
397
398 Ok(())
399 }
400
401 pub fn register_custom_handler(&self, handler: Arc<dyn CustomRuleHandler>) {
403 let mut handlers = self.custom_handlers.write().unwrap();
404 handlers.insert(handler.name().to_string(), handler);
405 }
406
407 pub fn get_all_rules(&self) -> Vec<Rule> {
409 let rules = self.rules.read().unwrap();
410 rules.values().cloned().collect()
411 }
412
413 pub fn find_rules_by_tag(&self, tag: &str) -> Vec<Rule> {
415 let rules = self.rules.read().unwrap();
416 rules.values()
417 .filter(|rule| rule.tags.contains(&tag.to_string()))
418 .cloned()
419 .collect()
420 }
421
422 pub fn get_inheritance_chain(&self, rule_id: &str) -> Vec<String> {
424 let mut chain = Vec::new();
425 let mut current_id = rule_id.to_string();
426
427 let rules = self.rules.read().unwrap();
428
429 while let Some(rule) = rules.get(¤t_id) {
430 chain.push(current_id.clone());
431 if let Some(parent_id) = &rule.parent_id {
432 current_id = parent_id.clone();
433 } else {
434 break;
435 }
436 }
437
438 chain.reverse();
439 chain
440 }
441
442 pub fn clear_cache(&self) {
444 let mut cache = self.regex_cache.write().unwrap();
445 cache.clear();
446 }
447}
448
449impl Default for RuleEngine {
450 fn default() -> Self {
451 Self::new(Arc::new(StringPool::default()))
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_rule_engine_creation() {
461 let string_pool = Arc::new(StringPool::default());
462 let engine = RuleEngine::new(string_pool);
463
464 assert_eq!(engine.get_all_rules().len(), 0);
465 }
466
467 #[test]
468 fn test_add_rule() {
469 let engine = RuleEngine::default();
470
471 let rule = Rule {
472 id: "test_rule".to_string(),
473 name: "Test Rule".to_string(),
474 rule_type: RuleType::Regex {
475 pattern: r"\d{3}-\d{3}-\d{4}".to_string(),
476 flags: None,
477 },
478 priority: 100,
479 enabled: true,
480 description: Some("Test phone number rule".to_string()),
481 tags: vec!["phone".to_string()],
482 parent_id: None,
483 };
484
485 assert!(engine.add_rule(rule).is_ok());
486 assert_eq!(engine.get_all_rules().len(), 1);
487 }
488
489 #[test]
490 fn test_execute_enum_rule() {
491 let engine = RuleEngine::default();
492
493 let rule = Rule {
494 id: "enum_rule".to_string(),
495 name: "Enum Rule".to_string(),
496 rule_type: RuleType::Enum {
497 values: vec![
498 Value::String("A".to_string()),
499 Value::String("B".to_string()),
500 Value::String("C".to_string()),
501 ],
502 },
503 priority: 100,
504 enabled: true,
505 description: None,
506 tags: vec![],
507 parent_id: None,
508 };
509
510 engine.add_rule(rule).unwrap();
511
512 let context = RuleContext {
513 field_name: "test_field".to_string(),
514 current_value: None,
515 params: HashMap::new(),
516 generation_history: Vec::new(),
517 };
518
519 let result = engine.execute_rule("enum_rule", &context).unwrap();
520 assert!(result.matched);
521 assert!(result.value.is_some());
522
523 if let Some(Value::String(s)) = result.value {
524 assert!(["A", "B", "C"].contains(&s.as_str()));
525 }
526 }
527
528 #[test]
529 fn test_rule_inheritance() {
530 let engine = RuleEngine::default();
531
532 let parent_rule = Rule {
534 id: "parent_rule".to_string(),
535 name: "Parent Rule".to_string(),
536 rule_type: RuleType::Length { min: Some(5), max: Some(10) },
537 priority: 50,
538 enabled: true,
539 description: None,
540 tags: vec![],
541 parent_id: None,
542 };
543 engine.add_rule(parent_rule).unwrap();
544
545 let child_rule = Rule {
547 id: "child_rule".to_string(),
548 name: "Child Rule".to_string(),
549 rule_type: RuleType::Length { min: Some(3), max: Some(8) },
550 priority: 100,
551 enabled: true,
552 description: None,
553 tags: vec![],
554 parent_id: Some("parent_rule".to_string()),
555 };
556 engine.add_rule(child_rule).unwrap();
557
558 let chain = engine.get_inheritance_chain("child_rule");
559 assert_eq!(chain, vec!["parent_rule", "child_rule"]);
560 }
561
562 #[test]
563 fn test_find_rules_by_tag() {
564 let engine = RuleEngine::default();
565
566 let rule1 = Rule {
567 id: "rule1".to_string(),
568 name: "Rule 1".to_string(),
569 rule_type: RuleType::Length { min: Some(1), max: Some(10) },
570 priority: 100,
571 enabled: true,
572 description: None,
573 tags: vec!["test".to_string(), "demo".to_string()],
574 parent_id: None,
575 };
576
577 let rule2 = Rule {
578 id: "rule2".to_string(),
579 name: "Rule 2".to_string(),
580 rule_type: RuleType::Length { min: Some(1), max: Some(5) },
581 priority: 50,
582 enabled: true,
583 description: None,
584 tags: vec!["test".to_string()],
585 parent_id: None,
586 };
587
588 engine.add_rule(rule1).unwrap();
589 engine.add_rule(rule2).unwrap();
590
591 let test_rules = engine.find_rules_by_tag("test");
592 assert_eq!(test_rules.len(), 2);
593
594 let demo_rules = engine.find_rules_by_tag("demo");
595 assert_eq!(demo_rules.len(), 1);
596 }
597}