Skip to main content

koda_core/
git.rs

1//! Git integration for context injection.
2//!
3//! - `git_context()`: compact git info for the system prompt
4//!
5//! File-level undo is handled by `undo.rs` (in-memory snapshots),
6//! not git. See DESIGN.md for rationale.
7
8use std::path::Path;
9use std::process::Command;
10
11// ── Context injection (#263) ────────────────────────────────────
12
13/// Maximum characters for the diff stat section.
14const MAX_DIFF_STAT_CHARS: usize = 2_000;
15/// Maximum recent commits to include.
16const MAX_RECENT_COMMITS: usize = 5;
17
18/// Compact git context for injection into the system prompt.
19///
20/// Returns `None` if not in a git repo. Includes:
21/// - Current branch name
22/// - Staged diff stat (truncated)
23/// - Unstaged diff stat (truncated)
24/// - Last N commit subjects
25pub fn git_context(project_root: &Path) -> Option<String> {
26    let branch = git_cmd(project_root, &["rev-parse", "--abbrev-ref", "HEAD"])?;
27
28    let mut parts = vec![format!("[Git: branch={branch}")];
29
30    // Staged changes (stat only — token-efficient)
31    if let Some(staged) = git_cmd(project_root, &["diff", "--cached", "--stat"])
32        && !staged.trim().is_empty()
33    {
34        let truncated = truncate_str(&staged, MAX_DIFF_STAT_CHARS);
35        parts.push(format!("staged:\n{truncated}"));
36    }
37
38    // Unstaged changes (stat only)
39    if let Some(unstaged) = git_cmd(project_root, &["diff", "--stat"])
40        && !unstaged.trim().is_empty()
41    {
42        let truncated = truncate_str(&unstaged, MAX_DIFF_STAT_CHARS);
43        parts.push(format!("unstaged:\n{truncated}"));
44    }
45
46    // Untracked file count
47    if let Some(untracked) = git_cmd(
48        project_root,
49        &["ls-files", "--others", "--exclude-standard"],
50    ) {
51        let count = untracked.lines().count();
52        if count > 0 {
53            parts.push(format!("{count} untracked file(s)"));
54        }
55    }
56
57    // Recent commits
58    if let Some(log) = git_cmd(
59        project_root,
60        &[
61            "log",
62            "--oneline",
63            &format!("-{MAX_RECENT_COMMITS}"),
64            "--no-decorate",
65        ],
66    ) && !log.trim().is_empty()
67    {
68        parts.push(format!("recent commits:\n{log}"));
69    }
70
71    parts.push("]".to_string());
72    Some(parts.join(", "))
73}
74
75// ── Helpers ─────────────────────────────────────────────────────
76
77/// Run a git command and return stdout if successful.
78fn git_cmd(cwd: &Path, args: &[&str]) -> Option<String> {
79    Command::new("git")
80        .args(args)
81        .current_dir(cwd)
82        .output()
83        .ok()
84        .filter(|o| o.status.success())
85        .map(|o| String::from_utf8_lossy(&o.stdout).to_string())
86}
87
88/// Truncate a string to max chars at a line boundary.
89fn truncate_str(s: &str, max: usize) -> String {
90    if s.len() <= max {
91        return s.to_string();
92    }
93    // Find last newline before max
94    let end = s[..max].rfind('\n').unwrap_or(max);
95    let truncated = &s[..end];
96    let remaining = s[end..].lines().count();
97    format!("{truncated}\n  ... ({remaining} more lines)")
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103
104    #[test]
105    fn test_git_context_in_repo() {
106        // We're running tests inside the koda repo, so this should work
107        let ctx = git_context(Path::new("."));
108        assert!(ctx.is_some());
109        let ctx = ctx.unwrap();
110        assert!(ctx.contains("[Git: branch="));
111        assert!(ctx.contains("recent commits:"));
112    }
113
114    #[test]
115    fn test_git_context_not_a_repo() {
116        let tmp = tempfile::tempdir().unwrap();
117        let ctx = git_context(tmp.path());
118        assert!(ctx.is_none());
119    }
120
121    #[test]
122    fn test_truncate_str_short() {
123        assert_eq!(truncate_str("hello", 100), "hello");
124    }
125
126    #[test]
127    fn test_truncate_str_long() {
128        let lines: Vec<String> = (0..50).map(|i| format!("line {i}")).collect();
129        let input = lines.join("\n");
130        let truncated = truncate_str(&input, 50);
131        assert!(truncated.len() <= 80); // 50 + "... (N more lines)"
132        assert!(truncated.contains("more lines"));
133    }
134}