Skip to main content

bamboo_infrastructure/llm/protocol/
openai.rs

1//! OpenAI protocol conversion implementation.
2
3use crate::llm::api::models::{
4    ChatMessage as OpenAIChatMessage, Content as OpenAIContent, ContentPart as OpenAIContentPart,
5    Role as OpenAIRole, Tool, ToolCall as OpenAIToolCall,
6};
7use crate::llm::models::ContentPart;
8use crate::llm::protocol::{FromProvider, ProtocolResult, ToProvider};
9use bamboo_domain::{FunctionCall, ToolCall};
10use bamboo_domain::{FunctionSchema, ToolSchema};
11use bamboo_domain::{Message, MessagePart, MessagePhase, Role};
12
13/// OpenAI protocol converter.
14pub struct OpenAIProtocol;
15
16// ============================================================================
17// OpenAI → Internal (FromProvider)
18// ============================================================================
19
20impl FromProvider<OpenAIChatMessage> for Message {
21    fn from_provider(msg: OpenAIChatMessage) -> ProtocolResult<Self> {
22        let role = convert_openai_role_to_internal(&msg.role);
23
24        let (content, content_parts) = match msg.content {
25            OpenAIContent::Text(text) => (text, None),
26            OpenAIContent::Parts(parts) => {
27                // Preserve parts (including images) while also producing a text-only projection.
28                let text = parts
29                    .iter()
30                    .filter_map(|part| match part {
31                        OpenAIContentPart::Text { text } => Some(text.as_str()),
32                        OpenAIContentPart::ImageUrl { .. } => None,
33                    })
34                    .collect::<Vec<_>>()
35                    .join("");
36                let message_parts: Vec<MessagePart> = parts.into_iter().map(Into::into).collect();
37                (text, Some(message_parts))
38            }
39        };
40
41        let tool_calls = msg
42            .tool_calls
43            .map(|calls| calls.into_iter().map(ToolCall::from_provider).collect())
44            .transpose()?;
45        let phase = match msg.phase.as_deref() {
46            Some("commentary") => Some(MessagePhase::Commentary),
47            Some("final_answer") => Some(MessagePhase::FinalAnswer),
48            _ => None,
49        };
50
51        Ok(Message {
52            id: String::new(), // Will be generated if needed
53            role,
54            content,
55            reasoning: None,
56            content_parts,
57            image_ocr: None,
58            phase,
59            tool_calls,
60            tool_call_id: msg.tool_call_id,
61            tool_success: None,
62            compressed: false,
63            compressed_by_event_id: None,
64            never_compress: false,
65            compression_level: 0,
66            created_at: chrono::Utc::now(),
67            metadata: None,
68        })
69    }
70}
71
72impl FromProvider<OpenAIToolCall> for ToolCall {
73    fn from_provider(tc: OpenAIToolCall) -> ProtocolResult<Self> {
74        Ok(ToolCall {
75            id: tc.id,
76            tool_type: tc.tool_type,
77            function: FunctionCall {
78                name: tc.function.name,
79                arguments: tc.function.arguments,
80            },
81        })
82    }
83}
84
85impl FromProvider<Tool> for ToolSchema {
86    fn from_provider(tool: Tool) -> ProtocolResult<Self> {
87        Ok(ToolSchema {
88            schema_type: tool.tool_type,
89            function: FunctionSchema {
90                name: tool.function.name,
91                description: tool.function.description.unwrap_or_default(),
92                parameters: tool.function.parameters,
93            },
94        })
95    }
96}
97
98// ============================================================================
99// Internal → OpenAI (ToProvider)
100// ============================================================================
101
102impl ToProvider<OpenAIChatMessage> for Message {
103    fn to_provider(&self) -> ProtocolResult<OpenAIChatMessage> {
104        let role = convert_internal_role_to_openai(&self.role);
105
106        let content = match self.content_parts.as_ref() {
107            Some(parts) => {
108                OpenAIContent::Parts(parts.iter().cloned().map(ContentPart::from).collect())
109            }
110            None => OpenAIContent::Text(self.content.clone()),
111        };
112
113        let tool_calls = self
114            .tool_calls
115            .as_ref()
116            .map(|calls| calls.iter().map(|tc| tc.to_provider()).collect())
117            .transpose()?;
118
119        Ok(OpenAIChatMessage {
120            role,
121            content,
122            phase: self.phase.as_ref().map(|phase| phase.as_str().to_string()),
123            tool_calls,
124            tool_call_id: self.tool_call_id.clone(),
125        })
126    }
127}
128
129impl ToProvider<OpenAIToolCall> for ToolCall {
130    fn to_provider(&self) -> ProtocolResult<OpenAIToolCall> {
131        Ok(OpenAIToolCall {
132            id: self.id.clone(),
133            tool_type: self.tool_type.clone(),
134            function: crate::llm::api::models::FunctionCall {
135                name: self.function.name.clone(),
136                arguments: self.function.arguments.clone(),
137            },
138        })
139    }
140}
141
142impl ToProvider<Tool> for ToolSchema {
143    fn to_provider(&self) -> ProtocolResult<Tool> {
144        Ok(Tool {
145            tool_type: self.schema_type.clone(),
146            function: crate::llm::api::models::FunctionDefinition {
147                name: self.function.name.clone(),
148                description: Some(self.function.description.clone()),
149                parameters: self.function.parameters.clone(),
150            },
151        })
152    }
153}
154
155// ============================================================================
156// Helper functions
157// ============================================================================
158
159fn convert_openai_role_to_internal(role: &OpenAIRole) -> Role {
160    match role {
161        OpenAIRole::System => Role::System,
162        OpenAIRole::User => Role::User,
163        OpenAIRole::Assistant => Role::Assistant,
164        OpenAIRole::Tool => Role::Tool,
165    }
166}
167
168fn convert_internal_role_to_openai(role: &Role) -> OpenAIRole {
169    match role {
170        Role::System => OpenAIRole::System,
171        Role::User => OpenAIRole::User,
172        Role::Assistant => OpenAIRole::Assistant,
173        Role::Tool => OpenAIRole::Tool,
174    }
175}
176
177// ============================================================================
178// Extension trait for ergonomic conversion
179// ============================================================================
180
181/// Extension trait for converting types with .into_internal() and .to_openai() (test-only)
182#[cfg(test)]
183pub trait OpenAIExt: Sized {
184    fn into_internal(self) -> ProtocolResult<Message>;
185    fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage>;
186}
187
188#[cfg(test)]
189impl OpenAIExt for OpenAIChatMessage {
190    fn into_internal(self) -> ProtocolResult<Message> {
191        Message::from_provider(self)
192    }
193
194    fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage> {
195        Ok(self.clone())
196    }
197}
198
199#[cfg(test)]
200impl OpenAIExt for Message {
201    fn into_internal(self) -> ProtocolResult<Message> {
202        Ok(self)
203    }
204
205    fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage> {
206        self.to_provider()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::llm::api::models::{FunctionCall as OpenAIFunctionCall, Role as OpenAIRole};
214    use bamboo_domain::FunctionCall;
215    use bamboo_domain::Role;
216
217    #[test]
218    fn test_openai_to_internal_simple_message() {
219        let openai_msg = OpenAIChatMessage {
220            role: OpenAIRole::User,
221            content: OpenAIContent::Text("Hello".to_string()),
222            phase: None,
223            tool_calls: None,
224            tool_call_id: None,
225        };
226
227        let internal_msg: Message = openai_msg.into_internal().unwrap();
228
229        assert_eq!(internal_msg.role, Role::User);
230        assert_eq!(internal_msg.content, "Hello");
231        assert!(internal_msg.tool_calls.is_none());
232    }
233
234    #[test]
235    fn test_internal_to_openai_simple_message() {
236        let internal_msg = Message::user("Hello");
237
238        let openai_msg: OpenAIChatMessage = internal_msg.to_openai().unwrap();
239
240        assert_eq!(openai_msg.role, OpenAIRole::User);
241        assert!(matches!(openai_msg.content, OpenAIContent::Text(ref t) if t == "Hello"));
242        assert!(openai_msg.tool_calls.is_none());
243    }
244
245    #[test]
246    fn test_openai_to_internal_with_tool_call() {
247        let openai_msg = OpenAIChatMessage {
248            role: OpenAIRole::Assistant,
249            content: OpenAIContent::Text(String::new()),
250            phase: None,
251            tool_calls: Some(vec![OpenAIToolCall {
252                id: "call_1".to_string(),
253                tool_type: "function".to_string(),
254                function: OpenAIFunctionCall {
255                    name: "search".to_string(),
256                    arguments: r#"{"q":"test"}"#.to_string(),
257                },
258            }]),
259            tool_call_id: None,
260        };
261
262        let internal_msg: Message = Message::from_provider(openai_msg).unwrap();
263
264        assert_eq!(internal_msg.role, Role::Assistant);
265        assert!(internal_msg.tool_calls.is_some());
266        let tool_calls = internal_msg.tool_calls.unwrap();
267        assert_eq!(tool_calls.len(), 1);
268        assert_eq!(tool_calls[0].id, "call_1");
269        assert_eq!(tool_calls[0].function.name, "search");
270    }
271
272    #[test]
273    fn test_internal_to_openai_with_tool_call() {
274        let tool_call = ToolCall {
275            id: "call_1".to_string(),
276            tool_type: "function".to_string(),
277            function: FunctionCall {
278                name: "search".to_string(),
279                arguments: r#"{"q":"test"}"#.to_string(),
280            },
281        };
282
283        let internal_msg = Message::assistant("", Some(vec![tool_call]));
284
285        let openai_msg: OpenAIChatMessage = internal_msg.to_provider().unwrap();
286
287        assert_eq!(openai_msg.role, OpenAIRole::Assistant);
288        assert!(openai_msg.tool_calls.is_some());
289        let tool_calls = openai_msg.tool_calls.unwrap();
290        assert_eq!(tool_calls.len(), 1);
291        assert_eq!(tool_calls[0].id, "call_1");
292        assert_eq!(tool_calls[0].function.name, "search");
293        assert_eq!(tool_calls[0].function.arguments, r#"{"q":"test"}"#);
294    }
295
296    #[test]
297    fn test_roundtrip_conversion() {
298        let original = Message::user("Hello, world!");
299
300        // Internal → OpenAI
301        let openai_msg: OpenAIChatMessage = original.to_provider().unwrap();
302
303        // OpenAI → Internal
304        let roundtrip: Message = Message::from_provider(openai_msg).unwrap();
305
306        assert_eq!(roundtrip.role, original.role);
307        assert_eq!(roundtrip.content, original.content);
308    }
309
310    #[test]
311    fn test_tool_schema_conversion() {
312        let openai_tool = Tool {
313            tool_type: "function".to_string(),
314            function: crate::llm::api::models::FunctionDefinition {
315                name: "search".to_string(),
316                description: Some("Search the web".to_string()),
317                parameters: serde_json::json!({
318                    "type": "object",
319                    "properties": {
320                        "q": { "type": "string" }
321                    }
322                }),
323            },
324        };
325
326        // OpenAI → Internal
327        let internal_schema: ToolSchema = ToolSchema::from_provider(openai_tool.clone()).unwrap();
328        assert_eq!(internal_schema.function.name, "search");
329
330        // Internal → OpenAI
331        let roundtrip: Tool = internal_schema.to_provider().unwrap();
332        assert_eq!(roundtrip.function.name, "search");
333        assert_eq!(
334            roundtrip.function.description,
335            Some("Search the web".to_string())
336        );
337    }
338
339    #[test]
340    fn test_extension_trait() {
341        let openai_msg = OpenAIChatMessage {
342            role: OpenAIRole::User,
343            content: OpenAIContent::Text("Test".to_string()),
344            phase: None,
345            tool_calls: None,
346            tool_call_id: None,
347        };
348
349        // Using extension trait
350        let internal = openai_msg.into_internal().unwrap();
351        assert_eq!(internal.content, "Test");
352    }
353}