nika_engine/tools/
rig_adapter.rs1use std::future::Future;
25use std::pin::Pin;
26use std::sync::Arc;
27
28use rig::completion::ToolDefinition;
29use rig::tool::{ToolDyn, ToolError};
30use serde_json::Value;
31
32use super::FileTool;
33
34type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
36
37pub struct RigFileTool<T: FileTool + Send + Sync + 'static> {
42 inner: Arc<T>,
43}
44
45impl<T: FileTool + Send + Sync + 'static> RigFileTool<T> {
46 pub fn new(tool: T) -> Self {
48 Self {
49 inner: Arc::new(tool),
50 }
51 }
52
53 pub fn from_arc(tool: Arc<T>) -> Self {
55 Self { inner: tool }
56 }
57}
58
59impl<T: FileTool + Send + Sync + 'static> Clone for RigFileTool<T> {
60 fn clone(&self) -> Self {
61 Self {
62 inner: Arc::clone(&self.inner),
63 }
64 }
65}
66
67impl<T: FileTool + Send + Sync + 'static> std::fmt::Debug for RigFileTool<T> {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("RigFileTool")
70 .field("name", &self.inner.name())
71 .finish()
72 }
73}
74
75impl<T: FileTool + Send + Sync + 'static> ToolDyn for RigFileTool<T> {
76 fn name(&self) -> String {
77 self.inner.name().to_string()
78 }
79
80 fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
81 let def = ToolDefinition {
82 name: self.inner.name().to_string(),
83 description: self.inner.description().to_string(),
84 parameters: self.inner.parameters_schema(),
85 };
86 Box::pin(async move { def })
87 }
88
89 fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
90 let inner = Arc::clone(&self.inner);
91
92 Box::pin(async move {
93 let params: Value = serde_json::from_str(&args).map_err(|e| {
95 ToolError::ToolCallError(Box::new(std::io::Error::other(format!(
96 "Invalid JSON arguments: {}",
97 e
98 ))))
99 })?;
100
101 let result = inner.call(params).await.map_err(|e| {
103 ToolError::ToolCallError(Box::new(std::io::Error::other(e.to_string())))
104 })?;
105
106 if result.is_error {
108 Err(ToolError::ToolCallError(Box::new(std::io::Error::other(
109 result.content,
110 ))))
111 } else {
112 Ok(result.content)
113 }
114 })
115 }
116}
117
118use super::{EditTool, GlobTool, GrepTool, ReadTool, ToolContext, WriteTool};
123
124pub fn create_rig_file_tools(ctx: Arc<ToolContext>) -> Vec<Box<dyn ToolDyn>> {
139 vec![
140 Box::new(RigFileTool::new(ReadTool::new(Arc::clone(&ctx)))),
141 Box::new(RigFileTool::new(WriteTool::new(Arc::clone(&ctx)))),
142 Box::new(RigFileTool::new(EditTool::new(Arc::clone(&ctx)))),
143 Box::new(RigFileTool::new(GlobTool::new(Arc::clone(&ctx)))),
144 Box::new(RigFileTool::new(GrepTool::new(ctx))),
145 ]
146}
147
148#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::tools::context::testing::{create_test_file, setup_test};
156 use tokio::fs;
157
158 #[tokio::test]
159 async fn test_rig_file_tool_read() {
160 let (temp_dir, ctx) = setup_test().await;
161
162 let file_path = create_test_file(&temp_dir, "test.txt", "Hello, World!").await;
164
165 let rig_tool = RigFileTool::new(ReadTool::new(ctx));
167
168 assert_eq!(rig_tool.name(), "read");
170
171 let def = rig_tool.definition("".to_string()).await;
172 assert_eq!(def.name, "read");
173 assert!(def.description.contains("Read a file"));
174
175 let args = serde_json::json!({
177 "file_path": file_path.to_string_lossy()
178 })
179 .to_string();
180
181 let result = rig_tool.call(args).await;
182 assert!(result.is_ok());
183 let content = result.unwrap();
184 assert!(content.contains("Hello, World!"));
185 }
186
187 #[tokio::test]
188 async fn test_rig_file_tool_write() {
189 let (temp_dir, ctx) = setup_test().await;
190
191 let file_path = temp_dir.path().join("new_file.txt");
192
193 let rig_tool = RigFileTool::new(WriteTool::new(ctx));
194
195 let args = serde_json::json!({
196 "file_path": file_path.to_string_lossy(),
197 "content": "New content"
198 })
199 .to_string();
200
201 let result = rig_tool.call(args).await;
202 assert!(result.is_ok());
203
204 let content = fs::read_to_string(&file_path).await.unwrap();
206 assert_eq!(content, "New content");
207 }
208
209 #[tokio::test]
210 async fn test_rig_file_tool_glob() {
211 use crate::tools::context::testing::create_test_tree;
212 let (temp_dir, ctx) = setup_test().await;
213
214 create_test_tree(
216 &temp_dir,
217 &[("a.rs", "fn a()"), ("b.rs", "fn b()"), ("c.txt", "text")],
218 )
219 .await;
220
221 let rig_tool = RigFileTool::new(GlobTool::new(ctx));
222
223 let args = serde_json::json!({
224 "pattern": "*.rs"
225 })
226 .to_string();
227
228 let result = rig_tool.call(args).await;
229 assert!(result.is_ok());
230 let content = result.unwrap();
231 assert!(content.contains("a.rs"));
232 assert!(content.contains("b.rs"));
233 assert!(!content.contains("c.txt"));
234 }
235
236 #[tokio::test]
237 async fn test_create_rig_file_tools() {
238 let (_temp_dir, ctx) = setup_test().await;
239
240 let tools = create_rig_file_tools(ctx);
241
242 assert_eq!(tools.len(), 5);
243
244 let names: Vec<String> = tools.iter().map(|t| t.name()).collect();
246 assert!(names.contains(&"read".to_string()));
247 assert!(names.contains(&"write".to_string()));
248 assert!(names.contains(&"edit".to_string()));
249 assert!(names.contains(&"glob".to_string()));
250 assert!(names.contains(&"grep".to_string()));
251 }
252
253 #[tokio::test]
254 async fn test_rig_file_tool_error_handling() {
255 let (_temp_dir, ctx) = setup_test().await;
256
257 let rig_tool = RigFileTool::new(ReadTool::new(ctx));
258
259 let args = serde_json::json!({
261 "file_path": "/nonexistent/path/file.txt"
262 })
263 .to_string();
264
265 let result = rig_tool.call(args).await;
266 assert!(result.is_err());
267 }
268}