pub mod config;
#[cfg(feature = "service")]
mod service;
use std::collections::HashMap;
#[cfg(test)]
use std::time::Duration as StdDuration;
pub use config::{
NoopSessionContextConfigValidator, ResolvedSessionContextConfig, SessionContextConfig,
SessionContextConfigSource, SessionContextConfigValidationError,
SessionContextConfigValidationFailure, SessionContextConfigValidator,
SessionContextFixedPostAuthRedirectValidator,
};
use http::StatusCode;
#[cfg(test)]
use securitydept_utils::redirect::RedirectTargetConfig;
use securitydept_utils::{
error::{ErrorPresentation, ToErrorPresentation, UserRecovery},
principal::AuthenticatedPrincipal,
redirect::RedirectTargetError,
};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::Value;
#[cfg(feature = "service")]
pub use service::{
DevSessionAuthService, OidcSessionAuthService, OidcSessionAuthServiceConfig,
SessionAuthServiceError, SessionAuthServiceTrait,
};
use snafu::Snafu;
use tower_sessions::{
Expiry, Session, SessionManagerLayer, SessionStore,
cookie::{SameSite, time::Duration},
};
use typed_builder::TypedBuilder;
pub const DEFAULT_COOKIE_NAME: &str = "securitydept_session";
pub const DEFAULT_SESSION_CONTEXT_KEY: &str = "securitydept.session_context";
pub type SessionPrincipal = AuthenticatedPrincipal;
#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, TypedBuilder)]
pub struct SessionContext<Extra = HashMap<String, Value>> {
pub principal: SessionPrincipal,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
#[builder(default)]
pub attributes: HashMap<String, Value>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra: Option<Extra>,
}
#[derive(Debug, Clone, Copy, Serialize, serde::Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum SessionCookieSameSite {
Strict,
#[default]
Lax,
None,
}
impl From<SessionCookieSameSite> for SameSite {
fn from(value: SessionCookieSameSite) -> Self {
match value {
SessionCookieSameSite::Strict => SameSite::Strict,
SessionCookieSameSite::Lax => SameSite::Lax,
SessionCookieSameSite::None => SameSite::None,
}
}
}
pub fn build_session_layer<Store>(
config: &ResolvedSessionContextConfig,
store: Store,
) -> SessionManagerLayer<Store>
where
Store: SessionStore,
{
let mut layer = SessionManagerLayer::new(store)
.with_name(config.cookie_name.clone())
.with_path(config.cookie_path.clone())
.with_same_site(config.same_site.into())
.with_http_only(config.http_only)
.with_secure(config.secure);
if let Some(ttl) = config.ttl {
layer = layer.with_expiry(Expiry::OnInactivity(
Duration::seconds(ttl.as_secs() as i64),
));
}
layer
}
#[derive(Debug, Snafu)]
pub enum SessionContextError {
#[snafu(display("session context is missing"))]
MissingContext,
#[snafu(display("session operation failed: {source}"))]
Session {
source: tower_sessions::session::Error,
},
#[snafu(display("post-auth redirect is invalid: {source}"))]
RedirectTarget { source: RedirectTargetError },
}
pub type SessionContextResult<T> = Result<T, SessionContextError>;
impl SessionContextError {
pub fn status_code(&self) -> StatusCode {
match self {
Self::MissingContext => StatusCode::UNAUTHORIZED,
Self::Session { .. } | Self::RedirectTarget { .. } => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
impl ToErrorPresentation for SessionContextError {
fn to_error_presentation(&self) -> ErrorPresentation {
match self {
SessionContextError::MissingContext => ErrorPresentation::new(
"authentication_required",
"Sign in to continue.",
UserRecovery::Reauthenticate,
),
SessionContextError::Session { .. } => ErrorPresentation::new(
"session_unavailable",
"The session is temporarily unavailable.",
UserRecovery::Retry,
),
SessionContextError::RedirectTarget { .. } => ErrorPresentation::new(
"session_post_auth_redirect_invalid",
"The configured post-auth redirect is invalid.",
UserRecovery::ContactSupport,
),
}
}
}
#[derive(Clone)]
pub struct SessionContextSession {
session: Session,
session_context_key: String,
}
impl From<Session> for SessionContextSession {
fn from(session: Session) -> Self {
Self {
session,
session_context_key: DEFAULT_SESSION_CONTEXT_KEY.to_string(),
}
}
}
impl SessionContextSession {
pub fn new(session: Session) -> Self {
Self::from(session)
}
pub fn from_resolved_config(session: Session, config: &ResolvedSessionContextConfig) -> Self {
Self {
session,
session_context_key: config.session_context_key.clone(),
}
}
pub fn with_key(session: Session, session_context_key: impl Into<String>) -> Self {
Self {
session,
session_context_key: session_context_key.into(),
}
}
pub fn raw_session(&self) -> &Session {
&self.session
}
pub async fn insert<Extra>(&self, context: &SessionContext<Extra>) -> SessionContextResult<()>
where
Extra: Serialize,
{
self.session
.insert(&self.session_context_key, context)
.await
.map_err(|source| SessionContextError::Session { source })
}
pub async fn get<Extra>(&self) -> SessionContextResult<Option<SessionContext<Extra>>>
where
Extra: DeserializeOwned,
{
self.session
.get(&self.session_context_key)
.await
.map_err(|source| SessionContextError::Session { source })
}
pub async fn require<Extra>(&self) -> SessionContextResult<SessionContext<Extra>>
where
Extra: DeserializeOwned,
{
self.get().await?.ok_or(SessionContextError::MissingContext)
}
pub async fn clear(&self) -> SessionContextResult<()> {
self.session
.remove_value(&self.session_context_key)
.await
.map(|_| ())
.map_err(|source| SessionContextError::Session { source })
}
pub async fn is_authenticated<Extra>(&self) -> SessionContextResult<bool>
where
Extra: DeserializeOwned,
{
Ok(self.get::<Extra>().await?.is_some())
}
pub async fn cycle_id(&self) -> SessionContextResult<()> {
self.session
.cycle_id()
.await
.map_err(|source| SessionContextError::Session { source })
}
pub async fn flush(&self) -> SessionContextResult<()> {
self.session
.flush()
.await
.map_err(|source| SessionContextError::Session { source })
}
}
#[cfg(test)]
mod tests {
use securitydept_utils::redirect::RedirectTargetRule;
use super::*;
#[test]
fn test_default_config() {
let config = SessionContextConfig::default();
assert_eq!(config.cookie_name, DEFAULT_COOKIE_NAME);
assert_eq!(config.session_context_key, DEFAULT_SESSION_CONTEXT_KEY);
assert_eq!(config.cookie_path, "/");
assert!(config.http_only);
assert!(!config.secure);
assert_eq!(config.same_site, SessionCookieSameSite::Lax);
assert_eq!(config.ttl, Some(StdDuration::from_secs(86_400)));
assert_eq!(
config.post_auth_redirect.default_redirect_target.as_deref(),
Some("/")
);
}
#[test]
fn test_context_with_extra_data() {
let context = SessionContext::builder()
.principal(
SessionPrincipal::builder()
.subject("dev-session")
.display_name("dev")
.build(),
)
.attributes(HashMap::from([(
"mode".to_string(),
Value::String("dev".to_string()),
)]))
.extra(HashMap::from([(
"provider".to_string(),
Value::String("local".to_string()),
)]))
.build();
assert_eq!(context.principal.subject, "dev-session");
assert_eq!(context.principal.display_name, "dev");
assert_eq!(
context.attributes.get("mode"),
Some(&Value::String("dev".to_string()))
);
assert_eq!(
context
.extra
.as_ref()
.and_then(|extra| extra.get("provider")),
Some(&Value::String("local".to_string()))
);
}
#[test]
fn test_post_auth_redirect_resolution() {
let config = SessionContextConfigSource::resolve_all(
&SessionContextConfig::builder()
.post_auth_redirect(RedirectTargetConfig::dynamic_default_and_dynamic_targets(
"/",
[RedirectTargetRule::Strict {
value: "/app".to_string(),
}],
))
.build(),
)
.expect("session context config should resolve");
assert_eq!(
config
.resolve_post_auth_redirect(None)
.expect("default redirect should resolve"),
"/"
);
assert_eq!(
config
.resolve_post_auth_redirect(Some("/app"))
.expect("dynamic redirect should resolve"),
"/app"
);
}
#[test]
fn fixed_post_auth_redirect_validator_rejects_override() {
let config = SessionContextConfig::builder()
.post_auth_redirect(RedirectTargetConfig::strict_default("/admin"))
.build();
let validator = SessionContextFixedPostAuthRedirectValidator::new(
RedirectTargetConfig::strict_default("/"),
);
let error = SessionContextConfigSource::resolve_all_with_validator(&config, &validator)
.expect_err("unexpected session post_auth_redirect should be rejected");
assert!(matches!(
error,
SessionContextConfigValidationFailure::Validation { source }
if source.field_path == "post_auth_redirect"
&& source.code == "fixed_post_auth_redirect_conflict"
));
}
}