use crate::config::HydraConfig;
use crate::error::{ServiceError, ServiceResult};
use super::{AuthContext, extract_subject};
use axum::body::Body;
use axum::response::{IntoResponse, Response};
use connectrpc::{ConnectError, ErrorCode};
use serde::Deserialize;
use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::Arc,
task::{Context as TaskContext, Poll},
};
use tower::{Layer, Service};
#[derive(Debug, Clone, Deserialize)]
pub struct IntrospectionResponse {
pub active: bool,
#[serde(default)]
pub sub: Option<String>,
#[serde(default)]
pub exp: Option<i64>,
#[serde(default)]
pub email: Option<String>,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub roles: Option<Vec<String>>,
#[serde(default)]
pub scope: Option<String>,
#[serde(default)]
pub client_id: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
impl IntrospectionResponse {
pub fn email(&self) -> Option<String> {
self.email
.clone()
.or_else(|| extra_string(&self.extra, "email"))
}
pub fn display_name(&self) -> Option<String> {
self.name
.clone()
.or_else(|| extra_string(&self.extra, "name"))
.or_else(|| extra_string(&self.extra, "preferred_username"))
}
pub fn resolved_roles(&self) -> Vec<String> {
if let Some(ref roles) = self.roles {
return roles.clone();
}
if let Some(roles) = extra_roles(&self.extra, "roles") {
return roles;
}
extra_roles(&self.extra, "role").unwrap_or_default()
}
pub fn resolved_scope(&self) -> Option<String> {
self.scope
.clone()
.or_else(|| extra_string(&self.extra, "scope"))
}
pub fn resolved_client_id(&self) -> Option<String> {
self.client_id
.clone()
.or_else(|| extra_string(&self.extra, "client_id"))
}
}
fn extra_string(extra: &HashMap<String, serde_json::Value>, key: &str) -> Option<String> {
extra.get(key).and_then(|v| match v {
serde_json::Value::String(s) => Some(s.clone()),
_ => Some(v.to_string()),
})
}
fn extra_roles(extra: &HashMap<String, serde_json::Value>, key: &str) -> Option<Vec<String>> {
extra.get(key).map(|v| match v {
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|item| match item {
serde_json::Value::String(s) => Some(s.clone()),
_ => Some(item.to_string()),
})
.collect(),
serde_json::Value::String(s) => vec![s.clone()],
_ => Vec::new(),
})
}
fn unauthorized(message: &str) -> Response {
ConnectError::new(ErrorCode::Unauthenticated, message).into_response()
}
#[derive(Debug, Clone)]
pub struct IntrospectionClient {
config: HydraConfig,
http: Arc<reqwest::Client>,
}
impl IntrospectionClient {
pub fn new(config: HydraConfig) -> Self {
Self {
config,
http: Arc::new(reqwest::Client::new()),
}
}
pub fn from_config(config: HydraConfig) -> Self {
Self::new(config)
}
pub async fn introspect(&self, token: &str) -> ServiceResult<IntrospectionResponse> {
let response = self
.http
.post(&self.config.introspection_url)
.basic_auth(&self.config.client_id, Some(&self.config.client_secret))
.form(&[("token", token)])
.send()
.await
.map_err(|e| ServiceError::Unauthenticated(format!("token introspection failed: {e}")))?;
if !response.status().is_success() {
return Err(ServiceError::Unauthenticated(format!(
"token introspection returned {}",
response.status()
)));
}
response
.json::<IntrospectionResponse>()
.await
.map_err(|e| ServiceError::Unauthenticated(format!("failed to parse introspection response: {e}")))
}
}
#[derive(Debug, Clone)]
pub struct IntrospectionLayer {
client: Arc<IntrospectionClient>,
}
impl IntrospectionLayer {
pub fn new(client: IntrospectionClient) -> Self {
Self {
client: Arc::new(client),
}
}
pub fn from_config(config: HydraConfig) -> Self {
Self::new(IntrospectionClient::new(config))
}
}
impl<S> Layer<S> for IntrospectionLayer {
type Service = IntrospectionService<S>;
fn layer(&self, inner: S) -> Self::Service {
IntrospectionService {
inner,
client: Arc::clone(&self.client),
}
}
}
#[derive(Debug, Clone)]
pub struct IntrospectionService<S> {
inner: S,
client: Arc<IntrospectionClient>,
}
impl<S> Service<http::Request<Body>> for IntrospectionService<S>
where
S: Service<http::Request<Body>, Response = Response> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<Body>) -> Self::Future {
let token = match extract_subject(req.headers()) {
Some(token) => token,
None => {
let resp = unauthorized("missing authorization token");
return Box::pin(async move { Ok(resp) });
}
};
let client = Arc::clone(&self.client);
let mut inner = self.inner.clone();
Box::pin(async move {
let introspection = match client.introspect(&token).await {
Ok(resp) => resp,
Err(_) => {
return Ok(unauthorized("token introspection failed"));
}
};
if !introspection.active {
return Ok(unauthorized("token is inactive or expired"));
}
let subject = introspection
.sub
.clone()
.or_else(|| introspection.email.clone())
.unwrap_or_default();
let mut ctx = AuthContext::authenticated(subject, None)
.with_roles(introspection.resolved_roles());
if let Some(email) = introspection.email() {
ctx = ctx.with_email(email);
}
if let Some(name) = introspection.display_name() {
ctx = ctx.with_name(name);
}
if let Some(exp) = introspection.exp {
ctx = ctx.with_exp(exp);
}
if let Some(scope) = introspection.resolved_scope() {
ctx = ctx.with_scope(scope);
}
if let Some(client_id) = introspection.resolved_client_id() {
ctx = ctx.with_client_id(client_id);
}
let mut req = req;
req.extensions_mut().insert(ctx);
inner.call(req).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
extract::State,
response::{IntoResponse, Response},
routing::post,
Json, Router,
};
use serde_json::json;
use std::future::Future;
use tower::{ServiceBuilder, ServiceExt};
fn ok_service() -> impl Service<
http::Request<Body>,
Response = Response,
Error = std::convert::Infallible,
Future = impl Future<Output = Result<Response, std::convert::Infallible>>,
> + Clone {
tower::service_fn(|_req: http::Request<Body>| async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
})
}
async fn mock_hydra_server(response: serde_json::Value) -> (String, tokio::task::JoinHandle<()>) {
mock_hydra_server_with_status(response, http::StatusCode::OK).await
}
async fn mock_hydra_server_with_status(
response: serde_json::Value,
status: http::StatusCode,
) -> (String, tokio::task::JoinHandle<()>) {
async fn handler(
State((body, status)): State<(serde_json::Value, http::StatusCode)>,
) -> (http::StatusCode, Json<serde_json::Value>) {
(status, Json(body))
}
let app = Router::new()
.route("/oauth2/introspect", post(handler))
.with_state((response, status));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(format!("http://{}/oauth2/introspect", addr), handle)
}
async fn mock_hydra_malformed_server() -> (String, tokio::task::JoinHandle<()>) {
let app = Router::new().route(
"/oauth2/introspect",
post(|| async { "this is not json" }),
);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(format!("http://{}/oauth2/introspect", addr), handle)
}
#[test]
fn test_response_resolved_roles_from_explicit_field() {
let resp: IntrospectionResponse = serde_json::from_value(json!({
"active": true,
"sub": "user-1",
"roles": ["admin", "user"],
}))
.unwrap();
assert_eq!(resp.resolved_roles(), vec!["admin", "user"]);
}
#[test]
fn test_response_resolved_roles_from_extra_string() {
let resp: IntrospectionResponse = serde_json::from_value(json!({
"active": true,
"sub": "user-1",
"role": "admin",
}))
.unwrap();
assert_eq!(resp.resolved_roles(), vec!["admin"]);
}
#[test]
fn test_response_email_from_extra() {
let resp: IntrospectionResponse = serde_json::from_value(json!({
"active": true,
"sub": "user-1",
"email": "alice@example.com",
}))
.unwrap();
assert_eq!(resp.email(), Some("alice@example.com".to_string()));
}
#[tokio::test]
async fn test_introspection_layer_missing_token_returns_401() {
let client = IntrospectionClient::new(HydraConfig::default());
let layer = IntrospectionLayer::new(client);
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let req = http::Request::builder()
.uri("/")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_introspection_layer_inactive_token_returns_401() {
let (url, _handle) = mock_hydra_server(json!({ "active": false })).await;
let client = IntrospectionClient::new(HydraConfig {
introspection_url: url,
client_id: "client".to_string(),
client_secret: "secret".to_string(),
});
let layer = IntrospectionLayer::new(client);
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let req = http::Request::builder()
.uri("/")
.header("Authorization", "Bearer invalid-token")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_introspection_layer_active_token_injects_context() {
let (url, _handle) = mock_hydra_server(json!({
"active": true,
"sub": "user-123",
"email": "alice@example.com",
"name": "Alice",
"roles": ["admin"],
"scope": "read write",
"client_id": "my-client",
"exp": 1893456000,
}))
.await;
let client = IntrospectionClient::new(HydraConfig {
introspection_url: url,
client_id: "client".to_string(),
client_secret: "secret".to_string(),
});
let layer = IntrospectionLayer::new(client);
let mut svc = ServiceBuilder::new().layer(layer).service(
tower::service_fn(|req: http::Request<Body>| async move {
let ctx = req.extensions().get::<AuthContext>().cloned().unwrap();
assert!(ctx.is_authenticated());
assert_eq!(ctx.subject(), Some(&"user-123".to_string()));
assert_eq!(ctx.email(), Some(&"alice@example.com".to_string()));
assert_eq!(ctx.name(), Some(&"Alice".to_string()));
assert_eq!(ctx.roles(), &["admin"]);
assert_eq!(ctx.scope(), Some(&"read write".to_string()));
assert_eq!(ctx.client_id(), Some(&"my-client".to_string()));
assert_eq!(ctx.exp, Some(1893456000));
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
}),
);
let req = http::Request::builder()
.uri("/")
.header("Authorization", "Bearer valid-token")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
}
#[test]
fn test_response_resolved_roles_from_extra_array() {
let mut extra = std::collections::HashMap::new();
extra.insert("roles".to_string(), json![["admin", "user"]]);
let resp = IntrospectionResponse {
active: true,
sub: Some("user-1".to_string()),
exp: None,
email: None,
name: None,
roles: None,
scope: None,
client_id: None,
extra,
};
assert_eq!(resp.resolved_roles(), vec!["admin", "user"]);
}
#[test]
fn test_extra_roles_array_and_non_string() {
let mut extra = std::collections::HashMap::new();
extra.insert("roles".to_string(), json!(["admin", 42, true]));
assert_eq!(
extra_roles(&extra, "roles"),
Some(vec!["admin".to_string(), "42".to_string(), "true".to_string()])
);
extra.insert("count".to_string(), json!(7));
assert_eq!(extra_roles(&extra, "count"), Some(Vec::<String>::new()));
}
#[test]
fn test_extra_string_non_string_value() {
let mut extra = std::collections::HashMap::new();
extra.insert("num".to_string(), json!(42));
assert_eq!(extra_string(&extra, "num"), Some("42".to_string()));
assert_eq!(extra_string(&extra, "missing"), None);
}
#[test]
fn test_response_display_name_from_preferred_username() {
let resp: IntrospectionResponse = serde_json::from_value(json!({
"active": true,
"sub": "user-1",
"preferred_username": "bob",
}))
.unwrap();
assert_eq!(resp.display_name(), Some("bob".to_string()));
}
#[test]
fn test_response_scope_and_client_id() {
let resp: IntrospectionResponse = serde_json::from_value(json!({
"active": true,
"sub": "user-1",
"scope": "read write",
"client_id": "client-1",
}))
.unwrap();
assert_eq!(resp.resolved_scope(), Some("read write".to_string()));
assert_eq!(resp.resolved_client_id(), Some("client-1".to_string()));
}
#[test]
fn test_response_scope_and_client_id_from_extra() {
let mut extra = std::collections::HashMap::new();
extra.insert("scope".to_string(), json!("read"));
extra.insert("client_id".to_string(), json!(42));
let resp = IntrospectionResponse {
active: true,
sub: Some("user-1".to_string()),
exp: None,
email: None,
name: None,
roles: None,
scope: None,
client_id: None,
extra,
};
assert_eq!(resp.resolved_scope(), Some("read".to_string()));
assert_eq!(resp.resolved_client_id(), Some("42".to_string()));
}
#[test]
fn test_from_config_constructors() {
let config = HydraConfig {
introspection_url: "http://hydra:4445/oauth2/introspect".to_string(),
client_id: "client".to_string(),
client_secret: "secret".to_string(),
};
let client = IntrospectionClient::from_config(config.clone());
assert_eq!(client.config.introspection_url, config.introspection_url);
let layer = IntrospectionLayer::from_config(config);
assert!(Arc::strong_count(&layer.client) >= 1);
}
#[tokio::test]
async fn test_introspection_client_non_success_status_returns_unauthenticated() {
let (url, _handle) = mock_hydra_server_with_status(
json!({ "active": false }),
http::StatusCode::FORBIDDEN,
)
.await;
let client = IntrospectionClient::new(HydraConfig {
introspection_url: url,
client_id: "client".to_string(),
client_secret: "secret".to_string(),
});
let result = client.introspect("token").await;
assert!(matches!(result, Err(ServiceError::Unauthenticated(_))));
}
#[tokio::test]
async fn test_introspection_client_malformed_response_returns_unauthenticated() {
let (url, _handle) = mock_hydra_malformed_server().await;
let client = IntrospectionClient::new(HydraConfig {
introspection_url: url,
client_id: "client".to_string(),
client_secret: "secret".to_string(),
});
let result = client.introspect("token").await;
assert!(matches!(result, Err(ServiceError::Unauthenticated(_))));
}
#[tokio::test]
async fn test_introspection_layer_network_error_returns_401() {
let client = IntrospectionClient::new(HydraConfig {
introspection_url: "http://127.0.0.1:1/oauth2/introspect".to_string(),
client_id: "client".to_string(),
client_secret: "secret".to_string(),
});
let layer = IntrospectionLayer::new(client);
let mut svc = ServiceBuilder::new().layer(layer).service(ok_service());
let req = http::Request::builder()
.uri("/")
.header("Authorization", "Bearer token")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_introspection_layer_uses_email_when_sub_missing() {
let (url, _handle) = mock_hydra_server(json!({
"active": true,
"email": "bob@example.com",
"name": "Bob",
"roles": ["user"],
}))
.await;
let client = IntrospectionClient::new(HydraConfig {
introspection_url: url,
client_id: "client".to_string(),
client_secret: "secret".to_string(),
});
let layer = IntrospectionLayer::new(client);
let mut svc = ServiceBuilder::new().layer(layer).service(
tower::service_fn(|req: http::Request<Body>| async move {
let ctx = req.extensions().get::<AuthContext>().cloned().unwrap();
assert!(ctx.is_authenticated());
assert_eq!(ctx.subject(), Some(&"bob@example.com".to_string()));
assert_eq!(ctx.email(), Some(&"bob@example.com".to_string()));
assert_eq!(ctx.name(), Some(&"Bob".to_string()));
assert_eq!(ctx.roles(), &["user"]);
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
}),
);
let req = http::Request::builder()
.uri("/")
.header("Authorization", "Bearer valid-token")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
}
#[tokio::test]
async fn test_introspection_layer_minimal_active_response() {
let (url, _handle) = mock_hydra_server(json!({ "active": true, "sub": "user-99" })).await;
let client = IntrospectionClient::new(HydraConfig {
introspection_url: url,
client_id: "client".to_string(),
client_secret: "secret".to_string(),
});
let layer = IntrospectionLayer::new(client);
let mut svc = ServiceBuilder::new().layer(layer).service(
tower::service_fn(|req: http::Request<Body>| async move {
let ctx = req.extensions().get::<AuthContext>().cloned().unwrap();
assert!(ctx.is_authenticated());
assert_eq!(ctx.subject(), Some(&"user-99".to_string()));
assert!(ctx.email().is_none());
assert!(ctx.name().is_none());
assert!(ctx.roles().is_empty());
assert!(ctx.scope().is_none());
assert!(ctx.client_id().is_none());
assert!(ctx.exp.is_none());
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(http::StatusCode::OK)
.body(Body::empty())
.unwrap()
.into_response(),
)
}),
);
let req = http::Request::builder()
.uri("/")
.header("Authorization", "Bearer valid-token")
.body(Body::empty())
.unwrap();
let resp = svc.ready().await.unwrap().call(req).await.unwrap();
assert_eq!(resp.status(), http::StatusCode::OK);
}
}