use std::collections::HashSet;
use ciborium::value::{Integer, Value};
use serde::de::DeserializeOwned;
use serde::Serialize;
use uuid::Uuid;
use crate::error::{MintError, Reason};
use crate::reserved::{MandateFields, ManifestFields};
use crate::types::NumericDate;
const KEY_TID: i8 = -1;
const KEY_EXP: i8 = -2;
const KEY_AUD: i8 = -3;
const KEY_SUB: i8 = -4;
const KEY_ISS: i8 = -5;
const RESERVED: [i8; 5] = [KEY_TID, KEY_EXP, KEY_AUD, KEY_SUB, KEY_ISS];
#[inline]
fn ikey(n: i8) -> Value {
Value::Integer(Integer::from(n))
}
fn encode(value: &Value) -> Vec<u8> {
let mut out = Vec::new();
ciborium::into_writer(value, &mut out).expect("CBOR encode to Vec is infallible");
out
}
fn canonicalize(value: &mut Value) {
match value {
Value::Map(entries) => {
for (k, v) in entries.iter_mut() {
canonicalize(k);
canonicalize(v);
}
entries.sort_by_cached_key(|(k, _)| encode(k));
}
Value::Array(items) => items.iter_mut().for_each(canonicalize),
Value::Tag(_, inner) => canonicalize(inner.as_mut()),
_ => {}
}
}
fn is_negative_int(k: &Value) -> bool {
matches!(k, Value::Integer(_)) && encode(k).first().is_some_and(|b| b >> 5 == 1)
}
fn contains_nan(value: &Value) -> bool {
match value {
Value::Float(f) => f.is_nan(),
Value::Array(items) => items.iter().any(contains_nan),
Value::Map(entries) => entries
.iter()
.any(|(k, v)| contains_nan(k) || contains_nan(v)),
Value::Tag(_, inner) => contains_nan(inner),
_ => false,
}
}
fn has_invalid_map_key(value: &Value) -> bool {
match value {
Value::Map(entries) => entries.iter().any(|(k, v)| {
!matches!(k, Value::Integer(_) | Value::Text(_)) || has_invalid_map_key(v)
}),
Value::Array(items) => items.iter().any(has_invalid_map_key),
Value::Tag(_, inner) => has_invalid_map_key(inner),
_ => false,
}
}
enum KeyKind {
Reserved(i8),
UnknownReserved,
App,
Invalid,
}
fn classify(k: &Value) -> KeyKind {
for n in RESERVED {
if k == &ikey(n) {
return KeyKind::Reserved(n);
}
}
if is_negative_int(k) {
KeyKind::UnknownReserved
} else if matches!(k, Value::Integer(_) | Value::Text(_)) {
KeyKind::App
} else {
KeyKind::Invalid
}
}
fn app_entries<T: Serialize>(app: &T) -> Result<Vec<(Value, Value)>, MintError> {
let value = Value::serialized(app).map_err(|e| MintError::Serialization(e.to_string()))?;
let entries = match value {
Value::Map(entries) => entries,
_ => return Err(MintError::AppNotMap),
};
if entries.iter().any(|(k, _)| is_negative_int(k)) {
return Err(MintError::ReservedKey);
}
if entries.iter().any(|(_, v)| contains_nan(v)) {
return Err(MintError::Nan);
}
Ok(entries)
}
fn assemble(entries: Vec<(Value, Value)>) -> Vec<u8> {
let mut value = Value::Map(entries);
canonicalize(&mut value);
encode(&value)
}
pub(crate) fn to_mandate_plaintext<T: Serialize>(
exp: NumericDate,
tid: Uuid,
iss: Option<&str>,
aud: Option<&[String]>,
sub: Option<&str>,
app: &T,
) -> Result<Vec<u8>, MintError> {
let mut entries = app_entries(app)?;
entries.push((ikey(KEY_TID), Value::Bytes(tid.as_bytes().to_vec())));
entries.push((ikey(KEY_EXP), Value::Integer(Integer::from(exp))));
if let Some(aud) = aud {
let arr = aud.iter().map(|s| Value::Text(s.clone())).collect();
entries.push((ikey(KEY_AUD), Value::Array(arr)));
}
if let Some(sub) = sub {
entries.push((ikey(KEY_SUB), Value::Text(sub.to_owned())));
}
if let Some(iss) = iss {
entries.push((ikey(KEY_ISS), Value::Text(iss.to_owned())));
}
Ok(assemble(entries))
}
pub(crate) fn to_manifest_plaintext<T: Serialize>(
iss: &str,
app: &T,
) -> Result<Vec<u8>, MintError> {
let mut entries = app_entries(app)?;
entries.push((ikey(KEY_ISS), Value::Text(iss.to_owned())));
Ok(assemble(entries))
}
fn strict_map(plain: &[u8]) -> Result<Vec<(Value, Value)>, Reason> {
let value: Value = ciborium::from_reader(plain).map_err(|_| Reason::Malformed)?;
let entries = match value {
Value::Map(entries) => entries,
_ => return Err(Reason::Malformed),
};
if entries
.iter()
.any(|(k, v)| contains_nan(k) || contains_nan(v))
{
return Err(Reason::NonCanonical);
}
if entries
.iter()
.any(|(k, v)| !matches!(k, Value::Integer(_) | Value::Text(_)) || has_invalid_map_key(v))
{
return Err(Reason::NonCanonical);
}
let mut seen = HashSet::with_capacity(entries.len());
for (k, _) in &entries {
if !seen.insert(encode(k)) {
return Err(Reason::NonCanonical);
}
}
let mut canon = Value::Map(entries.clone());
canonicalize(&mut canon);
if encode(&canon) != plain {
return Err(Reason::NonCanonical);
}
Ok(entries)
}
fn read_tid(v: &Value) -> Result<Uuid, Reason> {
match v {
Value::Bytes(b) if b.len() == 16 => {
Ok(Uuid::from_bytes(b[..16].try_into().expect("len checked")))
}
_ => Err(Reason::BadType),
}
}
fn read_int(v: &Value) -> Result<NumericDate, Reason> {
match v {
Value::Integer(i) => NumericDate::try_from(*i).map_err(|_| Reason::BadType),
_ => Err(Reason::BadType),
}
}
fn read_text(v: &Value) -> Result<String, Reason> {
match v {
Value::Text(s) => Ok(s.clone()),
_ => Err(Reason::BadType),
}
}
fn read_aud(v: &Value) -> Result<Vec<String>, Reason> {
match v {
Value::Array(items) => items
.iter()
.map(|it| match it {
Value::Text(s) => Ok(s.clone()),
_ => Err(Reason::BadType),
})
.collect(),
_ => Err(Reason::BadType),
}
}
fn decode_app<T: DeserializeOwned>(entries: Vec<(Value, Value)>) -> Result<T, Reason> {
Value::Map(entries)
.deserialized::<T>()
.map_err(|_| Reason::Malformed)
}
pub(crate) fn from_mandate_plaintext<T: DeserializeOwned>(
plain: &[u8],
) -> Result<MandateFields<T>, Reason> {
let entries = strict_map(plain)?;
let mut exp = None;
let mut tid = None;
let mut iss = None;
let mut aud = None;
let mut sub = None;
let mut app = Vec::new();
for (k, v) in entries {
match classify(&k) {
KeyKind::Reserved(KEY_TID) => tid = Some(read_tid(&v)?),
KeyKind::Reserved(KEY_EXP) => exp = Some(read_int(&v)?),
KeyKind::Reserved(KEY_AUD) => aud = Some(read_aud(&v)?),
KeyKind::Reserved(KEY_SUB) => sub = Some(read_text(&v)?),
KeyKind::Reserved(KEY_ISS) => iss = Some(read_text(&v)?),
KeyKind::Reserved(_) => unreachable!("RESERVED covers every match arm"),
KeyKind::UnknownReserved => return Err(Reason::UnknownReservedKey),
KeyKind::App => app.push((k, v)),
KeyKind::Invalid => return Err(Reason::NonCanonical),
}
}
Ok(MandateFields {
exp,
tid,
iss,
aud,
sub,
app: decode_app(app)?,
})
}
pub(crate) fn from_manifest_plaintext<T: DeserializeOwned>(
plain: &[u8],
) -> Option<ManifestFields<T>> {
let entries = strict_map(plain).ok()?;
let mut iss = None;
let mut exp = None;
let mut app = Vec::new();
for (k, v) in entries {
match classify(&k) {
KeyKind::Reserved(KEY_ISS) => iss = Some(read_text(&v).ok()?),
KeyKind::Reserved(KEY_EXP) => exp = Some(read_int(&v).ok()?),
KeyKind::Reserved(_) | KeyKind::UnknownReserved | KeyKind::Invalid => return None,
KeyKind::App => app.push((k, v)),
}
}
Some(ManifestFields {
iss: iss?,
exp,
app: Value::Map(app).deserialized::<T>().ok()?,
})
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
#[derive(Serialize, Deserialize, PartialEq, Debug, Default)]
struct App {
role: String,
n: u32,
}
fn tid7() -> Uuid {
Uuid::now_v7()
}
#[test]
fn mandate_round_trips_reserved_and_app() {
let tid = tid7();
let pt = to_mandate_plaintext(
1000,
tid,
Some("auth.example"),
Some(&["api".to_string(), "admin".to_string()]),
Some("user-1"),
&App {
role: "admin".into(),
n: 7,
},
)
.unwrap();
let c: MandateFields<App> = from_mandate_plaintext(&pt).unwrap();
assert_eq!(c.exp, Some(1000));
assert_eq!(c.tid, Some(tid));
assert_eq!(c.iss.as_deref(), Some("auth.example"));
assert_eq!(
c.aud.as_deref(),
Some(&["api".to_string(), "admin".to_string()][..])
);
assert_eq!(c.sub.as_deref(), Some("user-1"));
assert_eq!(
c.app,
App {
role: "admin".into(),
n: 7
}
);
}
#[test]
fn encoding_is_canonical_and_deterministic() {
let tid = tid7();
let a = to_mandate_plaintext(
1,
tid,
None,
None,
None,
&App {
role: "x".into(),
n: 1,
},
)
.unwrap();
let b = to_mandate_plaintext(
1,
tid,
None,
None,
None,
&App {
role: "x".into(),
n: 1,
},
)
.unwrap();
assert_eq!(a, b);
let c: MandateFields<App> = from_mandate_plaintext(&a).unwrap();
assert_eq!(
c.app,
App {
role: "x".into(),
n: 1
}
);
}
#[test]
fn rejects_non_canonical_unsorted_keys() {
let tid = tid7();
let unsorted = Value::Map(vec![
(ikey(KEY_EXP), Value::Integer(Integer::from(5))),
(ikey(KEY_TID), Value::Bytes(tid.as_bytes().to_vec())),
]);
let bytes = encode(&unsorted); assert_eq!(
from_mandate_plaintext::<crate::reserved::NoApp>(&bytes).unwrap_err(),
Reason::NonCanonical
);
}
#[test]
fn rejects_duplicate_key() {
let dup = Value::Map(vec![
(ikey(KEY_EXP), Value::Integer(Integer::from(1))),
(ikey(KEY_EXP), Value::Integer(Integer::from(2))),
]);
let bytes = encode(&dup);
assert_eq!(
from_mandate_plaintext::<crate::reserved::NoApp>(&bytes).unwrap_err(),
Reason::NonCanonical
);
}
#[test]
fn unknown_negative_key_fails_closed() {
let m = Value::Map(vec![(ikey(-9), Value::Integer(Integer::from(1)))]);
let bytes = encode(&m);
assert_eq!(
from_mandate_plaintext::<crate::reserved::NoApp>(&bytes).unwrap_err(),
Reason::UnknownReservedKey
);
}
#[test]
fn wrong_type_reserved_field_rejected() {
let m = Value::Map(vec![(ikey(KEY_TID), Value::Text("nope".into()))]);
let bytes = encode(&m);
assert_eq!(
from_mandate_plaintext::<crate::reserved::NoApp>(&bytes).unwrap_err(),
Reason::BadType
);
}
#[test]
fn rejects_nan_float() {
let m = Value::Map(vec![(
Value::Integer(Integer::from(0)),
Value::Float(f64::NAN),
)]);
let bytes = encode(&m);
assert_eq!(
from_mandate_plaintext::<crate::reserved::NoApp>(&bytes).unwrap_err(),
Reason::NonCanonical
);
use std::collections::BTreeMap;
let app = BTreeMap::from([("x".to_string(), f64::NAN)]);
assert!(matches!(
to_mandate_plaintext(1, tid7(), None, None, None, &app),
Err(MintError::Nan)
));
}
#[test]
fn app_reserved_key_rejected_on_mint() {
use std::collections::BTreeMap;
let app = BTreeMap::from([(-1i64, 5u32)]);
assert!(matches!(
to_manifest_plaintext("iss", &app),
Err(MintError::ReservedKey)
));
assert!(matches!(
to_mandate_plaintext(1, tid7(), None, None, None, &app),
Err(MintError::ReservedKey)
));
}
#[test]
fn app_not_map_rejected_on_mint() {
#[derive(Serialize)]
struct Bad;
assert!(matches!(
to_manifest_plaintext("iss", &Bad),
Err(MintError::AppNotMap)
));
}
#[test]
fn manifest_round_trips() {
let pt = to_manifest_plaintext(
"auth.example",
&App {
role: "ui".into(),
n: 2,
},
)
.unwrap();
let c: ManifestFields<App> = from_manifest_plaintext(&pt).unwrap();
assert_eq!(c.iss, "auth.example");
assert_eq!(
c.app,
App {
role: "ui".into(),
n: 2
}
);
}
}