llm_coding_tools_rig/allowed/
read.rs1use llm_coding_tools_core::operations::read_file;
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6use llm_coding_tools_core::{ToolContext, ToolError, ToolOutput};
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use schemars::{schema_for, JsonSchema};
10use serde::Deserialize;
11
12const DEFAULT_OFFSET: usize = 1;
13const DEFAULT_LIMIT: usize = 2000;
14
15fn default_offset() -> usize {
16 DEFAULT_OFFSET
17}
18
19fn default_limit() -> usize {
20 DEFAULT_LIMIT
21}
22
23#[derive(Debug, Clone, Deserialize, JsonSchema)]
25pub struct ReadArgs {
26 pub file_path: String,
28 #[serde(default = "default_offset")]
30 pub offset: usize,
31 #[serde(default = "default_limit")]
33 pub limit: usize,
34}
35
36#[derive(Debug, Clone)]
40pub struct ReadTool<const LINE_NUMBERS: bool = true> {
41 resolver: AllowedPathResolver,
42}
43
44impl<const LINE_NUMBERS: bool> ReadTool<LINE_NUMBERS> {
45 pub fn new(resolver: AllowedPathResolver) -> Self {
65 Self { resolver }
66 }
67}
68
69impl<const LINE_NUMBERS: bool> Tool for ReadTool<LINE_NUMBERS> {
70 const NAME: &'static str = tool_names::READ;
71
72 type Error = ToolError;
73 type Args = ReadArgs;
74 type Output = ToolOutput;
75
76 async fn definition(&self, _prompt: String) -> ToolDefinition {
77 let description = if LINE_NUMBERS {
78 "Read file contents with line numbers from allowed directories. \
79 Paths are relative to configured base directories."
80 } else {
81 "Read file contents from allowed directories. \
82 Paths are relative to configured base directories."
83 };
84 ToolDefinition {
85 name: <Self as Tool>::NAME.to_string(),
86 description: description.to_string(),
87 parameters: serde_json::to_value(schema_for!(ReadArgs))
88 .expect("schema serialization should never fail"),
89 }
90 }
91
92 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
93 read_file::<_, LINE_NUMBERS>(&self.resolver, &args.file_path, args.offset, args.limit).await
94 }
95}
96
97impl<const LINE_NUMBERS: bool> ToolContext for ReadTool<LINE_NUMBERS> {
98 const NAME: &'static str = tool_names::READ;
99
100 fn context(&self) -> &'static str {
101 llm_coding_tools_core::context::READ_ALLOWED
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use tempfile::TempDir;
109
110 #[tokio::test]
111 async fn reads_file_with_line_numbers() {
112 let dir = TempDir::new().unwrap();
113 let file_path = dir.path().join("test.txt");
114 std::fs::write(&file_path, "hello\nworld\n").unwrap();
115
116 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
117 let tool: ReadTool<true> = ReadTool::new(resolver);
118 let args = ReadArgs {
119 file_path: "test.txt".to_string(),
120 offset: 1,
121 limit: 2000,
122 };
123 let result = tool.call(args).await.unwrap();
124 assert_eq!(result.content, "L1: hello\nL2: world");
125 }
126
127 #[tokio::test]
128 async fn rejects_path_traversal() {
129 let dir = TempDir::new().unwrap();
130 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
131 let tool: ReadTool = ReadTool::new(resolver);
132 let args = ReadArgs {
133 file_path: "../../../etc/passwd".to_string(),
134 offset: 1,
135 limit: 100,
136 };
137 let result = tool.call(args).await;
138 assert!(matches!(result, Err(ToolError::InvalidPath(_))));
139 }
140}