use crate::Authentication;
use std::sync::Arc;
pub struct SecurityContext {
authentication: Arc<tokio::sync::RwLock<Option<Authentication>>>,
}
impl SecurityContext {
pub fn new() -> Self {
Self {
authentication: Arc::new(tokio::sync::RwLock::new(None)),
}
}
pub async fn get_authentication(&self) -> Option<Authentication> {
self.authentication.read().await.clone()
}
pub async fn set_authentication(&self, auth: Authentication) {
let mut auth_guard = self.authentication.write().await;
*auth_guard = Some(auth);
}
pub async fn clear(&self) {
let mut auth_guard = self.authentication.write().await;
*auth_guard = None;
}
pub async fn is_authenticated(&self) -> bool {
self.authentication
.read()
.await
.as_ref()
.is_some_and(|a| a.authenticated)
}
pub async fn get_username(&self) -> Option<String> {
self.authentication
.read()
.await
.as_ref()
.map(|a| a.principal.clone())
}
pub async fn has_authority(&self, authority: &crate::Authority) -> bool {
self.authentication
.read()
.await
.as_ref()
.is_some_and(|a| a.has_authority(authority))
}
pub async fn has_role(&self, role: &crate::Role) -> bool {
self.authentication
.read()
.await
.as_ref()
.is_some_and(|a| a.has_role(role))
}
}
impl Default for SecurityContext {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_CONTEXT: std::sync::LazyLock<SecurityContext> =
std::sync::LazyLock::new(SecurityContext::new);
pub fn context() -> &'static SecurityContext {
&GLOBAL_CONTEXT
}
pub async fn get_authentication() -> Option<Authentication> {
context().get_authentication().await
}
pub async fn set_authentication(auth: Authentication) {
context().set_authentication(auth).await;
}
pub async fn clear_context() {
context().clear().await;
}
pub async fn is_authenticated() -> bool {
context().is_authenticated().await
}
pub async fn get_username() -> Option<String> {
context().get_username().await
}
pub async fn has_authority(authority: &crate::Authority) -> bool {
context().has_authority(authority).await
}
pub async fn has_role(role: &crate::Role) -> bool {
context().has_role(role).await
}
tokio::task_local! {
static CURRENT_SECURITY_CONTEXT: Arc<SecurityContext>;
}
pub struct SecurityContextGuard {
ctx: Arc<SecurityContext>,
}
impl SecurityContextGuard {
pub fn new(ctx: SecurityContext) -> Self {
Self { ctx: Arc::new(ctx) }
}
pub fn scope<F, R>(&self, f: F) -> R
where
F: FnOnce() -> R,
{
CURRENT_SECURITY_CONTEXT.sync_scope(self.ctx.clone(), f)
}
pub async fn scope_async<F, Fut, R>(&self, f: F) -> R
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
{
CURRENT_SECURITY_CONTEXT.scope(self.ctx.clone(), f()).await
}
pub fn context(&self) -> Arc<SecurityContext> {
self.ctx.clone()
}
}
pub fn set_security_context<F, R>(ctx: SecurityContext, f: F) -> R
where
F: FnOnce() -> R,
{
let guard = SecurityContextGuard::new(ctx);
guard.scope(f)
}
pub fn get_security_context() -> Option<Arc<SecurityContext>> {
CURRENT_SECURITY_CONTEXT.try_with(Clone::clone).ok()
}
pub fn with_security_context<F, R>(ctx: SecurityContext, f: F) -> R
where
F: FnOnce() -> R,
{
CURRENT_SECURITY_CONTEXT.sync_scope(Arc::new(ctx), f)
}
pub async fn with_security_context_async<F, Fut, R>(ctx: SecurityContext, f: F) -> R
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
{
let arc = Arc::new(ctx);
CURRENT_SECURITY_CONTEXT.scope(arc, f()).await
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_security_context() {
let context = SecurityContext::new();
assert!(!context.is_authenticated().await);
assert!(context.get_username().await.is_none());
let auth = Authentication {
principal: "john".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
};
context.set_authentication(auth).await;
assert!(context.is_authenticated().await);
assert_eq!(context.get_username().await, Some("john".to_string()));
}
#[test]
fn test_default_security_context() {
let ctx = SecurityContext::default();
assert!(ctx.authentication.try_read().is_ok());
}
#[tokio::test]
async fn test_guard_scope_sync() {
assert!(get_security_context().is_none());
let ctx = SecurityContext::new();
ctx.set_authentication(Authentication {
principal: "alice".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
})
.await;
let guard = SecurityContextGuard::new(ctx);
guard.scope(|| {
let retrieved = get_security_context().expect("context should be set");
assert!(Arc::strong_count(&retrieved) > 0);
});
assert!(get_security_context().is_none());
}
#[tokio::test]
async fn test_set_security_context() {
let ctx = SecurityContext::new();
ctx.set_authentication(Authentication {
principal: "bob".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
})
.await;
let result = set_security_context(ctx, || {
let retrieved = get_security_context().expect("context should be set");
assert!(Arc::strong_count(&retrieved) > 0);
42
});
assert_eq!(result, 42);
assert!(get_security_context().is_none());
}
#[tokio::test]
async fn test_with_security_context() {
let ctx = SecurityContext::new();
ctx.set_authentication(Authentication {
principal: "bob".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
})
.await;
let result = with_security_context(ctx, || {
let retrieved = get_security_context().expect("context should be set");
assert!(Arc::strong_count(&retrieved) > 0);
42
});
assert_eq!(result, 42);
assert!(get_security_context().is_none());
}
#[tokio::test]
async fn test_with_security_context_async() {
let ctx = SecurityContext::new();
ctx.set_authentication(Authentication {
principal: "charlie".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
})
.await;
let username = with_security_context_async(ctx, || async {
tokio::task::yield_now().await;
let retrieved = get_security_context().expect("context should be set");
retrieved.get_username().await
})
.await;
assert_eq!(username, Some("charlie".to_string()));
assert!(get_security_context().is_none());
}
#[tokio::test]
async fn test_context_propagates_across_await() {
let ctx = SecurityContext::new();
ctx.set_authentication(Authentication {
principal: "dave".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
})
.await;
let guard = SecurityContextGuard::new(ctx);
let username = guard
.scope_async(|| async {
let before = get_security_context().unwrap().get_username().await;
assert_eq!(before, Some("dave".to_string()));
tokio::task::yield_now().await;
let after = get_security_context().unwrap().get_username().await;
assert_eq!(after, Some("dave".to_string()));
after
})
.await;
assert_eq!(username, Some("dave".to_string()));
assert!(get_security_context().is_none());
}
#[tokio::test]
async fn test_spawned_task_does_not_inherit_context() {
let ctx = SecurityContext::new();
ctx.set_authentication(Authentication {
principal: "eve".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
})
.await;
let guard = SecurityContextGuard::new(ctx);
guard.scope(|| {
let handle = tokio::task::spawn(async { get_security_context().is_some() });
drop(handle);
});
}
#[tokio::test]
async fn test_guard_context_accessor() {
let ctx = SecurityContext::new();
ctx.set_authentication(Authentication {
principal: "frank".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
})
.await;
let guard = SecurityContextGuard::new(ctx);
let arc = guard.context();
let username = arc.get_username().await;
assert_eq!(username, Some("frank".to_string()));
}
}