ollama_rs/generation/tools/
mod.rs1#[cfg_attr(docsrs, doc(cfg(feature = "tool-implementations")))]
2#[cfg(feature = "tool-implementations")]
3pub mod implementations;
4
5use std::{future::Future, pin::Pin};
6
7use schemars::{generate::SchemaSettings, JsonSchema, Schema};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use serde_json::Value;
10
11pub type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
12
13pub trait Tool: Send + Sync {
17 type Params: Parameters;
18
19 fn name() -> &'static str;
20 fn description() -> &'static str;
21
22 fn call(
26 &mut self,
27 parameters: Self::Params,
28 ) -> impl Future<Output = Result<String>> + Send + Sync;
29}
30
31pub trait Parameters: DeserializeOwned + JsonSchema {}
32
33impl<P: DeserializeOwned + JsonSchema> Parameters for P {}
34
35pub(crate) trait ToolHolder: Send + Sync {
36 fn call(
37 &mut self,
38 parameters: Value,
39 ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>>;
40}
41
42impl<T: Tool> ToolHolder for T {
43 fn call(
44 &mut self,
45 parameters: Value,
46 ) -> Pin<Box<dyn Future<Output = Result<String>> + '_ + Send + Sync>> {
47 Box::pin(async move {
48 let param_value = match serde_json::from_value(parameters.clone()) {
51 Ok(ToolCallFunction { name: _, arguments }) => arguments,
53 Err(_err) => match serde_json::from_value::<ToolInfo>(parameters.clone()) {
54 Ok(ti) => ti.function.parameters.to_value(),
55 Err(_err) => parameters,
56 },
57 };
58
59 let param = serde_json::from_value(param_value)?;
60
61 T::call(self, param).await
62 })
63 }
64}
65
66#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct ToolInfo {
68 #[serde(rename = "type")]
69 pub tool_type: ToolType,
70 pub function: ToolFunctionInfo,
71}
72
73impl ToolInfo {
74 pub(crate) fn new<P: Parameters, T: Tool<Params = P>>() -> Self {
75 let mut settings = SchemaSettings::draft07();
76 settings.inline_subschemas = true;
77 let generator = settings.into_generator();
78
79 let parameters = generator.into_root_schema_for::<P>();
80
81 Self {
82 tool_type: ToolType::Function,
83 function: ToolFunctionInfo {
84 name: T::name().to_string(),
85 description: T::description().to_string(),
86 parameters,
87 },
88 }
89 }
90}
91
92#[derive(Clone, Debug, Serialize, Deserialize)]
93pub enum ToolType {
94 #[serde(rename_all(deserialize = "PascalCase"))]
95 Function,
96}
97
98#[derive(Clone, Debug, Serialize, Deserialize)]
99pub struct ToolFunctionInfo {
100 pub name: String,
101 pub description: String,
102 pub parameters: Schema,
103}
104
105#[derive(Clone, Debug, Serialize, Deserialize)]
106pub struct ToolCall {
107 pub function: ToolCallFunction,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize)]
111pub struct ToolCallFunction {
112 pub name: String,
113 #[serde(alias = "parameters")]
117 pub arguments: Value,
118}