use serde::{Deserialize, Serialize};
use super::content::JsonSchema;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum ToolChoice {
#[default]
Auto,
Required,
None,
#[serde(untagged)]
Named(String),
}
impl From<&str> for ToolChoice {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"auto" => ToolChoice::Auto,
"required" => ToolChoice::Required,
"none" => ToolChoice::None,
name => ToolChoice::Named(name.to_string()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: JsonSchema,
#[serde(default = "default_strict")]
pub strict: bool,
}
fn default_strict() -> bool {
true
}
impl ToolDefinition {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
parameters: JsonSchema,
) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters,
strict: true,
}
}
pub fn with_strict(mut self, strict: bool) -> Self {
self.strict = strict;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Function {
pub name: String,
pub arguments: String,
}
impl Function {
pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
serde_json::from_str(&self.arguments)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
pub id: String,
pub function: Function,
#[serde(default = "default_tool_type")]
#[serde(rename = "type")]
pub tool_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
}
fn default_tool_type() -> String {
"function".to_string()
}
impl ToolCall {
pub fn new(
id: impl Into<String>,
name: impl Into<String>,
arguments: impl Into<String>,
) -> Self {
Self {
id: id.into(),
function: Function {
name: name.into(),
arguments: arguments.into(),
},
tool_type: "function".to_string(),
thought_signature: None,
}
}
pub fn parse_args<T: for<'de> Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
self.function.parse_args()
}
}