llm_link/api/
openai.rs

1use axum::{
2    extract::{Query, State},
3    http::{HeaderMap, StatusCode},
4    response::{IntoResponse, Json},
5    response::Response,
6    body::Body,
7};
8use futures::StreamExt;
9use serde::Deserialize;
10use serde_json::{json, Value};
11use std::convert::Infallible;
12use tracing::{info, warn, error};
13
14use crate::adapters::{ClientAdapter, FormatDetector};
15use crate::api::{AppState, convert};
16
17#[derive(Debug, Deserialize)]
18pub struct OpenAIChatRequest {
19    pub model: String,
20    pub messages: Vec<Value>,
21    pub stream: Option<bool>,
22    #[allow(dead_code)]
23    pub max_tokens: Option<u32>,
24    #[allow(dead_code)]
25    pub temperature: Option<f32>,
26    #[allow(dead_code)]
27    pub tools: Option<Vec<Value>>,
28    #[allow(dead_code)]
29    pub tool_choice: Option<Value>,
30}
31
32#[derive(Debug, Deserialize)]
33pub struct OpenAIModelsParams {
34    // OpenAI models endpoint parameters (if any)
35}
36
37/// OpenAI Chat Completions API
38pub async fn chat(
39    headers: HeaderMap,
40    State(state): State<AppState>,
41    Json(request): Json<OpenAIChatRequest>,
42) -> Result<Response, StatusCode> {
43    // API Key 校验
44    enforce_api_key(&headers, &state)?;
45
46    info!("📝 Received request - model: {}, stream: {:?}, messages count: {}",
47          request.model, request.stream, request.messages.len());
48
49    // 验证模型
50    if !request.model.is_empty() {
51        let validation_result = {
52            let llm_service = state.llm_service.read().unwrap();
53            llm_service.validate_model(&request.model).await
54        };
55
56        match validation_result {
57            Ok(false) => {
58                error!("❌ Model validation failed: model '{}' not found", request.model);
59                return Err(StatusCode::BAD_REQUEST);
60            }
61            Err(e) => {
62                error!("❌ Model validation error: {:?}", e);
63                return Err(StatusCode::INTERNAL_SERVER_ERROR);
64            }
65            Ok(true) => {
66                info!("✅ Model '{}' validated successfully", request.model);
67            }
68        }
69    }
70
71    // 转换消息格式
72    match convert::openai_messages_to_llm(request.messages) {
73        Ok(messages) => {
74            info!("✅ Successfully converted {} messages", messages.len());
75            let model = if request.model.is_empty() { None } else { Some(request.model.as_str()) };
76
77            // 转换 tools 格式
78            let tools = request.tools.map(|t| convert::openai_tools_to_llm(t));
79            if tools.is_some() {
80                info!("🔧 Request includes {} tools", tools.as_ref().unwrap().len());
81                // Debug: log the first tool
82                if let Some(first_tool) = tools.as_ref().unwrap().first() {
83                    info!("🔧 First tool: {:?}", serde_json::to_value(first_tool).ok());
84                }
85            }
86
87            // 直接使用请求指定的模式(流式或非流式)
88            // 等待 llm-connector 修复流式 tool_calls 解析问题
89            if request.stream.unwrap_or(false) {
90                handle_streaming_request(headers, state, model, messages, tools).await
91            } else {
92                handle_non_streaming_request(state, model, messages, tools).await
93            }
94        }
95        Err(e) => {
96            error!("❌ Failed to convert OpenAI messages: {:?}", e);
97            Err(StatusCode::BAD_REQUEST)
98        }
99    }
100}
101
102/// 处理流式请求
103async fn handle_streaming_request(
104    headers: HeaderMap,
105    state: AppState,
106    model: Option<&str>,
107    messages: Vec<llm_connector::types::Message>,
108    tools: Option<Vec<llm_connector::types::Tool>>,
109) -> Result<Response, StatusCode> {
110    // 🎯 检测客户端类型(默认使用 OpenAI 适配器)
111    let config = state.config.read().unwrap();
112    let client_adapter = detect_openai_client(&headers, &config);
113    let (_stream_format, _) = FormatDetector::determine_format(&headers);
114    drop(config); // 释放读锁
115    
116    // 使用客户端偏好格式(SSE)
117    let final_format = client_adapter.preferred_format();
118    let content_type = FormatDetector::get_content_type(final_format);
119
120    info!("📡 Starting OpenAI streaming response - Format: {:?} ({})", final_format, content_type);
121
122    let stream_result = {
123        let llm_service = state.llm_service.read().unwrap();
124        llm_service.chat_stream_openai(model, messages.clone(), tools.clone(), final_format).await
125    };
126
127    match stream_result {
128        Ok(rx) => {
129            info!("✅ OpenAI streaming response started successfully");
130
131            let config_clone = state.config.clone();
132            let adapted_stream = rx.map(move |data| {
133                // SSE 格式的数据以 "data: " 开头,需要先提取 JSON 部分
134                let json_str = if data.starts_with("data: ") {
135                    &data[6..] // 去掉 "data: " 前缀
136                } else {
137                    &data
138                };
139
140                // 跳过空行和 [DONE] 标记
141                if json_str.trim().is_empty() || json_str.trim() == "[DONE]" {
142                    return data.to_string();
143                }
144
145                // 解析并适配响应数据
146                if let Ok(mut json_data) = serde_json::from_str::<Value>(json_str) {
147                    tracing::debug!("📝 Parsed JSON chunk, applying adaptations...");
148                    let config = config_clone.read().unwrap();
149                    client_adapter.apply_response_adaptations(&config, &mut json_data);
150
151                    match final_format {
152                        llm_connector::StreamFormat::SSE => {
153                            format!("data: {}\n\n", json_data)
154                        }
155                        llm_connector::StreamFormat::NDJSON => {
156                            format!("{}\n", json_data)
157                        }
158                        llm_connector::StreamFormat::Json => {
159                            json_data.to_string()
160                        }
161                    }
162                } else {
163                    tracing::debug!("⚠️ Failed to parse chunk as JSON: {}", json_str);
164                    data.to_string()
165                }
166            });
167
168            let body_stream = adapted_stream.map(|data| Ok::<_, Infallible>(data));
169            let body = Body::from_stream(body_stream);
170
171            let response = Response::builder()
172                .status(200)
173                .header("content-type", content_type)
174                .header("cache-control", "no-cache")
175                .body(body)
176                .unwrap();
177
178            Ok(response)
179        }
180        Err(e) => {
181            warn!("⚠️ OpenAI streaming failed, falling back to non-streaming: {:?}", e);
182            handle_non_streaming_request(state, model, messages, tools).await
183        }
184    }
185}
186
187/// 处理非流式请求
188async fn handle_non_streaming_request(
189    state: AppState,
190    model: Option<&str>,
191    messages: Vec<llm_connector::types::Message>,
192    tools: Option<Vec<llm_connector::types::Tool>>,
193) -> Result<Response, StatusCode> {
194    let chat_result = {
195        let llm_service = state.llm_service.read().unwrap();
196        llm_service.chat(model, messages, tools).await
197    };
198
199    match chat_result {
200        Ok(response) => {
201            let openai_response = convert::response_to_openai(response);
202            Ok(Json(openai_response).into_response())
203        }
204        Err(e) => {
205            error!("❌ OpenAI chat request failed: {:?}", e);
206            Err(StatusCode::INTERNAL_SERVER_ERROR)
207        }
208    }
209}
210
211/// OpenAI Models API
212pub async fn models(
213    headers: HeaderMap,
214    State(state): State<AppState>,
215    Query(_params): Query<OpenAIModelsParams>,
216) -> Result<impl IntoResponse, StatusCode> {
217    enforce_api_key(&headers, &state)?;
218
219    let models_result = {
220        let llm_service = state.llm_service.read().unwrap();
221        llm_service.list_models().await
222    };
223
224    match models_result {
225        Ok(models) => {
226            let openai_models: Vec<Value> = models.into_iter().map(|model| {
227                json!({
228                    "id": model.id,
229                    "object": "model",
230                    "created": chrono::Utc::now().timestamp(),
231                    "owned_by": "system"
232                })
233            }).collect();
234
235            let config = state.config.read().unwrap();
236            let current_provider = match &config.llm_backend {
237                crate::settings::LlmBackendSettings::OpenAI { .. } => "openai",
238                crate::settings::LlmBackendSettings::Anthropic { .. } => "anthropic",
239                crate::settings::LlmBackendSettings::Zhipu { .. } => "zhipu",
240                crate::settings::LlmBackendSettings::Ollama { .. } => "ollama",
241                crate::settings::LlmBackendSettings::Aliyun { .. } => "aliyun",
242                crate::settings::LlmBackendSettings::Volcengine { .. } => "volcengine",
243                crate::settings::LlmBackendSettings::Tencent { .. } => "tencent",
244            };
245
246            let response = json!({
247                "object": "list",
248                "data": openai_models,
249                "provider": current_provider,
250            });
251            Ok(Json(response))
252        }
253        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
254    }
255}
256
257/// OpenAI API Key 认证
258fn enforce_api_key(headers: &HeaderMap, state: &AppState) -> Result<(), StatusCode> {
259    let config = state.config.read().unwrap();
260    if let Some(cfg) = &config.apis.openai {
261        if cfg.enabled {
262            if let Some(expected_key) = cfg.api_key.as_ref() {
263                let header_name = cfg.api_key_header.as_deref().unwrap_or("authorization").to_ascii_lowercase();
264                
265                let value_opt = if header_name == "authorization" {
266                    headers.get(axum::http::header::AUTHORIZATION)
267                } else {
268                    match axum::http::HeaderName::from_bytes(header_name.as_bytes()) {
269                        Ok(name) => headers.get(name),
270                        Err(_) => None,
271                    }
272                };
273
274                if let Some(value) = value_opt {
275                    if let Ok(value_str) = value.to_str() {
276                        let token = if value_str.starts_with("Bearer ") {
277                            &value_str[7..]
278                        } else {
279                            value_str
280                        };
281
282                        if token == expected_key {
283                            info!("✅ OpenAI API key authentication successful");
284                            return Ok(());
285                        }
286                    }
287                }
288
289                warn!("🚫 OpenAI API key authentication failed");
290                return Err(StatusCode::UNAUTHORIZED);
291            }
292        }
293    }
294    Ok(())
295}
296
297/// 检测 OpenAI 客户端类型
298fn detect_openai_client(_headers: &HeaderMap, _config: &crate::settings::Settings) -> ClientAdapter {
299    // OpenAI API 总是使用 OpenAI 适配器
300    ClientAdapter::OpenAI
301}