use std::{collections::BTreeMap, env};
use crate::{http_config::build_client, preview};
use async_trait::async_trait;
use futures_util::{StreamExt, stream::BoxStream};
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
use rucora_core::{
error::ProviderError,
provider::{
LlmProvider,
types::{
ChatMessage, ChatRequest, ChatResponse, ChatStreamChunk, FinishReason, ResponseFormat,
Role,
},
},
tool::types::{ToolCall, ToolDefinition},
};
use serde_json::{Value, json};
use tracing::debug;
const OPENAI_DEFAULT_MODEL: &str = "gpt-4o-mini";
#[derive(Clone)]
pub struct OpenAiProvider {
client: reqwest::Client,
base_url: String,
default_model: String,
}
impl OpenAiProvider {
fn map_reqwest_error(e: reqwest::Error) -> ProviderError {
if e.is_timeout() {
ProviderError::Timeout {
message: e.to_string(),
elapsed: std::time::Duration::ZERO,
}
} else if e.is_connect() || e.is_request() {
ProviderError::Network {
message: e.to_string(),
source: Some(Box::new(e)),
retriable: true,
}
} else {
ProviderError::Message(e.to_string())
}
}
fn map_http_error(status: reqwest::StatusCode, message: String) -> ProviderError {
match status.as_u16() {
401 | 403 => ProviderError::Authentication { message },
429 => ProviderError::RateLimit {
message,
retry_after: None,
},
status => ProviderError::Api {
status,
message,
code: None,
},
}
}
pub fn from_env() -> Result<Self, ProviderError> {
let api_key = env::var("OPENAI_API_KEY")
.map_err(|_| ProviderError::Message("缺少环境变量 OPENAI_API_KEY".to_string()))?;
let base_url =
env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
let default_model =
env::var("OPENAI_DEFAULT_MODEL").unwrap_or_else(|_| OPENAI_DEFAULT_MODEL.to_string());
Ok(Self::with_model(base_url, api_key, default_model))
}
pub fn new(base_url: impl Into<String>, api_key: impl Into<String>) -> Self {
Self::with_model(base_url, api_key, OPENAI_DEFAULT_MODEL.to_string())
}
pub fn with_model(
base_url: impl Into<String>,
api_key: impl Into<String>,
default_model: impl Into<String>,
) -> Self {
let api_key = api_key.into();
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if let Ok(v) = HeaderValue::from_str(&format!("Bearer {api_key}")) {
headers.insert(AUTHORIZATION, v);
}
let client = build_client(headers);
Self {
client,
base_url: base_url.into(),
default_model: default_model.into(),
}
}
pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
self.default_model = model.into();
self
}
pub fn default_model(&self) -> &str {
&self.default_model
}
fn map_role(role: &Role) -> &'static str {
match role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
fn build_messages(messages: &[ChatMessage]) -> Vec<Value> {
messages
.iter()
.map(|m| {
let mut obj = json!({
"role": Self::map_role(&m.role),
"content": m.content,
});
if let Some(name) = &m.name
&& let Some(map) = obj.as_object_mut()
{
map.insert("name".to_string(), Value::String(name.clone()));
}
obj
})
.collect()
}
fn build_response_format(fmt: &ResponseFormat) -> Value {
match fmt {
ResponseFormat::JsonObject => json!({"type": "json_object"}),
ResponseFormat::JsonSchema {
name,
schema,
strict,
} => {
let mut obj = json!({
"type": "json_schema",
"json_schema": {
"name": name,
"schema": schema,
}
});
if let Some(strict) = strict
&& let Some(root) = obj.as_object_mut()
&& let Some(js) = root.get_mut("json_schema").and_then(|v| v.as_object_mut())
{
js.insert("strict".to_string(), json!(strict));
}
obj
}
}
}
fn build_tools(tools: &[ToolDefinition]) -> Vec<Value> {
tools
.iter()
.map(|t| {
json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.input_schema,
}
})
})
.collect()
}
fn parse_tool_calls(message: &Value) -> Vec<ToolCall> {
let mut out = Vec::new();
let Some(tool_calls) = message.get("tool_calls") else {
return out;
};
let Some(arr) = tool_calls.as_array() else {
return out;
};
for item in arr {
let id = item
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let function = item.get("function").cloned().unwrap_or(Value::Null);
let name = function
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let args_raw = function
.get("arguments")
.and_then(|v| v.as_str())
.unwrap_or("{}");
let input: Value = serde_json::from_str(args_raw).unwrap_or_else(|_| {
Value::String(args_raw.to_string())
});
if !id.is_empty() && !name.is_empty() {
out.push(ToolCall { id, name, input });
}
}
out
}
fn parse_finish_reason(fr: &str) -> FinishReason {
match fr {
"stop" => FinishReason::Stop,
"length" => FinishReason::Length,
"tool_calls" => FinishReason::ToolCall,
_ => FinishReason::Other,
}
}
}
#[async_trait]
impl LlmProvider for OpenAiProvider {
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
let model = request
.model
.clone()
.unwrap_or_else(|| self.default_model.clone());
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let messages = Self::build_messages(&request.messages);
let last_user_preview = request
.messages
.iter()
.rev()
.find(|m| m.role == Role::User)
.map(|m| preview(&m.content, 600));
debug!(
provider = "openai",
url = %url,
model = %model,
messages_len = request.messages.len(),
tools_len = request.tools.as_ref().map_or(0, |t| t.len()),
last_user = last_user_preview.as_deref().unwrap_or(""),
"provider.chat.start"
);
let mut body = json!({
"model": model,
"messages": messages,
});
if let Some(tools) = request.tools.as_ref()
&& let Some(map) = body.as_object_mut()
{
map.insert("tools".to_string(), Value::Array(Self::build_tools(tools)));
}
if let Some(t) = request.temperature
&& let Some(map) = body.as_object_mut()
{
map.insert("temperature".to_string(), json!(t));
}
if let Some(max_tokens) = request.max_tokens
&& let Some(map) = body.as_object_mut()
{
map.insert("max_tokens".to_string(), json!(max_tokens));
}
if let Some(fmt) = request.response_format.as_ref()
&& let Some(map) = body.as_object_mut()
{
map.insert(
"response_format".to_string(),
Self::build_response_format(fmt),
);
}
if let Some(top_p) = request.top_p
&& let Some(map) = body.as_object_mut()
{
map.insert("top_p".to_string(), json!(top_p));
}
if let Some(top_k) = request.top_k
&& let Some(map) = body.as_object_mut()
{
map.insert("top_k".to_string(), json!(top_k));
}
if let Some(frequency_penalty) = request.frequency_penalty
&& let Some(map) = body.as_object_mut()
{
map.insert("frequency_penalty".to_string(), json!(frequency_penalty));
}
if let Some(presence_penalty) = request.presence_penalty
&& let Some(map) = body.as_object_mut()
{
map.insert("presence_penalty".to_string(), json!(presence_penalty));
}
if let Some(stop) = request.stop.as_ref()
&& let Some(map) = body.as_object_mut()
{
map.insert("stop".to_string(), json!(stop));
}
if let Some(extra) = request.extra.as_ref()
&& let Some(map) = body.as_object_mut()
&& let Some(extra_map) = extra.as_object()
{
for (key, value) in extra_map {
map.insert(key.clone(), value.clone());
}
}
debug!(
provider = "openai",
model = %model,
body = %preview(&body.to_string(), 1200),
"provider.chat.request_body"
);
let start = std::time::Instant::now();
let resp = self
.client
.post(url)
.json(&body)
.send()
.await
.map_err(Self::map_reqwest_error)?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| ProviderError::Message(format!("读取响应失败:{e}")))?;
let elapsed_ms = start.elapsed().as_millis() as u64;
debug!(
provider = "openai",
status = %status,
elapsed_ms,
"provider.chat.http.done"
);
debug!(
provider = "openai",
status = %status,
body = %preview(&text, 1200),
"provider.chat.response_body"
);
if !status.is_success() {
let error_msg = if status == reqwest::StatusCode::NOT_FOUND {
format!(
"OpenAI 请求失败:status={} body={} \n\n\
提示:404 错误可能是因为:\n\
1. Base URL 不正确,请检查是否为有效的 API 端点\n\
2. 模型名称不正确,请确认模型在该平台可用\n\
3. API 路径不正确,某些平台可能需要特定的路径格式\n\n\
当前配置:\n\
- Base URL: {}\n\
- Model: {}",
status, text, self.base_url, model
)
} else {
format!("OpenAI 请求失败:status={status} body={text}")
};
return Err(Self::map_http_error(status, error_msg));
}
let data: Value = serde_json::from_str(&text).map_err(|e| {
ProviderError::Message(format!(
"解析响应 JSON 失败:{}。响应内容:{}",
e,
preview(&text, 500)
))
})?;
let message = data
.get("choices")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|c| c.get("message"));
if message.is_none() {
if let Some(error) = data.get("error") {
return Err(ProviderError::Message(format!("API 返回错误:{error}")));
}
return Err(ProviderError::Message(format!(
"OpenAI 响应格式不兼容。响应内容:{}",
preview(&text, 500)
)));
}
let Some(message) = message.cloned() else {
return Err(ProviderError::Message(
"OpenAI 响应缺少 message 字段".to_string(),
));
};
let mut content = message
.get("content")
.and_then(|v| {
if let Some(s) = v.as_str() {
return Some(s.to_string());
}
if let Some(obj) = v.as_object() {
if let Some(text) = obj.get("text").and_then(|t| t.as_str()) {
return Some(text.to_string());
}
}
None
})
.unwrap_or_default();
let tool_calls = Self::parse_tool_calls(&message);
if content.trim().is_empty()
&& tool_calls.is_empty()
&& let Some(r) = message.get("reasoning").and_then(|v| v.as_str())
&& !r.trim().is_empty()
{
content = r.to_string();
}
if !tool_calls.is_empty() {
let names: Vec<&str> = tool_calls.iter().map(|c| c.name.as_str()).collect();
debug!(
provider = "openai",
tool_calls_len = tool_calls.len(),
tool_call_names = ?names,
"provider.chat.tool_calls"
);
}
let usage = data
.get("usage")
.and_then(|u| u.as_object())
.map(|usage_obj| rucora_core::provider::types::Usage {
prompt_tokens: usage_obj
.get("prompt_tokens")
.and_then(|v| v.as_u64())
.map_or(0, |v| v as u32),
completion_tokens: usage_obj
.get("completion_tokens")
.and_then(|v| v.as_u64())
.map_or(0, |v| v as u32),
total_tokens: usage_obj
.get("total_tokens")
.and_then(|v| v.as_u64())
.map_or(0, |v| v as u32),
});
let finish_reason = data
.get("choices")
.and_then(|v| v.as_array())
.and_then(|arr| arr.first())
.and_then(|c| c.get("finish_reason"))
.and_then(|fr| fr.as_str())
.map(Self::parse_finish_reason);
debug!(
provider = "openai",
assistant_content_len = content.len(),
"provider.chat.parsed"
);
Ok(ChatResponse {
message: ChatMessage {
role: Role::Assistant,
content,
name: None,
},
tool_calls,
usage,
finish_reason,
})
}
fn stream_chat(
&self,
request: ChatRequest,
) -> Result<BoxStream<'static, Result<ChatStreamChunk, ProviderError>>, ProviderError> {
let model = request
.model
.clone()
.unwrap_or_else(|| self.default_model.clone());
let client = self.client.clone();
let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/'));
let preview = |s: &str, max: usize| {
if s.len() <= max {
s.to_string()
} else {
let truncated: String = s.char_indices().take(max).map(|(_, c)| c).collect();
format!("{}...<truncated:{}>", truncated, s.len())
}
};
debug!(
provider = "openai",
url = %url,
model = %model,
messages_len = request.messages.len(),
tools_len = request.tools.as_ref().map_or(0, |t| t.len()),
"provider.stream_chat.start"
);
let mut body = json!({
"model": model,
"messages": Self::build_messages(&request.messages),
"stream": true,
});
if let Some(tools) = request.tools.as_ref()
&& let Some(map) = body.as_object_mut()
{
map.insert("tools".to_string(), Value::Array(Self::build_tools(tools)));
}
if let Some(t) = request.temperature
&& let Some(map) = body.as_object_mut()
{
map.insert("temperature".to_string(), json!(t));
}
if let Some(max_tokens) = request.max_tokens
&& let Some(map) = body.as_object_mut()
{
map.insert("max_tokens".to_string(), json!(max_tokens));
}
if let Some(fmt) = request.response_format.as_ref()
&& let Some(map) = body.as_object_mut()
{
map.insert(
"response_format".to_string(),
Self::build_response_format(fmt),
);
}
debug!(
provider = "openai",
model = %model,
body = %preview(&body.to_string(), 1200),
"provider.stream_chat.request_body"
);
let stream = async_stream::try_stream! {
let start = std::time::Instant::now();
let resp = client
.post(url)
.json(&body)
.send()
.await
.map_err(Self::map_reqwest_error)?;
let status = resp.status();
if !status.is_success() {
Err(Self::map_http_error(
status,
format!("OpenAI stream 请求失败:status={status}"),
))?;
}
debug!(
provider = "openai",
status = %status,
elapsed_ms = start.elapsed().as_millis() as u64,
"provider.stream_chat.http.started"
);
let mut buf = String::new();
let mut bytes_stream = resp.bytes_stream();
let mut tool_call_parts: BTreeMap<usize, (String, String, String)> = BTreeMap::new();
while let Some(item) = bytes_stream.next().await {
let bytes = item.map_err(Self::map_reqwest_error)?;
let chunk = String::from_utf8_lossy(&bytes);
buf.push_str(&chunk);
while let Some(idx) = buf.find("\n\n") {
let event = buf[..idx].to_string();
buf = buf[idx + 2..].to_string();
let mut data_lines: Vec<&str> = Vec::new();
for line in event.lines() {
let line = line.trim();
if let Some(rest) = line.strip_prefix("data:") {
data_lines.push(rest.trim());
}
}
if data_lines.is_empty() {
continue;
}
let data = data_lines.join("\n");
if data == "[DONE]" {
break;
}
let v: Value = serde_json::from_str(&data)
.map_err(|e| ProviderError::Message(format!("SSE JSON 解析失败: {e} data={data}")))?;
let choice = v
.get("choices")
.and_then(|c| c.as_array())
.and_then(|arr| arr.first());
let delta_obj = choice.and_then(|c0| c0.get("delta"));
let delta = delta_obj
.and_then(|d| d.get("content"))
.and_then(|s| s.as_str())
.map(|s| s.to_string());
if let Some(tool_calls) = delta_obj
.and_then(|d| d.get("tool_calls"))
.and_then(|tc| tc.as_array())
{
for item in tool_calls {
let index = item.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let entry = tool_call_parts
.entry(index)
.or_insert_with(|| (String::new(), String::new(), String::new()));
if let Some(id) = item.get("id").and_then(|v| v.as_str())
&& !id.is_empty()
{
entry.0 = id.to_string();
}
if let Some(function) = item.get("function") {
if let Some(name) = function.get("name").and_then(|v| v.as_str()) {
entry.1.push_str(name);
}
if let Some(arguments) = function.get("arguments").and_then(|v| v.as_str()) {
entry.2.push_str(arguments);
}
}
}
}
let finish_reason = choice
.and_then(|c0| c0.get("finish_reason"))
.and_then(|fr| fr.as_str())
.map(Self::parse_finish_reason);
if delta.is_some() {
yield ChatStreamChunk {
delta,
tool_calls: vec![],
usage: None,
finish_reason: finish_reason.clone(),
};
}
if matches!(finish_reason, Some(FinishReason::ToolCall)) {
let tool_calls = tool_call_parts
.values()
.filter_map(|(id, name, args_raw)| {
if name.is_empty() {
return None;
}
let input = serde_json::from_str(args_raw)
.unwrap_or_else(|_| Value::String(args_raw.clone()));
Some(ToolCall {
id: id.clone(),
name: name.clone(),
input,
})
})
.collect::<Vec<_>>();
if !tool_calls.is_empty() {
yield ChatStreamChunk {
delta: None,
tool_calls,
usage: None,
finish_reason,
};
}
}
}
if buf.contains("[DONE]") {
break;
}
}
};
Ok(Box::pin(stream))
}
}