#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::{string::String, vec::Vec};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use core::fmt;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum CursorError {
InvalidBase64,
InvalidSignature,
TooShort,
}
impl fmt::Display for CursorError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidBase64 => f.write_str("cursor: invalid base64 encoding"),
Self::InvalidSignature => f.write_str("cursor: invalid HMAC signature"),
Self::TooShort => f.write_str("cursor: payload too short to contain signature"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for CursorError {}
pub struct Cursor;
impl Cursor {
#[must_use]
pub fn encode(payload: &[u8]) -> String {
URL_SAFE_NO_PAD.encode(payload)
}
pub fn decode(cursor: &str) -> Result<Vec<u8>, CursorError> {
URL_SAFE_NO_PAD
.decode(cursor)
.map_err(|_| CursorError::InvalidBase64)
}
#[cfg(feature = "hmac")]
#[must_use]
pub fn encode_signed(payload: &[u8], key: &[u8]) -> String {
use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(payload);
let sig = mac.finalize().into_bytes();
let mut combined = Vec::with_capacity(payload.len() + 32);
combined.extend_from_slice(payload);
combined.extend_from_slice(&sig);
URL_SAFE_NO_PAD.encode(&combined)
}
#[cfg(feature = "hmac")]
pub fn decode_signed(cursor: &str, key: &[u8]) -> Result<Vec<u8>, CursorError> {
use hmac::{Hmac, KeyInit, Mac};
use sha2::Sha256;
let combined = URL_SAFE_NO_PAD
.decode(cursor)
.map_err(|_| CursorError::InvalidBase64)?;
if combined.len() < 32 {
return Err(CursorError::TooShort);
}
let (payload, stored_sig) = combined.split_at(combined.len() - 32);
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(payload);
mac.verify_slice(stored_sig)
.map_err(|_| CursorError::InvalidSignature)?;
Ok(payload.to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_decode_round_trip() {
let payload = b"user:42:2024-01-01";
let encoded = Cursor::encode(payload);
let decoded = Cursor::decode(&encoded).unwrap();
assert_eq!(decoded, payload);
}
#[test]
fn encode_uses_url_safe_alphabet() {
for i in 0u8..=255 {
let encoded = Cursor::encode(&[i]);
assert!(
encoded
.chars()
.all(|ch| ch.is_alphanumeric() || ch == '-' || ch == '_'),
"non-url-safe character in {encoded}"
);
}
}
#[test]
fn decode_invalid_base64_error() {
let result = Cursor::decode("!!!not-base64!!!");
assert_eq!(result.unwrap_err(), CursorError::InvalidBase64);
}
#[test]
fn encode_empty_payload() {
let encoded = Cursor::encode(b"");
assert_eq!(encoded, "");
let decoded = Cursor::decode(&encoded).unwrap();
assert_eq!(decoded, b"");
}
#[cfg(feature = "hmac")]
#[test]
fn signed_round_trip() {
let key = b"test-key-very-secret";
let payload = b"id=99&sort=asc";
let signed = Cursor::encode_signed(payload, key);
let out = Cursor::decode_signed(&signed, key).unwrap();
assert_eq!(out, payload);
}
#[cfg(feature = "hmac")]
#[test]
fn signed_wrong_key_fails() {
let signed = Cursor::encode_signed(b"id=1", b"right-key");
let result = Cursor::decode_signed(&signed, b"wrong-key");
assert_eq!(result.unwrap_err(), CursorError::InvalidSignature);
}
#[cfg(feature = "hmac")]
#[test]
fn signed_tampered_payload_fails() {
let signed = Cursor::encode_signed(b"id=1", b"key");
let mut raw = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(&signed)
.unwrap();
raw[0] ^= 0xFF;
let tampered = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&raw);
let result = Cursor::decode_signed(&tampered, b"key");
assert_eq!(result.unwrap_err(), CursorError::InvalidSignature);
}
#[cfg(feature = "hmac")]
#[test]
fn signed_too_short_error() {
let short = Cursor::encode(b"tiny");
let result = Cursor::decode_signed(&short, b"key");
assert_eq!(result.unwrap_err(), CursorError::TooShort);
}
#[test]
fn cursor_error_display_all_variants() {
assert!(!CursorError::InvalidBase64.to_string().is_empty());
assert!(!CursorError::InvalidSignature.to_string().is_empty());
assert!(!CursorError::TooShort.to_string().is_empty());
}
}