use schemars::{generate::SchemaSettings, JsonSchema, SchemaGenerator};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use snafu::{ResultExt, Snafu};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum Tool {
Function {
function_declarations: Vec<FunctionDeclaration>,
},
GoogleSearch {
google_search: GoogleSearchConfig,
},
URLContext {
url_context: URLContextConfig,
},
GoogleMaps {
google_maps: GoogleMapsConfig,
},
CodeExecution {
#[serde(rename = "codeExecution")]
code_execution: CodeExecutionConfig,
},
FileSearch {
file_search: FileSearchConfig,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GoogleSearchConfig {}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct URLContextConfig {}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct GoogleMapsConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub enable_widget: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CodeExecutionConfig {}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ExecutableCode {
pub language: CodeLanguage,
pub code: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum CodeLanguage {
Python,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct CodeExecutionResult {
pub outcome: CodeExecutionOutcome,
pub output: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum CodeExecutionOutcome {
OutcomeOk,
OutcomeFailed,
OutcomeDeadlineExceeded,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct FileSearchConfig {
pub file_search_store_names: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata_filter: Option<String>,
}
impl Tool {
pub fn new(function_declaration: FunctionDeclaration) -> Self {
Self::Function {
function_declarations: vec![function_declaration],
}
}
pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
Self::Function {
function_declarations,
}
}
pub fn google_search() -> Self {
Self::GoogleSearch {
google_search: GoogleSearchConfig {},
}
}
pub fn url_context() -> Self {
Self::URLContext {
url_context: URLContextConfig {},
}
}
pub fn google_maps(enable_widget: Option<bool>) -> Self {
Self::GoogleMaps {
google_maps: GoogleMapsConfig { enable_widget },
}
}
pub fn code_execution() -> Self {
Self::CodeExecution {
code_execution: CodeExecutionConfig {},
}
}
pub fn file_search(store_names: Vec<String>, metadata_filter: Option<String>) -> Self {
Self::FileSearch {
file_search: FileSearchConfig {
file_search_store_names: store_names,
metadata_filter,
},
}
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Behavior {
#[default]
Blocking,
NonBlocking,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub behavior: Option<Behavior>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) parameters: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) response: Option<Value>,
}
fn generate_parameters_schema<Parameters>() -> Value
where
Parameters: JsonSchema + Serialize,
{
let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
s.inline_subschemas = true;
s.meta_schema = None;
}));
let mut schema = schema_generator.into_root_schema_for::<Parameters>();
schema.remove("title");
schema.to_value()
}
impl FunctionDeclaration {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
behavior: Option<Behavior>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
behavior,
..Default::default()
}
}
pub fn with_parameters<Parameters>(mut self) -> Self
where
Parameters: JsonSchema + Serialize,
{
self.parameters = Some(generate_parameters_schema::<Parameters>());
self
}
pub fn with_response<Response>(mut self) -> Self
where
Response: JsonSchema + Serialize,
{
self.response = Some(generate_parameters_schema::<Response>());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
}
#[derive(Debug, Snafu)]
pub enum FunctionCallError {
#[snafu(display("failed to deserialize parameter '{key}'"))]
Deserialization {
source: serde_json::Error,
key: String,
},
#[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
MissingParameter {
key: String,
args: serde_json::Value,
},
#[snafu(display("arguments should be an object; actual: {actual}"))]
ArgumentTypeMismatch { actual: String },
}
impl FunctionCall {
pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
Self {
name: name.into(),
args,
thought_signature: None,
}
}
pub fn with_thought_signature(
name: impl Into<String>,
args: serde_json::Value,
thought_signature: impl Into<String>,
) -> Self {
Self {
name: name.into(),
args,
thought_signature: Some(thought_signature.into()),
}
}
pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
match &self.args {
serde_json::Value::Object(obj) => {
if let Some(value) = obj.get(key) {
serde_json::from_value(value.clone()).with_context(|_| DeserializationSnafu {
key: key.to_string(),
})
} else {
Err(MissingParameterSnafu {
key: key.to_string(),
args: self.args.clone(),
}
.build())
}
}
_ => Err(ArgumentTypeMismatchSnafu {
actual: self.args.to_string(),
}
.build()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionResponse {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<serde_json::Value>,
}
impl FunctionResponse {
pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
Self {
name: name.into(),
response: Some(response),
}
}
pub fn from_schema<Response>(
name: impl Into<String>,
response: Response,
) -> Result<Self, serde_json::Error>
where
Response: JsonSchema + Serialize,
{
let json = serde_json::to_value(&response)?;
Ok(Self {
name: name.into(),
response: Some(json),
})
}
pub fn from_str(
name: impl Into<String>,
response: impl Into<String>,
) -> Result<Self, serde_json::Error> {
let json = serde_json::from_str(&response.into())?;
Ok(Self {
name: name.into(),
response: Some(json),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct ToolConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub function_calling_config: Option<FunctionCallingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub retrieval_config: Option<RetrievalConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCallingConfig {
pub mode: FunctionCallingMode,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub lat_lng: Option<LatLng>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LatLng {
pub latitude: f64,
pub longitude: f64,
}
impl LatLng {
pub fn new(latitude: f64, longitude: f64) -> Self {
Self {
latitude,
longitude,
}
}
}