1pub mod server;
13use std::collections::HashMap;
14use std::fmt;
15use std::sync::Arc;
16
17use futures::Future;
18use serde::{Deserialize, Serialize};
19
20use crate::{
21 completion::{self, ToolDefinition},
22 embeddings::{embed::EmbedError, tool::ToolSchema},
23 wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync},
24};
25
26#[derive(Debug, thiserror::Error)]
27pub enum ToolError {
28 #[cfg(not(target_family = "wasm"))]
29 ToolCallError(#[from] Box<dyn std::error::Error + Send + Sync>),
31
32 #[cfg(target_family = "wasm")]
33 ToolCallError(#[from] Box<dyn std::error::Error>),
35 JsonError(#[from] serde_json::Error),
37}
38
39impl fmt::Display for ToolError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 match self {
42 ToolError::ToolCallError(e) => {
43 let error_str = e.to_string();
44 if error_str.starts_with("ToolCallError: ") {
47 write!(f, "{}", error_str)
48 } else {
49 write!(f, "ToolCallError: {}", error_str)
50 }
51 }
52 ToolError::JsonError(e) => write!(f, "JsonError: {e}"),
53 }
54 }
55}
56
57pub trait Tool: Sized + WasmCompatSend + WasmCompatSync {
113 const NAME: &'static str;
116
117 type Error: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
119 type Args: for<'a> Deserialize<'a> + WasmCompatSend + WasmCompatSync;
121 type Output: Serialize;
123
124 fn name(&self) -> String {
126 Self::NAME.to_string()
127 }
128
129 fn definition(
132 &self,
133 _prompt: String,
134 ) -> impl Future<Output = ToolDefinition> + WasmCompatSend + WasmCompatSync;
135
136 fn call(
140 &self,
141 args: Self::Args,
142 ) -> impl Future<Output = Result<Self::Output, Self::Error>> + WasmCompatSend;
143}
144
145pub trait ToolEmbedding: Tool {
147 type InitError: std::error::Error + WasmCompatSend + WasmCompatSync + 'static;
149
150 type Context: for<'a> Deserialize<'a> + Serialize;
155
156 type State: WasmCompatSend;
160
161 fn embedding_docs(&self) -> Vec<String>;
165
166 fn context(&self) -> Self::Context;
168
169 fn init(state: Self::State, context: Self::Context) -> Result<Self, Self::InitError>;
171}
172
173pub trait ToolDyn: WasmCompatSend + WasmCompatSync {
175 fn name(&self) -> String;
177
178 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition>;
180
181 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>>;
183}
184
185fn serialize_tool_output(output: impl Serialize) -> serde_json::Result<String> {
186 match serde_json::to_value(output)? {
187 serde_json::Value::String(text) => Ok(text),
188 value => Ok(value.to_string()),
189 }
190}
191
192impl<T: Tool> ToolDyn for T {
193 fn name(&self) -> String {
194 self.name()
195 }
196
197 fn definition<'a>(&'a self, prompt: String) -> WasmBoxedFuture<'a, ToolDefinition> {
198 Box::pin(<Self as Tool>::definition(self, prompt))
199 }
200
201 fn call<'a>(&'a self, args: String) -> WasmBoxedFuture<'a, Result<String, ToolError>> {
202 Box::pin(async move {
203 match serde_json::from_str(&args) {
204 Ok(args) => <Self as Tool>::call(self, args)
205 .await
206 .map_err(|e| ToolError::ToolCallError(Box::new(e)))
207 .and_then(|output| serialize_tool_output(output).map_err(ToolError::JsonError)),
208 Err(e) => Err(ToolError::JsonError(e)),
209 }
210 })
211 }
212}
213
214#[cfg(feature = "rmcp")]
215#[cfg_attr(docsrs, doc(cfg(feature = "rmcp")))]
216pub mod rmcp;
217
218pub trait ToolEmbeddingDyn: ToolDyn {
220 fn context(&self) -> serde_json::Result<serde_json::Value>;
222
223 fn embedding_docs(&self) -> Vec<String>;
225}
226
227impl<T> ToolEmbeddingDyn for T
228where
229 T: ToolEmbedding + 'static,
230{
231 fn context(&self) -> serde_json::Result<serde_json::Value> {
232 serde_json::to_value(self.context())
233 }
234
235 fn embedding_docs(&self) -> Vec<String> {
236 self.embedding_docs()
237 }
238}
239
240#[derive(Clone)]
241pub(crate) enum ToolType {
242 Simple(Arc<dyn ToolDyn>),
243 Embedding(Arc<dyn ToolEmbeddingDyn>),
244}
245
246impl ToolType {
247 pub fn name(&self) -> String {
248 match self {
249 ToolType::Simple(tool) => tool.name(),
250 ToolType::Embedding(tool) => tool.name(),
251 }
252 }
253
254 pub async fn definition(&self, prompt: String) -> ToolDefinition {
255 match self {
256 ToolType::Simple(tool) => tool.definition(prompt).await,
257 ToolType::Embedding(tool) => tool.definition(prompt).await,
258 }
259 }
260
261 pub async fn call(&self, args: String) -> Result<String, ToolError> {
262 match self {
263 ToolType::Simple(tool) => tool.call(args).await,
264 ToolType::Embedding(tool) => tool.call(args).await,
265 }
266 }
267}
268
269#[derive(Debug, thiserror::Error)]
270pub enum ToolSetError {
271 #[error("ToolCallError: {0}")]
273 ToolCallError(#[from] ToolError),
274
275 #[error("ToolNotFoundError: {0}")]
277 ToolNotFoundError(String),
278
279 #[error("JsonError: {0}")]
281 JsonError(#[from] serde_json::Error),
282
283 #[error("Tool call interrupted")]
285 Interrupted,
286}
287
288#[derive(Default)]
290pub struct ToolSet {
291 pub(crate) tools: HashMap<String, ToolType>,
292}
293
294impl ToolSet {
295 pub fn from_tools(tools: Vec<impl ToolDyn + 'static>) -> Self {
297 let mut toolset = Self::default();
298 tools.into_iter().for_each(|tool| {
299 toolset.add_tool(tool);
300 });
301 toolset
302 }
303
304 pub fn from_tools_boxed(tools: Vec<Box<dyn ToolDyn + 'static>>) -> Self {
306 let mut toolset = Self::default();
307 tools.into_iter().for_each(|tool| {
308 toolset.add_tool_boxed(tool);
309 });
310 toolset
311 }
312
313 pub fn builder() -> ToolSetBuilder {
315 ToolSetBuilder::default()
316 }
317
318 pub fn contains(&self, toolname: &str) -> bool {
320 self.tools.contains_key(toolname)
321 }
322
323 pub fn add_tool(&mut self, tool: impl ToolDyn + 'static) {
325 self.tools
326 .insert(tool.name(), ToolType::Simple(Arc::new(tool)));
327 }
328
329 pub fn add_tool_boxed(&mut self, tool: Box<dyn ToolDyn>) {
331 self.tools
332 .insert(tool.name(), ToolType::Simple(Arc::from(tool)));
333 }
334
335 pub fn delete_tool(&mut self, tool_name: &str) {
337 let _ = self.tools.remove(tool_name);
338 }
339
340 pub fn add_tools(&mut self, toolset: ToolSet) {
342 self.tools.extend(toolset.tools);
343 }
344
345 pub(crate) fn get(&self, toolname: &str) -> Option<&ToolType> {
346 self.tools.get(toolname)
347 }
348
349 pub async fn get_tool_definitions(&self) -> Result<Vec<ToolDefinition>, ToolSetError> {
351 let mut defs = Vec::new();
352 for tool in self.tools.values() {
353 let def = tool.definition(String::new()).await;
354 defs.push(def);
355 }
356 Ok(defs)
357 }
358
359 pub async fn call(&self, toolname: &str, args: String) -> Result<String, ToolSetError> {
361 if let Some(tool) = self.tools.get(toolname) {
362 tracing::debug!(target: "rig",
363 "Calling tool {toolname} with args:\n{}",
364 args
365 );
366 Ok(tool.call(args).await?)
367 } else {
368 Err(ToolSetError::ToolNotFoundError(toolname.to_string()))
369 }
370 }
371
372 pub async fn documents(&self) -> Result<Vec<completion::Document>, ToolSetError> {
374 let mut docs = Vec::new();
375 for tool in self.tools.values() {
376 match tool {
377 ToolType::Simple(tool) => {
378 docs.push(completion::Document {
379 id: tool.name(),
380 text: format!(
381 "\
382 Tool: {}\n\
383 Definition: \n\
384 {}\
385 ",
386 tool.name(),
387 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
388 ),
389 additional_props: HashMap::new(),
390 });
391 }
392 ToolType::Embedding(tool) => {
393 docs.push(completion::Document {
394 id: tool.name(),
395 text: format!(
396 "\
397 Tool: {}\n\
398 Definition: \n\
399 {}\
400 ",
401 tool.name(),
402 serde_json::to_string_pretty(&tool.definition("".to_string()).await)?
403 ),
404 additional_props: HashMap::new(),
405 });
406 }
407 }
408 }
409 Ok(docs)
410 }
411
412 pub fn schemas(&self) -> Result<Vec<ToolSchema>, EmbedError> {
416 self.tools
417 .values()
418 .filter_map(|tool_type| {
419 if let ToolType::Embedding(tool) = tool_type {
420 Some(ToolSchema::try_from(&**tool))
421 } else {
422 None
423 }
424 })
425 .collect::<Result<Vec<_>, _>>()
426 }
427}
428
429#[derive(Default)]
430pub struct ToolSetBuilder {
432 tools: Vec<ToolType>,
433}
434
435impl ToolSetBuilder {
436 pub fn static_tool(mut self, tool: impl ToolDyn + 'static) -> Self {
438 self.tools.push(ToolType::Simple(Arc::new(tool)));
439 self
440 }
441
442 pub fn dynamic_tool(mut self, tool: impl ToolEmbeddingDyn + 'static) -> Self {
444 self.tools.push(ToolType::Embedding(Arc::new(tool)));
445 self
446 }
447
448 pub fn build(self) -> ToolSet {
450 ToolSet {
451 tools: self
452 .tools
453 .into_iter()
454 .map(|tool| (tool.name(), tool))
455 .collect(),
456 }
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use crate::message::{DocumentSourceKind, ToolResultContent};
463 use crate::test_utils::{
464 MockImageOutputTool, MockObjectOutputTool, MockStringOutputTool, mock_math_toolset,
465 };
466 use serde_json::json;
467
468 use super::*;
469
470 fn get_test_toolset() -> ToolSet {
471 mock_math_toolset()
472 }
473
474 #[tokio::test]
475 async fn test_get_tool_definitions() {
476 let toolset = get_test_toolset();
477 let tools = toolset.get_tool_definitions().await.unwrap();
478 assert_eq!(tools.len(), 2);
479 }
480
481 #[test]
482 fn test_tool_deletion() {
483 let mut toolset = get_test_toolset();
484 assert_eq!(toolset.tools.len(), 2);
485 toolset.delete_tool("add");
486 assert!(!toolset.contains("add"));
487 assert_eq!(toolset.tools.len(), 1);
488 }
489
490 #[tokio::test]
491 async fn string_tool_outputs_are_preserved_verbatim() {
492 let mut toolset = ToolSet::default();
493 toolset.add_tool(MockStringOutputTool);
494
495 let output = toolset
496 .call("string_output", "{}".to_string())
497 .await
498 .expect("tool should succeed");
499
500 assert_eq!(output, "Hello\nWorld");
501 }
502
503 #[tokio::test]
504 async fn structured_string_tool_outputs_remain_parseable() {
505 let mut toolset = ToolSet::default();
506 toolset.add_tool(MockImageOutputTool);
507
508 let output = toolset
509 .call("image_output", "{}".to_string())
510 .await
511 .expect("tool should succeed");
512 let content = ToolResultContent::from_tool_output(output);
513
514 assert_eq!(content.len(), 1);
515 match content.first() {
516 ToolResultContent::Image(image) => {
517 assert!(matches!(image.data, DocumentSourceKind::Base64(_)));
518 assert_eq!(image.media_type, Some(crate::message::ImageMediaType::PNG));
519 }
520 other => panic!("expected image tool result content, got {other:?}"),
521 }
522 }
523
524 #[tokio::test]
525 async fn object_tool_outputs_still_serialize_as_json() {
526 let mut toolset = ToolSet::default();
527 toolset.add_tool(MockObjectOutputTool);
528
529 let output = toolset
530 .call("object_output", "{}".to_string())
531 .await
532 .expect("tool should succeed");
533
534 assert!(output.starts_with('{'));
535 assert_eq!(
536 serde_json::from_str::<serde_json::Value>(&output).unwrap(),
537 json!({
538 "status": "ok",
539 "count": 42
540 })
541 );
542 }
543}