lakestream 0.0.3

Portable file-utility for object-stores
Documentation
use std::collections::HashMap;

use hmac::{Hmac, Mac, NewMac};
use percent_encoding::{utf8_percent_encode, CONTROLS};
use sha2::{Digest, Sha256};
use url::{form_urlencoded, Url};

use super::client::S3ClientConfig;
use crate::utils::time::UtcTimeNow;
use crate::LakestreamError;

pub struct RequestBuilder {
    url: String,
}

impl RequestBuilder {
    pub fn new(url: &str) -> Self {
        let url = url.to_string();
        Self { url }
    }

    pub fn generate_headers(
        &self,
        config: &S3ClientConfig,
        method: &str,
        resource: Option<&str>,
        query_string: Option<&str>,
        payload_hash: Option<&str>,
    ) -> Result<HashMap<String, String>, LakestreamError> {
        let utc_now = UtcTimeNow::new();
        let date_stamp = utc_now.date_stamp();
        let x_amz_date = utc_now.x_amz_date();

        let credential_scope =
            format!("{}/{}/s3/aws4_request", date_stamp, config.region());
        let mut headers = self.initiate_headers(&x_amz_date, payload_hash);

        let url = Url::parse(&self.url)?;
        let host = url.host_str().ok_or("Missing host")?.to_owned();
        let host = match url.port() {
            Some(port) => format!("{}:{}", host, port),
            None => host,
        };

        headers.insert("host".to_string(), host);

        let canonical_uri = self.get_canonical_uri(&url, resource);

        let canonical_headers = self.get_canonical_headers(&headers);

        let mut signed_headers: Vec<String> =
            headers.keys().map(|key| key.to_lowercase()).collect();

        signed_headers.sort();

        let signed_headers_str = signed_headers.join(";");

        let canonical_query_string =
            self.get_canonical_query_string(query_string)?;

        let canonical_request = format!(
            "{}\n/{}\n{}\n{}\n{}\n{}",
            method,
            canonical_uri,
            canonical_query_string,
            canonical_headers,
            signed_headers_str,
            payload_hash.unwrap_or("UNSIGNED-PAYLOAD")
        );

        let string_to_sign = format!(
            "AWS4-HMAC-SHA256\n{}\n{}\n{:x}",
            x_amz_date,
            credential_scope,
            Sha256::digest(canonical_request.as_bytes())
        );
        let signing_key = self.generate_signing_key(
            &date_stamp,
            config.credentials().secret_key(),
            config.region(),
        );
        let signature = sign(&signing_key, string_to_sign.as_bytes());

        let authorization_header = format!(
            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
            config.credentials().access_key(),
            credential_scope,
            signed_headers_str,
            hex::encode(signature)
        );

        headers.insert("Authorization".to_string(), authorization_header);
        Ok(headers)
    }

    fn get_canonical_headers(
        &self,
        headers: &HashMap<String, String>,
    ) -> String {
        let mut canonical_headers = String::new();
        let mut headers_vec: Vec<(&String, &String)> = headers.iter().collect();
        headers_vec.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase()));

        for (header_name, header_value) in headers_vec {
            let header_name = header_name.trim().to_lowercase();
            if header_name.starts_with("x-amz-")
                && header_name != "x-amz-client-context"
                || header_name == "host"
                || header_name == "content-type"
                || header_name == "date"
            {
                canonical_headers +=
                    &format!("{}:{}\n", header_name, header_value.trim());
            }
        }

        canonical_headers
    }

    fn generate_signing_key(
        &self,
        date_stamp: &str,
        secret_key: &str,
        region: &str,
    ) -> Vec<u8> {
        let k_date = sign(
            format!("AWS4{}", secret_key).as_bytes(),
            date_stamp.as_bytes(),
        );
        let k_region = sign(&k_date, region.as_bytes());
        let k_service = sign(&k_region, b"s3");
        sign(&k_service, b"aws4_request")
    }

    fn initiate_headers(
        &self,
        x_amz_date: &str,
        payload_hash: Option<&str>,
    ) -> HashMap<String, String> {
        let mut headers = HashMap::new();
        headers.insert("x-amz-date".to_string(), x_amz_date.to_string());
        headers.insert(
            "x-amz-content-sha256".to_string(),
            payload_hash.unwrap_or("UNSIGNED-PAYLOAD").to_string(),
        );
        headers
    }

    fn get_canonical_uri(&self, url: &Url, resource: Option<&str>) -> String {
        let canonical_resource = form_urlencoded::byte_serialize(
            resource
                .unwrap_or_default()
                .trim_start_matches('/')
                .trim_end_matches('/')
                .as_bytes(),
        )
        .collect::<String>();
        let endpoint_path =
            url.path().trim_start_matches('/').trim_end_matches('/');

        if endpoint_path.is_empty() {
            canonical_resource
        } else {
            format!(
                "{}/{}",
                form_urlencoded::byte_serialize(endpoint_path.as_bytes())
                    .collect::<String>(),
                canonical_resource
            )
        }
    }

    fn get_canonical_query_string(
        &self,
        query_string: Option<&str>,
    ) -> Result<String, LakestreamError> {
        if query_string.as_ref().map_or(true, |s| s.is_empty()) {
            Ok(String::new())
        } else {
            let mut parts: Vec<(String, String)> = match query_string.as_ref() {
                Some(query) => query
                    .split('&')
                    .filter_map(|p| {
                        let mut split = p.splitn(2, '=');
                        match (split.next(), split.next()) {
                            (Some(k), Some(v)) => {
                                Some((k.to_string(), v.to_string()))
                            }
                            _ => None,
                        }
                    })
                    .collect(),
                None => Vec::new(),
            };
            parts.sort();

            let encoded_parts: Vec<String> = parts
                .into_iter()
                .map(|(k, v)| {
                    format!("{}={}", k, utf8_percent_encode(&v, CONTROLS))
                })
                .collect();

            Ok(encoded_parts.join("&"))
        }
    }
}

fn sign(key: &[u8], msg: &[u8]) -> Vec<u8> {
    let mut hmac = Hmac::<Sha256>::new_from_slice(key)
        .expect("HMAC can take key of any size");
    hmac.update(msg);
    let result = hmac.finalize();
    result.into_bytes().as_slice().to_vec()
}