use serde::Serialize;
use serde_json::Value;
use crate::llm_client::{
Request,
types::{ContentPart, Message, MessageContent, Role, ToolDefinition},
};
use super::OpenAIScheme;
#[derive(Debug, Serialize)]
pub(crate) struct OpenAIRequest {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>, #[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>, #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub stop: Vec<String>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
pub messages: Vec<OpenAIMessage>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<OpenAITool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>, }
#[derive(Debug, Serialize)]
pub(crate) struct StreamOptions {
pub include_usage: bool,
}
#[derive(Debug, Serialize)]
pub(crate) struct OpenAIMessage {
pub role: String,
pub content: Option<OpenAIContent>, #[serde(skip_serializing_if = "Vec::is_empty")]
pub tool_calls: Vec<OpenAIToolCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>, }
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub(crate) enum OpenAIContent {
Text(String),
Parts(Vec<OpenAIContentPart>),
}
#[allow(dead_code)]
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
pub(crate) enum OpenAIContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Serialize)]
pub(crate) struct ImageUrl {
pub url: String,
}
#[derive(Debug, Serialize)]
pub(crate) struct OpenAITool {
pub r#type: String,
pub function: OpenAIToolFunction,
}
#[derive(Debug, Serialize)]
pub(crate) struct OpenAIToolFunction {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value,
}
#[derive(Debug, Serialize)]
pub(crate) struct OpenAIToolCall {
pub id: String,
pub r#type: String,
pub function: OpenAIToolCallFunction,
}
#[derive(Debug, Serialize)]
pub(crate) struct OpenAIToolCallFunction {
pub name: String,
pub arguments: String,
}
impl OpenAIScheme {
pub(crate) fn build_request(&self, model: &str, request: &Request) -> OpenAIRequest {
let mut messages = Vec::new();
if let Some(system) = &request.system_prompt {
messages.push(OpenAIMessage {
role: "system".to_string(),
content: Some(OpenAIContent::Text(system.clone())),
tool_calls: vec![],
tool_call_id: None,
name: None,
});
}
messages.extend(request.messages.iter().map(|m| self.convert_message(m)));
let tools = request.tools.iter().map(|t| self.convert_tool(t)).collect();
let (max_tokens, max_completion_tokens) = if self.use_legacy_max_tokens {
(request.config.max_tokens, None)
} else {
(None, request.config.max_tokens)
};
OpenAIRequest {
model: model.to_string(),
max_completion_tokens,
max_tokens,
temperature: request.config.temperature,
top_p: request.config.top_p,
stop: request.config.stop_sequences.clone(),
stream: true,
stream_options: Some(StreamOptions {
include_usage: true,
}),
messages,
tools,
tool_choice: None, }
}
fn convert_message(&self, message: &Message) -> OpenAIMessage {
match &message.content {
MessageContent::ToolResult {
tool_use_id,
content,
} => OpenAIMessage {
role: "tool".to_string(),
content: Some(OpenAIContent::Text(content.clone())),
tool_calls: vec![],
tool_call_id: Some(tool_use_id.clone()),
name: None,
},
MessageContent::Text(text) => {
let role = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
};
OpenAIMessage {
role: role.to_string(),
content: Some(OpenAIContent::Text(text.clone())),
tool_calls: vec![],
tool_call_id: None,
name: None,
}
}
MessageContent::Parts(parts) => {
let role = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
};
let mut content_parts = Vec::new();
let mut tool_calls = Vec::new();
let mut is_tool_result = false;
let mut tool_result_id = None;
let mut tool_result_content = String::new();
for part in parts {
match part {
ContentPart::Text { text } => {
content_parts.push(OpenAIContentPart::Text { text: text.clone() });
}
ContentPart::ToolUse { id, name, input } => {
tool_calls.push(OpenAIToolCall {
id: id.clone(),
r#type: "function".to_string(),
function: OpenAIToolCallFunction {
name: name.clone(),
arguments: input.to_string(),
},
});
}
ContentPart::ToolResult {
tool_use_id,
content,
} => {
is_tool_result = true;
tool_result_id = Some(tool_use_id.clone());
tool_result_content = content.clone();
}
}
}
if is_tool_result {
OpenAIMessage {
role: "tool".to_string(),
content: Some(OpenAIContent::Text(tool_result_content)),
tool_calls: vec![],
tool_call_id: tool_result_id,
name: None,
}
} else {
let content = if content_parts.is_empty() {
None
} else if content_parts.len() == 1 {
if let OpenAIContentPart::Text { text } = &content_parts[0] {
Some(OpenAIContent::Text(text.clone()))
} else {
Some(OpenAIContent::Parts(content_parts))
}
} else {
Some(OpenAIContent::Parts(content_parts))
};
OpenAIMessage {
role: role.to_string(),
content,
tool_calls,
tool_call_id: None,
name: None,
}
}
}
}
}
fn convert_tool(&self, tool: &ToolDefinition) -> OpenAITool {
OpenAITool {
r#type: "function".to_string(),
function: OpenAIToolFunction {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.input_schema.clone(),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_simple_request() {
let scheme = OpenAIScheme::new();
let request = Request::new().system("System prompt").user("Hello");
let body = scheme.build_request("gpt-4o", &request);
assert_eq!(body.model, "gpt-4o");
assert_eq!(body.messages.len(), 2);
assert_eq!(body.messages[0].role, "system");
assert_eq!(body.messages[1].role, "user");
if let Some(OpenAIContent::Text(text)) = &body.messages[0].content {
assert_eq!(text, "System prompt");
} else {
panic!("Expected text content");
}
}
#[test]
fn test_build_request_with_tool() {
let scheme = OpenAIScheme::new();
let request = Request::new()
.user("Check weather")
.tool(ToolDefinition::new("weather").description("Get weather"));
let body = scheme.build_request("gpt-4o", &request);
assert_eq!(body.tools.len(), 1);
assert_eq!(body.tools[0].function.name, "weather");
}
#[test]
fn test_build_request_legacy_max_tokens() {
let scheme = OpenAIScheme::new().with_legacy_max_tokens(true);
let request = Request::new().user("Hello").max_tokens(100);
let body = scheme.build_request("llama3", &request);
assert_eq!(body.max_tokens, Some(100));
assert!(body.max_completion_tokens.is_none());
}
#[test]
fn test_build_request_modern_max_tokens() {
let scheme = OpenAIScheme::new(); let request = Request::new().user("Hello").max_tokens(100);
let body = scheme.build_request("gpt-4o", &request);
assert_eq!(body.max_completion_tokens, Some(100));
assert!(body.max_tokens.is_none());
}
}