1use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::{Path, PathBuf};
7
8const RULES_FILE_NAME: &str = "rules.toml";
9
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct UserRules {
12 #[serde(default)]
13 pub extraction: Vec<ExtractionRule>,
14 #[serde(default)]
15 pub context: Vec<ContextRule>,
16 #[serde(default)]
17 pub suppress: Vec<SuppressRule>,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ExtractionRule {
22 pub trigger: String,
23 #[serde(default = "default_memory_type")]
24 pub memory_type: String,
25 #[serde(default)]
26 pub description: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ContextRule {
31 #[serde(default = "default_scope")]
32 pub scope: String,
33 pub always_include: Vec<String>,
34 #[serde(default)]
35 pub description: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct SuppressRule {
40 pub pattern: String,
41 #[serde(default = "default_action")]
42 pub action: String,
43 #[serde(default)]
44 pub description: String,
45}
46
47fn default_memory_type() -> String {
48 "preference".to_string()
49}
50fn default_scope() -> String {
51 "project".to_string()
52}
53fn default_action() -> String {
54 "skip".to_string()
55}
56
57pub fn rules_path(spool_root: &Path) -> PathBuf {
58 spool_root.join(RULES_FILE_NAME)
59}
60
61pub fn load(spool_root: &Path) -> UserRules {
62 let path = rules_path(spool_root);
63 match fs::read_to_string(&path) {
64 Ok(content) => toml::from_str(&content).unwrap_or_default(),
65 Err(_) => UserRules::default(),
66 }
67}
68
69pub fn save(spool_root: &Path, rules: &UserRules) -> anyhow::Result<()> {
70 let path = rules_path(spool_root);
71 if let Some(parent) = path.parent() {
72 fs::create_dir_all(parent)?;
73 }
74 let content = toml::to_string_pretty(rules)?;
75 fs::write(&path, content)?;
76 Ok(())
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use tempfile::tempdir;
83
84 #[test]
85 fn load_returns_default_when_file_missing() {
86 let temp = tempdir().unwrap();
87 let rules = load(temp.path());
88 assert!(rules.extraction.is_empty());
89 assert!(rules.context.is_empty());
90 assert!(rules.suppress.is_empty());
91 }
92
93 #[test]
94 fn save_and_load_roundtrip() {
95 let temp = tempdir().unwrap();
96 let rules = UserRules {
97 extraction: vec![ExtractionRule {
98 trigger: "技术选型".to_string(),
99 memory_type: "decision".to_string(),
100 description: "技术选型相关决策".to_string(),
101 }],
102 context: vec![ContextRule {
103 scope: "project".to_string(),
104 always_include: vec!["架构约束".to_string()],
105 description: "".to_string(),
106 }],
107 suppress: vec![SuppressRule {
108 pattern: "临时.*测试".to_string(),
109 action: "skip".to_string(),
110 description: "跳过临时测试内容".to_string(),
111 }],
112 };
113 save(temp.path(), &rules).unwrap();
114 let loaded = load(temp.path());
115 assert_eq!(loaded.extraction.len(), 1);
116 assert_eq!(loaded.extraction[0].trigger, "技术选型");
117 assert_eq!(loaded.context[0].always_include[0], "架构约束");
118 assert_eq!(loaded.suppress[0].pattern, "临时.*测试");
119 }
120
121 #[test]
122 fn load_handles_partial_toml() {
123 let temp = tempdir().unwrap();
124 fs::write(
125 temp.path().join(RULES_FILE_NAME),
126 "[[extraction]]\ntrigger = \"test\"\n",
127 )
128 .unwrap();
129 let rules = load(temp.path());
130 assert_eq!(rules.extraction.len(), 1);
131 assert!(rules.context.is_empty());
132 }
133}