use llm_coding_tools_core::operations::write_file;
use llm_coding_tools_core::path::AllowedPathResolver;
use llm_coding_tools_core::tool_names;
use llm_coding_tools_core::{ToolContext, ToolError};
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use schemars::{schema_for, JsonSchema};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct WriteToolArgs {
pub file_path: String,
pub content: String,
}
#[derive(Debug, Clone)]
pub struct WriteTool {
resolver: AllowedPathResolver,
}
impl WriteTool {
pub fn new(resolver: AllowedPathResolver) -> Self {
Self { resolver }
}
}
impl Tool for WriteTool {
const NAME: &'static str = tool_names::WRITE;
type Error = ToolError;
type Args = WriteToolArgs;
type Output = String;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: <Self as Tool>::NAME.to_string(),
description: "Write content to a file within allowed directories. \
Paths are relative to configured base directories."
.to_string(),
parameters: serde_json::to_value(schema_for!(WriteToolArgs))
.expect("schema generation should not fail"),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
write_file(&self.resolver, &args.file_path, &args.content).await
}
}
impl ToolContext for WriteTool {
const NAME: &'static str = tool_names::WRITE;
fn context(&self) -> &'static str {
llm_coding_tools_core::context::WRITE_ALLOWED
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn writes_new_file() {
let dir = TempDir::new().unwrap();
let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
let tool = WriteTool::new(resolver);
let result = tool
.call(WriteToolArgs {
file_path: "new.txt".to_string(),
content: "hello".to_string(),
})
.await
.unwrap();
assert!(result.contains("5 bytes"));
assert!(dir.path().join("new.txt").exists());
}
#[tokio::test]
async fn rejects_path_traversal() {
let dir = TempDir::new().unwrap();
let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
let tool = WriteTool::new(resolver);
let result = tool
.call(WriteToolArgs {
file_path: "../../../tmp/escape.txt".to_string(),
content: "content".to_string(),
})
.await;
assert!(matches!(result, Err(ToolError::InvalidPath(_))));
}
}