Skip to main content

mofa_foundation/llm/
openai.rs

1//! OpenAI Provider Implementation
2//!
3//! 使用 `async-openai` crate 实现 OpenAI API 交互
4//!
5//! # 支持的服务
6//!
7//! - OpenAI API (api.openai.com)
8//! - Azure OpenAI
9//! - 兼容 OpenAI API 的本地服务 (Ollama, vLLM, LocalAI 等)
10//!
11//! # 示例
12//!
13//! ```rust,ignore
14//! use mofa_foundation::llm::openai::{OpenAIProvider, OpenAIConfig};
15//!
16//! // 使用 OpenAI
17//! let provider = OpenAIProvider::new("sk-xxx");
18//!
19//! // 使用自定义 endpoint
20//! let provider = OpenAIProvider::with_config(
21//!     OpenAIConfig::new("sk-xxx")
22//!         .with_base_url("http://localhost:11434/v1")
23//!         .with_model("llama2")
24//! );
25//!
26//! // 使用 Azure OpenAI
27//! let provider = OpenAIProvider::azure("https://xxx.openai.azure.com", "api-key", "deployment");
28//! ```
29
30use super::provider::{ChatStream, LLMProvider, ModelCapabilities, ModelInfo};
31use super::types::*;
32use async_openai::{
33    Client,
34    config::OpenAIConfig as AsyncOpenAIConfig,
35    types::{
36        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
37        ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage,
38        ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs,
39        ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
40        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
41        ChatCompletionToolArgs, ChatCompletionToolChoiceOption, ChatCompletionToolType,
42        CreateChatCompletionRequestArgs, FunctionObjectArgs, ImageDetail as OpenAIImageDetail,
43        ImageUrl as OpenAIImageUrl, InputAudio, InputAudioFormat,
44    },
45};
46use async_trait::async_trait;
47use futures::StreamExt;
48
49/// OpenAI Provider 配置
50#[derive(Debug, Clone)]
51pub struct OpenAIConfig {
52    /// API Key
53    pub api_key: String,
54    /// API 基础 URL
55    pub base_url: Option<String>,
56    /// 组织 ID
57    pub org_id: Option<String>,
58    /// 默认模型
59    pub default_model: String,
60    /// 默认温度
61    pub default_temperature: f32,
62    /// 默认最大 token 数
63    pub default_max_tokens: u32,
64    /// 请求超时(秒)
65    pub timeout_secs: u64,
66}
67
68impl Default for OpenAIConfig {
69    fn default() -> Self {
70        Self {
71            api_key: String::new(),
72            base_url: None,
73            org_id: None,
74            default_model: "gpt-4o".to_string(),
75            default_temperature: 0.7,
76            default_max_tokens: 4096,
77            timeout_secs: 60,
78        }
79    }
80}
81
82impl OpenAIConfig {
83    /// 创建新配置
84    pub fn new(api_key: impl Into<String>) -> Self {
85        Self {
86            api_key: api_key.into(),
87            ..Default::default()
88        }
89    }
90
91    /// 从环境变量创建配置
92    pub fn from_env() -> Self {
93        Self {
94            api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
95            base_url: std::env::var("OPENAI_BASE_URL").ok(),
96            default_model: std::env::var("OPENAI_MODEL").unwrap_or_default(),
97            ..Default::default()
98        }
99    }
100
101    /// 设置 base URL
102    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
103        self.base_url = Some(url.into());
104        self
105    }
106
107    /// 设置默认模型
108    pub fn with_model(mut self, model: impl Into<String>) -> Self {
109        self.default_model = model.into();
110        self
111    }
112
113    /// 设置默认温度
114    pub fn with_temperature(mut self, temp: f32) -> Self {
115        self.default_temperature = temp;
116        self
117    }
118
119    /// 设置默认最大 token 数
120    pub fn with_max_tokens(mut self, tokens: u32) -> Self {
121        self.default_max_tokens = tokens;
122        self
123    }
124
125    /// 设置组织 ID
126    pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
127        self.org_id = Some(org_id.into());
128        self
129    }
130
131    /// 设置超时
132    pub fn with_timeout(mut self, secs: u64) -> Self {
133        self.timeout_secs = secs;
134        self
135    }
136}
137
138/// OpenAI LLM Provider
139///
140/// 支持 OpenAI API 及兼容服务
141pub struct OpenAIProvider {
142    client: Client<AsyncOpenAIConfig>,
143    config: OpenAIConfig,
144}
145
146impl OpenAIProvider {
147    /// 使用 API Key 创建 Provider
148    pub fn new(api_key: impl Into<String>) -> Self {
149        let config = OpenAIConfig::new(api_key);
150        Self::with_config(config)
151    }
152
153    /// 从环境变量创建 Provider
154    pub fn from_env() -> Self {
155        Self::with_config(OpenAIConfig::from_env())
156    }
157
158    /// 使用配置创建 Provider
159    pub fn with_config(config: OpenAIConfig) -> Self {
160        let mut openai_config = AsyncOpenAIConfig::new().with_api_key(&config.api_key);
161
162        if let Some(ref base_url) = config.base_url {
163            openai_config = openai_config.with_api_base(base_url);
164        }
165
166        if let Some(ref org_id) = config.org_id {
167            openai_config = openai_config.with_org_id(org_id);
168        }
169
170        let client = Client::with_config(openai_config);
171
172        Self { client, config }
173    }
174
175    /// 创建 Azure OpenAI Provider
176    pub fn azure(
177        endpoint: impl Into<String>,
178        api_key: impl Into<String>,
179        deployment: impl Into<String>,
180    ) -> Self {
181        let endpoint = endpoint.into();
182        let deployment = deployment.into();
183
184        // Azure OpenAI 使用不同的 URL 格式
185        let base_url = format!(
186            "{}/openai/deployments/{}",
187            endpoint.trim_end_matches('/'),
188            deployment
189        );
190
191        let config = OpenAIConfig::new(api_key)
192            .with_base_url(base_url)
193            .with_model(deployment);
194
195        Self::with_config(config)
196    }
197
198    /// 创建兼容 OpenAI API 的本地服务 Provider
199    pub fn local(base_url: impl Into<String>, model: impl Into<String>) -> Self {
200        let config = OpenAIConfig::new("not-needed")
201            .with_base_url(base_url)
202            .with_model(model);
203
204        Self::with_config(config)
205    }
206
207    /// 获取底层 async-openai 客户端
208    pub fn client(&self) -> &Client<AsyncOpenAIConfig> {
209        &self.client
210    }
211
212    /// 获取配置
213    pub fn config(&self) -> &OpenAIConfig {
214        &self.config
215    }
216
217    /// 转换消息格式
218    fn convert_messages(
219        messages: &[ChatMessage],
220    ) -> Result<Vec<ChatCompletionRequestMessage>, LLMError> {
221        messages.iter().map(Self::convert_message).collect()
222    }
223
224    /// 转换单个消息
225    fn convert_message(msg: &ChatMessage) -> Result<ChatCompletionRequestMessage, LLMError> {
226        let text_only_content = msg
227            .content
228            .as_ref()
229            .map(|c| match c {
230                MessageContent::Text(s) => s.clone(),
231                MessageContent::Parts(parts) => parts
232                    .iter()
233                    .filter_map(|p| match p {
234                        ContentPart::Text { text } => Some(text.clone()),
235                        _ => None,
236                    })
237                    .collect::<Vec<_>>()
238                    .join("\n"),
239            })
240            .unwrap_or_default();
241
242        match msg.role {
243            Role::System => Ok(ChatCompletionRequestSystemMessageArgs::default()
244                .content(text_only_content)
245                .build()
246                .map_err(|e| LLMError::Other(e.to_string()))?
247                .into()),
248            Role::User => {
249                let content = match msg.content.as_ref() {
250                    Some(MessageContent::Text(s)) => {
251                        ChatCompletionRequestUserMessageContent::Text(s.clone())
252                    }
253                    Some(MessageContent::Parts(parts)) => {
254                        let mut out = Vec::new();
255                        for part in parts {
256                            match part {
257                                ContentPart::Text { text } => {
258                                    out.push(ChatCompletionRequestUserMessageContentPart::Text(
259                                        ChatCompletionRequestMessageContentPartText {
260                                            text: text.clone(),
261                                        },
262                                    ));
263                                }
264                                ContentPart::Image { image_url } => {
265                                    let detail = image_url.detail.as_ref().map(|d| match d {
266                                        ImageDetail::Auto => OpenAIImageDetail::Auto,
267                                        ImageDetail::Low => OpenAIImageDetail::Low,
268                                        ImageDetail::High => OpenAIImageDetail::High,
269                                    });
270                                    let image_part = ChatCompletionRequestMessageContentPartImage {
271                                        image_url: OpenAIImageUrl {
272                                            url: image_url.url.clone(),
273                                            detail,
274                                        },
275                                    };
276                                    out.push(
277                                        ChatCompletionRequestUserMessageContentPart::ImageUrl(
278                                            image_part,
279                                        ),
280                                    );
281                                }
282                                ContentPart::Audio { audio } => {
283                                    let format = match audio.format.to_lowercase().as_str() {
284                                        "wav" => InputAudioFormat::Wav,
285                                        _ => InputAudioFormat::Mp3,
286                                    };
287                                    let audio_part = ChatCompletionRequestMessageContentPartAudio {
288                                        input_audio: InputAudio {
289                                            data: audio.data.clone(),
290                                            format,
291                                        },
292                                    };
293                                    out.push(
294                                        ChatCompletionRequestUserMessageContentPart::InputAudio(
295                                            audio_part,
296                                        ),
297                                    );
298                                }
299                            }
300                        }
301                        ChatCompletionRequestUserMessageContent::Array(out)
302                    }
303                    None => ChatCompletionRequestUserMessageContent::Text(String::new()),
304                };
305
306                Ok(ChatCompletionRequestUserMessageArgs::default()
307                    .content(content)
308                    .build()
309                    .map_err(|e| LLMError::Other(e.to_string()))?
310                    .into())
311            }
312            Role::Assistant => {
313                let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
314                if !text_only_content.is_empty() {
315                    builder.content(text_only_content);
316                }
317
318                // 处理工具调用
319                if let Some(ref tool_calls) = msg.tool_calls {
320                    let converted_calls: Vec<_> = tool_calls
321                        .iter()
322                        .map(|tc| async_openai::types::ChatCompletionMessageToolCall {
323                            id: tc.id.clone(),
324                            r#type: ChatCompletionToolType::Function,
325                            function: async_openai::types::FunctionCall {
326                                name: tc.function.name.clone(),
327                                arguments: tc.function.arguments.clone(),
328                            },
329                        })
330                        .collect();
331                    builder.tool_calls(converted_calls);
332                }
333
334                Ok(builder
335                    .build()
336                    .map_err(|e| LLMError::Other(e.to_string()))?
337                    .into())
338            }
339            Role::Tool => {
340                let tool_call_id = msg
341                    .tool_call_id
342                    .clone()
343                    .unwrap_or_else(|| "unknown".to_string());
344
345                Ok(ChatCompletionRequestToolMessageArgs::default()
346                    .tool_call_id(tool_call_id)
347                    .content(text_only_content)
348                    .build()
349                    .map_err(|e| LLMError::Other(e.to_string()))?
350                    .into())
351            }
352        }
353    }
354
355    /// 转换工具定义
356    fn convert_tools(
357        tools: &[Tool],
358    ) -> Result<Vec<async_openai::types::ChatCompletionTool>, LLMError> {
359        tools
360            .iter()
361            .map(|tool| {
362                let function = FunctionObjectArgs::default()
363                    .name(&tool.function.name)
364                    .description(tool.function.description.clone().unwrap_or_default())
365                    .parameters(
366                        tool.function
367                            .parameters
368                            .clone()
369                            .unwrap_or(serde_json::json!({})),
370                    )
371                    .build()
372                    .map_err(|e| LLMError::Other(e.to_string()))?;
373
374                ChatCompletionToolArgs::default()
375                    .r#type(ChatCompletionToolType::Function)
376                    .function(function)
377                    .build()
378                    .map_err(|e| LLMError::Other(e.to_string()))
379            })
380            .collect()
381    }
382
383    /// 转换响应
384    fn convert_response(
385        response: async_openai::types::CreateChatCompletionResponse,
386    ) -> ChatCompletionResponse {
387        let choices: Vec<Choice> = response
388            .choices
389            .into_iter()
390            .map(|choice| {
391                let message = Self::convert_response_message(choice.message);
392                let finish_reason = choice.finish_reason.map(|r| match r {
393                    async_openai::types::FinishReason::Stop => FinishReason::Stop,
394                    async_openai::types::FinishReason::Length => FinishReason::Length,
395                    async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
396                    async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
397                    async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
398                });
399
400                Choice {
401                    index: choice.index,
402                    message,
403                    finish_reason,
404                    logprobs: None,
405                }
406            })
407            .collect();
408
409        let usage = response.usage.map(|u| Usage {
410            prompt_tokens: u.prompt_tokens,
411            completion_tokens: u.completion_tokens,
412            total_tokens: u.total_tokens,
413        });
414
415        ChatCompletionResponse {
416            id: response.id,
417            object: response.object,
418            created: response.created as u64,
419            model: response.model,
420            choices,
421            usage,
422            system_fingerprint: response.system_fingerprint,
423        }
424    }
425
426    /// 转换响应消息
427    fn convert_response_message(
428        msg: async_openai::types::ChatCompletionResponseMessage,
429    ) -> ChatMessage {
430        let content = msg.content.map(MessageContent::Text);
431
432        let tool_calls = msg.tool_calls.map(|calls| {
433            calls
434                .into_iter()
435                .map(|tc| ToolCall {
436                    id: tc.id,
437                    call_type: "function".to_string(),
438                    function: FunctionCall {
439                        name: tc.function.name,
440                        arguments: tc.function.arguments,
441                    },
442                })
443                .collect()
444        });
445
446        ChatMessage {
447            role: Role::Assistant,
448            content,
449            name: None,
450            tool_calls,
451            tool_call_id: None,
452        }
453    }
454}
455
456#[async_trait]
457impl LLMProvider for OpenAIProvider {
458    fn name(&self) -> &str {
459        "openai"
460    }
461
462    fn default_model(&self) -> &str {
463        &self.config.default_model
464    }
465
466    fn supported_models(&self) -> Vec<&str> {
467        vec![
468            "gpt-4o",
469            "gpt-4o-mini",
470            "gpt-4-turbo",
471            "gpt-4",
472            "gpt-3.5-turbo",
473            "o1",
474            "o1-mini",
475            "o1-preview",
476        ]
477    }
478
479    fn supports_streaming(&self) -> bool {
480        true
481    }
482
483    fn supports_tools(&self) -> bool {
484        true
485    }
486
487    fn supports_vision(&self) -> bool {
488        true
489    }
490
491    fn supports_embedding(&self) -> bool {
492        true
493    }
494
495    async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
496        let messages = Self::convert_messages(&request.messages)?;
497
498        let model = if request.model.is_empty() {
499            self.config.default_model.clone()
500        } else {
501            request.model.clone()
502        };
503
504        let mut builder = CreateChatCompletionRequestArgs::default();
505        builder.model(&model).messages(messages);
506
507        // 设置可选参数
508        if let Some(temp) = request.temperature {
509            builder.temperature(temp);
510        } else {
511            builder.temperature(self.config.default_temperature);
512        }
513
514        if let Some(max_tokens) = request.max_tokens {
515            builder.max_tokens(max_tokens);
516        }
517
518        if let Some(top_p) = request.top_p {
519            builder.top_p(top_p);
520        }
521
522        if let Some(ref stop) = request.stop {
523            builder.stop(stop.clone());
524        }
525
526        if let Some(freq_penalty) = request.frequency_penalty {
527            builder.frequency_penalty(freq_penalty);
528        }
529
530        if let Some(pres_penalty) = request.presence_penalty {
531            builder.presence_penalty(pres_penalty);
532        }
533
534        if let Some(ref user) = request.user {
535            builder.user(user);
536        }
537
538        // 设置工具
539        if let Some(ref tools) = request.tools
540            && !tools.is_empty()
541        {
542            let converted_tools = Self::convert_tools(tools)?;
543            builder.tools(converted_tools);
544
545            // 设置 tool_choice
546            if let Some(ref choice) = request.tool_choice {
547                let tc = match choice {
548                    ToolChoice::Auto => ChatCompletionToolChoiceOption::Auto,
549                    ToolChoice::None => ChatCompletionToolChoiceOption::None,
550                    ToolChoice::Required => ChatCompletionToolChoiceOption::Required,
551                    ToolChoice::Specific { function, .. } => ChatCompletionToolChoiceOption::Named(
552                        async_openai::types::ChatCompletionNamedToolChoice {
553                            r#type: ChatCompletionToolType::Function,
554                            function: async_openai::types::FunctionName {
555                                name: function.name.clone(),
556                            },
557                        },
558                    ),
559                };
560                builder.tool_choice(tc);
561            }
562        }
563
564        // 设置响应格式
565        if let Some(ref format) = request.response_format
566            && format.format_type == "json_object"
567        {
568            builder.response_format(async_openai::types::ResponseFormat::JsonObject);
569        }
570
571        let openai_request = builder
572            .build()
573            .map_err(|e| LLMError::ConfigError(e.to_string()))?;
574
575        let response = self
576            .client
577            .chat()
578            .create(openai_request)
579            .await
580            .map_err(Self::convert_error)?;
581
582        Ok(Self::convert_response(response))
583    }
584
585    async fn chat_stream(&self, request: ChatCompletionRequest) -> LLMResult<ChatStream> {
586        let messages = Self::convert_messages(&request.messages)?;
587
588        let model = if request.model.is_empty() {
589            self.config.default_model.clone()
590        } else {
591            request.model.clone()
592        };
593
594        let mut builder = CreateChatCompletionRequestArgs::default();
595        builder.model(&model).messages(messages).stream(true);
596
597        if let Some(temp) = request.temperature {
598            builder.temperature(temp);
599        }
600
601        if let Some(max_tokens) = request.max_tokens {
602            builder.max_tokens(max_tokens);
603        }
604
605        // 设置工具
606        if let Some(ref tools) = request.tools
607            && !tools.is_empty()
608        {
609            let converted_tools = Self::convert_tools(tools)?;
610            builder.tools(converted_tools);
611        }
612
613        let openai_request = builder
614            .build()
615            .map_err(|e| LLMError::ConfigError(e.to_string()))?;
616
617        let stream = self
618            .client
619            .chat()
620            .create_stream(openai_request)
621            .await
622            .map_err(Self::convert_error)?;
623
624        // 转换流,过滤掉 UTF-8 错误(某些 OpenAI 兼容 API 可能返回无效的 UTF-8 数据)
625        let converted_stream = stream
626            .filter_map(|result| async move {
627                match result {
628                    Ok(chunk) => Some(Ok(Self::convert_chunk(chunk))),
629                    Err(e) => {
630                        let err_str = e.to_string();
631                        // 过滤掉 UTF-8 错误,记录日志但继续处理流
632                        if err_str.contains("stream did not contain valid UTF-8") || err_str.contains("utf8") {
633                            tracing::warn!("Skipping invalid UTF-8 chunk from stream (may happen with some OpenAI-compatible APIs)");
634                            None
635                        } else {
636                            Some(Err(Self::convert_error(e)))
637                        }
638                    }
639                }
640            });
641
642        Ok(Box::pin(converted_stream))
643    }
644
645    async fn embedding(&self, request: EmbeddingRequest) -> LLMResult<EmbeddingResponse> {
646        use async_openai::types::CreateEmbeddingRequestArgs;
647
648        let input = match request.input {
649            EmbeddingInput::Single(s) => vec![s],
650            EmbeddingInput::Multiple(v) => v,
651        };
652
653        let openai_request = CreateEmbeddingRequestArgs::default()
654            .model(&request.model)
655            .input(input)
656            .build()
657            .map_err(|e| LLMError::ConfigError(e.to_string()))?;
658
659        let response = self
660            .client
661            .embeddings()
662            .create(openai_request)
663            .await
664            .map_err(Self::convert_error)?;
665
666        let data: Vec<EmbeddingData> = response
667            .data
668            .into_iter()
669            .map(|d| EmbeddingData {
670                object: "embedding".to_string(),
671                index: d.index,
672                embedding: d.embedding,
673            })
674            .collect();
675
676        Ok(EmbeddingResponse {
677            object: "list".to_string(),
678            model: response.model,
679            data,
680            usage: EmbeddingUsage {
681                prompt_tokens: response.usage.prompt_tokens,
682                total_tokens: response.usage.total_tokens,
683            },
684        })
685    }
686
687    async fn health_check(&self) -> LLMResult<bool> {
688        // 发送一个简单请求来检查连接
689        let request = ChatCompletionRequest::new(&self.config.default_model)
690            .system("Say 'ok'")
691            .max_tokens(5);
692
693        match self.chat(request).await {
694            Ok(_) => Ok(true),
695            Err(_) => Ok(false),
696        }
697    }
698
699    async fn get_model_info(&self, model: &str) -> LLMResult<ModelInfo> {
700        // OpenAI 没有公开的模型信息 API,返回预定义信息
701        let info = match model {
702            "gpt-4o" => ModelInfo {
703                id: "gpt-4o".to_string(),
704                name: "GPT-4o".to_string(),
705                description: Some("Most capable GPT-4 model with vision".to_string()),
706                context_window: Some(128000),
707                max_output_tokens: Some(16384),
708                training_cutoff: Some("2023-10".to_string()),
709                capabilities: ModelCapabilities {
710                    streaming: true,
711                    tools: true,
712                    vision: true,
713                    json_mode: true,
714                    json_schema: true,
715                },
716            },
717            "gpt-4o-mini" => ModelInfo {
718                id: "gpt-4o-mini".to_string(),
719                name: "GPT-4o Mini".to_string(),
720                description: Some("Smaller, faster GPT-4o".to_string()),
721                context_window: Some(128000),
722                max_output_tokens: Some(16384),
723                training_cutoff: Some("2023-10".to_string()),
724                capabilities: ModelCapabilities {
725                    streaming: true,
726                    tools: true,
727                    vision: true,
728                    json_mode: true,
729                    json_schema: true,
730                },
731            },
732            "gpt-4-turbo" => ModelInfo {
733                id: "gpt-4-turbo".to_string(),
734                name: "GPT-4 Turbo".to_string(),
735                description: Some("GPT-4 Turbo with vision".to_string()),
736                context_window: Some(128000),
737                max_output_tokens: Some(4096),
738                training_cutoff: Some("2023-12".to_string()),
739                capabilities: ModelCapabilities {
740                    streaming: true,
741                    tools: true,
742                    vision: true,
743                    json_mode: true,
744                    json_schema: false,
745                },
746            },
747            "gpt-3.5-turbo" => ModelInfo {
748                id: "gpt-3.5-turbo".to_string(),
749                name: "GPT-3.5 Turbo".to_string(),
750                description: Some("Fast and cost-effective".to_string()),
751                context_window: Some(16385),
752                max_output_tokens: Some(4096),
753                training_cutoff: Some("2021-09".to_string()),
754                capabilities: ModelCapabilities {
755                    streaming: true,
756                    tools: true,
757                    vision: false,
758                    json_mode: true,
759                    json_schema: false,
760                },
761            },
762            _ => ModelInfo {
763                id: model.to_string(),
764                name: model.to_string(),
765                description: None,
766                context_window: None,
767                max_output_tokens: None,
768                training_cutoff: None,
769                capabilities: ModelCapabilities::default(),
770            },
771        };
772
773        Ok(info)
774    }
775}
776
777impl OpenAIProvider {
778    /// 转换流式响应块
779    fn convert_chunk(
780        chunk: async_openai::types::CreateChatCompletionStreamResponse,
781    ) -> ChatCompletionChunk {
782        let choices: Vec<ChunkChoice> = chunk
783            .choices
784            .into_iter()
785            .map(|choice| {
786                let delta = ChunkDelta {
787                    role: choice.delta.role.map(|_| Role::Assistant),
788                    content: choice.delta.content,
789                    tool_calls: choice.delta.tool_calls.map(|calls| {
790                        calls
791                            .into_iter()
792                            .map(|tc| ToolCallDelta {
793                                index: tc.index,
794                                id: tc.id,
795                                call_type: Some("function".to_string()),
796                                function: tc.function.map(|f| FunctionCallDelta {
797                                    name: f.name,
798                                    arguments: f.arguments,
799                                }),
800                            })
801                            .collect()
802                    }),
803                };
804
805                let finish_reason = choice.finish_reason.map(|r| match r {
806                    async_openai::types::FinishReason::Stop => FinishReason::Stop,
807                    async_openai::types::FinishReason::Length => FinishReason::Length,
808                    async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
809                    async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
810                    async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
811                });
812
813                ChunkChoice {
814                    index: choice.index,
815                    delta,
816                    finish_reason,
817                }
818            })
819            .collect();
820
821        ChatCompletionChunk {
822            id: chunk.id,
823            object: "chat.completion.chunk".to_string(),
824            created: chunk.created as u64,
825            model: chunk.model,
826            choices,
827            usage: chunk.usage.map(|u| Usage {
828                prompt_tokens: u.prompt_tokens,
829                completion_tokens: u.completion_tokens,
830                total_tokens: u.total_tokens,
831            }),
832        }
833    }
834
835    /// 转换错误
836    fn convert_error(err: async_openai::error::OpenAIError) -> LLMError {
837        match err {
838            async_openai::error::OpenAIError::ApiError(api_err) => {
839                let code = api_err.code.clone();
840                let message = api_err.message.clone();
841
842                // 根据错误类型分类
843                if message.contains("rate limit") {
844                    LLMError::RateLimited(message)
845                } else if message.contains("quota") || message.contains("billing") {
846                    LLMError::QuotaExceeded(message)
847                } else if message.contains("model") && message.contains("not found") {
848                    LLMError::ModelNotFound(message)
849                } else if message.contains("context") || message.contains("tokens") {
850                    LLMError::ContextLengthExceeded(message)
851                } else if message.contains("content") && message.contains("filter") {
852                    LLMError::ContentFiltered(message)
853                } else {
854                    LLMError::ApiError { code, message }
855                }
856            }
857            async_openai::error::OpenAIError::Reqwest(e) => {
858                if e.is_timeout() {
859                    LLMError::Timeout(e.to_string())
860                } else {
861                    LLMError::NetworkError(e.to_string())
862                }
863            }
864            async_openai::error::OpenAIError::InvalidArgument(msg) => LLMError::ConfigError(msg),
865            _ => LLMError::Other(err.to_string()),
866        }
867    }
868
869    /// 快速创建 OpenAI Provider
870    pub fn openai(api_key: impl Into<String>) -> OpenAIProvider {
871        OpenAIProvider::new(api_key)
872    }
873
874    /// 快速创建兼容 OpenAI API 的本地 Provider
875    pub fn openai_compatible(
876        base_url: impl Into<String>,
877        api_key: impl Into<String>,
878        model: impl Into<String>,
879    ) -> OpenAIProvider {
880        let config = OpenAIConfig::new(api_key)
881            .with_base_url(base_url)
882            .with_model(model);
883        OpenAIProvider::with_config(config)
884    }
885}
886
887#[cfg(test)]
888mod tests {
889    use super::*;
890
891    #[test]
892    fn test_config_builder() {
893        let config = OpenAIConfig::new("sk-test")
894            .with_base_url("http://localhost:8080")
895            .with_model("gpt-4")
896            .with_temperature(0.5)
897            .with_max_tokens(2048);
898
899        assert_eq!(config.api_key, "sk-test");
900        assert_eq!(config.base_url, Some("http://localhost:8080".to_string()));
901        assert_eq!(config.default_model, "gpt-4");
902        assert_eq!(config.default_temperature, 0.5);
903        assert_eq!(config.default_max_tokens, 2048);
904    }
905
906    #[test]
907    fn test_provider_name() {
908        let provider = OpenAIProvider::new("test-key");
909        assert_eq!(provider.name(), "openai");
910    }
911}