use std::{collections::HashSet, sync::Arc};
use anyhow::Context;
use axum::{response::IntoResponse, routing::get, Json, Router};
use composable_tower_http::{
authorize::{
header::bearer::DefaultBearerExtractor,
jwt::{
jwk_set::{fetch::HttpJwkSetFetcher, rotating::RotatingJwkSetProvider},
DefaultJwtAuthorizerBuilder, Validation,
},
},
extension::{ExtensionLayerExt, ModificationLayerExt},
extract::Extracted,
modify::Modifier,
};
use http::StatusCode;
use reqwest::Client;
use serde::{Deserialize, Serialize};
#[path = "../util/util.rs"]
mod util;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub email_verified: bool,
pub name: String,
pub preferred_username: String,
pub given_name: String,
pub family_name: String,
pub email: String,
pub groups: Vec<String>,
}
async fn claims(Extracted(claims): Extracted<Claims>) -> impl IntoResponse {
Json(claims)
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
util::init("jwt_groups")?;
let jwks_uri = std::env::var("JWKS_URI").unwrap_or_else(|_| {
String::from("https://keycloak.com/realms/master/protocol/openid-connect/certs")
});
let iss =
std::env::var("ISS").unwrap_or_else(|_| String::from("https://keycloak.com/realms/master"));
tracing::info!(%jwks_uri, %iss);
let auth_layer = DefaultJwtAuthorizerBuilder::new(
DefaultBearerExtractor::new(),
RotatingJwkSetProvider::new(30, HttpJwkSetFetcher::new(jwks_uri, Client::new()))
.await
.context("Failed to create jwk set provider")?,
Validation::new().aud(&["account"]).iss(&[iss]),
)
.build::<Claims>()
.extension_layer();
let admins: HashSet<String> = ["/admins"].into_iter().map(Into::into).collect();
let admins_modify_layer = GroupsValidator::new(admins).modification_layer::<Claims>();
let super_admins: HashSet<String> = ["/super-admins"].into_iter().map(Into::into).collect();
let super_admins_modify_layer =
GroupsValidator::new(super_admins).modification_layer::<Claims>();
let app = Router::new()
.route(
"/super-admins",
get(claims).layer(super_admins_modify_layer),
)
.route("/admins", get(claims).layer(admins_modify_layer))
.route("/", get(claims))
.layer(auth_layer)
.layer(util::trace_layer());
util::serve(app).await
}
#[derive(Debug, Clone)]
struct GroupsValidator {
groups: Arc<HashSet<String>>,
}
impl GroupsValidator {
fn new(groups: HashSet<String>) -> Self {
Self {
groups: Arc::new(groups),
}
}
}
impl Modifier<Claims> for GroupsValidator {
type Modified = Claims;
type Error = GroupsValidationError;
async fn modify(&self, claims: Claims) -> Result<Claims, Self::Error> {
if claims
.groups
.iter()
.any(|group| self.groups.contains(group))
{
return Ok(claims);
};
Err(GroupsValidationError)
}
}
#[derive(Debug, thiserror::Error)]
#[error("Not in groups")]
struct GroupsValidationError;
impl IntoResponse for GroupsValidationError {
fn into_response(self) -> axum::response::Response {
(StatusCode::UNAUTHORIZED, "Not in groups").into_response()
}
}