1use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7
8use crate::{
9 completion::ToolDefinition,
10 tool::{Tool, ToolSet},
11 vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndex, request::Filter},
12 wasm_compat::WasmCompatSend,
13};
14
15#[derive(Debug, thiserror::Error)]
17#[error("Mock tool error")]
18pub struct MockToolError;
19
20#[derive(Deserialize)]
22pub struct MockOperationArgs {
23 x: i32,
24 y: i32,
25}
26
27#[derive(Deserialize, Serialize)]
29pub struct MockAddTool;
30
31impl Tool for MockAddTool {
32 const NAME: &'static str = "add";
33 type Error = MockToolError;
34 type Args = MockOperationArgs;
35 type Output = i32;
36
37 async fn definition(&self, _prompt: String) -> ToolDefinition {
38 ToolDefinition {
39 name: Self::NAME.to_string(),
40 description: "Add x and y together".to_string(),
41 parameters: json!({
42 "type": "object",
43 "properties": {
44 "x": {
45 "type": "number",
46 "description": "The first number to add"
47 },
48 "y": {
49 "type": "number",
50 "description": "The second number to add"
51 }
52 },
53 "required": ["x", "y"],
54 }),
55 }
56 }
57
58 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
59 Ok(args.x + args.y)
60 }
61}
62
63#[derive(Deserialize, Serialize)]
65pub struct MockSubtractTool;
66
67impl Tool for MockSubtractTool {
68 const NAME: &'static str = "subtract";
69 type Error = MockToolError;
70 type Args = MockOperationArgs;
71 type Output = i32;
72
73 async fn definition(&self, _prompt: String) -> ToolDefinition {
74 ToolDefinition {
75 name: Self::NAME.to_string(),
76 description: "Subtract y from x".to_string(),
77 parameters: json!({
78 "type": "object",
79 "properties": {
80 "x": {
81 "type": "number",
82 "description": "The number to subtract from"
83 },
84 "y": {
85 "type": "number",
86 "description": "The number to subtract"
87 }
88 },
89 "required": ["x", "y"],
90 }),
91 }
92 }
93
94 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
95 Ok(args.x - args.y)
96 }
97}
98
99pub fn mock_math_toolset() -> ToolSet {
101 let mut toolset = ToolSet::default();
102 toolset.add_tool(MockAddTool);
103 toolset.add_tool(MockSubtractTool);
104 toolset
105}
106
107#[derive(Deserialize, Serialize)]
109pub struct MockStringOutputTool;
110
111impl Tool for MockStringOutputTool {
112 const NAME: &'static str = "string_output";
113 type Error = MockToolError;
114 type Args = serde_json::Value;
115 type Output = String;
116
117 async fn definition(&self, _prompt: String) -> ToolDefinition {
118 ToolDefinition {
119 name: Self::NAME.to_string(),
120 description: "Returns a multiline string".to_string(),
121 parameters: json!({
122 "type": "object",
123 "properties": {}
124 }),
125 }
126 }
127
128 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
129 Ok("Hello\nWorld".to_string())
130 }
131}
132
133#[derive(Deserialize, Serialize)]
135pub struct MockImageOutputTool;
136
137impl Tool for MockImageOutputTool {
138 const NAME: &'static str = "image_output";
139 type Error = MockToolError;
140 type Args = serde_json::Value;
141 type Output = String;
142
143 async fn definition(&self, _prompt: String) -> ToolDefinition {
144 ToolDefinition {
145 name: Self::NAME.to_string(),
146 description: "Returns image JSON".to_string(),
147 parameters: json!({
148 "type": "object",
149 "properties": {}
150 }),
151 }
152 }
153
154 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
155 Ok(json!({
156 "type": "image",
157 "data": "base64data==",
158 "mimeType": "image/png"
159 })
160 .to_string())
161 }
162}
163
164#[derive(Debug, Deserialize, Serialize)]
166pub struct MockImageGeneratorTool;
167
168impl Tool for MockImageGeneratorTool {
169 const NAME: &'static str = "generate_test_image";
170 type Error = MockToolError;
171 type Args = serde_json::Value;
172 type Output = String;
173
174 async fn definition(&self, _prompt: String) -> ToolDefinition {
175 ToolDefinition {
176 name: Self::NAME.to_string(),
177 description: "Generates a small test image (a 1x1 red pixel). Call this tool when asked to generate or show an image.".to_string(),
178 parameters: json!({
179 "type": "object",
180 "properties": {},
181 "required": []
182 }),
183 }
184 }
185
186 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
187 Ok(json!({
188 "type": "image",
189 "data": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==",
190 "mimeType": "image/png"
191 })
192 .to_string())
193 }
194}
195
196#[derive(Deserialize, Serialize)]
198pub struct MockObjectOutputTool;
199
200impl Tool for MockObjectOutputTool {
201 const NAME: &'static str = "object_output";
202 type Error = MockToolError;
203 type Args = serde_json::Value;
204 type Output = serde_json::Value;
205
206 async fn definition(&self, _prompt: String) -> ToolDefinition {
207 ToolDefinition {
208 name: Self::NAME.to_string(),
209 description: "Returns an object".to_string(),
210 parameters: json!({
211 "type": "object",
212 "properties": {}
213 }),
214 }
215 }
216
217 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
218 Ok(json!({
219 "status": "ok",
220 "count": 42
221 }))
222 }
223}
224
225pub struct MockExampleTool;
227
228impl Tool for MockExampleTool {
229 const NAME: &'static str = "example_tool";
230 type Error = MockToolError;
231 type Args = ();
232 type Output = String;
233
234 async fn definition(&self, _prompt: String) -> ToolDefinition {
235 ToolDefinition {
236 name: Self::NAME.to_string(),
237 description: "A tool that returns some example text.".to_string(),
238 parameters: json!({
239 "type": "object",
240 "properties": {},
241 "required": []
242 }),
243 }
244 }
245
246 async fn call(&self, _input: Self::Args) -> Result<Self::Output, Self::Error> {
247 Ok("Example answer".to_string())
248 }
249}
250
251#[derive(Clone)]
253pub struct MockBarrierTool {
254 pub barrier: Arc<tokio::sync::Barrier>,
256}
257
258impl MockBarrierTool {
259 pub fn new(barrier: Arc<tokio::sync::Barrier>) -> Self {
261 Self { barrier }
262 }
263}
264
265impl Tool for MockBarrierTool {
266 const NAME: &'static str = "barrier_tool";
267 type Error = MockToolError;
268 type Args = serde_json::Value;
269 type Output = String;
270
271 async fn definition(&self, _prompt: String) -> ToolDefinition {
272 ToolDefinition {
273 name: Self::NAME.to_string(),
274 description: "Waits at a barrier to test concurrency".to_string(),
275 parameters: json!({"type": "object", "properties": {}}),
276 }
277 }
278
279 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
280 self.barrier.wait().await;
281 Ok("done".to_string())
282 }
283}
284
285#[derive(Clone)]
287pub struct MockControlledTool {
288 pub started: Arc<tokio::sync::Notify>,
290 pub allow_finish: Arc<tokio::sync::Notify>,
292}
293
294impl MockControlledTool {
295 pub fn new(started: Arc<tokio::sync::Notify>, allow_finish: Arc<tokio::sync::Notify>) -> Self {
297 Self {
298 started,
299 allow_finish,
300 }
301 }
302}
303
304impl Tool for MockControlledTool {
305 const NAME: &'static str = "controlled";
306 type Error = MockToolError;
307 type Args = serde_json::Value;
308 type Output = i32;
309
310 async fn definition(&self, _prompt: String) -> ToolDefinition {
311 ToolDefinition {
312 name: Self::NAME.to_string(),
313 description: "Test tool".to_string(),
314 parameters: json!({"type": "object", "properties": {}}),
315 }
316 }
317
318 async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
319 self.started.notify_one();
320 self.allow_finish.notified().await;
321 Ok(42)
322 }
323}
324
325pub struct MockToolIndex {
327 tool_ids: Vec<String>,
328}
329
330impl MockToolIndex {
331 pub fn new(tool_ids: impl IntoIterator<Item = impl Into<String>>) -> Self {
333 Self {
334 tool_ids: tool_ids.into_iter().map(Into::into).collect(),
335 }
336 }
337}
338
339impl VectorStoreIndex for MockToolIndex {
340 type Filter = Filter<serde_json::Value>;
341
342 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
343 &self,
344 _req: VectorSearchRequest,
345 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
346 Ok(vec![])
347 }
348
349 async fn top_n_ids(
350 &self,
351 _req: VectorSearchRequest,
352 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
353 Ok(self
354 .tool_ids
355 .iter()
356 .enumerate()
357 .map(|(i, id)| (1.0 - (i as f64 * 0.1), id.clone()))
358 .collect())
359 }
360}
361
362pub struct BarrierMockToolIndex {
364 barrier: Arc<tokio::sync::Barrier>,
365 tool_id: String,
366}
367
368impl BarrierMockToolIndex {
369 pub fn new(barrier: Arc<tokio::sync::Barrier>, tool_id: impl Into<String>) -> Self {
371 Self {
372 barrier,
373 tool_id: tool_id.into(),
374 }
375 }
376}
377
378impl VectorStoreIndex for BarrierMockToolIndex {
379 type Filter = Filter<serde_json::Value>;
380
381 async fn top_n<T: for<'a> Deserialize<'a> + WasmCompatSend>(
382 &self,
383 _req: VectorSearchRequest,
384 ) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
385 Ok(vec![])
386 }
387
388 async fn top_n_ids(
389 &self,
390 _req: VectorSearchRequest,
391 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
392 self.barrier.wait().await;
393 Ok(vec![(1.0, self.tool_id.clone())])
394 }
395}