use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::{Request, Response};
use subtle::ConstantTimeEq as _;
use tonic::Status;
use tonic::body::Body as TonicBody;
use tower::{Layer, Service};
use tracing::warn;
use webgates::accounts::Account;
use webgates::authz::access_hierarchy::AccessHierarchy;
use webgates::authz::access_policy::AccessPolicy;
use webgates::codecs::Codec;
use webgates::codecs::jwt::JwtClaims;
use crate::context::{JwtAuthContext, OptionalJwtAuthContext, StaticTokenAuthorized};
use crate::errors::AuthError;
#[derive(Clone)]
pub struct JwtConfig<R, G>
where
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq,
{
policy: AccessPolicy<R, G>,
optional: bool,
}
impl<R, G> std::fmt::Debug for JwtConfig<R, G>
where
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtConfig")
.field("optional", &self.optional)
.finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct StaticTokenConfig {
token: String,
optional: bool,
}
impl std::fmt::Debug for StaticTokenConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StaticTokenConfig")
.field("token", &"<redacted>")
.field("optional", &self.optional)
.finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct BearerGate<C, R, G, M>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq,
{
issuer: String,
codec: Arc<C>,
mode: M,
_phantom: std::marker::PhantomData<(R, G)>,
}
impl<C, R, G> std::fmt::Debug for BearerGate<C, R, G, JwtConfig<R, G>>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BearerGate")
.field("issuer", &self.issuer)
.field("mode", &self.mode)
.finish_non_exhaustive()
}
}
impl<C, R, G> std::fmt::Debug for BearerGate<C, R, G, StaticTokenConfig>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BearerGate")
.field("issuer", &self.issuer)
.field("mode", &self.mode)
.finish_non_exhaustive()
}
}
impl<C, R, G> BearerGate<C, R, G, JwtConfig<R, G>>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq + Clone,
{
pub(crate) fn new_with_codec(issuer: &str, codec: Arc<C>) -> Self {
Self {
issuer: issuer.to_owned(),
codec,
mode: JwtConfig {
policy: AccessPolicy::deny_all(),
optional: false,
},
_phantom: std::marker::PhantomData,
}
}
pub fn with_policy(mut self, policy: AccessPolicy<R, G>) -> Self {
self.mode.policy = policy;
self
}
pub fn require_login(mut self) -> Self
where
R: Default,
{
let baseline = R::default();
self.mode.policy = AccessPolicy::require_role_or_supervisor(baseline);
self
}
pub fn allow_anonymous_with_optional_user(mut self) -> Self {
self.mode.optional = true;
self
}
pub fn with_static_token(
self,
token: impl Into<String>,
) -> BearerGate<C, R, G, StaticTokenConfig> {
BearerGate {
issuer: self.issuer,
codec: self.codec,
mode: StaticTokenConfig {
token: token.into(),
optional: false,
},
_phantom: std::marker::PhantomData,
}
}
}
impl<C, R, G> BearerGate<C, R, G, StaticTokenConfig>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq + Clone,
{
pub fn allow_anonymous_with_optional_user(mut self) -> Self {
self.mode.optional = true;
self
}
}
impl<S, C, R, G> Layer<S> for BearerGate<C, R, G, JwtConfig<R, G>>
where
C: Codec<Payload = JwtClaims<Account<R, G>>> + Send + Sync + 'static,
R: AccessHierarchy + Eq + std::fmt::Display + Clone + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
type Service = JwtBearerService<C, R, G, S>;
fn layer(&self, inner: S) -> Self::Service {
if self.mode.optional {
JwtBearerService::new_optional(
inner,
&self.issuer,
self.mode.policy.clone(),
Arc::clone(&self.codec),
)
} else {
JwtBearerService::new(
inner,
&self.issuer,
self.mode.policy.clone(),
Arc::clone(&self.codec),
)
}
}
}
impl<S, C, R, G> Layer<S> for BearerGate<C, R, G, StaticTokenConfig>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq + Clone,
{
type Service = StaticTokenService<S>;
fn layer(&self, inner: S) -> Self::Service {
if self.mode.optional {
StaticTokenService::new_optional(inner, self.mode.token.clone())
} else {
StaticTokenService::new(inner, self.mode.token.clone())
}
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct JwtBearerService<C, R, G, S>
where
C: Codec<Payload = JwtClaims<Account<R, G>>>,
R: AccessHierarchy + Eq + std::fmt::Display + Clone,
G: Eq + Clone,
{
inner: S,
runtime: webgates::gate::bearer::JwtBearerRuntime<C, R, G>,
}
impl<C, R, G, S> JwtBearerService<C, R, G, S>
where
C: Codec<Payload = JwtClaims<Account<R, G>>>,
R: AccessHierarchy + Eq + std::fmt::Display + Clone,
G: Eq + Clone,
{
fn new(inner: S, issuer: &str, policy: AccessPolicy<R, G>, codec: Arc<C>) -> Self {
Self {
inner,
runtime: webgates::gate::bearer::JwtBearerRuntime::new(issuer, policy, codec, false),
}
}
fn new_optional(inner: S, issuer: &str, policy: AccessPolicy<R, G>, codec: Arc<C>) -> Self {
Self {
inner,
runtime: webgates::gate::bearer::JwtBearerRuntime::new(issuer, policy, codec, true),
}
}
fn extract_bearer_token(req: &Request<TonicBody>) -> Result<Option<&str>, AuthError> {
let Some(value) = req.headers().get(http::header::AUTHORIZATION) else {
return Ok(None);
};
let text: &str = value
.to_str()
.map_err(|_| AuthError::MalformedAuthorizationMetadata)?
.trim();
let mut parts = text.split_whitespace();
let scheme = parts
.next()
.ok_or(AuthError::MalformedAuthorizationMetadata)?;
if !scheme.eq_ignore_ascii_case("Bearer") {
return Err(AuthError::MalformedAuthorizationMetadata);
}
let token = parts
.next()
.ok_or(AuthError::MalformedAuthorizationMetadata)?;
Ok(Some(token))
}
}
impl<C, R, G, S> Service<Request<TonicBody>> for JwtBearerService<C, R, G, S>
where
S: Service<Request<TonicBody>, Response = Response<TonicBody>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
C: Codec<Payload = JwtClaims<Account<R, G>>> + Send + Sync + 'static,
R: AccessHierarchy + Eq + std::fmt::Display + Clone + Send + Sync + 'static,
G: Eq + Clone + Send + Sync + 'static,
{
type Response = Response<TonicBody>;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<TonicBody>) -> Self::Future {
#[cfg(feature = "audit-logging")]
use webgates::audit;
#[cfg(feature = "audit-logging")]
let _span = audit::request_span(req.method().as_str(), req.uri().path(), None);
let token_result = Self::extract_bearer_token(&req);
let token = match token_result {
Ok(t) => t,
Err(err) => {
let status = err.into_status();
return Box::pin(async move { Ok(status_to_response(status)) });
}
};
let eval = self.runtime.evaluate(token);
match eval {
webgates::gate::bearer::BearerEvaluation::JwtOptionalAnonymous => {
req.extensions_mut()
.insert(OptionalJwtAuthContext::<R, G>::anonymous());
let fut = self.inner.call(req);
Box::pin(fut)
}
webgates::gate::bearer::BearerEvaluation::JwtOptionalAuthorized {
account,
registered_claims,
} => {
req.extensions_mut()
.insert(OptionalJwtAuthContext::<R, G>::authenticated(
account,
registered_claims,
));
let fut = self.inner.call(req);
Box::pin(fut)
}
webgates::gate::bearer::BearerEvaluation::JwtDenyAllPolicy => {
#[cfg(feature = "audit-logging")]
audit::denied(None, "policy_denies_all");
let status = AuthError::PolicyDeniesAll.into_status();
Box::pin(async move { Ok(status_to_response(status)) })
}
webgates::gate::bearer::BearerEvaluation::JwtMissingToken => {
#[cfg(feature = "audit-logging")]
audit::denied(None, "missing_authorization_header");
let status = AuthError::MissingAuthorizationMetadata.into_status();
Box::pin(async move { Ok(status_to_response(status)) })
}
webgates::gate::bearer::BearerEvaluation::JwtInvalidToken => {
#[cfg(feature = "audit-logging")]
audit::jwt_invalid_token("validation_failed");
let status = AuthError::InvalidToken.into_status();
Box::pin(async move { Ok(status_to_response(status)) })
}
webgates::gate::bearer::BearerEvaluation::JwtInvalidIssuer { expected, actual } => {
#[cfg(feature = "audit-logging")]
audit::jwt_invalid_issuer(&expected, &actual);
warn!(
"JWT issuer mismatch. Expected='{}', Actual='{}'",
expected, actual
);
let status = AuthError::InvalidIssuer.into_status();
Box::pin(async move { Ok(status_to_response(status)) })
}
webgates::gate::bearer::BearerEvaluation::JwtPolicyDenied { account_id } => {
#[cfg(not(feature = "audit-logging"))]
let _ = account_id;
#[cfg(feature = "audit-logging")]
audit::denied(Some(&account_id), "policy_denied");
let status = AuthError::PolicyDenied.into_status();
Box::pin(async move { Ok(status_to_response(status)) })
}
webgates::gate::bearer::BearerEvaluation::JwtAuthorized {
account,
registered_claims,
} => {
#[cfg(feature = "audit-logging")]
audit::authorized(&account.account_id, None);
req.extensions_mut()
.insert(JwtAuthContext::new(account, registered_claims));
let fut = self.inner.call(req);
Box::pin(fut)
}
_ => {
let status = AuthError::Internal.into_status();
Box::pin(async move { Ok(status_to_response(status)) })
}
}
}
}
#[doc(hidden)]
#[derive(Clone)]
pub struct StaticTokenService<S> {
inner: S,
token: String,
optional: bool,
}
impl<S> StaticTokenService<S> {
fn new(inner: S, token: String) -> Self {
Self {
inner,
token,
optional: false,
}
}
fn new_optional(inner: S, token: String) -> Self {
Self {
inner,
token,
optional: true,
}
}
fn extract_bearer_token(req: &Request<TonicBody>) -> Option<&str> {
let value = req.headers().get(http::header::AUTHORIZATION)?;
let text: &str = value.to_str().ok()?.trim();
let mut parts = text.split_whitespace();
let scheme = parts.next()?;
if !scheme.eq_ignore_ascii_case("Bearer") {
return None;
}
parts.next()
}
}
impl<S> Service<Request<TonicBody>> for StaticTokenService<S>
where
S: Service<Request<TonicBody>, Response = Response<TonicBody>, Error = Infallible>
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = Response<TonicBody>;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<TonicBody>) -> Self::Future {
#[cfg(feature = "audit-logging")]
use webgates::audit;
#[cfg(feature = "audit-logging")]
let _span = audit::request_span(req.method().as_str(), req.uri().path(), None);
if self.optional {
let provided = Self::extract_bearer_token(&req);
let authorized =
provided.is_some_and(|t| bool::from(t.as_bytes().ct_eq(self.token.as_bytes())));
req.extensions_mut()
.insert(StaticTokenAuthorized::new(authorized));
let fut = self.inner.call(req);
return Box::pin(fut);
}
let Some(provided) = Self::extract_bearer_token(&req) else {
#[cfg(feature = "audit-logging")]
audit::denied(None, "missing_authorization_header");
let status = AuthError::MissingAuthorizationMetadata.into_status();
return Box::pin(async move { Ok(status_to_response(status)) });
};
if !bool::from(provided.as_bytes().ct_eq(self.token.as_bytes())) {
#[cfg(feature = "audit-logging")]
audit::denied(None, "static_token_mismatch");
let status = AuthError::PolicyDenied.into_status();
return Box::pin(async move { Ok(status_to_response(status)) });
}
req.extensions_mut()
.insert(StaticTokenAuthorized::new(true));
let fut = self.inner.call(req);
Box::pin(fut)
}
}
fn status_to_response(status: Status) -> Response<TonicBody> {
status.into_http::<TonicBody>()
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use std::sync::Arc;
use http::Request;
use tower::{Layer, ServiceExt};
use chrono::Utc;
use webgates::accounts::Account;
use webgates::authz::access_policy::AccessPolicy;
use webgates::codecs::Codec as _;
use webgates::codecs::jwt::{JsonWebToken, JwtClaims, RegisteredClaims};
use webgates::groups::Group;
use webgates::roles::Role;
use super::*;
use crate::context::{JwtAuthContext, OptionalJwtAuthContext, StaticTokenAuthorized};
type TestBearerGateJwt = BearerGate<
JsonWebToken<JwtClaims<Account<Role, Group>>>,
Role,
Group,
JwtConfig<Role, Group>,
>;
fn install_jwt_crypto_provider() {
use webgates::codecs::jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER as JWT_CRYPTO_PROVIDER;
let _ = JWT_CRYPTO_PROVIDER.install_default();
}
fn make_request_no_auth() -> Request<TonicBody> {
Request::builder()
.uri("/test.Service/Method")
.body(TonicBody::empty())
.expect("request construction should succeed")
}
fn make_request_with_bearer(token: &str) -> Request<TonicBody> {
Request::builder()
.uri("/test.Service/Method")
.header(http::header::AUTHORIZATION, format!("Bearer {token}"))
.body(TonicBody::empty())
.expect("request construction should succeed")
}
fn echo_service() -> impl Service<
Request<TonicBody>,
Response = Response<TonicBody>,
Error = Infallible,
Future = impl Future<Output = Result<Response<TonicBody>, Infallible>> + Send + 'static,
> + Clone {
tower::service_fn(|_req: Request<TonicBody>| async {
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
})
}
#[tokio::test]
async fn strict_jwt_missing_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: TestBearerGateJwt = BearerGate::new_with_codec("issuer", codec)
.with_policy(AccessPolicy::require_role(Role::Admin));
let svc = gate.layer(echo_service());
let resp = svc.oneshot(make_request_no_auth()).await.expect("no error");
assert_eq!(
resp.status(),
http::StatusCode::OK,
"tonic uses 200 for gRPC status errors"
);
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(16), "expected UNAUTHENTICATED grpc-status");
}
#[tokio::test]
async fn strict_jwt_malformed_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: TestBearerGateJwt = BearerGate::new_with_codec("issuer", codec)
.with_policy(AccessPolicy::require_role(Role::Admin));
let svc = gate.layer(echo_service());
let req = Request::builder()
.uri("/test.Service/Method")
.header(http::header::AUTHORIZATION, "Basic dXNlcjpwYXNz")
.body(TonicBody::empty())
.expect("request construction should succeed");
let resp = svc.oneshot(req).await.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(16));
}
#[tokio::test]
async fn strict_jwt_deny_all_policy() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: TestBearerGateJwt = BearerGate::new_with_codec("issuer", codec);
let svc = gate.layer(echo_service());
let resp = svc
.oneshot(make_request_with_bearer("any-token"))
.await
.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(16));
}
#[tokio::test]
async fn strict_jwt_authorized() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let account = Account::<Role, Group>::new("user");
let exp = Utc::now().timestamp() as u64 + 60;
let claims = JwtClaims::new(account.clone(), RegisteredClaims::new("issuer", exp));
let encoded = codec.encode(&claims).expect("encode jwt");
let token = String::from_utf8(encoded).expect("utf-8");
let gate: TestBearerGateJwt =
BearerGate::new_with_codec("issuer", Arc::clone(&codec)).require_login();
let svc = gate.layer(tower::service_fn(move |req: Request<TonicBody>| {
let ctx = req
.extensions()
.get::<JwtAuthContext<Role, Group>>()
.cloned();
async move {
assert!(ctx.is_some(), "JwtAuthContext must be inserted on success");
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}
}));
let resp = svc
.oneshot(make_request_with_bearer(&token))
.await
.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert!(trailers.is_none() || trailers == Some(0));
}
#[tokio::test]
async fn optional_jwt_no_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: TestBearerGateJwt =
BearerGate::new_with_codec("issuer", codec).allow_anonymous_with_optional_user();
let svc = gate.layer(tower::service_fn(|req: Request<TonicBody>| async move {
let ctx = req
.extensions()
.get::<OptionalJwtAuthContext<Role, Group>>()
.cloned();
assert!(
ctx.is_some(),
"OptionalJwtAuthContext must always be inserted"
);
assert!(
!ctx.unwrap().is_authenticated(),
"must be anonymous when no token"
);
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}));
svc.oneshot(make_request_no_auth()).await.expect("no error");
}
#[tokio::test]
async fn optional_jwt_with_valid_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let account = Account::<Role, Group>::new("user");
let exp = Utc::now().timestamp() as u64 + 60;
let claims = JwtClaims::new(account.clone(), RegisteredClaims::new("issuer", exp));
let encoded = codec.encode(&claims).expect("encode jwt");
let token = String::from_utf8(encoded).expect("utf-8");
let gate: TestBearerGateJwt = BearerGate::new_with_codec("issuer", Arc::clone(&codec))
.allow_anonymous_with_optional_user();
let svc = gate.layer(tower::service_fn(
move |req: Request<TonicBody>| async move {
let ctx = req
.extensions()
.get::<OptionalJwtAuthContext<Role, Group>>()
.cloned();
assert!(ctx.is_some(), "OptionalJwtAuthContext must be present");
assert!(ctx.unwrap().is_authenticated(), "must be authenticated");
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
},
));
svc.oneshot(make_request_with_bearer(&token))
.await
.expect("no error");
}
#[tokio::test]
async fn optional_jwt_with_invalid_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: TestBearerGateJwt =
BearerGate::new_with_codec("issuer", codec).allow_anonymous_with_optional_user();
let svc = gate.layer(tower::service_fn(|req: Request<TonicBody>| async move {
let ctx = req
.extensions()
.get::<OptionalJwtAuthContext<Role, Group>>()
.cloned();
assert!(ctx.is_some(), "OptionalJwtAuthContext must be present");
assert!(
!ctx.unwrap().is_authenticated(),
"invalid token must result in anonymous context"
);
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}));
svc.oneshot(make_request_with_bearer("invalid-token-value"))
.await
.expect("no error");
}
#[tokio::test]
async fn static_token_strict_missing() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, StaticTokenConfig> =
BearerGate::new_with_codec("issuer", codec).with_static_token("secret");
let svc = gate.layer(echo_service());
let resp = svc.oneshot(make_request_no_auth()).await.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(16));
}
#[tokio::test]
async fn static_token_strict_wrong_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, StaticTokenConfig> =
BearerGate::new_with_codec("issuer", codec).with_static_token("correct-secret");
let svc = gate.layer(echo_service());
let resp = svc
.oneshot(make_request_with_bearer("wrong-secret"))
.await
.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(7));
}
#[tokio::test]
async fn static_token_strict_success() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, StaticTokenConfig> =
BearerGate::new_with_codec("issuer", codec).with_static_token("my-secret");
let svc = gate.layer(tower::service_fn(|req: Request<TonicBody>| async move {
let auth = req.extensions().get::<StaticTokenAuthorized>().copied();
assert!(auth.is_some(), "StaticTokenAuthorized must be inserted");
assert!(auth.unwrap().is_authorized(), "must be authorized");
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}));
svc.oneshot(make_request_with_bearer("my-secret"))
.await
.expect("no error");
}
#[tokio::test]
async fn static_token_optional_success() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, StaticTokenConfig> =
BearerGate::new_with_codec("issuer", codec)
.with_static_token("my-secret")
.allow_anonymous_with_optional_user();
let svc = gate.layer(tower::service_fn(|req: Request<TonicBody>| async move {
let auth = req.extensions().get::<StaticTokenAuthorized>().copied();
assert!(auth.is_some());
assert!(auth.unwrap().is_authorized());
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}));
svc.oneshot(make_request_with_bearer("my-secret"))
.await
.expect("no error");
}
#[tokio::test]
async fn static_token_optional_anonymous() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, StaticTokenConfig> =
BearerGate::new_with_codec("issuer", codec)
.with_static_token("my-secret")
.allow_anonymous_with_optional_user();
let svc = gate.layer(tower::service_fn(|req: Request<TonicBody>| async move {
let auth = req.extensions().get::<StaticTokenAuthorized>().copied();
assert!(auth.is_some());
assert!(!auth.unwrap().is_authorized());
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}));
svc.oneshot(make_request_no_auth()).await.expect("no error");
}
#[cfg(feature = "audit-logging")]
mod audit_logging_tests {
use std::convert::Infallible;
use chrono::Utc;
use http::{Request, Response};
use tower::{Layer, ServiceExt};
use webgates::accounts::Account;
use webgates::authz::access_policy::AccessPolicy;
use webgates::codecs::Codec as _;
use webgates::codecs::jwt::{JsonWebToken, JwtClaims, RegisteredClaims};
use webgates::groups::Group;
use webgates::roles::Role;
use super::*;
use crate::context::{JwtAuthContext, StaticTokenAuthorized};
fn install_crypto() {
use webgates::codecs::jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER as JWT_CRYPTO_PROVIDER;
let _ = JWT_CRYPTO_PROVIDER.install_default();
}
fn echo_service() -> impl Service<
Request<TonicBody>,
Response = Response<TonicBody>,
Error = Infallible,
Future = impl Future<Output = Result<Response<TonicBody>, Infallible>> + Send + 'static,
> + Clone {
tower::service_fn(|_req: Request<TonicBody>| async {
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
})
}
fn make_request_no_auth() -> Request<TonicBody> {
Request::builder()
.uri("/test.Service/Method")
.body(TonicBody::empty())
.expect("request construction should succeed")
}
fn make_request_with_bearer(token: &str) -> Request<TonicBody> {
Request::builder()
.uri("/test.Service/Method")
.header(http::header::AUTHORIZATION, format!("Bearer {token}"))
.body(TonicBody::empty())
.expect("request construction should succeed")
}
#[tokio::test]
async fn audit_jwt_authorized() {
install_crypto();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let account = Account::<Role, Group>::new("audit-user");
let exp = Utc::now().timestamp() as u64 + 60;
let claims = JwtClaims::new(account.clone(), RegisteredClaims::new("issuer", exp));
let encoded = codec.encode(&claims).expect("encode jwt");
let token = String::from_utf8(encoded).expect("utf-8");
let gate: BearerGate<_, Role, Group, JwtConfig<Role, Group>> =
BearerGate::new_with_codec("issuer", Arc::clone(&codec)).require_login();
let svc = gate.layer(tower::service_fn(move |req: Request<TonicBody>| {
let ctx = req
.extensions()
.get::<JwtAuthContext<Role, Group>>()
.cloned();
async move {
assert!(ctx.is_some(), "JwtAuthContext must be present");
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}
}));
svc.oneshot(make_request_with_bearer(&token))
.await
.expect("no error");
}
#[tokio::test]
async fn audit_jwt_policy_denied() {
install_crypto();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let account = Account::<Role, Group>::new("audit-user");
let exp = Utc::now().timestamp() as u64 + 60;
let claims = JwtClaims::new(account.clone(), RegisteredClaims::new("issuer", exp));
let encoded = codec.encode(&claims).expect("encode jwt");
let token = String::from_utf8(encoded).expect("utf-8");
let gate: BearerGate<_, Role, Group, JwtConfig<Role, Group>> =
BearerGate::new_with_codec("issuer", Arc::clone(&codec))
.with_policy(AccessPolicy::require_role(Role::Admin));
let svc = gate.layer(echo_service());
let resp = svc
.oneshot(make_request_with_bearer(&token))
.await
.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(7));
}
#[tokio::test]
async fn audit_jwt_invalid_issuer() {
install_crypto();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let account = Account::<Role, Group>::new("audit-user");
let exp = Utc::now().timestamp() as u64 + 60;
let claims =
JwtClaims::new(account.clone(), RegisteredClaims::new("other-issuer", exp));
let encoded = codec.encode(&claims).expect("encode jwt");
let token = String::from_utf8(encoded).expect("utf-8");
let gate: BearerGate<_, Role, Group, JwtConfig<Role, Group>> =
BearerGate::new_with_codec("issuer", Arc::clone(&codec)).require_login();
let svc = gate.layer(echo_service());
let resp = svc
.oneshot(make_request_with_bearer(&token))
.await
.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(16));
}
#[tokio::test]
async fn audit_jwt_invalid_token() {
install_crypto();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, JwtConfig<Role, Group>> =
BearerGate::new_with_codec("issuer", codec).require_login();
let svc = gate.layer(echo_service());
let resp = svc
.oneshot(make_request_with_bearer("not-a-valid-jwt"))
.await
.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(16));
}
#[tokio::test]
async fn audit_jwt_missing_auth() {
install_crypto();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, JwtConfig<Role, Group>> =
BearerGate::new_with_codec("issuer", codec).require_login();
let svc = gate.layer(echo_service());
let resp = svc.oneshot(make_request_no_auth()).await.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(16));
}
#[tokio::test]
async fn audit_static_token_mismatch() {
install_crypto();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, StaticTokenConfig> =
BearerGate::new_with_codec("issuer", codec).with_static_token("correct");
let svc = gate.layer(echo_service());
let resp = svc
.oneshot(make_request_with_bearer("wrong"))
.await
.expect("no error");
let trailers = resp
.headers()
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u32>().ok());
assert_eq!(trailers, Some(7));
}
#[tokio::test]
async fn audit_static_token_authorized() {
install_crypto();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate: BearerGate<_, Role, Group, StaticTokenConfig> =
BearerGate::new_with_codec("issuer", codec).with_static_token("secret");
let svc = gate.layer(tower::service_fn(|req: Request<TonicBody>| async move {
let auth = req.extensions().get::<StaticTokenAuthorized>().copied();
assert!(auth.is_some());
assert!(auth.unwrap().is_authorized());
Ok::<_, Infallible>(Response::new(TonicBody::empty()))
}));
svc.oneshot(make_request_with_bearer("secret"))
.await
.expect("no error");
}
}
}