#![deny(missing_docs)]
#![deny(missing_debug_implementations)]
#![forbid(unsafe_code)]
#![warn(clippy::all)]
#![warn(clippy::pedantic)]
#![warn(trivial_casts, trivial_numeric_casts)]
#![warn(unsafe_op_in_unsafe_fn)]
#![warn(unused_qualifications)]
use chrono::{NaiveDate, Utc};
use data_encoding::BASE32;
use serde::Deserialize;
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use thiserror::Error;
use tracing::{debug, error, warn};
#[derive(Debug, Deserialize, Eq, PartialEq)]
pub struct Config {
accounts: HashMap<String, Account>,
}
#[derive(Debug, Deserialize, Eq, PartialEq)]
struct Account {
regions: HashMap<String, Services>,
}
#[derive(Debug, Deserialize, Eq, PartialEq)]
struct Services {
services: Vec<String>,
}
#[derive(Debug, PartialEq, Eq)]
pub struct AWSCredential {
pub access_key_id: String,
pub account_id: String,
pub date: NaiveDate,
pub region: String,
pub service: String,
}
impl AWSCredential {
const BYTE_MASK: u64 = 0x7fff_ffff_ff80;
const ANY: &'static str = "*";
#[tracing::instrument]
pub fn new_from_http_authz(header: &str) -> Result<AWSCredential, AWSCredentialError> {
let start = header
.find("Credential=")
.ok_or_else(|| AWSCredentialError::AuthHeaderMissingParts(header.to_string()))?;
let value_start = start + 11;
let end = header[value_start..].find(',').unwrap_or(header.len());
let header = Ok(&header[value_start..value_start + end])?;
debug!(header = header);
Ok(AWSCredential::new(header))?
}
pub fn new(credential: &str) -> Result<AWSCredential, AWSCredentialError> {
let parts: Vec<&str> = credential.split('/').collect();
if parts.len() != 5 {
error!(error = %AWSCredentialError::CredentialComponentMissingParts(credential.to_string()));
return Err(AWSCredentialError::CredentialComponentMissingParts(
credential.to_string(),
));
}
let account_id = AWSCredential::get_account_id(parts[0].as_bytes())?;
let date = AWSCredential::parse_date(parts[1])?;
let service = parts[3].to_string();
debug!(
credential = credential,
access_key_id = parts[0].to_string(),
status = "Parsed"
);
Ok(AWSCredential {
access_key_id: parts[0].to_string(),
region: parts[2].to_string(),
account_id,
date,
service,
})
}
#[must_use]
pub fn is_request_allowed(&self, config: &Config) -> bool {
let Some(account) = &config
.accounts
.get(&self.account_id)
.or_else(|| config.accounts.get(Self::ANY))
else {
debug!(
access_key_id = self.access_key_id,
region = self.region,
service = self.service,
status = "Denied"
);
return false;
};
let Some(services) = account
.regions
.get(&self.region)
.or_else(|| account.regions.get(Self::ANY))
else {
debug!(
access_key_id = self.access_key_id,
region = self.region,
service = self.service,
status = "Denied"
);
return false;
};
if services.services.contains(&self.service)
|| services.services.contains(&Self::ANY.to_owned())
{
debug!(
access_key_id = self.access_key_id,
region = self.region,
service = self.service,
status = "Allowed"
);
return true;
}
debug!(
access_key_id = self.access_key_id,
region = self.region,
service = self.service,
status = "Denied"
);
false
}
fn parse_date(date_str: &str) -> Result<NaiveDate, AWSCredentialError> {
match NaiveDate::parse_from_str(date_str, "%Y%m%d") {
Ok(date) => Ok(date),
Err(e) => {
error!(error = %AWSCredentialError::DateParseError(e.to_string()));
Err(AWSCredentialError::DateParseError(e.to_string()))
}
}
}
fn get_account_id(access_key_id: &[u8]) -> Result<String, AWSCredentialError> {
if access_key_id.len() <= 12 {
error!(error = %AWSCredentialError::AccessKeyIDLengthError(access_key_id.len().to_string()));
return Err(AWSCredentialError::AccessKeyIDLengthError(
access_key_id.len().to_string(),
));
}
let key_part = &access_key_id[4..];
match BASE32.decode_len(key_part.len()) {
Ok(decode_len) => {
if decode_len != 10 {
error!(error = %AWSCredentialError::AccessKeyIDLengthError(decode_len.to_string()));
return Err(AWSCredentialError::AccountMissingFromAccessKeyId(
decode_len.to_string(),
));
}
}
Err(e) => {
error!(time = %Utc::now().to_rfc3339(), error = %AWSCredentialError::Base32DecodeError(e.to_string()));
return Err(AWSCredentialError::Base32DecodeError(e.to_string()));
}
};
let mut output: [u8; 10] = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let _ = BASE32.decode_mut(key_part, &mut output);
let decodedb = u64::from_be_bytes([
0, 0, output[0], output[1], output[2], output[3], output[4], output[5],
]);
let e = (decodedb & AWSCredential::BYTE_MASK) >> 7;
debug!(credentials = e);
Ok(format!("{e:0>12}"))
}
pub fn read_config(&self, file_path: PathBuf) -> Result<Config, ConfigError> {
let mut file = File::open(file_path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
debug!(status = "Config parsed.");
Ok(serde_yaml::from_str(&contents)?)
}
}
#[non_exhaustive]
#[derive(Error, Debug, PartialEq)]
pub enum AWSCredentialError {
#[error("Access Key ID invalid length, expected more than 12 chars got: {0}")]
AccessKeyIDLengthError(String),
#[error("Auth header missing parts: {0}")]
AuthHeaderMissingParts(String),
#[error("Could not find account id in access key: {0}")]
AccountMissingFromAccessKeyId(String),
#[error("Base32 Decode Error {0}")]
Base32DecodeError(String),
#[error("Credential component missing parts: {0}")]
CredentialComponentMissingParts(String),
#[error("Could not parse date {0}")]
DateParseError(String),
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("YAML parse error: {0}")]
YamlParse(#[from] serde_yaml::Error),
}
#[cfg(test)]
mod tests {
use crate::AWSCredentialError;
use super::*;
use std::io::Write;
fn temp_file_with_content(content: &str) -> PathBuf {
let mut file = tempfile::NamedTempFile::new().unwrap();
writeln!(file, "{}", content).unwrap();
file.into_temp_path().to_path_buf()
}
#[test]
fn correct_authz_header() {
let authz_header = r#"
Authorization: AWS4-HMAC-SHA256
Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request,
SignedHeaders=host;range;x-amz-date,
Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024
"#;
let acc = AWSCredential::new_from_http_authz(authz_header).unwrap();
assert_eq!(acc.account_id, "581039954779".to_string());
assert_eq!(acc.region, "us-east-1".to_string());
assert_eq!(
acc.date.format("%Y-%m-%d").to_string(),
"2013-05-24".to_string()
);
}
#[test]
fn wrong_authz_header() {
let authz_header = r#"
Authorization: Credent=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request,
"#;
let acc = AWSCredential::new_from_http_authz(authz_header);
assert_eq!(
acc,
Err(AWSCredentialError::AuthHeaderMissingParts(
authz_header.to_string()
))
)
}
#[test]
fn empty_authz_header() {
let acc = AWSCredential::new_from_http_authz("");
assert_eq!(
acc,
Err(AWSCredentialError::AuthHeaderMissingParts("".to_string()))
)
}
#[test]
fn long_authz_header() {
let long_string = "a".repeat(10000);
let acc = AWSCredential::new_from_http_authz(&long_string);
assert_eq!(
acc,
Err(AWSCredentialError::AuthHeaderMissingParts(
long_string.to_string()
))
)
}
#[test]
fn correct_credential_header() {
let accone =
AWSCredential::new("ASIAQNZGKIQY56JQ7WML/20221228/eu-west-1/ec2/aws4_request").unwrap();
assert_eq!(accone.account_id, "029608264753".to_string());
assert_eq!(accone.region, "eu-west-1".to_string());
assert_eq!(
accone.date.format("%Y-%m-%d").to_string(),
"2022-12-28".to_string()
);
}
#[test]
fn wrong_credential_header() {
let acc = AWSCredential::new("ASIAQNZGKI/20221228/eu-west-1/ec2/aws4_request");
assert_eq!(
acc,
Err(AWSCredentialError::AccessKeyIDLengthError("10".to_string()))
)
}
#[test]
fn wrong_date_format() {
let d = AWSCredential::parse_date("20221228").unwrap();
assert_eq!(d.format("%Y%m%d").to_string(), "20221228".to_string())
}
#[test]
fn wrong_date_credential_header() {
let acc = AWSCredential::new("ASIAQNZGKIQY56JQ7WML/202228/eu-west-1/ec2/aws4_request");
assert_eq!(
acc,
Err(AWSCredentialError::DateParseError(
"premature end of input".to_string()
))
)
}
#[test]
fn empty_credential_header() {
let acc = AWSCredential::new("");
assert_eq!(
acc,
Err(AWSCredentialError::CredentialComponentMissingParts(
"".to_string()
))
)
}
#[test]
fn known_account() {
let accone = AWSCredential::get_account_id(b"ASIAQNZGKIQY56JQ7WML");
assert_eq!(accone.unwrap(), "029608264753".to_string());
}
#[test]
fn known_account_zero() {
let accone = AWSCredential::get_account_id(b"ASIAAAAAAAAAAAAAAAAA");
assert_eq!(accone.unwrap(), "000000000000".to_string());
}
#[test]
fn bad_account_input() {
let acc = AWSCredential::get_account_id(b"A");
assert_eq!(
acc,
Err(AWSCredentialError::AccessKeyIDLengthError("1".to_string()))
);
}
#[test]
fn long_account_input() {
let long_string = "a".repeat(1000);
let acc = AWSCredential::get_account_id(long_string.as_bytes());
assert_eq!(
acc,
Err(AWSCredentialError::Base32DecodeError(
"invalid length at 992".to_string()
))
)
}
#[test]
fn test_read_yaml_invalid() {
let yaml_content = "not a valid yaml"; let file_path = temp_file_with_content(yaml_content);
let aws_creds =
AWSCredential::new("ASIAQNZGKIQY56JQ7WML/20221228/eu-west-1/ec2/aws4_request").unwrap();
let result = aws_creds.read_config(file_path);
assert!(result.is_err());
}
#[test]
fn test_read_yaml_file_not_found() {
let file_path = PathBuf::from("non_existent_file.yaml");
let aws_creds =
AWSCredential::new("ASIAQNZGKIQY56JQ7WML/20221228/eu-west-1/ec2/aws4_request").unwrap();
let result = aws_creds.read_config(file_path);
assert!(result.is_err());
}
}