use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use chrono::{Duration, Utc};
use rand::RngCore;
use crate::error::Result;
use crate::orm::{Db, Row};
use super::role::Role;
use super::users::Identity;
pub const SESSION_COOKIE: &str = "rustio_session";
const SESSION_LENGTH_DAYS: i64 = 14;
pub async fn init_session_tables(db: &Db) -> Result<()> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS rustio_sessions (
token TEXT PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES rustio_users(id) ON DELETE CASCADE,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
last_seen TIMESTAMPTZ NOT NULL DEFAULT NOW()
)",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_sessions_user_idx ON rustio_sessions (user_id)",
)
.execute(db.pool())
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS rustio_sessions_expires_idx ON rustio_sessions (expires_at)",
)
.execute(db.pool())
.await?;
Ok(())
}
pub(crate) async fn migrate_session_schema(db: &Db) -> Result<()> {
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS ip TEXT")
.execute(db.pool())
.await?;
sqlx::query("ALTER TABLE rustio_sessions ADD COLUMN IF NOT EXISTS user_agent TEXT")
.execute(db.pool())
.await?;
Ok(())
}
pub async fn create_session(db: &Db, user_id: i64) -> Result<String> {
let token = random_token();
let expires = Utc::now() + Duration::days(SESSION_LENGTH_DAYS);
sqlx::query(
"INSERT INTO rustio_sessions (token, user_id, expires_at) VALUES ($1, $2, $3)",
)
.bind(&token)
.bind(user_id)
.bind(expires)
.execute(db.pool())
.await?;
Ok(token)
}
pub async fn delete_session(db: &Db, token: &str) -> Result<()> {
sqlx::query("DELETE FROM rustio_sessions WHERE token = $1")
.bind(token)
.execute(db.pool())
.await?;
Ok(())
}
pub async fn identity_from_session(db: &Db, token: &str) -> Result<Option<Identity>> {
let row = sqlx::query(
"SELECT u.id, u.email, u.role, u.is_active, u.is_demo, u.demo_label, s.expires_at
FROM rustio_sessions s
JOIN rustio_users u ON u.id = s.user_id
WHERE s.token = $1",
)
.bind(token)
.fetch_optional(db.pool())
.await?;
let row = match row {
Some(r) => r,
None => return Ok(None),
};
let r = Row::from_pg(&row);
let expires_at = r.get_datetime("expires_at")?;
if expires_at < Utc::now() {
let _ = delete_session(db, token).await;
return Ok(None);
}
let db_clone = db.clone();
let token_owned = token.to_string();
tokio::spawn(async move {
let _ = sqlx::query("UPDATE rustio_sessions SET last_seen = NOW() WHERE token = $1")
.bind(&token_owned)
.execute(db_clone.pool())
.await;
});
Ok(Some(Identity {
user_id: r.get_i64("id")?,
email: r.get_string("email")?,
role: Role::parse(&r.get_string("role")?)?,
is_active: r.get_bool("is_active")?,
is_demo: r.get_bool("is_demo")?,
demo_label: r.get_optional_string("demo_label")?,
}))
}
pub async fn purge_expired_sessions(db: &Db) -> Result<u64> {
let result = sqlx::query("DELETE FROM rustio_sessions WHERE expires_at < NOW()")
.execute(db.pool())
.await?;
Ok(result.rows_affected())
}
pub fn session_token_from_cookie(cookie_header: &str) -> Option<String> {
let prefix = format!("{SESSION_COOKIE}=");
for part in cookie_header.split(';') {
let part = part.trim();
if let Some(v) = part.strip_prefix(&prefix) {
return Some(v.to_string());
}
}
None
}
fn random_token() -> String {
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
URL_SAFE_NO_PAD.encode(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_token_from_cookie_header() {
let h = "foo=bar; rustio_session=abc123; other=x";
assert_eq!(session_token_from_cookie(h), Some("abc123".into()));
}
#[test]
fn returns_none_when_cookie_missing() {
let h = "foo=bar; other=x";
assert!(session_token_from_cookie(h).is_none());
}
#[test]
fn random_token_has_reasonable_entropy() {
assert_ne!(random_token(), random_token());
}
#[tokio::test]
#[ignore = "needs `RUSTIO_TEST_DB=1` + a running postgres (URL via RUSTIO_TEST_DATABASE_URL or default)"]
async fn existing_session_crud_unaffected_by_migration() {
let url = std::env::var("RUSTIO_TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:dev@localhost:5432/rustio_dev".into());
let opts = crate::orm::DbOptions {
max_connections: 2,
..crate::orm::DbOptions::default()
};
let db = crate::orm::Db::connect_with(&url, opts).await.unwrap();
crate::auth::init_tables(&db).await.unwrap();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let email = format!("sess_smoke_{pid}_{nanos}@example.test");
let user_id = crate::auth::create_user(&db, &email, "secret-pw-123", Role::User)
.await
.unwrap();
let token = create_session(&db, user_id).await.unwrap();
let identity = identity_from_session(&db, &token)
.await
.unwrap()
.expect("session resolves to identity");
assert_eq!(identity.user_id, user_id);
assert_eq!(identity.email, email);
delete_session(&db, &token).await.unwrap();
assert!(identity_from_session(&db, &token).await.unwrap().is_none());
let _ = sqlx::query("DELETE FROM rustio_users WHERE id = $1")
.bind(user_id)
.execute(db.pool())
.await;
}
}