use {
crate::PrincipalError,
std::fmt::{Display, Formatter, Result as FmtResult},
};
#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum IamIdPrefix {
AccessKey,
BearerToken,
Certificate,
ContextSpecificCredential,
Group,
InstanceProfile,
ManagedPolicy,
ManagedPolicyVersion,
PublicKey,
Role,
TemporaryAccessKey,
User,
}
impl Display for IamIdPrefix {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Self::AccessKey => f.write_str("AKIA"),
Self::BearerToken => f.write_str("ABIA"),
Self::Certificate => f.write_str("ASCA"),
Self::ContextSpecificCredential => f.write_str("ACCA"),
Self::Group => f.write_str("AGPA"),
Self::InstanceProfile => f.write_str("AIPA"),
Self::ManagedPolicy => f.write_str("ANPA"),
Self::ManagedPolicyVersion => f.write_str("ANVA"),
Self::PublicKey => f.write_str("APKA"),
Self::Role => f.write_str("AROA"),
Self::TemporaryAccessKey => f.write_str("ASIA"),
Self::User => f.write_str("AIDA"),
}
}
}
impl AsRef<str> for IamIdPrefix {
fn as_ref(&self) -> &str {
match self {
Self::AccessKey => "AKIA",
Self::BearerToken => "ABIA",
Self::Certificate => "ASCA",
Self::ContextSpecificCredential => "ACCA",
Self::Group => "AGPA",
Self::InstanceProfile => "AIPA",
Self::ManagedPolicy => "ANPA",
Self::ManagedPolicyVersion => "ANVA",
Self::PublicKey => "APKA",
Self::Role => "AROA",
Self::TemporaryAccessKey => "ASIA",
Self::User => "AIDA",
}
}
}
impl IamIdPrefix {
pub fn as_str(&self) -> &str {
self.as_ref()
}
}
pub fn validate_name<F: FnOnce(String) -> PrincipalError>(
name: &str,
max_length: usize,
map_err: F,
) -> Result<(), PrincipalError> {
let n_bytes = name.as_bytes();
let n_len = n_bytes.len();
if n_len == 0 || n_len > max_length {
return Err(map_err(name.to_string()));
}
for c in n_bytes {
if !(c.is_ascii_alphanumeric()
|| *c == b','
|| *c == b'-'
|| *c == b'.'
|| *c == b'='
|| *c == b'@'
|| *c == b'_')
{
return Err(map_err(name.to_string()));
}
}
Ok(())
}
pub fn validate_identifier<F: FnOnce(String) -> PrincipalError>(
id: &str,
prefix: &str,
map_err: F,
) -> Result<(), PrincipalError> {
if !id.starts_with(prefix) || id.len() < 20 {
Err(map_err(id.to_string()))
} else {
for c in id.as_bytes() {
if !(c.is_ascii_alphabetic() || (b'2'..=b'7').contains(c)) {
return Err(map_err(id.to_string()));
}
}
Ok(())
}
}
pub fn validate_path(path: &str) -> Result<(), PrincipalError> {
let p_bytes = path.as_bytes();
let p_len = p_bytes.len();
if p_len == 0 || p_len > 512 {
return Err(PrincipalError::InvalidPath(path.to_string()));
}
if p_bytes[0] != b'/' || p_bytes[p_len - 1] != b'/' {
return Err(PrincipalError::InvalidPath(path.to_string()));
}
for c in p_bytes {
if *c < 0x21 || *c > 0x7e {
return Err(PrincipalError::InvalidPath(path.to_string()));
}
}
Ok(())
}
pub fn validate_dns<F: FnOnce(String) -> PrincipalError>(
name: &str,
max_length: usize,
map_err: F,
) -> Result<(), PrincipalError> {
let name_bytes = name.as_bytes();
if name_bytes.is_empty() || name_bytes.len() > max_length {
return Err(map_err(name.to_string()));
}
let components = name_bytes.split(|c| *c == b'.');
for component in components {
if component.is_empty() || component.len() > 63 {
return Err(map_err(name.to_string()));
}
let mut last = b'-';
for c in component.iter() {
if *c == b'-' {
if last == b'-' {
return Err(map_err(name.to_string()));
}
} else if !c.is_ascii_alphanumeric() && *c != b'_' {
return Err(map_err(name.to_string()));
}
last = *c;
}
if last == b'-' {
return Err(map_err(name.to_string()));
}
}
Ok(())
}
#[cfg(test)]
mod test {
use {
super::{validate_dns, validate_identifier, validate_name, IamIdPrefix},
crate::PrincipalError,
std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
},
};
#[test]
fn check_names() {
validate_name("test", 32, PrincipalError::InvalidRoleName).unwrap();
validate_name("test,name-.with=exactly@32_chars", 32, PrincipalError::InvalidRoleName).unwrap();
assert_eq!(
validate_name("bad!name", 32, PrincipalError::InvalidRoleName).unwrap_err().to_string(),
r#"Invalid role name: "bad!name""#
);
}
fn validate_group_id(id: &str) -> Result<(), PrincipalError> {
validate_identifier(id, IamIdPrefix::Group.as_str(), PrincipalError::InvalidGroupId)
}
fn validate_instance_profile_id(id: &str) -> Result<(), PrincipalError> {
validate_identifier(id, IamIdPrefix::InstanceProfile.as_str(), PrincipalError::InvalidInstanceProfileId)
}
fn validate_role_id(id: &str) -> Result<(), PrincipalError> {
validate_identifier(id, IamIdPrefix::Role.as_str(), PrincipalError::InvalidRoleId)
}
fn validate_user_id(id: &str) -> Result<(), PrincipalError> {
validate_identifier(id, IamIdPrefix::User.as_str(), PrincipalError::InvalidUserId)
}
#[test]
fn check_identifiers() {
validate_group_id("AGPA234567ABCDEFGHIJ").unwrap();
let err = validate_group_id("AIDA234567ABCDEFGHIJ").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid group id: "AIDA234567ABCDEFGHIJ""#);
let err = validate_group_id("AGPA234567ABCDEFGHI!").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid group id: "AGPA234567ABCDEFGHI!""#);
let err = validate_group_id("AGPA234567ABCDEFGHI").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid group id: "AGPA234567ABCDEFGHI""#);
validate_instance_profile_id("AIPAKLMNOPQRSTUVWXYZ").unwrap();
let err = validate_instance_profile_id("AKIAKLMNOPQRSTUVWXYZ").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid instance profile id: "AKIAKLMNOPQRSTUVWXYZ""#);
let err = validate_instance_profile_id("AIPAKLMNOPQRSTUVWXY!").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid instance profile id: "AIPAKLMNOPQRSTUVWXY!""#);
let err = validate_instance_profile_id("AIPAKLMNOPQRSTUVWXY").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid instance profile id: "AIPAKLMNOPQRSTUVWXY""#);
validate_role_id("AROAKLMNOPQRSTUVWXYZ").unwrap();
let err = validate_role_id("AKIAKLMNOPQRSTUVWXYZ").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid role id: "AKIAKLMNOPQRSTUVWXYZ""#);
let err = validate_role_id("AROAKLMNOPQRSTUVWXY!").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid role id: "AROAKLMNOPQRSTUVWXY!""#);
let err = validate_role_id("AROAKLMNOPQRSTUVWXY").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid role id: "AROAKLMNOPQRSTUVWXY""#);
validate_user_id("AIDAKLMNOPQRSTUVWXYZ").unwrap();
let err = validate_user_id("AKIAKLMNOPQRSTUVWXYZ").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid user id: "AKIAKLMNOPQRSTUVWXYZ""#);
let err = validate_user_id("AIDAKLMNOPQRSTUVWXY!").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid user id: "AIDAKLMNOPQRSTUVWXY!""#);
let err = validate_user_id("AIDAKLMNOPQRSTUVWXY").unwrap_err();
assert_eq!(err.to_string(), r#"Invalid user id: "AIDAKLMNOPQRSTUVWXY""#);
}
#[test]
fn check_id_prefix_derived() {
let prefixes = vec![
IamIdPrefix::AccessKey,
IamIdPrefix::BearerToken,
IamIdPrefix::Certificate,
IamIdPrefix::ContextSpecificCredential,
IamIdPrefix::Group,
IamIdPrefix::InstanceProfile,
IamIdPrefix::ManagedPolicy,
IamIdPrefix::ManagedPolicyVersion,
IamIdPrefix::PublicKey,
IamIdPrefix::Role,
IamIdPrefix::TemporaryAccessKey,
IamIdPrefix::User,
];
let p1a = IamIdPrefix::AccessKey;
let p1b = p1a;
let p2 = IamIdPrefix::BearerToken;
assert_eq!(p1a, p1b);
assert_eq!(p1a, p1a.clone());
assert_ne!(p1a, p2);
let mut h1a = DefaultHasher::new();
let mut h1b = DefaultHasher::new();
let mut h2 = DefaultHasher::new();
p1a.hash(&mut h1a);
p1b.hash(&mut h1b);
p2.hash(&mut h2);
let hash1a = h1a.finish();
let hash1b = h1b.finish();
let hash2 = h2.finish();
assert_eq!(hash1a, hash1b);
assert_ne!(hash1a, hash2);
for i in 0..prefixes.len() {
for j in i + 1..prefixes.len() {
assert!(prefixes[i] < prefixes[j]);
assert!(prefixes[j] > prefixes[i]);
assert_eq!(prefixes[i].max(prefixes[j]), prefixes[j]);
}
let _ = format!("{:?}", prefixes[i]);
assert_eq!(prefixes[i].to_string().as_str(), prefixes[i].as_ref());
}
}
#[test]
fn check_access_key() {
assert_eq!(IamIdPrefix::AccessKey.as_ref(), "AKIA");
assert_eq!(format!("{}", IamIdPrefix::AccessKey).as_str(), "AKIA");
}
#[test]
fn check_dns() {
validate_dns("exa_mple.com", 256, PrincipalError::InvalidService).unwrap();
let e = validate_dns("exa_mple.com.", 256, PrincipalError::InvalidService).unwrap_err();
assert_eq!(e.to_string(), r#"Invalid service name: "exa_mple.com.""#);
let e = validate_dns("example.com", 5, PrincipalError::InvalidService).unwrap_err();
assert_eq!(e.to_string(), r#"Invalid service name: "example.com""#);
validate_dns("exam-ple.com", 256, PrincipalError::InvalidService).unwrap();
validate_dns("exam--ple.com", 256, PrincipalError::InvalidService).unwrap_err();
validate_dns("-example.com", 256, PrincipalError::InvalidService).unwrap_err();
validate_dns("example-.com", 256, PrincipalError::InvalidService).unwrap_err();
}
}