use std::collections::HashMap;
use std::marker::PhantomData;
use async_trait::async_trait;
use ppoppo_token::id_token::Nonce;
use super::port::{IdAssertion, IdTokenVerifier, IdVerifyError, ScopePiiReader};
use crate::VerifyConfig;
pub struct MemoryIdTokenVerifier<S: ScopePiiReader> {
assertions: HashMap<String, IdAssertion<S>>,
default_failure: Option<IdVerifyError>,
#[allow(dead_code)]
expectations: Option<VerifyConfig>,
_scope: PhantomData<S>,
}
impl<S: ScopePiiReader> Default for MemoryIdTokenVerifier<S> {
fn default() -> Self {
Self {
assertions: HashMap::new(),
default_failure: None,
expectations: None,
_scope: PhantomData,
}
}
}
impl<S: ScopePiiReader> MemoryIdTokenVerifier<S> {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_expectations(mut self, expectations: VerifyConfig) -> Self {
self.expectations = Some(expectations);
self
}
pub fn insert(
&mut self,
id_token: impl Into<String>,
assertion: IdAssertion<S>,
) -> &mut Self {
self.assertions.insert(id_token.into(), assertion);
self
}
pub fn fail_with(&mut self, err: IdVerifyError) -> &mut Self {
self.default_failure = Some(err);
self
}
}
#[async_trait]
impl<S: ScopePiiReader> IdTokenVerifier<S> for MemoryIdTokenVerifier<S> {
async fn verify(
&self,
id_token: &str,
_expected_nonce: &Nonce,
) -> Result<IdAssertion<S>, IdVerifyError> {
if let Some(err) = self.default_failure.clone() {
return Err(err);
}
self.assertions
.get(id_token)
.cloned()
.ok_or(IdVerifyError::SignatureInvalid)
}
}
#[cfg(feature = "oauth")]
mod state_store_impl {
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use ppoppo_clock::ArcClock;
use ppoppo_clock::native::WallClock;
use time::OffsetDateTime;
use tokio::sync::Mutex;
use crate::oidc::state_store::{
PendingAuthRequest, State, StateStore, StateStoreError,
};
struct Entry {
pending: PendingAuthRequest,
expires_at: OffsetDateTime,
}
pub struct InMemoryStateStore {
inner: Mutex<Inner>,
clock: ArcClock,
}
struct Inner {
map: HashMap<State, Entry>,
put_failures: Vec<StateStoreError>,
take_failures: Vec<StateStoreError>,
}
impl Default for InMemoryStateStore {
fn default() -> Self {
Self {
inner: Mutex::new(Inner {
map: HashMap::new(),
put_failures: Vec::new(),
take_failures: Vec::new(),
}),
clock: Arc::new(WallClock),
}
}
}
impl InMemoryStateStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_clock(mut self, clock: ArcClock) -> Self {
self.clock = clock;
self
}
pub async fn with_put_failure(self, err: StateStoreError) -> Self {
self.inner.lock().await.put_failures.push(err);
self
}
pub async fn with_take_failure(self, err: StateStoreError) -> Self {
self.inner.lock().await.take_failures.push(err);
self
}
}
#[async_trait]
impl StateStore for InMemoryStateStore {
async fn put(
&self,
state: &State,
pending: PendingAuthRequest,
ttl: Duration,
) -> Result<(), StateStoreError> {
let mut inner = self.inner.lock().await;
if let Some(err) = inner.put_failures.pop() {
return Err(err);
}
let expires_at = self.clock.now_utc() + ttl;
inner.map.insert(
state.clone(),
Entry {
pending,
expires_at,
},
);
Ok(())
}
async fn take(
&self,
state: &State,
) -> Result<Option<PendingAuthRequest>, StateStoreError> {
let mut inner = self.inner.lock().await;
if let Some(err) = inner.take_failures.pop() {
return Err(err);
}
let Some(entry) = inner.map.remove(state) else {
return Ok(None);
};
if self.clock.now_utc() > entry.expires_at {
return Ok(None);
}
Ok(Some(entry.pending))
}
}
}
#[cfg(feature = "oauth")]
pub use state_store_impl::InMemoryStateStore;