use crate::exception::{Error, Result};
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
pub trait CursorEncoder: Send + Sync {
fn encode(&self, position: usize) -> Result<String>;
fn decode(&self, cursor: &str) -> Result<usize>;
}
#[derive(Debug, Clone)]
pub struct Base64CursorEncoder {
pub expiry_seconds: u64,
secret_key: Vec<u8>,
}
impl Base64CursorEncoder {
pub fn new() -> Self {
use rand::RngCore;
let mut key = vec![0u8; 32];
rand::rng().fill_bytes(&mut key);
Self {
expiry_seconds: 86400, secret_key: key,
}
}
pub fn with_secret_key(key: &[u8]) -> Self {
Self {
expiry_seconds: 86400,
secret_key: key.to_vec(),
}
}
pub fn expiry_seconds(mut self, seconds: u64) -> Self {
self.expiry_seconds = seconds;
self
}
fn compute_hmac(&self, message: &[u8]) -> Vec<u8> {
let mut mac =
HmacSha256::new_from_slice(&self.secret_key).expect("HMAC accepts any key length");
mac.update(message);
mac.finalize().into_bytes().to_vec()
}
}
impl Default for Base64CursorEncoder {
fn default() -> Self {
Self::new()
}
}
impl CursorEncoder for Base64CursorEncoder {
fn encode(&self, position: usize) -> Result<String> {
use base64::{Engine as _, engine::general_purpose};
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let payload = format!("{}:{}", position, timestamp);
let hmac_bytes = self.compute_hmac(payload.as_bytes());
let hmac_hex = hex::encode(&hmac_bytes);
let cursor_data = format!("{}:{}:{}", position, timestamp, hmac_hex);
Ok(general_purpose::URL_SAFE_NO_PAD.encode(cursor_data.as_bytes()))
}
fn decode(&self, cursor: &str) -> Result<usize> {
use base64::{Engine as _, engine::general_purpose};
let decoded = general_purpose::URL_SAFE_NO_PAD
.decode(cursor)
.map_err(|_| Error::InvalidPage("Invalid cursor".to_string()))?;
let cursor_data = String::from_utf8(decoded)
.map_err(|_| Error::InvalidPage("Invalid cursor encoding".to_string()))?;
let parts: Vec<&str> = cursor_data.splitn(3, ':').collect();
if parts.len() != 3 {
return Err(Error::InvalidPage("Malformed cursor".to_string()));
}
let position: usize = parts[0]
.parse()
.map_err(|_| Error::InvalidPage("Invalid cursor value".to_string()))?;
let timestamp: u64 = parts[1]
.parse()
.map_err(|_| Error::InvalidPage("Invalid cursor timestamp".to_string()))?;
let provided_hmac = hex::decode(parts[2])
.map_err(|_| Error::InvalidPage("Invalid cursor signature".to_string()))?;
let payload = format!("{}:{}", position, timestamp);
let mut mac =
HmacSha256::new_from_slice(&self.secret_key).expect("HMAC accepts any key length");
mac.update(payload.as_bytes());
mac.verify_slice(&provided_hmac)
.map_err(|_| Error::InvalidPage("Cursor integrity check failed".to_string()))?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now.saturating_sub(timestamp) > self.expiry_seconds {
return Err(Error::Validation("Cursor expired".to_string()));
}
Ok(position)
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine as _;
use rstest::rstest;
#[rstest]
fn test_base64_encoder_encode_decode() {
let encoder = Base64CursorEncoder::with_secret_key(b"test-secret-key-for-unit-tests!!");
let position = 42;
let cursor = encoder.encode(position).unwrap();
let decoded = encoder.decode(&cursor).unwrap();
assert_eq!(decoded, position);
}
#[rstest]
fn test_base64_encoder_invalid_cursor() {
let encoder = Base64CursorEncoder::with_secret_key(b"test-secret-key-for-unit-tests!!");
let result = encoder.decode("not-valid-base64!!!");
assert!(result.is_err());
}
#[rstest]
fn test_base64_encoder_tampered_cursor() {
let encoder = Base64CursorEncoder::with_secret_key(b"test-secret-key-for-unit-tests!!");
let cursor = encoder.encode(42).unwrap();
let mut tampered = cursor.clone();
tampered.push('X');
let result = encoder.decode(&tampered);
assert!(result.is_err());
}
#[rstest]
fn test_base64_encoder_different_key_rejects_cursor() {
let encoder_a = Base64CursorEncoder::with_secret_key(b"secret-key-a-for-testing-only!!");
let encoder_b = Base64CursorEncoder::with_secret_key(b"secret-key-b-for-testing-only!!");
let cursor = encoder_a.encode(42).unwrap();
let result = encoder_b.decode(&cursor);
assert!(result.is_err());
if let Err(Error::InvalidPage(msg)) = result {
assert_eq!(msg, "Cursor integrity check failed");
} else {
panic!("Expected InvalidPage error for key mismatch");
}
}
#[rstest]
fn test_base64_encoder_custom_expiry() {
let encoder = Base64CursorEncoder::with_secret_key(b"test-secret-key-for-unit-tests!!")
.expiry_seconds(1);
let cursor = encoder.encode(42).unwrap();
std::thread::sleep(std::time::Duration::from_secs(2));
let result = encoder.decode(&cursor);
assert!(result.is_err());
if let Err(Error::Validation(msg)) = result {
assert_eq!(msg, "Cursor expired");
} else {
panic!("Expected Validation error");
}
}
#[rstest]
fn test_base64_encoder_with_secret_key() {
let key = b"my-secret-key-at-least-32-bytes!";
let encoder = Base64CursorEncoder::with_secret_key(key);
let cursor = encoder.encode(100).unwrap();
let decoded = encoder.decode(&cursor).unwrap();
assert_eq!(decoded, 100);
}
#[rstest]
fn test_base64_encoder_multiple_positions() {
let encoder = Base64CursorEncoder::with_secret_key(b"test-secret-key-for-unit-tests!!");
for position in [0, 1, 100, 999, 10000, usize::MAX / 2] {
let cursor = encoder.encode(position).unwrap();
let decoded = encoder.decode(&cursor).unwrap();
assert_eq!(decoded, position);
}
}
#[rstest]
fn test_base64_encoder_future_timestamp_no_underflow() {
let encoder = Base64CursorEncoder::with_secret_key(b"test-secret-key-for-unit-tests!!");
let position: usize = 42;
let future_timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ 3600;
let payload = format!("{}:{}", position, future_timestamp);
let hmac_bytes = encoder.compute_hmac(payload.as_bytes());
let hmac_hex = hex::encode(&hmac_bytes);
let cursor_data = format!("{}:{}:{}", position, future_timestamp, hmac_hex);
let cursor =
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(cursor_data.as_bytes());
let result = encoder.decode(&cursor);
assert!(result.is_ok());
assert_eq!(result.unwrap(), position);
}
}