1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21 completion::{self, ToolDefinition},
22 embeddings::{embed::EmbedError, tool::ToolSchema},
23 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28 #[cfg(not(target_family = "wasm"))]
29 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32 #[cfg(target_family = "wasm")]
33 ToolCallError(#[from] Box<dyn std::error::Error>),
35 JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 ToolError::ToolCallError(e) => {
43 let error_str = e.to_string();
44 if error_str.starts_with("ToolCallError: ") {
47 write!(f, "{}", error_str)
48 } else {
49 write!(f, "ToolCallError: {}", error_str)
50 }
51 }
52 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53 }
54 }
55}
56
57pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113 const NAME: &'static str;
116
117 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
119 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
121 type Output: Serialize;
123
124 fn name(&self) -> String {
126 Self::NAME.to_string()
127 }
128
129 fn definition(
132 &self,
133 _prompt: String,
134 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
135
136 fn call(
140 &self,
141 args: Self::Args,
142 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
143}
144
145pub trait ToolEmbedding: Tool {
147 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
149
150 type Context: for<'a> Deserialize<'a> + Serialize;
155
156 type State: WasmCompatSend;
160
161 fn embedding_docs(&self) -> Vec<String>;
165
166 fn context(&self) -> Self::Context;
168
169 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
171}
172
173pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
175 fn name(&self) -> String;
177
178 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
180
181 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
183}
184
185fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
186 match serde_json::to_value(output)? {
187 serde_json::Value::String(text) => Ok(text),
188 value => Ok(value.to_string()),
189 }
190}
191
192impl<T: Tool> ToolDyn for T {
193 fn name(&self) -> String {
194 self.name()
195 }
196
197 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
198 Box::pin(<Self as Tool>::definition(self, prompt))
199 }
200
201 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
202 Box::pin(async move {
203 let args = match serde_json::from_str(&args) {
210 Ok(args) => Ok(args),
211 Err(err) if args.trim() == "null" => serde_json::from_str("{}").map_err(|_| err),
212 Err(err) => Err(err),
213 };
214 match args {
215 Ok(args) => <Self as Tool>::call(self, args)
216 .await
217 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
218 .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
219 Err(e) => Err(ToolError::JsonError(e)),
220 }
221 })
222 }
223}
224
225#[cfg(feature = "rmcp")]
226#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
227pub mod rmcp;
228
229pub trait ToolEmbeddingDyn: ToolDyn {
231 fn context(&self) -> serde_json::Result<serde_json::Value>;
233
234 fn embedding_docs(&self) -> Vec<String>;
236}
237
238impl<T> ToolEmbeddingDyn for T
239where
240 T: ToolEmbedding + 'static,
241{
242 fn context(&self) -> serde_json::Result<serde_json::Value> {
243 serde_json::to_value(self.context())
244 }
245
246 fn embedding_docs(&self) -> Vec<String> {
247 self.embedding_docs()
248 }
249}
250
251#[derive(Clone)]
252pub(crate) enum ToolType {
253 Simple(Arc<dyn ToolDyn>),
254 Embedding(Arc<dyn ToolEmbeddingDyn>),
255}
256
257impl ToolType {
258 pub fn name(&self) -> String {
259 match self {
260 ToolType::Simple(tool) => tool.name(),
261 ToolType::Embedding(tool) => tool.name(),
262 }
263 }
264
265 pub async fn definition(&self, prompt: String) -> ToolDefinition {
266 match self {
267 ToolType::Simple(tool) => tool.definition(prompt).await,
268 ToolType::Embedding(tool) => tool.definition(prompt).await,
269 }
270 }
271
272 pub async fn call(&self, args: String) -> Result<String, ToolError> {
273 match self {
274 ToolType::Simple(tool) => tool.call(args).await,
275 ToolType::Embedding(tool) => tool.call(args).await,
276 }
277 }
278}
279
280#[derive(Debug, thiserror::Error)]
281pub enum ToolSetError {
282 #[error("ToolCallError: {0}")]
284 ToolCallError(#[from] ToolError),
285
286 #[error("ToolNotFoundError: {0}")]
288 ToolNotFoundError(String),
289
290 #[error("JsonError: {0}")]
292 JsonError(#[from] serde_json::Error),
293
294 #[error("Tool call interrupted")]
296 Interrupted,
297}
298
299#[derive(Default)]
301pub struct ToolSet {
302 pub(crate) tools: HashMap<String, ToolType>,
303}
304
305impl ToolSet {
306 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
308 let mut toolset = Self::default();
309 tools.into_iter().for_each(|tool| {
310 toolset.add_tool(tool);
311 });
312 toolset
313 }
314
315 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
317 let mut toolset = Self::default();
318 tools.into_iter().for_each(|tool| {
319 toolset.add_tool_boxed(tool);
320 });
321 toolset
322 }
323
324 pub fn builder() -> ToolSetBuilder {
326 ToolSetBuilder::default()
327 }
328
329 pub fn contains(&self, toolname: &str) -> bool {
331 self.tools.contains_key(toolname)
332 }
333
334 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
336 self.tools
337 .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
338 }
339
340 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
342 self.tools
343 .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
344 }
345
346 pub fn delete_tool(&mut self, tool_name: &str) {
348 let _ = self.tools.remove(tool_name);
349 }
350
351 pub fn add_tools(&mut self, toolset: ToolSet) {
353 self.tools.extend(toolset.tools);
354 }
355
356 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
357 self.tools.get(toolname)
358 }
359
360 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
362 let mut defs = Vec::new();
363 for tool in self.tools.values() {
364 let def = tool.definition(String::new()).await;
365 defs.push(def);
366 }
367 Ok(defs)
368 }
369
370 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
372 if let Some(tool) = self.tools.get(toolname) {
373 tracing::debug!(target: "rig",
374 "Calling tool {toolname} with args:\n{}",
375 args
376 );
377 Ok(tool.call(args).await?)
378 } else {
379 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
380 }
381 }
382
383 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
385 let mut docs = Vec::new();
386 for tool in self.tools.values() {
387 match tool {
388 ToolType::Simple(tool) => {
389 docs.push(completion::Document {
390 id: tool.name(),
391 text: format!(
392 "\
393 Tool: {}\n\
394 Definition: \n\
395 {}\
396 ",
397 tool.name(),
398 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
399 ),
400 additional_props: HashMap::new(),
401 });
402 }
403 ToolType::Embedding(tool) => {
404 docs.push(completion::Document {
405 id: tool.name(),
406 text: format!(
407 "\
408 Tool: {}\n\
409 Definition: \n\
410 {}\
411 ",
412 tool.name(),
413 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
414 ),
415 additional_props: HashMap::new(),
416 });
417 }
418 }
419 }
420 Ok(docs)
421 }
422
423 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
427 self.tools
428 .values()
429 .filter_map(|tool_type| {
430 if let ToolType::Embedding(tool) = tool_type {
431 Some(ToolSchema::try_from(&**tool))
432 } else {
433 None
434 }
435 })
436 .collect::<Result<Vec<_>, _>>()
437 }
438}
439
440#[derive(Default)]
441pub struct ToolSetBuilder {
443 tools: Vec<ToolType>,
444}
445
446impl ToolSetBuilder {
447 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
449 self.tools.push(ToolType::Simple(Arc::new(tool)));
450 self
451 }
452
453 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
455 self.tools.push(ToolType::Embedding(Arc::new(tool)));
456 self
457 }
458
459 pub fn build(self) -> ToolSet {
461 ToolSet {
462 tools: self
463 .tools
464 .into_iter()
465 .map(|tool| (tool.name(), tool))
466 .collect(),
467 }
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use crate::message::{DocumentSourceKind, ToolResultContent};
474 use crate::test_utils::{
475 MockExampleTool, MockImageOutputTool, MockObjectOutputTool, MockStringOutputTool,
476 mock_math_toolset,
477 };
478 use serde_json::json;
479
480 use super::*;
481
482 fn get_test_toolset() -> ToolSet {
483 mock_math_toolset()
484 }
485
486 #[tokio::test]
487 async fn test_get_tool_definitions() {
488 let toolset = get_test_toolset();
489 let tools = toolset.get_tool_definitions().await.unwrap();
490 assert_eq!(tools.len(), 2);
491 }
492
493 #[test]
494 fn test_tool_deletion() {
495 let mut toolset = get_test_toolset();
496 assert_eq!(toolset.tools.len(), 2);
497 toolset.delete_tool("add");
498 assert!(!toolset.contains("add"));
499 assert_eq!(toolset.tools.len(), 1);
500 }
501
502 #[tokio::test]
503 async fn string_tool_outputs_are_preserved_verbatim() {
504 let mut toolset = ToolSet::default();
505 toolset.add_tool(MockStringOutputTool);
506
507 let output = toolset
508 .call("string_output", "{}".to_string())
509 .await
510 .expect("tool should succeed");
511
512 assert_eq!(output, "Hello\nWorld");
513 }
514
515 #[tokio::test]
516 async fn structured_string_tool_outputs_remain_parseable() {
517 let mut toolset = ToolSet::default();
518 toolset.add_tool(MockImageOutputTool);
519
520 let output = toolset
521 .call("image_output", "{}".to_string())
522 .await
523 .expect("tool should succeed");
524 let content = ToolResultContent::from_tool_output(output);
525
526 assert_eq!(content.len(), 1);
527 match content.first() {
528 ToolResultContent::Image(image) => {
529 assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
530 assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
531 }
532 other => panic!("expected image tool result content, got {other:?}"),
533 }
534 }
535
536 #[tokio::test]
537 async fn object_tool_outputs_still_serialize_as_json() {
538 let mut toolset = ToolSet::default();
539 toolset.add_tool(MockObjectOutputTool);
540
541 let output = toolset
542 .call("object_output", "{}".to_string())
543 .await
544 .expect("tool should succeed");
545
546 assert!(output.starts_with('{'));
547 assert_eq!(
548 serde_json::from_str::<serde_json::Value>(&output).unwrap(),
549 json!({
550 "status": "ok",
551 "count": 42
552 })
553 );
554 }
555
556 #[tokio::test]
557 async fn null_args_are_preserved_for_unit_args() {
558 let mut toolset = ToolSet::default();
559 toolset.add_tool(MockExampleTool);
560
561 let output = toolset
562 .call("example_tool", "null".to_string())
563 .await
564 .expect("unit args should accept null without object fallback");
565
566 assert_eq!(output, "Example answer");
567 }
568
569 #[tokio::test]
574 async fn null_args_are_normalized_to_empty_object() {
575 use crate::test_utils::MockToolError;
576
577 #[derive(serde::Deserialize, serde::Serialize)]
578 struct NoRequiredArgs {
579 label: Option<String>,
580 }
581
582 struct NoArgTool;
583
584 impl Tool for NoArgTool {
585 const NAME: &'static str = "no_arg_tool";
586 type Error = MockToolError;
587 type Args = NoRequiredArgs;
588 type Output = String;
589
590 async fn definition(&self, _prompt: String) -> ToolDefinition {
591 ToolDefinition {
592 name: Self::NAME.to_string(),
593 description: "Tool with no required arguments".to_string(),
594 parameters: json!({"type": "object", "properties": {}}),
595 }
596 }
597
598 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
599 Ok(args.label.unwrap_or_else(|| "default".to_string()))
600 }
601 }
602
603 let mut toolset = ToolSet::default();
604 toolset.add_tool(NoArgTool);
605
606 let output = toolset
609 .call("no_arg_tool", "null".to_string())
610 .await
611 .expect("null args should succeed after normalisation");
612
613 assert_eq!(output, "default");
614 }
615}