Skip to main content

mofa_foundation/llm/
openai.rs

1//! OpenAI Provider Implementation
2//!
3//! 使用 `async-openai` crate 实现 OpenAI API 交互
4//!
5//! # 支持的服务
6//!
7//! - OpenAI API (api.openai.com)
8//! - Azure OpenAI
9//! - 兼容 OpenAI API 的本地服务 (Ollama, vLLM, LocalAI 等)
10//!
11//! # 示例
12//!
13//! ```rust,ignore
14//! use mofa_foundation::llm::openai::{OpenAIProvider, OpenAIConfig};
15//!
16//! // 使用 OpenAI
17//! let provider = OpenAIProvider::new("sk-xxx");
18//!
19//! // 使用自定义 endpoint
20//! let provider = OpenAIProvider::with_config(
21//!     OpenAIConfig::new("sk-xxx")
22//!         .with_base_url("http://localhost:11434/v1")
23//!         .with_model("llama2")
24//! );
25//!
26//! // 使用 Azure OpenAI
27//! let provider = OpenAIProvider::azure("https://xxx.openai.azure.com", "api-key", "deployment");
28//! ```
29
30use super::provider::{ChatStream, LLMProvider, ModelCapabilities, ModelInfo};
31use super::types::*;
32use async_openai::{
33    Client,
34    config::OpenAIConfig as AsyncOpenAIConfig,
35    types::{
36        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
37        ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage,
38        ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs,
39        ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
40        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
41        ChatCompletionToolArgs, ChatCompletionToolChoiceOption, ChatCompletionToolType,
42        CreateChatCompletionRequestArgs, FunctionObjectArgs, ImageDetail as OpenAIImageDetail,
43        ImageUrl as OpenAIImageUrl, InputAudio, InputAudioFormat,
44    },
45};
46use async_trait::async_trait;
47use futures::StreamExt;
48
49/// OpenAI Provider 配置
50#[derive(Debug, Clone)]
51pub struct OpenAIConfig {
52    /// API Key
53    pub api_key: String,
54    /// API 基础 URL
55    pub base_url: Option<String>,
56    /// 组织 ID
57    pub org_id: Option<String>,
58    /// 默认模型
59    pub default_model: String,
60    /// 默认温度
61    pub default_temperature: f32,
62    /// 默认最大 token 数
63    pub default_max_tokens: u32,
64    /// 请求超时(秒)
65    pub timeout_secs: u64,
66}
67
68impl Default for OpenAIConfig {
69    fn default() -> Self {
70        Self {
71            api_key: String::new(),
72            base_url: None,
73            org_id: None,
74            default_model: "gpt-4o".to_string(),
75            default_temperature: 0.7,
76            default_max_tokens: 4096,
77            timeout_secs: 60,
78        }
79    }
80}
81
82impl OpenAIConfig {
83    /// 创建新配置
84    pub fn new(api_key: impl Into<String>) -> Self {
85        Self {
86            api_key: api_key.into(),
87            ..Default::default()
88        }
89    }
90
91    /// 从环境变量创建配置
92    pub fn from_env() -> Self {
93        Self {
94            api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
95            base_url: std::env::var("OPENAI_BASE_URL").ok(),
96            default_model: std::env::var("OPENAI_MODEL").unwrap_or_default(),
97            ..Default::default()
98        }
99    }
100
101    /// 设置 base URL
102    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
103        self.base_url = Some(url.into());
104        self
105    }
106
107    /// 设置默认模型
108    pub fn with_model(mut self, model: impl Into<String>) -> Self {
109        self.default_model = model.into();
110        self
111    }
112
113    /// 设置默认温度
114    pub fn with_temperature(mut self, temp: f32) -> Self {
115        self.default_temperature = temp;
116        self
117    }
118
119    /// 设置默认最大 token 数
120    pub fn with_max_tokens(mut self, tokens: u32) -> Self {
121        self.default_max_tokens = tokens;
122        self
123    }
124
125    /// 设置组织 ID
126    pub fn with_org_id(mut self, org_id: impl Into<String>) -> Self {
127        self.org_id = Some(org_id.into());
128        self
129    }
130
131    /// 设置超时
132    pub fn with_timeout(mut self, secs: u64) -> Self {
133        self.timeout_secs = secs;
134        self
135    }
136}
137
138/// OpenAI LLM Provider
139///
140/// 支持 OpenAI API 及兼容服务
141pub struct OpenAIProvider {
142    client: Client<AsyncOpenAIConfig>,
143    config: OpenAIConfig,
144}
145
146impl OpenAIProvider {
147    /// 使用 API Key 创建 Provider
148    pub fn new(api_key: impl Into<String>) -> Self {
149        let config = OpenAIConfig::new(api_key);
150        Self::with_config(config)
151    }
152
153    /// 从环境变量创建 Provider
154    pub fn from_env() -> Self {
155        Self::with_config(OpenAIConfig::from_env())
156    }
157
158    /// 使用配置创建 Provider
159    pub fn with_config(config: OpenAIConfig) -> Self {
160        let mut openai_config = AsyncOpenAIConfig::new().with_api_key(&config.api_key);
161
162        if let Some(ref base_url) = config.base_url {
163            openai_config = openai_config.with_api_base(base_url);
164        }
165
166        if let Some(ref org_id) = config.org_id {
167            openai_config = openai_config.with_org_id(org_id);
168        }
169
170        let client = Client::with_config(openai_config);
171
172        Self { client, config }
173    }
174
175    /// 创建 Azure OpenAI Provider
176    pub fn azure(
177        endpoint: impl Into<String>,
178        api_key: impl Into<String>,
179        deployment: impl Into<String>,
180    ) -> Self {
181        let endpoint = endpoint.into();
182        let deployment = deployment.into();
183
184        // Azure OpenAI 使用不同的 URL 格式
185        let base_url = format!(
186            "{}/openai/deployments/{}",
187            endpoint.trim_end_matches('/'),
188            deployment
189        );
190
191        let config = OpenAIConfig::new(api_key)
192            .with_base_url(base_url)
193            .with_model(deployment);
194
195        Self::with_config(config)
196    }
197
198    /// 创建兼容 OpenAI API 的本地服务 Provider
199    pub fn local(base_url: impl Into<String>, model: impl Into<String>) -> Self {
200        let config = OpenAIConfig::new("not-needed")
201            .with_base_url(base_url)
202            .with_model(model);
203
204        Self::with_config(config)
205    }
206
207    /// 创建 Ollama Provider
208    pub fn ollama(model: impl Into<String>) -> Self {
209        Self::local("http://localhost:11434/v1", model)
210    }
211
212    /// 获取底层 async-openai 客户端
213    pub fn client(&self) -> &Client<AsyncOpenAIConfig> {
214        &self.client
215    }
216
217    /// 获取配置
218    pub fn config(&self) -> &OpenAIConfig {
219        &self.config
220    }
221
222    /// 转换消息格式
223    fn convert_messages(
224        messages: &[ChatMessage],
225    ) -> Result<Vec<ChatCompletionRequestMessage>, LLMError> {
226        messages.iter().map(Self::convert_message).collect()
227    }
228
229    /// 转换单个消息
230    fn convert_message(msg: &ChatMessage) -> Result<ChatCompletionRequestMessage, LLMError> {
231        let text_only_content = msg
232            .content
233            .as_ref()
234            .map(|c| match c {
235                MessageContent::Text(s) => s.clone(),
236                MessageContent::Parts(parts) => parts
237                    .iter()
238                    .filter_map(|p| match p {
239                        ContentPart::Text { text } => Some(text.clone()),
240                        _ => None,
241                    })
242                    .collect::<Vec<_>>()
243                    .join("\n"),
244            })
245            .unwrap_or_default();
246
247        match msg.role {
248            Role::System => Ok(ChatCompletionRequestSystemMessageArgs::default()
249                .content(text_only_content)
250                .build()
251                .map_err(|e| LLMError::Other(e.to_string()))?
252                .into()),
253            Role::User => {
254                let content = match msg.content.as_ref() {
255                    Some(MessageContent::Text(s)) => {
256                        ChatCompletionRequestUserMessageContent::Text(s.clone())
257                    }
258                    Some(MessageContent::Parts(parts)) => {
259                        let mut out = Vec::new();
260                        for part in parts {
261                            match part {
262                                ContentPart::Text { text } => {
263                                    out.push(ChatCompletionRequestUserMessageContentPart::Text(
264                                        ChatCompletionRequestMessageContentPartText {
265                                            text: text.clone(),
266                                        },
267                                    ));
268                                }
269                                ContentPart::Image { image_url } => {
270                                    let detail = image_url.detail.as_ref().map(|d| match d {
271                                        ImageDetail::Auto => OpenAIImageDetail::Auto,
272                                        ImageDetail::Low => OpenAIImageDetail::Low,
273                                        ImageDetail::High => OpenAIImageDetail::High,
274                                    });
275                                    let image_part = ChatCompletionRequestMessageContentPartImage {
276                                        image_url: OpenAIImageUrl {
277                                            url: image_url.url.clone(),
278                                            detail,
279                                        },
280                                    };
281                                    out.push(
282                                        ChatCompletionRequestUserMessageContentPart::ImageUrl(
283                                            image_part,
284                                        ),
285                                    );
286                                }
287                                ContentPart::Audio { audio } => {
288                                    let format = match audio.format.to_lowercase().as_str() {
289                                        "wav" => InputAudioFormat::Wav,
290                                        _ => InputAudioFormat::Mp3,
291                                    };
292                                    let audio_part = ChatCompletionRequestMessageContentPartAudio {
293                                        input_audio: InputAudio {
294                                            data: audio.data.clone(),
295                                            format,
296                                        },
297                                    };
298                                    out.push(
299                                        ChatCompletionRequestUserMessageContentPart::InputAudio(
300                                            audio_part,
301                                        ),
302                                    );
303                                }
304                            }
305                        }
306                        ChatCompletionRequestUserMessageContent::Array(out)
307                    }
308                    None => ChatCompletionRequestUserMessageContent::Text(String::new()),
309                };
310
311                Ok(ChatCompletionRequestUserMessageArgs::default()
312                    .content(content)
313                    .build()
314                    .map_err(|e| LLMError::Other(e.to_string()))?
315                    .into())
316            }
317            Role::Assistant => {
318                let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
319                if !text_only_content.is_empty() {
320                    builder.content(text_only_content);
321                }
322
323                // 处理工具调用
324                if let Some(ref tool_calls) = msg.tool_calls {
325                    let converted_calls: Vec<_> = tool_calls
326                        .iter()
327                        .map(|tc| async_openai::types::ChatCompletionMessageToolCall {
328                            id: tc.id.clone(),
329                            r#type: ChatCompletionToolType::Function,
330                            function: async_openai::types::FunctionCall {
331                                name: tc.function.name.clone(),
332                                arguments: tc.function.arguments.clone(),
333                            },
334                        })
335                        .collect();
336                    builder.tool_calls(converted_calls);
337                }
338
339                Ok(builder
340                    .build()
341                    .map_err(|e| LLMError::Other(e.to_string()))?
342                    .into())
343            }
344            Role::Tool => {
345                let tool_call_id = msg
346                    .tool_call_id
347                    .clone()
348                    .unwrap_or_else(|| "unknown".to_string());
349
350                Ok(ChatCompletionRequestToolMessageArgs::default()
351                    .tool_call_id(tool_call_id)
352                    .content(text_only_content)
353                    .build()
354                    .map_err(|e| LLMError::Other(e.to_string()))?
355                    .into())
356            }
357        }
358    }
359
360    /// 转换工具定义
361    fn convert_tools(
362        tools: &[Tool],
363    ) -> Result<Vec<async_openai::types::ChatCompletionTool>, LLMError> {
364        tools
365            .iter()
366            .map(|tool| {
367                let function = FunctionObjectArgs::default()
368                    .name(&tool.function.name)
369                    .description(tool.function.description.clone().unwrap_or_default())
370                    .parameters(
371                        tool.function
372                            .parameters
373                            .clone()
374                            .unwrap_or(serde_json::json!({})),
375                    )
376                    .build()
377                    .map_err(|e| LLMError::Other(e.to_string()))?;
378
379                ChatCompletionToolArgs::default()
380                    .r#type(ChatCompletionToolType::Function)
381                    .function(function)
382                    .build()
383                    .map_err(|e| LLMError::Other(e.to_string()))
384            })
385            .collect()
386    }
387
388    /// 转换响应
389    fn convert_response(
390        response: async_openai::types::CreateChatCompletionResponse,
391    ) -> ChatCompletionResponse {
392        let choices: Vec<Choice> = response
393            .choices
394            .into_iter()
395            .map(|choice| {
396                let message = Self::convert_response_message(choice.message);
397                let finish_reason = choice.finish_reason.map(|r| match r {
398                    async_openai::types::FinishReason::Stop => FinishReason::Stop,
399                    async_openai::types::FinishReason::Length => FinishReason::Length,
400                    async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
401                    async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
402                    async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
403                });
404
405                Choice {
406                    index: choice.index,
407                    message,
408                    finish_reason,
409                    logprobs: None,
410                }
411            })
412            .collect();
413
414        let usage = response.usage.map(|u| Usage {
415            prompt_tokens: u.prompt_tokens,
416            completion_tokens: u.completion_tokens,
417            total_tokens: u.total_tokens,
418        });
419
420        ChatCompletionResponse {
421            id: response.id,
422            object: response.object,
423            created: response.created as u64,
424            model: response.model,
425            choices,
426            usage,
427            system_fingerprint: response.system_fingerprint,
428        }
429    }
430
431    /// 转换响应消息
432    fn convert_response_message(
433        msg: async_openai::types::ChatCompletionResponseMessage,
434    ) -> ChatMessage {
435        let content = msg.content.map(MessageContent::Text);
436
437        let tool_calls = msg.tool_calls.map(|calls| {
438            calls
439                .into_iter()
440                .map(|tc| ToolCall {
441                    id: tc.id,
442                    call_type: "function".to_string(),
443                    function: FunctionCall {
444                        name: tc.function.name,
445                        arguments: tc.function.arguments,
446                    },
447                })
448                .collect()
449        });
450
451        ChatMessage {
452            role: Role::Assistant,
453            content,
454            name: None,
455            tool_calls,
456            tool_call_id: None,
457        }
458    }
459}
460
461#[async_trait]
462impl LLMProvider for OpenAIProvider {
463    fn name(&self) -> &str {
464        "openai"
465    }
466
467    fn default_model(&self) -> &str {
468        &self.config.default_model
469    }
470
471    fn supported_models(&self) -> Vec<&str> {
472        vec![
473            "gpt-4o",
474            "gpt-4o-mini",
475            "gpt-4-turbo",
476            "gpt-4",
477            "gpt-3.5-turbo",
478            "o1",
479            "o1-mini",
480            "o1-preview",
481        ]
482    }
483
484    fn supports_streaming(&self) -> bool {
485        true
486    }
487
488    fn supports_tools(&self) -> bool {
489        true
490    }
491
492    fn supports_vision(&self) -> bool {
493        true
494    }
495
496    fn supports_embedding(&self) -> bool {
497        true
498    }
499
500    async fn chat(&self, request: ChatCompletionRequest) -> LLMResult<ChatCompletionResponse> {
501        let messages = Self::convert_messages(&request.messages)?;
502
503        let model = if request.model.is_empty() {
504            self.config.default_model.clone()
505        } else {
506            request.model.clone()
507        };
508
509        let mut builder = CreateChatCompletionRequestArgs::default();
510        builder.model(&model).messages(messages);
511
512        // 设置可选参数
513        if let Some(temp) = request.temperature {
514            builder.temperature(temp);
515        } else {
516            builder.temperature(self.config.default_temperature);
517        }
518
519        if let Some(max_tokens) = request.max_tokens {
520            builder.max_tokens(max_tokens);
521        }
522
523        if let Some(top_p) = request.top_p {
524            builder.top_p(top_p);
525        }
526
527        if let Some(ref stop) = request.stop {
528            builder.stop(stop.clone());
529        }
530
531        if let Some(freq_penalty) = request.frequency_penalty {
532            builder.frequency_penalty(freq_penalty);
533        }
534
535        if let Some(pres_penalty) = request.presence_penalty {
536            builder.presence_penalty(pres_penalty);
537        }
538
539        if let Some(ref user) = request.user {
540            builder.user(user);
541        }
542
543        // 设置工具
544        if let Some(ref tools) = request.tools
545            && !tools.is_empty()
546        {
547            let converted_tools = Self::convert_tools(tools)?;
548            builder.tools(converted_tools);
549
550            // 设置 tool_choice
551            if let Some(ref choice) = request.tool_choice {
552                let tc = match choice {
553                    ToolChoice::Auto => ChatCompletionToolChoiceOption::Auto,
554                    ToolChoice::None => ChatCompletionToolChoiceOption::None,
555                    ToolChoice::Required => ChatCompletionToolChoiceOption::Required,
556                    ToolChoice::Specific { function, .. } => ChatCompletionToolChoiceOption::Named(
557                        async_openai::types::ChatCompletionNamedToolChoice {
558                            r#type: ChatCompletionToolType::Function,
559                            function: async_openai::types::FunctionName {
560                                name: function.name.clone(),
561                            },
562                        },
563                    ),
564                };
565                builder.tool_choice(tc);
566            }
567        }
568
569        // 设置响应格式
570        if let Some(ref format) = request.response_format
571            && format.format_type == "json_object"
572        {
573            builder.response_format(async_openai::types::ResponseFormat::JsonObject);
574        }
575
576        let openai_request = builder
577            .build()
578            .map_err(|e| LLMError::ConfigError(e.to_string()))?;
579
580        let response = self
581            .client
582            .chat()
583            .create(openai_request)
584            .await
585            .map_err(Self::convert_error)?;
586
587        Ok(Self::convert_response(response))
588    }
589
590    async fn chat_stream(&self, request: ChatCompletionRequest) -> LLMResult<ChatStream> {
591        let messages = Self::convert_messages(&request.messages)?;
592
593        let model = if request.model.is_empty() {
594            self.config.default_model.clone()
595        } else {
596            request.model.clone()
597        };
598
599        let mut builder = CreateChatCompletionRequestArgs::default();
600        builder.model(&model).messages(messages).stream(true);
601
602        if let Some(temp) = request.temperature {
603            builder.temperature(temp);
604        }
605
606        if let Some(max_tokens) = request.max_tokens {
607            builder.max_tokens(max_tokens);
608        }
609
610        // 设置工具
611        if let Some(ref tools) = request.tools
612            && !tools.is_empty()
613        {
614            let converted_tools = Self::convert_tools(tools)?;
615            builder.tools(converted_tools);
616        }
617
618        let openai_request = builder
619            .build()
620            .map_err(|e| LLMError::ConfigError(e.to_string()))?;
621
622        let stream = self
623            .client
624            .chat()
625            .create_stream(openai_request)
626            .await
627            .map_err(Self::convert_error)?;
628
629        // 转换流,过滤掉 UTF-8 错误(某些 OpenAI 兼容 API 可能返回无效的 UTF-8 数据)
630        let converted_stream = stream
631            .filter_map(|result| async move {
632                match result {
633                    Ok(chunk) => Some(Ok(Self::convert_chunk(chunk))),
634                    Err(e) => {
635                        let err_str = e.to_string();
636                        // 过滤掉 UTF-8 错误,记录日志但继续处理流
637                        if err_str.contains("stream did not contain valid UTF-8") || err_str.contains("utf8") {
638                            tracing::warn!("Skipping invalid UTF-8 chunk from stream (may happen with some OpenAI-compatible APIs)");
639                            None
640                        } else {
641                            Some(Err(Self::convert_error(e)))
642                        }
643                    }
644                }
645            });
646
647        Ok(Box::pin(converted_stream))
648    }
649
650    async fn embedding(&self, request: EmbeddingRequest) -> LLMResult<EmbeddingResponse> {
651        use async_openai::types::CreateEmbeddingRequestArgs;
652
653        let input = match request.input {
654            EmbeddingInput::Single(s) => vec![s],
655            EmbeddingInput::Multiple(v) => v,
656        };
657
658        let openai_request = CreateEmbeddingRequestArgs::default()
659            .model(&request.model)
660            .input(input)
661            .build()
662            .map_err(|e| LLMError::ConfigError(e.to_string()))?;
663
664        let response = self
665            .client
666            .embeddings()
667            .create(openai_request)
668            .await
669            .map_err(Self::convert_error)?;
670
671        let data: Vec<EmbeddingData> = response
672            .data
673            .into_iter()
674            .map(|d| EmbeddingData {
675                object: "embedding".to_string(),
676                index: d.index,
677                embedding: d.embedding,
678            })
679            .collect();
680
681        Ok(EmbeddingResponse {
682            object: "list".to_string(),
683            model: response.model,
684            data,
685            usage: EmbeddingUsage {
686                prompt_tokens: response.usage.prompt_tokens,
687                total_tokens: response.usage.total_tokens,
688            },
689        })
690    }
691
692    async fn health_check(&self) -> LLMResult<bool> {
693        // 发送一个简单请求来检查连接
694        let request = ChatCompletionRequest::new(&self.config.default_model)
695            .system("Say 'ok'")
696            .max_tokens(5);
697
698        match self.chat(request).await {
699            Ok(_) => Ok(true),
700            Err(_) => Ok(false),
701        }
702    }
703
704    async fn get_model_info(&self, model: &str) -> LLMResult<ModelInfo> {
705        // OpenAI 没有公开的模型信息 API,返回预定义信息
706        let info = match model {
707            "gpt-4o" => ModelInfo {
708                id: "gpt-4o".to_string(),
709                name: "GPT-4o".to_string(),
710                description: Some("Most capable GPT-4 model with vision".to_string()),
711                context_window: Some(128000),
712                max_output_tokens: Some(16384),
713                training_cutoff: Some("2023-10".to_string()),
714                capabilities: ModelCapabilities {
715                    streaming: true,
716                    tools: true,
717                    vision: true,
718                    json_mode: true,
719                    json_schema: true,
720                },
721            },
722            "gpt-4o-mini" => ModelInfo {
723                id: "gpt-4o-mini".to_string(),
724                name: "GPT-4o Mini".to_string(),
725                description: Some("Smaller, faster GPT-4o".to_string()),
726                context_window: Some(128000),
727                max_output_tokens: Some(16384),
728                training_cutoff: Some("2023-10".to_string()),
729                capabilities: ModelCapabilities {
730                    streaming: true,
731                    tools: true,
732                    vision: true,
733                    json_mode: true,
734                    json_schema: true,
735                },
736            },
737            "gpt-4-turbo" => ModelInfo {
738                id: "gpt-4-turbo".to_string(),
739                name: "GPT-4 Turbo".to_string(),
740                description: Some("GPT-4 Turbo with vision".to_string()),
741                context_window: Some(128000),
742                max_output_tokens: Some(4096),
743                training_cutoff: Some("2023-12".to_string()),
744                capabilities: ModelCapabilities {
745                    streaming: true,
746                    tools: true,
747                    vision: true,
748                    json_mode: true,
749                    json_schema: false,
750                },
751            },
752            "gpt-3.5-turbo" => ModelInfo {
753                id: "gpt-3.5-turbo".to_string(),
754                name: "GPT-3.5 Turbo".to_string(),
755                description: Some("Fast and cost-effective".to_string()),
756                context_window: Some(16385),
757                max_output_tokens: Some(4096),
758                training_cutoff: Some("2021-09".to_string()),
759                capabilities: ModelCapabilities {
760                    streaming: true,
761                    tools: true,
762                    vision: false,
763                    json_mode: true,
764                    json_schema: false,
765                },
766            },
767            _ => ModelInfo {
768                id: model.to_string(),
769                name: model.to_string(),
770                description: None,
771                context_window: None,
772                max_output_tokens: None,
773                training_cutoff: None,
774                capabilities: ModelCapabilities::default(),
775            },
776        };
777
778        Ok(info)
779    }
780}
781
782impl OpenAIProvider {
783    /// 转换流式响应块
784    fn convert_chunk(
785        chunk: async_openai::types::CreateChatCompletionStreamResponse,
786    ) -> ChatCompletionChunk {
787        let choices: Vec<ChunkChoice> = chunk
788            .choices
789            .into_iter()
790            .map(|choice| {
791                let delta = ChunkDelta {
792                    role: choice.delta.role.map(|_| Role::Assistant),
793                    content: choice.delta.content,
794                    tool_calls: choice.delta.tool_calls.map(|calls| {
795                        calls
796                            .into_iter()
797                            .map(|tc| ToolCallDelta {
798                                index: tc.index,
799                                id: tc.id,
800                                call_type: Some("function".to_string()),
801                                function: tc.function.map(|f| FunctionCallDelta {
802                                    name: f.name,
803                                    arguments: f.arguments,
804                                }),
805                            })
806                            .collect()
807                    }),
808                };
809
810                let finish_reason = choice.finish_reason.map(|r| match r {
811                    async_openai::types::FinishReason::Stop => FinishReason::Stop,
812                    async_openai::types::FinishReason::Length => FinishReason::Length,
813                    async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
814                    async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
815                    async_openai::types::FinishReason::FunctionCall => FinishReason::ToolCalls,
816                });
817
818                ChunkChoice {
819                    index: choice.index,
820                    delta,
821                    finish_reason,
822                }
823            })
824            .collect();
825
826        ChatCompletionChunk {
827            id: chunk.id,
828            object: "chat.completion.chunk".to_string(),
829            created: chunk.created as u64,
830            model: chunk.model,
831            choices,
832            usage: chunk.usage.map(|u| Usage {
833                prompt_tokens: u.prompt_tokens,
834                completion_tokens: u.completion_tokens,
835                total_tokens: u.total_tokens,
836            }),
837        }
838    }
839
840    /// 转换错误
841    fn convert_error(err: async_openai::error::OpenAIError) -> LLMError {
842        match err {
843            async_openai::error::OpenAIError::ApiError(api_err) => {
844                let code = api_err.code.clone();
845                let message = api_err.message.clone();
846
847                // 根据错误类型分类
848                if message.contains("rate limit") {
849                    LLMError::RateLimited(message)
850                } else if message.contains("quota") || message.contains("billing") {
851                    LLMError::QuotaExceeded(message)
852                } else if message.contains("model") && message.contains("not found") {
853                    LLMError::ModelNotFound(message)
854                } else if message.contains("context") || message.contains("tokens") {
855                    LLMError::ContextLengthExceeded(message)
856                } else if message.contains("content") && message.contains("filter") {
857                    LLMError::ContentFiltered(message)
858                } else {
859                    LLMError::ApiError { code, message }
860                }
861            }
862            async_openai::error::OpenAIError::Reqwest(e) => {
863                if e.is_timeout() {
864                    LLMError::Timeout(e.to_string())
865                } else {
866                    LLMError::NetworkError(e.to_string())
867                }
868            }
869            async_openai::error::OpenAIError::InvalidArgument(msg) => LLMError::ConfigError(msg),
870            _ => LLMError::Other(err.to_string()),
871        }
872    }
873
874    /// 快速创建 OpenAI Provider
875    pub fn openai(api_key: impl Into<String>) -> OpenAIProvider {
876        OpenAIProvider::new(api_key)
877    }
878
879    /// 快速创建兼容 OpenAI API 的本地 Provider
880    pub fn openai_compatible(
881        base_url: impl Into<String>,
882        api_key: impl Into<String>,
883        model: impl Into<String>,
884    ) -> OpenAIProvider {
885        let config = OpenAIConfig::new(api_key)
886            .with_base_url(base_url)
887            .with_model(model);
888        OpenAIProvider::with_config(config)
889    }
890}
891
892#[cfg(test)]
893mod tests {
894    use super::*;
895
896    #[test]
897    fn test_config_builder() {
898        let config = OpenAIConfig::new("sk-test")
899            .with_base_url("http://localhost:8080")
900            .with_model("gpt-4")
901            .with_temperature(0.5)
902            .with_max_tokens(2048);
903
904        assert_eq!(config.api_key, "sk-test");
905        assert_eq!(config.base_url, Some("http://localhost:8080".to_string()));
906        assert_eq!(config.default_model, "gpt-4");
907        assert_eq!(config.default_temperature, 0.5);
908        assert_eq!(config.default_max_tokens, 2048);
909    }
910
911    #[test]
912    fn test_provider_name() {
913        let provider = OpenAIProvider::new("test-key");
914        assert_eq!(provider.name(), "openai");
915    }
916}