Skip to main content

koda_core/
progress.rs

1//! Structured progress tracking.
2//!
3//! Auto-extracts progress from tool results into DB metadata.
4//! Survives compaction. Injected into system prompt so the LLM
5//! always knows what's been done even after context is trimmed.
6//!
7//! ## Why progress tracking exists
8//!
9//! When compaction summarizes old messages, the model loses awareness of
10//! what files were created/modified and what steps were completed. Progress
11//! entries are stored in the DB (not in messages) and re-injected into
12//! the system prompt, providing a persistent "done" list.
13//!
14//! ## What gets tracked
15//!
16//! - Files created (Write tool)
17//! - Files modified (Edit tool)
18//! - Tests run and their results (Bash tool with test patterns)
19//! - Commands executed with exit codes
20
21use crate::db::Database;
22use crate::persistence::Persistence;
23
24/// Extract progress from a tool call and persist it.
25pub async fn track_progress(
26    db: &Database,
27    session_id: &str,
28    tool_name: &str,
29    _tool_args: &str,
30    tool_result: &str,
31) {
32    let entry = match tool_name {
33        "Write" => extract_write_progress(tool_result),
34        "Edit" => extract_edit_progress(tool_result),
35        "Delete" => extract_delete_progress(tool_result),
36        "Bash" => extract_bash_progress(tool_result),
37        _ => None,
38    };
39
40    if let Some(entry) = entry {
41        append_progress(db, session_id, &entry).await;
42    }
43}
44
45/// Get the current progress summary for injection into the system prompt.
46pub async fn get_progress_summary(db: &Database, session_id: &str) -> Option<String> {
47    match db.get_metadata(session_id, "progress").await {
48        Ok(Some(progress)) if !progress.is_empty() => Some(format!(
49            "\n## Session Progress\n\
50                 The following actions have been completed this session:\n\
51                 {progress}"
52        )),
53        _ => None,
54    }
55}
56
57async fn append_progress(db: &Database, session_id: &str, entry: &str) {
58    let existing = db
59        .get_metadata(session_id, "progress")
60        .await
61        .ok()
62        .flatten()
63        .unwrap_or_default();
64
65    // Cap at 20 entries to avoid unbounded growth
66    let lines: Vec<&str> = existing.lines().collect();
67    let mut updated = if lines.len() >= 20 {
68        // Keep last 15 + new entry
69        lines[lines.len() - 15..].join("\n")
70    } else {
71        existing
72    };
73
74    if !updated.is_empty() {
75        updated.push('\n');
76    }
77    updated.push_str(entry);
78
79    let _ = db.set_metadata(session_id, "progress", &updated).await;
80}
81
82fn extract_write_progress(result: &str) -> Option<String> {
83    // Write tool output: "Created file: path" or "Wrote N bytes to path"
84    if result.contains("Created") || result.contains("Wrote") {
85        let path = result.lines().next().unwrap_or(result).trim();
86        Some(format!("- \u{2705} {path}"))
87    } else {
88        None
89    }
90}
91
92fn extract_edit_progress(result: &str) -> Option<String> {
93    if result.contains("Applied") || result.contains("edited") || result.contains("replacement") {
94        let first_line = result.lines().next().unwrap_or(result).trim();
95        let short = if first_line.len() > 80 {
96            format!("{}...", &first_line[..80])
97        } else {
98            first_line.to_string()
99        };
100        Some(format!("- \u{270f}\u{fe0f} {short}"))
101    } else {
102        None
103    }
104}
105
106fn extract_delete_progress(result: &str) -> Option<String> {
107    if result.contains("Deleted") || result.contains("removed") {
108        let first_line = result.lines().next().unwrap_or(result).trim();
109        Some(format!("- \u{1f5d1}\u{fe0f} {first_line}"))
110    } else {
111        None
112    }
113}
114
115fn extract_bash_progress(result: &str) -> Option<String> {
116    // Background process start
117    if result.contains("Background process started") {
118        // Extract the command from the result: "  Command: <cmd>"
119        let cmd = result
120            .lines()
121            .find(|l| l.trim_start().starts_with("Command:"))
122            .and_then(|l| l.split_once(':').map(|(_, v)| v))
123            .map(|s| s.trim())
124            .unwrap_or("?");
125        let short = if cmd.len() > 60 {
126            format!("{}...", &cmd[..60])
127        } else {
128            cmd.to_string()
129        };
130        return Some(format!("- \u{1f4e1} Started background: {short}"));
131    }
132    // Track test results and build outcomes
133    let lower = result.to_lowercase();
134    if lower.contains("test result: ok") || lower.contains("tests passed") {
135        Some("- \u{2705} Tests passed".to_string())
136    } else if lower.contains("test result: failed") || lower.contains("tests failed") {
137        Some("- \u{274c} Tests failed".to_string())
138    } else if lower.contains("build succeeded")
139        || lower.contains("finished") && lower.contains("target")
140    {
141        Some("- \u{1f3d7}\u{fe0f} Build succeeded".to_string())
142    } else if lower.contains("error:") && lower.contains("could not compile") {
143        Some("- \u{274c} Build failed".to_string())
144    } else {
145        None
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_write_progress() {
155        assert!(extract_write_progress("Created file: src/main.rs").is_some());
156        assert!(extract_write_progress("Wrote 100 bytes to foo.rs").is_some());
157        assert!(extract_write_progress("Error: permission denied").is_none());
158    }
159
160    #[test]
161    fn test_edit_progress() {
162        assert!(extract_edit_progress("Applied 2 replacements to src/lib.rs").is_some());
163        assert!(extract_edit_progress("No changes needed").is_none());
164    }
165
166    #[test]
167    fn test_bash_progress() {
168        assert!(extract_bash_progress("test result: ok. 50 passed").is_some());
169        assert!(extract_bash_progress("test result: FAILED. 1 failed").is_some());
170        assert!(extract_bash_progress("hello world").is_none());
171    }
172
173    #[test]
174    fn test_delete_progress() {
175        assert!(extract_delete_progress("Deleted src/old.rs").is_some());
176        assert!(extract_delete_progress("File not found").is_none());
177    }
178
179    // ── content / format assertions ───────────────────────────────────────
180
181    #[test]
182    fn test_write_progress_content_includes_first_line() {
183        let result = extract_write_progress("Created file: src/main.rs").unwrap();
184        assert!(
185            result.contains("Created file: src/main.rs"),
186            "got: {result}"
187        );
188    }
189
190    #[test]
191    fn test_delete_progress_removed_keyword() {
192        // "removed" is an alternative trigger word
193        assert!(extract_delete_progress("5 files removed successfully").is_some());
194    }
195
196    #[test]
197    fn test_delete_progress_content_includes_first_line() {
198        let result = extract_delete_progress("Deleted src/old.rs").unwrap();
199        assert!(result.contains("Deleted src/old.rs"), "got: {result}");
200    }
201
202    #[test]
203    fn test_edit_progress_long_line_is_truncated() {
204        let long = "Applied replacements to ".to_string() + &"x".repeat(100);
205        let result = extract_edit_progress(&long).unwrap();
206        // Should be capped at 80 chars + "..."
207        assert!(
208            result.contains("..."),
209            "long line should be truncated: {result}"
210        );
211    }
212
213    #[test]
214    fn test_edit_progress_short_line_not_truncated() {
215        let result = extract_edit_progress("Applied 3 replacements to src/lib.rs").unwrap();
216        assert!(
217            !result.contains("..."),
218            "short line should not be truncated: {result}"
219        );
220    }
221
222    // ── extract_bash_progress edge cases ─────────────────────────────────
223
224    #[test]
225    fn test_bash_progress_background_process() {
226        let result = extract_bash_progress(
227            "Background process started\n  Command: cargo watch -x test\n  PID: 1234",
228        )
229        .unwrap();
230        assert!(
231            result.contains("cargo watch"),
232            "should include command: {result}"
233        );
234        assert!(result.contains("Started background"), "got: {result}");
235    }
236
237    #[test]
238    fn test_bash_progress_build_cargo_finish() {
239        // cargo build output typically contains "Finished" and "target"
240        let result =
241            extract_bash_progress("   Finished `release` profile [optimized] target(s)").unwrap();
242        assert!(result.contains("Build succeeded"), "got: {result}");
243    }
244
245    #[test]
246    fn test_bash_progress_build_succeeded_literal() {
247        let result = extract_bash_progress("build succeeded").unwrap();
248        assert!(result.contains("Build succeeded"), "got: {result}");
249    }
250
251    #[test]
252    fn test_bash_progress_build_failed() {
253        let result =
254            extract_bash_progress("error: could not compile `myapp` due to 3 errors").unwrap();
255        assert!(result.contains("Build failed"), "got: {result}");
256    }
257
258    #[test]
259    fn test_bash_progress_tests_passed_text() {
260        // Some test runners emit "tests passed" not "test result: ok"
261        let result = extract_bash_progress("All 42 tests passed").unwrap();
262        assert!(result.contains("Tests passed"), "got: {result}");
263    }
264
265    #[test]
266    fn test_bash_progress_tests_failed_text() {
267        let result = extract_bash_progress("3 tests failed").unwrap();
268        assert!(result.contains("Tests failed"), "got: {result}");
269    }
270
271    #[test]
272    fn test_bash_progress_background_long_command_truncated() {
273        let long_cmd = "x".repeat(80);
274        let output = format!("Background process started\n  Command: {long_cmd}\n");
275        let result = extract_bash_progress(&output).unwrap();
276        assert!(
277            result.contains("..."),
278            "long command should be truncated: {result}"
279        );
280    }
281
282    // ── async / DB-backed tests ──────────────────────────────────────────
283
284    async fn test_db() -> (Database, tempfile::TempDir, String) {
285        let dir = tempfile::TempDir::new().unwrap();
286        let db = Database::open(&dir.path().join("progress_test.db"))
287            .await
288            .unwrap();
289        let sid = db.create_session("koda", dir.path()).await.unwrap();
290        (db, dir, sid)
291    }
292
293    #[tokio::test]
294    async fn test_get_progress_summary_empty() {
295        let (db, _dir, sid) = test_db().await;
296        let result = get_progress_summary(&db, &sid).await;
297        assert!(result.is_none(), "empty session should return None");
298    }
299
300    #[tokio::test]
301    async fn test_track_progress_write_tool() {
302        let (db, _dir, sid) = test_db().await;
303        track_progress(&db, &sid, "Write", "", "Created file: src/main.rs").await;
304        let summary = get_progress_summary(&db, &sid).await.unwrap();
305        assert!(summary.contains("src/main.rs"), "got: {summary}");
306        assert!(summary.contains("Session Progress"), "got: {summary}");
307    }
308
309    #[tokio::test]
310    async fn test_track_progress_edit_tool() {
311        let (db, _dir, sid) = test_db().await;
312        track_progress(
313            &db,
314            &sid,
315            "Edit",
316            "",
317            "Applied 2 replacements to src/lib.rs",
318        )
319        .await;
320        let summary = get_progress_summary(&db, &sid).await.unwrap();
321        assert!(summary.contains("lib.rs"), "got: {summary}");
322    }
323
324    #[tokio::test]
325    async fn test_track_progress_bash_tool() {
326        let (db, _dir, sid) = test_db().await;
327        track_progress(
328            &db,
329            &sid,
330            "Bash",
331            "",
332            "test result: ok. 10 passed; 0 failed",
333        )
334        .await;
335        let summary = get_progress_summary(&db, &sid).await.unwrap();
336        assert!(summary.contains("Tests passed"), "got: {summary}");
337    }
338
339    #[tokio::test]
340    async fn test_track_progress_delete_tool() {
341        let (db, _dir, sid) = test_db().await;
342        track_progress(&db, &sid, "Delete", "", "Deleted src/old.rs").await;
343        let summary = get_progress_summary(&db, &sid).await.unwrap();
344        assert!(summary.contains("Deleted"), "got: {summary}");
345    }
346
347    #[tokio::test]
348    async fn test_track_progress_unknown_tool_no_entry() {
349        let (db, _dir, sid) = test_db().await;
350        track_progress(&db, &sid, "UnknownTool", "", "some result").await;
351        // Unknown tool → no entry tracked
352        let summary = get_progress_summary(&db, &sid).await;
353        assert!(summary.is_none(), "unknown tool should not create progress");
354    }
355
356    #[tokio::test]
357    async fn test_progress_accumulates_multiple_entries() {
358        let (db, _dir, sid) = test_db().await;
359        track_progress(&db, &sid, "Write", "", "Created file: a.rs").await;
360        track_progress(&db, &sid, "Write", "", "Created file: b.rs").await;
361        let summary = get_progress_summary(&db, &sid).await.unwrap();
362        assert!(summary.contains("a.rs"), "got: {summary}");
363        assert!(summary.contains("b.rs"), "got: {summary}");
364    }
365}