use serde::{Deserialize, Serialize, de::Visitor, ser::SerializeStruct};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Tool {
Function {
name: String,
parameters: serde_json::Value,
strict: bool,
description: Option<String>,
},
FileSearch {
vector_store_ids: Vec<String>,
filters: FileSearchFilters,
max_num_results: u8,
ranking_options: RankingOptions,
},
#[serde(rename = "computer_use_preview")]
ComputerUse {
display_height: u64,
display_width: u64,
environment: Environment,
},
#[serde(rename = "web_search_preview")]
WebSearch {
search_context_size: SearchContextSize,
user_location: Option<UserLocation>,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UserLocation {
pub r#type: UserLocationType,
pub city: String,
pub country: String,
pub region: String,
pub timezone: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum UserLocationType {
#[default]
Approximate,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SearchContextSize {
Low,
High,
#[default]
Medium,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Environment {
Mac,
Ubuntu,
Browser,
Windows,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum FileSearchFilters {
Single(ComparisonFilter),
Compound(CompoundFilter),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComparisonFilter {
pub key: String,
pub r#type: ComparisonFilterType,
pub value: ComparisonFilterValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ComparisonFilterValue {
Number(f64),
Boolean(bool),
String(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ComparisonFilterType {
#[serde(rename = "eq")]
Equals,
#[serde(rename = "ne")]
NotEqual,
#[serde(rename = "gt")]
GreaterThan,
#[serde(rename = "gte")]
GreaterThanOrEqual,
#[serde(rename = "lt")]
LessThan,
#[serde(rename = "lte")]
LessThanOrEqual,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompoundFilter {
pub filters: Vec<FileSearchFilters>,
pub r#type: CompoundFilterType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CompoundFilterType {
And,
Or,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RankingOptions {
pub ranker: String,
pub score_threshold: f32,
}
#[derive(Debug, Clone, Default)]
pub enum ToolChoice {
None,
#[default]
Auto,
Required,
FileSearch,
WebSearchPreview,
ComputerUsePreview,
Function(String),
}
impl<'de> Deserialize<'de> for ToolChoice {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ToolChoiceVisitor;
impl<'de> Visitor<'de> for ToolChoiceVisitor {
type Value = ToolChoice;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("string or struct")
}
fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
where
E: serde::de::Error,
{
match value {
"none" => Ok(ToolChoice::None),
"auto" => Ok(ToolChoice::Auto),
"required" => Ok(ToolChoice::Required),
_ => Err(serde::de::Error::unknown_variant(
value,
&["none", "auto", "required"],
)),
}
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de>,
{
let mut record = HashMap::<String, String>::new();
while let Some((key, value)) = map.next_entry()? {
record.insert(key, value);
}
let Some(r#type) = record.get("type") else {
return Err(serde::de::Error::missing_field("type"));
};
match r#type.as_str() {
"file_search" => Ok(ToolChoice::FileSearch),
"web_search_preview" => Ok(ToolChoice::WebSearchPreview),
"computer_use_preview" => Ok(ToolChoice::ComputerUsePreview),
"function" => {
let Some(name) = record.get("name") else {
return Err(serde::de::Error::missing_field("name"));
};
Ok(ToolChoice::Function(name.clone()))
}
_ => Err(serde::de::Error::unknown_variant(
r#type.as_str(),
&[
"file_search",
"web_search_preview",
"computer_use_preview",
"function",
],
)),
}
}
}
deserializer.deserialize_any(ToolChoiceVisitor {})
}
}
impl Serialize for ToolChoice {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Self::None => serializer.serialize_str("none"),
Self::Auto => serializer.serialize_str("auto"),
Self::Required => serializer.serialize_str("required"),
Self::FileSearch => {
let mut fn_struct = serializer.serialize_struct("Function", 1)?;
fn_struct.serialize_field("type", "file_search")?;
fn_struct.end()
}
Self::WebSearchPreview => {
let mut fn_struct = serializer.serialize_struct("Function", 1)?;
fn_struct.serialize_field("type", "web_search_preview")?;
fn_struct.end()
}
Self::ComputerUsePreview => {
let mut fn_struct = serializer.serialize_struct("Function", 1)?;
fn_struct.serialize_field("type", "computer_use_preview")?;
fn_struct.end()
}
Self::Function(name) => {
let mut fn_struct = serializer.serialize_struct("Function", 2)?;
fn_struct.serialize_field("name", name)?;
fn_struct.serialize_field("type", "function")?;
fn_struct.end()
}
}
}
}