use async_trait::async_trait;
use rucora_core::{
error::ToolError,
tool::{Tool, ToolCategory},
};
use serde_json::{Value, json};
use std::path::{Path, PathBuf};
const ALLOWED_COMMANDS: &[&str] = &[
"status",
"log",
"diff",
"show",
"branch",
"remote",
"config",
"rev-parse",
"add",
"commit",
"checkout",
"stash",
"reset",
"revert",
"merge",
"rebase",
"pull",
"push",
"fetch",
"clone",
"init",
"describe",
"tag",
];
const WRITE_COMMANDS: &[&str] = &[
"add", "commit", "checkout", "stash", "reset", "revert", "merge", "rebase", "pull", "push",
"fetch", "clone", "init",
];
const FORBIDDEN_ARGS: &[&str] = &[
"--exec",
"--upload-pack",
"--receive-pack",
"--pager",
"--editor",
"--no-verify",
"--no-gpg-sign",
"-c",
];
pub struct GitTool {
allowed_roots: Option<Vec<PathBuf>>,
allow_write: bool,
}
impl GitTool {
pub fn new() -> Self {
Self {
allowed_roots: None,
allow_write: true,
}
}
pub fn with_allowed_roots(mut self, roots: Vec<PathBuf>) -> Self {
self.allowed_roots = Some(roots);
self
}
pub fn with_allow_write(mut self, allow: bool) -> Self {
self.allow_write = allow;
self
}
fn is_write_command(&self, command: &str) -> bool {
WRITE_COMMANDS.contains(&command)
}
fn validate_command(&self, command: &str) -> Result<(), ToolError> {
let cmd_lower = command.to_lowercase();
if !ALLOWED_COMMANDS.contains(&cmd_lower.as_str()) {
return Err(ToolError::Message(format!(
"不支持的 Git 命令:{command}(允许的命令:{ALLOWED_COMMANDS:?})"
)));
}
if self.is_write_command(&cmd_lower) && !self.allow_write {
return Err(ToolError::Message(format!(
"Git 写入操作已被禁用:{command}"
)));
}
Ok(())
}
fn validate_path(&self, path: &str) -> Result<PathBuf, ToolError> {
let path = Path::new(path);
let canonical_path = if path.is_absolute() {
path.canonicalize().unwrap_or_else(|_| path.to_path_buf())
} else {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join(path)
.canonicalize()
.unwrap_or_else(|_| path.to_path_buf())
};
if let Some(allowed_roots) = &self.allowed_roots {
let is_allowed = allowed_roots
.iter()
.any(|root| canonical_path.starts_with(root));
if !is_allowed {
return Err(ToolError::Message(format!(
"Git 仓库路径不在允许的范围内(允许的根目录:{allowed_roots:?})"
)));
}
}
let path_str = canonical_path.to_string_lossy().to_lowercase();
let forbidden_prefixes = [
"/etc/",
"/proc/",
"/sys/",
"/dev/",
"/boot/",
"/bin/",
"/sbin/",
"c:\\windows\\",
"c:\\program files",
];
for prefix in &forbidden_prefixes {
if path_str.starts_with(prefix) {
return Err(ToolError::Message(format!(
"禁止在系统敏感路径执行 Git 操作:{}",
canonical_path.display()
)));
}
}
Ok(canonical_path)
}
fn sanitize_args(&self, args: &[String]) -> Result<Vec<String>, ToolError> {
let mut result = Vec::with_capacity(args.len());
for arg in args {
let arg_lower = arg.to_lowercase();
for forbidden in FORBIDDEN_ARGS {
if arg_lower.starts_with(&forbidden.to_lowercase()) {
return Err(ToolError::Message(format!(
"禁止使用 Git 参数:{arg}(存在安全风险)"
)));
}
}
if arg.contains("$(")
|| arg.contains('`')
|| arg.contains('|')
|| arg.contains(';')
|| arg.contains("&&")
|| arg.contains("||")
|| arg.contains('>')
|| arg.contains('<')
|| arg.contains('\n')
|| arg.contains('\r')
{
return Err(ToolError::Message(format!(
"参数包含危险字符,可能存在注入风险:{arg}"
)));
}
if arg.contains("..\\") || arg.contains("../") {
if arg.starts_with("..")
|| arg.ends_with("..")
|| arg.contains("/../")
|| arg.contains("\\..\\")
{
return Err(ToolError::Message(format!("参数包含路径遍历:{arg}")));
}
}
result.push(arg.clone());
}
Ok(result)
}
}
impl Default for GitTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for GitTool {
fn name(&self) -> &str {
"git"
}
fn description(&self) -> Option<&str> {
Some("执行 Git 操作(有安全限制:命令白名单、参数检查、路径限制)")
}
fn categories(&self) -> &'static [ToolCategory] {
&[ToolCategory::System]
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Git 命令,如 status、log、add、commit"
},
"args": {
"type": "array",
"items": {
"type": "string"
},
"description": "命令参数列表"
},
"path": {
"type": "string",
"description": "Git 仓库路径,默认为当前目录",
"default": "."
}
},
"required": ["command"]
})
}
async fn call(&self, input: Value) -> Result<Value, ToolError> {
let command = input
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::Message("缺少必需的 'command' 字段".to_string()))?;
self.validate_command(command)?;
let path_str = input.get("path").and_then(|v| v.as_str()).unwrap_or(".");
let work_dir = self.validate_path(path_str)?;
let args_vec: Vec<String> = input
.get("args")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let sanitized_args = self.sanitize_args(&args_vec)?;
let output = tokio::process::Command::new("git")
.arg(command)
.args(&sanitized_args)
.current_dir(&work_dir)
.env_clear()
.env("PATH", std::env::var("PATH").unwrap_or_default())
.env("HOME", std::env::var("HOME").unwrap_or_default())
.env(
"USERPROFILE",
std::env::var("USERPROFILE").unwrap_or_default(),
)
.output()
.await
.map_err(|e| ToolError::Message(format!("Git 命令执行失败:{e}")))?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let is_write = self.is_write_command(&command.to_lowercase());
Ok(json!({
"command": command,
"args": sanitized_args,
"work_dir": work_dir.display().to_string(),
"stdout": stdout,
"stderr": stderr,
"success": output.status.success(),
"is_write_operation": is_write
}))
}
}