1use crate::error::{LearningError, Result};
3use crate::models::{Decision, DecisionContext};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8pub struct DecisionLogger {
10 decisions: Arc<RwLock<Vec<Decision>>>,
12 decision_index: Arc<RwLock<HashMap<String, usize>>>,
14 context_index: Arc<RwLock<HashMap<String, Vec<usize>>>>,
16}
17
18impl DecisionLogger {
19 pub fn new() -> Self {
21 Self {
22 decisions: Arc::new(RwLock::new(Vec::new())),
23 decision_index: Arc::new(RwLock::new(HashMap::new())),
24 context_index: Arc::new(RwLock::new(HashMap::new())),
25 }
26 }
27
28 pub async fn log_decision(&self, decision: Decision) -> Result<String> {
30 let decision_id = decision.id.clone();
31 let context_key = self.make_context_key(&decision.context);
32
33 let mut decisions = self.decisions.write().await;
34 let index = decisions.len();
35 decisions.push(decision);
36
37 let mut decision_index = self.decision_index.write().await;
39 decision_index.insert(decision_id.clone(), index);
40
41 let mut context_index = self.context_index.write().await;
42 context_index
43 .entry(context_key)
44 .or_insert_with(Vec::new)
45 .push(index);
46
47 Ok(decision_id)
48 }
49
50 pub async fn get_history(&self) -> Vec<Decision> {
52 self.decisions.read().await.clone()
53 }
54
55 pub async fn get_history_by_context(&self, context: &DecisionContext) -> Vec<Decision> {
57 let context_key = self.make_context_key(context);
58 let context_index = self.context_index.read().await;
59
60 if let Some(indices) = context_index.get(&context_key) {
61 let decisions = self.decisions.read().await;
62 indices
63 .iter()
64 .filter_map(|&idx| decisions.get(idx).cloned())
65 .collect()
66 } else {
67 Vec::new()
68 }
69 }
70
71 pub async fn get_history_by_type(&self, decision_type: &str) -> Vec<Decision> {
73 self.decisions
74 .read()
75 .await
76 .iter()
77 .filter(|d| d.decision_type == decision_type)
78 .cloned()
79 .collect()
80 }
81
82 pub async fn get_decision(&self, decision_id: &str) -> Result<Decision> {
84 let decision_index = self.decision_index.read().await;
85
86 if let Some(&idx) = decision_index.get(decision_id) {
87 let decisions = self.decisions.read().await;
88 decisions
89 .get(idx)
90 .cloned()
91 .ok_or_else(|| LearningError::DecisionNotFound(decision_id.to_string()))
92 } else {
93 Err(LearningError::DecisionNotFound(decision_id.to_string()))
94 }
95 }
96
97 pub async fn replay_decisions(&self) -> Vec<Decision> {
99 self.decisions.read().await.clone()
100 }
101
102 pub async fn replay_decisions_for_context(&self, context: &DecisionContext) -> Vec<Decision> {
104 self.get_history_by_context(context).await
105 }
106
107 pub async fn decision_count(&self) -> usize {
109 self.decisions.read().await.len()
110 }
111
112 pub async fn clear(&self) {
114 self.decisions.write().await.clear();
115 self.decision_index.write().await.clear();
116 self.context_index.write().await.clear();
117 }
118
119 pub async fn get_statistics(&self) -> DecisionStatistics {
121 let decisions = self.decisions.read().await;
122
123 let mut type_counts: HashMap<String, usize> = HashMap::new();
124 let mut agent_counts: HashMap<String, usize> = HashMap::new();
125
126 for decision in decisions.iter() {
127 *type_counts.entry(decision.decision_type.clone()).or_insert(0) += 1;
128 *agent_counts
129 .entry(decision.context.agent_type.clone())
130 .or_insert(0) += 1;
131 }
132
133 DecisionStatistics {
134 total_decisions: decisions.len(),
135 decision_types: type_counts,
136 agent_types: agent_counts,
137 }
138 }
139
140 fn make_context_key(&self, context: &DecisionContext) -> String {
142 format!(
143 "{}:{}",
144 context.project_path.display(),
145 context.file_path.display()
146 )
147 }
148}
149
150impl Default for DecisionLogger {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156#[derive(Debug, Clone)]
158pub struct DecisionStatistics {
159 pub total_decisions: usize,
161 pub decision_types: HashMap<String, usize>,
163 pub agent_types: HashMap<String, usize>,
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use std::path::PathBuf;
171
172 fn create_test_decision(
173 decision_type: &str,
174 agent_type: &str,
175 project_path: &str,
176 file_path: &str,
177 ) -> Decision {
178 let context = DecisionContext {
179 project_path: PathBuf::from(project_path),
180 file_path: PathBuf::from(file_path),
181 line_number: 10,
182 agent_type: agent_type.to_string(),
183 };
184
185 Decision::new(
186 context,
187 decision_type.to_string(),
188 serde_json::json!({"input": "test"}),
189 serde_json::json!({"output": "result"}),
190 )
191 }
192
193 #[tokio::test]
194 async fn test_log_decision() {
195 let logger = DecisionLogger::new();
196
197 let decision = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
198 let decision_id = decision.id.clone();
199
200 let result = logger.log_decision(decision).await;
201 assert!(result.is_ok());
202 assert_eq!(result.unwrap(), decision_id);
203
204 assert_eq!(logger.decision_count().await, 1);
205 }
206
207 #[tokio::test]
208 async fn test_get_history() {
209 let logger = DecisionLogger::new();
210
211 let decision1 = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
212 let decision2 = create_test_decision("refactor", "agent2", "/project", "/project/src/lib.rs");
213
214 logger.log_decision(decision1).await.unwrap();
215 logger.log_decision(decision2).await.unwrap();
216
217 let history = logger.get_history().await;
218 assert_eq!(history.len(), 2);
219 }
220
221 #[tokio::test]
222 async fn test_get_history_by_type() {
223 let logger = DecisionLogger::new();
224
225 let decision1 = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
226 let decision2 = create_test_decision("refactor", "agent2", "/project", "/project/src/lib.rs");
227 let decision3 = create_test_decision("code_gen", "agent1", "/project", "/project/src/utils.rs");
228
229 logger.log_decision(decision1).await.unwrap();
230 logger.log_decision(decision2).await.unwrap();
231 logger.log_decision(decision3).await.unwrap();
232
233 let code_gen_decisions = logger.get_history_by_type("code_gen").await;
234 assert_eq!(code_gen_decisions.len(), 2);
235
236 let refactor_decisions = logger.get_history_by_type("refactor").await;
237 assert_eq!(refactor_decisions.len(), 1);
238 }
239
240 #[tokio::test]
241 async fn test_get_history_by_context() {
242 let logger = DecisionLogger::new();
243
244 let context1 = DecisionContext {
245 project_path: PathBuf::from("/project1"),
246 file_path: PathBuf::from("/project1/src/main.rs"),
247 line_number: 10,
248 agent_type: "agent1".to_string(),
249 };
250
251 let context2 = DecisionContext {
252 project_path: PathBuf::from("/project2"),
253 file_path: PathBuf::from("/project2/src/main.rs"),
254 line_number: 20,
255 agent_type: "agent2".to_string(),
256 };
257
258 let decision1 = Decision::new(
259 context1.clone(),
260 "code_gen".to_string(),
261 serde_json::json!({}),
262 serde_json::json!({}),
263 );
264
265 let decision2 = Decision::new(
266 context2.clone(),
267 "refactor".to_string(),
268 serde_json::json!({}),
269 serde_json::json!({}),
270 );
271
272 let decision3 = Decision::new(
273 context1.clone(),
274 "code_gen".to_string(),
275 serde_json::json!({}),
276 serde_json::json!({}),
277 );
278
279 logger.log_decision(decision1).await.unwrap();
280 logger.log_decision(decision2).await.unwrap();
281 logger.log_decision(decision3).await.unwrap();
282
283 let context1_decisions = logger.get_history_by_context(&context1).await;
284 assert_eq!(context1_decisions.len(), 2);
285
286 let context2_decisions = logger.get_history_by_context(&context2).await;
287 assert_eq!(context2_decisions.len(), 1);
288 }
289
290 #[tokio::test]
291 async fn test_get_decision() {
292 let logger = DecisionLogger::new();
293
294 let decision = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
295 let decision_id = decision.id.clone();
296
297 logger.log_decision(decision.clone()).await.unwrap();
298
299 let retrieved = logger.get_decision(&decision_id).await;
300 assert!(retrieved.is_ok());
301 assert_eq!(retrieved.unwrap().id, decision_id);
302 }
303
304 #[tokio::test]
305 async fn test_get_decision_not_found() {
306 let logger = DecisionLogger::new();
307
308 let result = logger.get_decision("nonexistent").await;
309 assert!(result.is_err());
310 }
311
312 #[tokio::test]
313 async fn test_replay_decisions() {
314 let logger = DecisionLogger::new();
315
316 let decision1 = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
317 let decision2 = create_test_decision("refactor", "agent2", "/project", "/project/src/lib.rs");
318
319 logger.log_decision(decision1).await.unwrap();
320 logger.log_decision(decision2).await.unwrap();
321
322 let replayed = logger.replay_decisions().await;
323 assert_eq!(replayed.len(), 2);
324 assert_eq!(replayed[0].decision_type, "code_gen");
325 assert_eq!(replayed[1].decision_type, "refactor");
326 }
327
328 #[tokio::test]
329 async fn test_replay_decisions_for_context() {
330 let logger = DecisionLogger::new();
331
332 let context = DecisionContext {
333 project_path: PathBuf::from("/project"),
334 file_path: PathBuf::from("/project/src/main.rs"),
335 line_number: 10,
336 agent_type: "agent1".to_string(),
337 };
338
339 let decision1 = Decision::new(
340 context.clone(),
341 "code_gen".to_string(),
342 serde_json::json!({}),
343 serde_json::json!({}),
344 );
345
346 let decision2 = Decision::new(
347 context.clone(),
348 "refactor".to_string(),
349 serde_json::json!({}),
350 serde_json::json!({}),
351 );
352
353 logger.log_decision(decision1).await.unwrap();
354 logger.log_decision(decision2).await.unwrap();
355
356 let replayed = logger.replay_decisions_for_context(&context).await;
357 assert_eq!(replayed.len(), 2);
358 }
359
360 #[tokio::test]
361 async fn test_decision_count() {
362 let logger = DecisionLogger::new();
363
364 assert_eq!(logger.decision_count().await, 0);
365
366 let decision1 = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
367 logger.log_decision(decision1).await.unwrap();
368
369 assert_eq!(logger.decision_count().await, 1);
370
371 let decision2 = create_test_decision("refactor", "agent2", "/project", "/project/src/lib.rs");
372 logger.log_decision(decision2).await.unwrap();
373
374 assert_eq!(logger.decision_count().await, 2);
375 }
376
377 #[tokio::test]
378 async fn test_clear() {
379 let logger = DecisionLogger::new();
380
381 let decision = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
382 logger.log_decision(decision).await.unwrap();
383
384 assert_eq!(logger.decision_count().await, 1);
385
386 logger.clear().await;
387
388 assert_eq!(logger.decision_count().await, 0);
389 }
390
391 #[tokio::test]
392 async fn test_get_statistics() {
393 let logger = DecisionLogger::new();
394
395 let decision1 = create_test_decision("code_gen", "agent1", "/project", "/project/src/main.rs");
396 let decision2 = create_test_decision("refactor", "agent2", "/project", "/project/src/lib.rs");
397 let decision3 = create_test_decision("code_gen", "agent1", "/project", "/project/src/utils.rs");
398
399 logger.log_decision(decision1).await.unwrap();
400 logger.log_decision(decision2).await.unwrap();
401 logger.log_decision(decision3).await.unwrap();
402
403 let stats = logger.get_statistics().await;
404
405 assert_eq!(stats.total_decisions, 3);
406 assert_eq!(stats.decision_types.get("code_gen"), Some(&2));
407 assert_eq!(stats.decision_types.get("refactor"), Some(&1));
408 assert_eq!(stats.agent_types.get("agent1"), Some(&2));
409 assert_eq!(stats.agent_types.get("agent2"), Some(&1));
410 }
411
412 #[tokio::test]
413 async fn test_multiple_decisions_same_context() {
414 let logger = DecisionLogger::new();
415
416 let context = DecisionContext {
417 project_path: PathBuf::from("/project"),
418 file_path: PathBuf::from("/project/src/main.rs"),
419 line_number: 10,
420 agent_type: "agent1".to_string(),
421 };
422
423 for i in 0..5 {
424 let decision = Decision::new(
425 context.clone(),
426 format!("type_{}", i),
427 serde_json::json!({}),
428 serde_json::json!({}),
429 );
430 logger.log_decision(decision).await.unwrap();
431 }
432
433 let context_decisions = logger.get_history_by_context(&context).await;
434 assert_eq!(context_decisions.len(), 5);
435 }
436}