use std::{env, fmt::Display};
use anyhow::{bail, Context, Result};
use base64::{engine::general_purpose::STANDARD_NO_PAD as b64, Engine};
use orion::aead::SecretKey;
use sqlx::{sqlite::SqliteRow, FromRow, Row, SqlitePool};
use time::OffsetDateTime;
use uuid::Uuid;
use crate::crypto;
const INVALID_TOKEN: &str = "Invalid session token";
const DEFAULT_SESSION_LIFETIME: time::Duration = time::Duration::hours(8);
pub struct SessionKey {
id: Uuid,
key: SecretKey,
expire_time: OffsetDateTime,
}
impl FromRow<'_, SqliteRow> for SessionKey {
fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
let id: Uuid = row.get("id");
let key: Vec<u8> = row.get("key");
let expire_time: OffsetDateTime = row.get("expire_time");
let key = SecretKey::from_slice(&key).map_err(|e| sqlx::Error::Decode(Box::new(e)))?;
Ok(Self {
id,
key,
expire_time,
})
}
}
impl SessionKey {
fn new(expire_time: &OffsetDateTime) -> Self {
Self {
id: uuid::Uuid::new_v4(),
key: SecretKey::default(),
expire_time: *expire_time,
}
}
pub async fn get(db: &SqlitePool, id: &Uuid) -> Result<Self> {
struct SessionKeyDB {
id: Uuid,
key: Vec<u8>,
expire_time: OffsetDateTime,
}
let session = sqlx::query_as!(
SessionKeyDB,
r#"select id as "id: _", key, expire_time from session_keys where id = ?"#,
id,
)
.fetch_one(db)
.await?;
Ok(Self {
id: session.id,
key: SecretKey::from_slice(&session.key)?,
expire_time: session.expire_time,
})
}
async fn insert(&self, db: &SqlitePool) -> Result<()> {
let key = self.key.unprotected_as_bytes();
sqlx::query!(
"insert into session_keys (id, key, expire_time) values (?, ?, ?)",
self.id,
key,
self.expire_time
)
.execute(db)
.await
.context("Failed to insert session key")?;
Ok(())
}
pub async fn delete(&self, db: &SqlitePool) -> Result<()> {
sqlx::query!("delete from session_keys where id = ?", self.id)
.execute(db)
.await
.context("Failed to delete session key")?;
Ok(())
}
pub async fn delete_expired(db: &SqlitePool) -> Result<()> {
let now = OffsetDateTime::now_utc();
sqlx::query!("delete from session_keys where expire_time < ?", now)
.execute(db)
.await
.context("Failed to delete expired session key")?;
Ok(())
}
}
pub struct SessionToken(String);
impl SessionToken {
pub fn from_env() -> Result<Self> {
Ok(Self(env::var("RUDRIC_SESSION")?))
}
pub async fn new(
db: &SqlitePool,
master_key: SecretKey,
lifetime: Option<time::Duration>,
) -> Result<Self> {
let expire_time = OffsetDateTime::now_utc() + lifetime.unwrap_or(DEFAULT_SESSION_LIFETIME);
let session_key = SessionKey::new(&expire_time);
session_key.insert(db).await?;
let timed_key = [
&expire_time.unix_timestamp().to_be_bytes(),
master_key.unprotected_as_bytes(),
]
.concat();
let encrypted_timed_key = crypto::encrypt(&session_key.key, &timed_key)?;
let session_token = [session_key.id.as_bytes(), encrypted_timed_key.as_slice()].concat();
if let Err(e) = SessionKey::delete_expired(db).await {
eprintln!("Error deleting expired session tokens: {e}");
}
Ok(Self(b64.encode(session_token)))
}
pub async fn get_expire_time(&self, db: &SqlitePool) -> Result<OffsetDateTime> {
let (_, decrypted_timed_key) = self.decrypt_timed_key(db).await?;
let (expire_time, _) = split_timed_key(&decrypted_timed_key)?;
Ok(expire_time)
}
pub async fn into_master_key(self, db: &SqlitePool) -> Result<SecretKey> {
let (session_key, decrypted_timed_key) = self.decrypt_timed_key(db).await?;
let (expire_time, secret_key) = split_timed_key(&decrypted_timed_key)?;
if expire_time < OffsetDateTime::now_utc() {
session_key.delete(db).await?;
bail!("Session key has expired");
}
if let Err(e) = SessionKey::delete_expired(db).await {
eprintln!("Error deleting expired session tokens: {e}");
}
Ok(secret_key)
}
async fn decrypt_timed_key(&self, db: &SqlitePool) -> Result<(SessionKey, Vec<u8>)> {
let (session_id, encrypted_timed_key) = self.split_id()?;
let session_key = match SessionKey::get(db, &session_id).await {
Ok(s) => s,
Err(_) => bail!(INVALID_TOKEN),
};
let decrypted_timed_key = crypto::decrypt(&session_key.key, &encrypted_timed_key)?;
Ok((session_key, decrypted_timed_key))
}
pub fn split_id(&self) -> Result<(Uuid, Vec<u8>)> {
let session_token_bytes = b64.decode(self.0.clone())?;
if session_token_bytes.len() < 16 {
bail!(INVALID_TOKEN)
}
let (session_id, encrypted_timed_key) = session_token_bytes.split_at(16);
let session_id = Uuid::from_bytes(session_id.try_into()?);
Ok((session_id, encrypted_timed_key.to_vec()))
}
}
impl Display for SessionToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
fn split_timed_key(timed_key: &[u8]) -> Result<(OffsetDateTime, SecretKey)> {
if timed_key.len() < 8 {
bail!(INVALID_TOKEN)
}
let (timestamp_bytes, decrypted_master_key) = timed_key.split_at(8);
let timestamp = i64::from_be_bytes(timestamp_bytes.try_into()?);
let expire_time = OffsetDateTime::from_unix_timestamp(timestamp)?;
let secret_key = SecretKey::from_slice(decrypted_master_key)?;
Ok((expire_time, secret_key))
}
#[cfg(test)]
mod session_tests {
use super::*;
use anyhow::Result;
use sqlx::SqlitePool;
const ONE_SECOND: std::time::Duration = std::time::Duration::from_secs(1);
#[sqlx::test]
async fn test_new_token_default_lifetime(db: SqlitePool) -> Result<()> {
let secret_key = SecretKey::default();
let now = OffsetDateTime::now_utc();
let token = SessionToken::new(&db, secret_key, None).await?;
let expire_time = token.get_expire_time(&db).await?;
let time_diff = expire_time - (now + DEFAULT_SESSION_LIFETIME);
assert!(time_diff < ONE_SECOND);
Ok(())
}
#[sqlx::test]
async fn test_new_token_custom_lifetime(db: SqlitePool) -> Result<()> {
let session_lifetime = time::Duration::hours(4);
let secret_key = SecretKey::default();
let now = OffsetDateTime::now_utc();
let token = SessionToken::new(&db, secret_key, Some(session_lifetime)).await?;
let expire_time = token.get_expire_time(&db).await?;
let time_diff = expire_time - (now + session_lifetime);
assert!(time_diff < ONE_SECOND);
Ok(())
}
}