llm_coding_tools_rig/allowed/
write.rs1use llm_coding_tools_core::operations::write_file;
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6use llm_coding_tools_core::{ToolContext, ToolError};
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use schemars::{schema_for, JsonSchema};
10use serde::Deserialize;
11
12#[derive(Debug, Clone, Deserialize, JsonSchema)]
14pub struct WriteToolArgs {
15 pub file_path: String,
17 pub content: String,
19}
20
21#[derive(Debug, Clone)]
23pub struct WriteTool {
24 resolver: AllowedPathResolver,
25}
26
27impl WriteTool {
28 pub fn new(resolver: AllowedPathResolver) -> Self {
34 Self { resolver }
35 }
36}
37
38impl Tool for WriteTool {
39 const NAME: &'static str = tool_names::WRITE;
40
41 type Error = ToolError;
42 type Args = WriteToolArgs;
43 type Output = String;
44
45 async fn definition(&self, _prompt: String) -> ToolDefinition {
46 ToolDefinition {
47 name: <Self as Tool>::NAME.to_string(),
48 description: "Write content to a file within allowed directories. \
49 Paths are relative to configured base directories."
50 .to_string(),
51 parameters: serde_json::to_value(schema_for!(WriteToolArgs))
52 .expect("schema generation should not fail"),
53 }
54 }
55
56 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
57 write_file(&self.resolver, &args.file_path, &args.content).await
58 }
59}
60
61impl ToolContext for WriteTool {
62 const NAME: &'static str = tool_names::WRITE;
63
64 fn context(&self) -> &'static str {
65 llm_coding_tools_core::context::WRITE_ALLOWED
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use tempfile::TempDir;
73
74 #[tokio::test]
75 async fn writes_new_file() {
76 let dir = TempDir::new().unwrap();
77 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
78 let tool = WriteTool::new(resolver);
79 let result = tool
80 .call(WriteToolArgs {
81 file_path: "new.txt".to_string(),
82 content: "hello".to_string(),
83 })
84 .await
85 .unwrap();
86 assert!(result.contains("5 bytes"));
87 assert!(dir.path().join("new.txt").exists());
88 }
89
90 #[tokio::test]
91 async fn rejects_path_traversal() {
92 let dir = TempDir::new().unwrap();
93 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
94 let tool = WriteTool::new(resolver);
95 let result = tool
96 .call(WriteToolArgs {
97 file_path: "../../../tmp/escape.txt".to_string(),
98 content: "content".to_string(),
99 })
100 .await;
101 assert!(matches!(result, Err(ToolError::InvalidPath(_))));
102 }
103}