Skip to main content

bamboo_agent/agent/llm/protocol/
openai.rs

1//! OpenAI protocol conversion implementation.
2
3use crate::agent::core::tools::{FunctionCall, FunctionSchema, ToolCall, ToolSchema};
4use crate::agent::core::{Message, Role};
5use crate::agent::llm::api::models::{
6    ChatMessage as OpenAIChatMessage, Content as OpenAIContent, ContentPart as OpenAIContentPart,
7    Role as OpenAIRole, Tool, ToolCall as OpenAIToolCall,
8};
9use crate::agent::llm::protocol::{FromProvider, ProtocolResult, ToProvider};
10
11/// OpenAI protocol converter.
12pub struct OpenAIProtocol;
13
14// ============================================================================
15// OpenAI → Internal (FromProvider)
16// ============================================================================
17
18impl FromProvider<OpenAIChatMessage> for Message {
19    fn from_provider(msg: OpenAIChatMessage) -> ProtocolResult<Self> {
20        let role = convert_openai_role_to_internal(&msg.role);
21
22        let content = match msg.content {
23            OpenAIContent::Text(text) => text,
24            OpenAIContent::Parts(parts) => {
25                // Extract text from parts, ignore images for now
26                parts
27                    .into_iter()
28                    .filter_map(|part| match part {
29                        OpenAIContentPart::Text { text } => Some(text),
30                        OpenAIContentPart::ImageUrl { .. } => None,
31                    })
32                    .collect::<Vec<_>>()
33                    .join("")
34            }
35        };
36
37        let tool_calls = msg
38            .tool_calls
39            .map(|calls| calls.into_iter().map(ToolCall::from_provider).collect())
40            .transpose()?;
41
42        Ok(Message {
43            id: String::new(), // Will be generated if needed
44            role,
45            content,
46            tool_calls,
47            tool_call_id: msg.tool_call_id,
48            created_at: chrono::Utc::now(),
49        })
50    }
51}
52
53impl FromProvider<OpenAIToolCall> for ToolCall {
54    fn from_provider(tc: OpenAIToolCall) -> ProtocolResult<Self> {
55        Ok(ToolCall {
56            id: tc.id,
57            tool_type: tc.tool_type,
58            function: FunctionCall {
59                name: tc.function.name,
60                arguments: tc.function.arguments,
61            },
62        })
63    }
64}
65
66impl FromProvider<Tool> for ToolSchema {
67    fn from_provider(tool: Tool) -> ProtocolResult<Self> {
68        Ok(ToolSchema {
69            schema_type: tool.tool_type,
70            function: FunctionSchema {
71                name: tool.function.name,
72                description: tool.function.description.unwrap_or_default(),
73                parameters: tool.function.parameters,
74            },
75        })
76    }
77}
78
79// ============================================================================
80// Internal → OpenAI (ToProvider)
81// ============================================================================
82
83impl ToProvider<OpenAIChatMessage> for Message {
84    fn to_provider(&self) -> ProtocolResult<OpenAIChatMessage> {
85        let role = convert_internal_role_to_openai(&self.role);
86
87        let content = OpenAIContent::Text(self.content.clone());
88
89        let tool_calls = self
90            .tool_calls
91            .as_ref()
92            .map(|calls| calls.iter().map(|tc| tc.to_provider()).collect())
93            .transpose()?;
94
95        Ok(OpenAIChatMessage {
96            role,
97            content,
98            tool_calls,
99            tool_call_id: self.tool_call_id.clone(),
100        })
101    }
102}
103
104impl ToProvider<OpenAIToolCall> for ToolCall {
105    fn to_provider(&self) -> ProtocolResult<OpenAIToolCall> {
106        Ok(OpenAIToolCall {
107            id: self.id.clone(),
108            tool_type: self.tool_type.clone(),
109            function: crate::agent::llm::api::models::FunctionCall {
110                name: self.function.name.clone(),
111                arguments: self.function.arguments.clone(),
112            },
113        })
114    }
115}
116
117impl ToProvider<Tool> for ToolSchema {
118    fn to_provider(&self) -> ProtocolResult<Tool> {
119        Ok(Tool {
120            tool_type: self.schema_type.clone(),
121            function: crate::agent::llm::api::models::FunctionDefinition {
122                name: self.function.name.clone(),
123                description: Some(self.function.description.clone()),
124                parameters: self.function.parameters.clone(),
125            },
126        })
127    }
128}
129
130// ============================================================================
131// Helper functions
132// ============================================================================
133
134fn convert_openai_role_to_internal(role: &OpenAIRole) -> Role {
135    match role {
136        OpenAIRole::System => Role::System,
137        OpenAIRole::User => Role::User,
138        OpenAIRole::Assistant => Role::Assistant,
139        OpenAIRole::Tool => Role::Tool,
140    }
141}
142
143fn convert_internal_role_to_openai(role: &Role) -> OpenAIRole {
144    match role {
145        Role::System => OpenAIRole::System,
146        Role::User => OpenAIRole::User,
147        Role::Assistant => OpenAIRole::Assistant,
148        Role::Tool => OpenAIRole::Tool,
149    }
150}
151
152// ============================================================================
153// Extension trait for ergonomic conversion
154// ============================================================================
155
156/// Extension trait for converting types with .into_internal() and .to_openai() (test-only)
157#[cfg(test)]
158pub trait OpenAIExt: Sized {
159    fn into_internal(self) -> ProtocolResult<Message>;
160    fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage>;
161}
162
163#[cfg(test)]
164impl OpenAIExt for OpenAIChatMessage {
165    fn into_internal(self) -> ProtocolResult<Message> {
166        Message::from_provider(self)
167    }
168
169    fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage> {
170        Ok(self.clone())
171    }
172}
173
174#[cfg(test)]
175impl OpenAIExt for Message {
176    fn into_internal(self) -> ProtocolResult<Message> {
177        Ok(self)
178    }
179
180    fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage> {
181        self.to_provider()
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::agent::core::tools::FunctionCall;
189    use crate::agent::core::Role;
190    use crate::agent::llm::api::models::{FunctionCall as OpenAIFunctionCall, Role as OpenAIRole};
191
192    #[test]
193    fn test_openai_to_internal_simple_message() {
194        let openai_msg = OpenAIChatMessage {
195            role: OpenAIRole::User,
196            content: OpenAIContent::Text("Hello".to_string()),
197            tool_calls: None,
198            tool_call_id: None,
199        };
200
201        let internal_msg: Message = openai_msg.into_internal().unwrap();
202
203        assert_eq!(internal_msg.role, Role::User);
204        assert_eq!(internal_msg.content, "Hello");
205        assert!(internal_msg.tool_calls.is_none());
206    }
207
208    #[test]
209    fn test_internal_to_openai_simple_message() {
210        let internal_msg = Message::user("Hello");
211
212        let openai_msg: OpenAIChatMessage = internal_msg.to_openai().unwrap();
213
214        assert_eq!(openai_msg.role, OpenAIRole::User);
215        assert!(matches!(openai_msg.content, OpenAIContent::Text(ref t) if t == "Hello"));
216        assert!(openai_msg.tool_calls.is_none());
217    }
218
219    #[test]
220    fn test_openai_to_internal_with_tool_call() {
221        let openai_msg = OpenAIChatMessage {
222            role: OpenAIRole::Assistant,
223            content: OpenAIContent::Text(String::new()),
224            tool_calls: Some(vec![OpenAIToolCall {
225                id: "call_1".to_string(),
226                tool_type: "function".to_string(),
227                function: OpenAIFunctionCall {
228                    name: "search".to_string(),
229                    arguments: r#"{"q":"test"}"#.to_string(),
230                },
231            }]),
232            tool_call_id: None,
233        };
234
235        let internal_msg: Message = Message::from_provider(openai_msg).unwrap();
236
237        assert_eq!(internal_msg.role, Role::Assistant);
238        assert!(internal_msg.tool_calls.is_some());
239        let tool_calls = internal_msg.tool_calls.unwrap();
240        assert_eq!(tool_calls.len(), 1);
241        assert_eq!(tool_calls[0].id, "call_1");
242        assert_eq!(tool_calls[0].function.name, "search");
243    }
244
245    #[test]
246    fn test_internal_to_openai_with_tool_call() {
247        let tool_call = ToolCall {
248            id: "call_1".to_string(),
249            tool_type: "function".to_string(),
250            function: FunctionCall {
251                name: "search".to_string(),
252                arguments: r#"{"q":"test"}"#.to_string(),
253            },
254        };
255
256        let internal_msg = Message::assistant("", Some(vec![tool_call]));
257
258        let openai_msg: OpenAIChatMessage = internal_msg.to_provider().unwrap();
259
260        assert_eq!(openai_msg.role, OpenAIRole::Assistant);
261        assert!(openai_msg.tool_calls.is_some());
262        let tool_calls = openai_msg.tool_calls.unwrap();
263        assert_eq!(tool_calls.len(), 1);
264        assert_eq!(tool_calls[0].id, "call_1");
265        assert_eq!(tool_calls[0].function.name, "search");
266        assert_eq!(tool_calls[0].function.arguments, r#"{"q":"test"}"#);
267    }
268
269    #[test]
270    fn test_roundtrip_conversion() {
271        let original = Message::user("Hello, world!");
272
273        // Internal → OpenAI
274        let openai_msg: OpenAIChatMessage = original.to_provider().unwrap();
275
276        // OpenAI → Internal
277        let roundtrip: Message = Message::from_provider(openai_msg).unwrap();
278
279        assert_eq!(roundtrip.role, original.role);
280        assert_eq!(roundtrip.content, original.content);
281    }
282
283    #[test]
284    fn test_tool_schema_conversion() {
285        let openai_tool = Tool {
286            tool_type: "function".to_string(),
287            function: crate::agent::llm::api::models::FunctionDefinition {
288                name: "search".to_string(),
289                description: Some("Search the web".to_string()),
290                parameters: serde_json::json!({
291                    "type": "object",
292                    "properties": {
293                        "q": { "type": "string" }
294                    }
295                }),
296            },
297        };
298
299        // OpenAI → Internal
300        let internal_schema: ToolSchema = ToolSchema::from_provider(openai_tool.clone()).unwrap();
301        assert_eq!(internal_schema.function.name, "search");
302
303        // Internal → OpenAI
304        let roundtrip: Tool = internal_schema.to_provider().unwrap();
305        assert_eq!(roundtrip.function.name, "search");
306        assert_eq!(
307            roundtrip.function.description,
308            Some("Search the web".to_string())
309        );
310    }
311
312    #[test]
313    fn test_extension_trait() {
314        let openai_msg = OpenAIChatMessage {
315            role: OpenAIRole::User,
316            content: OpenAIContent::Text("Test".to_string()),
317            tool_calls: None,
318            tool_call_id: None,
319        };
320
321        // Using extension trait
322        let internal = openai_msg.into_internal().unwrap();
323        assert_eq!(internal.content, "Test");
324    }
325}