pub mod api_key;
pub mod authentication_context;
pub mod jwk;
pub mod jwt;
use actix_web::http::header::{self, HeaderMap};
use async_trait::async_trait;
pub use authentication_context::{AuthenticationContext, authentication_context_provider};
use serde_json::Value;
use sha2::Digest;
use thiserror::Error;
use tracing::debug;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fmt::Display};
use utoipa::IntoResponses;
use crate::security::{
api_key::ApiKeyAuthContextConfig,
jwt::{JwtAuthContextConfig, JwtClaims},
};
#[derive(Debug, Error, IntoResponses)]
#[error("{}", _0)]
#[response(status = StatusCode::UNAUTHORIZED)]
pub struct AuthorizationError(pub String);
#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)]
pub enum Credentials {
JsonWebToken(String),
ApiKey(String),
}
impl Display for Credentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let text = match self {
Credentials::ApiKey(_) => "ApiKey",
Credentials::JsonWebToken(_) => "JsonWebToken",
};
write!(f, "{text}")
}
}
impl Credentials {
pub fn credential_data(&self) -> &str {
match self {
Credentials::ApiKey(val) => val,
Credentials::JsonWebToken(val) => val,
}
}
pub fn parse_credentials_from_request_headers(
headers: &HeaderMap,
) -> Result<Credentials, AuthorizationError> {
let auth_header = headers.get(header::AUTHORIZATION);
auth_header
.and_then(|header| header.to_str().ok())
.and_then(|hval| hval.split_once(' '))
.and_then(|(credential_type, credential_data)| match credential_type {
"Bearer" => Some(Credentials::JsonWebToken(credential_data.to_string())),
"ApiKey" => Some(Credentials::ApiKey(credential_data.to_string())),
_ => None,
})
.ok_or_else(|| {
debug!("Could not parse Authorization header: {:?}", auth_header);
AuthorizationError("Invalid credentials".to_string())
})
}
}
#[async_trait]
pub trait Authenticating<C: Send, R: Send, E: Into<AuthorizationError>> {
async fn authenticate(&self, credentials: &C) -> Result<R, E>;
}
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Principal {
#[serde(alias = "sub")]
#[serde(alias = "user_id")]
#[serde(alias = "userId")]
id: String,
#[serde(default, flatten)]
attributes: HashMap<String, Value>,
}
impl Principal {
pub fn id(&self) -> &str {
&self.id
}
pub fn get_attribute<T>(&self, key: &str) -> Option<T>
where
T: for<'de> Deserialize<'de>,
{
self.attributes
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn get_raw_attribute(&self, key: &str) -> Option<&serde_json::Value> {
self.attributes.get(key)
}
pub fn has_attribute(&self, key: &str) -> bool {
self.attributes.contains_key(key)
}
pub fn attributes(&self) -> &HashMap<String, serde_json::Value> {
&self.attributes
}
pub fn builder(id: impl Into<String>) -> PrincipalBuilder {
PrincipalBuilder::new(id)
}
}
impl From<JwtClaims> for Principal {
fn from(value: JwtClaims) -> Self {
Principal::builder(value.user_id)
.with_attribute("audience", value.audience)
.with_attribute("issued_at", value.issued_at)
.with_attribute("issuer", value.issuer)
.with_attribute("token_id", value.token_id)
.with_attribute("not_before", value.not_before)
.with_attribute("authorizing_party", value.authorizing_party)
.with_attribute("expiration", value.expiration)
.with_attributes(value.additional_claims)
.build()
}
}
impl Display for Principal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = if let Some(name) = self.get_attribute("name") {
name
} else if let Some(username) = self.get_attribute("preferred_username") {
username
} else {
self.id.clone()
};
write!(f, "Principal({name})")
}
}
#[derive(Debug, Clone)]
pub struct PrincipalBuilder {
id: String,
attributes: HashMap<String, serde_json::Value>,
}
impl PrincipalBuilder {
pub fn new(id: impl Into<String>) -> Self {
Self {
id: id.into(),
attributes: HashMap::default(),
}
}
pub fn with_attribute(mut self, key: impl Into<String>, value: impl Serialize) -> Self {
if let Ok(json) = serde_json::to_value(value) {
self.attributes.insert(key.into(), json);
}
self
}
pub fn with_attributes(
mut self,
attrs: impl IntoIterator<Item = (impl Into<String>, impl Serialize)>,
) -> Self {
self.attributes.extend(
attrs
.into_iter()
.filter_map(|(k, v)| serde_json::to_value(v).ok().map(|v| (k.into(), v))),
);
self
}
pub fn build(self) -> Principal {
Principal {
id: self.id,
attributes: self.attributes,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct AuthenticationConfig {
api_key: Option<ApiKeyAuthContextConfig>,
jwt: Option<JwtAuthContextConfig>,
}
impl AuthenticationConfig {
pub fn with_api_key_enabled(mut self, config: ApiKeyAuthContextConfig) -> Self {
self.api_key = Some(config);
self
}
pub fn with_jwt_enabled(mut self, config: JwtAuthContextConfig) -> Self {
self.jwt = Some(config);
self
}
pub fn api_key(&self) -> Option<&ApiKeyAuthContextConfig> {
self.api_key.as_ref()
}
pub fn jwt(&self) -> Option<&JwtAuthContextConfig> {
self.jwt.as_ref()
}
}
#[async_trait]
impl Authenticating<Credentials, Principal, AuthorizationError> for AuthenticationConfig {
async fn authenticate(
&self,
credentials: &Credentials,
) -> Result<Principal, AuthorizationError> {
match credentials {
Credentials::ApiKey(creds) => {
let conf = self.api_key().ok_or_else(|| {
AuthorizationError("Api Key authentication not enabled".to_string())
})?;
let hashed = sha2::Sha256::digest(creds.as_bytes())
.iter()
.map(|b| format!("{:02x}", b))
.collect::<String>();
Ok(conf
.get(&hashed)
.await
.map_err(|e| AuthorizationError(e.to_string()))?
.ok_or_else(|| AuthorizationError("Unknown Api Key".to_string()))?)
}
Credentials::JsonWebToken(creds) => {
let conf = self.jwt().ok_or_else(|| {
AuthorizationError("JWT authentication not enabled".to_string())
})?;
let token = conf
.verify_token(creds)
.await
.map_err(|e| AuthorizationError(e.to_string()))?
.claims;
Ok(Principal::from(token))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::{HttpMessage, http::header::HeaderName, test::TestRequest};
use uuid::Uuid;
#[test]
fn test_principal_builder_and_display() {
let user_id = Uuid::new_v4().to_string();
let mut app_user = PrincipalBuilder::new(user_id.clone());
assert_eq!(app_user.clone().build().id(), &user_id);
assert_eq!(
app_user.clone().build().to_string(),
format!("Principal({})", user_id)
);
app_user = app_user.with_attribute("preferred_username", "SArcher");
assert_eq!(
app_user.clone().build().to_string(),
"Principal(SArcher)".to_string()
);
app_user = app_user.with_attributes(vec![("name".to_string(), "Sterling Archer")]);
assert_eq!(app_user.build().to_string(), "Principal(Sterling Archer)");
}
#[test]
fn test_credential_type_printing() {
let cred_type_api_key = Credentials::ApiKey("test".to_string());
let cred_type_jwt = Credentials::JsonWebToken("test".to_string());
assert_eq!(format!("{cred_type_api_key}"), "ApiKey");
assert_eq!(format!("{cred_type_jwt}"), "JsonWebToken");
}
#[test]
fn test_parse_authentication_header() {
const APIKEY: &str = "ApiKey test1234";
const JWT: &str = "Bearer a.b.c";
const APIKEY_HEADER_PAIR: (HeaderName, &str) = (header::AUTHORIZATION, APIKEY);
const JWT_HEADER_PAIR: (HeaderName, &str) = (header::AUTHORIZATION, JWT);
let apikey_header_map_request = TestRequest::default()
.insert_header(APIKEY_HEADER_PAIR)
.to_request();
let jwt_header_map_request = TestRequest::default()
.insert_header(JWT_HEADER_PAIR)
.to_request();
let auth_creds_apikey_parse = Credentials::parse_credentials_from_request_headers(
apikey_header_map_request.headers(),
)
.unwrap();
let auth_creds_jwt_parse =
Credentials::parse_credentials_from_request_headers(jwt_header_map_request.headers())
.unwrap();
assert_eq!(
auth_creds_jwt_parse,
Credentials::JsonWebToken("a.b.c".to_string())
);
assert_eq!(
auth_creds_jwt_parse.credential_data(),
JWT.split_whitespace().last().unwrap()
);
assert_eq!(
auth_creds_apikey_parse,
Credentials::ApiKey("test1234".to_string())
);
assert_eq!(auth_creds_apikey_parse.credential_data(), "test1234");
}
}