Skip to main content

agent_core_runtime/client/providers/openai/
mod.rs

1mod sse;
2mod types;
3
4use async_stream::stream;
5use futures::Stream;
6
7use crate::client::error::LlmError;
8use crate::client::http::HttpClient;
9use crate::client::models::{Message, MessageOptions, StreamEvent};
10use crate::client::traits::LlmProvider;
11use std::future::Future;
12use std::pin::Pin;
13
14// =============================================================================
15// Constants
16// =============================================================================
17
18/// Error code for SSE decoding errors.
19const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
20
21/// Error message for invalid UTF-8 in stream.
22const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
23
24// =============================================================================
25// Provider
26// =============================================================================
27
28/// Azure OpenAI configuration.
29#[derive(Clone)]
30pub struct AzureConfig {
31    /// Azure resource name (e.g., "my-resource").
32    pub resource: String,
33    /// Azure deployment name (e.g., "gpt-4-deployment").
34    pub deployment: String,
35    /// Azure API version (e.g., "2024-10-21").
36    pub api_version: String,
37}
38
39/// OpenAI API provider.
40///
41/// Also supports OpenAI-compatible APIs (Groq, Together, Fireworks, etc.)
42/// by specifying a custom base_url.
43///
44/// Also supports Azure OpenAI by specifying an AzureConfig.
45pub struct OpenAIProvider {
46    /// OpenAI API key.
47    api_key: String,
48    /// Model identifier (e.g., "gpt-4").
49    model: String,
50    /// Custom base URL for OpenAI-compatible providers.
51    /// If None, uses the default OpenAI endpoint.
52    base_url: Option<String>,
53    /// Azure configuration. If set, uses Azure OpenAI instead of standard OpenAI.
54    azure_config: Option<AzureConfig>,
55}
56
57impl OpenAIProvider {
58    /// Create a new OpenAI provider with API key and model.
59    pub fn new(api_key: String, model: String) -> Self {
60        Self {
61            api_key,
62            model,
63            base_url: None,
64            azure_config: None,
65        }
66    }
67
68    /// Create a new OpenAI-compatible provider with a custom base URL.
69    ///
70    /// Use this for providers like Groq, Together, Fireworks, etc.
71    /// The base_url should be the API base (e.g., "https://api.groq.com/openai/v1").
72    pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
73        Self {
74            api_key,
75            model,
76            base_url: Some(base_url),
77            azure_config: None,
78        }
79    }
80
81    /// Create a new Azure OpenAI provider.
82    ///
83    /// Azure OpenAI uses a different URL format and authentication header.
84    /// URL: https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
85    /// Auth: api-key header instead of Authorization: Bearer
86    pub fn azure(api_key: String, resource: String, deployment: String, api_version: String) -> Self {
87        Self {
88            api_key,
89            model: String::new(), // Not used for Azure
90            base_url: None,
91            azure_config: Some(AzureConfig {
92                resource,
93                deployment,
94                api_version,
95            }),
96        }
97    }
98
99    /// Returns the model identifier.
100    pub fn model(&self) -> &str {
101        &self.model
102    }
103
104    /// Returns true if this provider is configured for Azure OpenAI.
105    pub fn is_azure(&self) -> bool {
106        self.azure_config.is_some()
107    }
108
109    /// Returns the API endpoint URL.
110    fn api_url(&self) -> String {
111        if let Some(azure) = &self.azure_config {
112            types::get_azure_api_url(&azure.resource, &azure.deployment, &azure.api_version)
113        } else {
114            types::get_api_url_with_base(self.base_url.as_deref())
115        }
116    }
117
118    /// Returns the request headers appropriate for this provider configuration.
119    fn get_headers(&self) -> Vec<(&'static str, String)> {
120        if self.azure_config.is_some() {
121            types::get_azure_request_headers(&self.api_key)
122        } else {
123            types::get_request_headers(&self.api_key)
124        }
125    }
126}
127
128impl LlmProvider for OpenAIProvider {
129    fn send_msg(
130        &self,
131        client: &HttpClient,
132        messages: &[Message],
133        options: &MessageOptions,
134    ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
135        // Clone data for the async block
136        let client = client.clone();
137        let model = self.model.clone();
138        let api_url = self.api_url();
139        let headers = self.get_headers();
140        let messages = messages.to_vec();
141        let options = options.clone();
142
143        Box::pin(async move {
144            // Build request body
145            let body = types::build_request_body(&messages, &options, &model)?;
146
147            // Get headers
148            let headers_ref: Vec<(&str, &str)> = headers
149                .iter()
150                .map(|(k, v)| (*k, v.as_str()))
151                .collect();
152
153            // Make the API call
154            let response = client.post(&api_url, &headers_ref, &body).await?;
155
156            // Parse and return the response
157            types::parse_response(&response)
158        })
159    }
160
161    fn send_msg_stream(
162        &self,
163        client: &HttpClient,
164        messages: &[Message],
165        options: &MessageOptions,
166    ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
167        // Clone data for the async block
168        let client = client.clone();
169        let model = self.model.clone();
170        let api_url = self.api_url();
171        let headers = self.get_headers();
172        let messages = messages.to_vec();
173        let options = options.clone();
174
175        Box::pin(async move {
176            // Build streaming request body
177            let body = types::build_streaming_request_body(&messages, &options, &model)?;
178
179            // Get headers
180            let headers_ref: Vec<(&str, &str)> = headers
181                .iter()
182                .map(|(k, v)| (*k, v.as_str()))
183                .collect();
184
185            // Make the streaming API call
186            let byte_stream = client.post_stream(&api_url, &headers_ref, &body).await?;
187
188            // Convert byte stream to SSE events stream
189            use futures::StreamExt;
190            let event_stream = stream! {
191                let mut buffer = String::new();
192                let mut byte_stream = byte_stream;
193                let mut stream_state = sse::StreamState::default();
194
195                while let Some(chunk_result) = byte_stream.next().await {
196                    match chunk_result {
197                        Ok(bytes) => {
198                            // Append new bytes to buffer
199                            if let Ok(text) = std::str::from_utf8(&bytes) {
200                                buffer.push_str(text);
201                            } else {
202                                yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
203                                break;
204                            }
205
206                            // Parse complete SSE events from buffer
207                            let (events, remaining) = sse::parse_sse_chunk(&buffer);
208                            buffer = remaining;
209
210                            // Convert and yield each SSE event
211                            for sse_event in events {
212                                match sse::parse_stream_event(&sse_event, &mut stream_state) {
213                                    Ok(stream_events) => {
214                                        for stream_event in stream_events {
215                                            yield Ok(stream_event);
216                                        }
217                                    }
218                                    Err(e) => {
219                                        yield Err(e);
220                                        return;
221                                    }
222                                }
223                            }
224                        }
225                        Err(e) => {
226                            yield Err(e);
227                            break;
228                        }
229                    }
230                }
231            };
232
233            Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
234        })
235    }
236}