athena_rs 3.26.2

Hyper performant polyglot Database driver
Documentation
//! Strict validation for user-supplied S3 endpoint and bucket names before any SDK/s3cmd I/O.

use actix_web::HttpResponse;
use actix_web::http::StatusCode;
use serde_json::json;

use crate::api::response::{bad_request, error_response_with_code};
use crate::api::storage::errors::STORAGE_ERROR_CODE_R2_INVALID_ACCESS_KEY_ID_FORMAT;

const MAX_BUCKET_LEN: usize = 63;
const MIN_BUCKET_LEN: usize = 3;
const MAX_ENDPOINT_HOST_LEN: usize = 253;
const MAX_CREDENTIAL_FIELD_LEN: usize = 2048;
const MAX_REGION_LEN: usize = 64;

fn storage_allow_insecure_http() -> bool {
    std::env::var("ATHENA_STORAGE_ALLOW_HTTP")
        .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
        .unwrap_or(false)
}

/// Optional post-auth denylist for obvious scanner junk (low false positive for real bucket names).
pub fn scanner_junk_in_storage_field(s: &str) -> bool {
    let lower = s.to_ascii_lowercase();
    [
        "__import__",
        "eval(",
        ".urlopen(",
        "exec(",
        "child_process",
        "utl_inaddr",
        "from dual",
    ]
    .iter()
    .any(|needle| lower.contains(needle))
}

/// Returns `HttpResponse` error when the bucket name violates S3 DNS-style constraints.
pub fn validate_bucket_name(bucket: &str) -> Result<(), HttpResponse> {
    let b = bucket.trim();
    if b.len() < MIN_BUCKET_LEN || b.len() > MAX_BUCKET_LEN {
        return Err(bad_request(
            "Invalid bucket",
            "bucket must be between 3 and 63 characters",
        ));
    }
    if !b.is_ascii() {
        return Err(bad_request(
            "Invalid bucket",
            "bucket must contain only ASCII characters",
        ));
    }
    let bytes = b.as_bytes();
    if !bytes[0].is_ascii_alphanumeric() || !bytes[bytes.len() - 1].is_ascii_alphanumeric() {
        return Err(bad_request(
            "Invalid bucket",
            "bucket must start and end with a letter or digit",
        ));
    }
    let mut prev_dot = false;
    for &c in bytes {
        match c {
            b'a'..=b'z' | b'0'..=b'9' | b'-' => {
                prev_dot = false;
            }
            b'.' => {
                if prev_dot {
                    return Err(bad_request(
                        "Invalid bucket",
                        "bucket cannot contain adjacent dots",
                    ));
                }
                prev_dot = true;
            }
            _ => {
                return Err(bad_request(
                    "Invalid bucket",
                    "bucket may only contain lowercase letters, digits, dots, and hyphens",
                ));
            }
        }
    }
    if b.contains("..") {
        return Err(bad_request(
            "Invalid bucket",
            "bucket cannot contain consecutive dots",
        ));
    }
    Ok(())
}

fn blocked_endpoint_host(host: &str) -> bool {
    let h = host.trim_matches('.').to_ascii_lowercase();
    if h == "169.254.169.254"
        || h == "metadata.google.internal"
        || h == "metadata"
        || h.ends_with(".internal")
    {
        return true;
    }
    false
}

