use actix_web::HttpResponse;
use actix_web::http::StatusCode;
use serde_json::json;
use crate::api::response::{bad_request, error_response_with_code};
use crate::api::storage::errors::STORAGE_ERROR_CODE_R2_INVALID_ACCESS_KEY_ID_FORMAT;
const MAX_BUCKET_LEN: usize = 63;
const MIN_BUCKET_LEN: usize = 3;
const MAX_ENDPOINT_HOST_LEN: usize = 253;
const MAX_CREDENTIAL_FIELD_LEN: usize = 2048;
const MAX_REGION_LEN: usize = 64;
fn storage_allow_insecure_http() -> bool {
std::env::var("ATHENA_STORAGE_ALLOW_HTTP")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
pub fn scanner_junk_in_storage_field(s: &str) -> bool {
let lower = s.to_ascii_lowercase();
[
"__import__",
"eval(",
".urlopen(",
"exec(",
"child_process",
"utl_inaddr",
"from dual",
]
.iter()
.any(|needle| lower.contains(needle))
}
pub fn validate_bucket_name(bucket: &str) -> Result<(), HttpResponse> {
let b = bucket.trim();
if b.len() < MIN_BUCKET_LEN || b.len() > MAX_BUCKET_LEN {
return Err(bad_request(
"Invalid bucket",
"bucket must be between 3 and 63 characters",
));
}
if !b.is_ascii() {
return Err(bad_request(
"Invalid bucket",
"bucket must contain only ASCII characters",
));
}
let bytes = b.as_bytes();
if !bytes[0].is_ascii_alphanumeric() || !bytes[bytes.len() - 1].is_ascii_alphanumeric() {
return Err(bad_request(
"Invalid bucket",
"bucket must start and end with a letter or digit",
));
}
let mut prev_dot = false;
for &c in bytes {
match c {
b'a'..=b'z' | b'0'..=b'9' | b'-' => {
prev_dot = false;
}
b'.' => {
if prev_dot {
return Err(bad_request(
"Invalid bucket",
"bucket cannot contain adjacent dots",
));
}
prev_dot = true;
}
_ => {
return Err(bad_request(
"Invalid bucket",
"bucket may only contain lowercase letters, digits, dots, and hyphens",
));
}
}
}
if b.contains("..") {
return Err(bad_request(
"Invalid bucket",
"bucket cannot contain consecutive dots",
));
}
Ok(())
}
fn blocked_endpoint_host(host: &str) -> bool {
let h = host.trim_matches('.').to_ascii_lowercase();
if h == "169.254.169.254"
|| h == "metadata.google.internal"
|| h == "metadata"
|| h.ends_with(".internal")
{
return true;
}
false
}
pub fn validate_storage_endpoint(endpoint: &str) -> Result<reqwest::Url, HttpResponse> {
let trimmed = endpoint.trim();
if trimmed.is_empty() {
return Err(bad_request("Invalid endpoint", "endpoint field is empty"));
}
let url = if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
reqwest::Url::parse(trimmed).map_err(|e| {
bad_request(
"Invalid endpoint",
format!("invalid endpoint URL '{trimmed}': {e}"),
)
})?
} else {
let with_scheme = format!("https://{trimmed}");
reqwest::Url::parse(&with_scheme).map_err(|e| {
bad_request(
"Invalid endpoint",
format!("invalid endpoint URL '{trimmed}': {e}"),
)
})?
};
if url.scheme() == "http" && !storage_allow_insecure_http() {
return Err(bad_request(
"Invalid endpoint",
"only https endpoints are allowed (set ATHENA_STORAGE_ALLOW_HTTP=1 for insecure dev endpoints)",
));
}
if url.scheme() != "http" && url.scheme() != "https" {
return Err(bad_request(
"Invalid endpoint",
"endpoint must use http or https",
));
}
if url.username() != "" || url.password().is_some() {
return Err(bad_request(
"Invalid endpoint",
"endpoint must not include userinfo",
));
}
if url.path() != "/" && !url.path().is_empty() {
return Err(bad_request(
"Invalid endpoint",
"endpoint must not include a path (host[:port] only)",
));
}
if url.query().is_some() || url.fragment().is_some() {
return Err(bad_request(
"Invalid endpoint",
"endpoint must not include query or fragment",
));
}
let host = url
.host_str()
.ok_or_else(|| bad_request("Invalid endpoint", "endpoint URL does not contain a host"))?;
if host.len() > MAX_ENDPOINT_HOST_LEN {
return Err(bad_request("Invalid endpoint", "hostname is too long"));
}
if host.contains(' ') || host.contains('\t') || host.contains('\n') || host.contains('\r') {
return Err(bad_request(
"Invalid endpoint",
"hostname contains illegal whitespace",
));
}
if blocked_endpoint_host(host) {
return Err(bad_request(
"Invalid endpoint",
"endpoint host is not allowed",
));
}
if let Some(port) = url.port() {
if port == 0 {
return Err(bad_request("Invalid endpoint", "invalid port"));
}
}
Ok(url)
}
pub fn validate_region(region: &str) -> Result<(), HttpResponse> {
let r = region.trim();
if r.is_empty() {
return Err(bad_request("Invalid region", "region field is empty"));
}
if r.len() > MAX_REGION_LEN {
return Err(bad_request("Invalid region", "region is too long"));
}
if !r
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(bad_request(
"Invalid region",
"region contains invalid characters",
));
}
Ok(())
}
pub fn validate_access_credentials(
access_key_id: &str,
secret_key: &str,
) -> Result<(), HttpResponse> {
if access_key_id.trim().is_empty() {
return Err(bad_request(
"Invalid credentials",
"access_key_id is required",
));
}
if secret_key.trim().is_empty() {
return Err(bad_request("Invalid credentials", "secret_key is required"));
}
if access_key_id.len() > MAX_CREDENTIAL_FIELD_LEN || secret_key.len() > MAX_CREDENTIAL_FIELD_LEN
{
return Err(bad_request(
"Invalid credentials",
"credential fields are too long",
));
}
Ok(())
}
pub fn validate_provider_specific_credentials(
endpoint: &reqwest::Url,
access_key_id: &str,
) -> Result<(), HttpResponse> {
let Some(host) = endpoint.host_str() else {
return Ok(());
};
let normalized_host = host.trim().to_ascii_lowercase();
let trimmed_access_key = access_key_id.trim();
if normalized_host.ends_with(".r2.cloudflarestorage.com")
&& (trimmed_access_key.len() != 32 || trimmed_access_key.contains('-'))
{
return Err(error_response_with_code(
StatusCode::BAD_REQUEST,
"Invalid storage credentials",
"Cloudflare R2 access_key_id must be a 32-character S3 access key ID without hyphens",
STORAGE_ERROR_CODE_R2_INVALID_ACCESS_KEY_ID_FORMAT,
Some(json!({
"operation": "validate_credentials",
"backend": "s3",
"provider": "cloudflare_r2",
})),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bucket_accepts_simple_name() {
assert!(validate_bucket_name("my-bucket-1").is_ok());
}
#[test]
fn bucket_rejects_uppercase() {
assert!(validate_bucket_name("MyBucket").is_err());
}
#[test]
fn bucket_rejects_adjacent_dots() {
assert!(validate_bucket_name("a..b").is_err());
}
#[test]
fn rejects_r2_access_key_ids_with_non_32_length() {
let endpoint = reqwest::Url::parse("https://example.r2.cloudflarestorage.com").unwrap();
let response = validate_provider_specific_credentials(
&endpoint,
"fba11c68-6eed-a7d4-e905-44b68432b2a8",
)
.expect_err("expected provider-specific R2 validation to fail");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
}