use std::collections::BTreeMap;
#[cfg(any(test, feature = "test-support"))]
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use super::RateLimitKey;
#[async_trait]
pub trait AuditSink: std::fmt::Debug + Send + Sync {
async fn record_failure(&self, event: AuditEvent);
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEvent {
pub kind: VerifyErrorKind,
#[serde(with = "time::serde::rfc3339")]
pub occurred_at: OffsetDateTime,
pub source_id: String,
pub client_id_hint: Option<String>,
pub kid_hint: Option<String>,
pub metadata: BTreeMap<String, serde_json::Value>,
}
impl AuditEvent {
#[must_use]
pub fn from_hints(
kind: VerifyErrorKind,
occurred_at: OffsetDateTime,
client_id_hint: Option<String>,
kid_hint: Option<String>,
metadata: BTreeMap<String, serde_json::Value>,
) -> Self {
let source_id = compose_source_id(client_id_hint.as_deref(), kid_hint.as_deref());
Self {
kind,
occurred_at,
source_id,
client_id_hint,
kid_hint,
metadata,
}
}
#[must_use]
pub fn from_id_token_hints(
kind: VerifyErrorKind,
occurred_at: OffsetDateTime,
azp_hint: Option<String>,
aud_hint: Option<String>,
kid_hint: Option<String>,
mut metadata: BTreeMap<String, serde_json::Value>,
) -> Self {
let source_id = compose_id_token_source_id(
azp_hint.as_deref(),
aud_hint.as_deref(),
kid_hint.as_deref(),
);
if let Some(aud) = &aud_hint {
metadata.insert(
"aud_hint".to_owned(),
serde_json::Value::String(aud.clone()),
);
}
Self {
kind,
occurred_at,
source_id,
client_id_hint: azp_hint,
kid_hint,
metadata,
}
}
#[must_use]
pub fn rate_limit_key(&self) -> RateLimitKey {
RateLimitKey::new(self.source_id.clone())
}
}
#[must_use]
pub fn compose_source_id(client_id_hint: Option<&str>, kid_hint: Option<&str>) -> String {
let cid = client_id_hint.unwrap_or("anon");
let kid = kid_hint.unwrap_or("nokid");
format!("{cid}::{kid}")
}
#[must_use]
pub fn compose_id_token_source_id(
azp_hint: Option<&str>,
aud_hint: Option<&str>,
kid_hint: Option<&str>,
) -> String {
let azp = azp_hint.unwrap_or("anon");
let aud = aud_hint.unwrap_or("noaud");
let kid = kid_hint.unwrap_or("nokid");
format!("{azp}::{aud}::{kid}")
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "tag", content = "claim")]
pub enum VerifyErrorKind {
InvalidFormat,
SignatureInvalid,
Expired,
IssuerInvalid,
AudienceInvalid,
MissingClaim(String),
KeysetUnavailable,
IdTokenAsBearer,
IdToken(IdTokenFailureKind),
Other,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum IdTokenFailureKind {
NonceMissing,
NonceMismatch,
AtHashMissing,
AtHashMismatch,
CHashMissing,
CHashMismatch,
AzpMissing,
AzpMismatch,
AuthTimeMissing,
AuthTimeStale,
AcrMissing,
AcrNotAllowed,
UnknownClaim(String),
CatMismatch(String),
}
#[derive(Debug, Default, Clone)]
pub struct NoopAuditSink;
#[async_trait]
impl AuditSink for NoopAuditSink {
async fn record_failure(&self, _event: AuditEvent) {}
}
#[cfg(any(test, feature = "test-support"))]
#[derive(Debug, Default, Clone)]
pub struct MemoryAuditSink {
events: Arc<Mutex<Vec<AuditEvent>>>,
}
#[cfg(any(test, feature = "test-support"))]
impl MemoryAuditSink {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn events(&self) -> Vec<AuditEvent> {
self.events
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.clone()
}
pub fn len(&self) -> usize {
self.events
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(any(test, feature = "test-support"))]
#[async_trait]
impl AuditSink for MemoryAuditSink {
async fn record_failure(&self, event: AuditEvent) {
let mut events = self
.events
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
events.push(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture(kind: VerifyErrorKind, source_id: &str) -> AuditEvent {
AuditEvent {
kind,
occurred_at: OffsetDateTime::UNIX_EPOCH,
source_id: source_id.to_owned(),
client_id_hint: None,
kid_hint: None,
metadata: BTreeMap::new(),
}
}
#[test]
fn compose_source_id_uses_compound_separator() {
assert_eq!(
compose_source_id(Some("rcw-client"), Some("k1")),
"rcw-client::k1"
);
}
#[test]
fn compose_source_id_collapses_anonymous_into_canonical_bucket() {
assert_eq!(compose_source_id(None, None), "anon::nokid");
}
#[test]
fn compose_id_token_source_id_uses_three_tuple_separator() {
assert_eq!(
compose_id_token_source_id(Some("rp-id"), Some("rp-aud"), Some("k1")),
"rp-id::rp-aud::k1"
);
}
#[test]
fn compose_id_token_source_id_collapses_anonymous_into_canonical() {
assert_eq!(
compose_id_token_source_id(None, None, None),
"anon::noaud::nokid"
);
}
#[test]
fn compose_id_token_source_id_partial_anonymity() {
assert_eq!(
compose_id_token_source_id(None, Some("rp-aud"), Some("k1")),
"anon::rp-aud::k1"
);
assert_eq!(
compose_id_token_source_id(Some("rp-id"), None, Some("k1")),
"rp-id::noaud::k1"
);
}
#[test]
fn from_id_token_hints_derives_three_tuple_source_id_and_pushes_aud_into_metadata() {
let event = AuditEvent::from_id_token_hints(
VerifyErrorKind::IdToken(IdTokenFailureKind::NonceMismatch),
OffsetDateTime::UNIX_EPOCH,
Some("rp-id".to_owned()),
Some("rp-aud".to_owned()),
Some("k1".to_owned()),
BTreeMap::new(),
);
assert_eq!(event.source_id, "rp-id::rp-aud::k1");
assert_eq!(event.client_id_hint.as_deref(), Some("rp-id"));
assert_eq!(event.kid_hint.as_deref(), Some("k1"));
assert_eq!(
event
.metadata
.get("aud_hint")
.and_then(|v| v.as_str()),
Some("rp-aud")
);
}
#[test]
fn compose_source_id_partial_anonymity() {
assert_eq!(compose_source_id(None, Some("k1")), "anon::k1");
assert_eq!(compose_source_id(Some("rcw"), None), "rcw::nokid");
}
#[test]
fn from_hints_derives_source_id_from_hints() {
let event = AuditEvent::from_hints(
VerifyErrorKind::SignatureInvalid,
OffsetDateTime::UNIX_EPOCH,
Some("rcw".to_owned()),
Some("k1".to_owned()),
BTreeMap::new(),
);
assert_eq!(event.source_id, "rcw::k1");
assert_eq!(event.client_id_hint.as_deref(), Some("rcw"));
assert_eq!(event.kid_hint.as_deref(), Some("k1"));
}
#[test]
fn rate_limit_key_round_trip() {
let event = fixture(VerifyErrorKind::SignatureInvalid, "rcw::k1");
assert_eq!(event.rate_limit_key().as_str(), "rcw::k1");
}
#[test]
#[allow(clippy::expect_used)]
fn verify_error_kind_round_trips_through_serde() {
let kind = VerifyErrorKind::MissingClaim("aud".to_owned());
let json = serde_json::to_string(&kind).expect("serialize");
let back: VerifyErrorKind = serde_json::from_str(&json).expect("deserialize");
assert_eq!(kind, back);
}
#[test]
#[allow(clippy::expect_used)]
fn verify_error_kind_id_token_variants_round_trip_through_serde() {
let unit = VerifyErrorKind::IdToken(IdTokenFailureKind::NonceMissing);
let json = serde_json::to_string(&unit).expect("serialize unit");
let back: VerifyErrorKind = serde_json::from_str(&json).expect("deserialize unit");
assert_eq!(unit, back);
let payload = VerifyErrorKind::IdToken(IdTokenFailureKind::UnknownClaim(
"backdoor".to_owned(),
));
let json = serde_json::to_string(&payload).expect("serialize payload");
let back: VerifyErrorKind = serde_json::from_str(&json).expect("deserialize payload");
assert_eq!(payload, back);
let cat = VerifyErrorKind::IdToken(IdTokenFailureKind::CatMismatch(String::new()));
let json = serde_json::to_string(&cat).expect("serialize cat");
let back: VerifyErrorKind = serde_json::from_str(&json).expect("deserialize cat");
assert_eq!(cat, back);
}
#[tokio::test]
async fn noop_sink_is_a_no_op() {
let sink = NoopAuditSink;
let event = fixture(VerifyErrorKind::Expired, "x");
sink.record_failure(event).await;
}
#[tokio::test]
async fn memory_sink_records_events_in_insertion_order() {
let sink = MemoryAuditSink::new();
sink.record_failure(fixture(VerifyErrorKind::Expired, "a"))
.await;
sink.record_failure(fixture(VerifyErrorKind::SignatureInvalid, "b"))
.await;
sink.record_failure(fixture(VerifyErrorKind::IdTokenAsBearer, "c"))
.await;
let events = sink.events();
assert_eq!(events.len(), 3);
assert_eq!(events[0].kind, VerifyErrorKind::Expired);
assert_eq!(events[1].kind, VerifyErrorKind::SignatureInvalid);
assert_eq!(events[2].kind, VerifyErrorKind::IdTokenAsBearer);
assert_eq!(events[0].source_id, "a");
}
#[tokio::test]
async fn memory_sink_is_empty_initially() {
let sink = MemoryAuditSink::new();
assert!(sink.is_empty());
sink.record_failure(fixture(VerifyErrorKind::Other, "x"))
.await;
assert_eq!(sink.len(), 1);
}
#[allow(dead_code)]
fn dyn_object_safety() {
let _: Arc<dyn AuditSink> = Arc::new(NoopAuditSink);
let _: Arc<dyn AuditSink> = Arc::new(MemoryAuditSink::new());
}
}