1use schemars::{JsonSchema, SchemaGenerator, generate::SchemaSettings};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use snafu::{ResultExt, Snafu};
5
6#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(untagged)]
9pub enum Tool {
10 Function {
12 function_declarations: Vec<FunctionDeclaration>,
14 },
15 GoogleSearch {
17 google_search: GoogleSearchConfig,
19 },
20 GoogleMaps {
22 google_maps: Value,
24 },
25 CodeExecution {
27 code_execution: Value,
29 },
30 URLContext {
31 url_context: URLContextConfig,
32 },
33 FileSearch {
35 file_search: Value,
37 },
38 ComputerUse {
40 computer_use: Value,
42 },
43 McpServer {
45 #[serde(rename = "mcp_server")]
47 mcp_server: Value,
48 },
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
53pub struct GoogleSearchConfig {}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57pub struct URLContextConfig {}
58
59impl Tool {
60 pub fn new(function_declaration: FunctionDeclaration) -> Self {
62 Self::Function { function_declarations: vec![function_declaration] }
63 }
64
65 pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
67 Self::Function { function_declarations }
68 }
69
70 pub fn google_search() -> Self {
72 Self::GoogleSearch { google_search: GoogleSearchConfig {} }
73 }
74
75 pub fn url_context() -> Self {
77 Self::URLContext { url_context: URLContextConfig {} }
78 }
79
80 pub fn google_maps(config: Value) -> Self {
82 Self::GoogleMaps { google_maps: config }
83 }
84
85 pub fn code_execution() -> Self {
87 Self::CodeExecution { code_execution: Value::Object(Default::default()) }
88 }
89
90 pub fn file_search(config: Value) -> Self {
92 Self::FileSearch { file_search: config }
93 }
94
95 pub fn computer_use(config: Value) -> Self {
97 Self::ComputerUse { computer_use: config }
98 }
99
100 pub fn mcp_server(config: Value) -> Self {
102 Self::McpServer { mcp_server: config }
103 }
104
105 pub fn is_server_side(&self) -> bool {
112 !matches!(self, Self::Function { .. })
113 }
114}
115
116#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
118#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
119pub enum Behavior {
120 #[default]
123 Blocking,
124 NonBlocking,
128}
129
130#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
132pub struct FunctionDeclaration {
133 pub name: String,
135 pub description: String,
137 #[serde(skip_serializing_if = "Option::is_none")]
139 pub behavior: Option<Behavior>,
140 #[serde(skip_serializing_if = "Option::is_none")]
142 pub(crate) parameters: Option<Value>,
143 #[serde(skip_serializing_if = "Option::is_none")]
147 pub(crate) response: Option<Value>,
148}
149
150fn generate_parameters_schema<Parameters>() -> Value
152where
153 Parameters: JsonSchema + Serialize,
154{
155 let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
157 s.inline_subschemas = true;
158 s.meta_schema = None;
159 }));
160
161 let mut schema = schema_generator.into_root_schema_for::<Parameters>();
162
163 schema.remove("title");
165 schema.to_value()
166}
167
168impl FunctionDeclaration {
169 pub fn new(
171 name: impl Into<String>,
172 description: impl Into<String>,
173 behavior: Option<Behavior>,
174 ) -> Self {
175 Self { name: name.into(), description: description.into(), behavior, ..Default::default() }
176 }
177
178 pub fn with_parameters<Parameters>(mut self) -> Self
180 where
181 Parameters: JsonSchema + Serialize,
182 {
183 self.parameters = Some(generate_parameters_schema::<Parameters>());
184 self
185 }
186
187 pub fn with_response<Response>(mut self) -> Self
189 where
190 Response: JsonSchema + Serialize,
191 {
192 self.response = Some(generate_parameters_schema::<Response>());
193 self
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
199pub struct FunctionCall {
200 pub name: String,
202 pub args: serde_json::Value,
204 #[serde(skip_serializing_if = "Option::is_none", default)]
209 pub id: Option<String>,
210 #[serde(
216 skip_serializing_if = "Option::is_none",
217 default,
218 rename = "thoughtSignature",
219 alias = "thought_signature"
220 )]
221 pub thought_signature: Option<String>,
222}
223
224#[derive(Debug, Snafu)]
225pub enum FunctionCallError {
226 #[snafu(display("failed to deserialize parameter '{key}'"))]
227 Deserialization { source: serde_json::Error, key: String },
228
229 #[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
230 MissingParameter { key: String, args: serde_json::Value },
231
232 #[snafu(display("arguments should be an object; actual: {actual}"))]
233 ArgumentTypeMismatch { actual: String },
234}
235
236impl FunctionCall {
237 pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
239 Self { name: name.into(), args, id: None, thought_signature: None }
240 }
241
242 pub fn with_thought_signature(
244 name: impl Into<String>,
245 args: serde_json::Value,
246 thought_signature: impl Into<String>,
247 ) -> Self {
248 Self {
249 name: name.into(),
250 args,
251 id: None,
252 thought_signature: Some(thought_signature.into()),
253 }
254 }
255
256 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
258 match &self.args {
259 serde_json::Value::Object(obj) => {
260 if let Some(value) = obj.get(key) {
261 serde_json::from_value(value.clone())
262 .with_context(|_| DeserializationSnafu { key: key.to_string() })
263 } else {
264 Err(MissingParameterSnafu { key: key.to_string(), args: self.args.clone() }
265 .build())
266 }
267 }
268 _ => Err(ArgumentTypeMismatchSnafu { actual: self.args.to_string() }.build()),
269 }
270 }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
275pub struct FunctionResponse {
276 pub name: String,
278 #[serde(skip_serializing_if = "Option::is_none")]
281 pub response: Option<serde_json::Value>,
282 #[serde(default, skip_serializing_if = "Vec::is_empty")]
286 pub parts: Vec<FunctionResponsePart>,
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
294#[serde(untagged)]
295pub enum FunctionResponsePart {
296 InlineData {
298 #[serde(rename = "inlineData")]
299 inline_data: crate::Blob,
300 },
301 FileData {
303 #[serde(rename = "fileData")]
304 file_data: crate::FileDataRef,
305 },
306}
307
308impl FunctionResponse {
309 pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
311 let response = match response {
312 serde_json::Value::Object(_) => response,
313 other => serde_json::json!({ "result": other }),
314 };
315 Self { name: name.into(), response: Some(response), parts: Vec::new() }
316 }
317
318 pub fn with_inline_data(
320 name: impl Into<String>,
321 response: serde_json::Value,
322 inline_data: Vec<crate::Blob>,
323 ) -> Self {
324 let response = match response {
325 serde_json::Value::Object(_) => response,
326 other => serde_json::json!({ "result": other }),
327 };
328 let parts = inline_data
329 .into_iter()
330 .map(|blob| FunctionResponsePart::InlineData { inline_data: blob })
331 .collect();
332 Self { name: name.into(), response: Some(response), parts }
333 }
334
335 pub fn with_file_data(
337 name: impl Into<String>,
338 response: serde_json::Value,
339 file_data: Vec<crate::FileDataRef>,
340 ) -> Self {
341 let response = match response {
342 serde_json::Value::Object(_) => response,
343 other => serde_json::json!({ "result": other }),
344 };
345 let parts = file_data
346 .into_iter()
347 .map(|fdr| FunctionResponsePart::FileData { file_data: fdr })
348 .collect();
349 Self { name: name.into(), response: Some(response), parts }
350 }
351
352 pub fn inline_data_only(name: impl Into<String>, inline_data: Vec<crate::Blob>) -> Self {
354 let parts = inline_data
355 .into_iter()
356 .map(|blob| FunctionResponsePart::InlineData { inline_data: blob })
357 .collect();
358 Self { name: name.into(), response: None, parts }
359 }
360
361 pub fn from_schema<Response>(
363 name: impl Into<String>,
364 response: Response,
365 ) -> Result<Self, serde_json::Error>
366 where
367 Response: JsonSchema + Serialize,
368 {
369 let json = serde_json::to_value(&response)?;
370 Ok(Self::new(name, json))
371 }
372
373 pub fn from_str(
375 name: impl Into<String>,
376 response: impl Into<String>,
377 ) -> Result<Self, serde_json::Error> {
378 let json = serde_json::from_str(&response.into())?;
379 Ok(Self::new(name, json))
380 }
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
385pub struct ToolConfig {
386 #[serde(skip_serializing_if = "Option::is_none")]
388 pub function_calling_config: Option<FunctionCallingConfig>,
389 #[serde(skip_serializing_if = "Option::is_none", rename = "includeServerSideToolInvocations")]
392 pub include_server_side_tool_invocations: Option<bool>,
393 #[serde(skip_serializing_if = "Option::is_none", rename = "retrievalConfig")]
395 pub retrieval_config: Option<Value>,
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
400pub struct FunctionCallingConfig {
401 pub mode: FunctionCallingMode,
403 #[serde(skip_serializing_if = "Option::is_none")]
407 pub allowed_function_names: Option<Vec<String>>,
408}
409
410#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
412#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
413pub enum FunctionCallingMode {
414 Auto,
416 Any,
418 None,
420 Validated,
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn tool_config_include_server_side_tool_invocations_serde_round_trip() {
431 let config = ToolConfig {
432 function_calling_config: None,
433 include_server_side_tool_invocations: Some(true),
434 retrieval_config: None,
435 };
436
437 let json = serde_json::to_value(&config).unwrap();
438 assert_eq!(json["includeServerSideToolInvocations"], true);
439 assert!(json.get("include_server_side_tool_invocations").is_none());
441
442 let deserialized: ToolConfig = serde_json::from_value(json).unwrap();
443 assert_eq!(deserialized, config);
444 }
445
446 #[test]
447 fn tool_config_default_omits_server_side_flag() {
448 let config = ToolConfig::default();
449 assert_eq!(config.include_server_side_tool_invocations, None);
450 assert_eq!(config.retrieval_config, None);
451
452 let json = serde_json::to_value(&config).unwrap();
453 assert!(json.get("includeServerSideToolInvocations").is_none());
454 }
455
456 #[test]
457 fn function_calling_mode_validated_serde_round_trip() {
458 let config = FunctionCallingConfig {
459 mode: FunctionCallingMode::Validated,
460 allowed_function_names: None,
461 };
462 let json = serde_json::to_value(&config).unwrap();
463 assert_eq!(json["mode"], "VALIDATED");
464 let deserialized: FunctionCallingConfig = serde_json::from_value(json).unwrap();
465 assert_eq!(deserialized.mode, FunctionCallingMode::Validated);
466 }
467
468 #[test]
469 fn function_calling_config_with_allowed_names() {
470 let config = FunctionCallingConfig {
471 mode: FunctionCallingMode::Any,
472 allowed_function_names: Some(vec!["get_weather".to_string(), "search".to_string()]),
473 };
474 let json = serde_json::to_value(&config).unwrap();
475 assert_eq!(json["mode"], "ANY");
476 assert_eq!(json["allowed_function_names"], serde_json::json!(["get_weather", "search"]));
477
478 let deserialized: FunctionCallingConfig = serde_json::from_value(json).unwrap();
479 assert_eq!(deserialized, config);
480 }
481
482 #[test]
483 fn function_calling_config_omits_none_allowed_names() {
484 let config =
485 FunctionCallingConfig { mode: FunctionCallingMode::Auto, allowed_function_names: None };
486 let json = serde_json::to_value(&config).unwrap();
487 assert!(json.get("allowed_function_names").is_none());
488 }
489
490 #[test]
491 fn function_call_with_id_serde_round_trip() {
492 let call = FunctionCall {
493 name: "get_weather".to_string(),
494 args: serde_json::json!({"city": "Tokyo"}),
495 id: Some("fc_001".to_string()),
496 thought_signature: None,
497 };
498 let json = serde_json::to_value(&call).unwrap();
499 assert_eq!(json["id"], "fc_001");
500
501 let deserialized: FunctionCall = serde_json::from_value(json).unwrap();
502 assert_eq!(deserialized.id, Some("fc_001".to_string()));
503 }
504
505 #[test]
506 fn function_call_without_id_omits_field() {
507 let call = FunctionCall::new("get_weather", serde_json::json!({"city": "Tokyo"}));
508 let json = serde_json::to_value(&call).unwrap();
509 assert!(json.get("id").is_none());
510 }
511
512 #[test]
513 fn function_call_deserializes_without_id() {
514 let json = serde_json::json!({
515 "name": "get_weather",
516 "args": {"city": "Tokyo"}
517 });
518 let call: FunctionCall = serde_json::from_value(json).unwrap();
519 assert_eq!(call.id, None);
520 assert_eq!(call.name, "get_weather");
521 }
522}