bamboo_infrastructure/llm/protocol/
openai.rs1use crate::llm::api::models::{
4 ChatMessage as OpenAIChatMessage, Content as OpenAIContent, ContentPart as OpenAIContentPart,
5 Role as OpenAIRole, Tool, ToolCall as OpenAIToolCall,
6};
7use crate::llm::models::ContentPart;
8use crate::llm::protocol::{FromProvider, ProtocolResult, ToProvider};
9use bamboo_domain::{FunctionCall, ToolCall};
10use bamboo_domain::{FunctionSchema, ToolSchema};
11use bamboo_domain::{Message, MessagePart, MessagePhase, Role};
12
13pub struct OpenAIProtocol;
15
16impl FromProvider<OpenAIChatMessage> for Message {
21 fn from_provider(msg: OpenAIChatMessage) -> ProtocolResult<Self> {
22 let role = convert_openai_role_to_internal(&msg.role);
23
24 let (content, content_parts) = match msg.content {
25 OpenAIContent::Text(text) => (text, None),
26 OpenAIContent::Parts(parts) => {
27 let text = parts
29 .iter()
30 .filter_map(|part| match part {
31 OpenAIContentPart::Text { text } => Some(text.as_str()),
32 OpenAIContentPart::ImageUrl { .. } => None,
33 })
34 .collect::<Vec<_>>()
35 .join("");
36 let message_parts: Vec<MessagePart> = parts.into_iter().map(Into::into).collect();
37 (text, Some(message_parts))
38 }
39 };
40
41 let tool_calls = msg
42 .tool_calls
43 .map(|calls| calls.into_iter().map(ToolCall::from_provider).collect())
44 .transpose()?;
45 let phase = match msg.phase.as_deref() {
46 Some("commentary") => Some(MessagePhase::Commentary),
47 Some("final_answer") => Some(MessagePhase::FinalAnswer),
48 _ => None,
49 };
50
51 Ok(Message {
52 id: String::new(), role,
54 content,
55 reasoning: None,
56 content_parts,
57 image_ocr: None,
58 phase,
59 tool_calls,
60 tool_call_id: msg.tool_call_id,
61 tool_success: None,
62 compressed: false,
63 compressed_by_event_id: None,
64 never_compress: false,
65 compression_level: 0,
66 created_at: chrono::Utc::now(),
67 metadata: None,
68 })
69 }
70}
71
72impl FromProvider<OpenAIToolCall> for ToolCall {
73 fn from_provider(tc: OpenAIToolCall) -> ProtocolResult<Self> {
74 Ok(ToolCall {
75 id: tc.id,
76 tool_type: tc.tool_type,
77 function: FunctionCall {
78 name: tc.function.name,
79 arguments: tc.function.arguments,
80 },
81 })
82 }
83}
84
85impl FromProvider<Tool> for ToolSchema {
86 fn from_provider(tool: Tool) -> ProtocolResult<Self> {
87 Ok(ToolSchema {
88 schema_type: tool.tool_type,
89 function: FunctionSchema {
90 name: tool.function.name,
91 description: tool.function.description.unwrap_or_default(),
92 parameters: tool.function.parameters,
93 },
94 })
95 }
96}
97
98impl ToProvider<OpenAIChatMessage> for Message {
103 fn to_provider(&self) -> ProtocolResult<OpenAIChatMessage> {
104 let role = convert_internal_role_to_openai(&self.role);
105
106 let content = match self.content_parts.as_ref() {
107 Some(parts) => {
108 OpenAIContent::Parts(parts.iter().cloned().map(ContentPart::from).collect())
109 }
110 None => OpenAIContent::Text(self.content.clone()),
111 };
112
113 let tool_calls = self
114 .tool_calls
115 .as_ref()
116 .map(|calls| calls.iter().map(|tc| tc.to_provider()).collect())
117 .transpose()?;
118
119 Ok(OpenAIChatMessage {
120 role,
121 content,
122 phase: self.phase.as_ref().map(|phase| phase.as_str().to_string()),
123 tool_calls,
124 tool_call_id: self.tool_call_id.clone(),
125 })
126 }
127}
128
129impl ToProvider<OpenAIToolCall> for ToolCall {
130 fn to_provider(&self) -> ProtocolResult<OpenAIToolCall> {
131 Ok(OpenAIToolCall {
132 id: self.id.clone(),
133 tool_type: self.tool_type.clone(),
134 function: crate::llm::api::models::FunctionCall {
135 name: self.function.name.clone(),
136 arguments: self.function.arguments.clone(),
137 },
138 })
139 }
140}
141
142impl ToProvider<Tool> for ToolSchema {
143 fn to_provider(&self) -> ProtocolResult<Tool> {
144 Ok(Tool {
145 tool_type: self.schema_type.clone(),
146 function: crate::llm::api::models::FunctionDefinition {
147 name: self.function.name.clone(),
148 description: Some(self.function.description.clone()),
149 parameters: self.function.parameters.clone(),
150 },
151 })
152 }
153}
154
155fn convert_openai_role_to_internal(role: &OpenAIRole) -> Role {
160 match role {
161 OpenAIRole::System => Role::System,
162 OpenAIRole::User => Role::User,
163 OpenAIRole::Assistant => Role::Assistant,
164 OpenAIRole::Tool => Role::Tool,
165 }
166}
167
168fn convert_internal_role_to_openai(role: &Role) -> OpenAIRole {
169 match role {
170 Role::System => OpenAIRole::System,
171 Role::User => OpenAIRole::User,
172 Role::Assistant => OpenAIRole::Assistant,
173 Role::Tool => OpenAIRole::Tool,
174 }
175}
176
177#[cfg(test)]
183pub trait OpenAIExt: Sized {
184 fn into_internal(self) -> ProtocolResult<Message>;
185 fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage>;
186}
187
188#[cfg(test)]
189impl OpenAIExt for OpenAIChatMessage {
190 fn into_internal(self) -> ProtocolResult<Message> {
191 Message::from_provider(self)
192 }
193
194 fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage> {
195 Ok(self.clone())
196 }
197}
198
199#[cfg(test)]
200impl OpenAIExt for Message {
201 fn into_internal(self) -> ProtocolResult<Message> {
202 Ok(self)
203 }
204
205 fn to_openai(&self) -> ProtocolResult<OpenAIChatMessage> {
206 self.to_provider()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::llm::api::models::{FunctionCall as OpenAIFunctionCall, Role as OpenAIRole};
214 use bamboo_domain::FunctionCall;
215 use bamboo_domain::Role;
216
217 #[test]
218 fn test_openai_to_internal_simple_message() {
219 let openai_msg = OpenAIChatMessage {
220 role: OpenAIRole::User,
221 content: OpenAIContent::Text("Hello".to_string()),
222 phase: None,
223 tool_calls: None,
224 tool_call_id: None,
225 };
226
227 let internal_msg: Message = openai_msg.into_internal().unwrap();
228
229 assert_eq!(internal_msg.role, Role::User);
230 assert_eq!(internal_msg.content, "Hello");
231 assert!(internal_msg.tool_calls.is_none());
232 }
233
234 #[test]
235 fn test_internal_to_openai_simple_message() {
236 let internal_msg = Message::user("Hello");
237
238 let openai_msg: OpenAIChatMessage = internal_msg.to_openai().unwrap();
239
240 assert_eq!(openai_msg.role, OpenAIRole::User);
241 assert!(matches!(openai_msg.content, OpenAIContent::Text(ref t) if t == "Hello"));
242 assert!(openai_msg.tool_calls.is_none());
243 }
244
245 #[test]
246 fn test_openai_to_internal_with_tool_call() {
247 let openai_msg = OpenAIChatMessage {
248 role: OpenAIRole::Assistant,
249 content: OpenAIContent::Text(String::new()),
250 phase: None,
251 tool_calls: Some(vec![OpenAIToolCall {
252 id: "call_1".to_string(),
253 tool_type: "function".to_string(),
254 function: OpenAIFunctionCall {
255 name: "search".to_string(),
256 arguments: r#"{"q":"test"}"#.to_string(),
257 },
258 }]),
259 tool_call_id: None,
260 };
261
262 let internal_msg: Message = Message::from_provider(openai_msg).unwrap();
263
264 assert_eq!(internal_msg.role, Role::Assistant);
265 assert!(internal_msg.tool_calls.is_some());
266 let tool_calls = internal_msg.tool_calls.unwrap();
267 assert_eq!(tool_calls.len(), 1);
268 assert_eq!(tool_calls[0].id, "call_1");
269 assert_eq!(tool_calls[0].function.name, "search");
270 }
271
272 #[test]
273 fn test_internal_to_openai_with_tool_call() {
274 let tool_call = ToolCall {
275 id: "call_1".to_string(),
276 tool_type: "function".to_string(),
277 function: FunctionCall {
278 name: "search".to_string(),
279 arguments: r#"{"q":"test"}"#.to_string(),
280 },
281 };
282
283 let internal_msg = Message::assistant("", Some(vec![tool_call]));
284
285 let openai_msg: OpenAIChatMessage = internal_msg.to_provider().unwrap();
286
287 assert_eq!(openai_msg.role, OpenAIRole::Assistant);
288 assert!(openai_msg.tool_calls.is_some());
289 let tool_calls = openai_msg.tool_calls.unwrap();
290 assert_eq!(tool_calls.len(), 1);
291 assert_eq!(tool_calls[0].id, "call_1");
292 assert_eq!(tool_calls[0].function.name, "search");
293 assert_eq!(tool_calls[0].function.arguments, r#"{"q":"test"}"#);
294 }
295
296 #[test]
297 fn test_roundtrip_conversion() {
298 let original = Message::user("Hello, world!");
299
300 let openai_msg: OpenAIChatMessage = original.to_provider().unwrap();
302
303 let roundtrip: Message = Message::from_provider(openai_msg).unwrap();
305
306 assert_eq!(roundtrip.role, original.role);
307 assert_eq!(roundtrip.content, original.content);
308 }
309
310 #[test]
311 fn test_tool_schema_conversion() {
312 let openai_tool = Tool {
313 tool_type: "function".to_string(),
314 function: crate::llm::api::models::FunctionDefinition {
315 name: "search".to_string(),
316 description: Some("Search the web".to_string()),
317 parameters: serde_json::json!({
318 "type": "object",
319 "properties": {
320 "q": { "type": "string" }
321 }
322 }),
323 },
324 };
325
326 let internal_schema: ToolSchema = ToolSchema::from_provider(openai_tool.clone()).unwrap();
328 assert_eq!(internal_schema.function.name, "search");
329
330 let roundtrip: Tool = internal_schema.to_provider().unwrap();
332 assert_eq!(roundtrip.function.name, "search");
333 assert_eq!(
334 roundtrip.function.description,
335 Some("Search the web".to_string())
336 );
337 }
338
339 #[test]
340 fn test_extension_trait() {
341 let openai_msg = OpenAIChatMessage {
342 role: OpenAIRole::User,
343 content: OpenAIContent::Text("Test".to_string()),
344 phase: None,
345 tool_calls: None,
346 tool_call_id: None,
347 };
348
349 let internal = openai_msg.into_internal().unwrap();
351 assert_eq!(internal.content, "Test");
352 }
353}