ollama_rs/generation/tools/
mod.rs

1#[cfg_attr(docsrs, doc(cfg(feature = "tool-implementations")))]
2#[cfg(feature = "tool-implementations")]
3pub mod implementations;
4
5use std::{future::Future, pin::Pin};
6
7use schemars::{generate::SchemaSettings, JsonSchema, Schema};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
12
13/// It's highly recommended that the `JsonSchema` has descriptions for all attributes.
14/// Descriptions can be defined with `#[schemars(description = "Hi I am an attribute")]` above each attribute
15// TODO enforce at compile-time
16pub trait Tool: Send + Sync {
17    type Params: Parameters;
18
19    fn name() -> &'static str;
20    fn description() -> &'static str;
21
22    /// Call the tool.
23    /// Note that returning an Err will cause it to be bubbled up. If you want the LLM to handle the error,
24    /// return that error as a string.
25    fn call(
26        &mut self,
27        parameters: Self::Params,
28    ) -> impl Future<Output = Result<String>> + Send + Sync;
29}
30
31pub trait Parameters: DeserializeOwned + JsonSchema {}
32
33impl<P: DeserializeOwned + JsonSchema> Parameters for P {}
34
35pub(crate) trait ToolHolder: Send + Sync {
36    fn call(
37        &mut self,
38        parameters: Value,
39    ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>>;
40}
41
42impl<T: Tool> ToolHolder for T {
43    fn call(
44        &mut self,
45        parameters: Value,
46    ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>> {
47        Box::pin(async move {
48            // Json returned from the model can sometimes be in different formats, see https://github.com/pepperoni21/ollama-rs/issues/210
49            // This is a work-around for this issue.
50            let param_value = match serde_json::from_value(parameters.clone()) {
51                // We first try with the ToolCallFunction format
52                Ok(ToolCallFunction { name: _, arguments }) => arguments,
53                Err(_err) => {
54                    // If that fails we then try the ToolInfo format
55                    let ti: ToolInfo = serde_json::from_value(parameters)?;
56                    ti.function.parameters.to_value()
57                }
58            };
59
60            let param = serde_json::from_value(param_value)?;
61
62            T::call(self, param).await
63        })
64    }
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ToolInfo {
69    #[serde(rename = "type")]
70    pub tool_type: ToolType,
71    pub function: ToolFunctionInfo,
72}
73
74impl ToolInfo {
75    pub(crate) fn new<P: Parameters, T: Tool<Params = P>>() -> Self {
76        let mut settings = SchemaSettings::draft07();
77        settings.inline_subschemas = true;
78        let generator = settings.into_generator();
79
80        let parameters = generator.into_root_schema_for::<P>();
81
82        Self {
83            tool_type: ToolType::Function,
84            function: ToolFunctionInfo {
85                name: T::name().to_string(),
86                description: T::description().to_string(),
87                parameters,
88            },
89        }
90    }
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub enum ToolType {
95    #[serde(rename_all(deserialize = "PascalCase"))]
96    Function,
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize)]
100pub struct ToolFunctionInfo {
101    pub name: String,
102    pub description: String,
103    pub parameters: Schema,
104}
105
106#[derive(Clone, Debug, Serialize, Deserialize)]
107pub struct ToolCall {
108    pub function: ToolCallFunction,
109}
110
111#[derive(Clone, Debug, Serialize, Deserialize)]
112pub struct ToolCallFunction {
113    pub name: String,
114    // I don't love this (the Value)
115    // But fixing it would be a big effort
116    // FIXME
117    #[serde(alias = "parameters")]
118    pub arguments: Value,
119}