aidale_provider/
openai.rs

1//! OpenAI provider implementation using async-openai crate.
2//!
3//! This provider implements the simplified Provider trait, only exposing
4//! chat_completion() and stream_chat_completion(). Higher-level abstractions
5//! like generate_text() and generate_object() are handled by the Runtime layer.
6
7use aidale_core::error::AiError;
8use aidale_core::provider::{ChatCompletionStream, Provider};
9use aidale_core::types::*;
10use async_openai::config::OpenAIConfig;
11use async_openai::types::{
12    ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
13    ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequest,
14    CreateChatCompletionRequestArgs, CreateChatCompletionStreamResponse,
15    ResponseFormat as OpenAIResponseFormat,
16    ResponseFormatJsonSchema as OpenAIResponseFormatJsonSchema,
17};
18use async_openai::Client;
19use async_trait::async_trait;
20use futures::stream::{Stream, StreamExt};
21use std::sync::Arc;
22
23/// OpenAI provider using async-openai
24#[derive(Clone)]
25pub struct OpenAiProvider {
26    client: Client<OpenAIConfig>,
27    info: Arc<ProviderInfo>,
28}
29
30impl std::fmt::Debug for OpenAiProvider {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("OpenAiProvider")
33            .field("info", &self.info)
34            .finish()
35    }
36}
37
38impl OpenAiProvider {
39    /// Create a new OpenAI provider with default configuration
40    pub fn new(api_key: impl Into<String>) -> Self {
41        let config = OpenAIConfig::new().with_api_key(api_key);
42        let client = Client::with_config(config);
43
44        Self {
45            client,
46            info: Arc::new(ProviderInfo {
47                id: "openai".to_string(),
48                name: "OpenAI".to_string(),
49            }),
50        }
51    }
52
53    /// Create a builder for more configuration options
54    pub fn builder() -> OpenAiBuilder {
55        OpenAiBuilder::default()
56    }
57
58    /// Convert our Message type to OpenAI's ChatCompletionRequestMessage
59    fn convert_message(msg: &Message) -> Result<ChatCompletionRequestMessage, AiError> {
60        // Extract text content from message
61        let content = msg
62            .content
63            .iter()
64            .filter_map(|part| match part {
65                ContentPart::Text { text } => Some(text.clone()),
66                _ => None,
67            })
68            .collect::<Vec<_>>()
69            .join("\n");
70
71        match msg.role {
72            Role::System => {
73                let msg = ChatCompletionRequestSystemMessageArgs::default()
74                    .content(content)
75                    .build()
76                    .map_err(|e| {
77                        AiError::provider(format!("Failed to build system message: {}", e))
78                    })?;
79                Ok(ChatCompletionRequestMessage::System(msg))
80            }
81            Role::User => {
82                let msg = ChatCompletionRequestUserMessageArgs::default()
83                    .content(content)
84                    .build()
85                    .map_err(|e| {
86                        AiError::provider(format!("Failed to build user message: {}", e))
87                    })?;
88                Ok(ChatCompletionRequestMessage::User(msg))
89            }
90            Role::Assistant => {
91                let msg = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
92                    .content(content)
93                    .build()
94                    .map_err(|e| {
95                        AiError::provider(format!("Failed to build assistant message: {}", e))
96                    })?;
97                Ok(ChatCompletionRequestMessage::Assistant(msg))
98            }
99            Role::Tool => {
100                // For tool messages, we'll use a system message as fallback
101                let msg = ChatCompletionRequestSystemMessageArgs::default()
102                    .content(content)
103                    .build()
104                    .map_err(|e| {
105                        AiError::provider(format!("Failed to build tool message: {}", e))
106                    })?;
107                Ok(ChatCompletionRequestMessage::System(msg))
108            }
109        }
110    }
111
112    /// Convert our ResponseFormat to OpenAI's ResponseFormat
113    fn convert_response_format(format: &ResponseFormat) -> Result<OpenAIResponseFormat, AiError> {
114        match format {
115            ResponseFormat::Text => Ok(OpenAIResponseFormat::Text),
116            ResponseFormat::JsonObject => Ok(OpenAIResponseFormat::JsonObject),
117            ResponseFormat::JsonSchema {
118                name,
119                schema,
120                strict,
121            } => {
122                let json_schema = OpenAIResponseFormatJsonSchema {
123                    name: name.clone(),
124                    schema: Some(schema.clone()),
125                    strict: Some(*strict),
126                    description: None,
127                };
128                Ok(OpenAIResponseFormat::JsonSchema { json_schema })
129            }
130        }
131    }
132
133    /// Build CreateChatCompletionRequest from our ChatCompletionRequest
134    fn build_request(
135        &self,
136        req: &ChatCompletionRequest,
137    ) -> Result<CreateChatCompletionRequest, AiError> {
138        let messages: Result<Vec<_>, _> = req.messages.iter().map(Self::convert_message).collect();
139
140        let mut builder = CreateChatCompletionRequestArgs::default();
141        builder.model(&req.model).messages(messages?);
142
143        if let Some(max_tokens) = req.max_tokens {
144            builder.max_tokens(max_tokens);
145        }
146        if let Some(temperature) = req.temperature {
147            builder.temperature(temperature);
148        }
149        if let Some(top_p) = req.top_p {
150            builder.top_p(top_p);
151        }
152        if let Some(frequency_penalty) = req.frequency_penalty {
153            builder.frequency_penalty(frequency_penalty);
154        }
155        if let Some(presence_penalty) = req.presence_penalty {
156            builder.presence_penalty(presence_penalty);
157        }
158        if let Some(stop) = &req.stop {
159            builder.stop(stop.clone());
160        }
161        if let Some(response_format) = &req.response_format {
162            builder.response_format(Self::convert_response_format(response_format)?);
163        }
164        if let Some(stream) = req.stream {
165            builder.stream(stream);
166        }
167
168        builder
169            .build()
170            .map_err(|e| AiError::provider(format!("Failed to build request: {}", e)))
171    }
172
173    /// Convert OpenAI response to our ChatCompletionResponse
174    fn convert_response(
175        &self,
176        response: async_openai::types::CreateChatCompletionResponse,
177    ) -> Result<ChatCompletionResponse, AiError> {
178        let choices = response
179            .choices
180            .into_iter()
181            .map(|choice| {
182                let message = Message {
183                    role: match choice.message.role {
184                        async_openai::types::Role::System => Role::System,
185                        async_openai::types::Role::User => Role::User,
186                        async_openai::types::Role::Assistant => Role::Assistant,
187                        async_openai::types::Role::Tool => Role::Tool,
188                        _ => Role::Assistant,
189                    },
190                    content: vec![ContentPart::Text {
191                        text: choice.message.content.unwrap_or_default(),
192                    }],
193                    name: None, // OpenAI doesn't return name in responses
194                };
195
196                let finish_reason = choice
197                    .finish_reason
198                    .map_or(FinishReason::Stop, |r| match r {
199                        async_openai::types::FinishReason::Stop => FinishReason::Stop,
200                        async_openai::types::FinishReason::Length => FinishReason::Length,
201                        async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
202                        async_openai::types::FinishReason::ContentFilter => {
203                            FinishReason::ContentFilter
204                        }
205                        _ => FinishReason::Other("unknown".to_string()),
206                    });
207
208                Choice {
209                    index: choice.index,
210                    message,
211                    finish_reason,
212                }
213            })
214            .collect();
215
216        let usage = response.usage.map_or(
217            Usage {
218                prompt_tokens: 0,
219                completion_tokens: 0,
220                total_tokens: 0,
221            },
222            |u| Usage {
223                prompt_tokens: u.prompt_tokens,
224                completion_tokens: u.completion_tokens,
225                total_tokens: u.total_tokens,
226            },
227        );
228
229        Ok(ChatCompletionResponse {
230            id: response.id,
231            model: response.model,
232            choices,
233            usage,
234            created: Some(response.created as u64),
235        })
236    }
237
238    /// Convert OpenAI stream chunk to our ChatCompletionChunk
239    fn convert_stream_chunk(
240        response: CreateChatCompletionStreamResponse,
241    ) -> Result<ChatCompletionChunk, AiError> {
242        let choices = response
243            .choices
244            .into_iter()
245            .map(|choice| {
246                let delta = MessageDelta {
247                    role: choice.delta.role.as_ref().map(|r| match r {
248                        async_openai::types::Role::System => Role::System,
249                        async_openai::types::Role::User => Role::User,
250                        async_openai::types::Role::Assistant => Role::Assistant,
251                        async_openai::types::Role::Tool => Role::Tool,
252                        _ => Role::Assistant,
253                    }),
254                    content: choice.delta.content,
255                    tool_calls: None,
256                };
257
258                let finish_reason = choice.finish_reason.map(|r| match r {
259                    async_openai::types::FinishReason::Stop => FinishReason::Stop,
260                    async_openai::types::FinishReason::Length => FinishReason::Length,
261                    async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
262                    async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
263                    _ => FinishReason::Other("unknown".to_string()),
264                });
265
266                ChoiceDelta {
267                    index: choice.index,
268                    delta,
269                    finish_reason,
270                }
271            })
272            .collect();
273
274        Ok(ChatCompletionChunk {
275            id: response.id,
276            model: response.model,
277            choices,
278            usage: None,
279        })
280    }
281}
282
283#[async_trait]
284impl Provider for OpenAiProvider {
285    fn info(&self) -> Arc<ProviderInfo> {
286        self.info.clone()
287    }
288
289    async fn chat_completion(
290        &self,
291        req: ChatCompletionRequest,
292    ) -> Result<ChatCompletionResponse, AiError> {
293        let openai_req = self.build_request(&req)?;
294
295        let response = self
296            .client
297            .chat()
298            .create(openai_req)
299            .await
300            .map_err(|e| AiError::provider(format!("OpenAI API error: {}", e)))?;
301
302        self.convert_response(response)
303    }
304
305    async fn stream_chat_completion(
306        &self,
307        req: ChatCompletionRequest,
308    ) -> Result<Box<ChatCompletionStream>, AiError> {
309        let mut openai_req = self.build_request(&req)?;
310        openai_req.stream = Some(true);
311
312        let stream = self
313            .client
314            .chat()
315            .create_stream(openai_req)
316            .await
317            .map_err(|e| AiError::provider(format!("OpenAI API error: {}", e)))?;
318
319        // Convert OpenAI stream to our ChatCompletionStream
320        let chat_stream = stream.map(|result| match result {
321            Ok(response) => Self::convert_stream_chunk(response),
322            Err(e) => Err(AiError::provider(format!("Stream error: {}", e))),
323        });
324
325        Ok(Box::new(chat_stream)
326            as Box<
327                dyn Stream<Item = Result<ChatCompletionChunk, AiError>> + Send + Unpin,
328            >)
329    }
330}
331
332/// Builder for OpenAI provider with custom configuration
333#[derive(Default)]
334pub struct OpenAiBuilder {
335    api_key: Option<String>,
336    api_base: Option<String>,
337    org_id: Option<String>,
338}
339
340impl OpenAiBuilder {
341    /// Set API key
342    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
343        self.api_key = Some(api_key.into());
344        self
345    }
346
347    /// Set API base URL (for OpenAI-compatible APIs like DeepSeek)
348    pub fn api_base(mut self, api_base: impl Into<String>) -> Self {
349        self.api_base = Some(api_base.into());
350        self
351    }
352
353    /// Set organization ID
354    pub fn organization(mut self, org_id: impl Into<String>) -> Self {
355        self.org_id = Some(org_id.into());
356        self
357    }
358
359    /// Build the provider
360    pub fn build(self) -> Result<OpenAiProvider, AiError> {
361        let api_key = self
362            .api_key
363            .ok_or_else(|| AiError::configuration("API key is required"))?;
364
365        let mut config = OpenAIConfig::new().with_api_key(api_key);
366
367        if let Some(api_base) = self.api_base {
368            config = config.with_api_base(api_base);
369        }
370
371        if let Some(org_id) = self.org_id {
372            config = config.with_org_id(org_id);
373        }
374
375        let client = Client::with_config(config);
376
377        Ok(OpenAiProvider {
378            client,
379            info: Arc::new(ProviderInfo {
380                id: "openai".to_string(),
381                name: "OpenAI".to_string(),
382            }),
383        })
384    }
385
386    /// Build a provider with a custom provider ID and name
387    ///
388    /// This is useful for OpenAI-compatible APIs like DeepSeek that use
389    /// the same protocol but different endpoints.
390    pub fn build_with_id(
391        self,
392        provider_id: impl Into<String>,
393        provider_name: impl Into<String>,
394    ) -> Result<OpenAiProvider, AiError> {
395        let api_key = self
396            .api_key
397            .ok_or_else(|| AiError::configuration("API key is required"))?;
398
399        let mut config = OpenAIConfig::new().with_api_key(api_key);
400
401        if let Some(api_base) = self.api_base {
402            config = config.with_api_base(api_base);
403        }
404
405        if let Some(org_id) = self.org_id {
406            config = config.with_org_id(org_id);
407        }
408
409        let client = Client::with_config(config);
410
411        Ok(OpenAiProvider {
412            client,
413            info: Arc::new(ProviderInfo {
414                id: provider_id.into(),
415                name: provider_name.into(),
416            }),
417        })
418    }
419}