use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::function_schema::FunctionSchema;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AdapterType {
OpenAI,
Gemini,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
None,
Required,
Function { name: String },
}
impl ToolChoice {
pub fn to_openai_value(&self) -> Value {
match self {
ToolChoice::Auto => Value::String("auto".into()),
ToolChoice::None => Value::String("none".into()),
ToolChoice::Required => Value::String("required".into()),
ToolChoice::Function { name } => serde_json::json!({
"type": "function",
"function": { "name": name }
}),
}
}
}
#[derive(Debug, Clone)]
pub struct ToolsSchema {
pub standard_tools: Vec<FunctionSchema>,
pub custom_tools: Option<HashMap<AdapterType, Vec<Value>>>,
}
impl ToolsSchema {
pub fn new(tools: Vec<FunctionSchema>) -> Self {
Self {
standard_tools: tools,
custom_tools: None,
}
}
pub fn with_custom(
tools: Vec<FunctionSchema>,
custom: HashMap<AdapterType, Vec<Value>>,
) -> Self {
Self {
standard_tools: tools,
custom_tools: Some(custom),
}
}
pub fn merge_standard(&self, extra: &[FunctionSchema]) -> Self {
let mut merged = self.standard_tools.clone();
merged.extend_from_slice(extra);
Self {
standard_tools: merged,
custom_tools: self.custom_tools.clone(),
}
}
}