use serde::{Deserialize, Serialize};
use crate::providers::Provider;
use crate::models::response::ToolSchema;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionToolParam {
pub r#type: String,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct OpenAIProvider {
strict: bool,
}
impl OpenAIProvider {
pub fn new() -> Self {
Self { strict: false }
}
pub fn with_strict(mut self, strict: bool) -> Self {
self.strict = strict;
self
}
}
impl Default for OpenAIProvider {
fn default() -> Self {
Self::new()
}
}
impl Provider for OpenAIProvider {
type Tool = ChatCompletionToolParam;
type ToolCollection = Vec<ChatCompletionToolParam>;
fn name(&self) -> &str {
"openai"
}
fn wrap_tool(&self, tool: &ToolSchema) -> Self::Tool {
ChatCompletionToolParam {
r#type: "function".to_string(),
function: FunctionDefinition {
name: tool.slug.clone(),
description: tool.description.clone(),
parameters: tool.input_parameters.clone(),
strict: if self.strict { Some(true) } else { None },
},
}
}
fn wrap_tools(&self, tools: Vec<ToolSchema>) -> Self::ToolCollection {
tools.iter().map(|t| self.wrap_tool(t)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn create_test_tool() -> ToolSchema {
ToolSchema {
slug: "GITHUB_CREATE_ISSUE".to_string(),
name: "Create GitHub Issue".to_string(),
description: "Create a new issue in a GitHub repository".to_string(),
toolkit: "github".to_string(),
input_parameters: json!({
"type": "object",
"properties": {
"owner": {"type": "string"},
"repo": {"type": "string"},
"title": {"type": "string"}
},
"required": ["owner", "repo", "title"]
}),
output_parameters: json!({}),
version: "1.0.0".to_string(),
available_versions: vec!["1.0.0".to_string()],
is_deprecated: false,
no_auth: false,
scopes: vec![],
tags: vec![],
}
}
#[test]
fn test_openai_provider_name() {
let provider = OpenAIProvider::new();
assert_eq!(provider.name(), "openai");
}
#[test]
fn test_wrap_tool_basic() {
let provider = OpenAIProvider::new();
let tool = create_test_tool();
let wrapped = provider.wrap_tool(&tool);
assert_eq!(wrapped.r#type, "function");
assert_eq!(wrapped.function.name, "GITHUB_CREATE_ISSUE");
assert_eq!(wrapped.function.description, "Create a new issue in a GitHub repository");
assert!(wrapped.function.strict.is_none());
}
#[test]
fn test_wrap_tool_with_strict() {
let provider = OpenAIProvider::new().with_strict(true);
let tool = create_test_tool();
let wrapped = provider.wrap_tool(&tool);
assert_eq!(wrapped.function.strict, Some(true));
}
#[test]
fn test_wrap_tools() {
let provider = OpenAIProvider::new();
let tools = vec![create_test_tool(), create_test_tool()];
let wrapped = provider.wrap_tools(tools);
assert_eq!(wrapped.len(), 2);
assert_eq!(wrapped[0].function.name, "GITHUB_CREATE_ISSUE");
}
#[test]
fn test_serialization() {
let provider = OpenAIProvider::new();
let tool = create_test_tool();
let wrapped = provider.wrap_tool(&tool);
let json = serde_json::to_string(&wrapped).unwrap();
assert!(json.contains("function"));
assert!(json.contains("GITHUB_CREATE_ISSUE"));
}
}