use std::path::Path;
use std::process::Command;
const MAX_DIFF_STAT_CHARS: usize = 2_000;
const MAX_RECENT_COMMITS: usize = 5;
pub fn git_context(project_root: &Path) -> Option<String> {
let branch = git_cmd(project_root, &["rev-parse", "--abbrev-ref", "HEAD"])?;
let mut parts = vec![format!("[Git: branch={branch}")];
if let Some(staged) = git_cmd(project_root, &["diff", "--cached", "--stat"])
&& !staged.trim().is_empty()
{
let truncated = truncate_str(&staged, MAX_DIFF_STAT_CHARS);
parts.push(format!("staged:\n{truncated}"));
}
if let Some(unstaged) = git_cmd(project_root, &["diff", "--stat"])
&& !unstaged.trim().is_empty()
{
let truncated = truncate_str(&unstaged, MAX_DIFF_STAT_CHARS);
parts.push(format!("unstaged:\n{truncated}"));
}
if let Some(untracked) = git_cmd(
project_root,
&["ls-files", "--others", "--exclude-standard"],
) {
let count = untracked.lines().count();
if count > 0 {
parts.push(format!("{count} untracked file(s)"));
}
}
if let Some(log) = git_cmd(
project_root,
&[
"log",
"--oneline",
&format!("-{MAX_RECENT_COMMITS}"),
"--no-decorate",
],
) && !log.trim().is_empty()
{
parts.push(format!("recent commits:\n{log}"));
}
parts.push("]".to_string());
Some(parts.join(", "))
}
fn git_cmd(cwd: &Path, args: &[&str]) -> Option<String> {
Command::new("git")
.args(args)
.current_dir(cwd)
.output()
.ok()
.filter(|o| o.status.success())
.map(|o| String::from_utf8_lossy(&o.stdout).to_string())
}
fn truncate_str(s: &str, max: usize) -> String {
if s.len() <= max {
return s.to_string();
}
let end = s[..max].rfind('\n').unwrap_or(max);
let truncated = &s[..end];
let remaining = s[end..].lines().count();
format!("{truncated}\n ... ({remaining} more lines)")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_git_context_in_repo() {
let ctx = git_context(Path::new("."));
assert!(ctx.is_some());
let ctx = ctx.unwrap();
assert!(ctx.contains("[Git: branch="));
assert!(ctx.contains("recent commits:"));
}
#[test]
fn test_git_context_not_a_repo() {
let tmp = tempfile::tempdir().unwrap();
let ctx = git_context(tmp.path());
assert!(ctx.is_none());
}
#[test]
fn test_truncate_str_short() {
assert_eq!(truncate_str("hello", 100), "hello");
}
#[test]
fn test_truncate_str_long() {
let lines: Vec<String> = (0..50).map(|i| format!("line {i}")).collect();
let input = lines.join("\n");
let truncated = truncate_str(&input, 50);
assert!(truncated.len() <= 80); assert!(truncated.contains("more lines"));
}
}