restrepo 0.5.12

A collection of components for building restful webservices with actix-web
Documentation
//! Provides authentication related functionality,
//! such as session user representation and
//! authentication middleware
pub mod api_key;
pub mod authentication_context;
pub mod jwk;
pub mod jwt;

use actix_web::http::header::{self, HeaderMap};
use async_trait::async_trait;
pub use authentication_context::{AuthenticationContext, authentication_context_provider};
use serde_json::Value;
use sha2::Digest;
use thiserror::Error;
use tracing::debug;

use serde::{Deserialize, Serialize};
use std::{collections::HashMap, fmt::Display};
use utoipa::IntoResponses;

use crate::security::{
    api_key::ApiKeyAuthContextConfig,
    jwt::{JwtAuthContextConfig, JwtClaims},
};

/// The server could not authenticate the request
#[derive(Debug, Error, IntoResponses)]
#[error("{}", _0)]
#[response(status = StatusCode::UNAUTHORIZED)]
pub struct AuthorizationError(pub String);

/// Represents methods of authentication and holds raw credentials extracted from the
/// (Authorization header)[actix_web::http::header::AUTHORIZATION]
#[derive(Debug, Clone, Deserialize, PartialEq, Serialize)]
pub enum Credentials {
    JsonWebToken(String),
    ApiKey(String),
}

impl Display for Credentials {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let text = match self {
            Credentials::ApiKey(_) => "ApiKey",
            Credentials::JsonWebToken(_) => "JsonWebToken",
        };
        write!(f, "{text}")
    }
}

impl Credentials {
    pub fn credential_data(&self) -> &str {
        match self {
            Credentials::ApiKey(val) => val,
            Credentials::JsonWebToken(val) => val,
        }
    }

    /// Attempt to parse the value of a requests [Authorization Header](header::AUTHORIZATION) header into [Credentials].
    pub fn parse_credentials_from_request_headers(
        headers: &HeaderMap,
    ) -> Result<Credentials, AuthorizationError> {
        let auth_header = headers.get(header::AUTHORIZATION);
        auth_header
            .and_then(|header| header.to_str().ok())
            .and_then(|hval| hval.split_once(' '))
            .and_then(|(credential_type, credential_data)| match credential_type {
                "Bearer" => Some(Credentials::JsonWebToken(credential_data.to_string())),
                "ApiKey" => Some(Credentials::ApiKey(credential_data.to_string())),
                _ => None,
            })
            .ok_or_else(|| {
                debug!("Could not parse Authorization header: {:?}", auth_header);
                AuthorizationError("Invalid credentials".to_string())
            })
    }
}

/// A trait to implement request authentication. Takes credentials
#[async_trait]
pub trait Authenticating<C: Send, R: Send, E: Into<AuthorizationError>> {
    async fn authenticate(&self, credentials: &C) -> Result<R, E>;
}

/// A generalised representation of an Identity interacting with the service. Consists of an (ideally uniquely)
/// identifying property and information about the time of authentication and arbitrary identity attributes
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct Principal {
    #[serde(alias = "sub")]
    #[serde(alias = "user_id")]
    #[serde(alias = "userId")]
    id: String,
    #[serde(default, flatten)]
    attributes: HashMap<String, Value>,
}

impl Principal {
    /// Get this principals (ideally unique) identifier
    pub fn id(&self) -> &str {
        &self.id
    }

    /// Generically retrieve an attribute
    pub fn get_attribute<T>(&self, key: &str) -> Option<T>
    where
        T: for<'de> Deserialize<'de>,
    {
        self.attributes
            .get(key)
            .and_then(|v| serde_json::from_value(v.clone()).ok())
    }

    /// Retrieve a plain attribute
    pub fn get_raw_attribute(&self, key: &str) -> Option<&serde_json::Value> {
        self.attributes.get(key)
    }

    /// Check if principals attributes contain key
    pub fn has_attribute(&self, key: &str) -> bool {
        self.attributes.contains_key(key)
    }

    /// Retrieve entire attribute map
    pub fn attributes(&self) -> &HashMap<String, serde_json::Value> {
        &self.attributes
    }

    /// Create a [PrincipalBuilder]. Equivalent to calling [PrincipalBuilder::new]
    pub fn builder(id: impl Into<String>) -> PrincipalBuilder {
        PrincipalBuilder::new(id)
    }
}

impl From<JwtClaims> for Principal {
    fn from(value: JwtClaims) -> Self {
        Principal::builder(value.user_id)
            .with_attribute("audience", value.audience)
            .with_attribute("issued_at", value.issued_at)
            .with_attribute("issuer", value.issuer)
            .with_attribute("token_id", value.token_id)
            .with_attribute("not_before", value.not_before)
            .with_attribute("authorizing_party", value.authorizing_party)
            .with_attribute("expiration", value.expiration)
            .with_attributes(value.additional_claims)
            .build()
    }
}

impl Display for Principal {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let name = if let Some(name) = self.get_attribute("name") {
            name
        } else if let Some(username) = self.get_attribute("preferred_username") {
            username
        } else {
            self.id.clone()
        };
        write!(f, "Principal({name})")
    }
}

#[derive(Debug, Clone)]
pub struct PrincipalBuilder {
    id: String,
    attributes: HashMap<String, serde_json::Value>,
}

impl PrincipalBuilder {
    pub fn new(id: impl Into<String>) -> Self {
        Self {
            id: id.into(),
            attributes: HashMap::default(),
        }
    }

