use crate::web::{Error, RequestContext};
pub use crate::resilience::rate_limit::RateLimit;
pub trait Guard: Send + Sync + 'static {
fn check(&self, ctx: &RequestContext) -> Result<(), Error>;
}
pub struct BearerAuth {
expected: &'static str,
}
impl BearerAuth {
pub const fn new(token: &'static str) -> Self {
Self { expected: token }
}
}
impl Guard for BearerAuth {
fn check(&self, ctx: &RequestContext) -> Result<(), Error> {
let h = ctx.header("authorization").ok_or(Error::Unauthorized)?;
let token = h.strip_prefix("Bearer ").ok_or(Error::Unauthorized)?;
if ct_eq(token.as_bytes(), self.expected.as_bytes()) {
Ok(())
} else {
Err(Error::Unauthorized)
}
}
}
#[inline]
fn ct_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for i in 0..a.len() {
diff |= a[i] ^ b[i];
}
diff == 0
}
pub struct JwtAuthGuard;
pub static JWT_AUTH: JwtAuthGuard = JwtAuthGuard;
impl Guard for JwtAuthGuard {
fn check(&self, ctx: &RequestContext) -> Result<(), Error> {
ctx.claims().map(|_| ()).ok_or(Error::Unauthorized)
}
}
pub struct RoleGuard {
pub role: &'static str,
}
impl RoleGuard {
pub const fn require(role: &'static str) -> Self {
Self { role }
}
}
impl Guard for RoleGuard {
fn check(&self, ctx: &RequestContext) -> Result<(), Error> {
let claims = ctx.claims().ok_or(Error::Unauthorized)?;
match claims.get("role").and_then(|v| v.as_str()) {
Some(r) if r == self.role => Ok(()),
Some(_) => Err(Error::Forbidden),
None => Err(Error::Forbidden),
}
}
}
pub struct SessionAuthGuard;
pub static SESSION_AUTH: SessionAuthGuard = SessionAuthGuard;
impl Guard for SessionAuthGuard {
fn check(&self, ctx: &RequestContext) -> Result<(), Error> {
ctx.session().map(|_| ()).ok_or(Error::Unauthorized)
}
}