use crate::error::AwsError;
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::Sha256;
use std::sync::OnceLock;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
const TOKEN_VERSION: u8 = 1;
const TAG_LEN: usize = 16;
const MIN_ENVELOPE_LEN: usize = 1 + 8 + TAG_LEN;
type HmacSha256 = Hmac<Sha256>;
static SIGNING_KEY: OnceLock<[u8; 32]> = OnceLock::new();
fn signing_key() -> &'static [u8; 32] {
SIGNING_KEY.get_or_init(|| {
let mut k = [0u8; 32];
rand::thread_rng().fill_bytes(&mut k);
k
})
}
fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub fn mint(principal: impl AsRef<str>, ttl: Duration) -> String {
let expiry = now_unix().saturating_add(ttl.as_secs());
let principal = principal.as_ref().as_bytes();
let mut envelope = Vec::with_capacity(1 + 8 + principal.len() + TAG_LEN);
envelope.push(TOKEN_VERSION);
envelope.extend_from_slice(&expiry.to_be_bytes());
envelope.extend_from_slice(principal);
let mut mac = HmacSha256::new_from_slice(signing_key()).expect("HMAC accepts any key length");
mac.update(&envelope);
let tag = mac.finalize().into_bytes();
envelope.extend_from_slice(&tag[..TAG_LEN]);
URL_SAFE_NO_PAD.encode(&envelope)
}
pub fn verify(token: &str) -> Result<String, AwsError> {
let envelope = URL_SAFE_NO_PAD.decode(token).map_err(|_| denied())?;
if envelope.len() < MIN_ENVELOPE_LEN {
return Err(denied());
}
if envelope[0] != TOKEN_VERSION {
return Err(denied());
}
let tag_start = envelope.len() - TAG_LEN;
let (signed, tag) = envelope.split_at(tag_start);
let mut mac = HmacSha256::new_from_slice(signing_key()).expect("HMAC accepts any key length");
mac.update(signed);
let expected = mac.finalize().into_bytes();
if !constant_time_eq(tag, &expected[..TAG_LEN]) {
return Err(denied());
}
let mut expiry_bytes = [0u8; 8];
expiry_bytes.copy_from_slice(&signed[1..9]);
let expiry = u64::from_be_bytes(expiry_bytes);
if expiry < now_unix() {
return Err(denied());
}
let principal = String::from_utf8(signed[9..].to_vec()).map_err(|_| denied())?;
Ok(principal)
}
fn denied() -> AwsError {
AwsError::access_denied("Invalid or expired bearer token.")
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_recovers_principal() {
let tok = mint(
"arn:aws:iam::111122223333:user/alice",
Duration::from_secs(60),
);
let got = verify(&tok).unwrap();
assert_eq!(got, "arn:aws:iam::111122223333:user/alice");
}
#[test]
fn tampered_token_rejected() {
let tok = mint("alice", Duration::from_secs(60));
let mut bytes = URL_SAFE_NO_PAD.decode(&tok).unwrap();
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
let tampered = URL_SAFE_NO_PAD.encode(&bytes);
assert!(verify(&tampered).is_err());
}
#[test]
fn forged_with_wrong_key_rejected() {
let foreign = [0u8; 32];
let principal = b"alice";
let mut envelope = Vec::new();
envelope.push(TOKEN_VERSION);
let expiry = now_unix() + 60;
envelope.extend_from_slice(&expiry.to_be_bytes());
envelope.extend_from_slice(principal);
let mut mac = HmacSha256::new_from_slice(&foreign).unwrap();
mac.update(&envelope);
let tag = mac.finalize().into_bytes();
envelope.extend_from_slice(&tag[..TAG_LEN]);
let tok = URL_SAFE_NO_PAD.encode(&envelope);
assert!(verify(&tok).is_err());
}
#[test]
fn truncated_token_rejected() {
assert!(verify("YQ").is_err());
}
#[test]
fn expired_token_rejected() {
let tok = mint("alice", Duration::from_secs(0));
std::thread::sleep(Duration::from_millis(1100));
assert!(verify(&tok).is_err());
}
#[test]
fn empty_principal_round_trips() {
let tok = mint("", Duration::from_secs(60));
assert_eq!(verify(&tok).unwrap(), "");
}
#[test]
fn wrong_version_byte_rejected() {
let mut envelope = Vec::new();
envelope.push(99);
envelope.extend_from_slice(&(now_unix() + 60).to_be_bytes());
envelope.extend_from_slice(b"alice");
let mut mac = HmacSha256::new_from_slice(signing_key()).unwrap();
mac.update(&envelope);
let tag = mac.finalize().into_bytes();
envelope.extend_from_slice(&tag[..TAG_LEN]);
let tok = URL_SAFE_NO_PAD.encode(&envelope);
assert!(verify(&tok).is_err());
}
}