1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::Path;
4
5#[derive(Debug, Serialize, Deserialize, Default, Clone)]
9pub struct FeedbackTracker {
10 rules: HashMap<String, RuleStats>,
12}
13
14#[derive(Debug, Serialize, Deserialize, Default, Clone)]
16pub struct RuleStats {
17 pub shown: u64,
19 pub dismissed: u64,
21 pub fixed: u64,
23}
24
25impl RuleStats {
26 #[must_use]
28 #[allow(clippy::cast_precision_loss)]
29 pub fn dismiss_rate(&self) -> f64 {
30 if self.shown == 0 {
31 0.0
32 } else {
33 self.dismissed as f64 / self.shown as f64
34 }
35 }
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub struct DisableSuggestion {
41 pub rule_id: String,
42 pub dismiss_rate: f64,
43 pub dismissed_count: u64,
44}
45
46impl FeedbackTracker {
47 #[must_use]
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn record_shown(&mut self, rule_id: &str) {
55 self.rules.entry(rule_id.to_string()).or_default().shown += 1;
56 }
57
58 pub fn record_dismissed(&mut self, rule_id: &str) {
60 let stats = self.rules.entry(rule_id.to_string()).or_default();
61 stats.dismissed += 1;
62 }
63
64 pub fn record_fixed(&mut self, rule_id: &str) {
66 let stats = self.rules.entry(rule_id.to_string()).or_default();
67 stats.fixed += 1;
68 }
69
70 #[must_use]
72 pub fn get_stats(&self, rule_id: &str) -> Option<&RuleStats> {
73 self.rules.get(rule_id)
74 }
75
76 #[must_use]
80 pub fn suggest_disable(&self, threshold: f64, min_shown: u64) -> Vec<DisableSuggestion> {
81 let mut suggestions: Vec<DisableSuggestion> = self
82 .rules
83 .iter()
84 .filter(|(_, stats)| stats.shown >= min_shown && stats.dismiss_rate() > threshold)
85 .map(|(rule_id, stats)| DisableSuggestion {
86 rule_id: rule_id.clone(),
87 dismiss_rate: stats.dismiss_rate(),
88 dismissed_count: stats.dismissed,
89 })
90 .collect();
91
92 suggestions.sort_by(|a, b| {
94 b.dismiss_rate
95 .partial_cmp(&a.dismiss_rate)
96 .unwrap_or(std::cmp::Ordering::Equal)
97 });
98 suggestions
99 }
100
101 #[must_use]
103 pub fn rule_count(&self) -> usize {
104 self.rules.len()
105 }
106
107 pub fn load(path: &Path) -> anyhow::Result<Self> {
109 if path.exists() {
110 let content = std::fs::read_to_string(path)?;
111 let tracker: Self = serde_json::from_str(&content)?;
112 Ok(tracker)
113 } else {
114 Ok(Self::new())
115 }
116 }
117
118 pub fn save(&self, path: &Path) -> anyhow::Result<()> {
120 let content = serde_json::to_string_pretty(self)?;
121 std::fs::write(path, content)?;
122 Ok(())
123 }
124
125 #[must_use]
127 pub fn create_false_positive_report(
128 rule_id: &str,
129 text_snippet: &str,
130 max_snippet_len: usize,
131 ) -> FalsePositiveReport {
132 let snippet = if text_snippet.len() > max_snippet_len {
134 &text_snippet[..max_snippet_len]
135 } else {
136 text_snippet
137 };
138
139 FalsePositiveReport {
140 rule_id: rule_id.to_string(),
141 snippet: snippet.to_string(),
142 }
143 }
144}
145
146#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
148pub struct FalsePositiveReport {
149 pub rule_id: String,
150 pub snippet: String,
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn new_tracker_is_empty() {
159 let tracker = FeedbackTracker::new();
160 assert_eq!(tracker.rule_count(), 0);
161 }
162
163 #[test]
164 fn record_and_retrieve_stats() {
165 let mut tracker = FeedbackTracker::new();
166 tracker.record_shown("spelling.typo");
167 tracker.record_shown("spelling.typo");
168 tracker.record_dismissed("spelling.typo");
169 tracker.record_fixed("spelling.typo");
170
171 let stats = tracker.get_stats("spelling.typo").unwrap();
172 assert_eq!(stats.shown, 2);
173 assert_eq!(stats.dismissed, 1);
174 assert_eq!(stats.fixed, 1);
175 }
176
177 #[test]
178 fn dismiss_rate_calculation() {
179 let stats = RuleStats {
180 shown: 10,
181 dismissed: 8,
182 fixed: 2,
183 };
184 let rate = stats.dismiss_rate();
185 assert!((rate - 0.8).abs() < f64::EPSILON);
186 }
187
188 #[test]
189 fn dismiss_rate_zero_shown() {
190 let stats = RuleStats::default();
191 assert!((stats.dismiss_rate()).abs() < f64::EPSILON);
192 }
193
194 #[test]
195 fn suggest_disable_above_threshold() {
196 let mut tracker = FeedbackTracker::new();
197
198 for _ in 0..10 {
200 tracker.record_shown("noisy.rule");
201 tracker.record_dismissed("noisy.rule");
202 }
203
204 for _ in 0..10 {
206 tracker.record_shown("useful.rule");
207 }
208 tracker.record_dismissed("useful.rule");
209
210 tracker.record_shown("rare.rule");
212 tracker.record_dismissed("rare.rule");
213
214 let suggestions = tracker.suggest_disable(0.5, 5);
215 assert_eq!(suggestions.len(), 1);
216 assert_eq!(suggestions[0].rule_id, "noisy.rule");
217 assert!((suggestions[0].dismiss_rate - 1.0).abs() < f64::EPSILON);
218 }
219
220 #[test]
221 fn save_and_load_roundtrip() {
222 let dir = std::env::temp_dir().join("lang_check_feedback_test");
223 let _ = std::fs::remove_dir_all(&dir);
224 std::fs::create_dir_all(&dir).unwrap();
225 let path = dir.join("feedback.json");
226
227 let mut tracker = FeedbackTracker::new();
228 tracker.record_shown("test.rule");
229 tracker.record_dismissed("test.rule");
230 tracker.save(&path).unwrap();
231
232 let loaded = FeedbackTracker::load(&path).unwrap();
233 let stats = loaded.get_stats("test.rule").unwrap();
234 assert_eq!(stats.shown, 1);
235 assert_eq!(stats.dismissed, 1);
236
237 let _ = std::fs::remove_dir_all(&dir);
238 }
239
240 #[test]
241 fn load_missing_file_returns_empty() {
242 let path = std::env::temp_dir().join("lang_check_feedback_nonexistent.json");
243 let tracker = FeedbackTracker::load(&path).unwrap();
244 assert_eq!(tracker.rule_count(), 0);
245 }
246
247 #[test]
248 fn false_positive_report() {
249 let report = FeedbackTracker::create_false_positive_report(
250 "spelling.typo",
251 "This is a perfectly valid sentence.",
252 50,
253 );
254 assert_eq!(report.rule_id, "spelling.typo");
255 assert_eq!(report.snippet, "This is a perfectly valid sentence.");
256 }
257
258 #[test]
259 fn false_positive_report_truncation() {
260 let long_text = "a".repeat(200);
261 let report = FeedbackTracker::create_false_positive_report("test.rule", &long_text, 50);
262 assert_eq!(report.snippet.len(), 50);
263 }
264}