securitydept-session-context 0.2.0-beta.6

Session Context of SecurityDept, a layered authentication and authorization toolkit built as reusable Rust crates.
Documentation
#[cfg(feature = "service")]
mod service;

use std::{collections::HashMap, time::Duration as StdDuration};

use http::StatusCode;
use securitydept_utils::{
    error::{ErrorPresentation, ToErrorPresentation, UserRecovery},
    principal::AuthenticatedPrincipal,
    redirect::{RedirectTargetConfig, RedirectTargetError, UriRelativeRedirectTargetResolver},
};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
#[cfg(feature = "service")]
pub use service::{
    DevSessionAuthService, OidcSessionAuthService, SessionAuthServiceError, SessionAuthServiceTrait,
};
use snafu::Snafu;
use tower_sessions::{
    Expiry, Session, SessionManagerLayer, SessionStore,
    cookie::{SameSite, time::Duration},
};
use typed_builder::TypedBuilder;

pub const DEFAULT_COOKIE_NAME: &str = "securitydept_session";
pub const DEFAULT_SESSION_CONTEXT_KEY: &str = "securitydept.session_context";

pub type SessionPrincipal = AuthenticatedPrincipal;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, TypedBuilder)]
pub struct SessionContext<Extra = HashMap<String, Value>> {
    pub principal: SessionPrincipal,
    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
    #[builder(default)]
    pub attributes: HashMap<String, Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    #[builder(default, setter(strip_option))]
    pub extra: Option<Extra>,
}

#[derive(Debug, Clone, Copy, Serialize, serde::Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum SessionCookieSameSite {
    Strict,
    #[default]
    Lax,
    None,
}

impl From<SessionCookieSameSite> for SameSite {
    fn from(value: SessionCookieSameSite) -> Self {
        match value {
            SessionCookieSameSite::Strict => SameSite::Strict,
            SessionCookieSameSite::Lax => SameSite::Lax,
            SessionCookieSameSite::None => SameSite::None,
        }
    }
}

#[derive(Debug, Clone, Serialize, serde::Deserialize, TypedBuilder)]
pub struct SessionContextConfig {
    #[builder(default = DEFAULT_COOKIE_NAME.to_string())]
    #[serde(default = "default_cookie_name")]
    pub cookie_name: String,
    #[builder(default = DEFAULT_SESSION_CONTEXT_KEY.to_string())]
    #[serde(default = "default_session_context_key")]
    pub session_context_key: String,
    #[builder(default = "/".to_string())]
    #[serde(default = "default_cookie_path")]
    pub cookie_path: String,
    #[builder(default = true)]
    #[serde(default = "default_true")]
    pub http_only: bool,
    #[builder(default = false)]
    #[serde(default)]
    pub secure: bool,
    #[builder(default)]
    #[serde(default)]
    pub same_site: SessionCookieSameSite,
    #[builder(default = Some(StdDuration::from_secs(86_400)))]
    #[serde(default = "default_ttl", with = "humantime_serde::option")]
    pub ttl: Option<StdDuration>,
    #[builder(default = default_post_auth_redirect())]
    #[serde(default = "default_post_auth_redirect")]
    pub post_auth_redirect: RedirectTargetConfig,
}

fn default_cookie_name() -> String {
    DEFAULT_COOKIE_NAME.to_string()
}

fn default_session_context_key() -> String {
    DEFAULT_SESSION_CONTEXT_KEY.to_string()
}

fn default_cookie_path() -> String {
    "/".to_string()
}

fn default_true() -> bool {
    true
}

fn default_ttl() -> Option<StdDuration> {
    Some(StdDuration::from_secs(86_400))
}

fn default_post_auth_redirect() -> RedirectTargetConfig {
    RedirectTargetConfig::strict_default("/")
}

impl Default for SessionContextConfig {
    fn default() -> Self {
        Self {
            cookie_name: default_cookie_name(),
            session_context_key: default_session_context_key(),
            cookie_path: default_cookie_path(),
            http_only: default_true(),
            secure: false,
            same_site: SessionCookieSameSite::default(),
            ttl: default_ttl(),
            post_auth_redirect: default_post_auth_redirect(),
        }
    }
}

pub fn build_session_layer<Store>(
    config: &SessionContextConfig,
    store: Store,
) -> SessionManagerLayer<Store>
where
    Store: SessionStore,
{
    let mut layer = SessionManagerLayer::new(store)
        .with_name(config.cookie_name.clone())
        .with_path(config.cookie_path.clone())
        .with_same_site(config.same_site.into())
        .with_http_only(config.http_only)
        .with_secure(config.secure);

    if let Some(ttl) = config.ttl {
        layer = layer.with_expiry(Expiry::OnInactivity(
            Duration::seconds(ttl.as_secs() as i64),
        ));
    }

    layer
}

#[derive(Debug, Snafu)]
pub enum SessionContextError {
    #[snafu(display("session context is missing"))]
    MissingContext,
    #[snafu(display("session operation failed: {source}"))]
    Session {
        source: tower_sessions::session::Error,
    },
    #[snafu(display("post-auth redirect is invalid: {source}"))]
    RedirectTarget { source: RedirectTargetError },
}

pub type SessionContextResult<T> = Result<T, SessionContextError>;

impl SessionContextError {
    pub fn status_code(&self) -> StatusCode {
        match self {
            Self::MissingContext => StatusCode::UNAUTHORIZED,
            Self::Session { .. } | Self::RedirectTarget { .. } => StatusCode::INTERNAL_SERVER_ERROR,
        }
    }
}

