ollama_rs/generation/tools/
mod.rs1#[cfg_attr(docsrs, doc(cfg(feature = "tool-implementations")))]
2#[cfg(feature = "tool-implementations")]
3pub mod implementations;
4
5use std::{future::Future, pin::Pin};
6
7use schemars::{generate::SchemaSettings, JsonSchema, Schema};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
12
13pub trait Tool: Send + Sync {
17 type Params: Parameters;
18
19 fn name() -> &'static str;
20 fn description() -> &'static str;
21
22 fn call(
26 &mut self,
27 parameters: Self::Params,
28 ) -> impl Future<Output = Result<String>> + Send + Sync;
29}
30
31pub trait Parameters: DeserializeOwned + JsonSchema {}
32
33impl<P: DeserializeOwned + JsonSchema> Parameters for P {}
34
35pub(crate) trait ToolHolder: Send + Sync {
36 fn call(
37 &mut self,
38 parameters: Value,
39 ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>>;
40}
41
42impl<T: Tool> ToolHolder for T {
43 fn call(
44 &mut self,
45 parameters: Value,
46 ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>> {
47 Box::pin(async move {
48 let param_value = match serde_json::from_value(parameters.clone()) {
51 Ok(ToolCallFunction { name: _, arguments }) => arguments,
53 Err(_err) => {
54 let ti: ToolInfo = serde_json::from_value(parameters)?;
56 ti.function.parameters.to_value()
57 }
58 };
59
60 let param = serde_json::from_value(param_value)?;
61
62 T::call(self, param).await
63 })
64 }
65}
66
67#[derive(Clone, Debug, Serialize, Deserialize)]
68pub struct ToolInfo {
69 #[serde(rename = "type")]
70 pub tool_type: ToolType,
71 pub function: ToolFunctionInfo,
72}
73
74impl ToolInfo {
75 pub(crate) fn new<P: Parameters, T: Tool<Params = P>>() -> Self {
76 let mut settings = SchemaSettings::draft07();
77 settings.inline_subschemas = true;
78 let generator = settings.into_generator();
79
80 let parameters = generator.into_root_schema_for::<P>();
81
82 Self {
83 tool_type: ToolType::Function,
84 function: ToolFunctionInfo {
85 name: T::name().to_string(),
86 description: T::description().to_string(),
87 parameters,
88 },
89 }
90 }
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub enum ToolType {
95 #[serde(rename_all(deserialize = "PascalCase"))]
96 Function,
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize)]
100pub struct ToolFunctionInfo {
101 pub name: String,
102 pub description: String,
103 pub parameters: Schema,
104}
105
106#[derive(Clone, Debug, Serialize, Deserialize)]
107pub struct ToolCall {
108 pub function: ToolCallFunction,
109}
110
111#[derive(Clone, Debug, Serialize, Deserialize)]
112pub struct ToolCallFunction {
113 pub name: String,
114 #[serde(alias = "parameters")]
118 pub arguments: Value,
119}