use crate::error::AwsError;
use base64::Engine as _;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use hmac::{Hmac, Mac};
use rand::RngCore;
use sha2::Sha256;
use std::sync::OnceLock;
use std::time::{SystemTime, UNIX_EPOCH};
const TOKEN_INVALID_CODE: &str = "InvalidParameterValue";
const TOKEN_INVALID_MSG: &str = "The pagination token is malformed or expired.";
const MAX_RESULTS_INVALID_CODE: &str = "ValidationException";
const TOKEN_VERSION: u8 = 1;
const TAG_LEN: usize = 16;
const MIN_ENVELOPE_LEN: usize = 1 + 8 + TAG_LEN;
pub const TOKEN_TTL_SECONDS: u64 = 6 * 60 * 60;
type HmacSha256 = Hmac<Sha256>;
static SIGNING_KEY: OnceLock<[u8; 32]> = OnceLock::new();
fn signing_key() -> &'static [u8; 32] {
SIGNING_KEY.get_or_init(|| {
let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut key);
key
})
}
fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[derive(Debug)]
pub struct Page<T> {
pub items: Vec<T>,
pub next_token: Option<String>,
}
pub fn encode_token(marker: &str) -> String {
encode_token_with_expiry(marker, now_unix().saturating_add(TOKEN_TTL_SECONDS))
}
fn encode_token_with_expiry(marker: &str, expiry: u64) -> String {
let marker_bytes = marker.as_bytes();
let mut envelope = Vec::with_capacity(1 + 8 + marker_bytes.len() + TAG_LEN);
envelope.push(TOKEN_VERSION);
envelope.extend_from_slice(&expiry.to_be_bytes());
envelope.extend_from_slice(marker_bytes);
let mut mac = HmacSha256::new_from_slice(signing_key()).expect("HMAC accepts any key length");
mac.update(&envelope);
let tag = mac.finalize().into_bytes();
envelope.extend_from_slice(&tag[..TAG_LEN]);
URL_SAFE_NO_PAD.encode(&envelope)
}
pub fn decode_token(token: &str) -> Result<String, AwsError> {
let envelope = URL_SAFE_NO_PAD.decode(token).map_err(|_| token_invalid())?;
if envelope.len() < MIN_ENVELOPE_LEN {
return Err(token_invalid());
}
if envelope[0] != TOKEN_VERSION {
return Err(token_invalid());
}
let tag_start = envelope.len() - TAG_LEN;
let (signed, tag) = envelope.split_at(tag_start);
let mut mac = HmacSha256::new_from_slice(signing_key()).expect("HMAC accepts any key length");
mac.update(signed);
let expected = mac.finalize().into_bytes();
if !constant_time_eq(tag, &expected[..TAG_LEN]) {
return Err(token_invalid());
}
let mut expiry_bytes = [0u8; 8];
expiry_bytes.copy_from_slice(&signed[1..9]);
let expiry = u64::from_be_bytes(expiry_bytes);
if expiry < now_unix() {
return Err(token_invalid());
}
let marker_bytes = &signed[9..];
String::from_utf8(marker_bytes.to_vec()).map_err(|_| token_invalid())
}
fn token_invalid() -> AwsError {
AwsError::bad_request(TOKEN_INVALID_CODE, TOKEN_INVALID_MSG)
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
pub fn cap_max_results(requested: Option<i64>, default: usize, max: usize) -> usize {
match requested {
None => default.min(max),
Some(n) if n < 1 => 1,
Some(n) => (n as usize).min(max),
}
}
pub fn clamp_max_results_strict(
requested: Option<i64>,
default: usize,
max: usize,
) -> Result<usize, AwsError> {
let n = match requested {
None => return Ok(default.min(max)),
Some(n) => n,
};
if n < 1 {
return Err(AwsError::bad_request(
MAX_RESULTS_INVALID_CODE,
format!("MaxResults must be at least 1, got {n}."),
));
}
let n = n as usize;
if n > max {
return Err(AwsError::bad_request(
MAX_RESULTS_INVALID_CODE,
format!("MaxResults must be at most {max}, got {n}."),
));
}
Ok(n)
}
pub fn paginate<T, F>(
items: Vec<T>,
max_results: usize,
starting_token: Option<&str>,
key_fn: F,
) -> Result<Page<T>, AwsError>
where
F: Fn(&T) -> String,
{
if max_results == 0 {
return Ok(Page {
items: Vec::new(),
next_token: None,
});
}
let start_idx = match starting_token {
None => 0,
Some(token) => {
let marker = decode_token(token)?;
items
.iter()
.position(|item| key_fn(item) >= marker)
.unwrap_or(items.len())
}
};
let total_len = items.len();
let take_n = max_results.min(total_len.saturating_sub(start_idx));
let mut iter = items.into_iter().skip(start_idx);
let page_items: Vec<T> = iter.by_ref().take(take_n).collect();
let next_token = iter.next().map(|next| encode_token(&key_fn(&next)));
Ok(Page {
items: page_items,
next_token,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn key(s: &&'static str) -> String {
(*s).to_string()
}
#[test]
fn empty_input_returns_empty_page() {
let page = paginate::<&str, _>(vec![], 10, None, key).unwrap();
assert!(page.items.is_empty());
assert!(page.next_token.is_none());
}
#[test]
fn page_smaller_than_max_results_no_token() {
let items = vec!["alpha", "bravo", "charlie"];
let page = paginate(items, 10, None, key).unwrap();
assert_eq!(page.items, vec!["alpha", "bravo", "charlie"]);
assert!(page.next_token.is_none());
}
#[test]
fn page_exactly_full_no_token_when_no_more() {
let items = vec!["alpha", "bravo", "charlie"];
let page = paginate(items, 3, None, key).unwrap();
assert_eq!(page.items.len(), 3);
assert!(page.next_token.is_none());
}
#[test]
fn page_full_with_more_emits_token() {
let items = vec!["alpha", "bravo", "charlie", "delta"];
let page = paginate(items, 2, None, key).unwrap();
assert_eq!(page.items, vec!["alpha", "bravo"]);
assert_eq!(
decode_token(page.next_token.as_deref().unwrap()).unwrap(),
"charlie"
);
}
#[test]
fn resuming_with_token_returns_next_page() {
let items = vec!["alpha", "bravo", "charlie", "delta"];
let token = encode_token("charlie");
let page = paginate(items, 2, Some(&token), key).unwrap();
assert_eq!(page.items, vec!["charlie", "delta"]);
assert!(page.next_token.is_none());
}
#[test]
fn token_pointing_at_deleted_key_advances_to_next_present() {
let items = vec!["alpha", "charlie", "delta"];
let token = encode_token("bravo");
let page = paginate(items, 10, Some(&token), key).unwrap();
assert_eq!(page.items, vec!["charlie", "delta"]);
assert!(page.next_token.is_none());
}
#[test]
fn token_past_end_returns_empty_page() {
let items = vec!["alpha", "bravo"];
let token = encode_token("zzz");
let page = paginate(items, 10, Some(&token), key).unwrap();
assert!(page.items.is_empty());
assert!(page.next_token.is_none());
}
#[test]
fn invalid_base64_token_returns_error() {
let items = vec!["alpha"];
let err = paginate(items, 10, Some("!!!not-base64!!!"), key).unwrap_err();
assert_eq!(err.code, TOKEN_INVALID_CODE);
}
#[test]
fn invalid_utf8_marker_returns_error() {
let bad_marker = [0xff, 0xfe, 0xfd];
let expiry = now_unix().saturating_add(TOKEN_TTL_SECONDS);
let mut envelope = Vec::new();
envelope.push(TOKEN_VERSION);
envelope.extend_from_slice(&expiry.to_be_bytes());
envelope.extend_from_slice(&bad_marker);
let mut mac = HmacSha256::new_from_slice(signing_key()).unwrap();
mac.update(&envelope);
let tag = mac.finalize().into_bytes();
envelope.extend_from_slice(&tag[..TAG_LEN]);
let token = URL_SAFE_NO_PAD.encode(&envelope);
let items = vec!["alpha"];
let err = paginate(items, 10, Some(&token), key).unwrap_err();
assert_eq!(err.code, TOKEN_INVALID_CODE);
}
#[test]
fn tampered_token_is_rejected() {
let token = encode_token("charlie");
let mut bytes = URL_SAFE_NO_PAD.decode(&token).unwrap();
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
let tampered = URL_SAFE_NO_PAD.encode(&bytes);
let err = decode_token(&tampered).unwrap_err();
assert_eq!(err.code, TOKEN_INVALID_CODE);
}
#[test]
fn forged_token_with_wrong_key_is_rejected() {
let foreign_key = [0u8; 32];
let mut envelope = Vec::new();
envelope.push(TOKEN_VERSION);
envelope.extend_from_slice(&now_unix().saturating_add(TOKEN_TTL_SECONDS).to_be_bytes());
envelope.extend_from_slice(b"charlie");
let mut mac = HmacSha256::new_from_slice(&foreign_key).unwrap();
mac.update(&envelope);
let tag = mac.finalize().into_bytes();
envelope.extend_from_slice(&tag[..TAG_LEN]);
let forged = URL_SAFE_NO_PAD.encode(&envelope);
let err = decode_token(&forged).unwrap_err();
assert_eq!(err.code, TOKEN_INVALID_CODE);
}
#[test]
fn expired_token_is_rejected() {
let already_expired = now_unix().saturating_sub(60);
let token = encode_token_with_expiry("charlie", already_expired);
let err = decode_token(&token).unwrap_err();
assert_eq!(err.code, TOKEN_INVALID_CODE);
}
#[test]
fn wrong_version_byte_is_rejected() {
let mut envelope = Vec::new();
envelope.push(99);
envelope.extend_from_slice(&now_unix().saturating_add(TOKEN_TTL_SECONDS).to_be_bytes());
envelope.extend_from_slice(b"charlie");
let mut mac = HmacSha256::new_from_slice(signing_key()).unwrap();
mac.update(&envelope);
let tag = mac.finalize().into_bytes();
envelope.extend_from_slice(&tag[..TAG_LEN]);
let token = URL_SAFE_NO_PAD.encode(&envelope);
let err = decode_token(&token).unwrap_err();
assert_eq!(err.code, TOKEN_INVALID_CODE);
}
#[test]
fn truncated_envelope_is_rejected() {
let err = decode_token("YQ").unwrap_err();
assert_eq!(err.code, TOKEN_INVALID_CODE);
}
#[test]
fn round_trip_through_full_collection_yields_every_item_once() {
let all: Vec<&'static str> = vec!["a", "b", "c", "d", "e", "f", "g"];
let mut seen: Vec<&'static str> = Vec::new();
let mut token: Option<String> = None;
loop {
let page = paginate(all.clone(), 2, token.as_deref(), key).unwrap();
seen.extend(page.items);
match page.next_token {
Some(t) => token = Some(t),
None => break,
}
}
assert_eq!(seen, all);
}
#[test]
fn cap_max_results_honors_default_when_unset() {
assert_eq!(cap_max_results(None, 100, 1000), 100);
}
#[test]
fn cap_max_results_caps_at_max() {
assert_eq!(cap_max_results(Some(5000), 100, 1000), 1000);
}
#[test]
fn cap_max_results_floors_at_one() {
assert_eq!(cap_max_results(Some(0), 100, 1000), 1);
assert_eq!(cap_max_results(Some(-3), 100, 1000), 1);
}
#[test]
fn cap_max_results_caps_default_at_max() {
assert_eq!(cap_max_results(None, 5000, 1000), 1000);
}
#[test]
fn clamp_strict_accepts_in_range() {
assert_eq!(clamp_max_results_strict(Some(50), 100, 1000).unwrap(), 50);
}
#[test]
fn clamp_strict_uses_default_when_unset() {
assert_eq!(clamp_max_results_strict(None, 100, 1000).unwrap(), 100);
}
#[test]
fn clamp_strict_rejects_zero() {
let err = clamp_max_results_strict(Some(0), 100, 1000).unwrap_err();
assert_eq!(err.code, MAX_RESULTS_INVALID_CODE);
}
#[test]
fn clamp_strict_rejects_above_max() {
let err = clamp_max_results_strict(Some(2000), 100, 1000).unwrap_err();
assert_eq!(err.code, MAX_RESULTS_INVALID_CODE);
}
#[test]
fn clamp_strict_rejects_negative() {
let err = clamp_max_results_strict(Some(-5), 100, 1000).unwrap_err();
assert_eq!(err.code, MAX_RESULTS_INVALID_CODE);
}
}