Skip to main content

brainwires_tool_system/
git.rs

1use anyhow::Result;
2use git2::Repository;
3use serde::Deserialize;
4use serde_json::{Value, json};
5use std::collections::HashMap;
6
7use brainwires_core::{Tool, ToolContext, ToolInputSchema, ToolResult};
8
9/// Git operations tool implementation
10pub struct GitTool;
11
12impl GitTool {
13    /// Get all git tool definitions
14    pub fn get_tools() -> Vec<Tool> {
15        vec![
16            Self::git_status_tool(),
17            Self::git_diff_tool(),
18            Self::git_log_tool(),
19            Self::git_stage_tool(),
20            Self::git_unstage_tool(),
21            Self::git_commit_tool(),
22            Self::git_push_tool(),
23            Self::git_pull_tool(),
24            Self::git_fetch_tool(),
25            Self::git_discard_tool(),
26            Self::git_branch_tool(),
27        ]
28    }
29
30    fn git_status_tool() -> Tool {
31        Tool {
32            name: "git_status".to_string(),
33            description: "Get git repository status".to_string(),
34            input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
35            requires_approval: false,
36            ..Default::default()
37        }
38    }
39
40    fn git_diff_tool() -> Tool {
41        Tool {
42            name: "git_diff".to_string(),
43            description: "Get git diff of changes".to_string(),
44            input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
45            requires_approval: false,
46            ..Default::default()
47        }
48    }
49
50    fn git_log_tool() -> Tool {
51        let mut properties = HashMap::new();
52        properties.insert(
53            "limit".to_string(),
54            json!({"type": "number", "description": "Number of commits", "default": 10}),
55        );
56        Tool {
57            name: "git_log".to_string(),
58            description: "Get git commit history".to_string(),
59            input_schema: ToolInputSchema::object(properties, vec![]),
60            requires_approval: false,
61            ..Default::default()
62        }
63    }
64
65    fn git_stage_tool() -> Tool {
66        let mut properties = HashMap::new();
67        properties.insert("files".to_string(), json!({"type": "array", "items": {"type": "string"}, "description": "Files to stage. Use '.' for all."}));
68        Tool {
69            name: "git_stage".to_string(),
70            description: "Stage files for commit.".to_string(),
71            input_schema: ToolInputSchema::object(properties, vec!["files".to_string()]),
72            requires_approval: true,
73            ..Default::default()
74        }
75    }
76
77    fn git_unstage_tool() -> Tool {
78        let mut properties = HashMap::new();
79        properties.insert("files".to_string(), json!({"type": "array", "items": {"type": "string"}, "description": "Files to unstage."}));
80        Tool {
81            name: "git_unstage".to_string(),
82            description: "Unstage files from the staging area.".to_string(),
83            input_schema: ToolInputSchema::object(properties, vec!["files".to_string()]),
84            requires_approval: true,
85            ..Default::default()
86        }
87    }
88
89    fn git_commit_tool() -> Tool {
90        let mut properties = HashMap::new();
91        properties.insert(
92            "message".to_string(),
93            json!({"type": "string", "description": "Commit message"}),
94        );
95        properties.insert("all".to_string(), json!({"type": "boolean", "description": "Stage all modified files before committing", "default": false}));
96        Tool {
97            name: "git_commit".to_string(),
98            description: "Create a git commit with staged changes.".to_string(),
99            input_schema: ToolInputSchema::object(properties, vec!["message".to_string()]),
100            requires_approval: true,
101            ..Default::default()
102        }
103    }
104
105    fn git_push_tool() -> Tool {
106        let mut properties = HashMap::new();
107        properties.insert("remote".to_string(), json!({"type": "string", "description": "Remote name (default: origin)", "default": "origin"}));
108        properties.insert(
109            "branch".to_string(),
110            json!({"type": "string", "description": "Branch to push"}),
111        );
112        properties.insert("set_upstream".to_string(), json!({"type": "boolean", "description": "Set upstream tracking (-u)", "default": false}));
113        Tool {
114            name: "git_push".to_string(),
115            description: "Push commits to a remote repository.".to_string(),
116            input_schema: ToolInputSchema::object(properties, vec![]),
117            requires_approval: true,
118            ..Default::default()
119        }
120    }
121
122    fn git_pull_tool() -> Tool {
123        let mut properties = HashMap::new();
124        properties.insert("remote".to_string(), json!({"type": "string", "description": "Remote name (default: origin)", "default": "origin"}));
125        properties.insert(
126            "branch".to_string(),
127            json!({"type": "string", "description": "Branch to pull"}),
128        );
129        properties.insert("rebase".to_string(), json!({"type": "boolean", "description": "Use rebase instead of merge", "default": false}));
130        Tool {
131            name: "git_pull".to_string(),
132            description: "Pull changes from a remote repository.".to_string(),
133            input_schema: ToolInputSchema::object(properties, vec![]),
134            requires_approval: true,
135            ..Default::default()
136        }
137    }
138
139    fn git_fetch_tool() -> Tool {
140        let mut properties = HashMap::new();
141        properties.insert("remote".to_string(), json!({"type": "string", "description": "Remote name (default: origin)", "default": "origin"}));
142        properties.insert(
143            "all".to_string(),
144            json!({"type": "boolean", "description": "Fetch all remotes", "default": false}),
145        );
146        properties.insert("prune".to_string(), json!({"type": "boolean", "description": "Remove stale remote-tracking refs", "default": false}));
147        Tool {
148            name: "git_fetch".to_string(),
149            description: "Fetch changes from a remote without merging.".to_string(),
150            input_schema: ToolInputSchema::object(properties, vec![]),
151            requires_approval: false,
152            ..Default::default()
153        }
154    }
155
156    fn git_discard_tool() -> Tool {
157        let mut properties = HashMap::new();
158        properties.insert("files".to_string(), json!({"type": "array", "items": {"type": "string"}, "description": "Files to discard changes for."}));
159        Tool {
160            name: "git_discard".to_string(),
161            description: "Discard uncommitted changes. WARNING: Permanent!".to_string(),
162            input_schema: ToolInputSchema::object(properties, vec!["files".to_string()]),
163            requires_approval: true,
164            ..Default::default()
165        }
166    }
167
168    fn git_branch_tool() -> Tool {
169        let mut properties = HashMap::new();
170        properties.insert(
171            "name".to_string(),
172            json!({"type": "string", "description": "Branch name"}),
173        );
174        properties.insert("action".to_string(), json!({"type": "string", "enum": ["list", "create", "switch", "delete"], "description": "Action to perform", "default": "list"}));
175        properties.insert(
176            "force".to_string(),
177            json!({"type": "boolean", "description": "Force the action", "default": false}),
178        );
179        Tool {
180            name: "git_branch".to_string(),
181            description: "Manage git branches: list, create, switch, or delete.".to_string(),
182            input_schema: ToolInputSchema::object(properties, vec![]),
183            requires_approval: true,
184            ..Default::default()
185        }
186    }
187
188    /// Execute a git tool
189    #[tracing::instrument(name = "tool.execute", skip(input, context), fields(tool_name))]
190    pub fn execute(
191        tool_use_id: &str,
192        tool_name: &str,
193        input: &Value,
194        context: &ToolContext,
195    ) -> ToolResult {
196        let result = match tool_name {
197            "git_status" => Self::git_status(context),
198            "git_diff" => Self::git_diff(context),
199            "git_log" => Self::git_log(input, context),
200            "git_stage" => Self::git_stage(input, context),
201            "git_unstage" => Self::git_unstage(input, context),
202            "git_commit" => Self::git_commit(input, context),
203            "git_push" => Self::git_push(input, context),
204            "git_pull" => Self::git_pull(input, context),
205            "git_fetch" => Self::git_fetch(input, context),
206            "git_discard" => Self::git_discard(input, context),
207            "git_branch" => Self::git_branch(input, context),
208            _ => Err(anyhow::anyhow!("Unknown git tool: {}", tool_name)),
209        };
210        match result {
211            Ok(output) => ToolResult::success(tool_use_id.to_string(), output),
212            Err(e) => ToolResult::error(
213                tool_use_id.to_string(),
214                format!("Git operation failed: {}", e),
215            ),
216        }
217    }
218
219    fn git_status(context: &ToolContext) -> Result<String> {
220        let repo = Repository::discover(&context.working_directory)?;
221        let statuses = repo.statuses(None)?;
222        let mut output = String::from("Git Status:\n\n");
223        for entry in statuses.iter() {
224            let path = entry.path().unwrap_or("?");
225            let status = entry.status();
226            output.push_str(&format!("{:?} - {}\n", status, path));
227        }
228        Ok(output)
229    }
230
231    fn git_diff(context: &ToolContext) -> Result<String> {
232        let repo = Repository::discover(&context.working_directory)?;
233        let head = repo.head()?.peel_to_tree()?;
234        let diff = repo.diff_tree_to_workdir_with_index(Some(&head), None)?;
235        Ok(format!("Git Diff:\n{} files changed", diff.deltas().len()))
236    }
237
238    fn git_log(input: &Value, context: &ToolContext) -> Result<String> {
239        #[derive(Deserialize)]
240        struct Input {
241            #[serde(default = "default_limit")]
242            limit: usize,
243        }
244        fn default_limit() -> usize {
245            10
246        }
247        let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input { limit: 10 });
248        let repo = Repository::discover(&context.working_directory)?;
249        let mut revwalk = repo.revwalk()?;
250        revwalk.push_head()?;
251        let mut output = String::from("Git Log:\n\n");
252        for (i, oid) in revwalk.enumerate() {
253            if i >= params.limit {
254                break;
255            }
256            let commit = repo.find_commit(oid?)?;
257            output.push_str(&format!(
258                "{} - {}\n",
259                commit.id(),
260                commit.summary().unwrap_or("No message")
261            ));
262        }
263        Ok(output)
264    }
265
266    fn git_stage(input: &Value, context: &ToolContext) -> Result<String> {
267        #[derive(Deserialize)]
268        struct Input {
269            files: Vec<String>,
270        }
271        let params: Input = serde_json::from_value(input.clone())?;
272        let mut cmd = std::process::Command::new("git");
273        cmd.current_dir(&context.working_directory).arg("add");
274        for file in &params.files {
275            cmd.arg(file);
276        }
277        let output = cmd.output()?;
278        if output.status.success() {
279            Ok(format!(
280                "Successfully staged {} file(s)",
281                params.files.len()
282            ))
283        } else {
284            Err(anyhow::anyhow!(
285                "Failed to stage files: {}",
286                String::from_utf8_lossy(&output.stderr)
287            ))
288        }
289    }
290
291    fn git_unstage(input: &Value, context: &ToolContext) -> Result<String> {
292        #[derive(Deserialize)]
293        struct Input {
294            files: Vec<String>,
295        }
296        let params: Input = serde_json::from_value(input.clone())?;
297        let mut cmd = std::process::Command::new("git");
298        cmd.current_dir(&context.working_directory)
299            .args(["reset", "HEAD", "--"]);
300        for file in &params.files {
301            cmd.arg(file);
302        }
303        let output = cmd.output()?;
304        if output.status.success() {
305            Ok(format!(
306                "Successfully unstaged {} file(s)",
307                params.files.len()
308            ))
309        } else {
310            Err(anyhow::anyhow!(
311                "Failed to unstage files: {}",
312                String::from_utf8_lossy(&output.stderr)
313            ))
314        }
315    }
316
317    fn git_commit(input: &Value, context: &ToolContext) -> Result<String> {
318        #[derive(Deserialize)]
319        struct Input {
320            message: String,
321            #[serde(default)]
322            all: bool,
323        }
324        let params: Input = serde_json::from_value(input.clone())?;
325        let mut cmd = std::process::Command::new("git");
326        cmd.current_dir(&context.working_directory).arg("commit");
327        if params.all {
328            cmd.arg("-a");
329        }
330        cmd.args(["-m", &params.message]);
331        let output = cmd.output()?;
332        if output.status.success() {
333            Ok(format!(
334                "Commit successful:\n{}",
335                String::from_utf8_lossy(&output.stdout)
336            ))
337        } else {
338            Err(anyhow::anyhow!(
339                "Commit failed: {}",
340                String::from_utf8_lossy(&output.stderr)
341            ))
342        }
343    }
344
345    fn git_push(input: &Value, context: &ToolContext) -> Result<String> {
346        #[derive(Deserialize)]
347        struct Input {
348            #[serde(default = "dr")]
349            remote: String,
350            branch: Option<String>,
351            #[serde(default)]
352            set_upstream: bool,
353        }
354        fn dr() -> String {
355            "origin".to_string()
356        }
357        let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
358            remote: "origin".to_string(),
359            branch: None,
360            set_upstream: false,
361        });
362        let mut cmd = std::process::Command::new("git");
363        cmd.current_dir(&context.working_directory).arg("push");
364        if params.set_upstream {
365            cmd.arg("-u");
366        }
367        cmd.arg(&params.remote);
368        if let Some(ref branch) = params.branch {
369            cmd.arg(branch);
370        }
371        let output = cmd.output()?;
372        if output.status.success() {
373            Ok(format!(
374                "Push successful:\n{}{}",
375                String::from_utf8_lossy(&output.stdout),
376                String::from_utf8_lossy(&output.stderr)
377            ))
378        } else {
379            Err(anyhow::anyhow!(
380                "Push failed: {}",
381                String::from_utf8_lossy(&output.stderr)
382            ))
383        }
384    }
385
386    fn git_pull(input: &Value, context: &ToolContext) -> Result<String> {
387        #[derive(Deserialize)]
388        struct Input {
389            #[serde(default = "dr")]
390            remote: String,
391            branch: Option<String>,
392            #[serde(default)]
393            rebase: bool,
394        }
395        fn dr() -> String {
396            "origin".to_string()
397        }
398        let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
399            remote: "origin".to_string(),
400            branch: None,
401            rebase: false,
402        });
403        let mut cmd = std::process::Command::new("git");
404        cmd.current_dir(&context.working_directory).arg("pull");
405        if params.rebase {
406            cmd.arg("--rebase");
407        }
408        cmd.arg(&params.remote);
409        if let Some(ref branch) = params.branch {
410            cmd.arg(branch);
411        }
412        let output = cmd.output()?;
413        if output.status.success() {
414            Ok(format!(
415                "Pull successful:\n{}",
416                String::from_utf8_lossy(&output.stdout)
417            ))
418        } else {
419            Err(anyhow::anyhow!(
420                "Pull failed: {}",
421                String::from_utf8_lossy(&output.stderr)
422            ))
423        }
424    }
425
426    fn git_fetch(input: &Value, context: &ToolContext) -> Result<String> {
427        #[derive(Deserialize)]
428        struct Input {
429            #[serde(default = "dr")]
430            remote: String,
431            #[serde(default)]
432            all: bool,
433            #[serde(default)]
434            prune: bool,
435        }
436        fn dr() -> String {
437            "origin".to_string()
438        }
439        let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
440            remote: "origin".to_string(),
441            all: false,
442            prune: false,
443        });
444        let mut cmd = std::process::Command::new("git");
445        cmd.current_dir(&context.working_directory).arg("fetch");
446        if params.all {
447            cmd.arg("--all");
448        } else {
449            cmd.arg(&params.remote);
450        }
451        if params.prune {
452            cmd.arg("--prune");
453        }
454        let output = cmd.output()?;
455        if output.status.success() {
456            let stdout = String::from_utf8_lossy(&output.stdout);
457            let stderr = String::from_utf8_lossy(&output.stderr);
458            let fetch_output = if stdout.is_empty() && stderr.is_empty() {
459                "Already up to date.".to_string()
460            } else {
461                format!("{}{}", stdout, stderr)
462            };
463            Ok(format!("Fetch successful:\n{}", fetch_output))
464        } else {
465            Err(anyhow::anyhow!(
466                "Fetch failed: {}",
467                String::from_utf8_lossy(&output.stderr)
468            ))
469        }
470    }
471
472    fn git_discard(input: &Value, context: &ToolContext) -> Result<String> {
473        #[derive(Deserialize)]
474        struct Input {
475            files: Vec<String>,
476        }
477        let params: Input = serde_json::from_value(input.clone())?;
478        let mut cmd = std::process::Command::new("git");
479        cmd.current_dir(&context.working_directory)
480            .args(["checkout", "--"]);
481        for file in &params.files {
482            cmd.arg(file);
483        }
484        let output = cmd.output()?;
485        if output.status.success() {
486            Ok(format!(
487                "Successfully discarded changes to {} file(s)",
488                params.files.len()
489            ))
490        } else {
491            Err(anyhow::anyhow!(
492                "Failed to discard changes: {}",
493                String::from_utf8_lossy(&output.stderr)
494            ))
495        }
496    }
497
498    fn git_branch(input: &Value, context: &ToolContext) -> Result<String> {
499        #[derive(Deserialize)]
500        struct Input {
501            name: Option<String>,
502            #[serde(default = "da")]
503            action: String,
504            #[serde(default)]
505            force: bool,
506        }
507        fn da() -> String {
508            "list".to_string()
509        }
510        let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
511            name: None,
512            action: "list".to_string(),
513            force: false,
514        });
515        let branch_name = params.name.clone();
516        let mut cmd = std::process::Command::new("git");
517        cmd.current_dir(&context.working_directory);
518        match params.action.as_str() {
519            "list" => {
520                cmd.args(["branch", "-a", "-v"]);
521            }
522            "create" => {
523                let n = params
524                    .name
525                    .ok_or_else(|| anyhow::anyhow!("Branch name required"))?;
526                cmd.args(["branch", &n]);
527            }
528            "switch" => {
529                let n = params
530                    .name
531                    .ok_or_else(|| anyhow::anyhow!("Branch name required"))?;
532                cmd.args(["checkout", &n]);
533            }
534            "delete" => {
535                let n = params
536                    .name
537                    .ok_or_else(|| anyhow::anyhow!("Branch name required"))?;
538                if params.force {
539                    cmd.args(["branch", "-D", &n]);
540                } else {
541                    cmd.args(["branch", "-d", &n]);
542                }
543            }
544            _ => return Err(anyhow::anyhow!("Unknown branch action: {}", params.action)),
545        }
546        let output = cmd.output()?;
547        if output.status.success() {
548            let stdout = String::from_utf8_lossy(&output.stdout);
549            Ok(match params.action.as_str() {
550                "list" => format!("Branches:\n{}", stdout),
551                "create" => format!("Created branch '{}'", branch_name.unwrap_or_default()),
552                "switch" => format!("Switched to branch '{}'", branch_name.unwrap_or_default()),
553                "delete" => format!("Deleted branch '{}'", branch_name.unwrap_or_default()),
554                _ => stdout.to_string(),
555            })
556        } else {
557            Err(anyhow::anyhow!(
558                "Branch operation failed: {}",
559                String::from_utf8_lossy(&output.stderr)
560            ))
561        }
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    fn create_test_context() -> ToolContext {
570        ToolContext {
571            working_directory: std::env::current_dir()
572                .unwrap()
573                .to_str()
574                .unwrap()
575                .to_string(),
576            ..Default::default()
577        }
578    }
579
580    #[test]
581    fn test_get_tools() {
582        let tools = GitTool::get_tools();
583        assert_eq!(tools.len(), 11);
584        let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
585        assert!(names.contains(&"git_status"));
586        assert!(names.contains(&"git_commit"));
587        assert!(names.contains(&"git_branch"));
588    }
589
590    #[test]
591    fn test_execute_unknown_tool() {
592        let context = create_test_context();
593        let input = json!({});
594        let result = GitTool::execute("1", "unknown_tool", &input, &context);
595        assert!(result.is_error);
596    }
597}