ollama_rs/generation/tools/
mod.rs1#[cfg_attr(docsrs, doc(cfg(feature = "tool-implementations")))]
2#[cfg(feature = "tool-implementations")]
3pub mod implementations;
4
5use std::{error::Error, future::Future};
6
7use schemars::{gen::SchemaSettings, schema::RootSchema, JsonSchema};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11use crate::error::ToolCallError;
12
13pub trait Tool {
17 type Params: Parameters;
18
19 fn name() -> &'static str;
20 fn description() -> &'static str;
21
22 fn call(
26 &mut self,
27 parameters: Self::Params,
28 ) -> impl Future<Output = Result<String, Box<dyn Error + Sync + Send>>>;
29}
30
31pub trait Parameters: DeserializeOwned + JsonSchema {}
32
33impl<P: DeserializeOwned + JsonSchema> Parameters for P {}
34
35pub trait ToolGroup {
36 fn tool_info(out: &mut Vec<ToolInfo>);
37
38 fn call(
39 &mut self,
40 tool_call: &ToolCallFunction,
41 ) -> impl Future<Output = Result<String, ToolCallError>>;
42}
43
44impl ToolGroup for () {
45 fn tool_info(_: &mut Vec<ToolInfo>) {}
46
47 async fn call(&mut self, _tool_call: &ToolCallFunction) -> Result<String, ToolCallError> {
48 Err(ToolCallError::UnknownToolName)
49 }
50}
51
52impl<T: Tool> ToolGroup for T {
53 fn tool_info(out: &mut Vec<ToolInfo>) {
54 out.push(ToolInfo::new::<_, T>())
55 }
56
57 async fn call(&mut self, tool_call: &ToolCallFunction) -> Result<String, ToolCallError> {
58 if tool_call.name == T::name() {
59 let p = serde_json::from_value(tool_call.arguments.clone())?;
60 return Ok(serde_json::to_string(&self.call(p).await?)?);
61 }
62
63 Err(ToolCallError::UnknownToolName)
64 }
65}
66
67impl<A: ToolGroup, B: ToolGroup> ToolGroup for (A, B) {
68 fn tool_info(out: &mut Vec<ToolInfo>) {
69 A::tool_info(out);
70 B::tool_info(out);
71 }
72
73 async fn call(&mut self, arguments: &ToolCallFunction) -> Result<String, ToolCallError> {
74 match self.0.call(arguments).await {
75 Ok(x) => Ok(x),
76 Err(ToolCallError::UnknownToolName) => self.1.call(arguments).await,
77 Err(e) => Err(e),
78 }
79 }
80}
81
82#[derive(Clone, Debug, Serialize)]
83pub struct ToolInfo {
84 #[serde(rename = "type")]
85 tool_type: ToolType,
86 function: ToolFunctionInfo,
87}
88
89impl ToolInfo {
90 fn new<P: Parameters, T: Tool<Params = P>>() -> Self {
91 let mut settings = SchemaSettings::draft07();
92 settings.inline_subschemas = true;
93 let generator = settings.into_generator();
94
95 let parameters = generator.into_root_schema_for::<P>();
96
97 Self {
98 tool_type: ToolType::Function,
99 function: ToolFunctionInfo {
100 name: T::name(),
101 description: T::description(),
102 parameters,
103 },
104 }
105 }
106}
107
108#[derive(Clone, Debug, Serialize)]
109enum ToolType {
110 Function,
111}
112
113#[derive(Clone, Debug, Serialize)]
114struct ToolFunctionInfo {
115 name: &'static str,
116 description: &'static str,
117 parameters: RootSchema,
118}
119
120#[derive(Clone, Debug, Serialize, Deserialize)]
121pub struct ToolCall {
122 pub function: ToolCallFunction,
123}
124
125#[derive(Clone, Debug, Serialize, Deserialize)]
126pub struct ToolCallFunction {
127 name: String,
128 arguments: Value,
131}