Skip to main content

rig_core/test_utils/
tools.rs

1//! Tool helpers for deterministic tests.
2
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7
8use crate::{
9    completion::ToolDefinition,
10    tool::{Tool, ToolSet},
11    vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndex, request::Filter},
12    wasm_compat::WasmCompatSend,
13};
14
15/// Shared error type for mock tools.
16#[derive(Debug, thiserror::Error)]
17#[error("Mock tool error")]
18pub struct MockToolError;
19
20/// Arguments for arithmetic mock tools.
21#[derive(Deserialize)]
22pub struct MockOperationArgs {
23    x: i32,
24    y: i32,
25}
26
27/// A mock tool that adds `x` and `y`.
28#[derive(Deserialize, Serialize)]
29pub struct MockAddTool;
30
31impl Tool for MockAddTool {
32    const NAME: &'static str = "add";
33    type Error = MockToolError;
34    type Args = MockOperationArgs;
35    type Output = i32;
36
37    async fn definition(&self, _prompt: String) -> ToolDefinition {
38        ToolDefinition {
39            name: Self::NAME.to_string(),
40            description: "Add x and y together".to_string(),
41            parameters: json!({
42                "type": "object",
43                "properties": {
44                    "x": {
45                        "type": "number",
46                        "description": "The first number to add"
47                    },
48                    "y": {
49                        "type": "number",
50                        "description": "The second number to add"
51                    }
52                },
53                "required": ["x", "y"],
54            }),
55        }
56    }
57
58    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
59        Ok(args.x + args.y)
60    }
61}
62
63/// A mock tool that subtracts `y` from `x`.
64#[derive(Deserialize, Serialize)]
65pub struct MockSubtractTool;
66
67impl Tool for MockSubtractTool {
68    const NAME: &'static str = "subtract";
69    type Error = MockToolError;
70    type Args = MockOperationArgs;
71    type Output = i32;
72
73    async fn definition(&self, _prompt: String) -> ToolDefinition {
74        ToolDefinition {
75            name: Self::NAME.to_string(),
76            description: "Subtract y from x".to_string(),
77            parameters: json!({
78                "type": "object",
79                "properties": {
80                    "x": {
81                        "type": "number",
82                        "description": "The number to subtract from"
83                    },
84                    "y": {
85                        "type": "number",
86                        "description": "The number to subtract"
87                    }
88                },
89                "required": ["x", "y"],
90            }),
91        }
92    }
93
94    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
95        Ok(args.x - args.y)
96    }
97}
98
99/// Create a [`ToolSet`] containing [`MockAddTool`] and [`MockSubtractTool`].
100pub fn mock_math_toolset() -> ToolSet {
101    let mut toolset = ToolSet::default();
102    toolset.add_tool(MockAddTool);
103    toolset.add_tool(MockSubtractTool);
104    toolset
105}
106
107/// A mock tool that returns a multiline string.
108#[derive(Deserialize, Serialize)]
109pub struct MockStringOutputTool;
110
111impl Tool for MockStringOutputTool {
112    const NAME: &'static str = "string_output";
113    type Error = MockToolError;
114    type Args = serde_json::Value;
115    type Output = String;
116
117    async fn definition(&self, _prompt: String) -> ToolDefinition {
118        ToolDefinition {
119            name: Self::NAME.to_string(),
120            description: "Returns a multiline string".to_string(),
121            parameters: json!({
122                "type": "object",
123                "properties": {}
124            }),
125        }
126    }
127
128    async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
129        Ok("Hello\nWorld".to_string())
130    }
131}
132
133/// A mock tool that returns image JSON as a string.
134#[derive(Deserialize, Serialize)]
135pub struct MockImageOutputTool;
136
137impl Tool for MockImageOutputTool {
138    const NAME: &'static str = "image_output";
139    type Error = MockToolError;
140    type Args = serde_json::Value;
141    type Output = String;
142
143    async fn definition(&self, _prompt: String) -> ToolDefinition {
144        ToolDefinition {
145            name: Self::NAME.to_string(),
146            description: "Returns image JSON".to_string(),
147            parameters: json!({
148                "type": "object",
149                "properties": {}
150            }),
151        }
152    }
153
154    async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
155        Ok(json!({
156            "type": "image",
157            "data": "base64data==",
158            "mimeType": "image/png"
159        })
160        .to_string())
161    }
162}
163
164/// A mock tool named `generate_test_image` that returns a 1x1 red PNG image payload.
165#[derive(Debug, Deserialize, Serialize)]
166pub struct MockImageGeneratorTool;
167
168impl Tool for MockImageGeneratorTool {
169    const NAME: &'static str = "generate_test_image";
170    type Error = MockToolError;
171    type Args = serde_json::Value;
172    type Output = String;
173
174    async fn definition(&self, _prompt: String) -> ToolDefinition {
175        ToolDefinition {
176            name: Self::NAME.to_string(),
177            description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
178            parameters: json!({
179                "type": "object",
180                "properties": {},
181                "required": []
182            }),
183        }
184    }
185
186    async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
187        Ok(json!({
188            "type": "image",
189            "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
190            "mimeType": "image/png"
191        })
192        .to_string())
193    }
194}
195
196/// A mock tool that returns a JSON object.
197#[derive(Deserialize, Serialize)]
198pub struct MockObjectOutputTool;
199
200impl Tool for MockObjectOutputTool {
201    const NAME: &'static str = "object_output";
202    type Error = MockToolError;
203    type Args = serde_json::Value;
204    type Output = serde_json::Value;
205
206    async fn definition(&self, _prompt: String) -> ToolDefinition {
207        ToolDefinition {
208            name: Self::NAME.to_string(),
209            description: "Returns an object".to_string(),
210            parameters: json!({
211                "type": "object",
212                "properties": {}
213            }),
214        }
215    }
216
217    async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
218        Ok(json!({
219            "status": "ok",
220            "count": 42
221        }))
222    }
223}
224
225/// A mock tool named `example_tool` that returns `"Example answer"`.
226pub struct MockExampleTool;
227
228impl Tool for MockExampleTool {
229    const NAME: &'static str = "example_tool";
230    type Error = MockToolError;
231    type Args = ();
232    type Output = String;
233
234    async fn definition(&self, _prompt: String) -> ToolDefinition {
235        ToolDefinition {
236            name: Self::NAME.to_string(),
237            description: "A tool that returns some example text.".to_string(),
238            parameters: json!({
239                "type": "object",
240                "properties": {},
241                "required": []
242            }),
243        }
244    }
245
246    async fn call(&self, _input: Self::Args) -> Result<Self::Output, Self::Error> {
247        Ok("Example answer".to_string())
248    }
249}
250
251/// A mock tool that waits at a barrier before returning `"done"`.
252#[derive(Clone)]
253pub struct MockBarrierTool {
254    /// Barrier waited on during each tool call.
255    pub barrier: Arc<tokio::sync::Barrier>,
256}
257
258impl MockBarrierTool {
259    /// Create a barrier-backed tool.
260    pub fn new(barrier: Arc<tokio::sync::Barrier>) -> Self {
261        Self { barrier }
262    }
263}
264
265impl Tool for MockBarrierTool {
266    const NAME: &'static str = "barrier_tool";
267    type Error = MockToolError;
268    type Args = serde_json::Value;
269    type Output = String;
270
271    async fn definition(&self, _prompt: String) -> ToolDefinition {
272        ToolDefinition {
273            name: Self::NAME.to_string(),
274            description: "Waits at a barrier to test concurrency".to_string(),
275            parameters: json!({"type": "object", "properties": {}}),
276        }
277    }
278
279    async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
280        self.barrier.wait().await;
281        Ok("done".to_string())
282    }
283}
284
285/// A mock tool that notifies when started and waits for an explicit finish signal.
286#[derive(Clone)]
287pub struct MockControlledTool {
288    /// Notified when a tool call starts.
289    pub started: Arc<tokio::sync::Notify>,
290    /// Waited on before a tool call finishes.
291    pub allow_finish: Arc<tokio::sync::Notify>,
292}
293
294impl MockControlledTool {
295    /// Create a controlled tool from notification primitives.
296    pub fn new(started: Arc<tokio::sync::Notify>, allow_finish: Arc<tokio::sync::Notify>) -> Self {
297        Self {
298            started,
299            allow_finish,
300        }
301    }
302}
303
304impl Tool for MockControlledTool {
305    const NAME: &'static str = "controlled";
306    type Error = MockToolError;
307    type Args = serde_json::Value;
308    type Output = i32;
309
310    async fn definition(&self, _prompt: String) -> ToolDefinition {
311        ToolDefinition {
312            name: Self::NAME.to_string(),
313            description: "Test tool".to_string(),
314            parameters: json!({"type": "object", "properties": {}}),
315        }
316    }
317
318    async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
319        self.started.notify_one();
320        self.allow_finish.notified().await;
321        Ok(42)
322    }
323}
324
325/// A vector index that returns a predefined list of tool IDs from `top_n_ids`.
326pub struct MockToolIndex {
327    tool_ids: Vec<String>,
328}
329
330impl MockToolIndex {
331    /// Create a tool index that returns the given IDs in order.
332    pub fn new(tool_ids: impl IntoIterator<Item = impl Into<String>>) -> Self {
333        Self {
334            tool_ids: tool_ids.into_iter().map(Into::into).collect(),
335        }
336    }
337}
338
339impl VectorStoreIndex for MockToolIndex {
340    type Filter = Filter<serde_json::Value>;
341
342    async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
343        &self,
344        _req: VectorSearchRequest,
345    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
346        Ok(vec![])
347    }
348
349    async fn top_n_ids(
350        &self,
351        _req: VectorSearchRequest,
352    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
353        Ok(self
354            .tool_ids
355            .iter()
356            .enumerate()
357            .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
358            .collect())
359    }
360}
361
362/// A vector index that waits at a barrier before returning one tool ID.
363pub struct BarrierMockToolIndex {
364    barrier: Arc<tokio::sync::Barrier>,
365    tool_id: String,
366}
367
368impl BarrierMockToolIndex {
369    /// Create a barrier-backed tool index.
370    pub fn new(barrier: Arc<tokio::sync::Barrier>, tool_id: impl Into<String>) -> Self {
371        Self {
372            barrier,
373            tool_id: tool_id.into(),
374        }
375    }
376}
377
378impl VectorStoreIndex for BarrierMockToolIndex {
379    type Filter = Filter<serde_json::Value>;
380
381    async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
382        &self,
383        _req: VectorSearchRequest,
384    ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
385        Ok(vec![])
386    }
387
388    async fn top_n_ids(
389        &self,
390        _req: VectorSearchRequest,
391    ) -> Result<Vec<(f64, String)>, VectorStoreError> {
392        self.barrier.wait().await;
393        Ok(vec![(1.0, self.tool_id.clone())])
394    }
395}