1use std::collections::HashMap;
8
9use crate::error::PeError;
10use crate::formatter::MessageFormatter;
11use crate::llm::{LlmResponse, ToolSchema};
12use crate::message::{
13 AiMessage, ContentBlock, InvalidToolCall, Message, MessageContent, ToolCall, UsageMetadata,
14};
15
16pub struct OpenAiFormatter;
34
35impl MessageFormatter for OpenAiFormatter {
36 fn name(&self) -> &str {
37 "openai"
38 }
39
40 fn format_messages(&self, messages: &[Message]) -> Result<serde_json::Value, PeError> {
41 let mut result = Vec::with_capacity(messages.len());
42 for msg in messages {
43 if let Some(wire) = format_single_message(msg)? {
44 result.push(wire);
45 }
46 }
47 Ok(serde_json::Value::Array(result))
48 }
49
50 fn format_tools(&self, tools: &[ToolSchema]) -> Result<serde_json::Value, PeError> {
51 let defs: Vec<serde_json::Value> = tools
52 .iter()
53 .map(|t| {
54 let mut func = serde_json::json!({
55 "name": t.name,
56 "description": t.description,
57 "parameters": t.parameters,
58 });
59 if t.strict {
60 func["strict"] = serde_json::Value::Bool(true);
61 }
62 serde_json::json!({
63 "type": "function",
64 "function": func,
65 })
66 })
67 .collect();
68 Ok(serde_json::Value::Array(defs))
69 }
70
71 fn parse_response(&self, raw: &serde_json::Value) -> Result<LlmResponse, PeError> {
72 let choices = raw
73 .get("choices")
74 .and_then(|v| v.as_array())
75 .ok_or(PeError::LlmEmpty)?;
76
77 let choice = choices.first().ok_or(PeError::LlmEmpty)?;
78 let message = choice.get("message").ok_or(PeError::LlmEmpty)?;
79
80 let content = message
81 .get("content")
82 .and_then(|v| v.as_str())
83 .map(|s| MessageContent::Text(s.to_string()))
84 .unwrap_or_else(|| MessageContent::Text(String::new()));
85
86 let (tool_calls, invalid_tool_calls) = parse_wire_tool_calls(message);
87
88 let usage_metadata = raw.get("usage").and_then(|u| {
89 Some(UsageMetadata {
90 input_tokens: u.get("prompt_tokens")?.as_u64()? as u32,
91 output_tokens: u.get("completion_tokens")?.as_u64()? as u32,
92 total_tokens: u.get("total_tokens")?.as_u64()? as u32,
93 input_token_details: None,
94 output_token_details: None,
95 })
96 });
97
98 let mut provider_metadata = HashMap::new();
99 for (key, src) in [
100 ("id", raw as &serde_json::Value),
101 ("model", raw),
102 ("finish_reason", choice),
103 ] {
104 if let Some(val) = src.get(key).and_then(|v| v.as_str()) {
105 provider_metadata.insert(key.into(), serde_json::Value::String(val.to_string()));
106 }
107 }
108
109 Ok(LlmResponse {
110 message: AiMessage {
111 content,
112 tool_calls,
113 invalid_tool_calls,
114 usage_metadata,
115 response_metadata: HashMap::new(),
116 id: None,
117 },
118 provider_metadata,
119 })
120 }
121}
122
123fn format_single_message(msg: &Message) -> Result<Option<serde_json::Value>, PeError> {
125 Ok(Some(match msg {
126 Message::Human(m) => {
127 serde_json::json!({"role": "user", "content": content_to_wire(&m.content)})
128 }
129 Message::System(m) => serde_json::json!({"role": "system", "content": m.content}),
130 Message::Ai(m) => {
131 let mut obj = serde_json::json!({"role": "assistant"});
132 obj["content"] = m
133 .content
134 .as_text()
135 .map(|s| serde_json::Value::String(s.to_string()))
136 .unwrap_or(serde_json::Value::Null);
137 if !m.tool_calls.is_empty() {
138 let wire: Result<Vec<_>, PeError> = m.tool_calls.iter().map(|tc| {
139 let args = serde_json::to_string(&tc.args).map_err(|e| PeError::LlmProvider {
140 details: format!("failed to serialize tool call args for '{}': {e}", tc.name),
141 })?;
142 Ok(serde_json::json!({"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": args}}))
143 }).collect();
144 obj["tool_calls"] = serde_json::Value::Array(wire?);
145 }
146 obj
147 }
148 Message::Tool(m) => {
149 serde_json::json!({"role": "tool", "content": m.content, "tool_call_id": m.tool_call_id})
150 }
151 #[allow(unreachable_patterns)]
152 _ => return Ok(None),
153 }))
154}
155
156fn content_to_wire(content: &MessageContent) -> serde_json::Value {
158 match content {
159 MessageContent::Text(t) => serde_json::Value::String(t.clone()),
160 MessageContent::Blocks(blocks) => {
161 let parts: Vec<_> = blocks
162 .iter()
163 .filter_map(|block| match block {
164 ContentBlock::Text { text } => {
165 Some(serde_json::json!({"type": "text", "text": text}))
166 }
167 ContentBlock::Image { url } => {
168 Some(serde_json::json!({"type": "image_url", "image_url": {"url": url}}))
169 }
170 _ => None,
171 })
172 .collect();
173 serde_json::Value::Array(parts)
174 }
175 #[allow(unreachable_patterns)]
176 _ => serde_json::Value::String("[unsupported content type]".into()),
177 }
178}
179
180fn parse_wire_tool_calls(message: &serde_json::Value) -> (Vec<ToolCall>, Vec<InvalidToolCall>) {
182 let (mut valid, mut invalid) = (Vec::new(), Vec::new());
183 let Some(wire) = message.get("tool_calls").and_then(|v| v.as_array()) else {
184 return (valid, invalid);
185 };
186 for tc in wire {
187 let func = tc.get("function");
188 let id = tc
189 .get("id")
190 .and_then(|v| v.as_str())
191 .unwrap_or("")
192 .to_string();
193 let name = func
194 .and_then(|f| f.get("name"))
195 .and_then(|v| v.as_str())
196 .unwrap_or("")
197 .to_string();
198 let arguments = func
199 .and_then(|f| f.get("arguments"))
200 .and_then(|v| v.as_str())
201 .unwrap_or("")
202 .to_string();
203 match serde_json::from_str::<serde_json::Value>(&arguments) {
204 Ok(args) => valid.push(ToolCall { id, name, args }),
205 Err(e) => invalid.push(InvalidToolCall {
206 id,
207 name,
208 args: arguments,
209 error: e.to_string(),
210 }),
211 }
212 }
213 (valid, invalid)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_name_returns_openai() {
222 assert_eq!(OpenAiFormatter.name(), "openai");
223 }
224
225 #[test]
226 fn test_format_human_message() {
227 let msgs = vec![Message::human("Hello")];
228 let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
229 assert_eq!(wire[0]["role"], "user");
230 assert_eq!(wire[0]["content"], "Hello");
231 }
232
233 #[test]
234 fn test_format_system_message() {
235 let msgs = vec![Message::system("Be helpful")];
236 let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
237 assert_eq!(wire[0]["role"], "system");
238 assert_eq!(wire[0]["content"], "Be helpful");
239 }
240
241 #[test]
242 fn test_format_ai_message_with_tool_calls() {
243 let msg = Message::Ai(AiMessage {
244 content: MessageContent::Text(String::new()),
245 tool_calls: vec![ToolCall {
246 id: "call_1".into(),
247 name: "search".into(),
248 args: serde_json::json!({"q": "rust"}),
249 }],
250 invalid_tool_calls: vec![],
251 usage_metadata: None,
252 response_metadata: HashMap::new(),
253 id: None,
254 });
255 let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
256 assert_eq!(wire[0]["role"], "assistant");
257 assert_eq!(wire[0]["tool_calls"][0]["function"]["name"], "search");
258 assert_eq!(wire[0]["tool_calls"][0]["type"], "function");
259 }
260
261 #[test]
262 fn test_format_tool_message() {
263 let msg = Message::tool("result data", "call_1");
264 let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
265 assert_eq!(wire[0]["role"], "tool");
266 assert_eq!(wire[0]["tool_call_id"], "call_1");
267 assert_eq!(wire[0]["content"], "result data");
268 }
269
270 #[test]
271 fn test_format_tools_with_strict() {
272 let tools = vec![ToolSchema {
273 name: "search".into(),
274 description: "Search the web".into(),
275 parameters: serde_json::json!({"type": "object"}),
276 strict: true,
277 }];
278 let wire = OpenAiFormatter.format_tools(&tools).unwrap();
279 assert_eq!(wire[0]["type"], "function");
280 assert_eq!(wire[0]["function"]["name"], "search");
281 assert_eq!(wire[0]["function"]["strict"], true);
282 }
283
284 #[test]
285 fn test_format_tools_without_strict() {
286 let tools = vec![ToolSchema {
287 name: "calc".into(),
288 description: "Calculate".into(),
289 parameters: serde_json::json!({"type": "object"}),
290 strict: false,
291 }];
292 let wire = OpenAiFormatter.format_tools(&tools).unwrap();
293 assert!(wire[0]["function"].get("strict").is_none());
295 }
296
297 #[test]
298 fn test_format_empty_tools() {
299 let wire = OpenAiFormatter.format_tools(&[]).unwrap();
300 assert_eq!(wire, serde_json::json!([]));
301 }
302
303 #[test]
304 fn test_parse_response_text() {
305 let raw = serde_json::json!({
306 "id": "chatcmpl-123",
307 "model": "gpt-4",
308 "choices": [{
309 "message": { "content": "Hello world", "role": "assistant" },
310 "finish_reason": "stop"
311 }],
312 "usage": {
313 "prompt_tokens": 10,
314 "completion_tokens": 5,
315 "total_tokens": 15
316 }
317 });
318 let resp = OpenAiFormatter.parse_response(&raw).unwrap();
319 assert_eq!(resp.message.content.as_text(), Some("Hello world"));
320 assert_eq!(
321 resp.message.usage_metadata.as_ref().unwrap().input_tokens,
322 10
323 );
324 assert_eq!(
325 resp.message.usage_metadata.as_ref().unwrap().output_tokens,
326 5
327 );
328 assert_eq!(resp.provider_metadata["finish_reason"], "stop");
329 assert_eq!(resp.provider_metadata["model"], "gpt-4");
330 assert_eq!(resp.provider_metadata["id"], "chatcmpl-123");
331 }
332
333 #[test]
334 fn test_parse_response_with_tool_calls() {
335 let raw = serde_json::json!({
336 "choices": [{
337 "message": {
338 "content": null,
339 "tool_calls": [{
340 "id": "call_abc",
341 "type": "function",
342 "function": {
343 "name": "get_weather",
344 "arguments": "{\"location\":\"NYC\"}"
345 }
346 }]
347 },
348 "finish_reason": "tool_calls"
349 }],
350 "usage": { "prompt_tokens": 20, "completion_tokens": 15, "total_tokens": 35 }
351 });
352 let resp = OpenAiFormatter.parse_response(&raw).unwrap();
353 assert_eq!(resp.message.tool_calls.len(), 1);
354 assert_eq!(resp.message.tool_calls[0].name, "get_weather");
355 assert_eq!(resp.message.tool_calls[0].args["location"], "NYC");
356 }
357
358 #[test]
359 fn test_parse_response_invalid_tool_call_json() {
360 let raw = serde_json::json!({
361 "choices": [{
362 "message": {
363 "content": null,
364 "tool_calls": [{
365 "id": "call_bad",
366 "type": "function",
367 "function": {
368 "name": "broken",
369 "arguments": "not json{"
370 }
371 }]
372 },
373 "finish_reason": "tool_calls"
374 }]
375 });
376 let resp = OpenAiFormatter.parse_response(&raw).unwrap();
377 assert!(resp.message.tool_calls.is_empty());
378 assert_eq!(resp.message.invalid_tool_calls.len(), 1);
379 assert_eq!(resp.message.invalid_tool_calls[0].name, "broken");
380 }
381
382 #[test]
383 fn test_parse_response_empty_choices_returns_error() {
384 let raw = serde_json::json!({ "choices": [] });
385 let err = OpenAiFormatter.parse_response(&raw).unwrap_err();
386 assert!(matches!(err, PeError::LlmEmpty));
387 }
388
389 #[test]
390 fn test_parse_response_no_choices_key_returns_error() {
391 let raw = serde_json::json!({ "error": "bad request" });
392 let err = OpenAiFormatter.parse_response(&raw).unwrap_err();
393 assert!(matches!(err, PeError::LlmEmpty));
394 }
395
396 #[test]
397 fn test_format_multimodal_content() {
398 let msg = Message::Human(crate::message::HumanMessage {
399 content: MessageContent::Blocks(vec![
400 ContentBlock::Text {
401 text: "What is this?".into(),
402 },
403 ContentBlock::Image {
404 url: "https://example.com/img.png".into(),
405 },
406 ]),
407 id: None,
408 name: None,
409 });
410 let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
411 let content = &wire[0]["content"];
412 assert!(content.is_array());
413 assert_eq!(content[0]["type"], "text");
414 assert_eq!(content[1]["type"], "image_url");
415 assert_eq!(
416 content[1]["image_url"]["url"],
417 "https://example.com/img.png"
418 );
419 }
420
421 #[test]
422 fn test_format_multiple_messages_preserves_order() {
423 let msgs = vec![
424 Message::system("System prompt"),
425 Message::human("Hello"),
426 Message::ai("Hi there"),
427 ];
428 let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
429 assert_eq!(wire.as_array().unwrap().len(), 3);
430 assert_eq!(wire[0]["role"], "system");
431 assert_eq!(wire[1]["role"], "user");
432 assert_eq!(wire[2]["role"], "assistant");
433 }
434}