Skip to main content

pe_core/
openai_formatter.rs

1//! OpenAI-compatible message formatter.
2//!
3//! Converts pe-core `Message`, `ToolSchema`, and raw JSON responses to/from
4//! the OpenAI chat completions wire format. This is the default formatter
5//! and also works with OpenAI-compatible providers (xAI, DeepSeek, vLLM, etc.).
6
7use std::collections::HashMap;
8
9use crate::error::PeError;
10use crate::formatter::MessageFormatter;
11use crate::llm::{LlmResponse, ToolSchema};
12use crate::message::{
13    AiMessage, ContentBlock, InvalidToolCall, Message, MessageContent, ToolCall, UsageMetadata,
14};
15
16/// Formatter for the OpenAI chat completions wire format.
17///
18/// Handles conversion between pe-core types and the JSON structure expected
19/// by OpenAI and OpenAI-compatible APIs.
20///
21/// # Example
22///
23/// ```
24/// use pe_core::openai_formatter::OpenAiFormatter;
25/// use pe_core::formatter::MessageFormatter;
26/// use pe_core::Message;
27///
28/// let fmt = OpenAiFormatter;
29/// let wire = fmt.format_messages(&[Message::human("Hi")]).unwrap();
30/// assert_eq!(wire[0]["role"], "user");
31/// assert_eq!(wire[0]["content"], "Hi");
32/// ```
33pub struct OpenAiFormatter;
34
35impl MessageFormatter for OpenAiFormatter {
36    fn name(&self) -> &str {
37        "openai"
38    }
39
40    fn format_messages(&self, messages: &[Message]) -> Result<serde_json::Value, PeError> {
41        let mut result = Vec::with_capacity(messages.len());
42        for msg in messages {
43            if let Some(wire) = format_single_message(msg)? {
44                result.push(wire);
45            }
46        }
47        Ok(serde_json::Value::Array(result))
48    }
49
50    fn format_tools(&self, tools: &[ToolSchema]) -> Result<serde_json::Value, PeError> {
51        let defs: Vec<serde_json::Value> = tools
52            .iter()
53            .map(|t| {
54                let mut func = serde_json::json!({
55                    "name": t.name,
56                    "description": t.description,
57                    "parameters": t.parameters,
58                });
59                if t.strict {
60                    func["strict"] = serde_json::Value::Bool(true);
61                }
62                serde_json::json!({
63                    "type": "function",
64                    "function": func,
65                })
66            })
67            .collect();
68        Ok(serde_json::Value::Array(defs))
69    }
70
71    fn parse_response(&self, raw: &serde_json::Value) -> Result<LlmResponse, PeError> {
72        let choices = raw
73            .get("choices")
74            .and_then(|v| v.as_array())
75            .ok_or(PeError::LlmEmpty)?;
76
77        let choice = choices.first().ok_or(PeError::LlmEmpty)?;
78        let message = choice.get("message").ok_or(PeError::LlmEmpty)?;
79
80        let content = message
81            .get("content")
82            .and_then(|v| v.as_str())
83            .map(|s| MessageContent::Text(s.to_string()))
84            .unwrap_or_else(|| MessageContent::Text(String::new()));
85
86        let (tool_calls, invalid_tool_calls) = parse_wire_tool_calls(message);
87
88        let usage_metadata = raw.get("usage").and_then(|u| {
89            Some(UsageMetadata {
90                input_tokens: u.get("prompt_tokens")?.as_u64()? as u32,
91                output_tokens: u.get("completion_tokens")?.as_u64()? as u32,
92                total_tokens: u.get("total_tokens")?.as_u64()? as u32,
93                input_token_details: None,
94                output_token_details: None,
95            })
96        });
97
98        let mut provider_metadata = HashMap::new();
99        for (key, src) in [
100            ("id", raw as &serde_json::Value),
101            ("model", raw),
102            ("finish_reason", choice),
103        ] {
104            if let Some(val) = src.get(key).and_then(|v| v.as_str()) {
105                provider_metadata.insert(key.into(), serde_json::Value::String(val.to_string()));
106            }
107        }
108
109        Ok(LlmResponse {
110            message: AiMessage {
111                content,
112                tool_calls,
113                invalid_tool_calls,
114                usage_metadata,
115                response_metadata: HashMap::new(),
116                id: None,
117            },
118            provider_metadata,
119        })
120    }
121}
122
123/// Convert a single pe-core Message to OpenAI wire JSON.
124fn format_single_message(msg: &Message) -> Result<Option<serde_json::Value>, PeError> {
125    Ok(Some(match msg {
126        Message::Human(m) => {
127            serde_json::json!({"role": "user", "content": content_to_wire(&m.content)})
128        }
129        Message::System(m) => serde_json::json!({"role": "system", "content": m.content}),
130        Message::Ai(m) => {
131            let mut obj = serde_json::json!({"role": "assistant"});
132            obj["content"] = m
133                .content
134                .as_text()
135                .map(|s| serde_json::Value::String(s.to_string()))
136                .unwrap_or(serde_json::Value::Null);
137            if !m.tool_calls.is_empty() {
138                let wire: Result<Vec<_>, PeError> = m.tool_calls.iter().map(|tc| {
139                    let args = serde_json::to_string(&tc.args).map_err(|e| PeError::LlmProvider {
140                        details: format!("failed to serialize tool call args for '{}': {e}", tc.name),
141                    })?;
142                    Ok(serde_json::json!({"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": args}}))
143                }).collect();
144                obj["tool_calls"] = serde_json::Value::Array(wire?);
145            }
146            obj
147        }
148        Message::Tool(m) => {
149            serde_json::json!({"role": "tool", "content": m.content, "tool_call_id": m.tool_call_id})
150        }
151        #[allow(unreachable_patterns)]
152        _ => return Ok(None),
153    }))
154}
155
156/// Convert MessageContent to OpenAI wire JSON.
157fn content_to_wire(content: &MessageContent) -> serde_json::Value {
158    match content {
159        MessageContent::Text(t) => serde_json::Value::String(t.clone()),
160        MessageContent::Blocks(blocks) => {
161            let parts: Vec<_> = blocks
162                .iter()
163                .filter_map(|block| match block {
164                    ContentBlock::Text { text } => {
165                        Some(serde_json::json!({"type": "text", "text": text}))
166                    }
167                    ContentBlock::Image { url } => {
168                        Some(serde_json::json!({"type": "image_url", "image_url": {"url": url}}))
169                    }
170                    _ => None,
171                })
172                .collect();
173            serde_json::Value::Array(parts)
174        }
175        #[allow(unreachable_patterns)]
176        _ => serde_json::Value::String("[unsupported content type]".into()),
177    }
178}
179
180/// Parse tool calls from the wire response message JSON.
181fn parse_wire_tool_calls(message: &serde_json::Value) -> (Vec<ToolCall>, Vec<InvalidToolCall>) {
182    let (mut valid, mut invalid) = (Vec::new(), Vec::new());
183    let Some(wire) = message.get("tool_calls").and_then(|v| v.as_array()) else {
184        return (valid, invalid);
185    };
186    for tc in wire {
187        let func = tc.get("function");
188        let id = tc
189            .get("id")
190            .and_then(|v| v.as_str())
191            .unwrap_or("")
192            .to_string();
193        let name = func
194            .and_then(|f| f.get("name"))
195            .and_then(|v| v.as_str())
196            .unwrap_or("")
197            .to_string();
198        let arguments = func
199            .and_then(|f| f.get("arguments"))
200            .and_then(|v| v.as_str())
201            .unwrap_or("")
202            .to_string();
203        match serde_json::from_str::<serde_json::Value>(&arguments) {
204            Ok(args) => valid.push(ToolCall { id, name, args }),
205            Err(e) => invalid.push(InvalidToolCall {
206                id,
207                name,
208                args: arguments,
209                error: e.to_string(),
210            }),
211        }
212    }
213    (valid, invalid)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_name_returns_openai() {
222        assert_eq!(OpenAiFormatter.name(), "openai");
223    }
224
225    #[test]
226    fn test_format_human_message() {
227        let msgs = vec![Message::human("Hello")];
228        let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
229        assert_eq!(wire[0]["role"], "user");
230        assert_eq!(wire[0]["content"], "Hello");
231    }
232
233    #[test]
234    fn test_format_system_message() {
235        let msgs = vec![Message::system("Be helpful")];
236        let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
237        assert_eq!(wire[0]["role"], "system");
238        assert_eq!(wire[0]["content"], "Be helpful");
239    }
240
241    #[test]
242    fn test_format_ai_message_with_tool_calls() {
243        let msg = Message::Ai(AiMessage {
244            content: MessageContent::Text(String::new()),
245            tool_calls: vec![ToolCall {
246                id: "call_1".into(),
247                name: "search".into(),
248                args: serde_json::json!({"q": "rust"}),
249            }],
250            invalid_tool_calls: vec![],
251            usage_metadata: None,
252            response_metadata: HashMap::new(),
253            id: None,
254        });
255        let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
256        assert_eq!(wire[0]["role"], "assistant");
257        assert_eq!(wire[0]["tool_calls"][0]["function"]["name"], "search");
258        assert_eq!(wire[0]["tool_calls"][0]["type"], "function");
259    }
260
261    #[test]
262    fn test_format_tool_message() {
263        let msg = Message::tool("result data", "call_1");
264        let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
265        assert_eq!(wire[0]["role"], "tool");
266        assert_eq!(wire[0]["tool_call_id"], "call_1");
267        assert_eq!(wire[0]["content"], "result data");
268    }
269
270    #[test]
271    fn test_format_tools_with_strict() {
272        let tools = vec![ToolSchema {
273            name: "search".into(),
274            description: "Search the web".into(),
275            parameters: serde_json::json!({"type": "object"}),
276            strict: true,
277        }];
278        let wire = OpenAiFormatter.format_tools(&tools).unwrap();
279        assert_eq!(wire[0]["type"], "function");
280        assert_eq!(wire[0]["function"]["name"], "search");
281        assert_eq!(wire[0]["function"]["strict"], true);
282    }
283
284    #[test]
285    fn test_format_tools_without_strict() {
286        let tools = vec![ToolSchema {
287            name: "calc".into(),
288            description: "Calculate".into(),
289            parameters: serde_json::json!({"type": "object"}),
290            strict: false,
291        }];
292        let wire = OpenAiFormatter.format_tools(&tools).unwrap();
293        // strict should not appear when false
294        assert!(wire[0]["function"].get("strict").is_none());
295    }
296
297    #[test]
298    fn test_format_empty_tools() {
299        let wire = OpenAiFormatter.format_tools(&[]).unwrap();
300        assert_eq!(wire, serde_json::json!([]));
301    }
302
303    #[test]
304    fn test_parse_response_text() {
305        let raw = serde_json::json!({
306            "id": "chatcmpl-123",
307            "model": "gpt-4",
308            "choices": [{
309                "message": { "content": "Hello world", "role": "assistant" },
310                "finish_reason": "stop"
311            }],
312            "usage": {
313                "prompt_tokens": 10,
314                "completion_tokens": 5,
315                "total_tokens": 15
316            }
317        });
318        let resp = OpenAiFormatter.parse_response(&raw).unwrap();
319        assert_eq!(resp.message.content.as_text(), Some("Hello world"));
320        assert_eq!(
321            resp.message.usage_metadata.as_ref().unwrap().input_tokens,
322            10
323        );
324        assert_eq!(
325            resp.message.usage_metadata.as_ref().unwrap().output_tokens,
326            5
327        );
328        assert_eq!(resp.provider_metadata["finish_reason"], "stop");
329        assert_eq!(resp.provider_metadata["model"], "gpt-4");
330        assert_eq!(resp.provider_metadata["id"], "chatcmpl-123");
331    }
332
333    #[test]
334    fn test_parse_response_with_tool_calls() {
335        let raw = serde_json::json!({
336            "choices": [{
337                "message": {
338                    "content": null,
339                    "tool_calls": [{
340                        "id": "call_abc",
341                        "type": "function",
342                        "function": {
343                            "name": "get_weather",
344                            "arguments": "{\"location\":\"NYC\"}"
345                        }
346                    }]
347                },
348                "finish_reason": "tool_calls"
349            }],
350            "usage": { "prompt_tokens": 20, "completion_tokens": 15, "total_tokens": 35 }
351        });
352        let resp = OpenAiFormatter.parse_response(&raw).unwrap();
353        assert_eq!(resp.message.tool_calls.len(), 1);
354        assert_eq!(resp.message.tool_calls[0].name, "get_weather");
355        assert_eq!(resp.message.tool_calls[0].args["location"], "NYC");
356    }
357
358    #[test]
359    fn test_parse_response_invalid_tool_call_json() {
360        let raw = serde_json::json!({
361            "choices": [{
362                "message": {
363                    "content": null,
364                    "tool_calls": [{
365                        "id": "call_bad",
366                        "type": "function",
367                        "function": {
368                            "name": "broken",
369                            "arguments": "not json{"
370                        }
371                    }]
372                },
373                "finish_reason": "tool_calls"
374            }]
375        });
376        let resp = OpenAiFormatter.parse_response(&raw).unwrap();
377        assert!(resp.message.tool_calls.is_empty());
378        assert_eq!(resp.message.invalid_tool_calls.len(), 1);
379        assert_eq!(resp.message.invalid_tool_calls[0].name, "broken");
380    }
381
382    #[test]
383    fn test_parse_response_empty_choices_returns_error() {
384        let raw = serde_json::json!({ "choices": [] });
385        let err = OpenAiFormatter.parse_response(&raw).unwrap_err();
386        assert!(matches!(err, PeError::LlmEmpty));
387    }
388
389    #[test]
390    fn test_parse_response_no_choices_key_returns_error() {
391        let raw = serde_json::json!({ "error": "bad request" });
392        let err = OpenAiFormatter.parse_response(&raw).unwrap_err();
393        assert!(matches!(err, PeError::LlmEmpty));
394    }
395
396    #[test]
397    fn test_format_multimodal_content() {
398        let msg = Message::Human(crate::message::HumanMessage {
399            content: MessageContent::Blocks(vec![
400                ContentBlock::Text {
401                    text: "What is this?".into(),
402                },
403                ContentBlock::Image {
404                    url: "https://example.com/img.png".into(),
405                },
406            ]),
407            id: None,
408            name: None,
409        });
410        let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
411        let content = &wire[0]["content"];
412        assert!(content.is_array());
413        assert_eq!(content[0]["type"], "text");
414        assert_eq!(content[1]["type"], "image_url");
415        assert_eq!(
416            content[1]["image_url"]["url"],
417            "https://example.com/img.png"
418        );
419    }
420
421    #[test]
422    fn test_format_multiple_messages_preserves_order() {
423        let msgs = vec![
424            Message::system("System prompt"),
425            Message::human("Hello"),
426            Message::ai("Hi there"),
427        ];
428        let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
429        assert_eq!(wire.as_array().unwrap().len(), 3);
430        assert_eq!(wire[0]["role"], "system");
431        assert_eq!(wire[1]["role"], "user");
432        assert_eq!(wire[2]["role"], "assistant");
433    }
434}