1use anyhow::Result;
2use git2::Repository;
3use serde::Deserialize;
4use serde_json::{Value, json};
5use std::collections::HashMap;
6
7use brainwires_core::{Tool, ToolContext, ToolInputSchema, ToolResult};
8
9pub struct GitTool;
11
12impl GitTool {
13 pub fn get_tools() -> Vec<Tool> {
15 vec![
16 Self::git_status_tool(),
17 Self::git_diff_tool(),
18 Self::git_log_tool(),
19 Self::git_stage_tool(),
20 Self::git_unstage_tool(),
21 Self::git_commit_tool(),
22 Self::git_push_tool(),
23 Self::git_pull_tool(),
24 Self::git_fetch_tool(),
25 Self::git_discard_tool(),
26 Self::git_branch_tool(),
27 ]
28 }
29
30 fn git_status_tool() -> Tool {
31 Tool {
32 name: "git_status".to_string(),
33 description: "Get git repository status".to_string(),
34 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
35 requires_approval: false,
36 ..Default::default()
37 }
38 }
39
40 fn git_diff_tool() -> Tool {
41 Tool {
42 name: "git_diff".to_string(),
43 description: "Get git diff of changes".to_string(),
44 input_schema: ToolInputSchema::object(HashMap::new(), vec![]),
45 requires_approval: false,
46 ..Default::default()
47 }
48 }
49
50 fn git_log_tool() -> Tool {
51 let mut properties = HashMap::new();
52 properties.insert(
53 "limit".to_string(),
54 json!({"type": "number", "description": "Number of commits", "default": 10}),
55 );
56 Tool {
57 name: "git_log".to_string(),
58 description: "Get git commit history".to_string(),
59 input_schema: ToolInputSchema::object(properties, vec![]),
60 requires_approval: false,
61 ..Default::default()
62 }
63 }
64
65 fn git_stage_tool() -> Tool {
66 let mut properties = HashMap::new();
67 properties.insert("files".to_string(), json!({"type": "array", "items": {"type": "string"}, "description": "Files to stage. Use '.' for all."}));
68 Tool {
69 name: "git_stage".to_string(),
70 description: "Stage files for commit.".to_string(),
71 input_schema: ToolInputSchema::object(properties, vec!["files".to_string()]),
72 requires_approval: true,
73 ..Default::default()
74 }
75 }
76
77 fn git_unstage_tool() -> Tool {
78 let mut properties = HashMap::new();
79 properties.insert("files".to_string(), json!({"type": "array", "items": {"type": "string"}, "description": "Files to unstage."}));
80 Tool {
81 name: "git_unstage".to_string(),
82 description: "Unstage files from the staging area.".to_string(),
83 input_schema: ToolInputSchema::object(properties, vec!["files".to_string()]),
84 requires_approval: true,
85 ..Default::default()
86 }
87 }
88
89 fn git_commit_tool() -> Tool {
90 let mut properties = HashMap::new();
91 properties.insert(
92 "message".to_string(),
93 json!({"type": "string", "description": "Commit message"}),
94 );
95 properties.insert("all".to_string(), json!({"type": "boolean", "description": "Stage all modified files before committing", "default": false}));
96 Tool {
97 name: "git_commit".to_string(),
98 description: "Create a git commit with staged changes.".to_string(),
99 input_schema: ToolInputSchema::object(properties, vec!["message".to_string()]),
100 requires_approval: true,
101 ..Default::default()
102 }
103 }
104
105 fn git_push_tool() -> Tool {
106 let mut properties = HashMap::new();
107 properties.insert("remote".to_string(), json!({"type": "string", "description": "Remote name (default: origin)", "default": "origin"}));
108 properties.insert(
109 "branch".to_string(),
110 json!({"type": "string", "description": "Branch to push"}),
111 );
112 properties.insert("set_upstream".to_string(), json!({"type": "boolean", "description": "Set upstream tracking (-u)", "default": false}));
113 Tool {
114 name: "git_push".to_string(),
115 description: "Push commits to a remote repository.".to_string(),
116 input_schema: ToolInputSchema::object(properties, vec![]),
117 requires_approval: true,
118 ..Default::default()
119 }
120 }
121
122 fn git_pull_tool() -> Tool {
123 let mut properties = HashMap::new();
124 properties.insert("remote".to_string(), json!({"type": "string", "description": "Remote name (default: origin)", "default": "origin"}));
125 properties.insert(
126 "branch".to_string(),
127 json!({"type": "string", "description": "Branch to pull"}),
128 );
129 properties.insert("rebase".to_string(), json!({"type": "boolean", "description": "Use rebase instead of merge", "default": false}));
130 Tool {
131 name: "git_pull".to_string(),
132 description: "Pull changes from a remote repository.".to_string(),
133 input_schema: ToolInputSchema::object(properties, vec![]),
134 requires_approval: true,
135 ..Default::default()
136 }
137 }
138
139 fn git_fetch_tool() -> Tool {
140 let mut properties = HashMap::new();
141 properties.insert("remote".to_string(), json!({"type": "string", "description": "Remote name (default: origin)", "default": "origin"}));
142 properties.insert(
143 "all".to_string(),
144 json!({"type": "boolean", "description": "Fetch all remotes", "default": false}),
145 );
146 properties.insert("prune".to_string(), json!({"type": "boolean", "description": "Remove stale remote-tracking refs", "default": false}));
147 Tool {
148 name: "git_fetch".to_string(),
149 description: "Fetch changes from a remote without merging.".to_string(),
150 input_schema: ToolInputSchema::object(properties, vec![]),
151 requires_approval: false,
152 ..Default::default()
153 }
154 }
155
156 fn git_discard_tool() -> Tool {
157 let mut properties = HashMap::new();
158 properties.insert("files".to_string(), json!({"type": "array", "items": {"type": "string"}, "description": "Files to discard changes for."}));
159 Tool {
160 name: "git_discard".to_string(),
161 description: "Discard uncommitted changes. WARNING: Permanent!".to_string(),
162 input_schema: ToolInputSchema::object(properties, vec!["files".to_string()]),
163 requires_approval: true,
164 ..Default::default()
165 }
166 }
167
168 fn git_branch_tool() -> Tool {
169 let mut properties = HashMap::new();
170 properties.insert(
171 "name".to_string(),
172 json!({"type": "string", "description": "Branch name"}),
173 );
174 properties.insert("action".to_string(), json!({"type": "string", "enum": ["list", "create", "switch", "delete"], "description": "Action to perform", "default": "list"}));
175 properties.insert(
176 "force".to_string(),
177 json!({"type": "boolean", "description": "Force the action", "default": false}),
178 );
179 Tool {
180 name: "git_branch".to_string(),
181 description: "Manage git branches: list, create, switch, or delete.".to_string(),
182 input_schema: ToolInputSchema::object(properties, vec![]),
183 requires_approval: true,
184 ..Default::default()
185 }
186 }
187
188 #[tracing::instrument(name = "tool.execute", skip(input, context), fields(tool_name))]
190 pub fn execute(
191 tool_use_id: &str,
192 tool_name: &str,
193 input: &Value,
194 context: &ToolContext,
195 ) -> ToolResult {
196 let result = match tool_name {
197 "git_status" => Self::git_status(context),
198 "git_diff" => Self::git_diff(context),
199 "git_log" => Self::git_log(input, context),
200 "git_stage" => Self::git_stage(input, context),
201 "git_unstage" => Self::git_unstage(input, context),
202 "git_commit" => Self::git_commit(input, context),
203 "git_push" => Self::git_push(input, context),
204 "git_pull" => Self::git_pull(input, context),
205 "git_fetch" => Self::git_fetch(input, context),
206 "git_discard" => Self::git_discard(input, context),
207 "git_branch" => Self::git_branch(input, context),
208 _ => Err(anyhow::anyhow!("Unknown git tool: {}", tool_name)),
209 };
210 match result {
211 Ok(output) => ToolResult::success(tool_use_id.to_string(), output),
212 Err(e) => ToolResult::error(
213 tool_use_id.to_string(),
214 format!("Git operation failed: {}", e),
215 ),
216 }
217 }
218
219 fn git_status(context: &ToolContext) -> Result<String> {
220 let repo = Repository::discover(&context.working_directory)?;
221 let statuses = repo.statuses(None)?;
222 let mut output = String::from("Git Status:\n\n");
223 for entry in statuses.iter() {
224 let path = entry.path().unwrap_or("?");
225 let status = entry.status();
226 output.push_str(&format!("{:?} - {}\n", status, path));
227 }
228 Ok(output)
229 }
230
231 fn git_diff(context: &ToolContext) -> Result<String> {
232 let repo = Repository::discover(&context.working_directory)?;
233 let head = repo.head()?.peel_to_tree()?;
234 let diff = repo.diff_tree_to_workdir_with_index(Some(&head), None)?;
235 Ok(format!("Git Diff:\n{} files changed", diff.deltas().len()))
236 }
237
238 fn git_log(input: &Value, context: &ToolContext) -> Result<String> {
239 #[derive(Deserialize)]
240 struct Input {
241 #[serde(default = "default_limit")]
242 limit: usize,
243 }
244 fn default_limit() -> usize {
245 10
246 }
247 let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input { limit: 10 });
248 let repo = Repository::discover(&context.working_directory)?;
249 let mut revwalk = repo.revwalk()?;
250 revwalk.push_head()?;
251 let mut output = String::from("Git Log:\n\n");
252 for (i, oid) in revwalk.enumerate() {
253 if i >= params.limit {
254 break;
255 }
256 let commit = repo.find_commit(oid?)?;
257 output.push_str(&format!(
258 "{} - {}\n",
259 commit.id(),
260 commit.summary().unwrap_or("No message")
261 ));
262 }
263 Ok(output)
264 }
265
266 fn git_stage(input: &Value, context: &ToolContext) -> Result<String> {
267 #[derive(Deserialize)]
268 struct Input {
269 files: Vec<String>,
270 }
271 let params: Input = serde_json::from_value(input.clone())?;
272 let mut cmd = std::process::Command::new("git");
273 cmd.current_dir(&context.working_directory).arg("add");
274 for file in ¶ms.files {
275 cmd.arg(file);
276 }
277 let output = cmd.output()?;
278 if output.status.success() {
279 Ok(format!(
280 "Successfully staged {} file(s)",
281 params.files.len()
282 ))
283 } else {
284 Err(anyhow::anyhow!(
285 "Failed to stage files: {}",
286 String::from_utf8_lossy(&output.stderr)
287 ))
288 }
289 }
290
291 fn git_unstage(input: &Value, context: &ToolContext) -> Result<String> {
292 #[derive(Deserialize)]
293 struct Input {
294 files: Vec<String>,
295 }
296 let params: Input = serde_json::from_value(input.clone())?;
297 let mut cmd = std::process::Command::new("git");
298 cmd.current_dir(&context.working_directory)
299 .args(["reset", "HEAD", "--"]);
300 for file in ¶ms.files {
301 cmd.arg(file);
302 }
303 let output = cmd.output()?;
304 if output.status.success() {
305 Ok(format!(
306 "Successfully unstaged {} file(s)",
307 params.files.len()
308 ))
309 } else {
310 Err(anyhow::anyhow!(
311 "Failed to unstage files: {}",
312 String::from_utf8_lossy(&output.stderr)
313 ))
314 }
315 }
316
317 fn git_commit(input: &Value, context: &ToolContext) -> Result<String> {
318 #[derive(Deserialize)]
319 struct Input {
320 message: String,
321 #[serde(default)]
322 all: bool,
323 }
324 let params: Input = serde_json::from_value(input.clone())?;
325 let mut cmd = std::process::Command::new("git");
326 cmd.current_dir(&context.working_directory).arg("commit");
327 if params.all {
328 cmd.arg("-a");
329 }
330 cmd.args(["-m", ¶ms.message]);
331 let output = cmd.output()?;
332 if output.status.success() {
333 Ok(format!(
334 "Commit successful:\n{}",
335 String::from_utf8_lossy(&output.stdout)
336 ))
337 } else {
338 Err(anyhow::anyhow!(
339 "Commit failed: {}",
340 String::from_utf8_lossy(&output.stderr)
341 ))
342 }
343 }
344
345 fn git_push(input: &Value, context: &ToolContext) -> Result<String> {
346 #[derive(Deserialize)]
347 struct Input {
348 #[serde(default = "dr")]
349 remote: String,
350 branch: Option<String>,
351 #[serde(default)]
352 set_upstream: bool,
353 }
354 fn dr() -> String {
355 "origin".to_string()
356 }
357 let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
358 remote: "origin".to_string(),
359 branch: None,
360 set_upstream: false,
361 });
362 let mut cmd = std::process::Command::new("git");
363 cmd.current_dir(&context.working_directory).arg("push");
364 if params.set_upstream {
365 cmd.arg("-u");
366 }
367 cmd.arg(¶ms.remote);
368 if let Some(ref branch) = params.branch {
369 cmd.arg(branch);
370 }
371 let output = cmd.output()?;
372 if output.status.success() {
373 Ok(format!(
374 "Push successful:\n{}{}",
375 String::from_utf8_lossy(&output.stdout),
376 String::from_utf8_lossy(&output.stderr)
377 ))
378 } else {
379 Err(anyhow::anyhow!(
380 "Push failed: {}",
381 String::from_utf8_lossy(&output.stderr)
382 ))
383 }
384 }
385
386 fn git_pull(input: &Value, context: &ToolContext) -> Result<String> {
387 #[derive(Deserialize)]
388 struct Input {
389 #[serde(default = "dr")]
390 remote: String,
391 branch: Option<String>,
392 #[serde(default)]
393 rebase: bool,
394 }
395 fn dr() -> String {
396 "origin".to_string()
397 }
398 let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
399 remote: "origin".to_string(),
400 branch: None,
401 rebase: false,
402 });
403 let mut cmd = std::process::Command::new("git");
404 cmd.current_dir(&context.working_directory).arg("pull");
405 if params.rebase {
406 cmd.arg("--rebase");
407 }
408 cmd.arg(¶ms.remote);
409 if let Some(ref branch) = params.branch {
410 cmd.arg(branch);
411 }
412 let output = cmd.output()?;
413 if output.status.success() {
414 Ok(format!(
415 "Pull successful:\n{}",
416 String::from_utf8_lossy(&output.stdout)
417 ))
418 } else {
419 Err(anyhow::anyhow!(
420 "Pull failed: {}",
421 String::from_utf8_lossy(&output.stderr)
422 ))
423 }
424 }
425
426 fn git_fetch(input: &Value, context: &ToolContext) -> Result<String> {
427 #[derive(Deserialize)]
428 struct Input {
429 #[serde(default = "dr")]
430 remote: String,
431 #[serde(default)]
432 all: bool,
433 #[serde(default)]
434 prune: bool,
435 }
436 fn dr() -> String {
437 "origin".to_string()
438 }
439 let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
440 remote: "origin".to_string(),
441 all: false,
442 prune: false,
443 });
444 let mut cmd = std::process::Command::new("git");
445 cmd.current_dir(&context.working_directory).arg("fetch");
446 if params.all {
447 cmd.arg("--all");
448 } else {
449 cmd.arg(¶ms.remote);
450 }
451 if params.prune {
452 cmd.arg("--prune");
453 }
454 let output = cmd.output()?;
455 if output.status.success() {
456 let stdout = String::from_utf8_lossy(&output.stdout);
457 let stderr = String::from_utf8_lossy(&output.stderr);
458 let fetch_output = if stdout.is_empty() && stderr.is_empty() {
459 "Already up to date.".to_string()
460 } else {
461 format!("{}{}", stdout, stderr)
462 };
463 Ok(format!("Fetch successful:\n{}", fetch_output))
464 } else {
465 Err(anyhow::anyhow!(
466 "Fetch failed: {}",
467 String::from_utf8_lossy(&output.stderr)
468 ))
469 }
470 }
471
472 fn git_discard(input: &Value, context: &ToolContext) -> Result<String> {
473 #[derive(Deserialize)]
474 struct Input {
475 files: Vec<String>,
476 }
477 let params: Input = serde_json::from_value(input.clone())?;
478 let mut cmd = std::process::Command::new("git");
479 cmd.current_dir(&context.working_directory)
480 .args(["checkout", "--"]);
481 for file in ¶ms.files {
482 cmd.arg(file);
483 }
484 let output = cmd.output()?;
485 if output.status.success() {
486 Ok(format!(
487 "Successfully discarded changes to {} file(s)",
488 params.files.len()
489 ))
490 } else {
491 Err(anyhow::anyhow!(
492 "Failed to discard changes: {}",
493 String::from_utf8_lossy(&output.stderr)
494 ))
495 }
496 }
497
498 fn git_branch(input: &Value, context: &ToolContext) -> Result<String> {
499 #[derive(Deserialize)]
500 struct Input {
501 name: Option<String>,
502 #[serde(default = "da")]
503 action: String,
504 #[serde(default)]
505 force: bool,
506 }
507 fn da() -> String {
508 "list".to_string()
509 }
510 let params: Input = serde_json::from_value(input.clone()).unwrap_or(Input {
511 name: None,
512 action: "list".to_string(),
513 force: false,
514 });
515 let branch_name = params.name.clone();
516 let mut cmd = std::process::Command::new("git");
517 cmd.current_dir(&context.working_directory);
518 match params.action.as_str() {
519 "list" => {
520 cmd.args(["branch", "-a", "-v"]);
521 }
522 "create" => {
523 let n = params
524 .name
525 .ok_or_else(|| anyhow::anyhow!("Branch name required"))?;
526 cmd.args(["branch", &n]);
527 }
528 "switch" => {
529 let n = params
530 .name
531 .ok_or_else(|| anyhow::anyhow!("Branch name required"))?;
532 cmd.args(["checkout", &n]);
533 }
534 "delete" => {
535 let n = params
536 .name
537 .ok_or_else(|| anyhow::anyhow!("Branch name required"))?;
538 if params.force {
539 cmd.args(["branch", "-D", &n]);
540 } else {
541 cmd.args(["branch", "-d", &n]);
542 }
543 }
544 _ => return Err(anyhow::anyhow!("Unknown branch action: {}", params.action)),
545 }
546 let output = cmd.output()?;
547 if output.status.success() {
548 let stdout = String::from_utf8_lossy(&output.stdout);
549 Ok(match params.action.as_str() {
550 "list" => format!("Branches:\n{}", stdout),
551 "create" => format!("Created branch '{}'", branch_name.unwrap_or_default()),
552 "switch" => format!("Switched to branch '{}'", branch_name.unwrap_or_default()),
553 "delete" => format!("Deleted branch '{}'", branch_name.unwrap_or_default()),
554 _ => stdout.to_string(),
555 })
556 } else {
557 Err(anyhow::anyhow!(
558 "Branch operation failed: {}",
559 String::from_utf8_lossy(&output.stderr)
560 ))
561 }
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 fn create_test_context() -> ToolContext {
570 ToolContext {
571 working_directory: std::env::current_dir()
572 .unwrap()
573 .to_str()
574 .unwrap()
575 .to_string(),
576 ..Default::default()
577 }
578 }
579
580 #[test]
581 fn test_get_tools() {
582 let tools = GitTool::get_tools();
583 assert_eq!(tools.len(), 11);
584 let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
585 assert!(names.contains(&"git_status"));
586 assert!(names.contains(&"git_commit"));
587 assert!(names.contains(&"git_branch"));
588 }
589
590 #[test]
591 fn test_execute_unknown_tool() {
592 let context = create_test_context();
593 let input = json!({});
594 let result = GitTool::execute("1", "unknown_tool", &input, &context);
595 assert!(result.is_error);
596 }
597}