1use crate::db::Database;
22use crate::persistence::Persistence;
23
24pub async fn track_progress(
26 db: &Database,
27 session_id: &str,
28 tool_name: &str,
29 _tool_args: &str,
30 tool_result: &str,
31) {
32 let entry = match tool_name {
33 "Write" => extract_write_progress(tool_result),
34 "Edit" => extract_edit_progress(tool_result),
35 "Delete" => extract_delete_progress(tool_result),
36 "Bash" => extract_bash_progress(tool_result),
37 _ => None,
38 };
39
40 if let Some(entry) = entry {
41 append_progress(db, session_id, &entry).await;
42 }
43}
44
45pub async fn get_progress_summary(db: &Database, session_id: &str) -> Option<String> {
47 match db.get_metadata(session_id, "progress").await {
48 Ok(Some(progress)) if !progress.is_empty() => Some(format!(
49 "\n## Session Progress\n\
50 The following actions have been completed this session:\n\
51 {progress}"
52 )),
53 _ => None,
54 }
55}
56
57async fn append_progress(db: &Database, session_id: &str, entry: &str) {
58 let existing = db
59 .get_metadata(session_id, "progress")
60 .await
61 .ok()
62 .flatten()
63 .unwrap_or_default();
64
65 let lines: Vec<&str> = existing.lines().collect();
67 let mut updated = if lines.len() >= 20 {
68 lines[lines.len() - 15..].join("\n")
70 } else {
71 existing
72 };
73
74 if !updated.is_empty() {
75 updated.push('\n');
76 }
77 updated.push_str(entry);
78
79 let _ = db.set_metadata(session_id, "progress", &updated).await;
80}
81
82fn extract_write_progress(result: &str) -> Option<String> {
83 if result.contains("Created") || result.contains("Wrote") {
85 let path = result.lines().next().unwrap_or(result).trim();
86 Some(format!("- \u{2705} {path}"))
87 } else {
88 None
89 }
90}
91
92fn extract_edit_progress(result: &str) -> Option<String> {
93 if result.contains("Applied") || result.contains("edited") || result.contains("replacement") {
94 let first_line = result.lines().next().unwrap_or(result).trim();
95 let short = if first_line.len() > 80 {
96 format!("{}...", &first_line[..80])
97 } else {
98 first_line.to_string()
99 };
100 Some(format!("- \u{270f}\u{fe0f} {short}"))
101 } else {
102 None
103 }
104}
105
106fn extract_delete_progress(result: &str) -> Option<String> {
107 if result.contains("Deleted") || result.contains("removed") {
108 let first_line = result.lines().next().unwrap_or(result).trim();
109 Some(format!("- \u{1f5d1}\u{fe0f} {first_line}"))
110 } else {
111 None
112 }
113}
114
115fn extract_bash_progress(result: &str) -> Option<String> {
116 if result.contains("Background process started") {
118 let cmd = result
120 .lines()
121 .find(|l| l.trim_start().starts_with("Command:"))
122 .and_then(|l| l.split_once(':').map(|(_, v)| v))
123 .map(|s| s.trim())
124 .unwrap_or("?");
125 let short = if cmd.len() > 60 {
126 format!("{}...", &cmd[..60])
127 } else {
128 cmd.to_string()
129 };
130 return Some(format!("- \u{1f4e1} Started background: {short}"));
131 }
132 let lower = result.to_lowercase();
134 if lower.contains("test result: ok") || lower.contains("tests passed") {
135 Some("- \u{2705} Tests passed".to_string())
136 } else if lower.contains("test result: failed") || lower.contains("tests failed") {
137 Some("- \u{274c} Tests failed".to_string())
138 } else if lower.contains("build succeeded")
139 || lower.contains("finished") && lower.contains("target")
140 {
141 Some("- \u{1f3d7}\u{fe0f} Build succeeded".to_string())
142 } else if lower.contains("error:") && lower.contains("could not compile") {
143 Some("- \u{274c} Build failed".to_string())
144 } else {
145 None
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_write_progress() {
155 assert!(extract_write_progress("Created file: src/main.rs").is_some());
156 assert!(extract_write_progress("Wrote 100 bytes to foo.rs").is_some());
157 assert!(extract_write_progress("Error: permission denied").is_none());
158 }
159
160 #[test]
161 fn test_edit_progress() {
162 assert!(extract_edit_progress("Applied 2 replacements to src/lib.rs").is_some());
163 assert!(extract_edit_progress("No changes needed").is_none());
164 }
165
166 #[test]
167 fn test_bash_progress() {
168 assert!(extract_bash_progress("test result: ok. 50 passed").is_some());
169 assert!(extract_bash_progress("test result: FAILED. 1 failed").is_some());
170 assert!(extract_bash_progress("hello world").is_none());
171 }
172
173 #[test]
174 fn test_delete_progress() {
175 assert!(extract_delete_progress("Deleted src/old.rs").is_some());
176 assert!(extract_delete_progress("File not found").is_none());
177 }
178
179 #[test]
182 fn test_write_progress_content_includes_first_line() {
183 let result = extract_write_progress("Created file: src/main.rs").unwrap();
184 assert!(
185 result.contains("Created file: src/main.rs"),
186 "got: {result}"
187 );
188 }
189
190 #[test]
191 fn test_delete_progress_removed_keyword() {
192 assert!(extract_delete_progress("5 files removed successfully").is_some());
194 }
195
196 #[test]
197 fn test_delete_progress_content_includes_first_line() {
198 let result = extract_delete_progress("Deleted src/old.rs").unwrap();
199 assert!(result.contains("Deleted src/old.rs"), "got: {result}");
200 }
201
202 #[test]
203 fn test_edit_progress_long_line_is_truncated() {
204 let long = "Applied replacements to ".to_string() + &"x".repeat(100);
205 let result = extract_edit_progress(&long).unwrap();
206 assert!(
208 result.contains("..."),
209 "long line should be truncated: {result}"
210 );
211 }
212
213 #[test]
214 fn test_edit_progress_short_line_not_truncated() {
215 let result = extract_edit_progress("Applied 3 replacements to src/lib.rs").unwrap();
216 assert!(
217 !result.contains("..."),
218 "short line should not be truncated: {result}"
219 );
220 }
221
222 #[test]
225 fn test_bash_progress_background_process() {
226 let result = extract_bash_progress(
227 "Background process started\n Command: cargo watch -x test\n PID: 1234",
228 )
229 .unwrap();
230 assert!(
231 result.contains("cargo watch"),
232 "should include command: {result}"
233 );
234 assert!(result.contains("Started background"), "got: {result}");
235 }
236
237 #[test]
238 fn test_bash_progress_build_cargo_finish() {
239 let result =
241 extract_bash_progress(" Finished `release` profile [optimized] target(s)").unwrap();
242 assert!(result.contains("Build succeeded"), "got: {result}");
243 }
244
245 #[test]
246 fn test_bash_progress_build_succeeded_literal() {
247 let result = extract_bash_progress("build succeeded").unwrap();
248 assert!(result.contains("Build succeeded"), "got: {result}");
249 }
250
251 #[test]
252 fn test_bash_progress_build_failed() {
253 let result =
254 extract_bash_progress("error: could not compile `myapp` due to 3 errors").unwrap();
255 assert!(result.contains("Build failed"), "got: {result}");
256 }
257
258 #[test]
259 fn test_bash_progress_tests_passed_text() {
260 let result = extract_bash_progress("All 42 tests passed").unwrap();
262 assert!(result.contains("Tests passed"), "got: {result}");
263 }
264
265 #[test]
266 fn test_bash_progress_tests_failed_text() {
267 let result = extract_bash_progress("3 tests failed").unwrap();
268 assert!(result.contains("Tests failed"), "got: {result}");
269 }
270
271 #[test]
272 fn test_bash_progress_background_long_command_truncated() {
273 let long_cmd = "x".repeat(80);
274 let output = format!("Background process started\n Command: {long_cmd}\n");
275 let result = extract_bash_progress(&output).unwrap();
276 assert!(
277 result.contains("..."),
278 "long command should be truncated: {result}"
279 );
280 }
281
282 async fn test_db() -> (Database, tempfile::TempDir, String) {
285 let dir = tempfile::TempDir::new().unwrap();
286 let db = Database::open(&dir.path().join("progress_test.db"))
287 .await
288 .unwrap();
289 let sid = db.create_session("koda", dir.path()).await.unwrap();
290 (db, dir, sid)
291 }
292
293 #[tokio::test]
294 async fn test_get_progress_summary_empty() {
295 let (db, _dir, sid) = test_db().await;
296 let result = get_progress_summary(&db, &sid).await;
297 assert!(result.is_none(), "empty session should return None");
298 }
299
300 #[tokio::test]
301 async fn test_track_progress_write_tool() {
302 let (db, _dir, sid) = test_db().await;
303 track_progress(&db, &sid, "Write", "", "Created file: src/main.rs").await;
304 let summary = get_progress_summary(&db, &sid).await.unwrap();
305 assert!(summary.contains("src/main.rs"), "got: {summary}");
306 assert!(summary.contains("Session Progress"), "got: {summary}");
307 }
308
309 #[tokio::test]
310 async fn test_track_progress_edit_tool() {
311 let (db, _dir, sid) = test_db().await;
312 track_progress(
313 &db,
314 &sid,
315 "Edit",
316 "",
317 "Applied 2 replacements to src/lib.rs",
318 )
319 .await;
320 let summary = get_progress_summary(&db, &sid).await.unwrap();
321 assert!(summary.contains("lib.rs"), "got: {summary}");
322 }
323
324 #[tokio::test]
325 async fn test_track_progress_bash_tool() {
326 let (db, _dir, sid) = test_db().await;
327 track_progress(
328 &db,
329 &sid,
330 "Bash",
331 "",
332 "test result: ok. 10 passed; 0 failed",
333 )
334 .await;
335 let summary = get_progress_summary(&db, &sid).await.unwrap();
336 assert!(summary.contains("Tests passed"), "got: {summary}");
337 }
338
339 #[tokio::test]
340 async fn test_track_progress_delete_tool() {
341 let (db, _dir, sid) = test_db().await;
342 track_progress(&db, &sid, "Delete", "", "Deleted src/old.rs").await;
343 let summary = get_progress_summary(&db, &sid).await.unwrap();
344 assert!(summary.contains("Deleted"), "got: {summary}");
345 }
346
347 #[tokio::test]
348 async fn test_track_progress_unknown_tool_no_entry() {
349 let (db, _dir, sid) = test_db().await;
350 track_progress(&db, &sid, "UnknownTool", "", "some result").await;
351 let summary = get_progress_summary(&db, &sid).await;
353 assert!(summary.is_none(), "unknown tool should not create progress");
354 }
355
356 #[tokio::test]
357 async fn test_progress_accumulates_multiple_entries() {
358 let (db, _dir, sid) = test_db().await;
359 track_progress(&db, &sid, "Write", "", "Created file: a.rs").await;
360 track_progress(&db, &sid, "Write", "", "Created file: b.rs").await;
361 let summary = get_progress_summary(&db, &sid).await.unwrap();
362 assert!(summary.contains("a.rs"), "got: {summary}");
363 assert!(summary.contains("b.rs"), "got: {summary}");
364 }
365}