use std::sync::Arc;
use axum::response::IntoResponse;
use glob::{MatchOptions, Pattern};
use crate::config::RoutePoliciesConfig;
use crate::middleware::common;
use modkit::api::Problem;
use modkit_security::SecurityContext;
#[derive(Clone, Debug)]
pub struct ScopeEnforcementRules {
rules: Arc<[CompiledRule]>,
enabled: bool,
}
#[derive(Clone, Debug)]
struct CompiledRule {
pattern: Pattern,
method: Option<String>,
required_scopes: Vec<String>,
}
impl ScopeEnforcementRules {
pub fn from_config(config: &RoutePoliciesConfig) -> Result<Self, anyhow::Error> {
if !config.enabled {
return Ok(Self {
rules: Arc::from([]),
enabled: false,
});
}
let mut rules = Vec::with_capacity(config.rules.len());
for rule in &config.rules {
if rule.required_scopes.is_empty() {
return Err(anyhow::anyhow!(
"Route policy rule for path '{}' has empty required_scopes. \
This would match all tokens and is likely a config mistake.",
rule.path
));
}
let pattern = Pattern::new(&rule.path).map_err(|e| {
anyhow::anyhow!(
"Invalid glob pattern '{}' in route_policies: {e}",
rule.path
)
})?;
rules.push(CompiledRule {
pattern,
method: rule.method.as_ref().map(|m| m.to_uppercase()),
required_scopes: rule.required_scopes.clone(),
});
}
tracing::info!(
rules_count = rules.len(),
"Route policy enforcement enabled with {} rules",
rules.len()
);
Ok(Self {
rules: Arc::from(rules),
enabled: true,
})
}
fn matches_protected_route(&self, path: &str, method: &str) -> bool {
if !self.enabled {
return false;
}
let match_opts = MatchOptions {
require_literal_separator: true,
..MatchOptions::default()
};
self.rules.iter().any(|rule| {
let path_matches = rule.pattern.matches_with(path, match_opts);
let method_matches = rule
.method
.as_ref()
.is_none_or(|m| m.eq_ignore_ascii_case(method));
path_matches && method_matches
})
}
#[allow(clippy::result_large_err)]
fn check(&self, path: &str, method: &str, token_scopes: &[String]) -> Result<(), Problem> {
if !self.enabled {
return Ok(());
}
if token_scopes.iter().any(|s| s == "*") {
return Ok(());
}
let match_opts = MatchOptions {
require_literal_separator: true,
..MatchOptions::default()
};
for rule in self.rules.iter() {
let path_matches = rule.pattern.matches_with(path, match_opts);
let method_matches = rule
.method
.as_ref()
.is_none_or(|m| m.eq_ignore_ascii_case(method));
if path_matches && method_matches {
let has_required_scope = rule
.required_scopes
.iter()
.any(|required| token_scopes.contains(required));
if has_required_scope {
return Ok(());
}
tracing::warn!(
path = %path,
method = %method,
pattern = %rule.pattern,
rule_method = ?rule.method,
required_scopes = ?rule.required_scopes,
token_scopes = ?token_scopes,
"Route policy enforcement denied: insufficient scopes"
);
return Err(Problem::new(
axum::http::StatusCode::FORBIDDEN,
"Forbidden",
"Insufficient token scopes for this resource",
));
}
}
Ok(())
}
}
#[derive(Clone)]
pub struct ScopeEnforcementState {
pub rules: ScopeEnforcementRules,
}
pub async fn scope_enforcement_middleware(
axum::extract::State(state): axum::extract::State<ScopeEnforcementState>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> axum::response::Response {
if !state.rules.enabled {
return next.run(req).await;
}
let path = req.uri().path();
let path = common::resolve_path(&req, path);
let method = req.method().as_str();
let Some(security_context) = req.extensions().get::<SecurityContext>() else {
if state.rules.matches_protected_route(&path, method) {
tracing::warn!(
path = %path,
method = %method,
"Route policy enforcement denied: no SecurityContext for protected route"
);
return Problem::new(
axum::http::StatusCode::UNAUTHORIZED,
"Unauthorized",
"Authentication required for this resource",
)
.into_response();
}
return next.run(req).await;
};
if let Err(problem) = state
.rules
.check(&path, method, security_context.token_scopes())
{
return problem.into_response();
}
next.run(req).await
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use crate::config::RoutePolicyRule;
fn build_config(enabled: bool, routes: Vec<(&str, Vec<&str>)>) -> RoutePoliciesConfig {
build_config_with_methods(
enabled,
routes.into_iter().map(|(p, s)| (p, None, s)).collect(),
)
}
type TestRoute<'a> = (&'a str, Option<&'a str>, Vec<&'a str>);
fn build_config_with_methods(enabled: bool, routes: Vec<TestRoute<'_>>) -> RoutePoliciesConfig {
let rules = routes
.into_iter()
.map(|(path, method, scopes)| RoutePolicyRule {
path: path.to_owned(),
method: method.map(String::from),
required_scopes: scopes.into_iter().map(String::from).collect(),
})
.collect();
RoutePoliciesConfig { enabled, rules }
}
#[test]
fn disabled_enforcement_always_passes() {
let config = build_config(false, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
assert!(rules.check("/admin/users", "GET", &[]).is_ok());
}
#[test]
fn first_party_app_always_passes() {
let config = build_config(true, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["*".to_owned()];
assert!(rules.check("/admin/users", "GET", &scopes).is_ok());
}
#[test]
fn matching_scope_passes() {
let config = build_config(true, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["admin".to_owned()];
assert!(rules.check("/admin/users", "GET", &scopes).is_ok());
}
#[test]
fn any_of_required_scopes_passes() {
let config = build_config(
true,
vec![("/events/v1/*", vec!["read:events", "write:events"])],
);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["read:events".to_owned()];
assert!(rules.check("/events/v1/list", "GET", &scopes).is_ok());
let scopes = vec!["write:events".to_owned()];
assert!(rules.check("/events/v1/create", "POST", &scopes).is_ok());
}
#[test]
fn missing_scope_returns_forbidden() {
let config = build_config(true, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["read:events".to_owned()];
let result = rules.check("/admin/users", "GET", &scopes);
assert!(result.is_err());
let problem = result.unwrap_err();
assert_eq!(problem.status, axum::http::StatusCode::FORBIDDEN);
}
#[test]
fn empty_scopes_returns_forbidden() {
let config = build_config(true, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let result = rules.check("/admin/users", "GET", &[]);
assert!(result.is_err());
let problem = result.unwrap_err();
assert_eq!(problem.status, axum::http::StatusCode::FORBIDDEN);
}
#[test]
fn unmatched_route_passes() {
let config = build_config(true, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["unrelated:scope".to_owned()];
assert!(rules.check("/public/health", "GET", &scopes).is_ok());
}
#[test]
fn glob_single_star_matches_single_segment_only() {
let config = build_config(true, vec![("/api/*/items", vec!["api:read"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["api:read".to_owned()];
assert!(rules.check("/api/v1/items", "GET", &scopes).is_ok());
assert!(rules.check("/api/v2/items", "GET", &scopes).is_ok());
let unrelated_scopes = vec!["unrelated:scope".to_owned()];
assert!(
rules
.check("/api/v1/nested/items", "GET", &unrelated_scopes)
.is_ok()
); }
#[test]
fn glob_double_star_matches_multiple_segments() {
let config = build_config(true, vec![("/api/**", vec!["api:access"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["api:access".to_owned()];
assert!(rules.check("/api/v1", "GET", &scopes).is_ok());
assert!(rules.check("/api/v1/items", "GET", &scopes).is_ok());
assert!(
rules
.check("/api/v1/items/123/details", "GET", &scopes)
.is_ok()
);
}
#[test]
fn invalid_glob_pattern_returns_error() {
let config = build_config(true, vec![("/admin/[invalid", vec!["admin"])]);
let result = ScopeEnforcementRules::from_config(&config);
assert!(result.is_err());
}
#[test]
fn empty_required_scopes_returns_error() {
let config = build_config(true, vec![("/admin/*", vec![])]);
let result = ScopeEnforcementRules::from_config(&config);
let err = result.expect_err("should fail with empty required_scopes");
assert!(
err.to_string().contains("empty required_scopes"),
"error should mention empty required_scopes: {err}"
);
}
#[test]
fn multiple_non_overlapping_rules() {
let config = build_config(
true,
vec![
("/admin/*", vec!["admin"]),
("/events/**", vec!["events:read"]),
],
);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let admin_scopes = vec!["admin".to_owned()];
assert!(rules.check("/admin/users", "GET", &admin_scopes).is_ok());
let events_scopes = vec!["events:read".to_owned()];
assert!(
rules
.check("/events/v1/list", "GET", &events_scopes)
.is_ok()
);
assert!(rules.check("/admin/users", "GET", &events_scopes).is_err());
assert!(
rules
.check("/events/v1/list", "GET", &admin_scopes)
.is_err()
);
}
#[test]
fn overlapping_rules_first_match_wins() {
let config = build_config(
true,
vec![
("/api/**", vec!["basic"]), ("/api/admin/**", vec!["admin"]), ],
);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let basic_scopes = vec!["basic".to_owned()];
let admin_scopes = vec!["admin".to_owned()];
assert!(
rules
.check("/api/admin/users", "GET", &basic_scopes)
.is_ok()
);
assert!(
rules
.check("/api/admin/users", "GET", &admin_scopes)
.is_err()
);
}
#[test]
fn matches_protected_route_returns_true_for_matching_path() {
let config = build_config(true, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
assert!(rules.matches_protected_route("/admin/users", "GET"));
assert!(rules.matches_protected_route("/admin/settings", "POST"));
}
#[test]
fn matches_protected_route_returns_false_for_non_matching_path() {
let config = build_config(true, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
assert!(!rules.matches_protected_route("/public/health", "GET"));
assert!(!rules.matches_protected_route("/api/v1/users", "GET"));
}
#[test]
fn matches_protected_route_returns_false_when_disabled() {
let config = build_config(false, vec![("/admin/*", vec!["admin"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
assert!(!rules.matches_protected_route("/admin/users", "GET"));
}
#[test]
fn first_match_wins_more_specific_rule_first() {
let config = build_config(
true,
vec![
("/api/admin/**", vec!["admin"]), ("/api/**", vec!["basic"]), ],
);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let admin_scopes = vec!["admin".to_owned()];
let basic_scopes = vec!["basic".to_owned()];
assert!(
rules
.check("/api/admin/users", "GET", &admin_scopes)
.is_ok()
);
assert!(
rules
.check("/api/admin/users", "GET", &basic_scopes)
.is_err()
);
assert!(rules.check("/api/other", "GET", &basic_scopes).is_ok());
}
#[test]
fn first_match_wins_broader_rule_first() {
let config = build_config(
true,
vec![
("/api/**", vec!["basic"]), ("/api/admin/**", vec!["admin"]), ],
);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let basic_scopes = vec!["basic".to_owned()];
assert!(
rules
.check("/api/admin/users", "GET", &basic_scopes)
.is_ok()
);
}
#[test]
fn method_matching_specific_method() {
let config = build_config_with_methods(
true,
vec![
("/users/*", Some("POST"), vec!["users:write"]),
("/users/*", Some("GET"), vec!["users:read"]),
],
);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let read_scopes = vec!["users:read".to_owned()];
let write_scopes = vec!["users:write".to_owned()];
assert!(rules.check("/users/123", "GET", &read_scopes).is_ok());
assert!(rules.check("/users/123", "POST", &write_scopes).is_ok());
assert!(rules.check("/users/123", "GET", &write_scopes).is_err());
assert!(rules.check("/users/123", "POST", &read_scopes).is_err());
}
#[test]
fn method_matching_any_method() {
let config = build_config_with_methods(true, vec![("/api/**", None, vec!["api:access"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["api:access".to_owned()];
assert!(rules.check("/api/users", "GET", &scopes).is_ok());
assert!(rules.check("/api/users", "POST", &scopes).is_ok());
assert!(rules.check("/api/users", "PUT", &scopes).is_ok());
assert!(rules.check("/api/users", "DELETE", &scopes).is_ok());
}
#[test]
fn method_matching_case_insensitive() {
let config =
build_config_with_methods(true, vec![("/api/**", Some("get"), vec!["api:read"])]);
let rules = ScopeEnforcementRules::from_config(&config).unwrap();
let scopes = vec!["api:read".to_owned()];
assert!(rules.check("/api/users", "GET", &scopes).is_ok());
assert!(rules.check("/api/users", "get", &scopes).is_ok());
assert!(rules.check("/api/users", "Get", &scopes).is_ok());
}
}