lakestream 0.0.2

Portable file-utility for object-stores
Documentation
use std::collections::HashMap;

use hmac::{Hmac, Mac, NewMac};
use itertools::Itertools;
use percent_encoding::{utf8_percent_encode, CONTROLS};
use sha2::{Digest, Sha256};
use url::{form_urlencoded, Url};

use super::bucket::S3Credentials;
use crate::utils::time::UtcTimeNow;
use crate::AWS_MAX_LIST_OBJECTS;

fn sign(key: &[u8], msg: &[u8]) -> Vec<u8> {
    let mut hmac = Hmac::<Sha256>::new_from_slice(key)
        .expect("HMAC can take key of any size");
    hmac.update(msg);
    let result = hmac.finalize();
    result.into_bytes().as_slice().to_vec()
}

pub struct S3Client {
    resource: String,
    region: String,
    credentials: S3Credentials,
    endpoint_url: String,
    utc_now: UtcTimeNow,
    query_string: Option<String>,
}

impl S3Client {
    pub fn new(
        endpoint_url: String,
        region: String,
        credentials: S3Credentials,
    ) -> S3Client {
        let resource = "".to_string();
        let utc_now = UtcTimeNow::new();

        S3Client {
            resource,
            region,
            credentials,
            endpoint_url,
            utc_now,
            query_string: None,
        }
    }

    pub fn url(&self) -> String {
        format!(
            "{}/{}?{}",
            &self.endpoint_url,
            &self.resource,
            self.query_string.as_ref().unwrap_or(&"".to_string())
        )
    }

