use crate::common::{
errors::{OpenAIToolError, Result as OpenAIToolResult},
parameters::Parameters,
};
use serde::{ser::SerializeStruct, Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct Function {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Parameters>,
pub arguments: Option<HashMap<String, Value>>,
pub strict: bool,
}
impl Function {
pub fn new<T: AsRef<str>, U: AsRef<str>>(name: T, description: U, parameters: Parameters, strict: bool) -> Self {
Self {
name: name.as_ref().to_string(),
description: Some(description.as_ref().to_string()),
parameters: Some(parameters),
strict,
..Default::default()
}
}
pub fn arguments_as_map(&self) -> OpenAIToolResult<HashMap<String, Value>> {
if let Some(args) = &self.arguments {
Ok(args.clone())
} else {
Err(OpenAIToolError::from(anyhow::anyhow!("Function arguments are not set")))
}
}
}
impl Serialize for Function {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("Function", 4)?;
state.serialize_field("name", &self.name)?;
if let Some(description) = &self.description {
state.serialize_field("description", description)?;
}
if let Some(parameters) = &self.parameters {
state.serialize_field("parameters", parameters)?;
}
if let Some(arguments) = &self.arguments {
state.serialize_field("arguments", arguments)?;
}
state.serialize_field("strict", &self.strict)?;
if let Some(arguments) = &self.arguments {
if !arguments.is_empty() {
state.serialize_field("arguments", &serde_json::to_string(arguments).expect("Failed to serialize arguments in Function"))?;
}
}
state.end()
}
}
impl<'de> Deserialize<'de> for Function {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut function = Function::default();
let map: HashMap<String, Value> = HashMap::deserialize(deserializer)?;
if let Some(name) = map.get("name").and_then(Value::as_str) {
function.name = name.to_string();
} else {
return Err(serde::de::Error::missing_field("name"));
}
let arguments = map.get("arguments").and_then(Value::as_str);
if let Some(args) = arguments {
function.arguments = serde_json::from_str(args).ok();
} else {
function.arguments = None;
}
let parameters = map.get("parameters").and_then(Value::as_object);
if let Some(params) = parameters {
function.parameters = Some(Parameters::deserialize(Value::Object(params.clone())).map_err(serde::de::Error::custom)?);
} else {
function.parameters = None;
}
function.description = map.get("description").and_then(Value::as_str).map(String::from);
function.strict = map.get("strict").and_then(Value::as_bool).unwrap_or(false);
Ok(function)
}
}