use crate::rule::BitmaskAuth;
use http::Request;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub enum RoleExtractionResult {
Roles(u32),
Anonymous,
Error(String),
}
impl RoleExtractionResult {
pub fn roles_or(&self, default: u32) -> u32 {
match self {
Self::Roles(roles) => *roles,
Self::Anonymous => default,
Self::Error(_) => default,
}
}
pub fn roles_or_none(&self) -> u32 {
self.roles_or(0)
}
}
pub trait RoleExtractor<B>: Send + Sync {
fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult;
}
impl<B, T: RoleExtractor<B>> RoleExtractor<B> for Arc<T> {
fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
(**self).extract_roles(request)
}
}
impl<B, T: RoleExtractor<B> + ?Sized> RoleExtractor<B> for Box<T> {
fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
(**self).extract_roles(request)
}
}
#[derive(Debug, Clone)]
pub struct HeaderRoleExtractor {
header_name: String,
default_roles: u32,
}
impl HeaderRoleExtractor {
pub fn new(header_name: impl Into<String>) -> Self {
Self {
header_name: header_name.into(),
default_roles: 0,
}
}
pub fn with_default_roles(mut self, roles: u32) -> Self {
self.default_roles = roles;
self
}
}
impl<B> RoleExtractor<B> for HeaderRoleExtractor {
fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
match request.headers().get(&self.header_name) {
Some(value) => match value.to_str() {
Ok(s) if !s.is_empty() => {
let trimmed = s.trim();
if let Ok(roles) = trimmed.parse::<u32>() {
RoleExtractionResult::Roles(roles)
} else if let Some(hex) = trimmed.strip_prefix("0x") {
u32::from_str_radix(hex, 16)
.map(RoleExtractionResult::Roles)
.unwrap_or_else(|_| {
if self.default_roles != 0 {
RoleExtractionResult::Roles(self.default_roles)
} else {
RoleExtractionResult::Anonymous
}
})
} else if self.default_roles != 0 {
RoleExtractionResult::Roles(self.default_roles)
} else {
RoleExtractionResult::Anonymous
}
}
_ => {
if self.default_roles != 0 {
RoleExtractionResult::Roles(self.default_roles)
} else {
RoleExtractionResult::Anonymous
}
}
},
None => {
if self.default_roles != 0 {
RoleExtractionResult::Roles(self.default_roles)
} else {
RoleExtractionResult::Anonymous
}
}
}
}
}
pub struct ExtensionRoleExtractor<T> {
extract_fn: Box<dyn Fn(&T) -> u32 + Send + Sync>,
}
impl<T> ExtensionRoleExtractor<T> {
pub fn new<F>(extract_fn: F) -> Self
where
F: Fn(&T) -> u32 + Send + Sync + 'static,
{
Self {
extract_fn: Box::new(extract_fn),
}
}
}
impl<T> std::fmt::Debug for ExtensionRoleExtractor<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtensionRoleExtractor")
.field("type", &std::any::type_name::<T>())
.finish()
}
}
impl<B, T: Clone + Send + Sync + 'static> RoleExtractor<B> for ExtensionRoleExtractor<T> {
fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
match request.extensions().get::<T>() {
Some(ext) => RoleExtractionResult::Roles((self.extract_fn)(ext)),
None => RoleExtractionResult::Anonymous,
}
}
}
#[derive(Debug, Clone)]
pub struct FixedRoleExtractor {
roles: u32,
}
impl FixedRoleExtractor {
pub fn new(roles: u32) -> Self {
Self { roles }
}
}
impl<B> RoleExtractor<B> for FixedRoleExtractor {
fn extract_roles(&self, _request: &Request<B>) -> RoleExtractionResult {
RoleExtractionResult::Roles(self.roles)
}
}
#[derive(Debug, Clone, Default)]
pub struct AnonymousRoleExtractor;
impl AnonymousRoleExtractor {
pub fn new() -> Self {
Self
}
}
impl<B> RoleExtractor<B> for AnonymousRoleExtractor {
fn extract_roles(&self, _request: &Request<B>) -> RoleExtractionResult {
RoleExtractionResult::Anonymous
}
}
pub struct ChainedRoleExtractor<B> {
extractors: Vec<Box<dyn RoleExtractor<B>>>,
}
#[derive(Debug, Clone)]
pub enum IdExtractionResult {
Id(String),
Anonymous,
Error(String),
}
impl IdExtractionResult {
pub fn id_or(&self, default: impl Into<String>) -> String {
match self {
Self::Id(id) => id.clone(),
Self::Anonymous => default.into(),
Self::Error(_) => default.into(),
}
}
pub fn id_or_wildcard(&self) -> String {
self.id_or("*")
}
}
pub trait IdExtractor<B>: Send + Sync {
fn extract_id(&self, request: &Request<B>) -> IdExtractionResult;
}
impl<B, T: IdExtractor<B>> IdExtractor<B> for Arc<T> {
fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
(**self).extract_id(request)
}
}
impl<B, T: IdExtractor<B> + ?Sized> IdExtractor<B> for Box<T> {
fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
(**self).extract_id(request)
}
}
#[derive(Debug, Clone)]
pub struct HeaderIdExtractor {
header_name: String,
}
impl HeaderIdExtractor {
pub fn new(header_name: impl Into<String>) -> Self {
Self {
header_name: header_name.into(),
}
}
}
impl<B> IdExtractor<B> for HeaderIdExtractor {
fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
match request.headers().get(&self.header_name) {
Some(value) => match value.to_str() {
Ok(s) if !s.is_empty() => IdExtractionResult::Id(s.trim().to_string()),
_ => IdExtractionResult::Anonymous,
},
None => IdExtractionResult::Anonymous,
}
}
}
pub struct ExtensionIdExtractor<T> {
extract_fn: Box<dyn Fn(&T) -> String + Send + Sync>,
}
impl<T> ExtensionIdExtractor<T> {
pub fn new<F>(extract_fn: F) -> Self
where
F: Fn(&T) -> String + Send + Sync + 'static,
{
Self {
extract_fn: Box::new(extract_fn),
}
}
}
impl<T> std::fmt::Debug for ExtensionIdExtractor<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExtensionIdExtractor")
.field("type", &std::any::type_name::<T>())
.finish()
}
}
impl<B, T: Clone + Send + Sync + 'static> IdExtractor<B> for ExtensionIdExtractor<T> {
fn extract_id(&self, request: &Request<B>) -> IdExtractionResult {
match request.extensions().get::<T>() {
Some(ext) => IdExtractionResult::Id((self.extract_fn)(ext)),
None => IdExtractionResult::Anonymous,
}
}
}
#[derive(Debug, Clone)]
pub struct FixedIdExtractor {
id: String,
}
impl FixedIdExtractor {
pub fn new(id: impl Into<String>) -> Self {
Self { id: id.into() }
}
}
impl<B> IdExtractor<B> for FixedIdExtractor {
fn extract_id(&self, _request: &Request<B>) -> IdExtractionResult {
IdExtractionResult::Id(self.id.clone())
}
}
#[derive(Debug, Clone, Default)]
pub struct AnonymousIdExtractor;
impl AnonymousIdExtractor {
pub fn new() -> Self {
Self
}
}
impl<B> IdExtractor<B> for AnonymousIdExtractor {
fn extract_id(&self, _request: &Request<B>) -> IdExtractionResult {
IdExtractionResult::Anonymous
}
}
#[derive(Debug, Clone)]
pub enum AuthResult<A> {
Auth(A),
Anonymous,
Error(String),
}
pub trait AuthExtractor<A, B>: Send + Sync {
fn extract_auth(&self, request: &Request<B>) -> AuthResult<A>;
}
pub struct BitmaskAuthExtractor<E, I> {
role_extractor: E,
id_extractor: I,
anonymous_roles: u32,
default_id: String,
}
impl<E, I> BitmaskAuthExtractor<E, I> {
pub fn new(role_extractor: E, id_extractor: I) -> Self {
Self {
role_extractor,
id_extractor,
anonymous_roles: 0,
default_id: "*".to_string(),
}
}
pub fn with_anonymous_roles(mut self, roles: u32) -> Self {
self.anonymous_roles = roles;
self
}
pub fn with_default_id(mut self, id: impl Into<String>) -> Self {
self.default_id = id.into();
self
}
}
impl<E: std::fmt::Debug, I: std::fmt::Debug> std::fmt::Debug for BitmaskAuthExtractor<E, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BitmaskAuthExtractor")
.field("role_extractor", &self.role_extractor)
.field("id_extractor", &self.id_extractor)
.finish()
}
}
impl<E, I, B> AuthExtractor<BitmaskAuth, B> for BitmaskAuthExtractor<E, I>
where
E: RoleExtractor<B>,
I: IdExtractor<B>,
{
fn extract_auth(&self, request: &Request<B>) -> AuthResult<BitmaskAuth> {
let roles = self.role_extractor.extract_roles(request).roles_or(self.anonymous_roles);
let id = self.id_extractor.extract_id(request).id_or(&self.default_id);
AuthResult::Auth(BitmaskAuth { roles, id })
}
}
impl<B> ChainedRoleExtractor<B> {
pub fn new() -> Self {
Self {
extractors: Vec::new(),
}
}
pub fn push<E: RoleExtractor<B> + 'static>(mut self, extractor: E) -> Self {
self.extractors.push(Box::new(extractor));
self
}
}
impl<B> Default for ChainedRoleExtractor<B> {
fn default() -> Self {
Self::new()
}
}
impl<B> std::fmt::Debug for ChainedRoleExtractor<B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChainedRoleExtractor")
.field("extractors_count", &self.extractors.len())
.finish()
}
}
impl<B> RoleExtractor<B> for ChainedRoleExtractor<B>
where
B: Send + Sync,
{
fn extract_roles(&self, request: &Request<B>) -> RoleExtractionResult {
for extractor in &self.extractors {
match extractor.extract_roles(request) {
RoleExtractionResult::Roles(roles) => return RoleExtractionResult::Roles(roles),
RoleExtractionResult::Error(e) => {
tracing::warn!(error = %e, "Role extractor failed, trying next");
}
RoleExtractionResult::Anonymous => continue,
}
}
RoleExtractionResult::Anonymous
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Request;
#[test]
fn test_header_extractor_decimal() {
let extractor = HeaderRoleExtractor::new("X-Roles");
let req = Request::builder()
.header("X-Roles", "5") .body(())
.unwrap();
match extractor.extract_roles(&req) {
RoleExtractionResult::Roles(roles) => assert_eq!(roles, 5),
_ => panic!("Expected Roles"),
}
}
#[test]
fn test_header_extractor_hex() {
let extractor = HeaderRoleExtractor::new("X-Roles");
let req = Request::builder()
.header("X-Roles", "0x1F") .body(())
.unwrap();
match extractor.extract_roles(&req) {
RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0x1F),
_ => panic!("Expected Roles"),
}
}
#[test]
fn test_header_extractor_missing() {
let extractor = HeaderRoleExtractor::new("X-Roles");
let req = Request::builder().body(()).unwrap();
match extractor.extract_roles(&req) {
RoleExtractionResult::Anonymous => {}
_ => panic!("Expected Anonymous"),
}
}
#[test]
fn test_header_extractor_default() {
let extractor = HeaderRoleExtractor::new("X-Roles")
.with_default_roles(0b100);
let req = Request::builder().body(()).unwrap();
match extractor.extract_roles(&req) {
RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0b100),
_ => panic!("Expected Roles"),
}
}
#[test]
fn test_fixed_extractor() {
let extractor = FixedRoleExtractor::new(0b11);
let req = Request::builder().body(()).unwrap();
match extractor.extract_roles(&req) {
RoleExtractionResult::Roles(roles) => assert_eq!(roles, 0b11),
_ => panic!("Expected Roles"),
}
}
}