use schemars::{JsonSchema, SchemaGenerator, generate::SchemaSettings};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use snafu::{ResultExt, Snafu};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum Tool {
Function {
function_declarations: Vec<FunctionDeclaration>,
},
GoogleSearch {
google_search: GoogleSearchConfig,
},
GoogleMaps {
google_maps: Value,
},
CodeExecution {
code_execution: Value,
},
URLContext {
url_context: URLContextConfig,
},
FileSearch {
file_search: Value,
},
ComputerUse {
computer_use: Value,
},
McpServer {
#[serde(rename = "mcp_server")]
mcp_server: Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GoogleSearchConfig {}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct URLContextConfig {}
impl Tool {
pub fn new(function_declaration: FunctionDeclaration) -> Self {
Self::Function { function_declarations: vec![function_declaration] }
}
pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
Self::Function { function_declarations }
}
pub fn google_search() -> Self {
Self::GoogleSearch { google_search: GoogleSearchConfig {} }
}
pub fn url_context() -> Self {
Self::URLContext { url_context: URLContextConfig {} }
}
pub fn google_maps(config: Value) -> Self {
Self::GoogleMaps { google_maps: config }
}
pub fn code_execution() -> Self {
Self::CodeExecution { code_execution: Value::Object(Default::default()) }
}
pub fn file_search(config: Value) -> Self {
Self::FileSearch { file_search: config }
}
pub fn computer_use(config: Value) -> Self {
Self::ComputerUse { computer_use: config }
}
pub fn mcp_server(config: Value) -> Self {
Self::McpServer { mcp_server: config }
}
pub fn is_server_side(&self) -> bool {
!matches!(self, Self::Function { .. })
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Behavior {
#[default]
Blocking,
NonBlocking,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub behavior: Option<Behavior>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) parameters: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) response: Option<Value>,
}
fn generate_parameters_schema<Parameters>() -> Value
where
Parameters: JsonSchema + Serialize,
{
let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
s.inline_subschemas = true;
s.meta_schema = None;
}));
let mut schema = schema_generator.into_root_schema_for::<Parameters>();
schema.remove("title");
schema.to_value()
}
impl FunctionDeclaration {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
behavior: Option<Behavior>,
) -> Self {
Self { name: name.into(), description: description.into(), behavior, ..Default::default() }
}
pub fn with_parameters<Parameters>(mut self) -> Self
where
Parameters: JsonSchema + Serialize,
{
self.parameters = Some(generate_parameters_schema::<Parameters>());
self
}
pub fn with_response<Response>(mut self) -> Self
where
Response: JsonSchema + Serialize,
{
self.response = Some(generate_parameters_schema::<Response>());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub id: Option<String>,
#[serde(
skip_serializing_if = "Option::is_none",
default,
rename = "thoughtSignature",
alias = "thought_signature"
)]
pub thought_signature: Option<String>,
}
#[derive(Debug, Snafu)]
pub enum FunctionCallError {
#[snafu(display("failed to deserialize parameter '{key}'"))]
Deserialization { source: serde_json::Error, key: String },
#[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
MissingParameter { key: String, args: serde_json::Value },
#[snafu(display("arguments should be an object; actual: {actual}"))]
ArgumentTypeMismatch { actual: String },
}
impl FunctionCall {
pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
Self { name: name.into(), args, id: None, thought_signature: None }
}
pub fn with_thought_signature(
name: impl Into<String>,
args: serde_json::Value,
thought_signature: impl Into<String>,
) -> Self {
Self {
name: name.into(),
args,
id: None,
thought_signature: Some(thought_signature.into()),
}
}
pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
match &self.args {
serde_json::Value::Object(obj) => {
if let Some(value) = obj.get(key) {
serde_json::from_value(value.clone())
.with_context(|_| DeserializationSnafu { key: key.to_string() })
} else {
Err(MissingParameterSnafu { key: key.to_string(), args: self.args.clone() }
.build())
}
}
_ => Err(ArgumentTypeMismatchSnafu { actual: self.args.to_string() }.build()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionResponse {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub parts: Vec<FunctionResponsePart>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum FunctionResponsePart {
InlineData {
#[serde(rename = "inlineData")]
inline_data: crate::Blob,
},
FileData {
#[serde(rename = "fileData")]
file_data: crate::FileDataRef,
},
}
impl FunctionResponse {
pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
let response = match response {
serde_json::Value::Object(_) => response,
other => serde_json::json!({ "result": other }),
};
Self { name: name.into(), response: Some(response), parts: Vec::new() }
}
pub fn with_inline_data(
name: impl Into<String>,
response: serde_json::Value,
inline_data: Vec<crate::Blob>,
) -> Self {
let response = match response {
serde_json::Value::Object(_) => response,
other => serde_json::json!({ "result": other }),
};
let parts = inline_data
.into_iter()
.map(|blob| FunctionResponsePart::InlineData { inline_data: blob })
.collect();
Self { name: name.into(), response: Some(response), parts }
}
pub fn with_file_data(
name: impl Into<String>,
response: serde_json::Value,
file_data: Vec<crate::FileDataRef>,
) -> Self {
let response = match response {
serde_json::Value::Object(_) => response,
other => serde_json::json!({ "result": other }),
};
let parts = file_data
.into_iter()
.map(|fdr| FunctionResponsePart::FileData { file_data: fdr })
.collect();
Self { name: name.into(), response: Some(response), parts }
}
pub fn inline_data_only(name: impl Into<String>, inline_data: Vec<crate::Blob>) -> Self {
let parts = inline_data
.into_iter()
.map(|blob| FunctionResponsePart::InlineData { inline_data: blob })
.collect();
Self { name: name.into(), response: None, parts }
}
pub fn from_schema<Response>(
name: impl Into<String>,
response: Response,
) -> Result<Self, serde_json::Error>
where
Response: JsonSchema + Serialize,
{
let json = serde_json::to_value(&response)?;
Ok(Self::new(name, json))
}
pub fn from_str(
name: impl Into<String>,
response: impl Into<String>,
) -> Result<Self, serde_json::Error> {
let json = serde_json::from_str(&response.into())?;
Ok(Self::new(name, json))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct ToolConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub function_calling_config: Option<FunctionCallingConfig>,
#[serde(skip_serializing_if = "Option::is_none", rename = "includeServerSideToolInvocations")]
pub include_server_side_tool_invocations: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none", rename = "retrievalConfig")]
pub retrieval_config: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCallingConfig {
pub mode: FunctionCallingMode,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
Validated,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tool_config_include_server_side_tool_invocations_serde_round_trip() {
let config = ToolConfig {
function_calling_config: None,
include_server_side_tool_invocations: Some(true),
retrieval_config: None,
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["includeServerSideToolInvocations"], true);
assert!(json.get("include_server_side_tool_invocations").is_none());
let deserialized: ToolConfig = serde_json::from_value(json).unwrap();
assert_eq!(deserialized, config);
}
#[test]
fn tool_config_default_omits_server_side_flag() {
let config = ToolConfig::default();
assert_eq!(config.include_server_side_tool_invocations, None);
assert_eq!(config.retrieval_config, None);
let json = serde_json::to_value(&config).unwrap();
assert!(json.get("includeServerSideToolInvocations").is_none());
}
#[test]
fn function_calling_mode_validated_serde_round_trip() {
let config = FunctionCallingConfig {
mode: FunctionCallingMode::Validated,
allowed_function_names: None,
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["mode"], "VALIDATED");
let deserialized: FunctionCallingConfig = serde_json::from_value(json).unwrap();
assert_eq!(deserialized.mode, FunctionCallingMode::Validated);
}
#[test]
fn function_calling_config_with_allowed_names() {
let config = FunctionCallingConfig {
mode: FunctionCallingMode::Any,
allowed_function_names: Some(vec!["get_weather".to_string(), "search".to_string()]),
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["mode"], "ANY");
assert_eq!(json["allowed_function_names"], serde_json::json!(["get_weather", "search"]));
let deserialized: FunctionCallingConfig = serde_json::from_value(json).unwrap();
assert_eq!(deserialized, config);
}
#[test]
fn function_calling_config_omits_none_allowed_names() {
let config =
FunctionCallingConfig { mode: FunctionCallingMode::Auto, allowed_function_names: None };
let json = serde_json::to_value(&config).unwrap();
assert!(json.get("allowed_function_names").is_none());
}
#[test]
fn function_call_with_id_serde_round_trip() {
let call = FunctionCall {
name: "get_weather".to_string(),
args: serde_json::json!({"city": "Tokyo"}),
id: Some("fc_001".to_string()),
thought_signature: None,
};
let json = serde_json::to_value(&call).unwrap();
assert_eq!(json["id"], "fc_001");
let deserialized: FunctionCall = serde_json::from_value(json).unwrap();
assert_eq!(deserialized.id, Some("fc_001".to_string()));
}
#[test]
fn function_call_without_id_omits_field() {
let call = FunctionCall::new("get_weather", serde_json::json!({"city": "Tokyo"}));
let json = serde_json::to_value(&call).unwrap();
assert!(json.get("id").is_none());
}
#[test]
fn function_call_deserializes_without_id() {
let json = serde_json::json!({
"name": "get_weather",
"args": {"city": "Tokyo"}
});
let call: FunctionCall = serde_json::from_value(json).unwrap();
assert_eq!(call.id, None);
assert_eq!(call.name, "get_weather");
}
}