composio_sdk/providers/
openai.rs1use serde::{Deserialize, Serialize};
30use crate::providers::Provider;
31use crate::models::response::ToolSchema;
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ChatCompletionToolParam {
38 pub r#type: String,
40 pub function: FunctionDefinition,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct FunctionDefinition {
47 pub name: String,
49 pub description: String,
51 pub parameters: serde_json::Value,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub strict: Option<bool>,
56}
57
58#[derive(Debug, Clone)]
82pub struct OpenAIProvider {
83 strict: bool,
85}
86
87impl OpenAIProvider {
88 pub fn new() -> Self {
98 Self { strict: false }
99 }
100
101 pub fn with_strict(mut self, strict: bool) -> Self {
118 self.strict = strict;
119 self
120 }
121}
122
123impl Default for OpenAIProvider {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl Provider for OpenAIProvider {
130 type Tool = ChatCompletionToolParam;
131 type ToolCollection = Vec<ChatCompletionToolParam>;
132
133 fn name(&self) -> &str {
134 "openai"
135 }
136
137 fn wrap_tool(&self, tool: &ToolSchema) -> Self::Tool {
138 ChatCompletionToolParam {
139 r#type: "function".to_string(),
140 function: FunctionDefinition {
141 name: tool.slug.clone(),
142 description: tool.description.clone(),
143 parameters: tool.input_parameters.clone(),
144 strict: if self.strict { Some(true) } else { None },
145 },
146 }
147 }
148
149 fn wrap_tools(&self, tools: Vec<ToolSchema>) -> Self::ToolCollection {
150 tools.iter().map(|t| self.wrap_tool(t)).collect()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use serde_json::json;
158
159 fn create_test_tool() -> ToolSchema {
160 ToolSchema {
161 slug: "GITHUB_CREATE_ISSUE".to_string(),
162 name: "Create GitHub Issue".to_string(),
163 description: "Create a new issue in a GitHub repository".to_string(),
164 toolkit: "github".to_string(),
165 input_parameters: json!({
166 "type": "object",
167 "properties": {
168 "owner": {"type": "string"},
169 "repo": {"type": "string"},
170 "title": {"type": "string"}
171 },
172 "required": ["owner", "repo", "title"]
173 }),
174 output_parameters: json!({}),
175 version: "1.0.0".to_string(),
176 available_versions: vec!["1.0.0".to_string()],
177 is_deprecated: false,
178 no_auth: false,
179 scopes: vec![],
180 tags: vec![],
181 }
182 }
183
184 #[test]
185 fn test_openai_provider_name() {
186 let provider = OpenAIProvider::new();
187 assert_eq!(provider.name(), "openai");
188 }
189
190 #[test]
191 fn test_wrap_tool_basic() {
192 let provider = OpenAIProvider::new();
193 let tool = create_test_tool();
194
195 let wrapped = provider.wrap_tool(&tool);
196
197 assert_eq!(wrapped.r#type, "function");
198 assert_eq!(wrapped.function.name, "GITHUB_CREATE_ISSUE");
199 assert_eq!(wrapped.function.description, "Create a new issue in a GitHub repository");
200 assert!(wrapped.function.strict.is_none());
201 }
202
203 #[test]
204 fn test_wrap_tool_with_strict() {
205 let provider = OpenAIProvider::new().with_strict(true);
206 let tool = create_test_tool();
207
208 let wrapped = provider.wrap_tool(&tool);
209
210 assert_eq!(wrapped.function.strict, Some(true));
211 }
212
213 #[test]
214 fn test_wrap_tools() {
215 let provider = OpenAIProvider::new();
216 let tools = vec![create_test_tool(), create_test_tool()];
217
218 let wrapped = provider.wrap_tools(tools);
219
220 assert_eq!(wrapped.len(), 2);
221 assert_eq!(wrapped[0].function.name, "GITHUB_CREATE_ISSUE");
222 }
223
224 #[test]
225 fn test_serialization() {
226 let provider = OpenAIProvider::new();
227 let tool = create_test_tool();
228 let wrapped = provider.wrap_tool(&tool);
229
230 let json = serde_json::to_string(&wrapped).unwrap();
231 assert!(json.contains("function"));
232 assert!(json.contains("GITHUB_CREATE_ISSUE"));
233 }
234}