/// Validates `endpoint` URL shape before passing it to AWS SDK / s3cmd.
pub fn validate_storage_endpoint(endpoint: &str) -> Result<reqwest::Url, HttpResponse> {
    let trimmed = endpoint.trim();
    if trimmed.is_empty() {
        return Err(bad_request("Invalid endpoint", "endpoint field is empty"));
    }

    let url = if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
        reqwest::Url::parse(trimmed).map_err(|e| {
            bad_request(
                "Invalid endpoint",
                format!("invalid endpoint URL '{trimmed}': {e}"),
            )
        })?
    } else {
        let with_scheme = format!("https://{trimmed}");
        reqwest::Url::parse(&with_scheme).map_err(|e| {
            bad_request(
                "Invalid endpoint",
                format!("invalid endpoint URL '{trimmed}': {e}"),
            )
        })?
    };

    if url.scheme() == "http" && !storage_allow_insecure_http() {
        return Err(bad_request(
            "Invalid endpoint",
            "only https endpoints are allowed (set ATHENA_STORAGE_ALLOW_HTTP=1 for insecure dev endpoints)",
        ));
    }
    if url.scheme() != "http" && url.scheme() != "https" {
        return Err(bad_request(
            "Invalid endpoint",
            "endpoint must use http or https",
        ));
    }

    if url.username() != "" || url.password().is_some() {
        return Err(bad_request(
            "Invalid endpoint",
            "endpoint must not include userinfo",
        ));
    }

    if url.path() != "/" && !url.path().is_empty() {
        return Err(bad_request(
            "Invalid endpoint",
            "endpoint must not include a path (host[:port] only)",
        ));
    }
    if url.query().is_some() || url.fragment().is_some() {
        return Err(bad_request(
            "Invalid endpoint",
            "endpoint must not include query or fragment",
        ));
    }

    let host = url
        .host_str()
        .ok_or_else(|| bad_request("Invalid endpoint", "endpoint URL does not contain a host"))?;

    if host.len() > MAX_ENDPOINT_HOST_LEN {
        return Err(bad_request("Invalid endpoint", "hostname is too long"));
    }

    if host.contains(' ') || host.contains('\t') || host.contains('\n') || host.contains('\r') {
        return Err(bad_request(
            "Invalid endpoint",
            "hostname contains illegal whitespace",
        ));
    }

    if blocked_endpoint_host(host) {
        return Err(bad_request(
            "Invalid endpoint",
            "endpoint host is not allowed",
        ));
    }

    if let Some(port) = url.port() {
        if port == 0 {
            return Err(bad_request("Invalid endpoint", "invalid port"));
        }
    }

    Ok(url)
}

pub fn validate_region(region: &str) -> Result<(), HttpResponse> {
    let r = region.trim();
    if r.is_empty() {
        return Err(bad_request("Invalid region", "region field is empty"));
    }
    if r.len() > MAX_REGION_LEN {
        return Err(bad_request("Invalid region", "region is too long"));
    }
    if !r
        .chars()
        .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
    {
        return Err(bad_request(
            "Invalid region",
            "region contains invalid characters",
        ));
    }
    Ok(())
}

pub fn validate_access_credentials(
    access_key_id: &str,
    secret_key: &str,
) -> Result<(), HttpResponse> {
    if access_key_id.trim().is_empty() {
        return Err(bad_request(
            "Invalid credentials",
            "access_key_id is required",
        ));
    }
    if secret_key.trim().is_empty() {
        return Err(bad_request("Invalid credentials", "secret_key is required"));
    }
    if access_key_id.len() > MAX_CREDENTIAL_FIELD_LEN || secret_key.len() > MAX_CREDENTIAL_FIELD_LEN
    {
        return Err(bad_request(
            "Invalid credentials",
            "credential fields are too long",
        ));
    }
    Ok(())
}

pub fn validate_provider_specific_credentials(
    endpoint: &reqwest::Url,
    access_key_id: &str,
) -> Result<(), HttpResponse> {
    let Some(host) = endpoint.host_str() else {
        return Ok(());
    };

    let normalized_host = host.trim().to_ascii_lowercase();
    let trimmed_access_key = access_key_id.trim();
    if normalized_host.ends_with(".r2.cloudflarestorage.com")
        && (trimmed_access_key.len() != 32 || trimmed_access_key.contains('-'))
    {
        return Err(error_response_with_code(
            StatusCode::BAD_REQUEST,
            "Invalid storage credentials",
            "Cloudflare R2 access_key_id must be a 32-character S3 access key ID without hyphens",
            STORAGE_ERROR_CODE_R2_INVALID_ACCESS_KEY_ID_FORMAT,
            Some(json!({
                "operation": "validate_credentials",
                "backend": "s3",
                "provider": "cloudflare_r2",
            })),
        ));
    }

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn bucket_accepts_simple_name() {
        assert!(validate_bucket_name("my-bucket-1").is_ok());
    }

    #[test]
    fn bucket_rejects_uppercase() {
        assert!(validate_bucket_name("MyBucket").is_err());
    }

    #[test]
    fn bucket_rejects_adjacent_dots() {
        assert!(validate_bucket_name("a..b").is_err());
    }

    #[test]
    fn rejects_r2_access_key_ids_with_non_32_length() {
        let endpoint = reqwest::Url::parse("https://example.r2.cloudflarestorage.com").unwrap();
        let response = validate_provider_specific_credentials(
            &endpoint,
            "fba11c68-6eed-a7d4-e905-44b68432b2a8",
        )
        .expect_err("expected provider-specific R2 validation to fail");

        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
    }
}