#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic, clippy::unimplemented)]
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use base64::{Engine, engine::general_purpose::STANDARD};
use pas_external::middleware::{
NewSession, SessionResolution, SessionResolver, SessionStore, SvAware, SessionValidator,
SvCachePort,
};
use pas_external::oauth::TokenResponse;
use pas_external::pas_port::{MemoryPasAuth, PasFailure};
use pas_external::session_liveness::{EncryptedRefreshToken, TokenCipher};
use pas_external::types::SessionId;
#[derive(Debug, thiserror::Error)]
#[error("fake error: {0}")]
struct FakeError(String);
#[derive(Clone, Debug)]
struct FakeAuthContext {
ppnum_id: String,
sv: Option<i64>,
}
impl SvAware for FakeAuthContext {
fn ppnum_id(&self) -> &str {
&self.ppnum_id
}
fn sv(&self) -> Option<i64> {
self.sv
}
}
#[derive(Default)]
struct FakeStore {
contexts: Mutex<HashMap<String, FakeAuthContext>>,
ciphertexts: Mutex<HashMap<String, EncryptedRefreshToken>>,
update_sv_should_fail: Mutex<bool>,
update_sv_calls: Mutex<u32>,
ciphertext_lookup_should_fail: Mutex<bool>,
drop_session_on_refetch: Mutex<bool>,
find_calls: Mutex<u32>,
}
impl FakeStore {
fn put(&self, id: &str, ctx: FakeAuthContext) {
self.contexts.lock().unwrap().insert(id.to_string(), ctx);
}
fn put_ciphertext(&self, id: &str, ct: EncryptedRefreshToken) {
self.ciphertexts.lock().unwrap().insert(id.to_string(), ct);
}
fn fail_update_sv(&self) {
*self.update_sv_should_fail.lock().unwrap() = true;
}
fn fail_ciphertext_lookup(&self) {
*self.ciphertext_lookup_should_fail.lock().unwrap() = true;
}
fn drop_session_on_refetch(&self) {
*self.drop_session_on_refetch.lock().unwrap() = true;
}
fn update_sv_call_count(&self) -> u32 {
*self.update_sv_calls.lock().unwrap()
}
}
impl SessionStore for FakeStore {
type Error = FakeError;
type AuthContext = FakeAuthContext;
async fn create(&self, _session: NewSession) -> Result<SessionId, FakeError> {
unimplemented!("not used by these tests")
}
async fn find(&self, id: &SessionId) -> Result<Option<FakeAuthContext>, FakeError> {
let mut calls = self.find_calls.lock().unwrap();
*calls += 1;
let n = *calls;
drop(calls);
if n >= 2 && *self.drop_session_on_refetch.lock().unwrap() {
return Ok(None);
}
Ok(self.contexts.lock().unwrap().get(&id.0).cloned())
}
async fn delete(&self, _id: &SessionId) -> Result<(), FakeError> {
Ok(())
}
async fn update_sv(&self, id: &SessionId, new_sv: i64) -> Result<(), FakeError> {
*self.update_sv_calls.lock().unwrap() += 1;
if *self.update_sv_should_fail.lock().unwrap() {
return Err(FakeError("update_sv DB failure".into()));
}
let mut inner = self.contexts.lock().unwrap();
if let Some(ctx) = inner.get_mut(&id.0) {
ctx.sv = Some(new_sv);
}
Ok(())
}
async fn get_refresh_ciphertext(
&self,
id: &SessionId,
) -> Result<Option<EncryptedRefreshToken>, FakeError> {
if *self.ciphertext_lookup_should_fail.lock().unwrap() {
return Err(FakeError("ciphertext lookup DB failure".into()));
}
Ok(self.ciphertexts.lock().unwrap().get(&id.0).cloned())
}
}
#[derive(Clone, Default)]
struct FakeBackend {
inner: Arc<FakeBackendInner>,
}
#[derive(Default)]
struct FakeBackendInner {
map: Mutex<HashMap<String, i64>>,
store_calls: Mutex<u32>,
last_ttl: Mutex<Option<Duration>>,
}
impl FakeBackend {
fn store_call_count(&self) -> u32 {
*self.inner.store_calls.lock().unwrap()
}
fn last_ttl(&self) -> Option<Duration> {
*self.inner.last_ttl.lock().unwrap()
}
}
#[async_trait]
impl SvCachePort for FakeBackend {
async fn load(&self, key: &str) -> Option<i64> {
self.inner.map.lock().unwrap().get(key).copied()
}
async fn store(&self, key: &str, sv: i64, ttl: Duration) {
*self.inner.store_calls.lock().unwrap() += 1;
*self.inner.last_ttl.lock().unwrap() = Some(ttl);
self.inner.map.lock().unwrap().insert(key.to_string(), sv);
}
}
const COOKIE_NAME: &str = "test_session";
const SESSION_ID: &str = "01HXYZTESTSESSION0000000000";
const PPNUM_ID: &str = "01HXYZPPPN00000000000000PP";
const PLAINTEXT_RT: &str = "rt_plain_xyz";
fn cipher() -> TokenCipher {
let key_b64 = STANDARD.encode([0u8; 32]);
TokenCipher::from_base64_key(&key_b64).unwrap()
}
fn other_cipher() -> TokenCipher {
let mut key = [0u8; 32];
key[0] = 1;
let key_b64 = STANDARD.encode(key);
TokenCipher::from_base64_key(&key_b64).unwrap()
}
fn jar_with_session() -> PrivateCookieJar {
let key = Key::generate();
let mut jar = PrivateCookieJar::new(key);
jar = jar.add(Cookie::new(COOKIE_NAME, SESSION_ID));
jar
}
fn token_response_with_sv(sv: Option<i64>) -> TokenResponse {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let header = URL_SAFE_NO_PAD.encode(br#"{"alg":"none"}"#);
let payload_json = match sv {
Some(v) => serde_json::json!({ "sv": v }),
None => serde_json::json!({}),
};
let payload = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&payload_json).unwrap());
let access_token = format!("{header}.{payload}.sig");
let body = serde_json::json!({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": null,
});
serde_json::from_value(body).unwrap()
}
fn standard_setup(cipher_for_ct: &TokenCipher) -> Arc<FakeStore> {
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
store.put_ciphertext(
SESSION_ID,
cipher_for_ct.encrypt_to_token(PLAINTEXT_RT).unwrap(),
);
store
}
fn build_resolver(
store: Arc<FakeStore>,
backend: FakeBackend,
pas: Arc<MemoryPasAuth>,
cipher: Option<Arc<TokenCipher>>,
) -> SessionValidator<FakeStore, MemoryPasAuth, FakeBackend> {
let cookie_name: Arc<str> = Arc::from(COOKIE_NAME);
let base = SessionResolver::new(Arc::clone(&store), cookie_name);
SessionValidator::new(base, store, pas, Arc::new(backend), cipher)
}
#[tokio::test]
async fn sv_cipher_failure_yields_expired() {
let cipher_for_ct = cipher();
let cipher_for_resolver = other_cipher();
let store = standard_setup(&cipher_for_ct);
let backend = FakeBackend::default();
let pas = Arc::new(MemoryPasAuth::new());
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::new(cipher_for_resolver)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(backend.store_call_count(), 0);
assert_eq!(store.update_sv_call_count(), 0);
}
#[tokio::test]
async fn sv_pas_4xx_on_refresh_yields_expired() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
let backend = FakeBackend::default();
let pas = Arc::new(MemoryPasAuth::new().expect_refresh(
PLAINTEXT_RT,
Err(PasFailure::Rejected {
status: 400,
detail: "invalid_grant".into(),
}),
));
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(backend.store_call_count(), 0, "cache must not be touched");
assert_eq!(
store.update_sv_call_count(),
0,
"update_sv must not be called"
);
}
#[tokio::test]
async fn sv_pas_5xx_on_refresh_yields_expired_fail_closed() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
let backend = FakeBackend::default();
let pas = Arc::new(MemoryPasAuth::new().expect_refresh(
PLAINTEXT_RT,
Err(PasFailure::ServerError {
status: 503,
detail: "upstream".into(),
}),
));
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(backend.store_call_count(), 0);
assert_eq!(store.update_sv_call_count(), 0);
}
#[tokio::test]
async fn sv_pas_transport_on_refresh_yields_expired() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
let backend = FakeBackend::default();
let pas = Arc::new(MemoryPasAuth::new().expect_refresh(
PLAINTEXT_RT,
Err(PasFailure::Transport {
detail: "connection reset".into(),
}),
));
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
}
#[tokio::test]
async fn sv_access_token_missing_sv_yields_expired() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
let backend = FakeBackend::default();
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_with_sv(None))),
);
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(backend.store_call_count(), 0);
assert_eq!(store.update_sv_call_count(), 0);
}
#[tokio::test]
async fn sv_update_sv_db_failure_yields_expired_does_not_update_cache() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
store.fail_update_sv();
let backend = FakeBackend::default();
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_with_sv(Some(2)))),
);
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(store.update_sv_call_count(), 1);
assert_eq!(
backend.store_call_count(),
0,
"cache must not be touched after store failure"
);
}
#[tokio::test]
async fn sv_happy_path_returns_authenticated_with_updated_sv() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
let backend = FakeBackend::default();
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_with_sv(Some(2)))),
);
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
let ctx = match res {
SessionResolution::Authenticated(ctx) => ctx,
other => panic!("expected Authenticated, got {other:?}"),
};
assert_eq!(ctx.sv(), Some(2));
assert_eq!(store.update_sv_call_count(), 1);
assert_eq!(backend.store_call_count(), 1);
assert_eq!(backend.last_ttl(), Some(Duration::from_secs(60)));
}
#[tokio::test]
async fn sv_no_ciphertext_yields_expired() {
let cipher_arc = Arc::new(cipher());
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
let backend = FakeBackend::default();
let pas = Arc::new(MemoryPasAuth::new());
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
}
#[tokio::test]
async fn sv_stale_branch_drives_refresh_and_records_new_sv() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
let backend = FakeBackend::default();
backend
.store(&format!("sv:{PPNUM_ID}"), 5, Duration::from_secs(60))
.await;
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_with_sv(Some(7)))),
);
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
let ctx = match res {
SessionResolution::Authenticated(ctx) => ctx,
other => panic!("expected Authenticated, got {other:?}"),
};
assert_eq!(ctx.sv(), Some(7));
assert_eq!(store.update_sv_call_count(), 1);
assert_eq!(backend.store_call_count(), 2);
}
#[tokio::test]
async fn sv_no_cipher_configured_with_ciphertext_yields_expired() {
let cipher_for_ct = cipher();
let store = standard_setup(&cipher_for_ct);
let backend = FakeBackend::default();
let pas = Arc::new(MemoryPasAuth::new());
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
None, );
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
}
#[tokio::test]
async fn sv_ciphertext_lookup_db_failure_yields_expired() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
store.fail_ciphertext_lookup();
let backend = FakeBackend::default();
let pas = Arc::new(MemoryPasAuth::new());
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(store.update_sv_call_count(), 0);
assert_eq!(backend.store_call_count(), 0);
}
#[tokio::test]
async fn sv_session_vanishes_between_update_and_refetch_yields_expired() {
let cipher_arc = Arc::new(cipher());
let store = standard_setup(&cipher_arc);
store.drop_session_on_refetch();
let backend = FakeBackend::default();
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_with_sv(Some(7)))),
);
let resolver = build_resolver(
Arc::clone(&store),
backend.clone(),
Arc::clone(&pas),
Some(Arc::clone(&cipher_arc)),
);
let res = resolver.validate(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(
store.update_sv_call_count(),
1,
"update_sv must have landed before the refetch raced",
);
assert_eq!(
backend.store_call_count(),
1,
"policy.record must have landed before the refetch raced",
);
}