1use schemars::{generate::SchemaSettings, JsonSchema, SchemaGenerator};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use snafu::{ResultExt, Snafu};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(untagged)]
9pub enum Tool {
10 Function {
12 function_declarations: Vec<FunctionDeclaration>,
14 },
15 GoogleSearch {
17 google_search: GoogleSearchConfig,
19 },
20 URLContext {
21 url_context: URLContextConfig,
22 },
23 GoogleMaps {
25 google_maps: GoogleMapsConfig,
27 },
28 CodeExecution {
30 #[serde(rename = "codeExecution")]
31 code_execution: CodeExecutionConfig,
32 },
33 FileSearch {
35 file_search: FileSearchConfig,
37 },
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub struct GoogleSearchConfig {}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct URLContextConfig {}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
50#[serde(rename_all = "camelCase")]
51pub struct GoogleMapsConfig {
52 #[serde(skip_serializing_if = "Option::is_none")]
54 pub enable_widget: Option<bool>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
59pub struct CodeExecutionConfig {}
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
65#[serde(rename_all = "camelCase")]
66pub struct ExecutableCode {
67 pub language: CodeLanguage,
69 pub code: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
75#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
76pub enum CodeLanguage {
77 Python,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
85#[serde(rename_all = "camelCase")]
86pub struct CodeExecutionResult {
87 pub outcome: CodeExecutionOutcome,
89 pub output: String,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
95#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
96pub enum CodeExecutionOutcome {
97 OutcomeOk,
99 OutcomeFailed,
101 OutcomeDeadlineExceeded,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107#[serde(rename_all = "camelCase")]
108pub struct FileSearchConfig {
109 pub file_search_store_names: Vec<String>,
111
112 #[serde(skip_serializing_if = "Option::is_none")]
114 pub metadata_filter: Option<String>,
115}
116
117impl Tool {
118 pub fn new(function_declaration: FunctionDeclaration) -> Self {
120 Self::Function {
121 function_declarations: vec![function_declaration],
122 }
123 }
124
125 pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
127 Self::Function {
128 function_declarations,
129 }
130 }
131
132 pub fn google_search() -> Self {
134 Self::GoogleSearch {
135 google_search: GoogleSearchConfig {},
136 }
137 }
138
139 pub fn url_context() -> Self {
141 Self::URLContext {
142 url_context: URLContextConfig {},
143 }
144 }
145
146 pub fn google_maps(enable_widget: Option<bool>) -> Self {
148 Self::GoogleMaps {
149 google_maps: GoogleMapsConfig { enable_widget },
150 }
151 }
152
153 pub fn code_execution() -> Self {
158 Self::CodeExecution {
159 code_execution: CodeExecutionConfig {},
160 }
161 }
162
163 pub fn file_search(store_names: Vec<String>, metadata_filter: Option<String>) -> Self {
165 Self::FileSearch {
166 file_search: FileSearchConfig {
167 file_search_store_names: store_names,
168 metadata_filter,
169 },
170 }
171 }
172}
173
174#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
176#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
177pub enum Behavior {
178 #[default]
181 Blocking,
182 NonBlocking,
186}
187
188#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
190pub struct FunctionDeclaration {
191 pub name: String,
193 pub description: String,
195 #[serde(skip_serializing_if = "Option::is_none")]
197 pub behavior: Option<Behavior>,
198 #[serde(skip_serializing_if = "Option::is_none")]
200 pub(crate) parameters: Option<Value>,
201 #[serde(skip_serializing_if = "Option::is_none")]
205 pub(crate) response: Option<Value>,
206}
207
208fn generate_parameters_schema<Parameters>() -> Value
210where
211 Parameters: JsonSchema + Serialize,
212{
213 let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
215 s.inline_subschemas = true;
216 s.meta_schema = None;
217 }));
218
219 let mut schema = schema_generator.into_root_schema_for::<Parameters>();
220
221 schema.remove("title");
223 schema.to_value()
224}
225
226impl FunctionDeclaration {
227 pub fn new(
229 name: impl Into<String>,
230 description: impl Into<String>,
231 behavior: Option<Behavior>,
232 ) -> Self {
233 Self {
234 name: name.into(),
235 description: description.into(),
236 behavior,
237 ..Default::default()
238 }
239 }
240
241 pub fn with_parameters<Parameters>(mut self) -> Self
243 where
244 Parameters: JsonSchema + Serialize,
245 {
246 self.parameters = Some(generate_parameters_schema::<Parameters>());
247 self
248 }
249
250 pub fn with_response<Response>(mut self) -> Self
252 where
253 Response: JsonSchema + Serialize,
254 {
255 self.response = Some(generate_parameters_schema::<Response>());
256 self
257 }
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
262pub struct FunctionCall {
263 pub name: String,
265 pub args: serde_json::Value,
267 #[serde(skip_serializing_if = "Option::is_none")]
269 pub thought_signature: Option<String>,
270}
271
272#[derive(Debug, Snafu)]
273pub enum FunctionCallError {
274 #[snafu(display("failed to deserialize parameter '{key}'"))]
275 Deserialization {
276 source: serde_json::Error,
277 key: String,
278 },
279
280 #[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
281 MissingParameter {
282 key: String,
283 args: serde_json::Value,
284 },
285
286 #[snafu(display("arguments should be an object; actual: {actual}"))]
287 ArgumentTypeMismatch { actual: String },
288}
289
290impl FunctionCall {
291 pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
293 Self {
294 name: name.into(),
295 args,
296 thought_signature: None,
297 }
298 }
299
300 pub fn with_thought_signature(
302 name: impl Into<String>,
303 args: serde_json::Value,
304 thought_signature: impl Into<String>,
305 ) -> Self {
306 Self {
307 name: name.into(),
308 args,
309 thought_signature: Some(thought_signature.into()),
310 }
311 }
312
313 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
315 match &self.args {
316 serde_json::Value::Object(obj) => {
317 if let Some(value) = obj.get(key) {
318 serde_json::from_value(value.clone()).with_context(|_| DeserializationSnafu {
319 key: key.to_string(),
320 })
321 } else {
322 Err(MissingParameterSnafu {
323 key: key.to_string(),
324 args: self.args.clone(),
325 }
326 .build())
327 }
328 }
329 _ => Err(ArgumentTypeMismatchSnafu {
330 actual: self.args.to_string(),
331 }
332 .build()),
333 }
334 }
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
339pub struct FunctionResponse {
340 pub name: String,
342 #[serde(skip_serializing_if = "Option::is_none")]
345 pub response: Option<serde_json::Value>,
346}
347
348impl FunctionResponse {
349 pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
351 Self {
352 name: name.into(),
353 response: Some(response),
354 }
355 }
356
357 pub fn from_schema<Response>(
359 name: impl Into<String>,
360 response: Response,
361 ) -> Result<Self, serde_json::Error>
362 where
363 Response: JsonSchema + Serialize,
364 {
365 let json = serde_json::to_value(&response)?;
366 Ok(Self {
367 name: name.into(),
368 response: Some(json),
369 })
370 }
371
372 pub fn from_str(
374 name: impl Into<String>,
375 response: impl Into<String>,
376 ) -> Result<Self, serde_json::Error> {
377 let json = serde_json::from_str(&response.into())?;
378 Ok(Self {
379 name: name.into(),
380 response: Some(json),
381 })
382 }
383}
384
385#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
387pub struct ToolConfig {
388 #[serde(skip_serializing_if = "Option::is_none")]
390 pub function_calling_config: Option<FunctionCallingConfig>,
391 #[serde(skip_serializing_if = "Option::is_none")]
393 pub retrieval_config: Option<RetrievalConfig>,
394}
395
396#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
398pub struct FunctionCallingConfig {
399 pub mode: FunctionCallingMode,
401}
402
403#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
405#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
406pub enum FunctionCallingMode {
407 Auto,
409 Any,
411 None,
413}
414
415#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
417#[serde(rename_all = "camelCase")]
418pub struct RetrievalConfig {
419 #[serde(skip_serializing_if = "Option::is_none")]
421 pub lat_lng: Option<LatLng>,
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
426pub struct LatLng {
427 pub latitude: f64,
429 pub longitude: f64,
431}
432
433impl LatLng {
434 pub fn new(latitude: f64, longitude: f64) -> Self {
436 Self {
437 latitude,
438 longitude,
439 }
440 }
441}