git_iris/agents/tools/
git.rs

1//! Git operations tools for Rig-based agents
2//!
3//! This module provides Git operations using Rig's tool system.
4
5use anyhow::Result;
6use rig::completion::ToolDefinition;
7use rig::tool::Tool;
8use serde::{Deserialize, Serialize};
9
10use crate::context::ChangeType;
11use crate::define_tool_error;
12use crate::git::StagedFile;
13
14use super::common::{get_current_repo, parameters_schema};
15
16define_tool_error!(GitError);
17
18/// Helper to add a change type if not already present
19fn add_change(changes: &mut Vec<&'static str>, change: &'static str) {
20    if !changes.contains(&change) {
21        changes.push(change);
22    }
23}
24
25/// Check for function definitions in a line based on language
26fn is_function_def(line: &str, ext: &str) -> bool {
27    match ext {
28        "rs" => {
29            line.starts_with("pub fn ")
30                || line.starts_with("fn ")
31                || line.starts_with("pub async fn ")
32                || line.starts_with("async fn ")
33        }
34        "ts" | "tsx" | "js" | "jsx" => {
35            line.starts_with("function ")
36                || line.starts_with("async function ")
37                || line.contains(" = () =>")
38                || line.contains(" = async () =>")
39        }
40        "py" => line.starts_with("def ") || line.starts_with("async def "),
41        "go" => line.starts_with("func "),
42        _ => false,
43    }
44}
45
46/// Check for import statements based on language
47fn is_import(line: &str, ext: &str) -> bool {
48    match ext {
49        "rs" => line.starts_with("use ") || line.starts_with("pub use "),
50        "ts" | "tsx" | "js" | "jsx" => line.starts_with("import ") || line.starts_with("export "),
51        "py" => line.starts_with("import ") || line.starts_with("from "),
52        "go" => line.starts_with("import "),
53        _ => false,
54    }
55}
56
57/// Check for type definitions based on language
58fn is_type_def(line: &str, ext: &str) -> bool {
59    match ext {
60        "rs" => {
61            line.starts_with("pub struct ")
62                || line.starts_with("struct ")
63                || line.starts_with("pub enum ")
64                || line.starts_with("enum ")
65        }
66        "ts" | "tsx" | "js" | "jsx" => {
67            line.starts_with("interface ")
68                || line.starts_with("type ")
69                || line.starts_with("class ")
70        }
71        "py" => line.starts_with("class "),
72        "go" => line.starts_with("type "),
73        _ => false,
74    }
75}
76
77/// Detect semantic change types from diff content
78#[allow(clippy::cognitive_complexity)]
79fn detect_semantic_changes(diff: &str, path: &str) -> Vec<&'static str> {
80    use std::path::Path;
81
82    let mut changes = Vec::new();
83
84    // Get file extension
85    let ext = Path::new(path)
86        .extension()
87        .and_then(|e| e.to_str())
88        .map(str::to_lowercase)
89        .unwrap_or_default();
90
91    // Only analyze supported languages
92    let supported = matches!(
93        ext.as_str(),
94        "rs" | "ts" | "tsx" | "js" | "jsx" | "py" | "go"
95    );
96
97    if supported {
98        // Analyze added lines for patterns
99        for line in diff
100            .lines()
101            .filter(|l| l.starts_with('+') && !l.starts_with("+++"))
102        {
103            let line = line.trim_start_matches('+').trim();
104
105            if is_function_def(line, &ext) {
106                add_change(&mut changes, "adds function");
107            }
108            if is_import(line, &ext) {
109                add_change(&mut changes, "modifies imports");
110            }
111            if is_type_def(line, &ext) {
112                add_change(&mut changes, "adds type");
113            }
114            // Rust-specific: impl blocks
115            if ext == "rs" && line.starts_with("impl ") {
116                add_change(&mut changes, "adds impl");
117            }
118        }
119    }
120
121    // Check for general change patterns
122    let has_deletions = diff
123        .lines()
124        .any(|l| l.starts_with('-') && !l.starts_with("---"));
125    let has_additions = diff
126        .lines()
127        .any(|l| l.starts_with('+') && !l.starts_with("+++"));
128
129    if has_deletions && has_additions && changes.is_empty() {
130        changes.push("refactors code");
131    } else if has_deletions && !has_additions {
132        changes.push("removes code");
133    }
134
135    changes
136}
137
138/// Calculate relevance score for a file (0.0 - 1.0)
139/// Higher score = more important for commit message
140#[allow(clippy::case_sensitive_file_extension_comparisons)]
141fn calculate_relevance_score(file: &StagedFile) -> (f32, Vec<&'static str>) {
142    let mut score: f32 = 0.5; // Base score
143    let mut reasons = Vec::new();
144    let path = file.path.to_lowercase();
145
146    // Change type scoring
147    match file.change_type {
148        ChangeType::Added => {
149            score += 0.15;
150            reasons.push("new file");
151        }
152        ChangeType::Modified => {
153            score += 0.1;
154        }
155        ChangeType::Deleted => {
156            score += 0.05;
157            reasons.push("deleted");
158        }
159    }
160
161    // File type scoring - source code is most important
162    if path.ends_with(".rs")
163        || path.ends_with(".py")
164        || path.ends_with(".ts")
165        || path.ends_with(".tsx")
166        || path.ends_with(".js")
167        || path.ends_with(".jsx")
168        || path.ends_with(".go")
169        || path.ends_with(".java")
170        || path.ends_with(".kt")
171        || path.ends_with(".swift")
172        || path.ends_with(".c")
173        || path.ends_with(".cpp")
174        || path.ends_with(".h")
175    {
176        score += 0.15;
177        reasons.push("source code");
178    } else if path.ends_with(".toml")
179        || path.ends_with(".json")
180        || path.ends_with(".yaml")
181        || path.ends_with(".yml")
182    {
183        score += 0.1;
184        reasons.push("config");
185    } else if path.ends_with(".md") || path.ends_with(".txt") || path.ends_with(".rst") {
186        score += 0.02;
187        reasons.push("docs");
188    }
189
190    // Path-based scoring
191    if path.contains("/src/") || path.starts_with("src/") {
192        score += 0.1;
193        reasons.push("core source");
194    }
195    if path.contains("/test") || path.contains("_test.") || path.contains(".test.") {
196        score -= 0.1;
197        reasons.push("test file");
198    }
199    if path.contains("generated") || path.contains(".lock") || path.contains("package-lock") {
200        score -= 0.2;
201        reasons.push("generated/lock");
202    }
203    if path.contains("/vendor/") || path.contains("/node_modules/") {
204        score -= 0.3;
205        reasons.push("vendored");
206    }
207
208    // Diff size scoring (estimate from diff length)
209    let diff_lines = file.diff.lines().count();
210    if diff_lines > 10 && diff_lines < 200 {
211        score += 0.1;
212        reasons.push("substantive changes");
213    } else if diff_lines >= 200 {
214        score += 0.05;
215        reasons.push("large diff");
216    }
217
218    // Add semantic change detection
219    let semantic_changes = detect_semantic_changes(&file.diff, &file.path);
220    for change in semantic_changes {
221        if !reasons.contains(&change) {
222            // Boost score for structural changes
223            if change == "adds function" || change == "adds type" || change == "adds impl" {
224                score += 0.1;
225            }
226            reasons.push(change);
227        }
228    }
229
230    // Clamp to 0.0-1.0
231    score = score.clamp(0.0, 1.0);
232
233    (score, reasons)
234}
235
236/// Scored file for output
237struct ScoredFile<'a> {
238    file: &'a StagedFile,
239    score: f32,
240    reasons: Vec<&'static str>,
241}
242
243/// Build the diff output string from scored files
244fn format_diff_output(
245    scored_files: &[ScoredFile],
246    total_files: usize,
247    is_filtered: bool,
248    include_diffs: bool,
249) -> String {
250    let mut output = String::new();
251    let showing = scored_files.len();
252
253    // Calculate stats
254    let additions: usize = scored_files
255        .iter()
256        .map(|sf| sf.file.diff.lines().filter(|l| l.starts_with('+')).count())
257        .sum();
258    let deletions: usize = scored_files
259        .iter()
260        .map(|sf| sf.file.diff.lines().filter(|l| l.starts_with('-')).count())
261        .sum();
262    let total_lines = additions + deletions;
263
264    // Categorize size
265    let (size, guidance) = if is_filtered {
266        ("Filtered", "Showing requested files only.")
267    } else if total_files <= 3 && total_lines < 100 {
268        ("Small", "Focus on all files equally.")
269    } else if total_files <= 10 && total_lines < 500 {
270        ("Medium", "Prioritize files with >60% relevance.")
271    } else {
272        (
273            "Large",
274            "Use files=['path1','path2'] with detail='standard' to analyze specific files.",
275        )
276    };
277
278    // Header
279    let files_info = if is_filtered {
280        format!("{showing} of {total_files} files")
281    } else {
282        format!("{total_files} files")
283    };
284    output.push_str(&format!(
285        "=== CHANGES SUMMARY ===\n{files_info} | +{additions} -{deletions} | Size: {size} ({total_lines} lines)\nGuidance: {guidance}\n\n"
286    ));
287
288    // File list
289    output.push_str("Files by importance:\n");
290    for sf in scored_files {
291        let reasons = if sf.reasons.is_empty() {
292            String::new()
293        } else {
294            format!(" ({})", sf.reasons.join(", "))
295        };
296        output.push_str(&format!(
297            "  [{:.0}%] {:?} {}{reasons}\n",
298            sf.score * 100.0,
299            sf.file.change_type,
300            sf.file.path
301        ));
302    }
303    output.push('\n');
304
305    // Diffs or hint
306    if include_diffs {
307        output.push_str("=== DIFFS ===\n");
308        for sf in scored_files {
309            output.push_str(&format!(
310                "--- {} [{:.0}% relevance]\n",
311                sf.file.path,
312                sf.score * 100.0
313            ));
314            output.push_str(&sf.file.diff);
315            output.push('\n');
316        }
317    } else if is_filtered {
318        output.push_str("(Use detail='standard' to see full diffs for these files)\n");
319    } else {
320        output.push_str(
321            "(Use detail='standard' with files=['file1','file2'] to see specific diffs)\n",
322        );
323    }
324
325    output
326}
327
328// Git status tool
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct GitStatus;
331
332#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
333pub struct GitStatusArgs {
334    #[serde(default)]
335    pub include_unstaged: bool,
336}
337
338impl Tool for GitStatus {
339    const NAME: &'static str = "git_status";
340    type Error = GitError;
341    type Args = GitStatusArgs;
342    type Output = String;
343
344    async fn definition(&self, _: String) -> ToolDefinition {
345        ToolDefinition {
346            name: "git_status".to_string(),
347            description: "Get current Git repository status including staged and unstaged files"
348                .to_string(),
349            parameters: parameters_schema::<GitStatusArgs>(),
350        }
351    }
352
353    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
354        let repo = get_current_repo().map_err(GitError::from)?;
355
356        let files_info = repo
357            .extract_files_info(args.include_unstaged)
358            .map_err(GitError::from)?;
359
360        let mut output = String::new();
361        output.push_str(&format!("Branch: {}\n", files_info.branch));
362        output.push_str(&format!(
363            "Files changed: {}\n",
364            files_info.staged_files.len()
365        ));
366
367        for file in &files_info.staged_files {
368            output.push_str(&format!("  {}: {:?}\n", file.path, file.change_type));
369        }
370
371        Ok(output)
372    }
373}
374
375// Git diff tool
376#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct GitDiff;
378
379/// Detail level for diff output
380#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema, Default)]
381#[serde(rename_all = "lowercase")]
382pub enum DetailLevel {
383    /// Summary only: file list with stats and relevance scores, no diffs (default)
384    #[default]
385    Summary,
386    /// Standard: includes full diffs (use with `files` filter for large changesets)
387    Standard,
388}
389
390#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
391pub struct GitDiffArgs {
392    /// Use "staged" or omit for staged changes, or specify commit/branch
393    #[serde(default)]
394    pub from: Option<String>,
395    /// Target commit/branch (use with from)
396    #[serde(default)]
397    pub to: Option<String>,
398    /// Detail level: "summary" (default) for overview, "standard" for full diffs
399    #[serde(default)]
400    pub detail: DetailLevel,
401    /// Filter to specific files (use with detail="standard" for targeted analysis)
402    #[serde(default)]
403    pub files: Option<Vec<String>>,
404}
405
406impl Tool for GitDiff {
407    const NAME: &'static str = "git_diff";
408    type Error = GitError;
409    type Args = GitDiffArgs;
410    type Output = String;
411
412    async fn definition(&self, _: String) -> ToolDefinition {
413        ToolDefinition {
414            name: "git_diff".to_string(),
415            description: "Get Git diff for file changes. Returns summary by default (file list with relevance scores). Use detail='standard' with files=['path1','path2'] to get full diffs for specific files. Progressive approach: call once for summary, then again with files filter for important ones.".to_string(),
416            parameters: parameters_schema::<GitDiffArgs>(),
417        }
418    }
419
420    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
421        let repo = get_current_repo().map_err(GitError::from)?;
422
423        // Normalize empty strings to None (LLMs often send "" instead of null)
424        let from = args.from.filter(|s| !s.is_empty());
425        let to = args.to.filter(|s| !s.is_empty());
426
427        // Handle the case where we want staged changes
428        // - No args: get staged changes
429        // - from="staged": get staged changes
430        // - Otherwise: get commit range
431        let files = match (from.as_deref(), to.as_deref()) {
432            (None | Some("staged"), None) | (Some("staged"), Some("HEAD")) => {
433                // Get staged changes
434                let files_info = repo.extract_files_info(false).map_err(GitError::from)?;
435                files_info.staged_files
436            }
437            (Some(from), Some(to)) => {
438                // Get changes between two commits/branches
439                repo.get_commit_range_files(from, to)
440                    .map_err(GitError::from)?
441            }
442            (None, Some(_)) => {
443                // Invalid: to without from
444                return Err(GitError(
445                    "Cannot specify 'to' without 'from'. Use both or neither.".to_string(),
446                ));
447            }
448            (Some(from), None) => {
449                // Get changes from a specific commit to HEAD (already handled "staged" above)
450                repo.get_commit_range_files(from, "HEAD")
451                    .map_err(GitError::from)?
452            }
453        };
454
455        // Score and sort files by relevance
456        let mut scored_files: Vec<ScoredFile> = files
457            .iter()
458            .map(|file| {
459                let (score, reasons) = calculate_relevance_score(file);
460                ScoredFile {
461                    file,
462                    score,
463                    reasons,
464                }
465            })
466            .collect();
467
468        // Sort by score descending (most important first)
469        scored_files.sort_by(|a, b| {
470            b.score
471                .partial_cmp(&a.score)
472                .unwrap_or(std::cmp::Ordering::Equal)
473        });
474
475        // Track total before filtering
476        let total_files = scored_files.len();
477
478        // Filter to specific files if requested
479        let is_filtered = args.files.is_some();
480        if let Some(ref filter) = args.files {
481            scored_files.retain(|sf| filter.iter().any(|f| sf.file.path.contains(f)));
482        }
483
484        // Build output
485        let include_diffs = matches!(args.detail, DetailLevel::Standard);
486        Ok(format_diff_output(
487            &scored_files,
488            total_files,
489            is_filtered,
490            include_diffs,
491        ))
492    }
493}
494
495// Git log tool
496#[derive(Debug, Clone, Serialize, Deserialize)]
497pub struct GitLog;
498
499#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
500pub struct GitLogArgs {
501    #[serde(default)]
502    pub count: Option<usize>,
503}
504
505impl Tool for GitLog {
506    const NAME: &'static str = "git_log";
507    type Error = GitError;
508    type Args = GitLogArgs;
509    type Output = String;
510
511    async fn definition(&self, _: String) -> ToolDefinition {
512        ToolDefinition {
513            name: "git_log".to_string(),
514            description: "Get Git commit history".to_string(),
515            parameters: parameters_schema::<GitLogArgs>(),
516        }
517    }
518
519    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
520        let repo = get_current_repo().map_err(GitError::from)?;
521
522        let commits = repo
523            .get_recent_commits(args.count.unwrap_or(10))
524            .map_err(GitError::from)?;
525
526        let mut output = String::new();
527        output.push_str("Recent commits:\n");
528
529        for commit in commits {
530            output.push_str(&format!(
531                "{}: {} ({})\n",
532                commit.hash, commit.message, commit.author
533            ));
534        }
535
536        Ok(output)
537    }
538}
539
540// Git repository info tool
541#[derive(Debug, Clone, Serialize, Deserialize)]
542pub struct GitRepoInfo;
543
544#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
545pub struct GitRepoInfoArgs {}
546
547impl Tool for GitRepoInfo {
548    const NAME: &'static str = "git_repo_info";
549    type Error = GitError;
550    type Args = GitRepoInfoArgs;
551    type Output = String;
552
553    async fn definition(&self, _: String) -> ToolDefinition {
554        ToolDefinition {
555            name: "git_repo_info".to_string(),
556            description: "Get general information about the Git repository".to_string(),
557            parameters: parameters_schema::<GitRepoInfoArgs>(),
558        }
559    }
560
561    async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
562        let repo = get_current_repo().map_err(GitError::from)?;
563
564        let branch = repo.get_current_branch().map_err(GitError::from)?;
565        let remote_url = repo.get_remote_url().unwrap_or("None").to_string();
566
567        let mut output = String::new();
568        output.push_str("Repository Information:\n");
569        output.push_str(&format!("Current Branch: {branch}\n"));
570        output.push_str(&format!("Remote URL: {remote_url}\n"));
571        output.push_str(&format!(
572            "Repository Path: {}\n",
573            repo.repo_path().display()
574        ));
575
576        Ok(output)
577    }
578}
579
580// Git changed files tool
581#[derive(Debug, Clone, Serialize, Deserialize)]
582pub struct GitChangedFiles;
583
584#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
585pub struct GitChangedFilesArgs {
586    #[serde(default)]
587    pub from: Option<String>,
588    #[serde(default)]
589    pub to: Option<String>,
590}
591
592impl Tool for GitChangedFiles {
593    const NAME: &'static str = "git_changed_files";
594    type Error = GitError;
595    type Args = GitChangedFilesArgs;
596    type Output = String;
597
598    async fn definition(&self, _: String) -> ToolDefinition {
599        ToolDefinition {
600            name: "git_changed_files".to_string(),
601            description: "Get list of files that have changed between commits or branches"
602                .to_string(),
603            parameters: parameters_schema::<GitChangedFilesArgs>(),
604        }
605    }
606
607    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
608        let repo = get_current_repo().map_err(GitError::from)?;
609
610        // Normalize empty strings to None (LLMs often send "" instead of null)
611        let from = args.from.filter(|s| !s.is_empty());
612        let mut to = args.to.filter(|s| !s.is_empty());
613
614        // Default to HEAD when the caller provides only a starting point.
615        if from.is_some() && to.is_none() {
616            to = Some("HEAD".to_string());
617        }
618
619        let files = match (from, to) {
620            (Some(from), Some(to)) => {
621                // When both from and to are provided, get files changed between commits/branches
622                let range_files = repo
623                    .get_commit_range_files(&from, &to)
624                    .map_err(GitError::from)?;
625                range_files.iter().map(|f| f.path.clone()).collect()
626            }
627            (None, Some(to)) => {
628                // When only to is provided, get files changed in that single commit
629                repo.get_file_paths_for_commit(&to)
630                    .map_err(GitError::from)?
631            }
632            (Some(_from), None) => {
633                // Invalid: from without to doesn't make sense for file listing
634                return Err(GitError(
635                    "Cannot specify 'from' without 'to' for file listing".to_string(),
636                ));
637            }
638            (None, None) => {
639                // When neither are provided, get staged files
640                let files_info = repo.extract_files_info(false).map_err(GitError::from)?;
641                files_info.file_paths
642            }
643        };
644
645        let mut output = String::new();
646        output.push_str("Changed files:\n");
647
648        for file in files {
649            output.push_str(&format!("  {file}\n"));
650        }
651
652        Ok(output)
653    }
654}