use std::{collections::BTreeMap, fmt::Display, time::SystemTime};
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use scion_sdk_token_validator::validator::Token;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Pssid(Uuid);
impl Pssid {
pub fn new(subscription_id: Uuid) -> Self {
Self(subscription_id)
}
}
impl Display for Pssid {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut bytes = Vec::with_capacity(17);
bytes.push(0x00);
bytes.extend_from_slice(self.0.as_bytes());
let encoded = URL_SAFE_NO_PAD.encode(&bytes);
write!(f, "{}", encoded)
}
}
impl Serialize for Pssid {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for Pssid {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let bytes = URL_SAFE_NO_PAD
.decode(&s)
.map_err(serde::de::Error::custom)?;
if bytes.len() != 17 {
return Err(serde::de::Error::custom(format!(
"invalid PSSID length: expected 17, got {}",
bytes.len()
)));
}
if bytes[0] != 0x00 {
return Err(serde::de::Error::custom(format!(
"invalid PSSID version: expected 0, got {}",
bytes[0]
)));
}
let uuid_bytes: [u8; 16] = bytes[1..].try_into().unwrap();
Ok(Pssid(Uuid::from_bytes(uuid_bytes)))
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SnapTokenClaims {
pub ver: usize,
pub iss: String,
pub aud: String,
pub exp: u64,
pub nbf: u64,
pub iat: u64,
pub jti: String,
pub pssid: Pssid,
#[serde(flatten)]
private_claims: BTreeMap<String, serde_json::Value>,
}
impl SnapTokenClaims {
pub fn new(pssid: Pssid, iat: SystemTime, nbf: SystemTime, exp: SystemTime) -> Self {
let iat_secs = iat
.duration_since(SystemTime::UNIX_EPOCH)
.expect("issued at is before epoch")
.as_secs();
let nbf_secs = nbf
.duration_since(SystemTime::UNIX_EPOCH)
.expect("not before is before epoch")
.as_secs();
let exp_secs = exp
.duration_since(SystemTime::UNIX_EPOCH)
.expect("expiration is before epoch")
.as_secs();
Self {
ver: 1,
iss: "ssr".to_string(),
aud: "snap".to_string(),
exp: exp_secs,
nbf: nbf_secs,
iat: iat_secs,
jti: Uuid::new_v4().to_string(),
pssid,
private_claims: BTreeMap::new(),
}
}
pub fn set_private_claim(
&mut self,
key: String,
value: serde_json::Value,
) -> Result<(), String> {
let reserved = ["ver", "iss", "aud", "exp", "nbf", "iat", "jti", "pssid"];
if reserved.contains(&key.as_str()) {
return Err(format!("claim key '{}' is reserved", key));
}
self.private_claims.insert(key, value);
Ok(())
}
pub fn private_claims(&self) -> &BTreeMap<String, serde_json::Value> {
&self.private_claims
}
}
impl Token for SnapTokenClaims {
fn id(&self) -> String {
self.jti.clone()
}
fn exp_time(&self) -> SystemTime {
SystemTime::UNIX_EPOCH + std::time::Duration::from_secs(self.exp)
}
fn required_claims() -> Vec<&'static str> {
vec!["ver", "iss", "aud", "exp", "nbf", "iat", "jti", "pssid"]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pssid_derivation() {
let subscription_id = Uuid::parse_str("123e4567-e89b-12d3-a456-426614174000").unwrap();
let pssid = Pssid::new(subscription_id);
assert_eq!(pssid.to_string(), "ABI-RWfomxLTpFZCZhQXQAA");
let decoded = URL_SAFE_NO_PAD.decode(pssid.to_string()).unwrap();
assert_eq!(decoded.len(), 17);
assert_eq!(decoded[0], 0x00);
let uuid_bytes: [u8; 16] = decoded[1..].try_into().unwrap();
let derived_uuid = Uuid::from_bytes(uuid_bytes);
assert_eq!(derived_uuid, subscription_id);
let json = serde_json::to_string(&pssid).unwrap();
assert_eq!(json, "\"ABI-RWfomxLTpFZCZhQXQAA\"");
let pssid2: Pssid = serde_json::from_str(&json).unwrap();
assert_eq!(pssid2, pssid);
}
}