use std::collections::HashMap;
use crate::error::PeError;
use crate::formatter::MessageFormatter;
use crate::llm::{LlmResponse, ToolSchema};
use crate::message::{
AiMessage, ContentBlock, InvalidToolCall, Message, MessageContent, ToolCall, UsageMetadata,
};
pub struct OpenAiFormatter;
impl MessageFormatter for OpenAiFormatter {
fn name(&self) -> &str {
"openai"
}
fn format_messages(&self, messages: &[Message]) -> Result<serde_json::Value, PeError> {
let mut result = Vec::with_capacity(messages.len());
for msg in messages {
if let Some(wire) = format_single_message(msg)? {
result.push(wire);
}
}
Ok(serde_json::Value::Array(result))
}
fn format_tools(&self, tools: &[ToolSchema]) -> Result<serde_json::Value, PeError> {
let defs: Vec<serde_json::Value> = tools
.iter()
.map(|t| {
let mut func = serde_json::json!({
"name": t.name,
"description": t.description,
"parameters": t.parameters,
});
if t.strict {
func["strict"] = serde_json::Value::Bool(true);
}
serde_json::json!({
"type": "function",
"function": func,
})
})
.collect();
Ok(serde_json::Value::Array(defs))
}
fn parse_response(&self, raw: &serde_json::Value) -> Result<LlmResponse, PeError> {
let choices = raw
.get("choices")
.and_then(|v| v.as_array())
.ok_or(PeError::LlmEmpty)?;
let choice = choices.first().ok_or(PeError::LlmEmpty)?;
let message = choice.get("message").ok_or(PeError::LlmEmpty)?;
let content = message
.get("content")
.and_then(|v| v.as_str())
.map(|s| MessageContent::Text(s.to_string()))
.unwrap_or_else(|| MessageContent::Text(String::new()));
let (tool_calls, invalid_tool_calls) = parse_wire_tool_calls(message);
let usage_metadata = raw.get("usage").and_then(|u| {
Some(UsageMetadata {
input_tokens: u.get("prompt_tokens")?.as_u64()? as u32,
output_tokens: u.get("completion_tokens")?.as_u64()? as u32,
total_tokens: u.get("total_tokens")?.as_u64()? as u32,
input_token_details: None,
output_token_details: None,
})
});
let mut provider_metadata = HashMap::new();
for (key, src) in [
("id", raw as &serde_json::Value),
("model", raw),
("finish_reason", choice),
] {
if let Some(val) = src.get(key).and_then(|v| v.as_str()) {
provider_metadata.insert(key.into(), serde_json::Value::String(val.to_string()));
}
}
Ok(LlmResponse {
message: AiMessage {
content,
tool_calls,
invalid_tool_calls,
usage_metadata,
response_metadata: HashMap::new(),
id: None,
},
provider_metadata,
})
}
}
fn format_single_message(msg: &Message) -> Result<Option<serde_json::Value>, PeError> {
Ok(Some(match msg {
Message::Human(m) => {
serde_json::json!({"role": "user", "content": content_to_wire(&m.content)})
}
Message::System(m) => serde_json::json!({"role": "system", "content": m.content}),
Message::Ai(m) => {
let mut obj = serde_json::json!({"role": "assistant"});
obj["content"] = m
.content
.as_text()
.map(|s| serde_json::Value::String(s.to_string()))
.unwrap_or(serde_json::Value::Null);
if !m.tool_calls.is_empty() {
let wire: Result<Vec<_>, PeError> = m.tool_calls.iter().map(|tc| {
let args = serde_json::to_string(&tc.args).map_err(|e| PeError::LlmProvider {
details: format!("failed to serialize tool call args for '{}': {e}", tc.name),
})?;
Ok(serde_json::json!({"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": args}}))
}).collect();
obj["tool_calls"] = serde_json::Value::Array(wire?);
}
obj
}
Message::Tool(m) => {
serde_json::json!({"role": "tool", "content": m.content, "tool_call_id": m.tool_call_id})
}
#[allow(unreachable_patterns)]
_ => return Ok(None),
}))
}
fn content_to_wire(content: &MessageContent) -> serde_json::Value {
match content {
MessageContent::Text(t) => serde_json::Value::String(t.clone()),
MessageContent::Blocks(blocks) => {
let parts: Vec<_> = blocks
.iter()
.filter_map(|block| match block {
ContentBlock::Text { text } => {
Some(serde_json::json!({"type": "text", "text": text}))
}
ContentBlock::Image { url } => {
Some(serde_json::json!({"type": "image_url", "image_url": {"url": url}}))
}
_ => None,
})
.collect();
serde_json::Value::Array(parts)
}
#[allow(unreachable_patterns)]
_ => serde_json::Value::String("[unsupported content type]".into()),
}
}
fn parse_wire_tool_calls(message: &serde_json::Value) -> (Vec<ToolCall>, Vec<InvalidToolCall>) {
let (mut valid, mut invalid) = (Vec::new(), Vec::new());
let Some(wire) = message.get("tool_calls").and_then(|v| v.as_array()) else {
return (valid, invalid);
};
for tc in wire {
let func = tc.get("function");
let id = tc
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let name = func
.and_then(|f| f.get("name"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let arguments = func
.and_then(|f| f.get("arguments"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
match serde_json::from_str::<serde_json::Value>(&arguments) {
Ok(args) => valid.push(ToolCall { id, name, args }),
Err(e) => invalid.push(InvalidToolCall {
id,
name,
args: arguments,
error: e.to_string(),
}),
}
}
(valid, invalid)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_name_returns_openai() {
assert_eq!(OpenAiFormatter.name(), "openai");
}
#[test]
fn test_format_human_message() {
let msgs = vec![Message::human("Hello")];
let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
assert_eq!(wire[0]["role"], "user");
assert_eq!(wire[0]["content"], "Hello");
}
#[test]
fn test_format_system_message() {
let msgs = vec![Message::system("Be helpful")];
let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
assert_eq!(wire[0]["role"], "system");
assert_eq!(wire[0]["content"], "Be helpful");
}
#[test]
fn test_format_ai_message_with_tool_calls() {
let msg = Message::Ai(AiMessage {
content: MessageContent::Text(String::new()),
tool_calls: vec![ToolCall {
id: "call_1".into(),
name: "search".into(),
args: serde_json::json!({"q": "rust"}),
}],
invalid_tool_calls: vec![],
usage_metadata: None,
response_metadata: HashMap::new(),
id: None,
});
let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
assert_eq!(wire[0]["role"], "assistant");
assert_eq!(wire[0]["tool_calls"][0]["function"]["name"], "search");
assert_eq!(wire[0]["tool_calls"][0]["type"], "function");
}
#[test]
fn test_format_tool_message() {
let msg = Message::tool("result data", "call_1");
let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
assert_eq!(wire[0]["role"], "tool");
assert_eq!(wire[0]["tool_call_id"], "call_1");
assert_eq!(wire[0]["content"], "result data");
}
#[test]
fn test_format_tools_with_strict() {
let tools = vec![ToolSchema {
name: "search".into(),
description: "Search the web".into(),
parameters: serde_json::json!({"type": "object"}),
strict: true,
}];
let wire = OpenAiFormatter.format_tools(&tools).unwrap();
assert_eq!(wire[0]["type"], "function");
assert_eq!(wire[0]["function"]["name"], "search");
assert_eq!(wire[0]["function"]["strict"], true);
}
#[test]
fn test_format_tools_without_strict() {
let tools = vec![ToolSchema {
name: "calc".into(),
description: "Calculate".into(),
parameters: serde_json::json!({"type": "object"}),
strict: false,
}];
let wire = OpenAiFormatter.format_tools(&tools).unwrap();
assert!(wire[0]["function"].get("strict").is_none());
}
#[test]
fn test_format_empty_tools() {
let wire = OpenAiFormatter.format_tools(&[]).unwrap();
assert_eq!(wire, serde_json::json!([]));
}
#[test]
fn test_parse_response_text() {
let raw = serde_json::json!({
"id": "chatcmpl-123",
"model": "gpt-4",
"choices": [{
"message": { "content": "Hello world", "role": "assistant" },
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
});
let resp = OpenAiFormatter.parse_response(&raw).unwrap();
assert_eq!(resp.message.content.as_text(), Some("Hello world"));
assert_eq!(
resp.message.usage_metadata.as_ref().unwrap().input_tokens,
10
);
assert_eq!(
resp.message.usage_metadata.as_ref().unwrap().output_tokens,
5
);
assert_eq!(resp.provider_metadata["finish_reason"], "stop");
assert_eq!(resp.provider_metadata["model"], "gpt-4");
assert_eq!(resp.provider_metadata["id"], "chatcmpl-123");
}
#[test]
fn test_parse_response_with_tool_calls() {
let raw = serde_json::json!({
"choices": [{
"message": {
"content": null,
"tool_calls": [{
"id": "call_abc",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\":\"NYC\"}"
}
}]
},
"finish_reason": "tool_calls"
}],
"usage": { "prompt_tokens": 20, "completion_tokens": 15, "total_tokens": 35 }
});
let resp = OpenAiFormatter.parse_response(&raw).unwrap();
assert_eq!(resp.message.tool_calls.len(), 1);
assert_eq!(resp.message.tool_calls[0].name, "get_weather");
assert_eq!(resp.message.tool_calls[0].args["location"], "NYC");
}
#[test]
fn test_parse_response_invalid_tool_call_json() {
let raw = serde_json::json!({
"choices": [{
"message": {
"content": null,
"tool_calls": [{
"id": "call_bad",
"type": "function",
"function": {
"name": "broken",
"arguments": "not json{"
}
}]
},
"finish_reason": "tool_calls"
}]
});
let resp = OpenAiFormatter.parse_response(&raw).unwrap();
assert!(resp.message.tool_calls.is_empty());
assert_eq!(resp.message.invalid_tool_calls.len(), 1);
assert_eq!(resp.message.invalid_tool_calls[0].name, "broken");
}
#[test]
fn test_parse_response_empty_choices_returns_error() {
let raw = serde_json::json!({ "choices": [] });
let err = OpenAiFormatter.parse_response(&raw).unwrap_err();
assert!(matches!(err, PeError::LlmEmpty));
}
#[test]
fn test_parse_response_no_choices_key_returns_error() {
let raw = serde_json::json!({ "error": "bad request" });
let err = OpenAiFormatter.parse_response(&raw).unwrap_err();
assert!(matches!(err, PeError::LlmEmpty));
}
#[test]
fn test_format_multimodal_content() {
let msg = Message::Human(crate::message::HumanMessage {
content: MessageContent::Blocks(vec![
ContentBlock::Text {
text: "What is this?".into(),
},
ContentBlock::Image {
url: "https://example.com/img.png".into(),
},
]),
id: None,
name: None,
});
let wire = OpenAiFormatter.format_messages(&[msg]).unwrap();
let content = &wire[0]["content"];
assert!(content.is_array());
assert_eq!(content[0]["type"], "text");
assert_eq!(content[1]["type"], "image_url");
assert_eq!(
content[1]["image_url"]["url"],
"https://example.com/img.png"
);
}
#[test]
fn test_format_multiple_messages_preserves_order() {
let msgs = vec![
Message::system("System prompt"),
Message::human("Hello"),
Message::ai("Hi there"),
];
let wire = OpenAiFormatter.format_messages(&msgs).unwrap();
assert_eq!(wire.as_array().unwrap().len(), 3);
assert_eq!(wire[0]["role"], "system");
assert_eq!(wire[1]["role"], "user");
assert_eq!(wire[2]["role"], "assistant");
}
}