    fn get_canonical_headers(
        &self,
        headers: &HashMap<String, String>,
    ) -> String {
        let mut canonical_headers = String::new();
        let mut headers_vec: Vec<(&String, &String)> = headers.iter().collect();
        headers_vec.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase()));

        for (header_name, header_value) in headers_vec {
            let header_name = header_name.trim().to_lowercase();
            if header_name.starts_with("x-amz-")
                && header_name != "x-amz-client-context"
                || header_name == "host"
                || header_name == "content-type"
                || header_name == "date"
            {
                canonical_headers +=
                    &format!("{}:{}\n", header_name, header_value.trim());
            }
        }

        canonical_headers
    }

    fn generate_signing_key(&self) -> Vec<u8> {
        let k_date = sign(
            format!("AWS4{}", self.credentials.secret_key()).as_bytes(),
            self.utc_now.date_stamp().as_bytes(),
        );
        let k_region = sign(&k_date, self.region.as_bytes());
        let k_service = sign(&k_region, b"s3");
        sign(&k_service, b"aws4_request")
    }

    fn initiate_headers(
        &self,
        headers: Option<HashMap<String, String>>,
        x_amz_date: &str,
        payload_hash: Option<&str>,
    ) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
        let mut headers = headers.unwrap_or_default();
        headers.insert("x-amz-date".to_string(), x_amz_date.to_string());
        headers.insert(
            "x-amz-content-sha256".to_string(),
            payload_hash.unwrap_or("UNSIGNED-PAYLOAD").to_string(),
        );
        Ok(headers)
    }

    fn get_canonical_uri(&self, url: &Url, resource: &str) -> String {
        let canonical_resource = form_urlencoded::byte_serialize(
            resource.trim_end_matches('/').as_bytes(),
        )
        .collect::<String>();
        let endpoint_path =
            url.path().trim_start_matches('/').trim_end_matches('/');

        if endpoint_path.is_empty() {
            canonical_resource
        } else {
            format!(
                "{}/{}",
                form_urlencoded::byte_serialize(endpoint_path.as_bytes())
                    .collect::<String>(),
                canonical_resource
            )
        }
    }

    fn get_canonical_query_string(
        &self,
    ) -> Result<String, Box<dyn std::error::Error>> {
        if self.query_string.as_ref().map_or(true, |s| s.is_empty()) {
            Ok(String::new())
        } else {
            let mut parts: Vec<(String, String)> =
                match self.query_string.as_ref() {
                    Some(query) => query
                        .split('&')
                        .filter_map(|p| {
                            let mut split = p.splitn(2, '=');
                            match (split.next(), split.next()) {
                                (Some(k), Some(v)) => {
                                    Some((k.to_string(), v.to_string()))
                                }
                                _ => None,
                            }
                        })
                        .collect(),
                    None => Vec::new(),
                };
            parts.sort();

            let encoded_parts: Vec<String> = parts
                .into_iter()
                .map(|(k, v)| {
                    format!("{}={}", k, utf8_percent_encode(&v, CONTROLS))
                })
                .collect();

            Ok(encoded_parts.join("&"))
        }
    }

    pub fn generate_list_buckets_headers(
        &mut self,
    ) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
        let method = "GET";

        self.generate_headers(method, None, None)
    }

    pub fn generate_list_objects_headers(
        &mut self,
        prefix: Option<&str>,
        max_keys: Option<u32>,
        continuation_token: Option<&str>,
    ) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
        let method = "GET";

        // Ensure max_keys does not exceed AWS_MAX_LIST_OBJECTS
        let max_keys = max_keys
            .map(|keys| std::cmp::min(keys, AWS_MAX_LIST_OBJECTS))
            .unwrap_or(AWS_MAX_LIST_OBJECTS);

        let mut query_parts = form_urlencoded::Serializer::new(String::new());
        query_parts.append_pair("list-type", "2");
        query_parts.append_pair("max-keys", &max_keys.to_string());
        query_parts.append_pair("delimiter", "/");
        query_parts.append_pair("encoding-type", "url");

        if let Some(p) = prefix {
            query_parts.append_pair("prefix", p);
        }
        if let Some(token) = continuation_token {
            query_parts.append_pair("continuation-token", token);
        }

        self.query_string = Some(query_parts.finish());

        self.generate_headers(method, None, None)
    }

    fn generate_headers(
        &mut self,
        method: &str,
        headers: Option<HashMap<String, String>>,
        payload_hash: Option<&str>,
    ) -> Result<HashMap<String, String>, Box<dyn std::error::Error>> {
        let date_stamp = self.utc_now.date_stamp();
        let x_amz_date = self.utc_now.x_amz_date();

        let credential_scope =
            format!("{}/{}/s3/aws4_request", date_stamp, self.region);
        let mut headers =
            self.initiate_headers(headers, &x_amz_date, payload_hash)?;

        let url = Url::parse(&self.endpoint_url)?;
        let host = url.host_str().ok_or("Missing host")?.to_owned();
        let host = match url.port() {
            Some(port) => host.replace(&format!(":{}", port), ""),
            None => host,
        };
        headers.insert("host".to_string(), host);

        let canonical_uri = self.get_canonical_uri(&url, &self.resource);

        let canonical_headers = self.get_canonical_headers(&headers);
        let signed_headers = headers
            .keys()
            .map(|key| key.to_lowercase())
            .sorted()
            .collect::<Vec<String>>()
            .join(";");

        let canonical_query_string = self.get_canonical_query_string()?;

        let canonical_request = format!(
            "{}\n/{}\n{}\n{}\n{}\n{}",
            method,
            canonical_uri,
            canonical_query_string,
            canonical_headers,
            signed_headers,
            payload_hash.unwrap_or("UNSIGNED-PAYLOAD")
        );

        let string_to_sign = format!(
            "AWS4-HMAC-SHA256\n{}\n{}\n{:x}",
            x_amz_date,
            credential_scope,
            Sha256::digest(canonical_request.as_bytes())
        );

        let signing_key = self.generate_signing_key();
        let signature = sign(&signing_key, string_to_sign.as_bytes());

        let authorization_header = format!(
            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
            self.credentials.access_key(),
            credential_scope,
            signed_headers,
            hex::encode(signature)
        );

        headers.insert("Authorization".to_string(), authorization_header);
        Ok(headers)
    }
}