use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine as _;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use subtle::ConstantTimeEq;
use thiserror::Error;
use super::types::{LinkId, MsgId};
type HmacSha256 = Hmac<Sha256>;
const VERSION: u8 = 0x01;
const TAG_LEN: usize = 16;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum TokenError {
#[error("token: base64 decode failed")]
BadEncoding,
#[error("token: wrong length")]
BadLength,
#[error("token: tag mismatch")]
TagMismatch,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TrackingToken(String);
impl TrackingToken {
pub fn from_string(s: impl Into<String>) -> Self {
Self(s.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for TrackingToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
#[derive(Clone)]
pub struct TrackingTokenSigner {
secret: Vec<u8>,
}
impl std::fmt::Debug for TrackingTokenSigner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TrackingTokenSigner")
.field("secret_len", &self.secret.len())
.finish()
}
}
impl TrackingTokenSigner {
pub fn new(secret: impl Into<Vec<u8>>) -> Result<Self, TokenError> {
let secret = secret.into();
if secret.len() < 16 {
return Err(TokenError::BadLength);
}
Ok(Self { secret })
}
pub fn sign_open(&self, tenant_id: &str, msg_id: &MsgId) -> TrackingToken {
self.sign(tenant_id, msg_id, None)
}
pub fn sign_click(&self, tenant_id: &str, msg_id: &MsgId, link_id: &LinkId) -> TrackingToken {
self.sign(tenant_id, msg_id, Some(link_id))
}
pub fn verify(
&self,
tenant_id: &str,
msg_id: &MsgId,
link_id: Option<&LinkId>,
token: &TrackingToken,
) -> Result<(), TokenError> {
let decoded = URL_SAFE_NO_PAD
.decode(token.as_str().as_bytes())
.map_err(|_| TokenError::BadEncoding)?;
if decoded.len() != TAG_LEN {
return Err(TokenError::BadLength);
}
let expected = self.compute_tag(tenant_id, msg_id, link_id);
if expected.ct_eq(&decoded).into() {
Ok(())
} else {
Err(TokenError::TagMismatch)
}
}
fn sign(&self, tenant_id: &str, msg_id: &MsgId, link_id: Option<&LinkId>) -> TrackingToken {
let tag = self.compute_tag(tenant_id, msg_id, link_id);
TrackingToken(URL_SAFE_NO_PAD.encode(tag))
}
fn compute_tag(
&self,
tenant_id: &str,
msg_id: &MsgId,
link_id: Option<&LinkId>,
) -> [u8; TAG_LEN] {
let mut mac =
HmacSha256::new_from_slice(&self.secret).expect("HmacSha256 accepts any key length");
mac.update(&[VERSION]);
mac.update(tenant_id.as_bytes());
mac.update(b"\x00");
mac.update(msg_id.as_str().as_bytes());
mac.update(b"\x00");
if let Some(l) = link_id {
mac.update(l.as_str().as_bytes());
}
let full = mac.finalize().into_bytes();
let mut out = [0u8; TAG_LEN];
out.copy_from_slice(&full[..TAG_LEN]);
out
}
}
#[cfg(test)]
mod tests {
use super::*;
fn signer() -> TrackingTokenSigner {
TrackingTokenSigner::new(vec![0u8; 32]).unwrap()
}
#[test]
fn sign_open_round_trips() {
let s = signer();
let m = MsgId::new("msg-1");
let tok = s.sign_open("acme", &m);
assert_eq!(s.verify("acme", &m, None, &tok), Ok(()));
}
#[test]
fn sign_click_round_trips() {
let s = signer();
let m = MsgId::new("msg-1");
let l = LinkId::new("L0");
let tok = s.sign_click("acme", &m, &l);
assert_eq!(s.verify("acme", &m, Some(&l), &tok), Ok(()));
}
#[test]
fn cross_tenant_tag_rejected() {
let s = signer();
let m = MsgId::new("msg-1");
let tok = s.sign_open("acme", &m);
assert_eq!(
s.verify("globex", &m, None, &tok),
Err(TokenError::TagMismatch),
);
}
#[test]
fn cross_msg_tag_rejected() {
let s = signer();
let tok = s.sign_open("acme", &MsgId::new("msg-1"));
assert_eq!(
s.verify("acme", &MsgId::new("msg-2"), None, &tok),
Err(TokenError::TagMismatch),
);
}
#[test]
fn open_tag_rejected_for_click_payload() {
let s = signer();
let m = MsgId::new("msg-1");
let l = LinkId::new("L0");
let tok = s.sign_open("acme", &m);
assert_eq!(
s.verify("acme", &m, Some(&l), &tok),
Err(TokenError::TagMismatch),
);
}
#[test]
fn different_secrets_dont_collide() {
let a = TrackingTokenSigner::new(vec![1u8; 32]).unwrap();
let b = TrackingTokenSigner::new(vec![2u8; 32]).unwrap();
let m = MsgId::new("msg-1");
let tok = a.sign_open("acme", &m);
assert_eq!(
b.verify("acme", &m, None, &tok),
Err(TokenError::TagMismatch),
);
}
#[test]
fn malformed_token_returns_bad_encoding() {
let s = signer();
let bad = TrackingToken::from_string("!!! not base64 !!!");
assert_eq!(
s.verify("acme", &MsgId::new("m"), None, &bad),
Err(TokenError::BadEncoding),
);
}
#[test]
fn truncated_token_returns_bad_length() {
let s = signer();
let bad = TrackingToken::from_string("aQ");
assert_eq!(
s.verify("acme", &MsgId::new("m"), None, &bad),
Err(TokenError::BadLength),
);
}
#[test]
fn short_secret_rejected() {
let r = TrackingTokenSigner::new(vec![0u8; 8]);
assert_eq!(r.unwrap_err(), TokenError::BadLength);
}
#[test]
fn signer_debug_does_not_leak_secret() {
let s = TrackingTokenSigner::new(b"super-secret-value-32bytes-okay!".to_vec()).unwrap();
let dbg = format!("{s:?}");
assert!(!dbg.contains("super-secret"));
assert!(dbg.contains("secret_len"));
}
#[test]
fn token_format_is_url_safe() {
let s = signer();
let tok = s.sign_open("acme", &MsgId::new("msg-1"));
let v = tok.as_str();
assert_eq!(v.len(), 22);
assert!(!v.contains('+'));
assert!(!v.contains('/'));
assert!(!v.contains('='));
}
}