use crate::client::{HttpClient, HttpError, HttpErrorKind, HttpRequest, HttpRequestBody};
use http::header::{InvalidHeaderName, InvalidHeaderValue};
use http::uri::InvalidUri;
use http::{HeaderName, HeaderValue, Method, Uri};
#[derive(Debug, thiserror::Error)]
pub(crate) enum RequestBuilderError {
#[error("Invalid URI")]
InvalidUri(#[from] InvalidUri),
#[error("Invalid Header Value")]
InvalidHeaderValue(#[from] InvalidHeaderValue),
#[error("Invalid Header Name")]
InvalidHeaderName(#[from] InvalidHeaderName),
#[error("JSON serialization error")]
SerdeJson(#[from] serde_json::Error),
#[error("URL serialization error")]
SerdeUrl(#[from] serde_urlencoded::ser::Error),
}
impl From<RequestBuilderError> for HttpError {
fn from(value: RequestBuilderError) -> Self {
Self::new(HttpErrorKind::Request, value)
}
}
impl From<std::convert::Infallible> for RequestBuilderError {
fn from(value: std::convert::Infallible) -> Self {
match value {}
}
}
pub(crate) struct HttpRequestBuilder {
client: HttpClient,
request: Result<HttpRequest, RequestBuilderError>,
}
impl HttpRequestBuilder {
pub(crate) fn new(client: HttpClient) -> Self {
Self {
client,
request: Ok(HttpRequest::new(HttpRequestBody::empty())),
}
}
#[cfg(any(feature = "aws", feature = "azure"))]
pub(crate) fn from_parts(client: HttpClient, request: HttpRequest) -> Self {
Self {
client,
request: Ok(request),
}
}
pub(crate) fn method(mut self, method: Method) -> Self {
if let Ok(r) = &mut self.request {
*r.method_mut() = method;
}
self
}
pub(crate) fn uri<U>(mut self, url: U) -> Self
where
U: TryInto<Uri>,
U::Error: Into<RequestBuilderError>,
{
match (url.try_into(), &mut self.request) {
(Ok(uri), Ok(r)) => *r.uri_mut() = uri,
(Err(e), Ok(_)) => self.request = Err(e.into()),
(_, Err(_)) => {}
}
self
}
pub(crate) fn extensions(mut self, extensions: ::http::Extensions) -> Self {
if let Ok(r) = &mut self.request {
*r.extensions_mut() = extensions;
}
self
}
pub(crate) fn header<K, V>(mut self, name: K, value: V) -> Self
where
K: TryInto<HeaderName>,
K::Error: Into<RequestBuilderError>,
V: TryInto<HeaderValue>,
V::Error: Into<RequestBuilderError>,
{
match (name.try_into(), value.try_into(), &mut self.request) {
(Ok(name), Ok(value), Ok(r)) => {
r.headers_mut().insert(name, value);
}
(Err(e), _, Ok(_)) => self.request = Err(e.into()),
(_, Err(e), Ok(_)) => self.request = Err(e.into()),
(_, _, Err(_)) => {}
}
self
}
#[cfg(feature = "aws")]
pub(crate) fn headers(mut self, headers: http::HeaderMap) -> Self {
use http::header::{Entry, OccupiedEntry};
if let Ok(ref mut req) = self.request {
let mut prev_entry: Option<OccupiedEntry<'_, _>> = None;
for (key, value) in headers {
match key {
Some(key) => match req.headers_mut().entry(key) {
Entry::Occupied(mut e) => {
e.insert(value);
prev_entry = Some(e);
}
Entry::Vacant(e) => {
let e = e.insert_entry(value);
prev_entry = Some(e);
}
},
None => match prev_entry {
Some(ref mut entry) => {
entry.append(value);
}
None => unreachable!("HeaderMap::into_iter yielded None first"),
},
}
}
}
self
}
#[cfg(feature = "gcp")]
pub(crate) fn bearer_auth(mut self, token: &str) -> Self {
let value = HeaderValue::try_from(format!("Bearer {token}"));
match (value, &mut self.request) {
(Ok(mut v), Ok(r)) => {
v.set_sensitive(true);
r.headers_mut().insert(http::header::AUTHORIZATION, v);
}
(Err(e), Ok(_)) => self.request = Err(e.into()),
(_, Err(_)) => {}
}
self
}
#[cfg(feature = "gcp")]
pub(crate) fn json<S: serde::Serialize>(mut self, s: S) -> Self {
match (serde_json::to_vec(&s), &mut self.request) {
(Ok(json), Ok(request)) => {
*request.body_mut() = json.into();
}
(Err(e), Ok(_)) => self.request = Err(e.into()),
(_, Err(_)) => {}
}
self
}
#[cfg(any(test, feature = "aws", feature = "gcp", feature = "azure"))]
pub(crate) fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
let mut error = None;
if let Ok(ref mut req) = self.request {
let mut out = format!("{}?", req.uri().path());
let start_position = out.len();
let mut encoder = form_urlencoded::Serializer::for_suffix(&mut out, start_position);
let serializer = serde_urlencoded::Serializer::new(&mut encoder);
if let Err(err) = query.serialize(serializer) {
error = Some(err.into());
}
match http::uri::PathAndQuery::from_maybe_shared(out) {
Ok(p) => {
let mut parts = req.uri().clone().into_parts();
parts.path_and_query = Some(p);
*req.uri_mut() = Uri::from_parts(parts).unwrap();
}
Err(err) => error = Some(err.into()),
}
}
if let Some(err) = error {
self.request = Err(err);
}
self
}
#[cfg(any(feature = "gcp", feature = "azure"))]
pub(crate) fn form<T: serde::Serialize>(mut self, form: T) -> Self {
let mut error = None;
if let Ok(ref mut req) = self.request {
match serde_urlencoded::to_string(form) {
Ok(body) => {
req.headers_mut().insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("application/x-www-form-urlencoded"),
);
*req.body_mut() = body.into();
}
Err(err) => error = Some(err.into()),
}
}
if let Some(err) = error {
self.request = Err(err);
}
self
}
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
pub(crate) fn body(mut self, b: impl Into<HttpRequestBody>) -> Self {
if let Ok(r) = &mut self.request {
*r.body_mut() = b.into();
}
self
}
pub(crate) fn into_parts(self) -> (HttpClient, Result<HttpRequest, RequestBuilderError>) {
(self.client, self.request)
}
}
#[cfg(any(test, feature = "azure"))]
pub(crate) fn add_query_pairs<I, K, V>(uri: &mut Uri, query_pairs: I)
where
I: IntoIterator,
I::Item: std::borrow::Borrow<(K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut parts = uri.clone().into_parts();
let mut out = match parts.path_and_query {
Some(p) => match p.query() {
Some(query) => format!("{}?{}", p.path(), query),
None => format!("{}?", p.path()),
},
None => "/?".to_string(),
};
let mut serializer = if out.ends_with('?') {
let start_position = out.len();
form_urlencoded::Serializer::for_suffix(&mut out, start_position)
} else {
form_urlencoded::Serializer::new(&mut out)
};
serializer.extend_pairs(query_pairs);
parts.path_and_query = Some(out.try_into().unwrap());
*uri = Uri::from_parts(parts).unwrap();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_add_query_pairs() {
let mut uri = Uri::from_static("https://foo@example.com/bananas");
add_query_pairs(&mut uri, [("foo", "1")]);
assert_eq!(uri.to_string(), "https://foo@example.com/bananas?foo=1");
add_query_pairs(&mut uri, [("bingo", "foo"), ("auth", "test")]);
assert_eq!(
uri.to_string(),
"https://foo@example.com/bananas?foo=1&bingo=foo&auth=test"
);
add_query_pairs(&mut uri, [("t1", "funky shenanigans"), ("a", "😀")]);
assert_eq!(
uri.to_string(),
"https://foo@example.com/bananas?foo=1&bingo=foo&auth=test&t1=funky+shenanigans&a=%F0%9F%98%80"
);
}
#[test]
fn test_add_query_pairs_no_path() {
let mut uri = Uri::from_static("https://foo@example.com");
add_query_pairs(&mut uri, [("foo", "1")]);
assert_eq!(uri.to_string(), "https://foo@example.com/?foo=1");
}
#[test]
fn test_request_builder_query() {
let client = HttpClient::new(reqwest::Client::new());
assert_request_uri(
HttpRequestBuilder::new(client.clone()).uri("http://example.com/bananas"),
"http://example.com/bananas",
);
assert_request_uri(
HttpRequestBuilder::new(client.clone())
.uri("http://example.com/bananas")
.query(&[("foo", "1")]),
"http://example.com/bananas?foo=1",
);
assert_request_uri(
HttpRequestBuilder::new(client.clone())
.uri("http://example.com")
.query(&[("foo", "1")]),
"http://example.com/?foo=1",
);
}
fn assert_request_uri(builder: HttpRequestBuilder, expected: &str) {
assert_eq!(builder.into_parts().1.unwrap().uri().to_string(), expected)
}
}