agent_air_runtime/client/providers/openai/
mod.rs1mod sse;
2mod types;
3
4use async_stream::stream;
5use futures::Stream;
6
7use crate::client::error::LlmError;
8use crate::client::http::HttpClient;
9use crate::client::models::{Message, MessageOptions, StreamEvent};
10use crate::client::traits::{LlmProvider, StreamMsgFuture};
11use std::future::Future;
12use std::pin::Pin;
13
14const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
20
21const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
23
24#[derive(Clone)]
30pub struct AzureConfig {
31 pub resource: String,
33 pub deployment: String,
35 pub api_version: String,
37}
38
39pub struct OpenAIProvider {
46 api_key: String,
48 model: String,
50 base_url: Option<String>,
53 azure_config: Option<AzureConfig>,
55}
56
57impl OpenAIProvider {
58 pub fn new(api_key: String, model: String) -> Self {
60 Self {
61 api_key,
62 model,
63 base_url: None,
64 azure_config: None,
65 }
66 }
67
68 pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
73 Self {
74 api_key,
75 model,
76 base_url: Some(base_url),
77 azure_config: None,
78 }
79 }
80
81 pub fn azure(
87 api_key: String,
88 resource: String,
89 deployment: String,
90 api_version: String,
91 ) -> Self {
92 Self {
93 api_key,
94 model: String::new(), base_url: None,
96 azure_config: Some(AzureConfig {
97 resource,
98 deployment,
99 api_version,
100 }),
101 }
102 }
103
104 pub fn model(&self) -> &str {
106 &self.model
107 }
108
109 pub fn is_azure(&self) -> bool {
111 self.azure_config.is_some()
112 }
113
114 fn api_url(&self) -> String {
116 if let Some(azure) = &self.azure_config {
117 types::get_azure_api_url(&azure.resource, &azure.deployment, &azure.api_version)
118 } else {
119 types::get_api_url_with_base(self.base_url.as_deref())
120 }
121 }
122
123 fn get_headers(&self) -> Vec<(&'static str, String)> {
125 if self.azure_config.is_some() {
126 types::get_azure_request_headers(&self.api_key)
127 } else {
128 types::get_request_headers(&self.api_key)
129 }
130 }
131}
132
133impl LlmProvider for OpenAIProvider {
134 fn send_msg(
135 &self,
136 client: &HttpClient,
137 messages: &[Message],
138 options: &MessageOptions,
139 ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
140 let client = client.clone();
142 let model = self.model.clone();
143 let api_url = self.api_url();
144 let headers = self.get_headers();
145 let messages = messages.to_vec();
146 let options = options.clone();
147
148 Box::pin(async move {
149 let body = types::build_request_body(&messages, &options, &model)?;
151
152 let headers_ref: Vec<(&str, &str)> =
154 headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
155
156 let response = client.post(&api_url, &headers_ref, &body).await?;
158
159 types::parse_response(&response)
161 })
162 }
163
164 fn send_msg_stream(
165 &self,
166 client: &HttpClient,
167 messages: &[Message],
168 options: &MessageOptions,
169 ) -> StreamMsgFuture {
170 let client = client.clone();
172 let model = self.model.clone();
173 let api_url = self.api_url();
174 let headers = self.get_headers();
175 let messages = messages.to_vec();
176 let options = options.clone();
177
178 Box::pin(async move {
179 let body = types::build_streaming_request_body(&messages, &options, &model)?;
181
182 let headers_ref: Vec<(&str, &str)> =
184 headers.iter().map(|(k, v)| (*k, v.as_str())).collect();
185
186 let byte_stream = client.post_stream(&api_url, &headers_ref, &body).await?;
188
189 use futures::StreamExt;
191 let event_stream = stream! {
192 let mut buffer = String::new();
193 let mut byte_stream = byte_stream;
194 let mut stream_state = sse::StreamState::default();
195
196 while let Some(chunk_result) = byte_stream.next().await {
197 match chunk_result {
198 Ok(bytes) => {
199 if let Ok(text) = std::str::from_utf8(&bytes) {
201 buffer.push_str(text);
202 } else {
203 yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
204 break;
205 }
206
207 let (events, remaining) = sse::parse_sse_chunk(&buffer);
209 buffer = remaining;
210
211 for sse_event in events {
213 match sse::parse_stream_event(&sse_event, &mut stream_state) {
214 Ok(stream_events) => {
215 for stream_event in stream_events {
216 yield Ok(stream_event);
217 }
218 }
219 Err(e) => {
220 yield Err(e);
221 return;
222 }
223 }
224 }
225 }
226 Err(e) => {
227 yield Err(e);
228 break;
229 }
230 }
231 }
232 };
233
234 Ok(Box::pin(event_stream)
235 as Pin<
236 Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>,
237 >)
238 })
239 }
240}