1use std::time::Duration;
9
10use bytes::Bytes;
11use futures::{Stream, StreamExt};
12use reqwest::multipart::{Form, Part};
13use reqwest::{Method, StatusCode};
14use reqwest_middleware::ClientWithMiddleware;
15use serde::{Serialize, de::DeserializeOwned};
16use url::Url;
17
18use crate::{
19 auth::SharedTokenProvider,
20 error::{ApiErrorResponse, SdkError},
21 interceptor::SharedInterceptor,
22};
23
24#[derive(Debug)]
26pub struct RequestSpec<'a, B: ?Sized = ()> {
27 pub method: Method,
28 pub path: &'a str,
31 pub query: &'a [(&'a str, Option<String>)],
33 pub body: Option<&'a B>,
35 pub extra_headers: &'a [(&'a str, String)],
37 pub timeout: Option<Duration>,
39}
40
41impl<B: ?Sized> Default for RequestSpec<'_, B> {
42 fn default() -> Self {
43 Self {
44 method: Method::GET,
45 path: "",
46 query: &[],
47 body: None,
48 extra_headers: &[],
49 timeout: None,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct MultipartRequestSpec<'a> {
57 pub path: &'a str,
58 pub query: &'a [(&'a str, Option<String>)],
59 pub field_name: &'a str,
60 pub filename: &'a str,
61 pub content_type: &'a str,
62 pub body: &'a [u8],
63 pub timeout: Option<Duration>,
64}
65
66#[derive(Debug, Clone)]
68pub struct Transport {
69 client: ClientWithMiddleware,
70 base_url: Url,
71 user_agent: String,
72 tokens: SharedTokenProvider,
73 interceptors: Vec<SharedInterceptor>,
74 default_timeout: Duration,
75}
76
77impl Transport {
78 pub(crate) fn new(
79 client: ClientWithMiddleware,
80 base_url: Url,
81 user_agent: String,
82 tokens: SharedTokenProvider,
83 interceptors: Vec<SharedInterceptor>,
84 default_timeout: Duration,
85 ) -> Self {
86 Self {
87 client,
88 base_url,
89 user_agent,
90 tokens,
91 interceptors,
92 default_timeout,
93 }
94 }
95
96 #[must_use]
98 pub fn base_url(&self) -> &Url {
99 &self.base_url
100 }
101
102 pub async fn request_json<B, R>(&self, spec: RequestSpec<'_, B>) -> Result<R, SdkError>
108 where
109 B: Serialize + ?Sized,
110 R: DeserializeOwned + 'static,
111 {
112 let url = self.resolve_url(spec.path)?;
113
114 let mut builder = self.client.request(spec.method.clone(), url.clone());
115 builder = builder.header(reqwest::header::USER_AGENT, &self.user_agent);
116
117 let token = self.tokens.token().await?;
118 builder = builder.header(reqwest::header::AUTHORIZATION, &token.authorization);
119
120 for (name, value) in spec.extra_headers {
121 builder = builder.header(*name, value);
122 }
123
124 let pairs: Vec<(&str, String)> = spec
126 .query
127 .iter()
128 .filter_map(|(k, v)| v.clone().map(|vv| (*k, vv)))
129 .collect();
130 if !pairs.is_empty() {
131 builder = builder.query(&pairs);
132 }
133
134 if let Some(body) = spec.body {
135 builder = builder.json(body);
136 }
137
138 if let Some(timeout) = spec.timeout {
139 builder = builder.timeout(timeout);
140 } else {
141 builder = builder.timeout(self.default_timeout);
142 }
143
144 let mut request = builder
145 .build()
146 .map_err(|e| SdkError::Serialize(e.to_string()))?;
147 for interceptor in &self.interceptors {
148 request = interceptor.on_request(request).await?;
149 }
150
151 let response = self.client.execute(request).await.map_err(SdkError::from)?;
152
153 for interceptor in &self.interceptors {
154 interceptor.on_response(&response).await?;
155 }
156
157 let status = response.status();
158 if status.is_success() {
159 decode_success::<R>(response).await
160 } else {
161 Err(decode_error(status, response).await)
162 }
163 }
164
165 #[allow(clippy::too_many_arguments)] pub async fn request_json_raw_body(
171 &self,
172 method: Method,
173 path: &str,
174 query: &[(&str, Option<String>)],
175 body: Vec<u8>,
176 content_type: &str,
177 extra_headers: &[(&str, String)],
178 timeout: Option<Duration>,
179 ) -> Result<serde_json::Value, SdkError> {
180 let url = self.resolve_url(path)?;
181
182 let ct = reqwest::header::HeaderValue::from_str(content_type)
183 .map_err(|e| SdkError::Serialize(format!("content-type: {e}")))?;
184
185 let mut builder = self.client.request(method.clone(), url.clone());
186 builder = builder.header(reqwest::header::USER_AGENT, &self.user_agent);
187 builder = builder.header(reqwest::header::CONTENT_TYPE, ct);
188
189 let token = self.tokens.token().await?;
190 builder = builder.header(reqwest::header::AUTHORIZATION, &token.authorization);
191
192 for (name, value) in extra_headers {
193 builder = builder.header(*name, value);
194 }
195
196 let pairs: Vec<(&str, String)> = query
197 .iter()
198 .filter_map(|(k, v)| v.clone().map(|vv| (*k, vv)))
199 .collect();
200 if !pairs.is_empty() {
201 builder = builder.query(&pairs);
202 }
203
204 builder = builder.body(body);
205
206 if let Some(t) = timeout {
207 builder = builder.timeout(t);
208 } else {
209 builder = builder.timeout(self.default_timeout);
210 }
211
212 let mut request = builder
213 .build()
214 .map_err(|e| SdkError::Serialize(e.to_string()))?;
215 for interceptor in &self.interceptors {
216 request = interceptor.on_request(request).await?;
217 }
218
219 let response = self.client.execute(request).await.map_err(SdkError::from)?;
220
221 for interceptor in &self.interceptors {
222 interceptor.on_response(&response).await?;
223 }
224
225 let status = response.status();
226 if status.is_success() {
227 decode_success::<serde_json::Value>(response).await
228 } else {
229 Err(decode_error(status, response).await)
230 }
231 }
232
233 pub async fn request_stream<B>(
241 &self,
242 spec: RequestSpec<'_, B>,
243 ) -> Result<
244 (
245 StatusCode,
246 impl Stream<Item = Result<Bytes, SdkError>> + Send,
247 ),
248 SdkError,
249 >
250 where
251 B: Serialize + ?Sized,
252 {
253 let url = self.resolve_url(spec.path)?;
254
255 let mut builder = self.client.request(spec.method.clone(), url.clone());
256 builder = builder.header(reqwest::header::USER_AGENT, &self.user_agent);
257
258 let token = self.tokens.token().await?;
259 builder = builder.header(reqwest::header::AUTHORIZATION, &token.authorization);
260
261 for (name, value) in spec.extra_headers {
262 builder = builder.header(*name, value);
263 }
264
265 let pairs: Vec<(&str, String)> = spec
266 .query
267 .iter()
268 .filter_map(|(k, v)| v.clone().map(|vv| (*k, vv)))
269 .collect();
270 if !pairs.is_empty() {
271 builder = builder.query(&pairs);
272 }
273
274 if let Some(body) = spec.body {
275 builder = builder.json(body);
276 }
277
278 if let Some(timeout) = spec.timeout {
279 builder = builder.timeout(timeout);
280 } else {
281 builder = builder.timeout(self.default_timeout);
282 }
283
284 let mut request = builder
285 .build()
286 .map_err(|e| SdkError::Serialize(e.to_string()))?;
287 for interceptor in &self.interceptors {
288 request = interceptor.on_request(request).await?;
289 }
290
291 let response = self.client.execute(request).await.map_err(SdkError::from)?;
292
293 for interceptor in &self.interceptors {
294 interceptor.on_response(&response).await?;
295 }
296
297 let status = response.status();
298 if status.is_success() {
299 let stream = response.bytes_stream().map(|r| r.map_err(SdkError::from));
300 Ok((status, stream))
301 } else {
302 Err(decode_error(status, response).await)
303 }
304 }
305
306 pub async fn request_multipart<R>(&self, spec: MultipartRequestSpec<'_>) -> Result<R, SdkError>
308 where
309 R: DeserializeOwned + 'static,
310 {
311 let url = self.resolve_url(spec.path)?;
312
313 let part = Part::bytes(spec.body.to_vec())
314 .file_name(spec.filename.to_string())
315 .mime_str(spec.content_type)
316 .map_err(|e| SdkError::Serialize(format!("multipart: {e}")))?;
317 let form = Form::new().part(spec.field_name.to_string(), part);
318
319 let mut builder = self.client.request(Method::POST, url.clone());
320 builder = builder.header(reqwest::header::USER_AGENT, &self.user_agent);
321
322 let token = self.tokens.token().await?;
323 builder = builder.header(reqwest::header::AUTHORIZATION, &token.authorization);
324
325 let pairs: Vec<(&str, String)> = spec
326 .query
327 .iter()
328 .filter_map(|(k, v)| v.clone().map(|vv| (*k, vv)))
329 .collect();
330 if !pairs.is_empty() {
331 builder = builder.query(&pairs);
332 }
333
334 builder = builder.multipart(form);
335
336 if let Some(timeout) = spec.timeout {
337 builder = builder.timeout(timeout);
338 } else {
339 builder = builder.timeout(self.default_timeout);
340 }
341
342 let mut request = builder
343 .build()
344 .map_err(|e| SdkError::Serialize(e.to_string()))?;
345 for interceptor in &self.interceptors {
346 request = interceptor.on_request(request).await?;
347 }
348
349 let response = self.client.execute(request).await.map_err(SdkError::from)?;
350
351 for interceptor in &self.interceptors {
352 interceptor.on_response(&response).await?;
353 }
354
355 let status = response.status();
356 if status.is_success() {
357 decode_success::<R>(response).await
358 } else {
359 Err(decode_error(status, response).await)
360 }
361 }
362
363 fn resolve_url(&self, path: &str) -> Result<Url, SdkError> {
364 let path = path.strip_prefix('/').unwrap_or(path);
365 let mut base = self.base_url.clone();
367 if !base.path().ends_with('/') {
368 let p = format!("{}/", base.path());
369 base.set_path(&p);
370 }
371 base.join(path)
372 .map_err(|e| SdkError::Config(format!("could not build URL from path {path}: {e}")))
373 }
374}
375
376async fn decode_success<R: DeserializeOwned + 'static>(
377 response: reqwest::Response,
378) -> Result<R, SdkError> {
379 if std::any::TypeId::of::<R>() == std::any::TypeId::of::<()>() {
382 let _ = response.bytes().await.map_err(SdkError::from)?;
384 return serde_json::from_str::<R>("null").map_err(SdkError::from);
386 }
387
388 let bytes = response.bytes().await.map_err(SdkError::from)?;
389 if bytes.is_empty() {
390 return serde_json::from_str::<R>("null").map_err(SdkError::from);
393 }
394 serde_json::from_slice::<R>(&bytes).map_err(SdkError::from)
395}
396
397async fn decode_error(status: StatusCode, response: reqwest::Response) -> SdkError {
398 let status_code = status.as_u16();
399 let bytes = match response.bytes().await {
400 Ok(b) => b,
401 Err(err) => {
402 return SdkError::Http {
403 status: status_code,
404 message: format!("failed to read error body: {err}"),
405 };
406 }
407 };
408
409 if status_code == 401 || status_code == 403 {
410 let message = serde_json::from_slice::<ApiErrorResponse>(&bytes).map_or_else(
411 |_| String::from_utf8_lossy(&bytes).to_string(),
412 |b| b.message,
413 );
414 return SdkError::Auth(message);
415 }
416
417 match serde_json::from_slice::<ApiErrorResponse>(&bytes) {
418 Ok(body) => SdkError::Api {
419 status: status_code,
420 body,
421 },
422 Err(_) => SdkError::Http {
423 status: status_code,
424 message: String::from_utf8_lossy(&bytes).to_string(),
425 },
426 }
427}