object_store 0.13.2

A generic object store interface for uniformly interacting with AWS S3, Google Cloud Storage, Azure Blob Storage and local files.
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::client::{HttpClient, HttpError, HttpErrorKind, HttpRequest, HttpRequestBody};
use http::header::{InvalidHeaderName, InvalidHeaderValue};
use http::uri::InvalidUri;
use http::{HeaderName, HeaderValue, Method, Uri};

#[derive(Debug, thiserror::Error)]
pub(crate) enum RequestBuilderError {
    #[error("Invalid URI")]
    InvalidUri(#[from] InvalidUri),

    #[error("Invalid Header Value")]
    InvalidHeaderValue(#[from] InvalidHeaderValue),

    #[error("Invalid Header Name")]
    InvalidHeaderName(#[from] InvalidHeaderName),

    #[error("JSON serialization error")]
    SerdeJson(#[from] serde_json::Error),

    #[error("URL serialization error")]
    SerdeUrl(#[from] serde_urlencoded::ser::Error),
}

impl From<RequestBuilderError> for HttpError {
    fn from(value: RequestBuilderError) -> Self {
        Self::new(HttpErrorKind::Request, value)
    }
}

impl From<std::convert::Infallible> for RequestBuilderError {
    fn from(value: std::convert::Infallible) -> Self {
        match value {}
    }
}

pub(crate) struct HttpRequestBuilder {
    client: HttpClient,
    request: Result<HttpRequest, RequestBuilderError>,
}

impl HttpRequestBuilder {
    pub(crate) fn new(client: HttpClient) -> Self {
        Self {
            client,
            request: Ok(HttpRequest::new(HttpRequestBody::empty())),
        }
    }

    #[cfg(any(feature = "aws", feature = "azure"))]
    pub(crate) fn from_parts(client: HttpClient, request: HttpRequest) -> Self {
        Self {
            client,
            request: Ok(request),
        }
    }

    pub(crate) fn method(mut self, method: Method) -> Self {
        if let Ok(r) = &mut self.request {
            *r.method_mut() = method;
        }
        self
    }

    pub(crate) fn uri<U>(mut self, url: U) -> Self
    where
        U: TryInto<Uri>,
        U::Error: Into<RequestBuilderError>,
    {
        match (url.try_into(), &mut self.request) {
            (Ok(uri), Ok(r)) => *r.uri_mut() = uri,
            (Err(e), Ok(_)) => self.request = Err(e.into()),
            (_, Err(_)) => {}
        }
        self
    }

    pub(crate) fn extensions(mut self, extensions: ::http::Extensions) -> Self {
        if let Ok(r) = &mut self.request {
            *r.extensions_mut() = extensions;
        }
        self
    }

    pub(crate) fn header<K, V>(mut self, name: K, value: V) -> Self
    where
        K: TryInto<HeaderName>,
        K::Error: Into<RequestBuilderError>,
        V: TryInto<HeaderValue>,
        V::Error: Into<RequestBuilderError>,
    {
        match (name.try_into(), value.try_into(), &mut self.request) {
            (Ok(name), Ok(value), Ok(r)) => {
                r.headers_mut().insert(name, value);
            }
            (Err(e), _, Ok(_)) => self.request = Err(e.into()),
            (_, Err(e), Ok(_)) => self.request = Err(e.into()),
            (_, _, Err(_)) => {}
        }
        self
    }

    #[cfg(feature = "aws")]
    pub(crate) fn headers(mut self, headers: http::HeaderMap) -> Self {
        use http::header::{Entry, OccupiedEntry};

        if let Ok(ref mut req) = self.request {
            // IntoIter of HeaderMap yields (Option<HeaderName>, HeaderValue).
            // The first time a name is yielded, it will be Some(name), and if
            // there are more values with the same name, the next yield will be
            // None.

            let mut prev_entry: Option<OccupiedEntry<'_, _>> = None;
            for (key, value) in headers {
                match key {
                    Some(key) => match req.headers_mut().entry(key) {
                        Entry::Occupied(mut e) => {
                            e.insert(value);
                            prev_entry = Some(e);
                        }
                        Entry::Vacant(e) => {
                            let e = e.insert_entry(value);
                            prev_entry = Some(e);
                        }
                    },
                    None => match prev_entry {
                        Some(ref mut entry) => {
                            entry.append(value);
                        }
                        None => unreachable!("HeaderMap::into_iter yielded None first"),
                    },
                }
            }
        }
        self
    }

    #[cfg(feature = "gcp")]
    pub(crate) fn bearer_auth(mut self, token: &str) -> Self {
        let value = HeaderValue::try_from(format!("Bearer {token}"));
        match (value, &mut self.request) {
            (Ok(mut v), Ok(r)) => {
                v.set_sensitive(true);
                r.headers_mut().insert(http::header::AUTHORIZATION, v);
            }
            (Err(e), Ok(_)) => self.request = Err(e.into()),
            (_, Err(_)) => {}
        }
        self
    }

    #[cfg(feature = "gcp")]
    pub(crate) fn json<S: serde::Serialize>(mut self, s: S) -> Self {
        match (serde_json::to_vec(&s), &mut self.request) {
            (Ok(json), Ok(request)) => {
                *request.body_mut() = json.into();
            }
            (Err(e), Ok(_)) => self.request = Err(e.into()),
            (_, Err(_)) => {}
        }
        self
    }

    #[cfg(any(test, feature = "aws", feature = "gcp", feature = "azure"))]
    pub(crate) fn query<T: serde::Serialize + ?Sized>(mut self, query: &T) -> Self {
        let mut error = None;
        if let Ok(ref mut req) = self.request {
            let mut out = format!("{}?", req.uri().path());
            let start_position = out.len();
            let mut encoder = form_urlencoded::Serializer::for_suffix(&mut out, start_position);
            let serializer = serde_urlencoded::Serializer::new(&mut encoder);

            if let Err(err) = query.serialize(serializer) {
                error = Some(err.into());
            }

            match http::uri::PathAndQuery::from_maybe_shared(out) {
                Ok(p) => {
                    let mut parts = req.uri().clone().into_parts();
                    parts.path_and_query = Some(p);
                    *req.uri_mut() = Uri::from_parts(parts).unwrap();
                }
                Err(err) => error = Some(err.into()),
            }
        }
        if let Some(err) = error {
            self.request = Err(err);
        }
        self
    }

    #[cfg(any(feature = "gcp", feature = "azure"))]
    pub(crate) fn form<T: serde::Serialize>(mut self, form: T) -> Self {
        let mut error = None;
        if let Ok(ref mut req) = self.request {
            match serde_urlencoded::to_string(form) {
                Ok(body) => {
                    req.headers_mut().insert(
                        http::header::CONTENT_TYPE,
                        HeaderValue::from_static("application/x-www-form-urlencoded"),
                    );
                    *req.body_mut() = body.into();
                }
                Err(err) => error = Some(err.into()),
            }
        }
        if let Some(err) = error {
            self.request = Err(err);
        }
        self
    }

    #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
    pub(crate) fn body(mut self, b: impl Into<HttpRequestBody>) -> Self {
        if let Ok(r) = &mut self.request {
            *r.body_mut() = b.into();
        }
        self
    }

    pub(crate) fn into_parts(self) -> (HttpClient, Result<HttpRequest, RequestBuilderError>) {
        (self.client, self.request)
    }
}

#[cfg(any(test, feature = "azure"))]
pub(crate) fn add_query_pairs<I, K, V>(uri: &mut Uri, query_pairs: I)
where
    I: IntoIterator,
    I::Item: std::borrow::Borrow<(K, V)>,
    K: AsRef<str>,
    V: AsRef<str>,
{
    let mut parts = uri.clone().into_parts();

    let mut out = match parts.path_and_query {
        Some(p) => match p.query() {
            Some(query) => format!("{}?{}", p.path(), query),
            None => format!("{}?", p.path()),
        },
        None => "/?".to_string(),
    };
    let mut serializer = if out.ends_with('?') {
        let start_position = out.len();
        form_urlencoded::Serializer::for_suffix(&mut out, start_position)
    } else {
        form_urlencoded::Serializer::new(&mut out)
    };

    serializer.extend_pairs(query_pairs);

    parts.path_and_query = Some(out.try_into().unwrap());
    *uri = Uri::from_parts(parts).unwrap();
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_add_query_pairs() {
        let mut uri = Uri::from_static("https://foo@example.com/bananas");

        add_query_pairs(&mut uri, [("foo", "1")]);
        assert_eq!(uri.to_string(), "https://foo@example.com/bananas?foo=1");

        add_query_pairs(&mut uri, [("bingo", "foo"), ("auth", "test")]);
        assert_eq!(
            uri.to_string(),
            "https://foo@example.com/bananas?foo=1&bingo=foo&auth=test"
        );

        add_query_pairs(&mut uri, [("t1", "funky shenanigans"), ("a", "😀")]);
        assert_eq!(
            uri.to_string(),
            "https://foo@example.com/bananas?foo=1&bingo=foo&auth=test&t1=funky+shenanigans&a=%F0%9F%98%80"
        );
    }

    #[test]
    fn test_add_query_pairs_no_path() {
        let mut uri = Uri::from_static("https://foo@example.com");
        add_query_pairs(&mut uri, [("foo", "1")]);
        assert_eq!(uri.to_string(), "https://foo@example.com/?foo=1");
    }

    #[test]
    fn test_request_builder_query() {
        let client = HttpClient::new(reqwest::Client::new());
        assert_request_uri(
            HttpRequestBuilder::new(client.clone()).uri("http://example.com/bananas"),
            "http://example.com/bananas",
        );

        assert_request_uri(
            HttpRequestBuilder::new(client.clone())
                .uri("http://example.com/bananas")
                .query(&[("foo", "1")]),
            "http://example.com/bananas?foo=1",
        );

        assert_request_uri(
            HttpRequestBuilder::new(client.clone())
                .uri("http://example.com")
                .query(&[("foo", "1")]),
            "http://example.com/?foo=1",
        );
    }

    fn assert_request_uri(builder: HttpRequestBuilder, expected: &str) {
        assert_eq!(builder.into_parts().1.unwrap().uri().to_string(), expected)
    }
}