llm_chain/tools/
tool.rs

1use super::description::ToolDescription;
2
3use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Serialize};
5
6/// Marker trait for Tool errors. It is needed so the concrete Errors can have a derived `From<ToolError>`
7pub trait ToolError {}
8
9/// The `Tool` trait defines an interface for tools that can be added to a `ToolCollection`.
10///
11/// A `Tool` is a function that takes a YAML-formatted input and returns a YAML-formatted output.
12/// It has a description that contains metadata about the tool, such as its name and usage.
13#[async_trait]
14pub trait Tool {
15    type Input: DeserializeOwned + Send + Sync;
16    type Output: Serialize;
17    type Error: std::fmt::Debug + std::error::Error + ToolError + From<serde_yaml::Error>;
18
19    async fn invoke_typed(&self, input: &Self::Input) -> Result<Self::Output, Self::Error>;
20
21    /// Returns the `ToolDescription` containing metadata about the tool.
22    fn description(&self) -> ToolDescription;
23
24    /// Invokes the tool with the given YAML-formatted input.
25    ///
26    /// # Errors
27    ///
28    /// Returns an `ToolUseError` if the input is not in the expected format or if the tool
29    /// fails to produce a valid output.
30    async fn invoke(&self, input: serde_yaml::Value) -> Result<serde_yaml::Value, Self::Error> {
31        let input = serde_yaml::from_value(input)
32            .map_err(<serde_yaml::Error as Into<Self::Error>>::into)?;
33        let output = self.invoke_typed(&input).await?;
34        Ok(serde_yaml::to_value(output)?)
35    }
36
37    /// Checks whether the tool matches the given name.
38    ///
39    /// This function is used to find the appropriate tool in a `ToolCollection` based on its name.
40    fn matches(&self, name: &str) -> bool {
41        self.description().name == name
42    }
43}