Skip to main content

composio_sdk/providers/
openai.rs

1//! OpenAI provider for Chat Completions API
2//!
3//! This module provides the OpenAI provider implementation, which converts
4//! Composio tools to OpenAI's `ChatCompletionToolParam` format.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use composio_sdk::{ComposioClient, providers::OpenAIProvider};
10//!
11//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
12//! let client = ComposioClient::with_provider(OpenAIProvider::new())
13//!     .api_key("your_key")
14//!     .build()?;
15//!
16//! let session = client
17//!     .create_session("user_123")
18//!     .toolkits(vec!["github"])
19//!     .send()
20//!     .await?;
21//!
22//! // Get tools in OpenAI format
23//! let tools = session.get_provider_tools().await?;
24//! // tools: Vec<ChatCompletionToolParam>
25//! # Ok(())
26//! # }
27//! ```
28
29use serde::{Deserialize, Serialize};
30use crate::providers::Provider;
31use crate::models::response::ToolSchema;
32
33/// OpenAI tool format (ChatCompletionToolParam)
34///
35/// Represents a tool in OpenAI's Chat Completions API format.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct ChatCompletionToolParam {
38    /// The type of the tool (always "function" for function calling)
39    pub r#type: String,
40    /// The function definition
41    pub function: FunctionDefinition,
42}
43
44/// Function definition for OpenAI tools
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct FunctionDefinition {
47    /// The name of the function
48    pub name: String,
49    /// A description of what the function does
50    pub description: String,
51    /// The parameters the function accepts (JSON Schema)
52    pub parameters: serde_json::Value,
53    /// Whether to enable strict schema validation (OpenAI feature)
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub strict: Option<bool>,
56}
57
58/// OpenAI provider for Chat Completions API
59///
60/// Converts Composio tools to OpenAI's ChatCompletionToolParam format.
61/// Supports optional strict schema validation.
62///
63/// # Example
64///
65/// ```no_run
66/// use composio_sdk::{ComposioClient, providers::OpenAIProvider};
67///
68/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
69/// // Default provider (no strict validation)
70/// let provider = OpenAIProvider::new();
71///
72/// // With strict validation enabled
73/// let provider = OpenAIProvider::new().with_strict(true);
74///
75/// let client = ComposioClient::with_provider(provider)
76///     .api_key("your_key")
77///     .build()?;
78/// # Ok(())
79/// # }
80/// ```
81#[derive(Debug, Clone)]
82pub struct OpenAIProvider {
83    /// Whether to use strict schema validation
84    strict: bool,
85}
86
87impl OpenAIProvider {
88    /// Create a new OpenAI provider with default settings
89    ///
90    /// # Example
91    ///
92    /// ```rust
93    /// use composio_sdk::providers::OpenAIProvider;
94    ///
95    /// let provider = OpenAIProvider::new();
96    /// ```
97    pub fn new() -> Self {
98        Self { strict: false }
99    }
100    
101    /// Enable or disable strict schema validation
102    ///
103    /// When enabled, OpenAI will enforce strict schema validation on tool calls.
104    /// This is an OpenAI-specific feature that ensures tool calls match the schema exactly.
105    ///
106    /// # Arguments
107    ///
108    /// * `strict` - Whether to enable strict validation
109    ///
110    /// # Example
111    ///
112    /// ```rust
113    /// use composio_sdk::providers::OpenAIProvider;
114    ///
115    /// let provider = OpenAIProvider::new().with_strict(true);
116    /// ```
117    pub fn with_strict(mut self, strict: bool) -> Self {
118        self.strict = strict;
119        self
120    }
121}
122
123impl Default for OpenAIProvider {
124    fn default() -> Self {
125        Self::new()
126    }
127}
128
129impl Provider for OpenAIProvider {
130    type Tool = ChatCompletionToolParam;
131    type ToolCollection = Vec<ChatCompletionToolParam>;
132    
133    fn name(&self) -> &str {
134        "openai"
135    }
136    
137    fn wrap_tool(&self, tool: &ToolSchema) -> Self::Tool {
138        ChatCompletionToolParam {
139            r#type: "function".to_string(),
140            function: FunctionDefinition {
141                name: tool.slug.clone(),
142                description: tool.description.clone(),
143                parameters: tool.input_parameters.clone(),
144                strict: if self.strict { Some(true) } else { None },
145            },
146        }
147    }
148    
149    fn wrap_tools(&self, tools: Vec<ToolSchema>) -> Self::ToolCollection {
150        tools.iter().map(|t| self.wrap_tool(t)).collect()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use serde_json::json;
158
159    fn create_test_tool() -> ToolSchema {
160        ToolSchema {
161            slug: "GITHUB_CREATE_ISSUE".to_string(),
162            name: "Create GitHub Issue".to_string(),
163            description: "Create a new issue in a GitHub repository".to_string(),
164            toolkit: "github".to_string(),
165            input_parameters: json!({
166                "type": "object",
167                "properties": {
168                    "owner": {"type": "string"},
169                    "repo": {"type": "string"},
170                    "title": {"type": "string"}
171                },
172                "required": ["owner", "repo", "title"]
173            }),
174            output_parameters: json!({}),
175            version: "1.0.0".to_string(),
176            available_versions: vec!["1.0.0".to_string()],
177            is_deprecated: false,
178            no_auth: false,
179            scopes: vec![],
180            tags: vec![],
181        }
182    }
183
184    #[test]
185    fn test_openai_provider_name() {
186        let provider = OpenAIProvider::new();
187        assert_eq!(provider.name(), "openai");
188    }
189
190    #[test]
191    fn test_wrap_tool_basic() {
192        let provider = OpenAIProvider::new();
193        let tool = create_test_tool();
194        
195        let wrapped = provider.wrap_tool(&tool);
196        
197        assert_eq!(wrapped.r#type, "function");
198        assert_eq!(wrapped.function.name, "GITHUB_CREATE_ISSUE");
199        assert_eq!(wrapped.function.description, "Create a new issue in a GitHub repository");
200        assert!(wrapped.function.strict.is_none());
201    }
202
203    #[test]
204    fn test_wrap_tool_with_strict() {
205        let provider = OpenAIProvider::new().with_strict(true);
206        let tool = create_test_tool();
207        
208        let wrapped = provider.wrap_tool(&tool);
209        
210        assert_eq!(wrapped.function.strict, Some(true));
211    }
212
213    #[test]
214    fn test_wrap_tools() {
215        let provider = OpenAIProvider::new();
216        let tools = vec![create_test_tool(), create_test_tool()];
217        
218        let wrapped = provider.wrap_tools(tools);
219        
220        assert_eq!(wrapped.len(), 2);
221        assert_eq!(wrapped[0].function.name, "GITHUB_CREATE_ISSUE");
222    }
223
224    #[test]
225    fn test_serialization() {
226        let provider = OpenAIProvider::new();
227        let tool = create_test_tool();
228        let wrapped = provider.wrap_tool(&tool);
229        
230        let json = serde_json::to_string(&wrapped).unwrap();
231        assert!(json.contains("function"));
232        assert!(json.contains("GITHUB_CREATE_ISSUE"));
233    }
234}