use crate::validators::{
authorizer::{ClerkAuthorizer, ClerkError, ClerkRequest},
jwks::JwksProvider,
};
use rocket::{
http::Status,
request::{FromRequest, Outcome},
Request,
};
use super::authorizer::ClerkJwt;
impl<'r> ClerkRequest for &'r Request<'_> {
fn get_header(&self, key: &str) -> Option<String> {
self.headers().get_one(key).map(|s| s.to_string())
}
fn get_cookie(&self, key: &str) -> Option<String> {
self.cookies().get(key).map(|cookie| cookie.value().to_string())
}
}
pub struct ClerkGuardConfig<J: JwksProvider> {
pub authorizer: ClerkAuthorizer<J>,
pub routes: Option<Vec<String>>,
}
impl<J: JwksProvider> ClerkGuardConfig<J> {
pub fn new(jwks_provider: J, routes: Option<Vec<String>>, validate_session_cookie: bool) -> Self {
let authorizer = ClerkAuthorizer::new(jwks_provider, validate_session_cookie);
Self { authorizer, routes }
}
}
pub struct ClerkGuard<J: JwksProvider + Send + Sync> {
pub jwt: Option<ClerkJwt>,
_marker: std::marker::PhantomData<J>,
}
#[rocket::async_trait]
impl<'r, J: JwksProvider + Send + Sync + 'static> FromRequest<'r> for ClerkGuard<J> {
type Error = ClerkError;
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let config = request
.rocket()
.state::<ClerkGuardConfig<J>>()
.expect("ClerkGuardConfig not found in managed state");
match &config.routes {
Some(route_matches) => {
let path = request.uri().path();
let path_not_in_specified_routes = route_matches.iter().find(|&route| route == &path.to_string()).is_none();
if path_not_in_specified_routes {
return Outcome::Success(ClerkGuard {
jwt: None,
_marker: std::marker::PhantomData,
});
}
}
None => {}
}
match config.authorizer.authorize(&request).await {
Ok(jwt) => {
request.local_cache(|| jwt.clone());
return Outcome::Success(ClerkGuard {
jwt: Some(jwt),
_marker: std::marker::PhantomData,
});
}
Err(error) => match error {
ClerkError::Unauthorized(msg) => Outcome::Error((Status::Unauthorized, ClerkError::Unauthorized(msg))),
ClerkError::InternalServerError(msg) => Outcome::Error((Status::InternalServerError, ClerkError::InternalServerError(msg))),
},
}
}
}