use schemars::{JsonSchema, SchemaGenerator, generate::SchemaSettings};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use snafu::{ResultExt, Snafu};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum Tool {
Function {
function_declarations: Vec<FunctionDeclaration>,
},
GoogleSearch {
google_search: GoogleSearchConfig,
},
URLContext {
url_context: URLContextConfig,
},
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct GoogleSearchConfig {}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct URLContextConfig {}
impl Tool {
pub fn new(function_declaration: FunctionDeclaration) -> Self {
Self::Function { function_declarations: vec![function_declaration] }
}
pub fn with_functions(function_declarations: Vec<FunctionDeclaration>) -> Self {
Self::Function { function_declarations }
}
pub fn google_search() -> Self {
Self::GoogleSearch { google_search: GoogleSearchConfig {} }
}
pub fn url_context() -> Self {
Self::URLContext { url_context: URLContextConfig {} }
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Behavior {
#[default]
Blocking,
NonBlocking,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub behavior: Option<Behavior>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) parameters: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) response: Option<Value>,
}
fn generate_parameters_schema<Parameters>() -> Value
where
Parameters: JsonSchema + Serialize,
{
let schema_generator = SchemaGenerator::new(SchemaSettings::openapi3().with(|s| {
s.inline_subschemas = true;
s.meta_schema = None;
}));
let mut schema = schema_generator.into_root_schema_for::<Parameters>();
schema.remove("title");
schema.to_value()
}
impl FunctionDeclaration {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
behavior: Option<Behavior>,
) -> Self {
Self { name: name.into(), description: description.into(), behavior, ..Default::default() }
}
pub fn with_parameters<Parameters>(mut self) -> Self
where
Parameters: JsonSchema + Serialize,
{
self.parameters = Some(generate_parameters_schema::<Parameters>());
self
}
pub fn with_response<Response>(mut self) -> Self
where
Response: JsonSchema + Serialize,
{
self.response = Some(generate_parameters_schema::<Response>());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
}
#[derive(Debug, Snafu)]
pub enum FunctionCallError {
#[snafu(display("failed to deserialize parameter '{key}'"))]
Deserialization { source: serde_json::Error, key: String },
#[snafu(display("parameter '{key}' is missing in arguments '{args}'"))]
MissingParameter { key: String, args: serde_json::Value },
#[snafu(display("arguments should be an object; actual: {actual}"))]
ArgumentTypeMismatch { actual: String },
}
impl FunctionCall {
pub fn new(name: impl Into<String>, args: serde_json::Value) -> Self {
Self { name: name.into(), args, thought_signature: None }
}
pub fn with_thought_signature(
name: impl Into<String>,
args: serde_json::Value,
thought_signature: impl Into<String>,
) -> Self {
Self { name: name.into(), args, thought_signature: Some(thought_signature.into()) }
}
pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Result<T, FunctionCallError> {
match &self.args {
serde_json::Value::Object(obj) => {
if let Some(value) = obj.get(key) {
serde_json::from_value(value.clone())
.with_context(|_| DeserializationSnafu { key: key.to_string() })
} else {
Err(MissingParameterSnafu { key: key.to_string(), args: self.args.clone() }
.build())
}
}
_ => Err(ArgumentTypeMismatchSnafu { actual: self.args.to_string() }.build()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionResponse {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub response: Option<serde_json::Value>,
}
impl FunctionResponse {
pub fn new(name: impl Into<String>, response: serde_json::Value) -> Self {
Self { name: name.into(), response: Some(response) }
}
pub fn from_schema<Response>(
name: impl Into<String>,
response: Response,
) -> Result<Self, serde_json::Error>
where
Response: JsonSchema + Serialize,
{
let json = serde_json::to_value(&response)?;
Ok(Self { name: name.into(), response: Some(json) })
}
pub fn from_str(
name: impl Into<String>,
response: impl Into<String>,
) -> Result<Self, serde_json::Error> {
let json = serde_json::from_str(&response.into())?;
Ok(Self { name: name.into(), response: Some(json) })
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct ToolConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub function_calling_config: Option<FunctionCallingConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCallingConfig {
pub mode: FunctionCallingMode,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FunctionCallingMode {
Auto,
Any,
None,
}