use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use super::{FetchError, Fetcher};
pub struct UserinfoFetcher {
base_url: String,
http: Client,
token_source: TokenSource,
}
enum TokenSource {
Static(String),
Dynamic(std::sync::Arc<dyn Fn() -> String + Send + Sync>),
Unset,
}
impl UserinfoFetcher {
#[must_use]
pub fn new(base_url: impl Into<String>) -> Self {
let mut base = base_url.into();
while base.ends_with('/') {
base.pop();
}
Self {
base_url: base,
http: Client::new(),
token_source: TokenSource::Unset,
}
}
#[must_use]
pub fn with_access_token(mut self, token: impl Into<String>) -> Self {
self.token_source = TokenSource::Static(token.into());
self
}
#[must_use]
pub fn with_token_source(
mut self,
source: std::sync::Arc<dyn Fn() -> String + Send + Sync>,
) -> Self {
self.token_source = TokenSource::Dynamic(source);
self
}
#[must_use]
pub fn with_http_client(mut self, http: Client) -> Self {
self.http = http;
self
}
fn current_token(&self) -> Option<String> {
match &self.token_source {
TokenSource::Static(s) => Some(s.clone()),
TokenSource::Dynamic(f) => Some(f()),
TokenSource::Unset => None,
}
}
}
impl std::fmt::Debug for UserinfoFetcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UserinfoFetcher")
.field("base_url", &self.base_url)
.finish_non_exhaustive()
}
}
#[derive(Deserialize)]
struct UserinfoSv {
session_version: Option<i64>,
}
#[async_trait]
impl Fetcher for UserinfoFetcher {
async fn fetch(&self, _sub: &str) -> Result<i64, FetchError> {
let token = self.current_token().ok_or_else(|| {
FetchError("UserinfoFetcher: no access_token configured".into())
})?;
let url = format!("{}/userinfo", self.base_url);
let resp = self
.http
.get(&url)
.bearer_auth(&token)
.send()
.await
.map_err(|e| FetchError(format!("transport: {e}")))?;
if !resp.status().is_success() {
return Err(FetchError(format!("HTTP {}", resp.status())));
}
let body: UserinfoSv = resp
.json()
.await
.map_err(|e| FetchError(format!("decode: {e}")))?;
body.session_version.ok_or_else(|| {
FetchError(
"userinfo did not return session_version (missing scope or non-Human subject)"
.into(),
)
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn happy_path_reads_session_version() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/userinfo"))
.and(header("authorization", "Bearer svc-token"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"sub": "abc", "session_version": 42})),
)
.mount(&server)
.await;
let fetcher =
UserinfoFetcher::new(server.uri()).with_access_token("svc-token");
let sv = fetcher.fetch("abc").await.unwrap();
assert_eq!(sv, 42);
}
#[tokio::test]
async fn missing_session_version_fails_closed() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/userinfo"))
.respond_with(
ResponseTemplate::new(200).set_body_json(serde_json::json!({"sub": "abc"})),
)
.mount(&server)
.await;
let fetcher = UserinfoFetcher::new(server.uri()).with_access_token("svc-token");
let err = fetcher.fetch("abc").await.unwrap_err();
assert!(err.0.contains("did not return session_version"), "{}", err.0);
}
#[tokio::test]
async fn http_error_fails_closed() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/userinfo"))
.respond_with(ResponseTemplate::new(503))
.mount(&server)
.await;
let fetcher = UserinfoFetcher::new(server.uri()).with_access_token("svc-token");
let err = fetcher.fetch("abc").await.unwrap_err();
assert!(err.0.contains("HTTP 503"), "{}", err.0);
}
#[tokio::test]
async fn no_token_fails_closed() {
let fetcher = UserinfoFetcher::new("http://unused");
let err = fetcher.fetch("abc").await.unwrap_err();
assert!(err.0.contains("no access_token configured"), "{}", err.0);
}
#[tokio::test]
async fn trailing_slash_on_base_url_normalized() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/userinfo"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"sub": "abc", "session_version": 1})),
)
.mount(&server)
.await;
let url_with_slash = format!("{}/", server.uri());
let fetcher = UserinfoFetcher::new(url_with_slash).with_access_token("svc-token");
let sv = fetcher.fetch("abc").await.unwrap();
assert_eq!(sv, 1);
}
}