use crate::Authentication;
use hiver_http::Request;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct SecurityContextExt {
authentication: Arc<RwLock<Option<Authentication>>>,
}
impl SecurityContextExt {
pub fn new() -> Self {
Self {
authentication: Arc::new(RwLock::new(None)),
}
}
pub fn from_request(req: &Request) -> Option<Arc<Self>> {
req.extensions().get::<Arc<Self>>().cloned()
}
pub fn set_to_request(req: &mut Request) -> Arc<Self> {
let ctx = Arc::new(Self::new());
req.extensions_mut().insert(ctx.clone());
ctx
}
pub async fn get_authentication(&self) -> Option<Authentication> {
self.authentication.read().await.clone()
}
pub async fn set_authentication(&self, auth: Authentication) {
let mut auth_guard = self.authentication.write().await;
*auth_guard = Some(auth);
}
pub async fn clear(&self) {
let mut auth_guard = self.authentication.write().await;
*auth_guard = None;
}
pub async fn is_authenticated(&self) -> bool {
self.authentication
.read()
.await
.as_ref()
.is_some_and(|a| a.authenticated)
}
pub async fn get_username(&self) -> Option<String> {
self.authentication
.read()
.await
.as_ref()
.map(|a| a.principal.clone())
}
pub async fn has_authority(&self, authority: &crate::Authority) -> bool {
self.authentication
.read()
.await
.as_ref()
.is_some_and(|a| a.has_authority(authority))
}
pub async fn has_role(&self, role: &crate::Role) -> bool {
self.authentication
.read()
.await
.as_ref()
.is_some_and(|a| a.has_role(role))
}
}
impl Default for SecurityContextExt {
fn default() -> Self {
Self::new()
}
}
pub async fn get_authentication_from_request(req: &Request) -> Option<Authentication> {
SecurityContextExt::from_request(req)?
.get_authentication()
.await
}
pub fn set_authentication_to_request(
req: &mut Request,
_auth: Authentication,
) -> Arc<SecurityContextExt> {
let ctx = SecurityContextExt::set_to_request(req);
ctx
}
#[cfg(test)]
mod tests {
use super::*;
use hiver_http::{Method, Request};
#[tokio::test]
async fn test_security_context_ext() {
let mut req = Request::from_method_uri(Method::GET, "/test");
let ctx = SecurityContextExt::set_to_request(&mut req);
let ctx2 = SecurityContextExt::from_request(&req).unwrap();
assert_eq!(Arc::as_ptr(&ctx), Arc::as_ptr(&ctx2));
let auth = Authentication {
principal: "john".to_string(),
credentials: None,
authorities: vec![],
authenticated: true,
details: None,
login_time: chrono::Utc::now(),
};
ctx.set_authentication(auth.clone()).await;
assert!(ctx.is_authenticated().await);
assert_eq!(ctx.get_username().await, Some("john".to_string()));
let auth_from_req = get_authentication_from_request(&req).await;
assert_eq!(auth_from_req, Some(auth));
}
}