impl ToErrorPresentation for SessionContextError {
    fn to_error_presentation(&self) -> ErrorPresentation {
        match self {
            SessionContextError::MissingContext => ErrorPresentation::new(
                "authentication_required",
                "Sign in to continue.",
                UserRecovery::Reauthenticate,
            ),
            SessionContextError::Session { .. } => ErrorPresentation::new(
                "session_unavailable",
                "The session is temporarily unavailable.",
                UserRecovery::Retry,
            ),
            SessionContextError::RedirectTarget { .. } => ErrorPresentation::new(
                "session_post_auth_redirect_invalid",
                "The configured post-auth redirect is invalid.",
                UserRecovery::ContactSupport,
            ),
        }
    }
}

impl SessionContextConfig {
    pub fn resolve_post_auth_redirect(
        &self,
        requested_post_auth_redirect: Option<&str>,
    ) -> SessionContextResult<String> {
        UriRelativeRedirectTargetResolver::from_config(self.post_auth_redirect.clone())
            .map_err(|source| SessionContextError::RedirectTarget { source })?
            .resolve_redirect_target(requested_post_auth_redirect)
            .map(|value| value.to_string())
            .map_err(|source| SessionContextError::RedirectTarget { source })
    }
}

#[derive(Clone)]
pub struct SessionContextSession {
    session: Session,
    session_context_key: String,
}

impl From<Session> for SessionContextSession {
    fn from(session: Session) -> Self {
        Self {
            session,
            session_context_key: DEFAULT_SESSION_CONTEXT_KEY.to_string(),
        }
    }
}

impl SessionContextSession {
    pub fn new(session: Session) -> Self {
        Self::from(session)
    }

    pub fn from_config(session: Session, config: &SessionContextConfig) -> Self {
        Self {
            session,
            session_context_key: config.session_context_key.clone(),
        }
    }

    pub fn with_key(session: Session, session_context_key: impl Into<String>) -> Self {
        Self {
            session,
            session_context_key: session_context_key.into(),
        }
    }

    pub fn raw_session(&self) -> &Session {
        &self.session
    }

    pub async fn insert<Extra>(&self, context: &SessionContext<Extra>) -> SessionContextResult<()>
    where
        Extra: Serialize,
    {
        self.session
            .insert(&self.session_context_key, context)
            .await
            .map_err(|source| SessionContextError::Session { source })
    }

    pub async fn get<Extra>(&self) -> SessionContextResult<Option<SessionContext<Extra>>>
    where
        Extra: DeserializeOwned,
    {
        self.session
            .get(&self.session_context_key)
            .await
            .map_err(|source| SessionContextError::Session { source })
    }

    pub async fn require<Extra>(&self) -> SessionContextResult<SessionContext<Extra>>
    where
        Extra: DeserializeOwned,
    {
        self.get().await?.ok_or(SessionContextError::MissingContext)
    }

    pub async fn clear(&self) -> SessionContextResult<()> {
        self.session
            .remove_value(&self.session_context_key)
            .await
            .map(|_| ())
            .map_err(|source| SessionContextError::Session { source })
    }

    pub async fn is_authenticated<Extra>(&self) -> SessionContextResult<bool>
    where
        Extra: DeserializeOwned,
    {
        Ok(self.get::<Extra>().await?.is_some())
    }

    pub async fn cycle_id(&self) -> SessionContextResult<()> {
        self.session
            .cycle_id()
            .await
            .map_err(|source| SessionContextError::Session { source })
    }

    pub async fn flush(&self) -> SessionContextResult<()> {
        self.session
            .flush()
            .await
            .map_err(|source| SessionContextError::Session { source })
    }
}

#[cfg(test)]
mod tests {
    use securitydept_utils::redirect::RedirectTargetRule;

    use super::*;

    #[test]
    fn test_default_config() {
        let config = SessionContextConfig::default();
        assert_eq!(config.cookie_name, DEFAULT_COOKIE_NAME);
        assert_eq!(config.session_context_key, DEFAULT_SESSION_CONTEXT_KEY);
        assert_eq!(config.cookie_path, "/");
        assert!(config.http_only);
        assert!(!config.secure);
        assert_eq!(config.same_site, SessionCookieSameSite::Lax);
        assert_eq!(config.ttl, Some(StdDuration::from_secs(86_400)));
        assert_eq!(
            config.post_auth_redirect.default_redirect_target.as_deref(),
            Some("/")
        );
    }

    #[test]
    fn test_context_with_extra_data() {
        let context = SessionContext::builder()
            .principal(
                SessionPrincipal::builder()
                    .subject("dev-session")
                    .display_name("dev")
                    .build(),
            )
            .attributes(HashMap::from([(
                "mode".to_string(),
                Value::String("dev".to_string()),
            )]))
            .extra(HashMap::from([(
                "provider".to_string(),
                Value::String("local".to_string()),
            )]))
            .build();

        assert_eq!(context.principal.subject, "dev-session");
        assert_eq!(context.principal.display_name, "dev");
        assert_eq!(
            context.attributes.get("mode"),
            Some(&Value::String("dev".to_string()))
        );
        assert_eq!(
            context
                .extra
                .as_ref()
                .and_then(|extra| extra.get("provider")),
            Some(&Value::String("local".to_string()))
        );
    }

    #[test]
    fn test_post_auth_redirect_resolution() {
        let config = SessionContextConfig::builder()
            .post_auth_redirect(RedirectTargetConfig::dynamic_default_and_dynamic_targets(
                "/",
                [RedirectTargetRule::Strict {
                    value: "/app".to_string(),
                }],
            ))
            .build();

        assert_eq!(
            config
                .resolve_post_auth_redirect(None)
                .expect("default redirect should resolve"),
            "/"
        );
        assert_eq!(
            config
                .resolve_post_auth_redirect(Some("/app"))
                .expect("dynamic redirect should resolve"),
            "/app"
        );
    }
}