use crate::security::{check_rate_limit, validate_ai_prompt, validate_command};
use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Serialize};
use std::io::Read;
use std::process::{Command, Stdio};
use std::time::{Duration, Instant};
const MAX_EXECUTION_TIMEOUT_SECS: u64 = 30;
const MAX_OUTPUT_SIZE_BYTES: usize = 10 * 1024 * 1024;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Step {
pub tool: String,
pub args: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Plan {
pub goal: String,
pub steps: Vec<Step>,
}
pub fn plan(goal: &str) -> Result<Plan> {
let sanitized_goal =
validate_ai_prompt(goal).context("Invalid goal: prompt validation failed")?;
check_rate_limit("agent_plan", 10, Duration::from_secs(60))
.context("Rate limit exceeded for agent planning")?;
eprintln!(
"[AGENT] Creating plan for goal: {}",
if sanitized_goal.len() > 100 {
format!("{}...", &sanitized_goal[..100])
} else {
sanitized_goal.clone()
}
);
let steps = decompose_goal(&sanitized_goal);
Ok(Plan {
goal: sanitized_goal,
steps,
})
}
fn decompose_goal(goal: &str) -> Vec<Step> {
let goal_lower = goal.to_lowercase();
let mut steps = Vec::new();
if goal_lower.contains("list") && (goal_lower.contains("file") || goal_lower.contains("dir")) {
let path = extract_path_from_goal(goal).unwrap_or_else(|| ".".to_string());
steps.push(Step {
tool: "ls".to_string(),
args: serde_json::json!({"path": path}),
});
}
if goal_lower.contains("read") && goal_lower.contains("file") {
if let Some(path) = extract_path_from_goal(goal) {
steps.push(Step {
tool: "cat".to_string(),
args: serde_json::json!({"path": path}),
});
}
}
if goal_lower.contains("find") || goal_lower.contains("search") {
let pattern = extract_pattern_from_goal(goal).unwrap_or_else(|| "*".to_string());
let path = extract_path_from_goal(goal).unwrap_or_else(|| ".".to_string());
steps.push(Step {
tool: "find".to_string(),
args: serde_json::json!({"path": path, "pattern": pattern}),
});
}
if goal_lower.contains("git") {
if goal_lower.contains("status") {
steps.push(Step {
tool: "git".to_string(),
args: serde_json::json!({"command": "status"}),
});
}
if goal_lower.contains("commit") {
steps.push(Step {
tool: "git".to_string(),
args: serde_json::json!({"command": "commit"}),
});
}
if goal_lower.contains("push") {
steps.push(Step {
tool: "git".to_string(),
args: serde_json::json!({"command": "push"}),
});
}
if goal_lower.contains("pull") {
steps.push(Step {
tool: "git".to_string(),
args: serde_json::json!({"command": "pull"}),
});
}
}
if goal_lower.contains("system") || goal_lower.contains("environment") {
steps.push(Step {
tool: "env".to_string(),
args: serde_json::json!({}),
});
}
if goal_lower.contains("process") || goal_lower.contains("running") {
steps.push(Step {
tool: "ps".to_string(),
args: serde_json::json!({}),
});
}
if goal_lower.contains("download") || goal_lower.contains("fetch") {
if let Some(url) = extract_url_from_goal(goal) {
steps.push(Step {
tool: "http_get".to_string(),
args: serde_json::json!({"url": url}),
});
}
}
if steps.is_empty() {
steps.push(Step {
tool: "analyze".to_string(),
args: serde_json::json!({"goal": goal}),
});
}
steps
}
fn extract_path_from_goal(goal: &str) -> Option<String> {
if let Some(start) = goal.find('"') {
if let Some(end) = goal[start + 1..].find('"') {
return Some(goal[start + 1..start + 1 + end].to_string());
}
}
if let Some(start) = goal.find('\'') {
if let Some(end) = goal[start + 1..].find('\'') {
return Some(goal[start + 1..start + 1 + end].to_string());
}
}
for word in goal.split_whitespace() {
if word.contains('/') || word.contains('\\') || word.starts_with('.') {
if !word.contains('<') && !word.contains('>') && !word.contains('|') {
return Some(word.to_string());
}
}
}
None
}
fn extract_pattern_from_goal(goal: &str) -> Option<String> {
let patterns = ["for ", "named ", "matching ", "like "];
for pattern in patterns {
if let Some(idx) = goal.to_lowercase().find(pattern) {
let rest = &goal[idx + pattern.len()..];
if let Some(word) = rest.split_whitespace().next() {
let word = word.trim_matches(|c| c == '"' || c == '\'');
return Some(word.to_string());
}
}
}
None
}
fn extract_url_from_goal(goal: &str) -> Option<String> {
for word in goal.split_whitespace() {
if word.starts_with("http://") || word.starts_with("https://") {
return Some(word.to_string());
}
}
None
}
pub fn execute(plan: &Plan) -> Result<()> {
validate_ai_prompt(&plan.goal).context("Invalid plan goal")?;
check_rate_limit("agent_execute", 5, Duration::from_secs(60))
.context("Rate limit exceeded for agent execution")?;
eprintln!("[AGENT] Executing plan with {} steps", plan.steps.len());
for (i, step) in plan.steps.iter().enumerate() {
eprintln!(
"[AGENT] Step {}/{}: {} {:?}",
i + 1,
plan.steps.len(),
step.tool,
step.args
);
if step.tool.is_empty() {
return Err(anyhow!("Step {}: tool name is empty", i));
}
if step.tool.len() > 100 {
return Err(anyhow!(
"Step {}: tool name too long (max 100 characters)",
i
));
}
if step.tool.contains('\0') {
return Err(anyhow!("Step {}: tool name contains null byte", i));
}
let args = parse_step_args(&step.args)
.with_context(|| format!("Step {}: failed to parse arguments", i))?;
validate_command(&step.tool, &args).with_context(|| {
format!("Step {}: command validation failed for '{}'", i, step.tool)
})?;
execute_step(&step.tool, &args).with_context(|| format!("Step {}: execution failed", i))?;
}
eprintln!("[AGENT] Plan execution completed successfully");
Ok(())
}
fn parse_step_args(args: &serde_json::Value) -> Result<Vec<String>> {
match args {
serde_json::Value::Array(arr) => {
let mut result = Vec::new();
for (i, val) in arr.iter().enumerate() {
match val {
serde_json::Value::String(s) => result.push(s.clone()),
serde_json::Value::Number(n) => result.push(n.to_string()),
serde_json::Value::Bool(b) => result.push(b.to_string()),
_ => {
return Err(anyhow!(
"Argument {}: unsupported type (must be string, number, or bool)",
i
));
}
}
}
Ok(result)
}
serde_json::Value::Object(obj) => {
let mut result = Vec::new();
for (key, val) in obj.iter() {
let val_str = match val {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => continue,
};
result.push(format!("{}={}", key, val_str));
}
Ok(result)
}
serde_json::Value::String(s) => Ok(vec![s.clone()]),
serde_json::Value::Null => Ok(vec![]),
_ => Err(anyhow!("Arguments must be array, object, or string")),
}
}
fn execute_step(tool: &str, args: &[String]) -> Result<()> {
let start_time = Instant::now();
eprintln!("[AGENT] Executing: {} {:?}", tool, args);
eprintln!(
"[AGENT] Sandbox: timeout={}s, max_output={}MB",
MAX_EXECUTION_TIMEOUT_SECS,
MAX_OUTPUT_SIZE_BYTES / (1024 * 1024)
);
let mut cmd = Command::new(tool);
cmd.args(args);
configure_sandbox(&mut cmd)?;
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
cmd.stdin(Stdio::null());
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to spawn command: {}", tool))?;
let mut stdout = child
.stdout
.take()
.ok_or_else(|| anyhow!("Failed to capture stdout"))?;
let mut stderr = child
.stderr
.take()
.ok_or_else(|| anyhow!("Failed to capture stderr"))?;
let mut output = Vec::new();
let mut error = Vec::new();
let mut total_bytes = 0usize;
loop {
if start_time.elapsed() > Duration::from_secs(MAX_EXECUTION_TIMEOUT_SECS) {
let _ = child.kill();
eprintln!(
"[AGENT] ❌ Timeout exceeded ({}s)",
MAX_EXECUTION_TIMEOUT_SECS
);
return Err(anyhow!(
"Command execution timeout ({} seconds exceeded)",
MAX_EXECUTION_TIMEOUT_SECS
));
}
let mut buf = [0u8; 4096];
match stdout.read(&mut buf) {
Ok(0) => break, Ok(n) => {
total_bytes += n;
if total_bytes > MAX_OUTPUT_SIZE_BYTES {
let _ = child.kill();
eprintln!(
"[AGENT] ❌ Output size limit exceeded ({}MB)",
MAX_OUTPUT_SIZE_BYTES / (1024 * 1024)
);
return Err(anyhow!(
"Command output size limit exceeded ({} bytes)",
MAX_OUTPUT_SIZE_BYTES
));
}
output.extend_from_slice(&buf[..n]);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
if let Some(status) = child.try_wait()? {
let _ = stderr.read_to_end(&mut error);
if !status.success() {
let stderr_str = String::from_utf8_lossy(&error);
eprintln!("[AGENT] ❌ Command failed: {}", stderr_str);
return Err(anyhow!(
"Command failed with exit code {}: {}",
status.code().unwrap_or(-1),
stderr_str
));
}
break;
}
std::thread::sleep(Duration::from_millis(10));
}
Err(e) => {
let _ = child.kill();
return Err(anyhow!("Failed to read command output: {}", e));
}
}
if let Some(status) = child.try_wait()? {
let _ = stdout.read_to_end(&mut output);
let _ = stderr.read_to_end(&mut error);
if !status.success() {
let stderr_str = String::from_utf8_lossy(&error);
eprintln!("[AGENT] ❌ Command failed: {}", stderr_str);
return Err(anyhow!(
"Command failed with exit code {}: {}",
status.code().unwrap_or(-1),
stderr_str
));
}
break;
}
}
let elapsed = start_time.elapsed();
let stdout_str = String::from_utf8_lossy(&output);
eprintln!(
"[AGENT] ✓ Completed in {:.2}s ({} bytes output)",
elapsed.as_secs_f64(),
total_bytes
);
if !stdout_str.is_empty() {
let preview = if stdout_str.len() > 200 {
format!("{}...", &stdout_str[..200])
} else {
stdout_str.to_string()
};
eprintln!("[AGENT] Output: {}", preview.replace('\n', " "));
}
Ok(())
}
#[cfg(target_os = "windows")]
fn configure_sandbox(_cmd: &mut Command) -> Result<()> {
eprintln!("[AGENT] Sandbox: Windows (timeout + output limits)");
Ok(())
}
#[cfg(target_os = "linux")]
fn configure_sandbox(cmd: &mut Command) -> Result<()> {
use std::os::unix::process::CommandExt;
unsafe {
cmd.pre_exec(|| {
use libc::{rlimit, setrlimit, RLIMIT_AS, RLIMIT_CPU, RLIMIT_FSIZE};
let cpu_limit = rlimit {
rlim_cur: MAX_EXECUTION_TIMEOUT_SECS,
rlim_max: MAX_EXECUTION_TIMEOUT_SECS,
};
if setrlimit(RLIMIT_CPU, &cpu_limit) != 0 {
eprintln!("[AGENT] Warning: Failed to set CPU limit");
}
let mem_limit = rlimit {
rlim_cur: 512 * 1024 * 1024,
rlim_max: 512 * 1024 * 1024,
};
if setrlimit(RLIMIT_AS, &mem_limit) != 0 {
eprintln!("[AGENT] Warning: Failed to set memory limit");
}
let file_limit = rlimit {
rlim_cur: 100 * 1024 * 1024,
rlim_max: 100 * 1024 * 1024,
};
if setrlimit(RLIMIT_FSIZE, &file_limit) != 0 {
eprintln!("[AGENT] Warning: Failed to set file size limit");
}
Ok(())
});
}
eprintln!(
"[AGENT] Sandbox: Linux (setrlimit: CPU={}s, MEM=512MB, FSIZE=100MB)",
MAX_EXECUTION_TIMEOUT_SECS
);
Ok(())
}
#[cfg(target_os = "macos")]
fn configure_sandbox(cmd: &mut Command) -> Result<()> {
use std::os::unix::process::CommandExt;
unsafe {
cmd.pre_exec(|| {
use libc::{rlimit, setrlimit, RLIMIT_AS, RLIMIT_CPU, RLIMIT_FSIZE};
let cpu_limit = rlimit {
rlim_cur: MAX_EXECUTION_TIMEOUT_SECS,
rlim_max: MAX_EXECUTION_TIMEOUT_SECS,
};
if setrlimit(RLIMIT_CPU, &cpu_limit) != 0 {
eprintln!("[AGENT] Warning: Failed to set CPU limit");
}
let mem_limit = rlimit {
rlim_cur: 512 * 1024 * 1024,
rlim_max: 512 * 1024 * 1024,
};
if setrlimit(RLIMIT_AS, &mem_limit) != 0 {
eprintln!("[AGENT] Warning: Failed to set memory limit");
}
let file_limit = rlimit {
rlim_cur: 100 * 1024 * 1024,
rlim_max: 100 * 1024 * 1024,
};
if setrlimit(RLIMIT_FSIZE, &file_limit) != 0 {
eprintln!("[AGENT] Warning: Failed to set file size limit");
}
Ok(())
});
}
eprintln!(
"[AGENT] Sandbox: macOS (setrlimit: CPU={}s, MEM=512MB, FSIZE=100MB)",
MAX_EXECUTION_TIMEOUT_SECS
);
Ok(())
}
#[cfg(not(any(target_os = "windows", target_os = "linux", target_os = "macos")))]
fn configure_sandbox(_cmd: &mut Command) -> Result<()> {
eprintln!("[AGENT] Sandbox: Fallback (timeout + output limits only)");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{configure_command_security, CommandSecurityConfig};
use std::collections::HashSet;
#[test]
fn test_plan_validation() {
let result = plan("List all files in the current directory");
assert!(result.is_ok());
let long_goal = "a".repeat(5000);
let result = plan(&long_goal);
assert!(result.is_err());
let result = plan("");
assert!(result.is_err());
}
#[test]
fn test_plan_decomposition_files() {
let result = plan("List all files in /tmp directory").unwrap();
assert!(!result.steps.is_empty());
assert_eq!(result.steps[0].tool, "ls");
let result = plan("Read file '/etc/hosts'").unwrap();
assert!(!result.steps.is_empty());
assert_eq!(result.steps[0].tool, "cat");
}
#[test]
fn test_plan_decomposition_git() {
let result = plan("Check git status").unwrap();
assert!(!result.steps.is_empty());
assert_eq!(result.steps[0].tool, "git");
let result = plan("Push changes to git").unwrap();
assert!(!result.steps.is_empty());
assert_eq!(result.steps[0].tool, "git");
}
#[test]
fn test_plan_decomposition_search() {
let result = plan("Find files matching '*.rs' in src/").unwrap();
assert!(!result.steps.is_empty());
assert_eq!(result.steps[0].tool, "find");
}
#[test]
fn test_path_extraction() {
assert_eq!(
extract_path_from_goal("read file '/tmp/test.txt'"),
Some("/tmp/test.txt".to_string())
);
assert_eq!(
extract_path_from_goal("list ./src directory"),
Some("./src".to_string())
);
assert_eq!(
extract_path_from_goal("check /var/log"),
Some("/var/log".to_string())
);
}
#[test]
fn test_pattern_extraction() {
assert_eq!(
extract_pattern_from_goal("find files matching '*.rs'"),
Some("*.rs".to_string())
);
assert_eq!(
extract_pattern_from_goal("search for TODO"),
Some("TODO".to_string())
);
assert_eq!(
extract_pattern_from_goal("files named config.json"),
Some("config.json".to_string())
);
}
#[test]
fn test_url_extraction() {
assert_eq!(
extract_url_from_goal("download from https://example.com/file.txt"),
Some("https://example.com/file.txt".to_string())
);
assert_eq!(
extract_url_from_goal("fetch http://api.test.com/data"),
Some("http://api.test.com/data".to_string())
);
assert_eq!(extract_url_from_goal("no url here"), None);
}
#[test]
fn test_execute_empty_plan() {
let plan = Plan {
goal: "Test goal".to_string(),
steps: vec![],
};
let result = execute(&plan);
assert!(result.is_ok());
}
#[test]
fn test_execute_with_command_validation() {
let mut allowed = HashSet::new();
#[cfg(target_os = "windows")]
allowed.insert("cmd".to_string());
#[cfg(not(target_os = "windows"))]
allowed.insert("echo".to_string());
let config = CommandSecurityConfig {
allowed_commands: allowed,
log_attempts: false,
max_command_length: 1000,
max_args: 50,
};
configure_command_security(config).unwrap();
#[cfg(target_os = "windows")]
let plan = Plan {
goal: "Test".to_string(),
steps: vec![Step {
tool: "cmd".to_string(),
args: serde_json::json!(["/C", "echo", "hello"]),
}],
};
#[cfg(not(target_os = "windows"))]
let plan = Plan {
goal: "Test".to_string(),
steps: vec![Step {
tool: "echo".to_string(),
args: serde_json::json!(["hello", "world"]),
}],
};
let result = execute(&plan);
if let Err(e) = &result {
eprintln!("Error executing command: {:?}", e);
}
assert!(result.is_ok(), "command should succeed");
let plan = Plan {
goal: "Test".to_string(),
steps: vec![Step {
tool: "rm".to_string(),
args: serde_json::json!(["-rf", "/"]),
}],
};
let result = execute(&plan);
assert!(result.is_err(), "rm command should be blocked");
}
#[test]
fn test_parse_step_args() {
let args = serde_json::json!(["arg1", "arg2", 123]);
let result = parse_step_args(&args).unwrap();
assert_eq!(result, vec!["arg1", "arg2", "123"]);
let args = serde_json::json!({"key": "value", "count": 42});
let result = parse_step_args(&args).unwrap();
assert!(result.contains(&"key=value".to_string()));
assert!(result.contains(&"count=42".to_string()));
let args = serde_json::json!("single-arg");
let result = parse_step_args(&args).unwrap();
assert_eq!(result, vec!["single-arg"]);
let args = serde_json::json!(null);
let result = parse_step_args(&args).unwrap();
assert_eq!(result, Vec::<String>::new());
}
}