ollama_rs/generation/tools/
mod.rs

1#[cfg_attr(docsrs, doc(cfg(feature = "tool-implementations")))]
2#[cfg(feature = "tool-implementations")]
3pub mod implementations;
4
5use std::{error::Error, future::Future};
6
7use schemars::{gen::SchemaSettings, schema::RootSchema, JsonSchema};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::error::ToolCallError;
12
13/// It's highly recommended that the `JsonSchema` has descriptions for all attributes.
14/// Descriptions can be defined with `#[schemars(descripion = "Hi I am an attribute")]` above each attribute
15// TODO enforce at compile-time
16pub trait Tool {
17    type Params: Parameters;
18
19    fn name() -> &'static str;
20    fn description() -> &'static str;
21
22    /// Call the tool.
23    /// Note that returning an Err will cause it to be bubbled up. If you want the LLM to handle the error,
24    /// return that error as a string.
25    fn call(
26        &mut self,
27        parameters: Self::Params,
28    ) -> impl Future<Output = Result<String, Box<dyn Error + Sync + Send>>>;
29}
30
31pub trait Parameters: DeserializeOwned + JsonSchema {}
32
33impl<P: DeserializeOwned + JsonSchema> Parameters for P {}
34
35pub trait ToolGroup {
36    fn tool_info(out: &mut Vec<ToolInfo>);
37
38    fn call(
39        &mut self,
40        tool_call: &ToolCallFunction,
41    ) -> impl Future<Output = Result<String, ToolCallError>>;
42}
43
44impl ToolGroup for () {
45    fn tool_info(_: &mut Vec<ToolInfo>) {}
46
47    async fn call(&mut self, _tool_call: &ToolCallFunction) -> Result<String, ToolCallError> {
48        Err(ToolCallError::UnknownToolName)
49    }
50}
51
52impl<T: Tool> ToolGroup for T {
53    fn tool_info(out: &mut Vec<ToolInfo>) {
54        out.push(ToolInfo::new::<_, T>())
55    }
56
57    async fn call(&mut self, tool_call: &ToolCallFunction) -> Result<String, ToolCallError> {
58        if tool_call.name == T::name() {
59            let p = serde_json::from_value(tool_call.arguments.clone())?;
60            return Ok(serde_json::to_string(&self.call(p).await?)?);
61        }
62
63        Err(ToolCallError::UnknownToolName)
64    }
65}
66
67impl<A: ToolGroup, B: ToolGroup> ToolGroup for (A, B) {
68    fn tool_info(out: &mut Vec<ToolInfo>) {
69        A::tool_info(out);
70        B::tool_info(out);
71    }
72
73    async fn call(&mut self, arguments: &ToolCallFunction) -> Result<String, ToolCallError> {
74        match self.0.call(arguments).await {
75            Ok(x) => Ok(x),
76            Err(ToolCallError::UnknownToolName) => self.1.call(arguments).await,
77            Err(e) => Err(e),
78        }
79    }
80}
81
82#[derive(Clone, Debug, Serialize)]
83pub struct ToolInfo {
84    #[serde(rename = "type")]
85    tool_type: ToolType,
86    function: ToolFunctionInfo,
87}
88
89impl ToolInfo {
90    fn new<P: Parameters, T: Tool<Params = P>>() -> Self {
91        let mut settings = SchemaSettings::draft07();
92        settings.inline_subschemas = true;
93        let generator = settings.into_generator();
94
95        let parameters = generator.into_root_schema_for::<P>();
96
97        Self {
98            tool_type: ToolType::Function,
99            function: ToolFunctionInfo {
100                name: T::name(),
101                description: T::description(),
102                parameters,
103            },
104        }
105    }
106}
107
108#[derive(Clone, Debug, Serialize)]
109enum ToolType {
110    Function,
111}
112
113#[derive(Clone, Debug, Serialize)]
114struct ToolFunctionInfo {
115    name: &'static str,
116    description: &'static str,
117    parameters: RootSchema,
118}
119
120#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct ToolCall {
122    pub function: ToolCallFunction,
123}
124
125#[derive(Clone, Debug, Serialize, Deserialize)]
126pub struct ToolCallFunction {
127    name: String,
128    // I don't love this (the Value)
129    // But fixing it would be a big effort
130    arguments: Value,
131}