use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::de::DeserializeOwned;
use crate::aead::open;
use crate::encoding::decode;
use crate::error::{Error, Reason};
use crate::key::MandateKey;
use crate::reserved::{Claims, Clauses, Mandate, Manifest};
use crate::serial;
use crate::token::parse;
use crate::types::{Alg, NumericDate, MANIFEST_KEY};
const MIN_HALF_BYTES: usize = 17;
const MAX_LEEWAY: NumericDate = 60;
const DEFAULT_MAX_DECODED_LEN: usize = 64 * 1024;
pub struct Verifier<'a> {
keys: Vec<&'a MandateKey>,
audience: Option<String>,
leeway: NumericDate,
now: Option<NumericDate>,
max_decoded_len: usize,
}
impl Default for Verifier<'_> {
fn default() -> Self {
Verifier {
keys: Vec::new(),
audience: None,
leeway: 0,
now: None,
max_decoded_len: DEFAULT_MAX_DECODED_LEN,
}
}
}
impl<'a> Verifier<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn key(mut self, key: &'a MandateKey) -> Self {
self.keys.push(key);
self
}
pub fn keys<I: IntoIterator<Item = &'a MandateKey>>(mut self, keys: I) -> Self {
self.keys.extend(keys);
self
}
pub fn audience(mut self, id: impl Into<String>) -> Self {
self.audience = Some(id.into());
self
}
pub fn leeway(mut self, leeway: Duration) -> Self {
self.leeway = leeway.as_secs().min(MAX_LEEWAY as u64) as NumericDate;
self
}
pub fn now(mut self, now: NumericDate) -> Self {
self.now = Some(now);
self
}
pub fn max_decoded_len(mut self, max: usize) -> Self {
self.max_decoded_len = max;
self
}
pub fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<Mandate<T>, Error> {
self.verify_inner(token).map_err(Error::new)
}
fn verify_inner<T: DeserializeOwned>(&self, token: &str) -> Result<Mandate<T>, Reason> {
let parsed = parse(token).map_err(|_| Reason::Malformed)?;
let half = parsed.mandate.ok_or(Reason::EmptyMandate)?;
let alg = Alg::from_code(half.alg_code).ok_or(Reason::Unsupported)?;
if half.text.len() > self.max_decoded_len.saturating_mul(2).saturating_add(8) {
return Err(Reason::Malformed);
}
let sealed = decode(half.text, parsed.encoding).ok_or(Reason::Malformed)?;
if sealed.len() < MIN_HALF_BYTES || sealed.len() > self.max_decoded_len {
return Err(Reason::Malformed);
}
let plain = self
.keys
.iter()
.find_map(|k| open(&sealed, k.bytes(), alg))
.map(zeroize::Zeroizing::new)
.ok_or(Reason::AuthFailed)?;
let clauses: Clauses<T> = serial::from_mandate_plaintext(&plain)?;
let tid = clauses.tid.ok_or(Reason::BadTid)?;
if tid.get_version_num() != 7 || tid.get_variant() != uuid::Variant::RFC4122 {
return Err(Reason::BadTid);
}
let exp = clauses.exp.ok_or(Reason::MissingClause)?;
let now = self.now.unwrap_or_else(now_unix);
if now >= exp.saturating_add(self.leeway) {
return Err(Reason::Expired);
}
if let Some(aud) = &clauses.aud {
if aud.is_empty() {
return Err(Reason::AudienceMismatch);
}
let me = self.audience.as_deref().ok_or(Reason::AudienceMismatch)?;
if !aud_contains(aud, me) {
return Err(Reason::AudienceMismatch);
}
}
Ok(Mandate { inner: clauses })
}
}
fn aud_contains(aud: &[String], me: &str) -> bool {
use subtle::ConstantTimeEq;
let mut hit = false;
for a in aud {
hit |= bool::from(a.as_bytes().ct_eq(me.as_bytes()));
}
hit
}
fn now_unix() -> NumericDate {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as NumericDate)
.unwrap_or(0)
}
pub fn open_manifest<T: DeserializeOwned>(token: &str) -> Option<Manifest<T>> {
let parsed = parse(token).ok()?;
let half = parsed.manifest?;
let alg = Alg::from_code(half.alg_code)?;
if half.text.len() > DEFAULT_MAX_DECODED_LEN.saturating_mul(2).saturating_add(8) {
return None;
}
let sealed = decode(half.text, parsed.encoding)?;
if sealed.len() < MIN_HALF_BYTES || sealed.len() > DEFAULT_MAX_DECODED_LEN {
return None;
}
let plain = open(&sealed, &MANIFEST_KEY, alg)?;
let claims: Claims<T> = serial::from_manifest_plaintext(&plain)?;
Some(Manifest { inner: claims })
}