1use crate::error::{LearningError, Result};
3use crate::models::{Rule, RuleScope};
4use ricecoder_storage::manager::PathResolver;
5use std::collections::HashMap;
6use std::path::PathBuf;
7use std::sync::Arc;
8use tokio::fs;
9use tokio::sync::RwLock;
10
11pub struct RuleStorage {
13 cache: Arc<RwLock<HashMap<String, Rule>>>,
15 scope: RuleScope,
17}
18
19impl RuleStorage {
20 pub fn new(scope: RuleScope) -> Self {
22 Self {
23 cache: Arc::new(RwLock::new(HashMap::new())),
24 scope,
25 }
26 }
27
28 fn get_storage_path(&self) -> Result<PathBuf> {
30 match self.scope {
31 RuleScope::Global => {
32 let global_path = PathResolver::resolve_global_path()?;
33 Ok(global_path.join("rules"))
34 }
35 RuleScope::Project => Ok(PathBuf::from(".ricecoder/rules")),
36 RuleScope::Session => Err(LearningError::PathResolutionFailed(
37 "Session scope has no persistent path".to_string(),
38 )),
39 }
40 }
41
42 async fn ensure_storage_dir(&self) -> Result<()> {
44 if self.scope == RuleScope::Session {
45 return Ok(());
46 }
47
48 let path = self.get_storage_path()?;
49 fs::create_dir_all(&path)
50 .await
51 .map_err(|e| LearningError::RuleStorageFailed(format!("Failed to create storage directory: {}", e)))?;
52
53 Ok(())
54 }
55
56 fn get_rule_file_path(&self, rule_id: &str) -> Result<PathBuf> {
58 let storage_path = self.get_storage_path()?;
59 Ok(storage_path.join(format!("{}.json", rule_id)))
60 }
61
62 pub async fn store_rule(&self, rule: Rule) -> Result<String> {
64 if rule.scope != self.scope {
65 return Err(LearningError::RuleStorageFailed(
66 format!("Rule scope {:?} does not match storage scope {:?}", rule.scope, self.scope),
67 ));
68 }
69
70 if self.scope == RuleScope::Session {
72 let mut cache = self.cache.write().await;
73 let rule_id = rule.id.clone();
74 cache.insert(rule_id.clone(), rule);
75 return Ok(rule_id);
76 }
77
78 self.ensure_storage_dir().await?;
80
81 let rule_id = rule.id.clone();
82 let file_path = self.get_rule_file_path(&rule_id)?;
83
84 let json = serde_json::to_string_pretty(&rule)
85 .map_err(|e| LearningError::SerializationError(e))?;
86
87 fs::write(&file_path, json)
88 .await
89 .map_err(|e| LearningError::RuleStorageFailed(format!("Failed to write rule file: {}", e)))?;
90
91 let mut cache = self.cache.write().await;
93 cache.insert(rule_id.clone(), rule);
94
95 Ok(rule_id)
96 }
97
98 pub async fn get_rule(&self, rule_id: &str) -> Result<Rule> {
100 {
102 let cache = self.cache.read().await;
103 if let Some(rule) = cache.get(rule_id) {
104 return Ok(rule.clone());
105 }
106 }
107
108 if self.scope == RuleScope::Session {
110 return Err(LearningError::RuleNotFound(rule_id.to_string()));
111 }
112
113 let file_path = self.get_rule_file_path(rule_id)?;
115
116 if !file_path.exists() {
117 return Err(LearningError::RuleNotFound(rule_id.to_string()));
118 }
119
120 let json = fs::read_to_string(&file_path)
121 .await
122 .map_err(|e| LearningError::RuleStorageFailed(format!("Failed to read rule file: {}", e)))?;
123
124 let rule: Rule = serde_json::from_str(&json)
125 .map_err(|e| LearningError::SerializationError(e))?;
126
127 {
129 let mut cache = self.cache.write().await;
130 cache.insert(rule_id.to_string(), rule.clone());
131 }
132
133 Ok(rule)
134 }
135
136 pub async fn list_rules(&self) -> Result<Vec<Rule>> {
138 if self.scope == RuleScope::Session {
140 let cache = self.cache.read().await;
141 return Ok(cache.values().cloned().collect());
142 }
143
144 let storage_path = self.get_storage_path()?;
146
147 if !storage_path.exists() {
148 return Ok(Vec::new());
149 }
150
151 let mut rules = Vec::new();
152 let mut dir_entries = fs::read_dir(&storage_path)
153 .await
154 .map_err(|e| LearningError::RuleStorageFailed(format!("Failed to read storage directory: {}", e)))?;
155
156 while let Some(entry) = dir_entries
157 .next_entry()
158 .await
159 .map_err(|e| LearningError::RuleStorageFailed(format!("Failed to read directory entry: {}", e)))?
160 {
161 let path = entry.path();
162
163 if path.extension().map_or(false, |ext| ext == "json") {
164 match fs::read_to_string(&path).await {
165 Ok(json) => match serde_json::from_str::<Rule>(&json) {
166 Ok(rule) => {
167 rules.push(rule);
168 }
169 Err(e) => {
170 eprintln!("Failed to deserialize rule from {:?}: {}", path, e);
171 }
172 },
173 Err(e) => {
174 eprintln!("Failed to read rule file {:?}: {}", path, e);
175 }
176 }
177 }
178 }
179
180 {
182 let mut cache = self.cache.write().await;
183 cache.clear();
184 for rule in &rules {
185 cache.insert(rule.id.clone(), rule.clone());
186 }
187 }
188
189 Ok(rules)
190 }
191
192 pub async fn filter_by_scope(&self, scope: RuleScope) -> Result<Vec<Rule>> {
194 let all_rules = self.list_rules().await?;
195 Ok(all_rules.into_iter().filter(|r| r.scope == scope).collect())
196 }
197
198 pub async fn delete_rule(&self, rule_id: &str) -> Result<()> {
200 {
202 let mut cache = self.cache.write().await;
203 cache.remove(rule_id);
204 }
205
206 if self.scope == RuleScope::Session {
208 return Ok(());
209 }
210
211 let file_path = self.get_rule_file_path(rule_id)?;
213
214 if file_path.exists() {
215 fs::remove_file(&file_path)
216 .await
217 .map_err(|e| LearningError::RuleStorageFailed(format!("Failed to delete rule file: {}", e)))?;
218 }
219
220 Ok(())
221 }
222
223 pub async fn update_rule(&self, rule: Rule) -> Result<String> {
225 if rule.scope != self.scope {
226 return Err(LearningError::RuleStorageFailed(
227 format!("Rule scope {:?} does not match storage scope {:?}", rule.scope, self.scope),
228 ));
229 }
230
231 self.delete_rule(&rule.id).await?;
233 self.store_rule(rule).await
234 }
235
236 pub async fn rule_count(&self) -> Result<usize> {
238 let rules = self.list_rules().await?;
239 Ok(rules.len())
240 }
241
242 pub async fn clear_all(&self) -> Result<()> {
244 {
246 let mut cache = self.cache.write().await;
247 cache.clear();
248 }
249
250 if self.scope == RuleScope::Session {
252 return Ok(());
253 }
254
255 let storage_path = self.get_storage_path()?;
257
258 if storage_path.exists() {
259 fs::remove_dir_all(&storage_path)
260 .await
261 .map_err(|e| LearningError::RuleStorageFailed(format!("Failed to clear storage: {}", e)))?;
262 }
263
264 Ok(())
265 }
266
267 pub async fn load_all(&self) -> Result<()> {
269 let rules = self.list_rules().await?;
270 let mut cache = self.cache.write().await;
271 cache.clear();
272 for rule in rules {
273 cache.insert(rule.id.clone(), rule);
274 }
275 Ok(())
276 }
277
278 pub fn get_scope(&self) -> RuleScope {
280 self.scope
281 }
282
283 pub async fn get_rules_by_pattern(&self, pattern: &str) -> Result<Vec<Rule>> {
285 let rules = self.list_rules().await?;
286 Ok(rules
287 .into_iter()
288 .filter(|r| r.pattern.contains(pattern))
289 .collect())
290 }
291
292 pub async fn get_rules_by_source(&self, source: crate::models::RuleSource) -> Result<Vec<Rule>> {
294 let rules = self.list_rules().await?;
295 Ok(rules.into_iter().filter(|r| r.source == source).collect())
296 }
297
298 pub async fn get_rules_by_confidence(&self, min_confidence: f32) -> Result<Vec<Rule>> {
300 if !(0.0..=1.0).contains(&min_confidence) {
301 return Err(LearningError::RuleStorageFailed(
302 "Confidence must be between 0.0 and 1.0".to_string(),
303 ));
304 }
305
306 let rules = self.list_rules().await?;
307 Ok(rules
308 .into_iter()
309 .filter(|r| r.confidence >= min_confidence)
310 .collect())
311 }
312
313 pub async fn get_rules_by_usage(&self) -> Result<Vec<Rule>> {
315 let mut rules = self.list_rules().await?;
316 rules.sort_by(|a, b| b.usage_count.cmp(&a.usage_count));
317 Ok(rules)
318 }
319
320 pub async fn get_rules_by_confidence_sorted(&self) -> Result<Vec<Rule>> {
322 let mut rules = self.list_rules().await?;
323 rules.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap_or(std::cmp::Ordering::Equal));
324 Ok(rules)
325 }
326
327 pub async fn get_rules_by_metadata(&self, key: &str, value: &serde_json::Value) -> Result<Vec<Rule>> {
329 let rules = self.list_rules().await?;
330 Ok(rules
331 .into_iter()
332 .filter(|r| r.metadata.get(key).map_or(false, |v| v == value))
333 .collect())
334 }
335
336 pub async fn get_rules_after(&self, timestamp: chrono::DateTime<chrono::Utc>) -> Result<Vec<Rule>> {
338 let rules = self.list_rules().await?;
339 Ok(rules
340 .into_iter()
341 .filter(|r| r.created_at > timestamp)
342 .collect())
343 }
344
345 pub async fn get_rules_updated_after(&self, timestamp: chrono::DateTime<chrono::Utc>) -> Result<Vec<Rule>> {
347 let rules = self.list_rules().await?;
348 Ok(rules
349 .into_iter()
350 .filter(|r| r.updated_at > timestamp)
351 .collect())
352 }
353
354 pub async fn get_rules_by_usage_count(&self, min_usage: u64) -> Result<Vec<Rule>> {
356 let rules = self.list_rules().await?;
357 Ok(rules
358 .into_iter()
359 .filter(|r| r.usage_count >= min_usage)
360 .collect())
361 }
362
363 pub async fn get_rules_by_success_rate(&self, min_success_rate: f32) -> Result<Vec<Rule>> {
365 if !(0.0..=1.0).contains(&min_success_rate) {
366 return Err(LearningError::RuleStorageFailed(
367 "Success rate must be between 0.0 and 1.0".to_string(),
368 ));
369 }
370
371 let rules = self.list_rules().await?;
372 Ok(rules
373 .into_iter()
374 .filter(|r| r.success_rate >= min_success_rate)
375 .collect())
376 }
377
378 pub async fn get_rules_by_version(&self, version: u32) -> Result<Vec<Rule>> {
380 let rules = self.list_rules().await?;
381 Ok(rules.into_iter().filter(|r| r.version == version).collect())
382 }
383
384 pub async fn get_rules_metadata(&self) -> Result<Vec<(String, String, f32, u64)>> {
386 let rules = self.list_rules().await?;
387 Ok(rules
388 .into_iter()
389 .map(|r| (r.id, r.pattern, r.confidence, r.usage_count))
390 .collect())
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::models::RuleSource;
398
399 #[tokio::test]
400 async fn test_session_rule_storage() {
401 let storage = RuleStorage::new(RuleScope::Session);
402
403 let rule = Rule::new(
404 RuleScope::Session,
405 "pattern".to_string(),
406 "action".to_string(),
407 RuleSource::Learned,
408 );
409
410 let rule_id = rule.id.clone();
411 let result = storage.store_rule(rule).await;
412
413 assert!(result.is_ok());
414 assert_eq!(result.unwrap(), rule_id);
415
416 let retrieved = storage.get_rule(&rule_id).await;
417 assert!(retrieved.is_ok());
418 assert_eq!(retrieved.unwrap().id, rule_id);
419 }
420
421 #[tokio::test]
422 async fn test_session_list_rules() {
423 let storage = RuleStorage::new(RuleScope::Session);
424
425 let rule1 = Rule::new(
426 RuleScope::Session,
427 "pattern1".to_string(),
428 "action1".to_string(),
429 RuleSource::Learned,
430 );
431
432 let rule2 = Rule::new(
433 RuleScope::Session,
434 "pattern2".to_string(),
435 "action2".to_string(),
436 RuleSource::Manual,
437 );
438
439 storage.store_rule(rule1).await.unwrap();
440 storage.store_rule(rule2).await.unwrap();
441
442 let rules = storage.list_rules().await.unwrap();
443 assert_eq!(rules.len(), 2);
444 }
445
446 #[tokio::test]
447 async fn test_session_delete_rule() {
448 let storage = RuleStorage::new(RuleScope::Session);
449
450 let rule = Rule::new(
451 RuleScope::Session,
452 "pattern".to_string(),
453 "action".to_string(),
454 RuleSource::Learned,
455 );
456
457 let rule_id = rule.id.clone();
458 storage.store_rule(rule).await.unwrap();
459
460 assert!(storage.get_rule(&rule_id).await.is_ok());
461
462 storage.delete_rule(&rule_id).await.unwrap();
463
464 assert!(storage.get_rule(&rule_id).await.is_err());
465 }
466
467 #[tokio::test]
468 async fn test_session_update_rule() {
469 let storage = RuleStorage::new(RuleScope::Session);
470
471 let mut rule = Rule::new(
472 RuleScope::Session,
473 "pattern".to_string(),
474 "action".to_string(),
475 RuleSource::Learned,
476 );
477
478 let rule_id = rule.id.clone();
479 storage.store_rule(rule.clone()).await.unwrap();
480
481 rule.pattern = "new_pattern".to_string();
482 storage.update_rule(rule).await.unwrap();
483
484 let retrieved = storage.get_rule(&rule_id).await.unwrap();
485 assert_eq!(retrieved.pattern, "new_pattern");
486 }
487
488 #[tokio::test]
489 async fn test_session_rule_count() {
490 let storage = RuleStorage::new(RuleScope::Session);
491
492 assert_eq!(storage.rule_count().await.unwrap(), 0);
493
494 let rule1 = Rule::new(
495 RuleScope::Session,
496 "pattern1".to_string(),
497 "action1".to_string(),
498 RuleSource::Learned,
499 );
500
501 let rule2 = Rule::new(
502 RuleScope::Session,
503 "pattern2".to_string(),
504 "action2".to_string(),
505 RuleSource::Manual,
506 );
507
508 storage.store_rule(rule1).await.unwrap();
509 storage.store_rule(rule2).await.unwrap();
510
511 assert_eq!(storage.rule_count().await.unwrap(), 2);
512 }
513
514 #[tokio::test]
515 async fn test_session_clear_all() {
516 let storage = RuleStorage::new(RuleScope::Session);
517
518 let rule = Rule::new(
519 RuleScope::Session,
520 "pattern".to_string(),
521 "action".to_string(),
522 RuleSource::Learned,
523 );
524
525 storage.store_rule(rule).await.unwrap();
526 assert_eq!(storage.rule_count().await.unwrap(), 1);
527
528 storage.clear_all().await.unwrap();
529 assert_eq!(storage.rule_count().await.unwrap(), 0);
530 }
531
532 #[tokio::test]
533 async fn test_filter_by_scope() {
534 let storage = RuleStorage::new(RuleScope::Session);
535
536 let rule1 = Rule::new(
537 RuleScope::Session,
538 "pattern1".to_string(),
539 "action1".to_string(),
540 RuleSource::Learned,
541 );
542
543 let rule2 = Rule::new(
544 RuleScope::Session,
545 "pattern2".to_string(),
546 "action2".to_string(),
547 RuleSource::Manual,
548 );
549
550 storage.store_rule(rule1).await.unwrap();
551 storage.store_rule(rule2).await.unwrap();
552
553 let filtered = storage.filter_by_scope(RuleScope::Session).await.unwrap();
554 assert_eq!(filtered.len(), 2);
555 }
556
557 #[tokio::test]
558 async fn test_wrong_scope_error() {
559 let storage = RuleStorage::new(RuleScope::Session);
560
561 let rule = Rule::new(
562 RuleScope::Global,
563 "pattern".to_string(),
564 "action".to_string(),
565 RuleSource::Learned,
566 );
567
568 let result = storage.store_rule(rule).await;
569 assert!(result.is_err());
570 }
571
572 #[tokio::test]
573 async fn test_get_nonexistent_rule() {
574 let storage = RuleStorage::new(RuleScope::Session);
575 let result = storage.get_rule("nonexistent").await;
576 assert!(result.is_err());
577 }
578
579 #[tokio::test]
580 async fn test_get_rules_by_pattern() {
581 let storage = RuleStorage::new(RuleScope::Session);
582
583 let rule1 = Rule::new(
584 RuleScope::Session,
585 "pattern_a".to_string(),
586 "action1".to_string(),
587 RuleSource::Learned,
588 );
589
590 let rule2 = Rule::new(
591 RuleScope::Session,
592 "pattern_b".to_string(),
593 "action2".to_string(),
594 RuleSource::Manual,
595 );
596
597 let rule3 = Rule::new(
598 RuleScope::Session,
599 "pattern_a_extended".to_string(),
600 "action3".to_string(),
601 RuleSource::Learned,
602 );
603
604 storage.store_rule(rule1).await.unwrap();
605 storage.store_rule(rule2).await.unwrap();
606 storage.store_rule(rule3).await.unwrap();
607
608 let pattern_a_rules = storage.get_rules_by_pattern("pattern_a").await.unwrap();
609 assert_eq!(pattern_a_rules.len(), 2);
610 }
611
612 #[tokio::test]
613 async fn test_get_rules_by_source() {
614 let storage = RuleStorage::new(RuleScope::Session);
615
616 let rule1 = Rule::new(
617 RuleScope::Session,
618 "pattern1".to_string(),
619 "action1".to_string(),
620 RuleSource::Learned,
621 );
622
623 let rule2 = Rule::new(
624 RuleScope::Session,
625 "pattern2".to_string(),
626 "action2".to_string(),
627 RuleSource::Manual,
628 );
629
630 let rule3 = Rule::new(
631 RuleScope::Session,
632 "pattern3".to_string(),
633 "action3".to_string(),
634 RuleSource::Learned,
635 );
636
637 storage.store_rule(rule1).await.unwrap();
638 storage.store_rule(rule2).await.unwrap();
639 storage.store_rule(rule3).await.unwrap();
640
641 let learned_rules = storage.get_rules_by_source(RuleSource::Learned).await.unwrap();
642 assert_eq!(learned_rules.len(), 2);
643
644 let manual_rules = storage.get_rules_by_source(RuleSource::Manual).await.unwrap();
645 assert_eq!(manual_rules.len(), 1);
646 }
647
648 #[tokio::test]
649 async fn test_get_rules_by_confidence() {
650 let storage = RuleStorage::new(RuleScope::Session);
651
652 let mut rule1 = Rule::new(
653 RuleScope::Session,
654 "pattern1".to_string(),
655 "action1".to_string(),
656 RuleSource::Learned,
657 );
658 rule1.confidence = 0.9;
659
660 let mut rule2 = Rule::new(
661 RuleScope::Session,
662 "pattern2".to_string(),
663 "action2".to_string(),
664 RuleSource::Manual,
665 );
666 rule2.confidence = 0.5;
667
668 let mut rule3 = Rule::new(
669 RuleScope::Session,
670 "pattern3".to_string(),
671 "action3".to_string(),
672 RuleSource::Learned,
673 );
674 rule3.confidence = 0.7;
675
676 storage.store_rule(rule1).await.unwrap();
677 storage.store_rule(rule2).await.unwrap();
678 storage.store_rule(rule3).await.unwrap();
679
680 let high_confidence = storage.get_rules_by_confidence(0.7).await.unwrap();
681 assert_eq!(high_confidence.len(), 2);
682
683 let very_high_confidence = storage.get_rules_by_confidence(0.8).await.unwrap();
684 assert_eq!(very_high_confidence.len(), 1);
685 }
686
687 #[tokio::test]
688 async fn test_get_rules_by_usage() {
689 let storage = RuleStorage::new(RuleScope::Session);
690
691 let mut rule1 = Rule::new(
692 RuleScope::Session,
693 "pattern1".to_string(),
694 "action1".to_string(),
695 RuleSource::Learned,
696 );
697 rule1.usage_count = 10;
698
699 let mut rule2 = Rule::new(
700 RuleScope::Session,
701 "pattern2".to_string(),
702 "action2".to_string(),
703 RuleSource::Manual,
704 );
705 rule2.usage_count = 5;
706
707 let mut rule3 = Rule::new(
708 RuleScope::Session,
709 "pattern3".to_string(),
710 "action3".to_string(),
711 RuleSource::Learned,
712 );
713 rule3.usage_count = 20;
714
715 storage.store_rule(rule1).await.unwrap();
716 storage.store_rule(rule2).await.unwrap();
717 storage.store_rule(rule3).await.unwrap();
718
719 let sorted = storage.get_rules_by_usage().await.unwrap();
720 assert_eq!(sorted.len(), 3);
721 assert_eq!(sorted[0].usage_count, 20);
722 assert_eq!(sorted[1].usage_count, 10);
723 assert_eq!(sorted[2].usage_count, 5);
724 }
725
726 #[tokio::test]
727 async fn test_get_rules_by_usage_count() {
728 let storage = RuleStorage::new(RuleScope::Session);
729
730 let mut rule1 = Rule::new(
731 RuleScope::Session,
732 "pattern1".to_string(),
733 "action1".to_string(),
734 RuleSource::Learned,
735 );
736 rule1.usage_count = 10;
737
738 let mut rule2 = Rule::new(
739 RuleScope::Session,
740 "pattern2".to_string(),
741 "action2".to_string(),
742 RuleSource::Manual,
743 );
744 rule2.usage_count = 5;
745
746 storage.store_rule(rule1).await.unwrap();
747 storage.store_rule(rule2).await.unwrap();
748
749 let high_usage = storage.get_rules_by_usage_count(8).await.unwrap();
750 assert_eq!(high_usage.len(), 1);
751 }
752
753 #[tokio::test]
754 async fn test_get_rules_by_success_rate() {
755 let storage = RuleStorage::new(RuleScope::Session);
756
757 let mut rule1 = Rule::new(
758 RuleScope::Session,
759 "pattern1".to_string(),
760 "action1".to_string(),
761 RuleSource::Learned,
762 );
763 rule1.success_rate = 0.95;
764
765 let mut rule2 = Rule::new(
766 RuleScope::Session,
767 "pattern2".to_string(),
768 "action2".to_string(),
769 RuleSource::Manual,
770 );
771 rule2.success_rate = 0.5;
772
773 storage.store_rule(rule1).await.unwrap();
774 storage.store_rule(rule2).await.unwrap();
775
776 let high_success = storage.get_rules_by_success_rate(0.8).await.unwrap();
777 assert_eq!(high_success.len(), 1);
778 }
779
780 #[tokio::test]
781 async fn test_get_rules_metadata() {
782 let storage = RuleStorage::new(RuleScope::Session);
783
784 let mut rule1 = Rule::new(
785 RuleScope::Session,
786 "pattern1".to_string(),
787 "action1".to_string(),
788 RuleSource::Learned,
789 );
790 rule1.confidence = 0.9;
791 rule1.usage_count = 10;
792
793 let mut rule2 = Rule::new(
794 RuleScope::Session,
795 "pattern2".to_string(),
796 "action2".to_string(),
797 RuleSource::Manual,
798 );
799 rule2.confidence = 0.5;
800 rule2.usage_count = 5;
801
802 storage.store_rule(rule1).await.unwrap();
803 storage.store_rule(rule2).await.unwrap();
804
805 let metadata = storage.get_rules_metadata().await.unwrap();
806 assert_eq!(metadata.len(), 2);
807 }
808
809 #[tokio::test]
810 async fn test_invalid_confidence_threshold() {
811 let storage = RuleStorage::new(RuleScope::Session);
812 let result = storage.get_rules_by_confidence(1.5).await;
813 assert!(result.is_err());
814 }
815
816 #[tokio::test]
817 async fn test_invalid_success_rate_threshold() {
818 let storage = RuleStorage::new(RuleScope::Session);
819 let result = storage.get_rules_by_success_rate(-0.1).await;
820 assert!(result.is_err());
821 }
822}