use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::collections::TtlMap;
use crate::error::{ConstraintViolationError, ConstraintViolationType, InternalError};
use crate::oauth::PendingAuthorization;
use super::{InflightOAuthRequestStore, InflightOAuthRequestStoreError};
const PENDING_AUTHORIZATION_EXPIRATION_SECS: u64 = 3600;
#[derive(Clone)]
pub struct MemoryInflightOAuthRequestStore {
pending_authorizations: Arc<Mutex<TtlMap<String, PendingAuthorization>>>,
}
impl MemoryInflightOAuthRequestStore {
pub fn new() -> Self {
Self {
pending_authorizations: Arc::new(Mutex::new(TtlMap::new(Duration::from_secs(
PENDING_AUTHORIZATION_EXPIRATION_SECS,
)))),
}
}
}
impl Default for MemoryInflightOAuthRequestStore {
fn default() -> Self {
Self::new()
}
}
impl InflightOAuthRequestStore for MemoryInflightOAuthRequestStore {
fn insert_request(
&self,
request_id: String,
pending_authorization: PendingAuthorization,
) -> Result<(), InflightOAuthRequestStoreError> {
let mut inner = self.pending_authorizations.lock().map_err(|_| {
InternalError::with_message("pending authorizations lock was poisoned".into())
})?;
if !inner.contains_key(&request_id) {
inner.insert(request_id, pending_authorization);
Ok(())
} else {
Err(InflightOAuthRequestStoreError::ConstraintViolation(
ConstraintViolationError::with_violation_type(ConstraintViolationType::Unique),
))
}
}
fn remove_request(
&self,
request_id: &str,
) -> Result<Option<PendingAuthorization>, InflightOAuthRequestStoreError> {
Ok(self
.pending_authorizations
.lock()
.map_err(|_| {
InternalError::with_message("pending authorizations lock was poisoned".into())
})?
.remove(request_id))
}
fn clone_box(&self) -> Box<dyn InflightOAuthRequestStore> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oauth::store::tests::{
test_duplicate_id_insert, test_request_store_insert_and_remove,
};
#[test]
fn memory_insert_request_and_remove() {
let inflight_request_store = MemoryInflightOAuthRequestStore::new();
test_request_store_insert_and_remove(&inflight_request_store);
}
#[test]
fn memory_duplicate_id_insert() {
let inflight_request_store = MemoryInflightOAuthRequestStore::new();
test_duplicate_id_insert(&inflight_request_store);
}
}