#![allow(clippy::unwrap_used, clippy::expect_used, clippy::panic, clippy::unimplemented)]
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use axum_extra::extract::PrivateCookieJar;
use axum_extra::extract::cookie::{Cookie, Key};
use pas_external::middleware::{
NewSession, RefreshTokenResolver, SessionResolution, SessionResolver, SessionStore, SvAware,
SvAwareSessionResolver,
};
use pas_external::oauth::{TokenResponse, UserInfo};
use pas_external::pas_port::{MemoryPasAuth, PasFailure};
use pas_external::session_version::{SV_CACHE_TTL, SessionVersionCache};
use pas_external::types::{Ppnum, PpnumId, SessionId};
#[derive(Debug, thiserror::Error)]
#[error("fake error: {0}")]
struct FakeError(String);
#[derive(Clone, Debug)]
struct FakeAuthContext {
ppnum_id: String,
sv: Option<i64>,
}
impl SvAware for FakeAuthContext {
fn ppnum_id(&self) -> &str {
&self.ppnum_id
}
fn sv(&self) -> Option<i64> {
self.sv
}
}
#[derive(Default)]
struct FakeStore {
inner: Mutex<HashMap<String, FakeAuthContext>>,
update_sv_should_fail: Mutex<bool>,
update_sv_calls: Mutex<u32>,
}
impl FakeStore {
fn put(&self, id: &str, ctx: FakeAuthContext) {
self.inner.lock().unwrap().insert(id.to_string(), ctx);
}
fn fail_update_sv(&self) {
*self.update_sv_should_fail.lock().unwrap() = true;
}
fn update_sv_call_count(&self) -> u32 {
*self.update_sv_calls.lock().unwrap()
}
}
impl SessionStore for FakeStore {
type Error = FakeError;
type AuthContext = FakeAuthContext;
async fn create(&self, _session: NewSession) -> Result<SessionId, FakeError> {
unimplemented!("not used by these tests")
}
async fn find(&self, id: &SessionId) -> Result<Option<FakeAuthContext>, FakeError> {
Ok(self.inner.lock().unwrap().get(&id.0).cloned())
}
async fn delete(&self, _id: &SessionId) -> Result<(), FakeError> {
Ok(())
}
async fn update_sv(&self, id: &SessionId, new_sv: i64) -> Result<(), FakeError> {
*self.update_sv_calls.lock().unwrap() += 1;
if *self.update_sv_should_fail.lock().unwrap() {
return Err(FakeError("update_sv DB failure".into()));
}
let mut inner = self.inner.lock().unwrap();
if let Some(ctx) = inner.get_mut(&id.0) {
ctx.sv = Some(new_sv);
}
Ok(())
}
}
struct FakeRefreshResolver {
plaintext: Option<String>,
}
impl RefreshTokenResolver for FakeRefreshResolver {
type Error = FakeError;
async fn resolve_refresh_token(&self, _id: &SessionId) -> Result<Option<String>, FakeError> {
Ok(self.plaintext.clone())
}
}
#[derive(Default)]
struct FakeCache {
inner: Mutex<HashMap<String, i64>>,
set_calls: Mutex<u32>,
}
impl FakeCache {
fn set_call_count(&self) -> u32 {
*self.set_calls.lock().unwrap()
}
}
#[async_trait]
impl SessionVersionCache for FakeCache {
async fn get(&self, key: &str) -> Option<i64> {
self.inner.lock().unwrap().get(key).copied()
}
async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
*self.set_calls.lock().unwrap() += 1;
self.inner.lock().unwrap().insert(key.to_string(), sv);
}
}
const COOKIE_NAME: &str = "test_session";
const SESSION_ID: &str = "01HXYZTESTSESSION0000000000";
const PPNUM_ID: &str = "01HXYZPPPN00000000000000PP";
const PLAINTEXT_RT: &str = "rt_plain_xyz";
fn jar_with_session() -> PrivateCookieJar {
let key = Key::generate();
let mut jar = PrivateCookieJar::new(key);
jar = jar.add(Cookie::new(COOKIE_NAME, SESSION_ID));
jar
}
fn fake_user_info(sv: Option<i64>) -> UserInfo {
let mut info = UserInfo::new(
PPNUM_ID.parse::<PpnumId>().unwrap(),
"10012345678".parse::<Ppnum>().unwrap(),
);
if let Some(v) = sv {
info = info.with_session_version(v);
}
info
}
fn token_response_for(at: &str) -> TokenResponse {
let body = serde_json::json!({
"access_token": at,
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": null,
});
serde_json::from_value(body).unwrap()
}
fn build_resolver(
store: Arc<FakeStore>,
cache: Arc<FakeCache>,
pas: Arc<MemoryPasAuth>,
refresh_resolver: Arc<FakeRefreshResolver>,
) -> SvAwareSessionResolver<FakeStore, FakeRefreshResolver, FakeCache, MemoryPasAuth> {
let cookie_name: Arc<str> = Arc::from(COOKIE_NAME);
let base = SessionResolver::new(Arc::clone(&store), cookie_name);
SvAwareSessionResolver::new(base, store, refresh_resolver, pas, cache)
}
#[tokio::test]
async fn sv_pas_4xx_on_refresh_yields_expired() {
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
let cache = Arc::new(FakeCache::default());
let pas = Arc::new(MemoryPasAuth::new().expect_refresh(
PLAINTEXT_RT,
Err(PasFailure::Rejected {
status: 400,
detail: "invalid_grant".into(),
}),
));
let resolver = build_resolver(
Arc::clone(&store),
Arc::clone(&cache),
Arc::clone(&pas),
Arc::new(FakeRefreshResolver {
plaintext: Some(PLAINTEXT_RT.into()),
}),
);
let res = resolver.resolve(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(cache.set_call_count(), 0, "cache must not be touched");
assert_eq!(
store.update_sv_call_count(),
0,
"update_sv must not be called"
);
}
#[tokio::test]
async fn sv_pas_5xx_on_refresh_yields_expired_fail_closed() {
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
let cache = Arc::new(FakeCache::default());
let pas = Arc::new(MemoryPasAuth::new().expect_refresh(
PLAINTEXT_RT,
Err(PasFailure::ServerError {
status: 503,
detail: "upstream".into(),
}),
));
let resolver = build_resolver(
Arc::clone(&store),
Arc::clone(&cache),
Arc::clone(&pas),
Arc::new(FakeRefreshResolver {
plaintext: Some(PLAINTEXT_RT.into()),
}),
);
let res = resolver.resolve(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
}
#[tokio::test]
async fn sv_userinfo_session_version_none_yields_expired() {
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
let cache = Arc::new(FakeCache::default());
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_for("at_new")))
.expect_userinfo("at_new", Ok(fake_user_info(None))),
);
let resolver = build_resolver(
Arc::clone(&store),
Arc::clone(&cache),
Arc::clone(&pas),
Arc::new(FakeRefreshResolver {
plaintext: Some(PLAINTEXT_RT.into()),
}),
);
let res = resolver.resolve(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(cache.set_call_count(), 0);
assert_eq!(store.update_sv_call_count(), 0);
}
#[tokio::test]
async fn sv_userinfo_5xx_yields_expired() {
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
let cache = Arc::new(FakeCache::default());
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_for("at_new")))
.expect_userinfo(
"at_new",
Err(PasFailure::ServerError {
status: 502,
detail: "bad gateway".into(),
}),
),
);
let resolver = build_resolver(
Arc::clone(&store),
Arc::clone(&cache),
Arc::clone(&pas),
Arc::new(FakeRefreshResolver {
plaintext: Some(PLAINTEXT_RT.into()),
}),
);
let res = resolver.resolve(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
}
#[tokio::test]
async fn sv_update_sv_db_failure_yields_expired_does_not_update_cache() {
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
store.fail_update_sv();
let cache = Arc::new(FakeCache::default());
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_for("at_new")))
.expect_userinfo("at_new", Ok(fake_user_info(Some(2)))),
);
let resolver = build_resolver(
Arc::clone(&store),
Arc::clone(&cache),
Arc::clone(&pas),
Arc::new(FakeRefreshResolver {
plaintext: Some(PLAINTEXT_RT.into()),
}),
);
let res = resolver.resolve(&jar_with_session()).await.unwrap();
assert!(matches!(res, SessionResolution::Expired));
assert_eq!(store.update_sv_call_count(), 1);
assert_eq!(
cache.set_call_count(),
0,
"cache must not be touched after store failure"
);
}
#[tokio::test]
async fn sv_happy_path_returns_authenticated_with_updated_sv() {
let store = Arc::new(FakeStore::default());
store.put(
SESSION_ID,
FakeAuthContext {
ppnum_id: PPNUM_ID.into(),
sv: Some(1),
},
);
let cache = Arc::new(FakeCache::default());
let pas = Arc::new(
MemoryPasAuth::new()
.expect_refresh(PLAINTEXT_RT, Ok(token_response_for("at_new")))
.expect_userinfo("at_new", Ok(fake_user_info(Some(2)))),
);
let resolver = build_resolver(
Arc::clone(&store),
Arc::clone(&cache),
Arc::clone(&pas),
Arc::new(FakeRefreshResolver {
plaintext: Some(PLAINTEXT_RT.into()),
}),
);
let res = resolver.resolve(&jar_with_session()).await.unwrap();
let ctx = match res {
SessionResolution::Authenticated(ctx) => ctx,
other => panic!("expected Authenticated, got {other:?}"),
};
assert_eq!(ctx.sv(), Some(2));
assert_eq!(store.update_sv_call_count(), 1);
assert_eq!(cache.set_call_count(), 1);
let _ttl = SV_CACHE_TTL;
}