Skip to main content

agent_air_runtime/client/providers/openai/
mod.rs

1mod sse;
2mod types;
3
4use async_stream::stream;
5use futures::Stream;
6
7use crate::client::error::LlmError;
8use crate::client::http::HttpClient;
9use crate::client::models::{Message, MessageOptions, StreamEvent};
10use crate::client::traits::{LlmProvider, StreamMsgFuture};
11use std::future::Future;
12use std::pin::Pin;
13
14// =============================================================================
15// Constants
16// =============================================================================
17
18/// Error code for SSE decoding errors.
19const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
20
21/// Error message for invalid UTF-8 in stream.
22const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
23
24// =============================================================================
25// Provider
26// =============================================================================
27
28/// Azure OpenAI configuration.
29#[derive(Clone)]
30pub struct AzureConfig {
31    /// Azure resource name (e.g., "my-resource").
32    pub resource: String,
33    /// Azure deployment name (e.g., "gpt-4-deployment").
34    pub deployment: String,
35    /// Azure API version (e.g., "2024-10-21").
36    pub api_version: String,
37}
38
39/// OpenAI API provider.
40///
41/// Also supports OpenAI-compatible APIs (Groq, Together, Fireworks, etc.)
42/// by specifying a custom base_url.
43///
44/// Also supports Azure OpenAI by specifying an AzureConfig.
45pub struct OpenAIProvider {
46    /// OpenAI API key.
47    api_key: String,
48    /// Model identifier (e.g., "gpt-4").
49    model: String,
50    /// Custom base URL for OpenAI-compatible providers.
51    /// If None, uses the default OpenAI endpoint.
52    base_url: Option<String>,
53    /// Azure configuration. If set, uses Azure OpenAI instead of standard OpenAI.
54    azure_config: Option<AzureConfig>,
55}
56
57impl OpenAIProvider {
58    /// Create a new OpenAI provider with API key and model.
59    pub fn new(api_key: String, model: String) -> Self {
60        Self {
61            api_key,
62            model,
63            base_url: None,
64            azure_config: None,
65        }
66    }
67
68    /// Create a new OpenAI-compatible provider with a custom base URL.
69    ///
70    /// Use this for providers like Groq, Together, Fireworks, etc.
71    /// The base_url should be the API base (e.g., "https://api.groq.com/openai/v1").
72    pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
73        Self {
74            api_key,
75            model,
76            base_url: Some(base_url),
77            azure_config: None,
78        }
79    }
80
81    /// Create a new Azure OpenAI provider.
82    ///
83    /// Azure OpenAI uses a different URL format and authentication header.
84    /// URL: https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
85    /// Auth: api-key header instead of Authorization: Bearer
86    pub fn azure(
87        api_key: String,
88        resource: String,
89        deployment: String,
90        api_version: String,
91    ) -> Self {
92        Self {
93            api_key,
94            model: String::new(), // Not used for Azure
95            base_url: None,
96            azure_config: Some(AzureConfig {
97                resource,
98                deployment,
99                api_version,
100            }),
101        }
102    }
103
104    /// Returns the model identifier.
105    pub fn model(&self) -> &str {
106        &self.model
107    }
108
109    /// Returns true if this provider is configured for Azure OpenAI.
110    pub fn is_azure(&self) -> bool {
111        self.azure_config.is_some()
112    }
113
114    /// Returns the API endpoint URL.
115    fn api_url(&self) -> String {
116        if let Some(azure) = &self.azure_config {
117            types::get_azure_api_url(&azure.resource, &azure.deployment, &azure.api_version)
118        } else {
119            types::get_api_url_with_base(self.base_url.as_deref())
120        }
121    }
122
123    /// Returns the request headers appropriate for this provider configuration.
124    fn get_headers(&self) -> Vec<(&'static str, String)> {
125        if self.azure_config.is_some() {
126            types::get_azure_request_headers(&self.api_key)
127        } else {
128            types::get_request_headers(&self.api_key)
129        }
130    }
131}
132
133impl LlmProvider for OpenAIProvider {
134    fn send_msg(
135        &self,
136        client: &HttpClient,
137        messages: &[Message],
138        options: &MessageOptions,
139    ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
140        // Clone data for the async block
141        let client = client.clone();
142        let model = self.model.clone();
143        let api_url = self.api_url();
144        let headers = self.get_headers();
145        let messages = messages.to_vec();
146        let options = options.clone();
147
148        Box::pin(async move {
149            // Build request body
150            let body = types::build_request_body(&messages, &options, &model)?;
151
152            // Get headers
153            let headers_ref: Vec<(&str, &str)> =
154                headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
155
156            // Make the API call
157            let response = client.post(&api_url, &headers_ref, &body).await?;
158
159            // Parse and return the response
160            types::parse_response(&response)
161        })
162    }
163
164    fn send_msg_stream(
165        &self,
166        client: &HttpClient,
167        messages: &[Message],
168        options: &MessageOptions,
169    ) -> StreamMsgFuture {
170        // Clone data for the async block
171        let client = client.clone();
172        let model = self.model.clone();
173        let api_url = self.api_url();
174        let headers = self.get_headers();
175        let messages = messages.to_vec();
176        let options = options.clone();
177
178        Box::pin(async move {
179            // Build streaming request body
180            let body = types::build_streaming_request_body(&messages, &options, &model)?;
181
182            // Get headers
183            let headers_ref: Vec<(&str, &str)> =
184                headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
185
186            // Make the streaming API call
187            let byte_stream = client.post_stream(&api_url, &headers_ref, &body).await?;
188
189            // Convert byte stream to SSE events stream
190            use futures::StreamExt;
191            let event_stream = stream! {
192                let mut buffer = String::new();
193                let mut byte_stream = byte_stream;
194                let mut stream_state = sse::StreamState::default();
195
196                while let Some(chunk_result) = byte_stream.next().await {
197                    match chunk_result {
198                        Ok(bytes) => {
199                            // Append new bytes to buffer
200                            if let Ok(text) = std::str::from_utf8(&bytes) {
201                                buffer.push_str(text);
202                            } else {
203                                yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
204                                break;
205                            }
206
207                            // Parse complete SSE events from buffer
208                            let (events, remaining) = sse::parse_sse_chunk(&buffer);
209                            buffer = remaining;
210
211                            // Convert and yield each SSE event
212                            for sse_event in events {
213                                match sse::parse_stream_event(&sse_event, &mut stream_state) {
214                                    Ok(stream_events) => {
215                                        for stream_event in stream_events {
216                                            yield Ok(stream_event);
217                                        }
218                                    }
219                                    Err(e) => {
220                                        yield Err(e);
221                                        return;
222                                    }
223                                }
224                            }
225                        }
226                        Err(e) => {
227                            yield Err(e);
228                            break;
229                        }
230                    }
231                }
232            };
233
234            Ok(Box::pin(event_stream)
235                as Pin<
236                    Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>,
237                >)
238        })
239    }
240}