Skip to main content

ares/llm/
openai.rs

1//! OpenAI LLM client implementation
2//!
3//! This module provides integration with OpenAI API and compatible endpoints.
4//!
5//! # Features
6//!
7//! Enable with the `openai` feature flag.
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use ares::llm::{LLMClient, Provider};
13//!
14//! let provider = Provider::OpenAI {
15//!     api_key: "sk-...".to_string(),
16//!     api_base: "https://api.openai.com/v1".to_string(),
17//!     model: "gpt-4".to_string(),
18//! };
19//! let client = provider.create_client().await?;
20//! let response = client.generate("Hello!").await?;
21//! ```
22
23use crate::llm::client::{LLMClient, LLMResponse, ModelParams, TokenUsage};
24use crate::llm::coordinator::{ConversationMessage, MessageRole};
25use crate::types::{AppError, Result, ToolCall, ToolDefinition};
26use async_openai::{
27    config::OpenAIConfig,
28    types::chat::{
29        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
30        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
31        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
32        ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionTools,
33        CreateChatCompletionRequestArgs, FunctionCall, FunctionObject,
34    },
35    Client,
36};
37use async_trait::async_trait;
38use futures::StreamExt;
39
40/// OpenAI client for API-based inference
41pub struct OpenAIClient {
42    client: Client<OpenAIConfig>,
43    model: String,
44    params: ModelParams,
45}
46
47impl OpenAIClient {
48    /// Create a new OpenAI client
49    ///
50    /// # Arguments
51    ///
52    /// * `api_key` - OpenAI API key
53    /// * `api_base` - Base URL for the API (e.g., `https://api.openai.com/v1`)
54    /// * `model` - Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
55    pub fn new(api_key: String, api_base: String, model: String) -> Self {
56        Self::with_params(api_key, api_base, model, ModelParams::default())
57    }
58
59    /// Create a new OpenAI client with model parameters
60    ///
61    /// # Arguments
62    ///
63    /// * `api_key` - OpenAI API key
64    /// * `api_base` - Base URL for the API (e.g., `https://api.openai.com/v1`)
65    /// * `model` - Model identifier (e.g., "gpt-4", "gpt-3.5-turbo")
66    /// * `params` - Model inference parameters (temperature, max_tokens, etc.)
67    pub fn with_params(
68        api_key: String,
69        api_base: String,
70        model: String,
71        params: ModelParams,
72    ) -> Self {
73        let config = OpenAIConfig::new()
74            .with_api_key(api_key)
75            .with_api_base(api_base);
76
77        Self {
78            client: Client::with_config(config),
79            model,
80            params,
81        }
82    }
83
84    /// Convert ToolDefinition to ChatCompletionTool
85    fn convert_tool(tool: &ToolDefinition) -> ChatCompletionTools {
86        ChatCompletionTools::Function(ChatCompletionTool {
87            function: FunctionObject {
88                name: tool.name.clone(),
89                description: Some(tool.description.clone()),
90                parameters: Some(tool.parameters.clone()),
91                strict: None,
92            },
93        })
94    }
95
96    /// Extract tool calls from the response message tool calls
97    fn extract_tool_calls(tool_calls: &[ChatCompletionMessageToolCalls]) -> Vec<ToolCall> {
98        tool_calls
99            .iter()
100            .filter_map(|wrapper| match wrapper {
101                ChatCompletionMessageToolCalls::Function(call) => Some(ToolCall {
102                    id: call.id.clone(),
103                    name: call.function.name.clone(),
104                    arguments: serde_json::from_str(&call.function.arguments)
105                        .unwrap_or(serde_json::json!({})),
106                }),
107                ChatCompletionMessageToolCalls::Custom(_) => None,
108            })
109            .collect()
110    }
111
112    /// Convert a ConversationMessage to OpenAI's ChatCompletionRequestMessage
113    fn convert_conversation_message(
114        &self,
115        msg: &ConversationMessage,
116    ) -> Result<ChatCompletionRequestMessage> {
117        match msg.role {
118            MessageRole::System => {
119                let system_msg = ChatCompletionRequestSystemMessageArgs::default()
120                    .content(msg.content.clone())
121                    .build()
122                    .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
123                Ok(ChatCompletionRequestMessage::System(system_msg))
124            }
125            MessageRole::User => {
126                let user_msg = ChatCompletionRequestUserMessageArgs::default()
127                    .content(msg.content.clone())
128                    .build()
129                    .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
130                Ok(ChatCompletionRequestMessage::User(user_msg))
131            }
132            MessageRole::Assistant => {
133                let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
134
135                if !msg.content.is_empty() {
136                    builder.content(msg.content.clone());
137                }
138
139                // Convert tool calls if present
140                if !msg.tool_calls.is_empty() {
141                    let openai_tool_calls: Vec<ChatCompletionMessageToolCalls> = msg
142                        .tool_calls
143                        .iter()
144                        .map(|tc| {
145                            ChatCompletionMessageToolCalls::Function(
146                                ChatCompletionMessageToolCall {
147                                    id: tc.id.clone(),
148                                    function: FunctionCall {
149                                        name: tc.name.clone(),
150                                        arguments: serde_json::to_string(&tc.arguments)
151                                            .unwrap_or_else(|_| "{}".to_string()),
152                                    },
153                                },
154                            )
155                        })
156                        .collect();
157                    builder.tool_calls(openai_tool_calls);
158                }
159
160                let assistant_msg = builder.build().map_err(|e| {
161                    AppError::LLM(format!("Failed to build assistant message: {}", e))
162                })?;
163                Ok(ChatCompletionRequestMessage::Assistant(assistant_msg))
164            }
165            MessageRole::Tool => {
166                let tool_call_id = msg.tool_call_id.clone().ok_or_else(|| {
167                    AppError::LLM("Tool message must have a tool_call_id".to_string())
168                })?;
169
170                let tool_msg = ChatCompletionRequestToolMessageArgs::default()
171                    .tool_call_id(tool_call_id)
172                    .content(msg.content.clone())
173                    .build()
174                    .map_err(|e| AppError::LLM(format!("Failed to build tool message: {}", e)))?;
175                Ok(ChatCompletionRequestMessage::Tool(tool_msg))
176            }
177        }
178    }
179}
180
181#[async_trait]
182impl LLMClient for OpenAIClient {
183    async fn generate(&self, prompt: &str) -> Result<String> {
184        let message = ChatCompletionRequestUserMessageArgs::default()
185            .content(prompt)
186            .build()
187            .map_err(|e| AppError::LLM(format!("Failed to build message: {}", e)))?;
188
189        let mut builder = CreateChatCompletionRequestArgs::default();
190        builder.model(&self.model);
191        builder.messages(vec![ChatCompletionRequestMessage::User(message)]);
192
193        // Apply model parameters
194        if let Some(temp) = self.params.temperature {
195            builder.temperature(temp);
196        }
197        if let Some(max_tokens) = self.params.max_tokens {
198            builder.max_completion_tokens(max_tokens);
199        }
200        if let Some(top_p) = self.params.top_p {
201            builder.top_p(top_p);
202        }
203        if let Some(freq_penalty) = self.params.frequency_penalty {
204            builder.frequency_penalty(freq_penalty);
205        }
206        if let Some(pres_penalty) = self.params.presence_penalty {
207            builder.presence_penalty(pres_penalty);
208        }
209
210        let request = builder
211            .build()
212            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
213
214        let response = self
215            .client
216            .chat()
217            .create(request)
218            .await
219            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
220
221        response
222            .choices
223            .first()
224            .and_then(|choice| choice.message.content.clone())
225            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
226    }
227
228    async fn generate_with_system(&self, system: &str, prompt: &str) -> Result<String> {
229        let system_message = ChatCompletionRequestSystemMessageArgs::default()
230            .content(system)
231            .build()
232            .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
233
234        let user_message = ChatCompletionRequestUserMessageArgs::default()
235            .content(prompt)
236            .build()
237            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
238
239        let mut builder = CreateChatCompletionRequestArgs::default();
240        builder.model(&self.model);
241        builder.messages(vec![
242            ChatCompletionRequestMessage::System(system_message),
243            ChatCompletionRequestMessage::User(user_message),
244        ]);
245
246        // Apply model parameters
247        if let Some(temp) = self.params.temperature {
248            builder.temperature(temp);
249        }
250        if let Some(max_tokens) = self.params.max_tokens {
251            builder.max_completion_tokens(max_tokens);
252        }
253        if let Some(top_p) = self.params.top_p {
254            builder.top_p(top_p);
255        }
256        if let Some(freq_penalty) = self.params.frequency_penalty {
257            builder.frequency_penalty(freq_penalty);
258        }
259        if let Some(pres_penalty) = self.params.presence_penalty {
260            builder.presence_penalty(pres_penalty);
261        }
262
263        let request = builder
264            .build()
265            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
266
267        let response = self
268            .client
269            .chat()
270            .create(request)
271            .await
272            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
273
274        response
275            .choices
276            .first()
277            .and_then(|choice| choice.message.content.clone())
278            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
279    }
280
281    async fn generate_with_history(&self, messages: &[(String, String)]) -> Result<String> {
282        let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
283            messages
284                .iter()
285                .map(|(role, content)| {
286                    match role.as_str() {
287                        "system" => {
288                            let msg = ChatCompletionRequestSystemMessageArgs::default()
289                                .content(content.as_str())
290                                .build()
291                                .map_err(|e| {
292                                    AppError::LLM(format!("Failed to build system message: {}", e))
293                                })?;
294                            Ok(ChatCompletionRequestMessage::System(msg))
295                        }
296                        "assistant" => {
297                            let msg = ChatCompletionRequestAssistantMessageArgs::default()
298                                .content(content.as_str())
299                                .build()
300                                .map_err(|e| {
301                                    AppError::LLM(format!(
302                                        "Failed to build assistant message: {}",
303                                        e
304                                    ))
305                                })?;
306                            Ok(ChatCompletionRequestMessage::Assistant(msg))
307                        }
308                        _ => {
309                            // Default to user message
310                            let msg = ChatCompletionRequestUserMessageArgs::default()
311                                .content(content.as_str())
312                                .build()
313                                .map_err(|e| {
314                                    AppError::LLM(format!("Failed to build user message: {}", e))
315                                })?;
316                            Ok(ChatCompletionRequestMessage::User(msg))
317                        }
318                    }
319                })
320                .collect();
321
322        let mut builder = CreateChatCompletionRequestArgs::default();
323        builder.model(&self.model);
324        builder.messages(chat_messages?);
325
326        // Apply model parameters
327        if let Some(temp) = self.params.temperature {
328            builder.temperature(temp);
329        }
330        if let Some(max_tokens) = self.params.max_tokens {
331            builder.max_completion_tokens(max_tokens);
332        }
333        if let Some(top_p) = self.params.top_p {
334            builder.top_p(top_p);
335        }
336        if let Some(freq_penalty) = self.params.frequency_penalty {
337            builder.frequency_penalty(freq_penalty);
338        }
339        if let Some(pres_penalty) = self.params.presence_penalty {
340            builder.presence_penalty(pres_penalty);
341        }
342
343        let request = builder
344            .build()
345            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
346
347        let response = self
348            .client
349            .chat()
350            .create(request)
351            .await
352            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
353
354        response
355            .choices
356            .first()
357            .and_then(|choice| choice.message.content.clone())
358            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))
359    }
360
361    async fn generate_with_tools(
362        &self,
363        prompt: &str,
364        tools: &[ToolDefinition],
365    ) -> Result<LLMResponse> {
366        let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
367
368        let user_message = ChatCompletionRequestUserMessageArgs::default()
369            .content(prompt)
370            .build()
371            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
372
373        let mut builder = CreateChatCompletionRequestArgs::default();
374        builder.model(&self.model);
375        builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
376        builder.tools(openai_tools);
377
378        // Apply model parameters
379        if let Some(temp) = self.params.temperature {
380            builder.temperature(temp);
381        }
382        if let Some(max_tokens) = self.params.max_tokens {
383            builder.max_completion_tokens(max_tokens);
384        }
385        if let Some(top_p) = self.params.top_p {
386            builder.top_p(top_p);
387        }
388        if let Some(freq_penalty) = self.params.frequency_penalty {
389            builder.frequency_penalty(freq_penalty);
390        }
391        if let Some(pres_penalty) = self.params.presence_penalty {
392            builder.presence_penalty(pres_penalty);
393        }
394
395        let request = builder
396            .build()
397            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
398
399        let response = self
400            .client
401            .chat()
402            .create(request)
403            .await
404            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
405
406        let choice = response
407            .choices
408            .first()
409            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
410
411        let content = choice.message.content.clone().unwrap_or_default();
412
413        let finish_reason = choice
414            .finish_reason
415            .as_ref()
416            .map(|r| format!("{:?}", r).to_lowercase())
417            .unwrap_or_else(|| "stop".to_string());
418
419        let tool_calls = choice
420            .message
421            .tool_calls
422            .as_ref()
423            .map(|calls| Self::extract_tool_calls(calls))
424            .unwrap_or_default();
425
426        // Extract token usage if available
427        #[allow(clippy::unnecessary_cast)]
428        let usage = response
429            .usage
430            .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
431
432        Ok(LLMResponse {
433            content,
434            tool_calls,
435            finish_reason,
436            usage,
437        })
438    }
439
440    async fn generate_with_tools_and_history(
441        &self,
442        messages: &[ConversationMessage],
443        tools: &[ToolDefinition],
444    ) -> Result<LLMResponse> {
445        // Convert ConversationMessage to OpenAI format
446        let openai_messages: Vec<ChatCompletionRequestMessage> = messages
447            .iter()
448            .map(|msg| self.convert_conversation_message(msg))
449            .collect::<Result<Vec<_>>>()?;
450
451        // Convert tools to OpenAI format
452        let openai_tools: Vec<ChatCompletionTools> = tools.iter().map(Self::convert_tool).collect();
453
454        let mut builder = CreateChatCompletionRequestArgs::default();
455        builder.model(&self.model);
456        builder.messages(openai_messages);
457
458        if !openai_tools.is_empty() {
459            builder.tools(openai_tools);
460        }
461
462        // Apply model parameters
463        if let Some(temp) = self.params.temperature {
464            builder.temperature(temp);
465        }
466        if let Some(max_tokens) = self.params.max_tokens {
467            builder.max_completion_tokens(max_tokens);
468        }
469        if let Some(top_p) = self.params.top_p {
470            builder.top_p(top_p);
471        }
472        if let Some(freq_penalty) = self.params.frequency_penalty {
473            builder.frequency_penalty(freq_penalty);
474        }
475        if let Some(pres_penalty) = self.params.presence_penalty {
476            builder.presence_penalty(pres_penalty);
477        }
478
479        let request = builder
480            .build()
481            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
482
483        let response = self
484            .client
485            .chat()
486            .create(request)
487            .await
488            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
489
490        let choice = response
491            .choices
492            .first()
493            .ok_or_else(|| AppError::LLM("No response from OpenAI".to_string()))?;
494
495        let content = choice.message.content.clone().unwrap_or_default();
496
497        let finish_reason = choice
498            .finish_reason
499            .as_ref()
500            .map(|r| format!("{:?}", r).to_lowercase())
501            .unwrap_or_else(|| "stop".to_string());
502
503        let tool_calls = choice
504            .message
505            .tool_calls
506            .as_ref()
507            .map(|calls| Self::extract_tool_calls(calls))
508            .unwrap_or_default();
509
510        #[allow(clippy::unnecessary_cast)]
511        let usage = response
512            .usage
513            .map(|u| TokenUsage::new(u.prompt_tokens as u32, u.completion_tokens as u32));
514
515        Ok(LLMResponse {
516            content,
517            tool_calls,
518            finish_reason,
519            usage,
520        })
521    }
522
523    async fn stream(
524        &self,
525        prompt: &str,
526    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
527        let user_message = ChatCompletionRequestUserMessageArgs::default()
528            .content(prompt)
529            .build()
530            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
531
532        let mut builder = CreateChatCompletionRequestArgs::default();
533        builder.model(&self.model);
534        builder.messages(vec![ChatCompletionRequestMessage::User(user_message)]);
535
536        // Apply model parameters
537        if let Some(temp) = self.params.temperature {
538            builder.temperature(temp);
539        }
540        if let Some(max_tokens) = self.params.max_tokens {
541            builder.max_completion_tokens(max_tokens);
542        }
543        if let Some(top_p) = self.params.top_p {
544            builder.top_p(top_p);
545        }
546        if let Some(freq_penalty) = self.params.frequency_penalty {
547            builder.frequency_penalty(freq_penalty);
548        }
549        if let Some(pres_penalty) = self.params.presence_penalty {
550            builder.presence_penalty(pres_penalty);
551        }
552
553        let request = builder
554            .build()
555            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
556
557        let mut stream = self
558            .client
559            .chat()
560            .create_stream(request)
561            .await
562            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
563
564        let result_stream = async_stream::stream! {
565            while let Some(result) = stream.next().await {
566                match result {
567                    Ok(response) => {
568                        for choice in response.choices {
569                            if let Some(content) = choice.delta.content {
570                                yield Ok(content);
571                            }
572                        }
573                    }
574                    Err(e) => {
575                        yield Err(AppError::LLM(format!("Stream error: {}", e)));
576                    }
577                }
578            }
579        };
580
581        Ok(Box::new(Box::pin(result_stream)))
582    }
583
584    async fn stream_with_system(
585        &self,
586        system: &str,
587        prompt: &str,
588    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
589        let system_message = ChatCompletionRequestSystemMessageArgs::default()
590            .content(system)
591            .build()
592            .map_err(|e| AppError::LLM(format!("Failed to build system message: {}", e)))?;
593
594        let user_message = ChatCompletionRequestUserMessageArgs::default()
595            .content(prompt)
596            .build()
597            .map_err(|e| AppError::LLM(format!("Failed to build user message: {}", e)))?;
598
599        let mut builder = CreateChatCompletionRequestArgs::default();
600        builder.model(&self.model);
601        builder.messages(vec![
602            ChatCompletionRequestMessage::System(system_message),
603            ChatCompletionRequestMessage::User(user_message),
604        ]);
605
606        // Apply model parameters
607        if let Some(temp) = self.params.temperature {
608            builder.temperature(temp);
609        }
610        if let Some(max_tokens) = self.params.max_tokens {
611            builder.max_completion_tokens(max_tokens);
612        }
613        if let Some(top_p) = self.params.top_p {
614            builder.top_p(top_p);
615        }
616        if let Some(freq_penalty) = self.params.frequency_penalty {
617            builder.frequency_penalty(freq_penalty);
618        }
619        if let Some(pres_penalty) = self.params.presence_penalty {
620            builder.presence_penalty(pres_penalty);
621        }
622
623        let request = builder
624            .build()
625            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
626
627        let mut stream = self
628            .client
629            .chat()
630            .create_stream(request)
631            .await
632            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
633
634        let result_stream = async_stream::stream! {
635            while let Some(result) = stream.next().await {
636                match result {
637                    Ok(response) => {
638                        for choice in response.choices {
639                            if let Some(content) = choice.delta.content {
640                                yield Ok(content);
641                            }
642                        }
643                    }
644                    Err(e) => {
645                        yield Err(AppError::LLM(format!("Stream error: {}", e)));
646                    }
647                }
648            }
649        };
650
651        Ok(Box::new(Box::pin(result_stream)))
652    }
653
654    async fn stream_with_history(
655        &self,
656        messages: &[(String, String)],
657    ) -> Result<Box<dyn futures::Stream<Item = Result<String>> + Send + Unpin>> {
658        let chat_messages: std::result::Result<Vec<ChatCompletionRequestMessage>, AppError> =
659            messages
660                .iter()
661                .map(|(role, content)| {
662                    match role.as_str() {
663                        "system" => {
664                            let msg = ChatCompletionRequestSystemMessageArgs::default()
665                                .content(content.as_str())
666                                .build()
667                                .map_err(|e| {
668                                    AppError::LLM(format!("Failed to build system message: {}", e))
669                                })?;
670                            Ok(ChatCompletionRequestMessage::System(msg))
671                        }
672                        "assistant" => {
673                            let msg = ChatCompletionRequestAssistantMessageArgs::default()
674                                .content(content.as_str())
675                                .build()
676                                .map_err(|e| {
677                                    AppError::LLM(format!(
678                                        "Failed to build assistant message: {}",
679                                        e
680                                    ))
681                                })?;
682                            Ok(ChatCompletionRequestMessage::Assistant(msg))
683                        }
684                        _ => {
685                            // Default to user message
686                            let msg = ChatCompletionRequestUserMessageArgs::default()
687                                .content(content.as_str())
688                                .build()
689                                .map_err(|e| {
690                                    AppError::LLM(format!("Failed to build user message: {}", e))
691                                })?;
692                            Ok(ChatCompletionRequestMessage::User(msg))
693                        }
694                    }
695                })
696                .collect();
697
698        let mut builder = CreateChatCompletionRequestArgs::default();
699        builder.model(&self.model);
700        builder.messages(chat_messages?);
701
702        // Apply model parameters
703        if let Some(temp) = self.params.temperature {
704            builder.temperature(temp);
705        }
706        if let Some(max_tokens) = self.params.max_tokens {
707            builder.max_completion_tokens(max_tokens);
708        }
709        if let Some(top_p) = self.params.top_p {
710            builder.top_p(top_p);
711        }
712        if let Some(freq_penalty) = self.params.frequency_penalty {
713            builder.frequency_penalty(freq_penalty);
714        }
715        if let Some(pres_penalty) = self.params.presence_penalty {
716            builder.presence_penalty(pres_penalty);
717        }
718
719        let request = builder
720            .build()
721            .map_err(|e| AppError::LLM(format!("Failed to build request: {}", e)))?;
722
723        let mut stream = self
724            .client
725            .chat()
726            .create_stream(request)
727            .await
728            .map_err(|e| AppError::LLM(format!("OpenAI API error: {}", e)))?;
729
730        let result_stream = async_stream::stream! {
731            while let Some(result) = stream.next().await {
732                match result {
733                    Ok(response) => {
734                        for choice in response.choices {
735                            if let Some(content) = choice.delta.content {
736                                yield Ok(content);
737                            }
738                        }
739                    }
740                    Err(e) => {
741                        yield Err(AppError::LLM(format!("Stream error: {}", e)));
742                    }
743                }
744            }
745        };
746
747        Ok(Box::new(Box::pin(result_stream)))
748    }
749
750    fn model_name(&self) -> &str {
751        &self.model
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758
759    #[test]
760    fn test_client_creation() {
761        let client = OpenAIClient::new(
762            "test-key".to_string(),
763            "https://api.openai.com/v1".to_string(),
764            "gpt-4".to_string(),
765        );
766
767        assert_eq!(client.model_name(), "gpt-4");
768    }
769
770    #[test]
771    fn test_tool_conversion() {
772        let tool = ToolDefinition {
773            name: "calculator".to_string(),
774            description: "Performs math operations".to_string(),
775            parameters: serde_json::json!({
776                "type": "object",
777                "properties": {
778                    "operation": {"type": "string"},
779                    "a": {"type": "number"},
780                    "b": {"type": "number"}
781                },
782                "required": ["operation", "a", "b"]
783            }),
784        };
785
786        let openai_tool = OpenAIClient::convert_tool(&tool);
787        match openai_tool {
788            ChatCompletionTools::Function(chat_tool) => {
789                assert_eq!(chat_tool.function.name, "calculator");
790                assert_eq!(
791                    chat_tool.function.description,
792                    Some("Performs math operations".to_string())
793                );
794            }
795            ChatCompletionTools::Custom(_) => {
796                panic!("Expected Function variant, got Custom");
797            }
798        }
799    }
800}