agent_core_runtime/client/providers/openai/
mod.rs1mod sse;
2mod types;
3
4use async_stream::stream;
5use futures::Stream;
6
7use crate::client::error::LlmError;
8use crate::client::http::HttpClient;
9use crate::client::models::{Message, MessageOptions, StreamEvent};
10use crate::client::traits::LlmProvider;
11use std::future::Future;
12use std::pin::Pin;
13
14const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
20
21const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
23
24#[derive(Clone)]
30pub struct AzureConfig {
31 pub resource: String,
33 pub deployment: String,
35 pub api_version: String,
37}
38
39pub struct OpenAIProvider {
46 api_key: String,
48 model: String,
50 base_url: Option<String>,
53 azure_config: Option<AzureConfig>,
55}
56
57impl OpenAIProvider {
58 pub fn new(api_key: String, model: String) -> Self {
60 Self {
61 api_key,
62 model,
63 base_url: None,
64 azure_config: None,
65 }
66 }
67
68 pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
73 Self {
74 api_key,
75 model,
76 base_url: Some(base_url),
77 azure_config: None,
78 }
79 }
80
81 pub fn azure(api_key: String, resource: String, deployment: String, api_version: String) -> Self {
87 Self {
88 api_key,
89 model: String::new(), base_url: None,
91 azure_config: Some(AzureConfig {
92 resource,
93 deployment,
94 api_version,
95 }),
96 }
97 }
98
99 pub fn model(&self) -> &str {
101 &self.model
102 }
103
104 pub fn is_azure(&self) -> bool {
106 self.azure_config.is_some()
107 }
108
109 fn api_url(&self) -> String {
111 if let Some(azure) = &self.azure_config {
112 types::get_azure_api_url(&azure.resource, &azure.deployment, &azure.api_version)
113 } else {
114 types::get_api_url_with_base(self.base_url.as_deref())
115 }
116 }
117
118 fn get_headers(&self) -> Vec<(&'static str, String)> {
120 if self.azure_config.is_some() {
121 types::get_azure_request_headers(&self.api_key)
122 } else {
123 types::get_request_headers(&self.api_key)
124 }
125 }
126}
127
128impl LlmProvider for OpenAIProvider {
129 fn send_msg(
130 &self,
131 client: &HttpClient,
132 messages: &[Message],
133 options: &MessageOptions,
134 ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
135 let client = client.clone();
137 let model = self.model.clone();
138 let api_url = self.api_url();
139 let headers = self.get_headers();
140 let messages = messages.to_vec();
141 let options = options.clone();
142
143 Box::pin(async move {
144 let body = types::build_request_body(&messages, &options, &model)?;
146
147 let headers_ref: Vec<(&str, &str)> = headers
149 .iter()
150 .map(|(k, v)| (*k, v.as_str()))
151 .collect();
152
153 let response = client.post(&api_url, &headers_ref, &body).await?;
155
156 types::parse_response(&response)
158 })
159 }
160
161 fn send_msg_stream(
162 &self,
163 client: &HttpClient,
164 messages: &[Message],
165 options: &MessageOptions,
166 ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
167 let client = client.clone();
169 let model = self.model.clone();
170 let api_url = self.api_url();
171 let headers = self.get_headers();
172 let messages = messages.to_vec();
173 let options = options.clone();
174
175 Box::pin(async move {
176 let body = types::build_streaming_request_body(&messages, &options, &model)?;
178
179 let headers_ref: Vec<(&str, &str)> = headers
181 .iter()
182 .map(|(k, v)| (*k, v.as_str()))
183 .collect();
184
185 let byte_stream = client.post_stream(&api_url, &headers_ref, &body).await?;
187
188 use futures::StreamExt;
190 let event_stream = stream! {
191 let mut buffer = String::new();
192 let mut byte_stream = byte_stream;
193 let mut stream_state = sse::StreamState::default();
194
195 while let Some(chunk_result) = byte_stream.next().await {
196 match chunk_result {
197 Ok(bytes) => {
198 if let Ok(text) = std::str::from_utf8(&bytes) {
200 buffer.push_str(text);
201 } else {
202 yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
203 break;
204 }
205
206 let (events, remaining) = sse::parse_sse_chunk(&buffer);
208 buffer = remaining;
209
210 for sse_event in events {
212 match sse::parse_stream_event(&sse_event, &mut stream_state) {
213 Ok(stream_events) => {
214 for stream_event in stream_events {
215 yield Ok(stream_event);
216 }
217 }
218 Err(e) => {
219 yield Err(e);
220 return;
221 }
222 }
223 }
224 }
225 Err(e) => {
226 yield Err(e);
227 break;
228 }
229 }
230 }
231 };
232
233 Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
234 })
235 }
236}