1use std::time::Duration;
5
6use bytes::Bytes;
7use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
8use reqwest::Method;
9use serde::de::DeserializeOwned;
10
11use crate::client::Client;
12use crate::error::OpenAIError;
13use crate::retry::{calculate_retry_timeout, parse_retry_after, should_retry};
14
15#[derive(Debug, Clone)]
18pub(crate) enum MultipartField {
19 Text {
20 name: String,
21 value: String,
22 },
23 File {
24 name: String,
25 filename: String,
26 bytes: Vec<u8>,
27 content_type: Option<String>,
28 },
29}
30
31#[derive(Debug, Clone, Default)]
32pub(crate) enum Payload {
33 #[default]
34 None,
35 Json(serde_json::Value),
36 Multipart(Vec<MultipartField>),
37}
38
39#[derive(Debug, Clone, Default)]
41pub(crate) struct RequestOptions {
42 pub query: Vec<(String, String)>,
43 pub body: Payload,
44 pub extra_headers: Vec<(String, String)>,
45 pub timeout: Option<Duration>,
46}
47
48impl RequestOptions {
49 pub fn json(body: serde_json::Value) -> Self {
50 Self {
51 body: Payload::Json(body),
52 ..Self::default()
53 }
54 }
55
56 pub fn multipart(fields: Vec<MultipartField>) -> Self {
57 Self {
58 body: Payload::Multipart(fields),
59 ..Self::default()
60 }
61 }
62
63 pub fn query(query: Vec<(String, String)>) -> Self {
64 Self {
65 query,
66 ..Self::default()
67 }
68 }
69}
70
71fn build_multipart(fields: &[MultipartField]) -> Result<reqwest::multipart::Form, OpenAIError> {
72 let mut form = reqwest::multipart::Form::new();
73 for field in fields {
74 form = match field {
75 MultipartField::Text { name, value } => form.text(name.clone(), value.clone()),
76 MultipartField::File {
77 name,
78 filename,
79 bytes,
80 content_type,
81 } => {
82 let mut part = reqwest::multipart::Part::bytes(bytes.clone())
83 .file_name(filename.clone());
84 if let Some(ct) = content_type {
85 part = part.mime_str(ct).map_err(|e| {
86 OpenAIError::Config(format!("invalid content type {ct:?}: {e}"))
87 })?;
88 }
89 form.part(name.clone(), part)
90 }
91 };
92 }
93 Ok(form)
94}
95
96impl Client {
97 fn request_headers(&self, extra: &[(String, String)]) -> Result<HeaderMap, OpenAIError> {
98 let config = self.config();
99 let mut headers = HeaderMap::new();
100
101 if let Some(azure) = &config.azure {
102 let (name, value) = azure.auth.header();
105 let mut value = HeaderValue::from_str(&value).map_err(|_| {
106 OpenAIError::Config("Azure credential contains invalid header characters".into())
107 })?;
108 value.set_sensitive(true);
109 headers.insert(HeaderName::from_static(name), value);
110 } else {
111 let mut auth = HeaderValue::from_str(&format!("Bearer {}", config.api_key))
112 .map_err(|_| {
113 OpenAIError::Config("API key contains invalid header characters".into())
114 })?;
115 auth.set_sensitive(true);
116 headers.insert(AUTHORIZATION, auth);
117 }
118
119 if let Some(org) = &config.organization {
120 headers.insert(
121 HeaderName::from_static("openai-organization"),
122 HeaderValue::from_str(org)
123 .map_err(|_| OpenAIError::Config("invalid organization header value".into()))?,
124 );
125 }
126 if let Some(project) = &config.project {
127 headers.insert(
128 HeaderName::from_static("openai-project"),
129 HeaderValue::from_str(project)
130 .map_err(|_| OpenAIError::Config("invalid project header value".into()))?,
131 );
132 }
133 for (name, value) in config.default_headers.iter().chain(extra.iter()) {
134 let name = HeaderName::from_bytes(name.as_bytes())
135 .map_err(|_| OpenAIError::Config(format!("invalid header name {name:?}")))?;
136 let value = HeaderValue::from_str(value)
137 .map_err(|_| OpenAIError::Config(format!("invalid value for header {name}")))?;
138 headers.insert(name, value);
139 }
140 Ok(headers)
141 }
142
143 fn build_request(
144 &self,
145 method: &Method,
146 path: &str,
147 options: &RequestOptions,
148 attempt: u32,
149 ) -> Result<reqwest::RequestBuilder, OpenAIError> {
150 let config = self.config();
151 let path = if let Some(azure) = &config.azure {
152 if let Some(deployment) = azure.deployment.as_deref() {
157 if crate::azure::DEPLOYMENTS_ENDPOINTS.contains(&path) {
158 format!("/deployments/{deployment}{path}")
159 } else {
160 path.to_string()
161 }
162 } else {
163 let model = match &options.body {
164 Payload::Json(value) => value.get("model").and_then(|m| m.as_str()),
165 _ => None,
166 };
167 crate::azure::rewrite_path(path, false, model)
168 }
169 } else {
170 path.to_string()
171 };
172 let url = format!("{}{}", self.base_url(), path);
173 let mut builder = self
174 .http()
175 .request(method.clone(), url)
176 .headers(self.request_headers(&options.extra_headers)?)
177 .header("x-stainless-retry-count", attempt);
178 if !config.default_query.is_empty() {
179 builder = builder.query(&config.default_query);
180 }
181 if !options.query.is_empty() {
182 builder = builder.query(&options.query);
183 }
184 if let Some(timeout) = options.timeout {
185 builder = builder.timeout(timeout);
186 }
187 builder = match &options.body {
188 Payload::None => builder,
189 Payload::Json(value) => builder.json(value),
190 Payload::Multipart(fields) => builder.multipart(build_multipart(fields)?),
191 };
192 Ok(builder)
193 }
194
195 pub(crate) async fn execute_raw(
198 &self,
199 method: Method,
200 path: &str,
201 options: RequestOptions,
202 ) -> Result<reqwest::Response, OpenAIError> {
203 let max_retries = self.config().max_retries;
204
205 for attempt in 0..=max_retries {
206 let request = self.build_request(&method, path, &options, attempt)?;
207 match request.send().await {
208 Err(err) => {
209 if attempt < max_retries {
212 tokio::time::sleep(calculate_retry_timeout(attempt, None, rand::random()))
213 .await;
214 continue;
215 }
216 return Err(if err.is_timeout() {
217 OpenAIError::Timeout
218 } else {
219 OpenAIError::Connection(err.to_string())
220 });
221 }
222 Ok(response) => {
223 let status = response.status();
224 if status.is_success() {
225 return Ok(response);
226 }
227
228 let headers = response.headers().clone();
229 let request_id = headers
230 .get("x-request-id")
231 .and_then(|v| v.to_str().ok())
232 .map(str::to_owned);
233 if attempt < max_retries && should_retry(status.as_u16(), &headers) {
234 tokio::time::sleep(calculate_retry_timeout(
235 attempt,
236 parse_retry_after(&headers),
237 rand::random(),
238 ))
239 .await;
240 continue;
241 }
242
243 let body = response.text().await.unwrap_or_default();
244 return Err(OpenAIError::from_response(status.as_u16(), request_id, &body));
245 }
246 }
247 }
248 unreachable!("retry loop always returns")
249 }
250
251 pub(crate) async fn execute<T: DeserializeOwned>(
253 &self,
254 method: Method,
255 path: &str,
256 options: RequestOptions,
257 ) -> Result<T, OpenAIError> {
258 let response = self.execute_raw(method, path, options).await?;
259 let bytes = response.bytes().await?;
260 Ok(serde_json::from_slice(&bytes)?)
261 }
262
263 pub(crate) async fn execute_bytes(
265 &self,
266 method: Method,
267 path: &str,
268 options: RequestOptions,
269 ) -> Result<Bytes, OpenAIError> {
270 let response = self.execute_raw(method, path, options).await?;
271 Ok(response.bytes().await?)
272 }
273
274 pub async fn get<T: DeserializeOwned>(
278 &self,
279 path: &str,
280 query: &[(&str, &str)],
281 ) -> Result<T, OpenAIError> {
282 let query = query
283 .iter()
284 .map(|(k, v)| (k.to_string(), v.to_string()))
285 .collect();
286 self.execute(Method::GET, path, RequestOptions::query(query))
287 .await
288 }
289
290 pub async fn post<B: serde::Serialize, T: DeserializeOwned>(
294 &self,
295 path: &str,
296 body: &B,
297 ) -> Result<T, OpenAIError> {
298 let body = serde_json::to_value(body)?;
299 self.execute(Method::POST, path, RequestOptions::json(body))
300 .await
301 }
302
303 pub async fn delete<T: DeserializeOwned>(&self, path: &str) -> Result<T, OpenAIError> {
305 self.execute(Method::DELETE, path, RequestOptions::default())
306 .await
307 }
308}