Skip to main content

ares/llm/
openai.rs

1//! OpenAI LLM client implementation
2//!
3//! This module provides integration with OpenAI API and compatible endpoints.
4//!
5//! # Features
6//!
7//! Enable with the `openai` feature flag.
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use ares::llm::{LLMClient, Provider};
13//!
14//! let provider = Provider::OpenAI {
15//!     api_key: "sk-...".to_string(),
16//!     api_base: "https://api.openai.com/v1".to_string(),
17//!     model: "gpt-4".to_string(),
18//! };
19//! let client = provider.create_client().await?;
20//! let response = client.generate("Hello!").await?;
21//! ```
22
23use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_openai::{
27    config::OpenAIConfig,
28    types::chat::{
29        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
30        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
31        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
32        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
33        CreateChatCompletionRequestArgs, FunctionCall, FunctionObject,
34    },
35    Client,
36};
37use async_trait::async_trait;
38use futures::StreamExt;
39
40/// OpenAI client for API-based inference
41pub struct OpenAIClient {
42    client: Client<OpenAIConfig>,
43    model: String,
44    params: ModelParams,
45}
46
47impl OpenAIClient {
48    /// Create a new OpenAI client
49    ///
50    /// # Arguments
51    ///
52    /// * `api_key` - OpenAI API key
53    /// * `api_base` - Base URL for the API (e.g., `https://api.openai.com/v1`)
54    /// * `model` - Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
55    pub fn new(api_key: String, api_base: String, model: String) -> Self {
56        Self::with_params(api_key, api_base, model, ModelParams::default())
57    }
58
59    /// Create a new OpenAI client with model parameters
60    ///
61    /// # Arguments
62    ///
63    /// * `api_key` - OpenAI API key
64    /// * `api_base` - Base URL for the API (e.g., `https://api.openai.com/v1`)
65    /// * `model` - Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
66    /// * `params` - Model inference parameters (temperature, max_tokens, etc.)
67    pub fn with_params(
68        api_key: String,
69        api_base: String,
70        model: String,
71        params: ModelParams,
72    ) -> Self {
73        let config = OpenAIConfig::new()
74            .with_api_key(api_key)
75            .with_api_base(api_base);
76
77        Self {
78            client: Client::with_config(config),
79            model,
80            params,
81        }
82    }
83
84    /// Convert ToolDefinition to ChatCompletionTool
85    fn convert_tool(tool: &ToolDefinition) -> ChatCompletionTools {
86        ChatCompletionTools::Function(ChatCompletionTool {
87            function: FunctionObject {
88                name: tool.name.clone(),
89                description: Some(tool.description.clone()),
90                parameters: Some(tool.parameters.clone()),
91                strict: None,
92            },
93        })
94    }
95
96    /// Extract tool calls from the response message tool calls
97    fn extract_tool_calls(tool_calls: &[ChatCompletionMessageToolCalls]) -> Vec<ToolCall> {
98        tool_calls
99            .iter()
100            .filter_map(|wrapper| match wrapper {
101                ChatCompletionMessageToolCalls::Function(call) => Some(ToolCall {
102                    id: call.id.clone(),
103                    name: call.function.name.clone(),
104                    arguments: serde_json::from_str(&call.function.arguments)
105                        .unwrap_or(serde_json::json!({})),
106                }),
107                ChatCompletionMessageToolCalls::Custom(_) => None,
108            })
109            .collect()
110    }
111
112    /// Convert a ConversationMessage to OpenAI's ChatCompletionRequestMessage
113    fn convert_conversation_message(
114        &self,
115        msg: &ConversationMessage,
116    ) -> Result<ChatCompletionRequestMessage> {
117        match msg.role {
118            MessageRole::System => {
119                let system_msg = ChatCompletionRequestSystemMessageArgs::default()
120                    .content(msg.content.clone())
121                    .build()
122                    .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
123                Ok(ChatCompletionRequestMessage::System(system_msg))
124            }
125            MessageRole::User => {
126                let user_msg = ChatCompletionRequestUserMessageArgs::default()
127                    .content(msg.content.clone())
128                    .build()
129                    .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
130                Ok(ChatCompletionRequestMessage::User(user_msg))
131            }
132            MessageRole::Assistant => {
133                let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
134
135                if !msg.content.is_empty() {
136                    builder.content(msg.content.clone());
137                }
138
139                // Convert tool calls if present
140                if !msg.tool_calls.is_empty() {
141                    let openai_tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
142                        .tool_calls
143                        .iter()
144                        .map(|tc| {
145                            ChatCompletionMessageToolCalls::Function(
146                                ChatCompletionMessageToolCall {
147                                    id: tc.id.clone(),
148                                    function: FunctionCall {
149                                        name: tc.name.clone(),
150                                        arguments: serde_json::to_string(&tc.arguments)
151                                            .unwrap_or_else(|_| "{}".to_string()),
152                                    },
153                                },
154                            )
155                        })
156                        .collect();
157                    builder.tool_calls(openai_tool_calls);
158                }
159
160                let assistant_msg = builder.build().map_err(|e| {
161                    AppError::LLM(format!("Failed to build assistant message: {}", e))
162                })?;
163                Ok(ChatCompletionRequestMessage::Assistant(assistant_msg))
164            }
165            MessageRole::Tool => {
166                let tool_call_id = msg.tool_call_id.clone().ok_or_else(|| {
167                    AppError::LLM("Tool message must have a tool_call_id".to_string())
168                })?;
169
170                let tool_msg = ChatCompletionRequestToolMessageArgs::default()
171                    .tool_call_id(tool_call_id)
172                    .content(msg.content.clone())
173                    .build()
174                    .map_err(|e| AppError::LLM(format!("Failed to build tool message: {}", e)))?;
175                Ok(ChatCompletionRequestMessage::Tool(tool_msg))
176            }
177        }
178    }
179}
180
181#[async_trait]
182impl LLMClient for OpenAIClient {
183    async fn generate(&self, prompt: &str) -> Result<String> {
184        let message = ChatCompletionRequestUserMessageArgs::default()
185            .content(prompt)
186            .build()
187            .map_err(|e| AppError::LLM(format!("Failed to build message: {}", e)))?;
188
189        let mut builder = CreateChatCompletionRequestArgs::default();
190        builder.model(&self.model);
191        builder.messages(vec![ChatCompletionRequestMessage::User(message)]);
192
193        // Apply model parameters
194        if let Some(temp) = self.params.temperature {
195            builder.temperature(temp);
196        }
197        if let Some(max_tokens) = self.params.max_tokens {
198            builder.max_completion_tokens(max_tokens);
199        }
200        if let Some(top_p) = self.params.top_p {
201            builder.top_p(top_p);
202        }
203        if let Some(freq_penalty) = self.params.frequency_penalty {
204            builder.frequency_penalty(freq_penalty);
205        }
206        if let Some(pres_penalty) = self.params.presence_penalty {
207            builder.presence_penalty(pres_penalty);
208        }
209
210        let request = builder
211            .build()
212            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
213
214        let response = self
215            .client
216            .chat()
217            .create(request)
218            .await
219            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
220
221        response
222            .choices
223            .first()
224            .and_then(|choice| choice.message.content.clone())
225            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
226    }
227
228    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
229        let system_message = ChatCompletionRequestSystemMessageArgs::default()
230            .content(system)
231            .build()
232            .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
233
234        let user_message = ChatCompletionRequestUserMessageArgs::default()
235            .content(prompt)
236            .build()
237            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
238
239        let mut builder = CreateChatCompletionRequestArgs::default();
240        builder.model(&self.model);
241        builder.messages(vec![
242            ChatCompletionRequestMessage::System(system_message),
243            ChatCompletionRequestMessage::User(user_message),
244        ]);
245
246        // Apply model parameters
247        if let Some(temp) = self.params.temperature {
248            builder.temperature(temp);
249        }
250        if let Some(max_tokens) = self.params.max_tokens {
251            builder.max_completion_tokens(max_tokens);
252        }
253        if let Some(top_p) = self.params.top_p {
254            builder.top_p(top_p);
255        }
256        if let Some(freq_penalty) = self.params.frequency_penalty {
257            builder.frequency_penalty(freq_penalty);
258        }
259        if let Some(pres_penalty) = self.params.presence_penalty {
260            builder.presence_penalty(pres_penalty);
261        }
262
263        let request = builder
264            .build()
265            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
266
267        let response = self
268            .client
269            .chat()
270            .create(request)
271            .await
272            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
273
274        response
275            .choices
276            .first()
277            .and_then(|choice| choice.message.content.clone())
278            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
279    }
280
281    async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<LLMResponse> {
282        let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
283            messages
284                .iter()
285                .map(|(role, content)| {
286                    match role.as_str() {
287                        "system" => {
288                            let msg = ChatCompletionRequestSystemMessageArgs::default()
289                                .content(content.as_str())
290                                .build()
291                                .map_err(|e| {
292                                    AppError::LLM(format!("Failed to build system message: {}", e))
293                                })?;
294                            Ok(ChatCompletionRequestMessage::System(msg))
295                        }
296                        "assistant" => {
297                            let msg = ChatCompletionRequestAssistantMessageArgs::default()
298                                .content(content.as_str())
299                                .build()
300                                .map_err(|e| {
301                                    AppError::LLM(format!(
302                                        "Failed to build assistant message: {}",
303                                        e
304                                    ))
305                                })?;
306                            Ok(ChatCompletionRequestMessage::Assistant(msg))
307                        }
308                        _ => {
309                            // Default to user message
310                            let msg = ChatCompletionRequestUserMessageArgs::default()
311                                .content(content.as_str())
312                                .build()
313                                .map_err(|e| {
314                                    AppError::LLM(format!("Failed to build user message: {}", e))
315                                })?;
316                            Ok(ChatCompletionRequestMessage::User(msg))
317                        }
318                    }
319                })
320                .collect();
321
322        let mut builder = CreateChatCompletionRequestArgs::default();
323        builder.model(&self.model);
324        builder.messages(chat_messages?);
325
326        // Apply model parameters
327        if let Some(temp) = self.params.temperature {
328            builder.temperature(temp);
329        }
330        if let Some(max_tokens) = self.params.max_tokens {
331            builder.max_completion_tokens(max_tokens);
332        }
333        if let Some(top_p) = self.params.top_p {
334            builder.top_p(top_p);
335        }
336        if let Some(freq_penalty) = self.params.frequency_penalty {
337            builder.frequency_penalty(freq_penalty);
338        }
339        if let Some(pres_penalty) = self.params.presence_penalty {
340            builder.presence_penalty(pres_penalty);
341        }
342
343        let request = builder
344            .build()
345            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
346
347        let response = self
348            .client
349            .chat()
350            .create(request)
351            .await
352            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
353
354        let content = response
355            .choices
356            .first()
357            .and_then(|choice| choice.message.content.clone())
358            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
359
360        #[allow(clippy::unnecessary_cast)]
361        let usage = response
362            .usage
363            .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
364
365        Ok(LLMResponse {
366            content,
367            tool_calls: vec![],
368            finish_reason: "stop".to_string(),
369            usage,
370        })
371    }
372
373    async fn generate_with_tools(
374        &self,
375        prompt: &str,
376        tools: &[ToolDefinition],
377    ) -> Result<LLMResponse> {
378        let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
379
380        let user_message = ChatCompletionRequestUserMessageArgs::default()
381            .content(prompt)
382            .build()
383            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
384
385        let mut builder = CreateChatCompletionRequestArgs::default();
386        builder.model(&self.model);
387        builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
388        builder.tools(openai_tools);
389
390        // Apply model parameters
391        if let Some(temp) = self.params.temperature {
392            builder.temperature(temp);
393        }
394        if let Some(max_tokens) = self.params.max_tokens {
395            builder.max_completion_tokens(max_tokens);
396        }
397        if let Some(top_p) = self.params.top_p {
398            builder.top_p(top_p);
399        }
400        if let Some(freq_penalty) = self.params.frequency_penalty {
401            builder.frequency_penalty(freq_penalty);
402        }
403        if let Some(pres_penalty) = self.params.presence_penalty {
404            builder.presence_penalty(pres_penalty);
405        }
406
407        let request = builder
408            .build()
409            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
410
411        let response = self
412            .client
413            .chat()
414            .create(request)
415            .await
416            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
417
418        let choice = response
419            .choices
420            .first()
421            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
422
423        let content = choice.message.content.clone().unwrap_or_default();
424
425        let finish_reason = choice
426            .finish_reason
427            .as_ref()
428            .map(|r| format!("{:?}", r).to_lowercase())
429            .unwrap_or_else(|| "stop".to_string());
430
431        let tool_calls = choice
432            .message
433            .tool_calls
434            .as_ref()
435            .map(|calls| Self::extract_tool_calls(calls))
436            .unwrap_or_default();
437
438        // Extract token usage if available
439        #[allow(clippy::unnecessary_cast)]
440        let usage = response
441            .usage
442            .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
443
444        Ok(LLMResponse {
445            content,
446            tool_calls,
447            finish_reason,
448            usage,
449        })
450    }
451
452    async fn generate_with_tools_and_history(
453        &self,
454        messages: &[ConversationMessage],
455        tools: &[ToolDefinition],
456    ) -> Result<LLMResponse> {
457        // Convert ConversationMessage to OpenAI format
458        let openai_messages: Vec<ChatCompletionRequestMessage> = messages
459            .iter()
460            .map(|msg| self.convert_conversation_message(msg))
461            .collect::<Result<Vec<_>>>()?;
462
463        // Convert tools to OpenAI format
464        let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
465
466        let mut builder = CreateChatCompletionRequestArgs::default();
467        builder.model(&self.model);
468        builder.messages(openai_messages);
469
470        if !openai_tools.is_empty() {
471            builder.tools(openai_tools);
472        }
473
474        // Apply model parameters
475        if let Some(temp) = self.params.temperature {
476            builder.temperature(temp);
477        }
478        if let Some(max_tokens) = self.params.max_tokens {
479            builder.max_completion_tokens(max_tokens);
480        }
481        if let Some(top_p) = self.params.top_p {
482            builder.top_p(top_p);
483        }
484        if let Some(freq_penalty) = self.params.frequency_penalty {
485            builder.frequency_penalty(freq_penalty);
486        }
487        if let Some(pres_penalty) = self.params.presence_penalty {
488            builder.presence_penalty(pres_penalty);
489        }
490
491        let request = builder
492            .build()
493            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
494
495        let response = self
496            .client
497            .chat()
498            .create(request)
499            .await
500            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
501
502        let choice = response
503            .choices
504            .first()
505            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
506
507        let content = choice.message.content.clone().unwrap_or_default();
508
509        let finish_reason = choice
510            .finish_reason
511            .as_ref()
512            .map(|r| format!("{:?}", r).to_lowercase())
513            .unwrap_or_else(|| "stop".to_string());
514
515        let tool_calls = choice
516            .message
517            .tool_calls
518            .as_ref()
519            .map(|calls| Self::extract_tool_calls(calls))
520            .unwrap_or_default();
521
522        #[allow(clippy::unnecessary_cast)]
523        let usage = response
524            .usage
525            .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
526
527        Ok(LLMResponse {
528            content,
529            tool_calls,
530            finish_reason,
531            usage,
532        })
533    }
534
535    async fn stream(
536        &self,
537        prompt: &str,
538    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
539        let user_message = ChatCompletionRequestUserMessageArgs::default()
540            .content(prompt)
541            .build()
542            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
543
544        let mut builder = CreateChatCompletionRequestArgs::default();
545        builder.model(&self.model);
546        builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
547
548        // Apply model parameters
549        if let Some(temp) = self.params.temperature {
550            builder.temperature(temp);
551        }
552        if let Some(max_tokens) = self.params.max_tokens {
553            builder.max_completion_tokens(max_tokens);
554        }
555        if let Some(top_p) = self.params.top_p {
556            builder.top_p(top_p);
557        }
558        if let Some(freq_penalty) = self.params.frequency_penalty {
559            builder.frequency_penalty(freq_penalty);
560        }
561        if let Some(pres_penalty) = self.params.presence_penalty {
562            builder.presence_penalty(pres_penalty);
563        }
564
565        let request = builder
566            .build()
567            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
568
569        let mut stream = self
570            .client
571            .chat()
572            .create_stream(request)
573            .await
574            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
575
576        let result_stream = async_stream::stream! {
577            while let Some(result) = stream.next().await {
578                match result {
579                    Ok(response) => {
580                        for choice in response.choices {
581                            if let Some(content) = choice.delta.content {
582                                yield Ok(content);
583                            }
584                        }
585                    }
586                    Err(e) => {
587                        yield Err(AppError::LLM(format!("Stream error: {}", e)));
588                    }
589                }
590            }
591        };
592
593        Ok(Box::new(Box::pin(result_stream)))
594    }
595
596    async fn stream_with_system(
597        &self,
598        system: &str,
599        prompt: &str,
600    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
601        let system_message = ChatCompletionRequestSystemMessageArgs::default()
602            .content(system)
603            .build()
604            .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
605
606        let user_message = ChatCompletionRequestUserMessageArgs::default()
607            .content(prompt)
608            .build()
609            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
610
611        let mut builder = CreateChatCompletionRequestArgs::default();
612        builder.model(&self.model);
613        builder.messages(vec![
614            ChatCompletionRequestMessage::System(system_message),
615            ChatCompletionRequestMessage::User(user_message),
616        ]);
617
618        // Apply model parameters
619        if let Some(temp) = self.params.temperature {
620            builder.temperature(temp);
621        }
622        if let Some(max_tokens) = self.params.max_tokens {
623            builder.max_completion_tokens(max_tokens);
624        }
625        if let Some(top_p) = self.params.top_p {
626            builder.top_p(top_p);
627        }
628        if let Some(freq_penalty) = self.params.frequency_penalty {
629            builder.frequency_penalty(freq_penalty);
630        }
631        if let Some(pres_penalty) = self.params.presence_penalty {
632            builder.presence_penalty(pres_penalty);
633        }
634
635        let request = builder
636            .build()
637            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
638
639        let mut stream = self
640            .client
641            .chat()
642            .create_stream(request)
643            .await
644            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
645
646        let result_stream = async_stream::stream! {
647            while let Some(result) = stream.next().await {
648                match result {
649                    Ok(response) => {
650                        for choice in response.choices {
651                            if let Some(content) = choice.delta.content {
652                                yield Ok(content);
653                            }
654                        }
655                    }
656                    Err(e) => {
657                        yield Err(AppError::LLM(format!("Stream error: {}", e)));
658                    }
659                }
660            }
661        };
662
663        Ok(Box::new(Box::pin(result_stream)))
664    }
665
666    async fn stream_with_history(
667        &self,
668        messages: &[(String, String)],
669    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
670        let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
671            messages
672                .iter()
673                .map(|(role, content)| {
674                    match role.as_str() {
675                        "system" => {
676                            let msg = ChatCompletionRequestSystemMessageArgs::default()
677                                .content(content.as_str())
678                                .build()
679                                .map_err(|e| {
680                                    AppError::LLM(format!("Failed to build system message: {}", e))
681                                })?;
682                            Ok(ChatCompletionRequestMessage::System(msg))
683                        }
684                        "assistant" => {
685                            let msg = ChatCompletionRequestAssistantMessageArgs::default()
686                                .content(content.as_str())
687                                .build()
688                                .map_err(|e| {
689                                    AppError::LLM(format!(
690                                        "Failed to build assistant message: {}",
691                                        e
692                                    ))
693                                })?;
694                            Ok(ChatCompletionRequestMessage::Assistant(msg))
695                        }
696                        _ => {
697                            // Default to user message
698                            let msg = ChatCompletionRequestUserMessageArgs::default()
699                                .content(content.as_str())
700                                .build()
701                                .map_err(|e| {
702                                    AppError::LLM(format!("Failed to build user message: {}", e))
703                                })?;
704                            Ok(ChatCompletionRequestMessage::User(msg))
705                        }
706                    }
707                })
708                .collect();
709
710        let mut builder = CreateChatCompletionRequestArgs::default();
711        builder.model(&self.model);
712        builder.messages(chat_messages?);
713
714        // Apply model parameters
715        if let Some(temp) = self.params.temperature {
716            builder.temperature(temp);
717        }
718        if let Some(max_tokens) = self.params.max_tokens {
719            builder.max_completion_tokens(max_tokens);
720        }
721        if let Some(top_p) = self.params.top_p {
722            builder.top_p(top_p);
723        }
724        if let Some(freq_penalty) = self.params.frequency_penalty {
725            builder.frequency_penalty(freq_penalty);
726        }
727        if let Some(pres_penalty) = self.params.presence_penalty {
728            builder.presence_penalty(pres_penalty);
729        }
730
731        let request = builder
732            .build()
733            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
734
735        let mut stream = self
736            .client
737            .chat()
738            .create_stream(request)
739            .await
740            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
741
742        let result_stream = async_stream::stream! {
743            while let Some(result) = stream.next().await {
744                match result {
745                    Ok(response) => {
746                        for choice in response.choices {
747                            if let Some(content) = choice.delta.content {
748                                yield Ok(content);
749                            }
750                        }
751                    }
752                    Err(e) => {
753                        yield Err(AppError::LLM(format!("Stream error: {}", e)));
754                    }
755                }
756            }
757        };
758
759        Ok(Box::new(Box::pin(result_stream)))
760    }
761
762    fn model_name(&self) -> &str {
763        &self.model
764    }
765}
766
767#[cfg(test)]
768mod tests {
769    use super::*;
770
771    #[test]
772    fn test_client_creation() {
773        let client = OpenAIClient::new(
774            "test-key".to_string(),
775            "https://api.openai.com/v1".to_string(),
776            "gpt-4".to_string(),
777        );
778
779        assert_eq!(client.model_name(), "gpt-4");
780    }
781
782    #[test]
783    fn test_tool_conversion() {
784        let tool = ToolDefinition {
785            name: "calculator".to_string(),
786            description: "Performs math operations".to_string(),
787            parameters: serde_json::json!({
788                "type": "object",
789                "properties": {
790                    "operation": {"type": "string"},
791                    "a": {"type": "number"},
792                    "b": {"type": "number"}
793                },
794                "required": ["operation", "a", "b"]
795            }),
796        };
797
798        let openai_tool = OpenAIClient::convert_tool(&tool);
799        match openai_tool {
800            ChatCompletionTools::Function(chat_tool) => {
801                assert_eq!(chat_tool.function.name, "calculator");
802                assert_eq!(
803                    chat_tool.function.description,
804                    Some("Performs math operations".to_string())
805                );
806            }
807            ChatCompletionTools::Custom(_) => {
808                panic!("Expected Function variant, got Custom");
809            }
810        }
811    }
812}