use crate::tools::Tool;
use async_openai::types::{ChatCompletionTool, ChatCompletionToolType, FunctionObject};
use async_trait::async_trait;
use schemars::schema_for;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::error::Error;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CommandOutput {
pub exit_code: i32,
pub stdout: String,
pub stderr: String,
}
#[derive(Clone, Debug)]
pub struct ListFilesTool;
#[derive(Deserialize, schemars::JsonSchema)]
struct ListFilesArgs {
path: String,
}
#[async_trait]
impl Tool for ListFilesTool {
fn name(&self) -> String {
"list_files".to_string()
}
fn schema(&self) -> ChatCompletionTool {
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: self.name(),
description: Some(
"Lists files and directories at a given path in the sandbox.".to_string(),
),
parameters: Some(schema_for!(ListFilesArgs).into()),
strict: None,
},
}
}
async fn call(&self, args: Value) -> Result<String, Box<dyn Error + Send + Sync>> {
let args: ListFilesArgs = serde_json::from_value(args)?;
let dummy_files = vec!["README.md", "src/", "Cargo.toml"];
tracing::debug!(path = %args.path, "Listing files (placeholder)");
Ok(serde_json::to_string(&dummy_files)?)
}
}
#[derive(Clone, Debug)]
pub struct ReadFileTool;
#[derive(Deserialize, schemars::JsonSchema)]
struct ReadFileArgs {
path: String,
}
#[async_trait]
impl Tool for ReadFileTool {
fn name(&self) -> String {
"read_file".to_string()
}
fn schema(&self) -> ChatCompletionTool {
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: self.name(),
description: Some(
"Reads the entire content of a file from the sandbox.".to_string(),
),
parameters: Some(schema_for!(ReadFileArgs).into()),
strict: None,
},
}
}
async fn call(&self, args: Value) -> Result<String, Box<dyn Error + Send + Sync>> {
let args: ReadFileArgs = serde_json::from_value(args)?;
let dummy_content = format!(
"// Contents of file: {}\nfn main() {{ println!(\"Hello, world!\"); }}",
args.path
);
tracing::debug!(path = %args.path, "Reading file (placeholder)");
Ok(serde_json::to_string(&dummy_content)?)
}
}
#[derive(Clone, Debug)]
pub struct WriteFileTool;
#[derive(Deserialize, schemars::JsonSchema)]
struct WriteFileArgs {
path: String,
content: String,
}
#[async_trait]
impl Tool for WriteFileTool {
fn name(&self) -> String {
"write_file".to_string()
}
fn schema(&self) -> ChatCompletionTool {
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: self.name(),
description: Some("Writes or overwrites a file in the sandbox.".to_string()),
parameters: Some(schema_for!(WriteFileArgs).into()),
strict: None,
},
}
}
async fn call(&self, args: Value) -> Result<String, Box<dyn Error + Send + Sync>> {
let args: WriteFileArgs = serde_json::from_value(args)?;
tracing::debug!(path = %args.path, bytes = args.content.len(), "Writing file (placeholder)");
Ok(json!({"status": "success", "bytes_written": args.content.len()}).to_string())
}
}
#[derive(Clone, Debug)]
pub struct ExecuteCommandTool;
#[derive(Deserialize, schemars::JsonSchema)]
struct ExecuteCommandArgs {
command: String,
args: Vec<String>,
}
#[async_trait]
impl Tool for ExecuteCommandTool {
fn name(&self) -> String {
"execute_command".to_string()
}
fn schema(&self) -> ChatCompletionTool {
ChatCompletionTool {
r#type: ChatCompletionToolType::Function,
function: FunctionObject {
name: self.name(),
description: Some("Executes a shell command in the sandbox.".to_string()),
parameters: Some(schema_for!(ExecuteCommandArgs).into()),
strict: None,
},
}
}
async fn call(&self, args: Value) -> Result<String, Box<dyn Error + Send + Sync>> {
let args: ExecuteCommandArgs = serde_json::from_value(args)?;
tracing::debug!(command = %args.command, args_count = args.args.len(), "Executing command (placeholder)");
let output = if args.command == "cargo" && args.args.contains(&"test".to_string()) {
CommandOutput {
exit_code: 0,
stdout: "running 3 tests\ntest tests::test_one ... ok\ntest tests::test_two ... ok\ntest tests::test_three ... ok\n\ntest result: ok. 3 passed; 0 failed"
.to_string(),
stderr: "".to_string(),
}
} else {
CommandOutput {
exit_code: 1,
stdout: "".to_string(),
stderr: "Command not found or failed".to_string(),
}
};
Ok(serde_json::to_string(&output)?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_list_files_tool() {
let tool = ListFilesTool;
assert_eq!(tool.name(), "list_files");
let schema = tool.schema();
assert_eq!(schema.function.name, "list_files");
assert!(schema.function.description.is_some());
let args = json!({ "path": "./src" });
let result = tool.call(args).await.unwrap();
let files: Vec<String> = serde_json::from_str(&result).unwrap();
assert!(!files.is_empty());
assert!(files.contains(&"README.md".to_string()));
}
#[tokio::test]
async fn test_read_file_tool() {
let tool = ReadFileTool;
assert_eq!(tool.name(), "read_file");
let schema = tool.schema();
assert_eq!(schema.function.name, "read_file");
assert!(schema.function.description.is_some());
let args = json!({ "path": "./src/main.rs" });
let result = tool.call(args).await.unwrap();
let content: String = serde_json::from_str(&result).unwrap();
assert!(content.contains("Contents of file"));
assert!(content.contains("./src/main.rs"));
}
#[tokio::test]
async fn test_write_file_tool() {
let tool = WriteFileTool;
assert_eq!(tool.name(), "write_file");
let schema = tool.schema();
assert_eq!(schema.function.name, "write_file");
assert!(schema.function.description.is_some());
let content = "fn main() { println!(\"test\"); }";
let args = json!({ "path": "./test.rs", "content": content });
let result = tool.call(args).await.unwrap();
let response: Value = serde_json::from_str(&result).unwrap();
assert_eq!(response["status"], "success");
assert_eq!(response["bytes_written"], content.len());
}
#[tokio::test]
async fn test_execute_command_tool_cargo_test() {
let tool = ExecuteCommandTool;
assert_eq!(tool.name(), "execute_command");
let schema = tool.schema();
assert_eq!(schema.function.name, "execute_command");
assert!(schema.function.description.is_some());
let args = json!({
"command": "cargo",
"args": ["test"]
});
let result = tool.call(args).await.unwrap();
let output: CommandOutput = serde_json::from_str(&result).unwrap();
assert_eq!(output.exit_code, 0);
assert!(output.stdout.contains("3 passed"));
assert!(output.stderr.is_empty());
}
#[tokio::test]
async fn test_execute_command_tool_unknown_command() {
let tool = ExecuteCommandTool;
let args = json!({
"command": "unknown_cmd",
"args": ["--help"]
});
let result = tool.call(args).await.unwrap();
let output: CommandOutput = serde_json::from_str(&result).unwrap();
assert_eq!(output.exit_code, 1);
assert!(output.stdout.is_empty());
assert!(output.stderr.contains("Command not found"));
}
#[tokio::test]
async fn test_list_files_invalid_args() {
let tool = ListFilesTool;
let args = json!({});
let result = tool.call(args).await;
assert!(result.is_err(), "Should fail with missing path argument");
}
#[tokio::test]
async fn test_read_file_invalid_args() {
let tool = ReadFileTool;
let args = json!({});
let result = tool.call(args).await;
assert!(result.is_err(), "Should fail with missing path argument");
}
#[tokio::test]
async fn test_write_file_invalid_args() {
let tool = WriteFileTool;
let args = json!({ "path": "./test.rs" });
let result = tool.call(args).await;
assert!(result.is_err(), "Should fail with missing content argument");
}
#[tokio::test]
async fn test_execute_command_invalid_args() {
let tool = ExecuteCommandTool;
let args = json!({ "command": "ls" });
let result = tool.call(args).await;
assert!(result.is_err(), "Should fail with missing args argument");
}
#[test]
fn test_command_output_serialization() {
let output = CommandOutput {
exit_code: 0,
stdout: "Hello".to_string(),
stderr: "".to_string(),
};
let json = serde_json::to_string(&output).unwrap();
let deserialized: CommandOutput = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.exit_code, 0);
assert_eq!(deserialized.stdout, "Hello");
assert_eq!(deserialized.stderr, "");
}
}