Skip to main content

openai_compat/
request.rs

1//! Request pipeline with retries, mirroring the send/retry loop in
2//! `_base_client.py:1006-1128`.
3
4use std::time::Duration;
5
6use bytes::Bytes;
7use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION};
8use reqwest::Method;
9use serde::de::DeserializeOwned;
10
11use crate::client::Client;
12use crate::error::OpenAIError;
13use crate::retry::{calculate_retry_timeout, parse_retry_after, should_retry};
14
15/// A multipart form field. Forms are rebuilt for every retry attempt, so the
16/// parts are stored as plain data rather than `reqwest::multipart::Form`.
17#[derive(Debug, Clone)]
18pub(crate) enum MultipartField {
19    Text {
20        name: String,
21        value: String,
22    },
23    File {
24        name: String,
25        filename: String,
26        bytes: Vec<u8>,
27        content_type: Option<String>,
28    },
29}
30
31#[derive(Debug, Clone, Default)]
32pub(crate) enum Payload {
33    #[default]
34    None,
35    Json(serde_json::Value),
36    Multipart(Vec<MultipartField>),
37}
38
39/// Per-request options assembled by resource methods.
40#[derive(Debug, Clone, Default)]
41pub(crate) struct RequestOptions {
42    pub query: Vec<(String, String)>,
43    pub body: Payload,
44    pub extra_headers: Vec<(String, String)>,
45    pub timeout: Option<Duration>,
46}
47
48impl RequestOptions {
49    pub fn json(body: serde_json::Value) -> Self {
50        Self {
51            body: Payload::Json(body),
52            ..Self::default()
53        }
54    }
55
56    pub fn multipart(fields: Vec<MultipartField>) -> Self {
57        Self {
58            body: Payload::Multipart(fields),
59            ..Self::default()
60        }
61    }
62
63    pub fn query(query: Vec<(String, String)>) -> Self {
64        Self {
65            query,
66            ..Self::default()
67        }
68    }
69}
70
71fn build_multipart(fields: &[MultipartField]) -> Result<reqwest::multipart::Form, OpenAIError> {
72    let mut form = reqwest::multipart::Form::new();
73    for field in fields {
74        form = match field {
75            MultipartField::Text { name, value } => form.text(name.clone(), value.clone()),
76            MultipartField::File {
77                name,
78                filename,
79                bytes,
80                content_type,
81            } => {
82                let mut part = reqwest::multipart::Part::bytes(bytes.clone())
83                    .file_name(filename.clone());
84                if let Some(ct) = content_type {
85                    part = part.mime_str(ct).map_err(|e| {
86                        OpenAIError::Config(format!("invalid content type {ct:?}: {e}"))
87                    })?;
88                }
89                form.part(name.clone(), part)
90            }
91        };
92    }
93    Ok(form)
94}
95
96impl Client {
97    fn request_headers(&self, extra: &[(String, String)]) -> Result<HeaderMap, OpenAIError> {
98        let config = self.config();
99        let mut headers = HeaderMap::new();
100
101        if let Some(azure) = &config.azure {
102            // Azure auth replaces the default Bearer header entirely:
103            // `api-key: <key>` or `Authorization: Bearer <AD token>`.
104            let (name, value) = azure.auth.header();
105            let mut value = HeaderValue::from_str(&value).map_err(|_| {
106                OpenAIError::Config("Azure credential contains invalid header characters".into())
107            })?;
108            value.set_sensitive(true);
109            headers.insert(HeaderName::from_static(name), value);
110        } else {
111            let mut auth = HeaderValue::from_str(&format!("Bearer {}", config.api_key))
112                .map_err(|_| {
113                    OpenAIError::Config("API key contains invalid header characters".into())
114                })?;
115            auth.set_sensitive(true);
116            headers.insert(AUTHORIZATION, auth);
117        }
118
119        if let Some(org) = &config.organization {
120            headers.insert(
121                HeaderName::from_static("openai-organization"),
122                HeaderValue::from_str(org)
123                    .map_err(|_| OpenAIError::Config("invalid organization header value".into()))?,
124            );
125        }
126        if let Some(project) = &config.project {
127            headers.insert(
128                HeaderName::from_static("openai-project"),
129                HeaderValue::from_str(project)
130                    .map_err(|_| OpenAIError::Config("invalid project header value".into()))?,
131            );
132        }
133        for (name, value) in config.default_headers.iter().chain(extra.iter()) {
134            let name = HeaderName::from_bytes(name.as_bytes())
135                .map_err(|_| OpenAIError::Config(format!("invalid header name {name:?}")))?;
136            let value = HeaderValue::from_str(value)
137                .map_err(|_| OpenAIError::Config(format!("invalid value for header {name}")))?;
138            headers.insert(name, value);
139        }
140        Ok(headers)
141    }
142
143    fn build_request(
144        &self,
145        method: &Method,
146        path: &str,
147        options: &RequestOptions,
148        attempt: u32,
149    ) -> Result<reqwest::RequestBuilder, OpenAIError> {
150        let config = self.config();
151        let path = if let Some(azure) = &config.azure {
152            // Mirror azure.py::_prepare_url: only deployments endpoints get a
153            // `/deployments/{...}` segment — from the pinned deployment, or
154            // derived from the body's `model` when none is pinned. All other
155            // endpoints go straight under `{endpoint}/openai`.
156            if let Some(deployment) = azure.deployment.as_deref() {
157                if crate::azure::DEPLOYMENTS_ENDPOINTS.contains(&path) {
158                    format!("/deployments/{deployment}{path}")
159                } else {
160                    path.to_string()
161                }
162            } else {
163                let model = match &options.body {
164                    Payload::Json(value) => value.get("model").and_then(|m| m.as_str()),
165                    _ => None,
166                };
167                crate::azure::rewrite_path(path, false, model)
168            }
169        } else {
170            path.to_string()
171        };
172        let url = format!("{}{}", self.base_url(), path);
173        let mut builder = self
174            .http()
175            .request(method.clone(), url)
176            .headers(self.request_headers(&options.extra_headers)?)
177            .header("x-stainless-retry-count", attempt);
178        if !config.default_query.is_empty() {
179            builder = builder.query(&config.default_query);
180        }
181        if !options.query.is_empty() {
182            builder = builder.query(&options.query);
183        }
184        if let Some(timeout) = options.timeout {
185            builder = builder.timeout(timeout);
186        }
187        builder = match &options.body {
188            Payload::None => builder,
189            Payload::Json(value) => builder.json(value),
190            Payload::Multipart(fields) => builder.multipart(build_multipart(fields)?),
191        };
192        Ok(builder)
193    }
194
195    /// Send a request, retrying per `_base_client.py` semantics, and return
196    /// the successful raw response (used for streaming and binary bodies).
197    pub(crate) async fn execute_raw(
198        &self,
199        method: Method,
200        path: &str,
201        options: RequestOptions,
202    ) -> Result<reqwest::Response, OpenAIError> {
203        let max_retries = self.config().max_retries;
204
205        for attempt in 0..=max_retries {
206            let request = self.build_request(&method, path, &options, attempt)?;
207            match request.send().await {
208                Err(err) => {
209                    // Connection-level failure: always retryable while
210                    // attempts remain.
211                    if attempt < max_retries {
212                        tokio::time::sleep(calculate_retry_timeout(attempt, None, rand::random()))
213                            .await;
214                        continue;
215                    }
216                    return Err(if err.is_timeout() {
217                        OpenAIError::Timeout
218                    } else {
219                        OpenAIError::Connection(err.to_string())
220                    });
221                }
222                Ok(response) => {
223                    let status = response.status();
224                    if status.is_success() {
225                        return Ok(response);
226                    }
227
228                    let headers = response.headers().clone();
229                    let request_id = headers
230                        .get("x-request-id")
231                        .and_then(|v| v.to_str().ok())
232                        .map(str::to_owned);
233                    if attempt < max_retries && should_retry(status.as_u16(), &headers) {
234                        tokio::time::sleep(calculate_retry_timeout(
235                            attempt,
236                            parse_retry_after(&headers),
237                            rand::random(),
238                        ))
239                        .await;
240                        continue;
241                    }
242
243                    let body = response.text().await.unwrap_or_default();
244                    return Err(OpenAIError::from_response(status.as_u16(), request_id, &body));
245                }
246            }
247        }
248        unreachable!("retry loop always returns")
249    }
250
251    /// Execute a request and deserialize the JSON response body.
252    pub(crate) async fn execute<T: DeserializeOwned>(
253        &self,
254        method: Method,
255        path: &str,
256        options: RequestOptions,
257    ) -> Result<T, OpenAIError> {
258        let response = self.execute_raw(method, path, options).await?;
259        let bytes = response.bytes().await?;
260        Ok(serde_json::from_slice(&bytes)?)
261    }
262
263    /// Execute a request and return the raw response bytes (e.g. audio).
264    pub(crate) async fn execute_bytes(
265        &self,
266        method: Method,
267        path: &str,
268        options: RequestOptions,
269    ) -> Result<Bytes, OpenAIError> {
270        let response = self.execute_raw(method, path, options).await?;
271        Ok(response.bytes().await?)
272    }
273
274    /// Escape hatch: `GET` an arbitrary API path with query parameters.
275    ///
276    /// `T` may be [`serde_json::Value`] for untyped access.
277    pub async fn get<T: DeserializeOwned>(
278        &self,
279        path: &str,
280        query: &[(&str, &str)],
281    ) -> Result<T, OpenAIError> {
282        let query = query
283            .iter()
284            .map(|(k, v)| (k.to_string(), v.to_string()))
285            .collect();
286        self.execute(Method::GET, path, RequestOptions::query(query))
287            .await
288    }
289
290    /// Escape hatch: `POST` an arbitrary API path with a JSON body.
291    ///
292    /// `T` may be [`serde_json::Value`] for untyped access.
293    pub async fn post<B: serde::Serialize, T: DeserializeOwned>(
294        &self,
295        path: &str,
296        body: &B,
297    ) -> Result<T, OpenAIError> {
298        let body = serde_json::to_value(body)?;
299        self.execute(Method::POST, path, RequestOptions::json(body))
300            .await
301    }
302
303    /// Escape hatch: `DELETE` an arbitrary API path.
304    pub async fn delete<T: DeserializeOwned>(&self, path: &str) -> Result<T, OpenAIError> {
305        self.execute(Method::DELETE, path, RequestOptions::default())
306            .await
307    }
308}