use serde::{Deserialize, Serialize};
use crate::canonical::{ChatRequest, ChatResponse, Message, PluginRequest, Role, StopReason, Usage};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAiMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OpenAiChatRequest {
pub model: String,
pub messages: Vec<OpenAiMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default)]
pub stream: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub plugins: Vec<PluginRequest>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct OpenAiUsage {
#[serde(default)]
pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32,
#[serde(default)]
pub total_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OpenAiChoice {
pub index: u32,
pub message: OpenAiMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct OpenAiChatResponse {
pub id: String,
pub object: String,
pub model: String,
pub choices: Vec<OpenAiChoice>,
#[serde(default)]
pub usage: OpenAiUsage,
}
impl From<OpenAiChatRequest> for ChatRequest {
fn from(req: OpenAiChatRequest) -> Self {
let mut system = None;
let mut messages = Vec::with_capacity(req.messages.len());
for msg in req.messages {
match msg.role.as_str() {
"system" => system = Some(msg.content),
"assistant" => messages.push(Message {
role: Role::Assistant,
content: msg.content,
}),
_ => messages.push(Message {
role: Role::User,
content: msg.content,
}),
}
}
ChatRequest {
model: req.model,
system,
messages,
max_tokens: req.max_tokens,
temperature: req.temperature,
stream: req.stream,
plugins: req.plugins,
forced_provider: None,
tags: Vec::new(),
}
}
}
impl From<&ChatRequest> for OpenAiChatRequest {
fn from(req: &ChatRequest) -> Self {
let mut messages = Vec::with_capacity(req.messages.len() + 1);
if let Some(system) = &req.system {
messages.push(OpenAiMessage {
role: "system".to_string(),
content: system.clone(),
});
}
for msg in &req.messages {
let role = match msg.role {
Role::User => "user",
Role::Assistant => "assistant",
};
messages.push(OpenAiMessage {
role: role.to_string(),
content: msg.content.clone(),
});
}
OpenAiChatRequest {
model: req.model.clone(),
messages,
max_tokens: req.max_tokens,
temperature: req.temperature,
stream: false,
plugins: Vec::new(),
}
}
}
impl From<OpenAiChatResponse> for ChatResponse {
fn from(resp: OpenAiChatResponse) -> Self {
let choice = resp.choices.into_iter().next();
let content = choice
.as_ref()
.map(|c| c.message.content.clone())
.unwrap_or_default();
let stop_reason = match choice.and_then(|c| c.finish_reason) {
Some(reason) if reason == "length" => StopReason::MaxTokens,
Some(reason) if reason == "stop" => StopReason::EndTurn,
_ => StopReason::Other,
};
ChatResponse {
id: resp.id,
model: resp.model,
content,
stop_reason,
usage: Usage {
input_tokens: resp.usage.prompt_tokens,
output_tokens: resp.usage.completion_tokens,
},
}
}
}
impl From<ChatResponse> for OpenAiChatResponse {
fn from(resp: ChatResponse) -> Self {
let finish_reason = match resp.stop_reason {
StopReason::EndTurn => "stop",
StopReason::MaxTokens => "length",
StopReason::Other => "stop",
};
OpenAiChatResponse {
id: resp.id,
object: "chat.completion".to_string(),
model: resp.model,
choices: vec![OpenAiChoice {
index: 0,
message: OpenAiMessage {
role: "assistant".to_string(),
content: resp.content,
},
finish_reason: Some(finish_reason.to_string()),
}],
usage: OpenAiUsage {
prompt_tokens: resp.usage.input_tokens,
completion_tokens: resp.usage.output_tokens,
total_tokens: resp.usage.input_tokens + resp.usage.output_tokens,
},
}
}
}