use std::time::{Duration, SystemTime, UNIX_EPOCH};
use serde::de::DeserializeOwned;
use crate::aead::open;
use crate::encoding::decode;
use crate::error::{Error, Reason};
use crate::key::MandateKey;
use crate::reserved::{Claims, Clauses, MandateFields, ManifestFields};
use crate::serial;
use crate::token::parse;
use crate::types::{Alg, NumericDate, MANIFEST_KEY};
const MIN_HALF_BYTES: usize = 17;
const MAX_LEEWAY: NumericDate = 60;
const DEFAULT_MAX_DECODED_LEN: usize = 64 * 1024;
pub struct Verifier<'a> {
keys: Vec<&'a MandateKey>,
audience: Option<String>,
leeway: NumericDate,
now: Option<NumericDate>,
max_decoded_len: usize,
}
impl Default for Verifier<'_> {
fn default() -> Self {
Verifier {
keys: Vec::new(),
audience: None,
leeway: 0,
now: None,
max_decoded_len: DEFAULT_MAX_DECODED_LEN,
}
}
}
impl<'a> Verifier<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn key(mut self, key: &'a MandateKey) -> Self {
self.keys.push(key);
self
}
pub fn keys<I: IntoIterator<Item = &'a MandateKey>>(mut self, keys: I) -> Self {
self.keys.extend(keys);
self
}
pub fn audience(mut self, id: impl Into<String>) -> Self {
self.audience = Some(id.into());
self
}
pub fn leeway(mut self, leeway: Duration) -> Self {
self.leeway = leeway.as_secs().min(MAX_LEEWAY as u64) as NumericDate;
self
}
pub fn now(mut self, now: NumericDate) -> Self {
self.now = Some(now);
self
}
pub fn max_decoded_len(mut self, max: usize) -> Self {
self.max_decoded_len = max;
self
}
pub fn clauses<T: DeserializeOwned>(&self, token: &str) -> Result<Clauses<T>, Error> {
self.clauses_inner(token).map_err(Error::new)
}
pub fn clauses_unchecked<T: DeserializeOwned>(&self, token: &str) -> Result<Clauses<T>, Error> {
self.clauses_unchecked_inner(token).map_err(Error::new)
}
pub fn mandate_plaintext(&self, token: &str) -> Result<Vec<u8>, Error> {
self.authenticate(token).map_err(Error::new)
}
fn authenticate(&self, token: &str) -> Result<Vec<u8>, Reason> {
let parsed = parse(token).map_err(|_| Reason::Malformed)?;
let half = parsed.mandate.ok_or(Reason::EmptyMandate)?;
let alg = Alg::from_code(half.alg_code).ok_or(Reason::Unsupported)?;
if half.text.len() > self.max_decoded_len.saturating_mul(2).saturating_add(8) {
return Err(Reason::Malformed);
}
let sealed = decode(half.text, parsed.encoding).ok_or(Reason::Malformed)?;
if sealed.len() < MIN_HALF_BYTES || sealed.len() > self.max_decoded_len {
return Err(Reason::Malformed);
}
self.keys
.iter()
.find_map(|k| open(&sealed, k.bytes(), alg))
.ok_or(Reason::AuthFailed)
}
fn authenticate_and_decode<T: DeserializeOwned>(
&self,
token: &str,
) -> Result<MandateFields<T>, Reason> {
let plain = zeroize::Zeroizing::new(self.authenticate(token)?);
serial::from_mandate_plaintext(&plain)
}
fn clauses_inner<T: DeserializeOwned>(&self, token: &str) -> Result<Clauses<T>, Reason> {
let fields = self.authenticate_and_decode::<T>(token)?;
let tid = fields.tid.ok_or(Reason::BadTid)?;
if tid.get_version_num() != 7 || tid.get_variant() != uuid::Variant::RFC4122 {
return Err(Reason::BadTid);
}
let exp = fields.exp.ok_or(Reason::MissingClause)?;
let now = self.now.unwrap_or_else(now_unix);
if now >= exp.saturating_add(self.leeway) {
return Err(Reason::Expired);
}
if let Some(aud) = &fields.aud {
if aud.is_empty() {
return Err(Reason::AudienceMismatch);
}
let me = self.audience.as_deref().ok_or(Reason::AudienceMismatch)?;
if !aud_contains(aud, me) {
return Err(Reason::AudienceMismatch);
}
}
Ok(Clauses { inner: fields })
}
fn clauses_unchecked_inner<T: DeserializeOwned>(
&self,
token: &str,
) -> Result<Clauses<T>, Reason> {
let fields = self.authenticate_and_decode::<T>(token)?;
fields.tid.ok_or(Reason::BadTid)?;
fields.exp.ok_or(Reason::MissingClause)?;
Ok(Clauses { inner: fields })
}
}
fn aud_contains(aud: &[String], me: &str) -> bool {
use subtle::ConstantTimeEq;
let mut hit = false;
for a in aud {
hit |= bool::from(a.as_bytes().ct_eq(me.as_bytes()));
}
hit
}
fn now_unix() -> NumericDate {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as NumericDate)
.unwrap_or(0)
}
pub fn claims<T: DeserializeOwned>(token: &str) -> Option<Claims<T>> {
let plain = manifest_plaintext(token)?;
let fields: ManifestFields<T> = serial::from_manifest_plaintext(&plain)?;
Some(Claims { inner: fields })
}
pub fn manifest(token: &str) -> Option<String> {
let parsed = parse(token).ok()?;
let half = parsed.manifest?;
let mut out = String::with_capacity(half.text.len() + 1 + parsed.separator.len_utf8());
out.push_str(half.text);
out.push(half.alg_code);
out.push(parsed.separator);
Some(out)
}
pub fn mandate(token: &str) -> Option<String> {
let parsed = parse(token).ok()?;
parsed.mandate.as_ref()?;
let mut out = String::with_capacity(parsed.separator.len_utf8() + parsed.mandate_part.len());
out.push(parsed.separator);
out.push_str(parsed.mandate_part);
Some(out)
}
pub fn manifest_plaintext(token: &str) -> Option<Vec<u8>> {
let parsed = parse(token).ok()?;
let half = parsed.manifest?;
let alg = Alg::from_code(half.alg_code)?;
if half.text.len() > DEFAULT_MAX_DECODED_LEN.saturating_mul(2).saturating_add(8) {
return None;
}
let sealed = decode(half.text, parsed.encoding)?;
if sealed.len() < MIN_HALF_BYTES || sealed.len() > DEFAULT_MAX_DECODED_LEN {
return None;
}
open(&sealed, &MANIFEST_KEY, alg)
}
pub fn authorization_header(token: &str, scheme: &str) -> Option<String> {
let half = mandate(token)?;
let mut out = String::with_capacity(scheme.len() + 1 + half.len());
out.push_str(scheme);
out.push(' ');
out.push_str(&half);
Some(out)
}