1use aidale_core::error::AiError;
8use aidale_core::provider::{ChatCompletionStream, Provider};
9use aidale_core::types::*;
10use async_openai::config::OpenAIConfig;
11use async_openai::types::{
12 ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
13 ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequest,
14 CreateChatCompletionRequestArgs, CreateChatCompletionStreamResponse,
15 ResponseFormat as OpenAIResponseFormat,
16 ResponseFormatJsonSchema as OpenAIResponseFormatJsonSchema,
17};
18use async_openai::Client;
19use async_trait::async_trait;
20use futures::stream::{Stream, StreamExt};
21use std::sync::Arc;
22
23#[derive(Clone)]
25pub struct OpenAiProvider {
26 client: Client<OpenAIConfig>,
27 info: Arc<ProviderInfo>,
28}
29
30impl std::fmt::Debug for OpenAiProvider {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("OpenAiProvider")
33 .field("info", &self.info)
34 .finish()
35 }
36}
37
38impl OpenAiProvider {
39 pub fn new(api_key: impl Into<String>) -> Self {
41 let config = OpenAIConfig::new().with_api_key(api_key);
42 let client = Client::with_config(config);
43
44 Self {
45 client,
46 info: Arc::new(ProviderInfo {
47 id: "openai".to_string(),
48 name: "OpenAI".to_string(),
49 }),
50 }
51 }
52
53 pub fn builder() -> OpenAiBuilder {
55 OpenAiBuilder::default()
56 }
57
58 fn convert_message(msg: &Message) -> Result<ChatCompletionRequestMessage, AiError> {
60 let content = msg
62 .content
63 .iter()
64 .filter_map(|part| match part {
65 ContentPart::Text { text } => Some(text.clone()),
66 _ => None,
67 })
68 .collect::<Vec<_>>()
69 .join("\n");
70
71 match msg.role {
72 Role::System => {
73 let msg = ChatCompletionRequestSystemMessageArgs::default()
74 .content(content)
75 .build()
76 .map_err(|e| {
77 AiError::provider(format!("Failed to build system message: {}", e))
78 })?;
79 Ok(ChatCompletionRequestMessage::System(msg))
80 }
81 Role::User => {
82 let msg = ChatCompletionRequestUserMessageArgs::default()
83 .content(content)
84 .build()
85 .map_err(|e| {
86 AiError::provider(format!("Failed to build user message: {}", e))
87 })?;
88 Ok(ChatCompletionRequestMessage::User(msg))
89 }
90 Role::Assistant => {
91 let msg = async_openai::types::ChatCompletionRequestAssistantMessageArgs::default()
92 .content(content)
93 .build()
94 .map_err(|e| {
95 AiError::provider(format!("Failed to build assistant message: {}", e))
96 })?;
97 Ok(ChatCompletionRequestMessage::Assistant(msg))
98 }
99 Role::Tool => {
100 let msg = ChatCompletionRequestSystemMessageArgs::default()
102 .content(content)
103 .build()
104 .map_err(|e| {
105 AiError::provider(format!("Failed to build tool message: {}", e))
106 })?;
107 Ok(ChatCompletionRequestMessage::System(msg))
108 }
109 }
110 }
111
112 fn convert_response_format(format: &ResponseFormat) -> Result<OpenAIResponseFormat, AiError> {
114 match format {
115 ResponseFormat::Text => Ok(OpenAIResponseFormat::Text),
116 ResponseFormat::JsonObject => Ok(OpenAIResponseFormat::JsonObject),
117 ResponseFormat::JsonSchema {
118 name,
119 schema,
120 strict,
121 } => {
122 let json_schema = OpenAIResponseFormatJsonSchema {
123 name: name.clone(),
124 schema: Some(schema.clone()),
125 strict: Some(*strict),
126 description: None,
127 };
128 Ok(OpenAIResponseFormat::JsonSchema { json_schema })
129 }
130 }
131 }
132
133 fn build_request(
135 &self,
136 req: &ChatCompletionRequest,
137 ) -> Result<CreateChatCompletionRequest, AiError> {
138 let messages: Result<Vec<_>, _> = req.messages.iter().map(Self::convert_message).collect();
139
140 let mut builder = CreateChatCompletionRequestArgs::default();
141 builder.model(&req.model).messages(messages?);
142
143 if let Some(max_tokens) = req.max_tokens {
144 builder.max_tokens(max_tokens);
145 }
146 if let Some(temperature) = req.temperature {
147 builder.temperature(temperature);
148 }
149 if let Some(top_p) = req.top_p {
150 builder.top_p(top_p);
151 }
152 if let Some(frequency_penalty) = req.frequency_penalty {
153 builder.frequency_penalty(frequency_penalty);
154 }
155 if let Some(presence_penalty) = req.presence_penalty {
156 builder.presence_penalty(presence_penalty);
157 }
158 if let Some(stop) = &req.stop {
159 builder.stop(stop.clone());
160 }
161 if let Some(response_format) = &req.response_format {
162 builder.response_format(Self::convert_response_format(response_format)?);
163 }
164 if let Some(stream) = req.stream {
165 builder.stream(stream);
166 }
167
168 builder
169 .build()
170 .map_err(|e| AiError::provider(format!("Failed to build request: {}", e)))
171 }
172
173 fn convert_response(
175 &self,
176 response: async_openai::types::CreateChatCompletionResponse,
177 ) -> Result<ChatCompletionResponse, AiError> {
178 let choices = response
179 .choices
180 .into_iter()
181 .map(|choice| {
182 let message = Message {
183 role: match choice.message.role {
184 async_openai::types::Role::System => Role::System,
185 async_openai::types::Role::User => Role::User,
186 async_openai::types::Role::Assistant => Role::Assistant,
187 async_openai::types::Role::Tool => Role::Tool,
188 _ => Role::Assistant,
189 },
190 content: vec![ContentPart::Text {
191 text: choice.message.content.unwrap_or_default(),
192 }],
193 name: None, };
195
196 let finish_reason = choice
197 .finish_reason
198 .map_or(FinishReason::Stop, |r| match r {
199 async_openai::types::FinishReason::Stop => FinishReason::Stop,
200 async_openai::types::FinishReason::Length => FinishReason::Length,
201 async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
202 async_openai::types::FinishReason::ContentFilter => {
203 FinishReason::ContentFilter
204 }
205 _ => FinishReason::Other("unknown".to_string()),
206 });
207
208 Choice {
209 index: choice.index,
210 message,
211 finish_reason,
212 }
213 })
214 .collect();
215
216 let usage = response.usage.map_or(
217 Usage {
218 prompt_tokens: 0,
219 completion_tokens: 0,
220 total_tokens: 0,
221 },
222 |u| Usage {
223 prompt_tokens: u.prompt_tokens,
224 completion_tokens: u.completion_tokens,
225 total_tokens: u.total_tokens,
226 },
227 );
228
229 Ok(ChatCompletionResponse {
230 id: response.id,
231 model: response.model,
232 choices,
233 usage,
234 created: Some(response.created as u64),
235 })
236 }
237
238 fn convert_stream_chunk(
240 response: CreateChatCompletionStreamResponse,
241 ) -> Result<ChatCompletionChunk, AiError> {
242 let choices = response
243 .choices
244 .into_iter()
245 .map(|choice| {
246 let delta = MessageDelta {
247 role: choice.delta.role.as_ref().map(|r| match r {
248 async_openai::types::Role::System => Role::System,
249 async_openai::types::Role::User => Role::User,
250 async_openai::types::Role::Assistant => Role::Assistant,
251 async_openai::types::Role::Tool => Role::Tool,
252 _ => Role::Assistant,
253 }),
254 content: choice.delta.content,
255 tool_calls: None,
256 };
257
258 let finish_reason = choice.finish_reason.map(|r| match r {
259 async_openai::types::FinishReason::Stop => FinishReason::Stop,
260 async_openai::types::FinishReason::Length => FinishReason::Length,
261 async_openai::types::FinishReason::ToolCalls => FinishReason::ToolCalls,
262 async_openai::types::FinishReason::ContentFilter => FinishReason::ContentFilter,
263 _ => FinishReason::Other("unknown".to_string()),
264 });
265
266 ChoiceDelta {
267 index: choice.index,
268 delta,
269 finish_reason,
270 }
271 })
272 .collect();
273
274 Ok(ChatCompletionChunk {
275 id: response.id,
276 model: response.model,
277 choices,
278 usage: None,
279 })
280 }
281}
282
283#[async_trait]
284impl Provider for OpenAiProvider {
285 fn info(&self) -> Arc<ProviderInfo> {
286 self.info.clone()
287 }
288
289 async fn chat_completion(
290 &self,
291 req: ChatCompletionRequest,
292 ) -> Result<ChatCompletionResponse, AiError> {
293 let openai_req = self.build_request(&req)?;
294
295 let response = self
296 .client
297 .chat()
298 .create(openai_req)
299 .await
300 .map_err(|e| AiError::provider(format!("OpenAI API error: {}", e)))?;
301
302 self.convert_response(response)
303 }
304
305 async fn stream_chat_completion(
306 &self,
307 req: ChatCompletionRequest,
308 ) -> Result<Box<ChatCompletionStream>, AiError> {
309 let mut openai_req = self.build_request(&req)?;
310 openai_req.stream = Some(true);
311
312 let stream = self
313 .client
314 .chat()
315 .create_stream(openai_req)
316 .await
317 .map_err(|e| AiError::provider(format!("OpenAI API error: {}", e)))?;
318
319 let chat_stream = stream.map(|result| match result {
321 Ok(response) => Self::convert_stream_chunk(response),
322 Err(e) => Err(AiError::provider(format!("Stream error: {}", e))),
323 });
324
325 Ok(Box::new(chat_stream)
326 as Box<
327 dyn Stream<Item = Result<ChatCompletionChunk, AiError>> + Send + Unpin,
328 >)
329 }
330}
331
332#[derive(Default)]
334pub struct OpenAiBuilder {
335 api_key: Option<String>,
336 api_base: Option<String>,
337 org_id: Option<String>,
338}
339
340impl OpenAiBuilder {
341 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
343 self.api_key = Some(api_key.into());
344 self
345 }
346
347 pub fn api_base(mut self, api_base: impl Into<String>) -> Self {
349 self.api_base = Some(api_base.into());
350 self
351 }
352
353 pub fn organization(mut self, org_id: impl Into<String>) -> Self {
355 self.org_id = Some(org_id.into());
356 self
357 }
358
359 pub fn build(self) -> Result<OpenAiProvider, AiError> {
361 let api_key = self
362 .api_key
363 .ok_or_else(|| AiError::configuration("API key is required"))?;
364
365 let mut config = OpenAIConfig::new().with_api_key(api_key);
366
367 if let Some(api_base) = self.api_base {
368 config = config.with_api_base(api_base);
369 }
370
371 if let Some(org_id) = self.org_id {
372 config = config.with_org_id(org_id);
373 }
374
375 let client = Client::with_config(config);
376
377 Ok(OpenAiProvider {
378 client,
379 info: Arc::new(ProviderInfo {
380 id: "openai".to_string(),
381 name: "OpenAI".to_string(),
382 }),
383 })
384 }
385
386 pub fn build_with_id(
391 self,
392 provider_id: impl Into<String>,
393 provider_name: impl Into<String>,
394 ) -> Result<OpenAiProvider, AiError> {
395 let api_key = self
396 .api_key
397 .ok_or_else(|| AiError::configuration("API key is required"))?;
398
399 let mut config = OpenAIConfig::new().with_api_key(api_key);
400
401 if let Some(api_base) = self.api_base {
402 config = config.with_api_base(api_base);
403 }
404
405 if let Some(org_id) = self.org_id {
406 config = config.with_org_id(org_id);
407 }
408
409 let client = Client::with_config(config);
410
411 Ok(OpenAiProvider {
412 client,
413 info: Arc::new(ProviderInfo {
414 id: provider_id.into(),
415 name: provider_name.into(),
416 }),
417 })
418 }
419}