use actix_web::HttpResponse;
use crate::api::response::bad_request;
const MAX_BUCKET_LEN: usize = 63;
const MIN_BUCKET_LEN: usize = 3;
const MAX_ENDPOINT_HOST_LEN: usize = 253;
const MAX_CREDENTIAL_FIELD_LEN: usize = 2048;
const MAX_REGION_LEN: usize = 64;
fn storage_allow_insecure_http() -> bool {
std::env::var("ATHENA_STORAGE_ALLOW_HTTP")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
}
pub fn scanner_junk_in_storage_field(s: &str) -> bool {
let lower = s.to_ascii_lowercase();
[
"__import__",
"eval(",
".urlopen(",
"exec(",
"child_process",
"utl_inaddr",
"from dual",
]
.iter()
.any(|needle| lower.contains(needle))
}
pub fn validate_bucket_name(bucket: &str) -> Result<(), HttpResponse> {
let b = bucket.trim();
if b.len() < MIN_BUCKET_LEN || b.len() > MAX_BUCKET_LEN {
return Err(bad_request(
"Invalid bucket",
"bucket must be between 3 and 63 characters",
));
}
if !b.is_ascii() {
return Err(bad_request(
"Invalid bucket",
"bucket must contain only ASCII characters",
));
}
let bytes = b.as_bytes();
if !bytes[0].is_ascii_alphanumeric() || !bytes[bytes.len() - 1].is_ascii_alphanumeric() {
return Err(bad_request(
"Invalid bucket",
"bucket must start and end with a letter or digit",
));
}
let mut prev_dot = false;
for &c in bytes {
match c {
b'a'..=b'z' | b'0'..=b'9' | b'-' => {
prev_dot = false;
}
b'.' => {
if prev_dot {
return Err(bad_request(
"Invalid bucket",
"bucket cannot contain adjacent dots",
));
}
prev_dot = true;
}
_ => {
return Err(bad_request(
"Invalid bucket",
"bucket may only contain lowercase letters, digits, dots, and hyphens",
));
}
}
}
if b.contains("..") {
return Err(bad_request(
"Invalid bucket",
"bucket cannot contain consecutive dots",
));
}
Ok(())
}
fn blocked_endpoint_host(host: &str) -> bool {
let h = host.trim_matches('.').to_ascii_lowercase();
if h == "169.254.169.254"
|| h == "metadata.google.internal"
|| h == "metadata"
|| h.ends_with(".internal")
{
return true;
}
false
}
pub fn validate_storage_endpoint(endpoint: &str) -> Result<reqwest::Url, HttpResponse> {
let trimmed = endpoint.trim();
if trimmed.is_empty() {
return Err(bad_request("Invalid endpoint", "endpoint field is empty"));
}
let url = if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
reqwest::Url::parse(trimmed).map_err(|e| {
bad_request(
"Invalid endpoint",
format!("invalid endpoint URL '{trimmed}': {e}"),
)
})?
} else {
let with_scheme = format!("https://{trimmed}");
reqwest::Url::parse(&with_scheme).map_err(|e| {
bad_request(
"Invalid endpoint",
format!("invalid endpoint URL '{trimmed}': {e}"),
)
})?
};
if url.scheme() == "http" && !storage_allow_insecure_http() {
return Err(bad_request(
"Invalid endpoint",
"only https endpoints are allowed (set ATHENA_STORAGE_ALLOW_HTTP=1 for insecure dev endpoints)",
));
}
if url.scheme() != "http" && url.scheme() != "https" {
return Err(bad_request(
"Invalid endpoint",
"endpoint must use http or https",
));
}
if url.username() != "" || url.password().is_some() {
return Err(bad_request(
"Invalid endpoint",
"endpoint must not include userinfo",
));
}
if url.path() != "/" && !url.path().is_empty() {
return Err(bad_request(
"Invalid endpoint",
"endpoint must not include a path (host[:port] only)",
));
}
if url.query().is_some() || url.fragment().is_some() {
return Err(bad_request(
"Invalid endpoint",
"endpoint must not include query or fragment",
));
}
let host = url
.host_str()
.ok_or_else(|| bad_request("Invalid endpoint", "endpoint URL does not contain a host"))?;
if host.len() > MAX_ENDPOINT_HOST_LEN {
return Err(bad_request("Invalid endpoint", "hostname is too long"));
}
if host.contains(' ') || host.contains('\t') || host.contains('\n') || host.contains('\r') {
return Err(bad_request(
"Invalid endpoint",
"hostname contains illegal whitespace",
));
}
if blocked_endpoint_host(host) {
return Err(bad_request(
"Invalid endpoint",
"endpoint host is not allowed",
));
}
if let Some(port) = url.port() {
if port == 0 {
return Err(bad_request("Invalid endpoint", "invalid port"));
}
}
Ok(url)
}
pub fn validate_region(region: &str) -> Result<(), HttpResponse> {
let r = region.trim();
if r.is_empty() {
return Err(bad_request("Invalid region", "region field is empty"));
}
if r.len() > MAX_REGION_LEN {
return Err(bad_request("Invalid region", "region is too long"));
}
if !r
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
{
return Err(bad_request(
"Invalid region",
"region contains invalid characters",
));
}
Ok(())
}
pub fn validate_access_credentials(
access_key_id: &str,
secret_key: &str,
) -> Result<(), HttpResponse> {
if access_key_id.trim().is_empty() {
return Err(bad_request(
"Invalid credentials",
"access_key_id is required",
));
}
if secret_key.trim().is_empty() {
return Err(bad_request("Invalid credentials", "secret_key is required"));
}
if access_key_id.len() > MAX_CREDENTIAL_FIELD_LEN || secret_key.len() > MAX_CREDENTIAL_FIELD_LEN
{
return Err(bad_request(
"Invalid credentials",
"credential fields are too long",
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bucket_accepts_simple_name() {
assert!(validate_bucket_name("my-bucket-1").is_ok());
}
#[test]
fn bucket_rejects_uppercase() {
assert!(validate_bucket_name("MyBucket").is_err());
}
#[test]
fn bucket_rejects_adjacent_dots() {
assert!(validate_bucket_name("a..b").is_err());
}
}