use async_trait::async_trait;
use serde_json::Value;
use rucora_core::error::ToolError;
use rucora_core::tool::types::ToolCall;
#[derive(Debug, Clone)]
pub struct ToolCallContext {
pub tool_call: ToolCall,
}
#[async_trait]
pub trait ToolPolicy: Send + Sync {
async fn check(&self, ctx: &ToolCallContext) -> Result<(), ToolError>;
}
#[derive(Debug, Default, Clone)]
pub struct AllowAllToolPolicy;
#[async_trait]
impl ToolPolicy for AllowAllToolPolicy {
async fn check(&self, _ctx: &ToolCallContext) -> Result<(), ToolError> {
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct CommandPolicyConfig {
pub allowed_commands: Vec<String>,
pub denied_commands: Vec<String>,
}
impl CommandPolicyConfig {
pub fn new() -> Self {
Self::default()
}
pub fn allow_command(mut self, cmd: impl Into<String>) -> Self {
self.allowed_commands.push(cmd.into());
self
}
pub fn deny_command(mut self, cmd: impl Into<String>) -> Self {
self.denied_commands.push(cmd.into());
self
}
}
#[derive(Debug, Clone)]
pub struct DefaultToolPolicy {
shell: CommandPolicyConfig,
cmd_exec: CommandPolicyConfig,
}
impl Default for DefaultToolPolicy {
fn default() -> Self {
Self::new()
}
}
impl DefaultToolPolicy {
pub fn new() -> Self {
Self {
shell: CommandPolicyConfig::new(),
cmd_exec: CommandPolicyConfig::new().allow_command("curl"),
}
}
pub fn with_shell_config(mut self, cfg: CommandPolicyConfig) -> Self {
self.shell = cfg;
self
}
pub fn with_cmd_exec_config(mut self, cfg: CommandPolicyConfig) -> Self {
self.cmd_exec = cfg;
self
}
fn extract_command_line(tool_name: &str, input: &Value) -> Option<String> {
match tool_name {
"shell" => {
let command = input.get("command")?.as_str()?.trim().to_string();
let args = input
.get("args")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|x| x.as_str())
.map(|s| s.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default();
if args.is_empty() {
Some(command)
} else {
Some(format!("{} {}", command, args.join(" ")))
}
}
"cmd_exec" => Some(input.get("command")?.as_str()?.trim().to_string()),
_ => None,
}
}
fn first_token(command_line: &str) -> Option<String> {
let t = command_line.trim();
if t.is_empty() {
return None;
}
let mut token = t
.split_whitespace()
.next()?
.trim_matches('"')
.trim_matches('\'');
if token.ends_with(".exe") {
token = token.trim_end_matches(".exe");
}
Some(token.to_ascii_lowercase())
}
fn is_dangerous_command(cmd: &str) -> bool {
matches!(
cmd,
"rm" | "del"
| "erase"
| "rmdir"
| "rd"
| "format"
| "mkfs"
| "dd"
| "shutdown"
| "reboot"
| "poweroff"
| "reg"
| "diskpart"
| "bcdedit"
| "sc"
| "net"
)
}
fn contains_shell_operators(command_line: &str) -> bool {
let forbidden = ["|", "&&", ";", ">", "<", "`", "$(", "\n", "\r"];
forbidden.iter().any(|x| command_line.contains(x))
}
fn check_command(cfg: &CommandPolicyConfig, command_line: &str) -> Result<(), ToolError> {
if Self::contains_shell_operators(command_line) {
return Err(ToolError::PolicyDenied {
rule_id: "default.shell_operators".to_string(),
reason: "command contains forbidden shell operators".to_string(),
});
}
let cmd = Self::first_token(command_line).ok_or_else(|| ToolError::PolicyDenied {
rule_id: "default.empty_command".to_string(),
reason: "empty command".to_string(),
})?;
if cfg
.denied_commands
.iter()
.any(|x| x.eq_ignore_ascii_case(&cmd))
{
return Err(ToolError::PolicyDenied {
rule_id: "config.denied_command".to_string(),
reason: format!("command '{cmd}' is denied"),
});
}
if Self::is_dangerous_command(&cmd) {
return Err(ToolError::PolicyDenied {
rule_id: "default.dangerous_command".to_string(),
reason: format!("dangerous command '{cmd}' is blocked by default"),
});
}
if !cfg.allowed_commands.is_empty()
&& !cfg
.allowed_commands
.iter()
.any(|x| x.eq_ignore_ascii_case(&cmd))
{
return Err(ToolError::PolicyDenied {
rule_id: "config.not_allowed".to_string(),
reason: format!("command '{cmd}' is not in allowlist"),
});
}
Ok(())
}
}
#[async_trait]
impl ToolPolicy for DefaultToolPolicy {
async fn check(&self, ctx: &ToolCallContext) -> Result<(), ToolError> {
let name = ctx.tool_call.name.as_str();
let input = &ctx.tool_call.input;
let Some(command_line) = Self::extract_command_line(name, input) else {
return Ok(());
};
match name {
"shell" => Self::check_command(&self.shell, &command_line),
"cmd_exec" => Self::check_command(&self.cmd_exec, &command_line),
_ => Ok(()),
}
}
}