use crate::{
models::{Headers, KeyFormat},
state::AppState,
};
use actix_web::http::header::HeaderMap;
use sha2::{Digest, Sha256};
use std::{fs::metadata, io::ErrorKind, path::Path};
use url::Url;
pub fn check_admin(token: &str, data: &AppState) -> bool {
data.no_admin_key || token == data.admin_key
}
pub fn is_key_valid(key: &str, length: u16, prefix: &str) -> bool {
if length > 0 && key.chars().count() as u16 != length {
return false;
}
if !prefix.is_empty() && !key.starts_with(prefix) {
return false;
}
true
}
pub fn parse_headers(headers: &HeaderMap, original_length: u16, metadata_length: u16) -> Headers {
let forwarded_for = get_truncated_header(headers, "x-forwarded-for", original_length);
let original_host = get_truncated_header(headers, "x-original-host", original_length);
let original_uri = get_truncated_header(headers, "x-original-uri", original_length);
let metadata = get_truncated_header(headers, "x-akas-metadata", metadata_length);
Headers {
forwarded_for,
original_host,
original_uri,
metadata,
}
}
fn get_truncated_header(headers: &HeaderMap, header_name: &str, max_len: u16) -> String {
if max_len == 0 {
return "-".to_string();
}
headers
.get(header_name)
.and_then(|v| v.to_str().ok())
.map(|s| {
if s.is_empty() {
return "-".to_string();
}
let actual_max_len = max_len as usize;
if s.len() > actual_max_len {
s.chars().take(actual_max_len).collect::<String>()
} else {
s.to_string()
}
})
.unwrap_or_else(|| "-".to_string())
}
pub fn sha256_hex(input: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(input.as_bytes());
hex::encode(hasher.finalize())
}
pub fn validate_inputs(host: &str, filename: &str, format: &str) -> Result<(), std::io::Error> {
if Url::parse(host).is_err() {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
format!("{} is not a valid URL.", host),
));
}
let file_path = Path::new(filename);
if !file_path.exists() {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
format!("{} is not a path to an existing file.", filename),
));
}
if !file_path.is_file() || !metadata(file_path)?.is_file() {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
format!("{} is not a valid file.", filename),
));
}
if metadata(file_path)?.permissions().readonly() {
return Err(std::io::Error::new(
ErrorKind::PermissionDenied,
format!("File {} is not readable.", filename),
));
}
if KeyFormat::from_str(format).is_none() {
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
format!(
"{} is not a valid format. Choose from '{}' or '{}'.",
format,
KeyFormat::Plain.as_str(),
KeyFormat::Sha256.as_str()
),
));
}
Ok(())
}
#[cfg(test)]
mod check_admin_tests {
use super::*;
use crate::init_state;
use crate::AppConfig;
use crate::AppState;
use prometheus::{IntCounterVec, Opts};
use std::sync::Arc;
fn create_default_app_state(no_admin: bool, admin_key: &str) -> AppState {
let auth_counter = Arc::new(
IntCounterVec::new(Opts::new("test_auth", "Test auth counter"), &["status"]).unwrap(),
);
init_state(AppConfig {
admin_key: admin_key.to_string(),
no_admin_key: no_admin,
local: false,
enable_metrics: false,
port: 5001,
log_level: "info".to_string(),
original_length: 0,
metadata_length: 0,
key_length: 0,
key_prefix: "".to_string(),
auth_counter: auth_counter,
})
.unwrap()
}
#[test]
fn test_check_admin_no_admin_key_set() {
let data = create_default_app_state(true, ""); assert!(check_admin("any_token", &data));
assert!(check_admin("", &data));
}
#[test]
fn test_check_admin_with_correct_key() {
let data = create_default_app_state(false, "secret_key");
assert!(check_admin("secret_key", &data));
}
#[test]
fn test_check_admin_with_incorrect_key() {
let data = create_default_app_state(false, "secret_key");
assert!(!check_admin("wrong_key", &data));
}
#[test]
fn test_check_admin_with_empty_token_and_set_key() {
let data = create_default_app_state(false, "secret_key");
assert!(!check_admin("", &data));
}
#[test]
fn test_check_admin_empty_token_non_empty_key() {
let data = create_default_app_state(false, "some_key"); assert!(!check_admin("", &data));
}
}
#[cfg(test)]
mod is_key_valid_tests {
use super::*;
#[test]
fn test_valid_key_with_length_and_prefix() {
assert!(is_key_valid("testkey123", 10, "test"));
}
#[test]
fn test_valid_key_with_length_no_prefix() {
assert!(is_key_valid("exactlength", 11, ""));
}
#[test]
fn test_valid_key_with_prefix_no_length() {
assert!(is_key_valid("mysecretkey", 0, "mysecret"));
}
#[test]
fn test_valid_key_no_constraints() {
assert!(is_key_valid("anykeywilldo", 0, ""));
}
#[test]
fn test_invalid_key_wrong_length() {
assert!(!is_key_valid("short", 10, "short"));
assert!(!is_key_valid("toolongkey", 5, "toolong"));
}
#[test]
fn test_invalid_key_wrong_prefix() {
assert!(!is_key_valid("wrongprefixkey", 14, "correct"));
}
#[test]
fn test_invalid_key_wrong_length_and_prefix() {
assert!(!is_key_valid("bad", 10, "wrong"));
}
#[test]
fn test_empty_key_with_constraints() {
assert!(!is_key_valid("", 5, "prefix"));
assert!(!is_key_valid("", 0, "prefix")); assert!(!is_key_valid("", 1, "")); }
#[test]
fn test_empty_key_no_constraints() {
assert!(is_key_valid("", 0, ""));
}
#[test]
fn test_key_min_length() {
assert!(is_key_valid("abc", 3, "ab"));
}
#[test]
fn test_key_with_special_characters() {
assert!(is_key_valid("k€y!@#$", 7, "k€y"));
}
#[test]
fn test_prefix_longer_than_key() {
assert!(!is_key_valid("abc", 3, "abcd"));
}
}
#[cfg(test)]
mod parse_headers_tests {
use super::*;
use actix_web::http::header::{HeaderMap, HeaderName, HeaderValue};
#[test]
fn test_parse_headers_all_present() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-forwarded-for"),
HeaderValue::from_static("203.0.113.195"),
);
headers.insert(
HeaderName::from_static("x-original-host"),
HeaderValue::from_static("my-host.com"),
);
headers.insert(
HeaderName::from_static("x-original-uri"),
HeaderValue::from_static("/some/path"),
);
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static("metadata1"),
);
let original_length = 100;
let metadata_length = 50;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "203.0.113.195");
assert_eq!(result.original_host, "my-host.com");
assert_eq!(result.original_uri, "/some/path");
assert_eq!(result.metadata, "metadata1");
}
#[test]
fn test_parse_headers_missing_headers() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-original-other"),
HeaderValue::from_static("other-value"),
);
let original_length = 100;
let metadata_length = 10;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "-");
assert_eq!(result.original_host, "-");
assert_eq!(result.original_uri, "-");
assert_eq!(result.metadata, "-");
}
#[test]
fn test_parse_headers_all_missing_except_metadata() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static("metadata1"),
);
let original_length = 100;
let metadata_length = 10;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "-");
assert_eq!(result.original_host, "-");
assert_eq!(result.original_uri, "-");
assert_eq!(result.metadata, "metadata1");
}
#[test]
fn test_parse_headers_original_truncation() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-forwarded-for"),
HeaderValue::from_static("203.0.113.195"),
);
headers.insert(
HeaderName::from_static("x-original-host"),
HeaderValue::from_static("my-host.com"),
);
headers.insert(
HeaderName::from_static("x-original-uri"),
HeaderValue::from_static("/some/path"),
);
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static("metadata1"),
);
let original_length = 5;
let metadata_length = 50;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "203.0");
assert_eq!(result.original_host, "my-ho");
assert_eq!(result.original_uri, "/some");
assert_eq!(result.metadata, "metadata1");
}
#[test]
fn test_parse_headers_akas_metadata_truncation() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static("long_value_that_should_be_truncated"),
);
let original_length = 100;
let metadata_length = 10;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.metadata, "long_value");
}
#[test]
fn test_parse_headers_akas_metadata_exact_length() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static("12345"),
);
let original_length = 100;
let metadata_length = 5;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "-");
assert_eq!(result.original_host, "-");
assert_eq!(result.original_uri, "-");
assert_eq!(result.metadata, "12345");
}
#[test]
fn test_parse_headers_original_length_zero() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-forwarded-for"),
HeaderValue::from_static("203.0.113.195"),
);
headers.insert(
HeaderName::from_static("x-original-host"),
HeaderValue::from_static("my-host.com"),
);
headers.insert(
HeaderName::from_static("x-original-uri"),
HeaderValue::from_static("/some/path"),
);
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static("metadata1"),
);
let original_length = 0;
let metadata_length = 20;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "-");
assert_eq!(result.original_host, "-");
assert_eq!(result.original_uri, "-");
assert_eq!(result.metadata, "metadata1");
}
#[test]
fn test_parse_headers_metadata_length_zero() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-original-host"),
HeaderValue::from_static("my-host.com"),
);
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static("metadata1"),
);
let original_length = 100;
let metadata_length = 0;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "-");
assert_eq!(result.original_host, "my-host.com");
assert_eq!(result.original_uri, "-");
assert_eq!(result.metadata, "-");
}
#[test]
fn test_parse_headers_empty_header_values() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-original-host"),
HeaderValue::from_static(""),
);
headers.insert(
HeaderName::from_static("x-original-uri"),
HeaderValue::from_static(""),
);
headers.insert(
HeaderName::from_static("x-akas-metadata"),
HeaderValue::from_static(""),
);
let original_length = 100;
let metadata_length = 10;
let result = parse_headers(&headers, original_length, metadata_length);
assert_eq!(result.forwarded_for, "-");
assert_eq!(result.original_host, "-");
assert_eq!(result.original_uri, "-");
assert_eq!(result.metadata, "-");
}
}
#[cfg(test)]
mod get_truncated_header_tests {
use super::*;
use actix_web::http::header::{HeaderMap, HeaderName, HeaderValue};
#[test]
fn test_header_exists_and_no_truncation_needed() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-test-header"),
HeaderValue::from_static("Hello World"),
);
let result = get_truncated_header(&headers, "x-test-header", 20);
assert_eq!(result, "Hello World".to_string());
}
#[test]
fn test_header_exists_and_truncation_needed() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-test-header"),
HeaderValue::from_static("This is a long string that needs to be truncated."),
);
let result = get_truncated_header(&headers, "x-test-header", 10);
assert_eq!(result, "This is a ".to_string()); }
#[test]
fn test_header_exists_and_exact_length() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-test-header"),
HeaderValue::from_static("ExactLength"),
);
let result = get_truncated_header(&headers, "x-test-header", 11);
assert_eq!(result, "ExactLength".to_string());
}
#[test]
fn test_header_not_found() {
let headers = HeaderMap::new(); let result = get_truncated_header(&headers, "non-existent-header", 10);
assert_eq!(result, "-".to_string()); }
#[test]
fn test_max_len_is_zero() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-test-header"),
HeaderValue::from_static("Some value"),
);
let result = get_truncated_header(&headers, "x-test-header", 0);
assert_eq!(result, "-".to_string()); }
#[test]
fn test_empty_header_value() {
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-empty-header"),
HeaderValue::from_static(""),
);
let result = get_truncated_header(&headers, "x-empty-header", 10);
assert_eq!(result, "-".to_string());
}
}
#[cfg(test)]
mod sha256_hex_tests {
use super::*;
use crate::state::EMPTY_STRING_SHA256;
#[test]
fn test_sha256_hex_hello_world() {
let input = "hello world";
let expected_hash = "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9";
assert_eq!(sha256_hex(input), expected_hash);
}
#[test]
fn test_sha256_hex_empty_string() {
let input = "";
assert_eq!(sha256_hex(input), EMPTY_STRING_SHA256);
}
#[test]
fn test_sha256_hex_long_string() {
let input = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.";
let expected_hash = "1f38b148591b024f56cd04fa661758d758dd31d855a225c4645126e76be72f32";
assert_eq!(sha256_hex(input), expected_hash);
}
#[test]
fn test_sha256_hex_with_unicode() {
let input = "你好世界"; let expected_hash = "beca6335b20ff57ccc47403ef4d9e0b8fccb4442b3151c2e7d50050673d43172";
assert_eq!(sha256_hex(input), expected_hash);
}
}
#[cfg(test)]
mod validate_inputs_tests {
use super::*;
#[test]
fn test_validate_inputs_ok() {
assert!(
validate_inputs("https://example.com", "tests/files/plain_key.txt", "plain").is_ok()
);
assert!(
validate_inputs("https://example.com", "tests/files/plain_key.txt", "sha256").is_ok()
);
assert!(validate_inputs(
"http://localhost:5001",
"tests/files/plain_key.txt",
"sha256"
)
.is_ok());
}
#[test]
fn test_validate_inputs_invalid_host() {
let err = validate_inputs("invalid_url", "tests/files/plain_key.txt", "plain").unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
assert_eq!(err.to_string(), "invalid_url is not a valid URL.");
}
#[test]
fn test_validate_inputs_non_existent_file() {
let err = validate_inputs(
"https://example.com",
"tests/files/non_existent_file.txt",
"plain",
)
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
assert_eq!(
err.to_string(),
"tests/files/non_existent_file.txt is not a path to an existing file."
);
}
#[test]
fn test_validate_inputs_dir_instead_file() {
let err = validate_inputs("https://example.com", "tests/files", "plain").unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
assert_eq!(err.to_string(), "tests/files is not a valid file.");
}
#[test]
fn test_validate_inputs_invalid_format() {
let err = validate_inputs(
"https://example.com",
"tests/files/plain_key.txt",
"invalid_format",
)
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidInput);
assert_eq!(
err.to_string(),
"invalid_format is not a valid format. Choose from 'plain' or 'sha256'."
);
}
}