llm_link/api/
openai.rs

1use axum::{
2    extract::{Query, State},
3    http::{HeaderMap, StatusCode},
4    response::{IntoResponse, Json},
5    response::Response,
6    body::Body,
7};
8use futures::StreamExt;
9use serde::Deserialize;
10use serde_json::{json, Value};
11use std::convert::Infallible;
12use tracing::{info, warn, error};
13
14use crate::adapters::{ClientAdapter, FormatDetector};
15use crate::api::{AppState, convert};
16
17#[derive(Debug, Deserialize)]
18pub struct OpenAIChatRequest {
19    pub model: String,
20    pub messages: Vec<Value>,
21    pub stream: Option<bool>,
22    #[allow(dead_code)]
23    pub max_tokens: Option<u32>,
24    #[allow(dead_code)]
25    pub temperature: Option<f32>,
26    #[allow(dead_code)]
27    pub tools: Option<Vec<Value>>,
28    #[allow(dead_code)]
29    pub tool_choice: Option<Value>,
30}
31
32#[derive(Debug, Deserialize)]
33pub struct OpenAIModelsParams {
34    // OpenAI models endpoint parameters (if any)
35}
36
37/// OpenAI Chat Completions API
38pub async fn chat(
39    headers: HeaderMap,
40    State(state): State<AppState>,
41    Json(request): Json<OpenAIChatRequest>,
42) -> Result<Response, StatusCode> {
43    // API Key 校验
44    enforce_api_key(&headers, &state)?;
45
46    info!("📝 Received request - model: {}, stream: {:?}, messages count: {}",
47          request.model, request.stream, request.messages.len());
48
49    // 验证模型
50    if !request.model.is_empty() {
51        match state.llm_service.validate_model(&request.model).await {
52            Ok(false) => {
53                error!("❌ Model validation failed: model '{}' not found", request.model);
54                return Err(StatusCode::BAD_REQUEST);
55            }
56            Err(e) => {
57                error!("❌ Model validation error: {:?}", e);
58                return Err(StatusCode::INTERNAL_SERVER_ERROR);
59            }
60            Ok(true) => {
61                info!("✅ Model '{}' validated successfully", request.model);
62            }
63        }
64    }
65
66    // 转换消息格式
67    match convert::openai_messages_to_llm(request.messages) {
68        Ok(messages) => {
69            info!("✅ Successfully converted {} messages", messages.len());
70            let model = if request.model.is_empty() { None } else { Some(request.model.as_str()) };
71
72            // 转换 tools 格式
73            let tools = request.tools.map(|t| convert::openai_tools_to_llm(t));
74            if tools.is_some() {
75                info!("🔧 Request includes {} tools", tools.as_ref().unwrap().len());
76                // Debug: log the first tool
77                if let Some(first_tool) = tools.as_ref().unwrap().first() {
78                    info!("🔧 First tool: {:?}", serde_json::to_value(first_tool).ok());
79                }
80            }
81
82            // 直接使用请求指定的模式(流式或非流式)
83            // 等待 llm-connector 修复流式 tool_calls 解析问题
84            if request.stream.unwrap_or(false) {
85                handle_streaming_request(headers, state, model, messages, tools).await
86            } else {
87                handle_non_streaming_request(state, model, messages, tools).await
88            }
89        }
90        Err(e) => {
91            error!("❌ Failed to convert OpenAI messages: {:?}", e);
92            Err(StatusCode::BAD_REQUEST)
93        }
94    }
95}
96
97/// 处理流式请求
98async fn handle_streaming_request(
99    headers: HeaderMap,
100    state: AppState,
101    model: Option<&str>,
102    messages: Vec<llm_connector::types::Message>,
103    tools: Option<Vec<llm_connector::types::Tool>>,
104) -> Result<Response, StatusCode> {
105    // 🎯 检测客户端类型(默认使用 OpenAI 适配器)
106    let client_adapter = detect_openai_client(&headers, &state.config);
107    let (_stream_format, _) = FormatDetector::determine_format(&headers);
108    
109    // 使用客户端偏好格式(SSE)
110    let final_format = client_adapter.preferred_format();
111    let content_type = FormatDetector::get_content_type(final_format);
112
113    info!("📡 Starting OpenAI streaming response - Format: {:?} ({})", final_format, content_type);
114
115    match state.llm_service.chat_stream_openai(model, messages.clone(), tools.clone(), final_format).await {
116        Ok(rx) => {
117            info!("✅ OpenAI streaming response started successfully");
118
119            let config = state.config.clone();
120            let adapted_stream = rx.map(move |data| {
121                // SSE 格式的数据以 "data: " 开头,需要先提取 JSON 部分
122                let json_str = if data.starts_with("data: ") {
123                    &data[6..] // 去掉 "data: " 前缀
124                } else {
125                    &data
126                };
127
128                // 跳过空行和 [DONE] 标记
129                if json_str.trim().is_empty() || json_str.trim() == "[DONE]" {
130                    return data.to_string();
131                }
132
133                // 解析并适配响应数据
134                if let Ok(mut json_data) = serde_json::from_str::<Value>(json_str) {
135                    tracing::debug!("📝 Parsed JSON chunk, applying adaptations...");
136                    client_adapter.apply_response_adaptations(&config, &mut json_data);
137
138                    match final_format {
139                        llm_connector::StreamFormat::SSE => {
140                            format!("data: {}\n\n", json_data)
141                        }
142                        llm_connector::StreamFormat::NDJSON => {
143                            format!("{}\n", json_data)
144                        }
145                        llm_connector::StreamFormat::Json => {
146                            json_data.to_string()
147                        }
148                    }
149                } else {
150                    tracing::debug!("⚠️ Failed to parse chunk as JSON: {}", json_str);
151                    data.to_string()
152                }
153            });
154
155            let body_stream = adapted_stream.map(|data| Ok::<_, Infallible>(data));
156            let body = Body::from_stream(body_stream);
157
158            let response = Response::builder()
159                .status(200)
160                .header("content-type", content_type)
161                .header("cache-control", "no-cache")
162                .body(body)
163                .unwrap();
164
165            Ok(response)
166        }
167        Err(e) => {
168            warn!("⚠️ OpenAI streaming failed, falling back to non-streaming: {:?}", e);
169            handle_non_streaming_request(state, model, messages, tools).await
170        }
171    }
172}
173
174/// 处理非流式请求
175async fn handle_non_streaming_request(
176    state: AppState,
177    model: Option<&str>,
178    messages: Vec<llm_connector::types::Message>,
179    tools: Option<Vec<llm_connector::types::Tool>>,
180) -> Result<Response, StatusCode> {
181    match state.llm_service.chat(model, messages, tools).await {
182        Ok(response) => {
183            let openai_response = convert::response_to_openai(response);
184            Ok(Json(openai_response).into_response())
185        }
186        Err(e) => {
187            error!("❌ OpenAI chat request failed: {:?}", e);
188            Err(StatusCode::INTERNAL_SERVER_ERROR)
189        }
190    }
191}
192
193/// OpenAI Models API
194pub async fn models(
195    headers: HeaderMap,
196    State(state): State<AppState>,
197    Query(_params): Query<OpenAIModelsParams>,
198) -> Result<impl IntoResponse, StatusCode> {
199    enforce_api_key(&headers, &state)?;
200    
201    match state.llm_service.list_models().await {
202        Ok(models) => {
203            let openai_models: Vec<Value> = models.into_iter().map(|model| {
204                json!({
205                    "id": model.id,
206                    "object": "model",
207                    "created": chrono::Utc::now().timestamp(),
208                    "owned_by": "system"
209                })
210            }).collect();
211
212            let current_provider = match &state.config.llm_backend {
213                crate::settings::LlmBackendSettings::OpenAI { .. } => "openai",
214                crate::settings::LlmBackendSettings::Anthropic { .. } => "anthropic",
215                crate::settings::LlmBackendSettings::Zhipu { .. } => "zhipu",
216                crate::settings::LlmBackendSettings::Ollama { .. } => "ollama",
217                crate::settings::LlmBackendSettings::Aliyun { .. } => "aliyun",
218                crate::settings::LlmBackendSettings::Volcengine { .. } => "volcengine",
219                crate::settings::LlmBackendSettings::Tencent { .. } => "tencent",
220            };
221
222            let response = json!({
223                "object": "list",
224                "data": openai_models,
225                "provider": current_provider,
226            });
227            Ok(Json(response))
228        }
229        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
230    }
231}
232
233/// OpenAI API Key 认证
234fn enforce_api_key(headers: &HeaderMap, state: &AppState) -> Result<(), StatusCode> {
235    if let Some(cfg) = &state.config.apis.openai {
236        if cfg.enabled {
237            if let Some(expected_key) = cfg.api_key.as_ref() {
238                let header_name = cfg.api_key_header.as_deref().unwrap_or("authorization").to_ascii_lowercase();
239                
240                let value_opt = if header_name == "authorization" {
241                    headers.get(axum::http::header::AUTHORIZATION)
242                } else {
243                    match axum::http::HeaderName::from_bytes(header_name.as_bytes()) {
244                        Ok(name) => headers.get(name),
245                        Err(_) => None,
246                    }
247                };
248
249                if let Some(value) = value_opt {
250                    if let Ok(value_str) = value.to_str() {
251                        let token = if value_str.starts_with("Bearer ") {
252                            &value_str[7..]
253                        } else {
254                            value_str
255                        };
256
257                        if token == expected_key {
258                            info!("✅ OpenAI API key authentication successful");
259                            return Ok(());
260                        }
261                    }
262                }
263
264                warn!("🚫 OpenAI API key authentication failed");
265                return Err(StatusCode::UNAUTHORIZED);
266            }
267        }
268    }
269    Ok(())
270}
271
272/// 检测 OpenAI 客户端类型
273fn detect_openai_client(_headers: &HeaderMap, _config: &crate::settings::Settings) -> ClientAdapter {
274    // OpenAI API 总是使用 OpenAI 适配器
275    ClientAdapter::OpenAI
276}