1use std::future::IntoFuture;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use futures::future::BoxFuture;
6use http::header::{Entry, HeaderName, ACCEPT, AUTHORIZATION, CONTENT_TYPE, COOKIE};
7use http::uri::PathAndQuery;
8use http::{header, HeaderMap, HeaderValue, Method, Uri, Version};
9use serde::Serialize;
10use serde_json::Value;
11
12use crate::error::ProtocolResult;
13use crate::middleware::Next;
14use crate::multipart::Form;
15use crate::{Client, Error, InMemoryBody, InMemoryResponse, Middleware, Request, Response};
16
17pub static ACCEPT_JSON: HeaderValue = HeaderValue::from_static("application/json");
18pub static CONTENT_JSON: HeaderValue = HeaderValue::from_static("application/json; charset=utf-8");
19pub static CONTENT_URL_ENCODED: HeaderValue = HeaderValue::from_static("application/x-www-form-urlencoded");
20
21#[derive(Debug)]
30pub struct RequestBuilder<'a, C = Client, B = InMemoryBody> {
31 client: &'a C,
32
33 pub version: Version,
34 pub method: Method,
35 pub uri: Uri,
36 pub headers: HeaderMap,
37 pub body: Option<B>,
38 pub middlewares: Vec<Arc<dyn Middleware>>,
39}
40
41impl<'a> RequestBuilder<'a, ()> {
42 pub fn get(url: &str) -> RequestBuilder<'a, ()> {
43 RequestBuilder::new(&(), Method::GET, Uri::from_str(url).expect("Invalid URL"))
44 }
45 pub fn post(url: &str) -> RequestBuilder<'a, ()> {
46 RequestBuilder::new(&(), Method::POST, Uri::from_str(url).expect("Invalid URL"))
47 }
48 pub fn put(url: &str) -> RequestBuilder<'a, ()> {
49 RequestBuilder::new(&(), Method::PUT, Uri::from_str(url).expect("Invalid URL"))
50 }
51 pub fn delete(url: &str) -> RequestBuilder<'a, ()> {
52 RequestBuilder::new(&(), Method::DELETE, Uri::from_str(url).expect("Invalid URL"))
53 }
54 pub fn head(url: &str) -> RequestBuilder<'a, ()> {
55 RequestBuilder::new(&(), Method::HEAD, Uri::from_str(url).expect("Invalid URL"))
56 }
57}
58
59impl<'a, C> RequestBuilder<'a, C> {
60 pub fn new(client: &'a C, method: Method, uri: Uri) -> RequestBuilder<'a, C, InMemoryBody> {
61 RequestBuilder {
62 client,
63 version: Default::default(),
64 method,
65 uri,
66 headers: Default::default(),
67 body: Default::default(),
68 middlewares: Default::default(),
69 }
70 }
71
72 #[must_use]
73 pub fn form<S: Serialize>(mut self, obj: S) -> Self {
74 match self.body {
75 None => {
76 let body = serde_qs::to_string(&obj).unwrap();
77 self.body = Some(InMemoryBody::Text(body));
78 self.headers.entry(CONTENT_TYPE).or_insert(CONTENT_URL_ENCODED.clone());
79 self.headers.entry(ACCEPT).or_insert(HeaderValue::from_static("html/text"));
80 self
81 }
82 Some(InMemoryBody::Text(ref mut body)) => {
83 let new_body = serde_qs::to_string(&obj).unwrap();
84 body.push('&');
85 body.push_str(&new_body);
86 self
87 }
88 _ => {
89 panic!("Cannot add form to non-form body");
90 }
91 }
92 }
93
94 #[must_use]
96 pub fn set_json<S: Serialize>(mut self, obj: S) -> Self {
97 self.body = Some(InMemoryBody::Json(serde_json::to_value(obj).unwrap()));
98 self.headers.entry(CONTENT_TYPE).or_insert(CONTENT_JSON.clone());
99 self.headers.entry(ACCEPT).or_insert(ACCEPT_JSON.clone());
100 self
101 }
102
103 #[must_use]
105 pub fn json<S: Serialize>(mut self, obj: S) -> Self {
106 match self.body {
107 None => self.set_json(obj),
108 Some(InMemoryBody::Json(Value::Object(ref mut body))) => {
109 if let Value::Object(obj) = serde_json::to_value(obj).unwrap() {
110 body.extend(obj);
111 } else {
112 panic!("Tried to push a non-object to a json body.");
113 }
114 self
115 }
116 _ => panic!("Tried to call .json() on a non-json body. Use .set_json if you need to force a json body."),
117 }
118 }
119
120 #[must_use]
122 pub fn bytes(mut self, bytes: Vec<u8>) -> Self {
123 self.body = Some(InMemoryBody::Bytes(bytes));
125 self.headers.entry(CONTENT_TYPE).or_insert(HeaderValue::from_static("application/octet-stream"));
126 self
127 }
128
129 #[must_use]
131 pub fn text(mut self, text: String) -> Self {
132 self.body = Some(InMemoryBody::Text(text));
134 self.headers.entry(CONTENT_TYPE).or_insert(HeaderValue::from_static("text/plain"));
135 self
136 }
137
138 #[must_use]
139 pub fn multipart<B>(mut self, form: Form<B>) -> Self
140 where
141 Form<B>: Into<Vec<u8>>,
142 {
143 let content_type = form.full_content_type();
144 self.headers.entry(CONTENT_TYPE).or_insert(content_type.parse().unwrap());
145 let body: Vec<u8> = form.into();
146 match String::from_utf8(body) {
148 Ok(text) => self.body = Some(InMemoryBody::Text(text)),
149 Err(bytes) => self.body = Some(InMemoryBody::Bytes(bytes.into_bytes())),
150 }
151 self
153 }
154}
155
156impl<'a> RequestBuilder<'a> {
157 pub async fn send(self) -> ProtocolResult<Response> {
160 let client = self.client;
161 let (request, middlewares) = self.into_req_and_middleware();
162 let next = Next {
163 client,
164 middlewares: &middlewares,
165 };
166 next.run(request).await
167 }
168}
169
170impl<'a, C, B: Default> RequestBuilder<'a, C, B> {
171 pub fn build(self) -> Request<B> {
172 let mut b = Request::builder().method(self.method).uri(self.uri).version(self.version);
173 *b.headers_mut().unwrap() = self.headers;
174 b.body(self.body.unwrap_or_default()).expect("Failed to build request in .build")
175 }
176
177 pub fn into_req_and_middleware(self) -> (Request<B>, Vec<Arc<dyn Middleware>>) {
178 let mut request = http::Request::builder().method(self.method).uri(self.uri).version(self.version);
179 *request.headers_mut().unwrap() = self.headers;
180 let request = request.body(self.body.unwrap_or_default().into()).unwrap();
181 (request, self.middlewares)
182 }
183}
184
185impl<'a, C, B> RequestBuilder<'a, C, B> {
186 pub fn for_client(client: &'a C) -> RequestBuilder<'a, C> {
187 RequestBuilder {
188 client,
189 version: Default::default(),
190 method: Default::default(),
191 uri: Default::default(),
192 headers: Default::default(),
193 body: Default::default(),
194 middlewares: Default::default(),
195 }
196 }
197
198 #[must_use]
199 pub fn method(mut self, method: Method) -> Self {
200 self.method = method;
201 self
202 }
203
204 #[must_use]
205 pub fn url(mut self, uri: &str) -> Self {
206 self.uri = Uri::from_str(uri).expect("Invalid URI");
207 self
208 }
209
210 #[must_use]
211 pub fn set_headers<S: AsRef<str>, I: Iterator<Item = (S, S)>>(mut self, headers: I) -> Self {
212 self.headers = HeaderMap::new();
213 self.headers(headers)
214 }
215
216 #[must_use]
217 pub fn headers<S: AsRef<str>, I: Iterator<Item = (S, S)>>(mut self, headers: I) -> Self {
218 self.headers
219 .extend(headers.map(|(k, v)| (HeaderName::from_str(k.as_ref()).unwrap(), HeaderValue::from_str(v.as_ref()).unwrap())));
220 self
221 }
222
223 #[must_use]
224 pub fn header<K: TryInto<HeaderName>>(mut self, key: K, value: &str) -> Self
225 where
226 <K as TryInto<HeaderName>>::Error: std::fmt::Debug,
227 {
228 let header = key.try_into().expect("Failed to convert key to HeaderName");
229 self.headers.insert(header, HeaderValue::from_str(value).unwrap());
230 self
231 }
232
233 #[must_use]
234 pub fn cookie(mut self, key: &str, value: &str) -> Self {
235 match self.headers.entry(COOKIE) {
236 Entry::Occupied(mut e) => {
237 let v = e.get_mut();
238 *v = HeaderValue::from_str(&format!("{}; {}={}", v.to_str().unwrap(), key, value)).unwrap();
239 }
240 Entry::Vacant(_) => {
241 let value = HeaderValue::from_str(&format!("{key}={value}")).unwrap();
242 self.headers.insert(COOKIE, value);
243 }
244 }
245 self
246 }
247
248 #[must_use]
249 pub fn bearer_auth(mut self, token: &str) -> Self {
250 self.headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {token}")).unwrap());
251 self
252 }
253
254 #[must_use]
255 pub fn token_auth(mut self, token: &str) -> Self {
256 self.headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Token {token}")).unwrap());
257 self
258 }
259
260 #[must_use]
261 pub fn basic_auth(mut self, token: &str) -> Self {
262 self.headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Basic {token}")).unwrap());
263 self
264 }
265
266 #[must_use]
268 pub fn set_query<S: Serialize>(mut self, obj: S) -> Self {
269 let qs = serde_qs::to_string(&obj).expect("Failed to serialize query in .set_query");
270 let mut parts = std::mem::take(&mut self.uri).into_parts();
271 let pq = parts.path_and_query.unwrap();
272 let pq = PathAndQuery::from_str(&format!("{}?{}", pq.path(), qs)).unwrap();
273 parts.path_and_query = Some(pq);
274 self.uri = Uri::from_parts(parts).unwrap();
275 self
276 }
277
278 #[must_use]
288 pub fn query(mut self, k: &str, v: &str) -> Self {
289 let mut parts = std::mem::take(&mut self.uri).into_parts();
290 let pq = parts.path_and_query.unwrap();
291 let pq = PathAndQuery::from_str(
292 match pq.query() {
293 Some(q) => format!("{}?{}&{}={}", pq.path(), q, urlencoding::encode(k), urlencoding::encode(v)),
294 None => format!("{}?{}={}", pq.path(), urlencoding::encode(k), urlencoding::encode(v)),
295 }
296 .as_str(),
297 )
298 .unwrap();
299 parts.path_and_query = Some(pq);
300 self.uri = Uri::from_parts(parts).unwrap();
301 self
302 }
303
304 #[must_use]
305 pub fn content_type(mut self, content_type: &str) -> Self {
306 self.headers.insert(header::CONTENT_TYPE, content_type.parse().unwrap());
307 self
308 }
309
310 #[must_use]
312 pub fn body(mut self, body: B) -> Self {
313 self.body = Some(body);
314 self
315 }
316
317 #[must_use]
318 pub fn set_middlewares(mut self, middlewares: Vec<Arc<dyn Middleware>>) -> Self {
319 self.middlewares = middlewares;
320 self
321 }
322
323 #[must_use]
324 pub fn middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
325 self.middlewares.push(middleware);
326 self
327 }
328}
329
330impl<'a> IntoFuture for RequestBuilder<'a, Client> {
331 type Output = crate::InMemoryResult<InMemoryResponse>;
332 type IntoFuture = BoxFuture<'a, Self::Output>;
333
334 fn into_future(self) -> Self::IntoFuture {
335 Box::pin(async move {
336 let res = self.send().await;
337 let res = match res {
338 Ok(res) => res,
339 Err(e) => return Err(e.into()),
340 };
341 let (parts, body) = res.into_parts();
342 let mut body = match body.into_memory().await {
343 Ok(body) => body,
344 Err(e) => return Err(e.into()),
345 };
346 if let InMemoryBody::Bytes(bytes) = body {
347 body = match String::from_utf8(bytes) {
348 Ok(text) => InMemoryBody::Text(text),
349 Err(e) => InMemoryBody::Bytes(e.into_bytes()),
350 };
351 }
352 let status = &parts.status;
353 if status.is_client_error() || status.is_server_error() {
354 Err(Error::HttpError(InMemoryResponse::from_parts(parts, body)))
356 } else {
357 Ok(InMemoryResponse::from_parts(parts, body))
358 }
359 })
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use serde::{Deserialize, Serialize};
366
367 use super::*;
368
369 #[derive(Serialize, Deserialize)]
370 pub struct TopLevel {
371 inside: Nested,
372 }
373
374 #[derive(Serialize, Deserialize)]
375 pub struct Nested {
376 a: usize,
377 }
378
379 #[test]
380 fn test_query() {
381 let c = Client::new();
382 let qs = TopLevel { inside: Nested { a: 1 } };
383 let r = c.get("/api").set_query(qs).build();
384 assert_eq!(r.uri().to_string(), "/api?inside[a]=1");
385 }
386}