1use schemars::{generate::SchemaSettings, JsonSchema, SchemaGenerator};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use snafu::{ResultExt, Snafu};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(untagged)]
9pub enum Tool {
10 Function {
12 function_declarations: Vec<FunctionDeclaration>,
14 },
15 GoogleSearch {
17 google_search: GoogleSearchConfig,
19 },
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
24pub struct GoogleSearchConfig {}
25
26impl Tool {
27 pub fn new(function_declaration: FunctionDeclaration) -> Self {
29 Self::Function {
30 function_declarations: vec![function_declaration],
31 }
32 }
33
34 pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
36 Self::Function {
37 function_declarations,
38 }
39 }
40
41 pub fn google_search() -> Self {
43 Self::GoogleSearch {
44 google_search: GoogleSearchConfig {},
45 }
46 }
47}
48
49#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
51#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
52pub enum Behavior {
53 #[default]
56 Blocking,
57 NonBlocking,
61}
62
63#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
65pub struct FunctionDeclaration {
66 pub name: String,
68 pub description: String,
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub behavior: Option<Behavior>,
73 #[serde(skip_serializing_if = "Option::is_none")]
75 pub(crate) parameters: Option<Value>,
76 #[serde(skip_serializing_if = "Option::is_none")]
80 pub(crate) response: Option<Value>,
81}
82
83fn generate_parameters_schema<Parameters>() -> Value
85where
86 Parameters: JsonSchema + Serialize,
87{
88 let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
90 s.inline_subschemas = true;
91 s.meta_schema = None;
92 }));
93
94 let mut schema = schema_generator.into_root_schema_for::<Parameters>();
95
96 schema.remove("title");
98 schema.to_value()
99}
100
101impl FunctionDeclaration {
102 pub fn new(
104 name: impl Into<String>,
105 description: impl Into<String>,
106 behavior: Option<Behavior>,
107 ) -> Self {
108 Self {
109 name: name.into(),
110 description: description.into(),
111 behavior,
112 ..Default::default()
113 }
114 }
115
116 pub fn with_parameters<Parameters>(mut self) -> Self
118 where
119 Parameters: JsonSchema + Serialize,
120 {
121 self.parameters = Some(generate_parameters_schema::<Parameters>());
122 self
123 }
124
125 pub fn with_response<Response>(mut self) -> Self
127 where
128 Response: JsonSchema + Serialize,
129 {
130 self.response = Some(generate_parameters_schema::<Response>());
131 self
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
137pub struct FunctionCall {
138 pub name: String,
140 pub args: serde_json::Value,
142 #[serde(skip_serializing_if = "Option::is_none")]
144 pub thought_signature: Option<String>,
145}
146
147#[derive(Debug, Snafu)]
148pub enum FunctionCallError {
149 #[snafu(display("failed to deserialize parameter '{key}'"))]
150 Deserialization {
151 source: serde_json::Error,
152 key: String,
153 },
154
155 #[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
156 MissingParameter {
157 key: String,
158 args: serde_json::Value,
159 },
160
161 #[snafu(display("arguments should be an object; actual: {actual}"))]
162 ArgumentTypeMismatch { actual: String },
163}
164
165impl FunctionCall {
166 pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
168 Self {
169 name: name.into(),
170 args,
171 thought_signature: None,
172 }
173 }
174
175 pub fn with_thought_signature(
177 name: impl Into<String>,
178 args: serde_json::Value,
179 thought_signature: impl Into<String>,
180 ) -> Self {
181 Self {
182 name: name.into(),
183 args,
184 thought_signature: Some(thought_signature.into()),
185 }
186 }
187
188 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
190 match &self.args {
191 serde_json::Value::Object(obj) => {
192 if let Some(value) = obj.get(key) {
193 serde_json::from_value(value.clone()).with_context(|_| DeserializationSnafu {
194 key: key.to_string(),
195 })
196 } else {
197 Err(MissingParameterSnafu {
198 key: key.to_string(),
199 args: self.args.clone(),
200 }
201 .build())
202 }
203 }
204 _ => Err(ArgumentTypeMismatchSnafu {
205 actual: self.args.to_string(),
206 }
207 .build()),
208 }
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
214pub struct FunctionResponse {
215 pub name: String,
217 #[serde(skip_serializing_if = "Option::is_none")]
220 pub response: Option<serde_json::Value>,
221}
222
223impl FunctionResponse {
224 pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
226 Self {
227 name: name.into(),
228 response: Some(response),
229 }
230 }
231
232 pub fn from_schema<Response>(
234 name: impl Into<String>,
235 response: Response,
236 ) -> Result<Self, serde_json::Error>
237 where
238 Response: JsonSchema + Serialize,
239 {
240 let json = serde_json::to_value(&response)?;
241 Ok(Self {
242 name: name.into(),
243 response: Some(json),
244 })
245 }
246
247 pub fn from_str(
249 name: impl Into<String>,
250 response: impl Into<String>,
251 ) -> Result<Self, serde_json::Error> {
252 let json = serde_json::from_str(&response.into())?;
253 Ok(Self {
254 name: name.into(),
255 response: Some(json),
256 })
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
262pub struct ToolConfig {
263 #[serde(skip_serializing_if = "Option::is_none")]
265 pub function_calling_config: Option<FunctionCallingConfig>,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
270pub struct FunctionCallingConfig {
271 pub mode: FunctionCallingMode,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
277#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
278pub enum FunctionCallingMode {
279 Auto,
281 Any,
283 None,
285}