use std::sync::Arc;
use tasker_shared::types::{SecurityContext, SecurityService};
use tonic::{Request, Status};
#[allow(dead_code, reason = "dead in --lib, used by test targets")]
pub const SECURITY_CONTEXT_KEY: &str = "security-context";
#[derive(Clone, Debug)]
pub struct AuthInterceptor {
security_service: Option<Arc<SecurityService>>,
}
impl AuthInterceptor {
pub fn new(security_service: Option<Arc<SecurityService>>) -> Self {
Self { security_service }
}
#[allow(dead_code, reason = "dead in --lib, used by test targets")]
pub fn is_enabled(&self) -> bool {
self.security_service
.as_ref()
.map(|s| s.is_enabled())
.unwrap_or(false)
}
pub async fn authenticate<T>(&self, request: &Request<T>) -> Result<SecurityContext, Status> {
let security_service = match &self.security_service {
Some(svc) if svc.is_enabled() => svc,
_ => {
return Ok(SecurityContext::disabled_context());
}
};
let bearer_token = request
.metadata()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|s| {
s.strip_prefix("Bearer ")
.or_else(|| s.strip_prefix("bearer "))
})
.map(|t| t.to_string());
let api_key = request
.metadata()
.get(security_service.api_key_header())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(token) = bearer_token {
security_service
.authenticate_bearer(&token)
.await
.map_err(|e| {
tracing::warn!(error = %e, "Bearer token authentication failed");
Status::unauthenticated("Invalid or expired credentials")
})
} else if let Some(key) = api_key {
security_service.authenticate_api_key(&key).map_err(|e| {
tracing::warn!(error = %e, "API key authentication failed");
Status::unauthenticated("Invalid or expired credentials")
})
} else {
Err(Status::unauthenticated(
"Authentication required. Provide Bearer token or API key.",
))
}
}
}
#[allow(dead_code, reason = "dead in --lib, used by test targets")]
pub trait SecurityContextExt {
fn security_context(&self) -> Option<&SecurityContext>;
}
impl<T> SecurityContextExt for Request<T> {
fn security_context(&self) -> Option<&SecurityContext> {
self.extensions().get::<SecurityContext>()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_with_none() {
let interceptor = AuthInterceptor::new(None);
assert!(!interceptor.is_enabled());
}
#[test]
fn test_clone() {
let interceptor = AuthInterceptor::new(None);
let cloned = interceptor.clone();
assert!(!cloned.is_enabled());
}
#[test]
fn test_debug() {
let interceptor = AuthInterceptor::new(None);
let debug = format!("{:?}", interceptor);
assert!(debug.contains("AuthInterceptor"));
}
#[tokio::test]
async fn test_authenticate_disabled_returns_permissive_context() {
let interceptor = AuthInterceptor::new(None);
let request = Request::new(());
let result = interceptor.authenticate(&request).await;
assert!(result.is_ok());
let ctx = result.unwrap();
assert!(!interceptor.is_enabled());
assert!(ctx.has_permission(&tasker_shared::types::Permission::TasksRead));
}
#[tokio::test]
async fn test_security_context_ext_none_when_not_set() {
let request = Request::new(());
assert!(request.security_context().is_none());
}
#[test]
fn test_security_context_key_constant() {
assert_eq!(SECURITY_CONTEXT_KEY, "security-context");
}
}