use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use chrono::{DateTime, Duration, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::error::Result;
use crate::orm::{Db, Row};
use super::role::Role;
use super::users::Identity;
pub const SESSION_COOKIE: &str = "rustio_session";
const SESSION_LENGTH_DAYS: i64 = 14;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionTrust {
Authenticated,
Elevated,
MfaVerified,
}
impl SessionTrust {
pub const fn as_str(self) -> &'static str {
match self {
Self::Authenticated => "authenticated",
Self::Elevated => "elevated",
Self::MfaVerified => "mfa_verified",
}
}
pub const fn rank(self) -> u8 {
match self {
Self::Authenticated => 1,
Self::Elevated => 2,
Self::MfaVerified => 3,
}
}
pub const fn satisfies(self, other: SessionTrust) -> bool {
self.rank() >= other.rank()
}
pub fn parse(s: &str) -> Self {
match s {
"elevated" => Self::Elevated,
"mfa_verified" => Self::MfaVerified,
_ => Self::Authenticated,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SessionInvalidationReason {
Logout,
Expired,
UserRequested,
AdministrativeRevoke,
PasswordReset,
PasswordResetByOther,
MfaEnabled,
MfaDisabled,
MfaDisabledByOther,
AuthorityEscalation,
EmergencyRecovery,
TrustEscalation,
}
impl SessionInvalidationReason {
pub const fn as_str(self) -> &'static str {
match self {
Self::Logout => "logout",
Self::Expired => "expired",
Self::UserRequested => "user_requested",
Self::AdministrativeRevoke => "administrative_revoke",
Self::PasswordReset => "password_reset",
Self::PasswordResetByOther => "password_reset_by_other",
Self::MfaEnabled => "mfa_enabled",
Self::MfaDisabled => "mfa_disabled",
Self::MfaDisabledByOther => "mfa_disabled_by_other",
Self::AuthorityEscalation => "authority_escalation",
Self::EmergencyRecovery => "emergency_recovery",
Self::TrustEscalation => "trust_escalation",
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum SessionTarget {
User { user_id: i64 },
UserExceptCurrent {
user_id: i64,
current_session_id: i64,
},
Single { session_id: i64 },
}
#[derive(Debug, Clone, Serialize)]
pub struct Session {
pub session_id: i64,
pub user_id: i64,
pub trust_level: SessionTrust,
pub created_at: DateTime<Utc>,
pub last_seen: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub elevated_until: Option<DateTime<Utc>>,
pub ip: Option<String>,
pub user_agent: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct InvalidationOutcome {
pub revoked_session_ids: Vec<i64>,
pub reason: Option<SessionInvalidationReason>,
}
pub async fn init_session_tables(db: &Db) -> Result<()> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS rustio_sessions (
token TEXT PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_seen TIMESTAMPTZ NOT NULL DEFAULT NOW()
)",
)
.execute(db.pool())
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS rustio_sessions_user_idx ON rustio_sessions (user_id)")
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_sessions_expires_idx ON rustio_sessions (expires_at)",
)
.execute(db.pool())
.await?;
Ok(())
}
pub(crate) async fn migrate_session_schema(db: &Db) -> Result<()> {
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS ip TEXT")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS user_agent TEXT")
.execute(db.pool())
.await?;
Ok(())
}
pub(crate) async fn migrate_session_lifecycle(db: &Db) -> Result<()> {
sqlx::query("CREATE SEQUENCE IF NOT EXISTS rustio_sessions_session_id_seq")
.execute(db.pool())
.await?;
sqlx::query(
"ALTER TABLE rustio_sessions \
ADD COLUMN IF NOT EXISTS session_id BIGINT NOT NULL DEFAULT \
nextval('rustio_sessions_session_id_seq')",
)
.execute(db.pool())
.await?;
sqlx::query(
"ALTER SEQUENCE rustio_sessions_session_id_seq OWNED BY rustio_sessions.session_id",
)
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS token_hash TEXT")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS device_id TEXT")
.execute(db.pool())
.await?;
sqlx::query(
"ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS trust_level TEXT \
NOT NULL DEFAULT 'authenticated'",
)
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS elevated_until TIMESTAMPTZ")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS parent_session_id BIGINT")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_at TIMESTAMPTZ")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS revoked_reason TEXT")
.execute(db.pool())
.await?;
sqlx::query(
"DO $$ BEGIN \
IF NOT EXISTS ( \
SELECT 1 FROM pg_constraint \
WHERE conname = 'rustio_sessions_trust_level_check' \
) THEN \
ALTER TABLE rustio_sessions \
ADD CONSTRAINT rustio_sessions_trust_level_check \
CHECK (trust_level IN ('authenticated', 'elevated', 'mfa_verified')); \
END IF; \
END $$",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_session_id_uq \
ON rustio_sessions (session_id)",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE UNIQUE INDEX IF NOT EXISTS rustio_sessions_token_hash_uq \
ON rustio_sessions (token_hash) \
WHERE revoked_at IS NULL AND token_hash IS NOT NULL",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_sessions_user_active_idx \
ON rustio_sessions (user_id) WHERE revoked_at IS NULL",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_sessions_parent_idx \
ON rustio_sessions (parent_session_id) WHERE parent_session_id IS NOT NULL",
)
.execute(db.pool())
.await?;
Ok(())
}
pub async fn create_session(db: &Db, user_id: i64) -> Result<String> {
let token = random_token();
let token_hash = hash_token_for_storage(&token);
let expires = Utc::now() + Duration::days(SESSION_LENGTH_DAYS);
sqlx::query(
"INSERT INTO rustio_sessions (token, token_hash, user_id, expires_at) \
VALUES ($1, $2, $3, $4)",
)
.bind(&token)
.bind(&token_hash)
.bind(user_id)
.bind(expires)
.execute(db.pool())
.await?;
Ok(token)
}
pub async fn delete_session(db: &Db, token: &str) -> Result<()> {
sqlx::query("DELETE FROM rustio_sessions WHERE token = $1 OR token_hash = $2")
.bind(token)
.bind(hash_token_for_storage(token))
.execute(db.pool())
.await?;
Ok(())
}
pub async fn invalidate_sessions(
db: &Db,
target: SessionTarget,
reason: SessionInvalidationReason,
) -> Result<InvalidationOutcome> {
let reason_str = reason.as_str();
let revoked_ids: Vec<i64> = match target {
SessionTarget::User { user_id } => {
sqlx::query_scalar::<_, i64>(
"UPDATE rustio_sessions \
SET revoked_at = NOW(), revoked_reason = $2 \
WHERE user_id = $1 AND revoked_at IS NULL \
RETURNING session_id",
)
.bind(user_id)
.bind(reason_str)
.fetch_all(db.pool())
.await?
}
SessionTarget::UserExceptCurrent {
user_id,
current_session_id,
} => {
sqlx::query_scalar::<_, i64>(
"UPDATE rustio_sessions \
SET revoked_at = NOW(), revoked_reason = $3 \
WHERE user_id = $1 AND session_id <> $2 AND revoked_at IS NULL \
RETURNING session_id",
)
.bind(user_id)
.bind(current_session_id)
.bind(reason_str)
.fetch_all(db.pool())
.await?
}
SessionTarget::Single { session_id } => {
sqlx::query_scalar::<_, i64>(
"UPDATE rustio_sessions \
SET revoked_at = NOW(), revoked_reason = $2 \
WHERE session_id = $1 AND revoked_at IS NULL \
RETURNING session_id",
)
.bind(session_id)
.bind(reason_str)
.fetch_all(db.pool())
.await?
}
};
Ok(InvalidationOutcome {
revoked_session_ids: revoked_ids,
reason: Some(reason),
})
}
pub async fn logout_session(db: &Db, token: &str) -> Result<()> {
let token_hash = hash_token_for_storage(token);
let session_id: Option<i64> = sqlx::query_scalar::<_, i64>(
"SELECT session_id FROM rustio_sessions \
WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
AND revoked_at IS NULL \
LIMIT 1",
)
.bind(&token_hash)
.bind(token)
.fetch_optional(db.pool())
.await?;
if let Some(sid) = session_id {
invalidate_sessions(
db,
SessionTarget::Single { session_id: sid },
SessionInvalidationReason::Logout,
)
.await?;
}
Ok(())
}
pub async fn list_active_for_user(db: &Db, user_id: i64) -> Result<Vec<Session>> {
let rows = sqlx::query(
"SELECT session_id, user_id, trust_level, created_at, last_seen, expires_at, \
elevated_until, ip, user_agent \
FROM rustio_sessions \
WHERE user_id = $1 AND revoked_at IS NULL AND expires_at > NOW() \
ORDER BY last_seen DESC",
)
.bind(user_id)
.fetch_all(db.pool())
.await?;
rows.iter()
.map(|r| {
let r = Row::from_pg(r);
Ok(Session {
session_id: r.get_i64("session_id")?,
user_id: r.get_i64("user_id")?,
trust_level: SessionTrust::parse(&r.get_string("trust_level")?),
created_at: r.get_datetime("created_at")?,
last_seen: r.get_datetime("last_seen")?,
expires_at: r.get_datetime("expires_at")?,
elevated_until: None, ip: r.get_optional_string("ip")?,
user_agent: r.get_optional_string("user_agent")?,
})
})
.collect()
}
pub async fn current_session_id(db: &Db, token: &str) -> Result<Option<i64>> {
let token_hash = hash_token_for_storage(token);
let id: Option<i64> = sqlx::query_scalar::<_, i64>(
"SELECT session_id FROM rustio_sessions \
WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
AND revoked_at IS NULL AND expires_at > NOW() \
LIMIT 1",
)
.bind(&token_hash)
.bind(token)
.fetch_optional(db.pool())
.await?;
Ok(id)
}
pub async fn identity_from_session(db: &Db, token: &str) -> Result<Option<Identity>> {
let token_hash = hash_token_for_storage(token);
let row = sqlx::query(
"SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
s.expires_at, s.token_hash IS NOT NULL AS hashed \
FROM rustio_sessions s \
JOIN rustio_users u ON u.id = s.user_id \
WHERE s.token_hash = $1 AND s.revoked_at IS NULL",
)
.bind(&token_hash)
.fetch_optional(db.pool())
.await?;
let row = match row {
Some(r) => Some(r),
None => {
sqlx::query(
"SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, \
s.expires_at, FALSE AS hashed \
FROM rustio_sessions s \
JOIN rustio_users u ON u.id = s.user_id \
WHERE s.token = $1 AND s.token_hash IS NULL AND s.revoked_at IS NULL",
)
.bind(token)
.fetch_optional(db.pool())
.await?
}
};
let row = match row {
Some(r) => r,
None => return Ok(None),
};
let r = Row::from_pg(&row);
let expires_at = r.get_datetime("expires_at")?;
if expires_at < Utc::now() {
let _ = delete_session(db, token).await;
return Ok(None);
}
let db_clone = db.clone();
let token_owned = token.to_string();
let token_hash_owned = token_hash.clone();
tokio::spawn(async move {
let _ = sqlx::query(
"UPDATE rustio_sessions SET last_seen = NOW() \
WHERE (token_hash = $1 OR (token_hash IS NULL AND token = $2)) \
AND revoked_at IS NULL",
)
.bind(&token_hash_owned)
.bind(&token_owned)
.execute(db_clone.pool())
.await;
});
Ok(Some(Identity {
user_id: r.get_i64("id")?,
email: r.get_string("email")?,
role: Role::parse(&r.get_string("role")?)?,
is_active: r.get_bool("is_active")?,
is_demo: r.get_bool("is_demo")?,
demo_label: r.get_optional_string("demo_label")?,
}))
}
pub async fn purge_expired_sessions(db: &Db) -> Result<u64> {
let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at < NOW()")
.execute(db.pool())
.await?;
Ok(result.rows_affected())
}
pub fn session_token_from_cookie(cookie_header: &str) -> Option<String> {
let prefix = format!("{SESSION_COOKIE}=");
for part in cookie_header.split(';') {
let part = part.trim();
if let Some(v) = part.strip_prefix(&prefix) {
return Some(v.to_string());
}
}
None
}
pub(crate) fn random_token() -> String {
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
pub(crate) fn hash_token_for_storage(token: &str) -> String {
let digest = Sha256::digest(token.as_bytes());
URL_SAFE_NO_PAD.encode(digest)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_token_from_cookie_header() {
let h = "foo=bar; rustio_session=abc123; other=x";
assert_eq!(session_token_from_cookie(h), Some("abc123".into()));
}
#[test]
fn returns_none_when_cookie_missing() {
let h = "foo=bar; other=x";
assert!(session_token_from_cookie(h).is_none());
}
#[test]
fn random_token_has_reasonable_entropy() {
assert_ne!(random_token(), random_token());
}
#[test]
fn hash_token_is_deterministic() {
let token = random_token();
assert_eq!(
hash_token_for_storage(&token),
hash_token_for_storage(&token)
);
}
#[test]
fn hash_token_differs_per_token() {
let a = hash_token_for_storage("aaaa");
let b = hash_token_for_storage("aaab");
assert_ne!(a, b);
}
#[test]
fn hash_token_output_is_url_safe_base64() {
let h = hash_token_for_storage("anything");
assert_eq!(h.len(), 43);
assert!(h
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'));
}
#[test]
fn hash_token_does_not_leak_plaintext() {
let plaintext = "secret-cookie-value-12345";
let h = hash_token_for_storage(plaintext);
assert!(!h.contains("secret"));
assert!(!h.contains("12345"));
}
#[test]
fn session_trust_orders_correctly() {
assert!(SessionTrust::Authenticated.rank() < SessionTrust::Elevated.rank());
assert!(SessionTrust::Elevated.rank() < SessionTrust::MfaVerified.rank());
assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Elevated));
assert!(SessionTrust::MfaVerified.satisfies(SessionTrust::Authenticated));
assert!(SessionTrust::Authenticated.satisfies(SessionTrust::Authenticated));
assert!(!SessionTrust::Authenticated.satisfies(SessionTrust::Elevated));
assert!(!SessionTrust::Elevated.satisfies(SessionTrust::MfaVerified));
}
#[test]
fn session_trust_round_trips_through_sql() {
for tier in [
SessionTrust::Authenticated,
SessionTrust::Elevated,
SessionTrust::MfaVerified,
] {
assert_eq!(SessionTrust::parse(tier.as_str()), tier);
}
}
#[test]
fn session_trust_parse_defaults_safely_on_unknown() {
assert_eq!(SessionTrust::parse("garbage"), SessionTrust::Authenticated);
assert_eq!(SessionTrust::parse(""), SessionTrust::Authenticated);
}
#[test]
fn invalidation_reason_strings_are_distinct() {
let reasons = [
SessionInvalidationReason::Logout,
SessionInvalidationReason::Expired,
SessionInvalidationReason::UserRequested,
SessionInvalidationReason::AdministrativeRevoke,
SessionInvalidationReason::PasswordReset,
SessionInvalidationReason::PasswordResetByOther,
SessionInvalidationReason::MfaEnabled,
SessionInvalidationReason::MfaDisabled,
SessionInvalidationReason::MfaDisabledByOther,
SessionInvalidationReason::AuthorityEscalation,
SessionInvalidationReason::EmergencyRecovery,
SessionInvalidationReason::TrustEscalation,
];
let mut set = std::collections::HashSet::new();
for r in reasons {
assert!(set.insert(r.as_str()), "duplicate as_str() for {r:?}");
}
assert_eq!(set.len(), reasons.len());
}
}