    /// Insert or update a single entry into the attribute map
    pub fn with_attribute(mut self, key: impl Into<String>, value: impl Serialize) -> Self {
        if let Ok(json) = serde_json::to_value(value) {
            self.attributes.insert(key.into(), json);
        }
        self
    }

    /// Update the attribute map with several entries from an iterator
    pub fn with_attributes(
        mut self,
        attrs: impl IntoIterator<Item = (impl Into<String>, impl Serialize)>,
    ) -> Self {
        self.attributes.extend(
            attrs
                .into_iter()
                .filter_map(|(k, v)| serde_json::to_value(v).ok().map(|v| (k.into(), v))),
        );
        self
    }

    /// Consume the builder and produce [Principal]
    pub fn build(self) -> Principal {
        Principal {
            id: self.id,
            attributes: self.attributes,
        }
    }
}

#[derive(Clone, Debug, Default)]
pub struct AuthenticationConfig {
    api_key: Option<ApiKeyAuthContextConfig>,
    jwt: Option<JwtAuthContextConfig>,
}

impl AuthenticationConfig {
    pub fn with_api_key_enabled(mut self, config: ApiKeyAuthContextConfig) -> Self {
        self.api_key = Some(config);
        self
    }

    pub fn with_jwt_enabled(mut self, config: JwtAuthContextConfig) -> Self {
        self.jwt = Some(config);
        self
    }

    pub fn api_key(&self) -> Option<&ApiKeyAuthContextConfig> {
        self.api_key.as_ref()
    }

    pub fn jwt(&self) -> Option<&JwtAuthContextConfig> {
        self.jwt.as_ref()
    }
}

#[async_trait]
impl Authenticating<Credentials, Principal, AuthorizationError> for AuthenticationConfig {
    async fn authenticate(
        &self,
        credentials: &Credentials,
    ) -> Result<Principal, AuthorizationError> {
        match credentials {
            Credentials::ApiKey(creds) => {
                let conf = self.api_key().ok_or_else(|| {
                    AuthorizationError("Api Key authentication not enabled".to_string())
                })?;
                let hashed = sha2::Sha256::digest(creds.as_bytes())
                    .iter()
                    .map(|b| format!("{:02x}", b))
                    .collect::<String>();
                Ok(conf
                    .get(&hashed)
                    .await
                    .map_err(|e| AuthorizationError(e.to_string()))?
                    .ok_or_else(|| AuthorizationError("Unknown Api Key".to_string()))?)
            }
            Credentials::JsonWebToken(creds) => {
                let conf = self.jwt().ok_or_else(|| {
                    AuthorizationError("JWT authentication not enabled".to_string())
                })?;
                let token = conf
                    .verify_token(creds)
                    .await
                    .map_err(|e| AuthorizationError(e.to_string()))?
                    .claims;
                Ok(Principal::from(token))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use actix_web::{HttpMessage, http::header::HeaderName, test::TestRequest};
    use uuid::Uuid;

    #[test]
    fn test_principal_builder_and_display() {
        let user_id = Uuid::new_v4().to_string();
        let mut app_user = PrincipalBuilder::new(user_id.clone());
        assert_eq!(app_user.clone().build().id(), &user_id);
        assert_eq!(
            app_user.clone().build().to_string(),
            format!("Principal({})", user_id)
        );
        app_user = app_user.with_attribute("preferred_username", "SArcher");
        assert_eq!(
            app_user.clone().build().to_string(),
            "Principal(SArcher)".to_string()
        );
        app_user = app_user.with_attributes(vec![("name".to_string(), "Sterling Archer")]);
        assert_eq!(app_user.build().to_string(), "Principal(Sterling Archer)");
    }

    #[test]
    fn test_credential_type_printing() {
        let cred_type_api_key = Credentials::ApiKey("test".to_string());
        let cred_type_jwt = Credentials::JsonWebToken("test".to_string());
        assert_eq!(format!("{cred_type_api_key}"), "ApiKey");
        assert_eq!(format!("{cred_type_jwt}"), "JsonWebToken");
    }

    #[test]
    fn test_parse_authentication_header() {
        const APIKEY: &str = "ApiKey test1234";
        const JWT: &str = "Bearer a.b.c";
        const APIKEY_HEADER_PAIR: (HeaderName, &str) = (header::AUTHORIZATION, APIKEY);
        const JWT_HEADER_PAIR: (HeaderName, &str) = (header::AUTHORIZATION, JWT);
        let apikey_header_map_request = TestRequest::default()
            .insert_header(APIKEY_HEADER_PAIR)
            .to_request();
        let jwt_header_map_request = TestRequest::default()
            .insert_header(JWT_HEADER_PAIR)
            .to_request();
        let auth_creds_apikey_parse = Credentials::parse_credentials_from_request_headers(
            apikey_header_map_request.headers(),
        )
        .unwrap();
        let auth_creds_jwt_parse =
            Credentials::parse_credentials_from_request_headers(jwt_header_map_request.headers())
                .unwrap();
        assert_eq!(
            auth_creds_jwt_parse,
            Credentials::JsonWebToken("a.b.c".to_string())
        );
        assert_eq!(
            auth_creds_jwt_parse.credential_data(),
            JWT.split_whitespace().last().unwrap()
        );
        assert_eq!(
            auth_creds_apikey_parse,
            Credentials::ApiKey("test1234".to_string())
        );
        assert_eq!(auth_creds_apikey_parse.credential_data(), "test1234");
    }
}