ollama_rs/generation/tools/
mod.rs

1#[cfg_attr(docsrs, doc(cfg(feature = "tool-implementations")))]
2#[cfg(feature = "tool-implementations")]
3pub mod implementations;
4
5use std::{future::Future, pin::Pin};
6
7use schemars::{generate::SchemaSettings, JsonSchema, Schema};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
12
13/// It's highly recommended that the `JsonSchema` has descriptions for all attributes.
14/// Descriptions can be defined with `#[schemars(description = "Hi I am an attribute")]` above each attribute
15// TODO enforce at compile-time
16pub trait Tool: Send + Sync {
17    type Params: Parameters;
18
19    fn name() -> &'static str;
20    fn description() -> &'static str;
21
22    /// Call the tool.
23    /// Note that returning an Err will cause it to be bubbled up. If you want the LLM to handle the error,
24    /// return that error as a string.
25    fn call(
26        &mut self,
27        parameters: Self::Params,
28    ) -> impl Future<Output = Result<String>> + Send + Sync;
29}
30
31pub trait Parameters: DeserializeOwned + JsonSchema {}
32
33impl<P: DeserializeOwned + JsonSchema> Parameters for P {}
34
35pub(crate) trait ToolHolder: Send + Sync {
36    fn call(
37        &mut self,
38        parameters: Value,
39    ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>>;
40}
41
42impl<T: Tool> ToolHolder for T {
43    fn call(
44        &mut self,
45        parameters: Value,
46    ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>> {
47        Box::pin(async move {
48            // Json returned from the model can sometimes be in different formats, see https://github.com/pepperoni21/ollama-rs/issues/210
49            // This is a work-around for this issue.
50            let param_value = match serde_json::from_value(parameters.clone()) {
51                // We first try with the ToolCallFunction format
52                Ok(ToolCallFunction { name: _, arguments }) => arguments,
53                Err(_err) => match serde_json::from_value::<ToolInfo>(parameters.clone()) {
54                    Ok(ti) => ti.function.parameters.to_value(),
55                    Err(_err) => parameters,
56                },
57            };
58
59            let param = serde_json::from_value(param_value)?;
60
61            T::call(self, param).await
62        })
63    }
64}
65
66#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct ToolInfo {
68    #[serde(rename = "type")]
69    pub tool_type: ToolType,
70    pub function: ToolFunctionInfo,
71}
72
73impl ToolInfo {
74    pub(crate) fn new<P: Parameters, T: Tool<Params = P>>() -> Self {
75        let mut settings = SchemaSettings::draft07();
76        settings.inline_subschemas = true;
77        let generator = settings.into_generator();
78
79        let parameters = generator.into_root_schema_for::<P>();
80
81        Self {
82            tool_type: ToolType::Function,
83            function: ToolFunctionInfo {
84                name: T::name().to_string(),
85                description: T::description().to_string(),
86                parameters,
87            },
88        }
89    }
90}
91
92#[derive(Clone, Debug, Serialize, Deserialize)]
93pub enum ToolType {
94    #[serde(rename_all(deserialize = "PascalCase"))]
95    Function,
96}
97
98#[derive(Clone, Debug, Serialize, Deserialize)]
99pub struct ToolFunctionInfo {
100    pub name: String,
101    pub description: String,
102    pub parameters: Schema,
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize)]
106pub struct ToolCall {
107    pub function: ToolCallFunction,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize)]
111pub struct ToolCallFunction {
112    pub name: String,
113    // I don't love this (the Value)
114    // But fixing it would be a big effort
115    // FIXME
116    #[serde(alias = "parameters")]
117    pub arguments: Value,
118}