rudric 0.1.7

CLI tool for managing secrets in a secure way
Documentation
use std::{env, fmt::Display};

use anyhow::{bail, Context, Result};
use base64::{engine::general_purpose::STANDARD_NO_PAD as b64, Engine};
use orion::aead::SecretKey;
use sqlx::{sqlite::SqliteRow, FromRow, Row, SqlitePool};
use time::OffsetDateTime;
use uuid::Uuid;

use crate::crypto;

const INVALID_TOKEN: &str = "Invalid session token";
const DEFAULT_SESSION_LIFETIME: time::Duration = time::Duration::hours(8);

pub struct SessionKey {
    id: Uuid,
    key: SecretKey,
    expire_time: OffsetDateTime,
}

impl FromRow<'_, SqliteRow> for SessionKey {
    fn from_row(row: &SqliteRow) -> sqlx::Result<Self> {
        let id: Uuid = row.get("id");
        let key: Vec<u8> = row.get("key");
        let expire_time: OffsetDateTime = row.get("expire_time");

        let key = SecretKey::from_slice(&key).map_err(|e| sqlx::Error::Decode(Box::new(e)))?;

        Ok(Self {
            id,
            key,
            expire_time,
        })
    }
}

impl SessionKey {
    fn new(expire_time: &OffsetDateTime) -> Self {
        Self {
            id: uuid::Uuid::new_v4(),
            key: SecretKey::default(),
            expire_time: *expire_time,
        }
    }

    pub async fn get(db: &SqlitePool, id: &Uuid) -> Result<Self> {
        struct SessionKeyDB {
            id: Uuid,
            key: Vec<u8>,
            expire_time: OffsetDateTime,
        }

        let session = sqlx::query_as!(
            SessionKeyDB,
            r#"select id as "id: _", key, expire_time from session_keys where id = ?"#,
            id,
        )
        .fetch_one(db)
        .await?;

        Ok(Self {
            id: session.id,
            key: SecretKey::from_slice(&session.key)?,
            expire_time: session.expire_time,
        })
    }

    async fn insert(&self, db: &SqlitePool) -> Result<()> {
        let key = self.key.unprotected_as_bytes();

        sqlx::query!(
            "insert into session_keys (id, key, expire_time) values (?, ?, ?)",
            self.id,
            key,
            self.expire_time
        )
        .execute(db)
        .await
        .context("Failed to insert session key")?;

        Ok(())
    }

    pub async fn delete(&self, db: &SqlitePool) -> Result<()> {
        sqlx::query!("delete from session_keys where id = ?", self.id)
            .execute(db)
            .await
            .context("Failed to delete session key")?;

        Ok(())
    }

    pub async fn delete_expired(db: &SqlitePool) -> Result<()> {
        let now = OffsetDateTime::now_utc();

        sqlx::query!("delete from session_keys where expire_time < ?", now)
            .execute(db)
            .await
            .context("Failed to delete expired session key")?;

        Ok(())
    }
}

pub struct SessionToken(String);

impl SessionToken {
    pub fn from_env() -> Result<Self> {
        Ok(Self(env::var("RUDRIC_SESSION")?))
    }

    /// Generate a new session token by establishing a token expiration time, generating a
    /// session key and encrypting the user's master key with it. A session token is
    /// generated by concatenating the expiration time (as bytes) with the encrypted master key
    /// (as bytes). This is base64 encoded and returned to the user as a session token.
    pub async fn new(
        db: &SqlitePool,
        master_key: SecretKey,
        lifetime: Option<time::Duration>,
    ) -> Result<Self> {
        // Convert from std::time::Duration to time::Duration
        // let lifetime: Option<time::Duration> =
        //     lifetime.map(|d| time::Duration::seconds_f64(d.as_secs_f64()));

        let expire_time = OffsetDateTime::now_utc() + lifetime.unwrap_or(DEFAULT_SESSION_LIFETIME);

        let session_key = SessionKey::new(&expire_time);
        session_key.insert(db).await?;

        // The timed key is a [u8] where the first 8 bytes are the expiration time as a
        // unix timestamp in the form of a big endian byte slice. The remaining bytes are the user's
        // master key.
        let timed_key = [
            &expire_time.unix_timestamp().to_be_bytes(),
            master_key.unprotected_as_bytes(),
        ]
        .concat();

        let encrypted_timed_key = crypto::encrypt(&session_key.key, &timed_key)?;

        // The session key ID is prepended to the encrypted timed key.
        let session_token = [session_key.id.as_bytes(), encrypted_timed_key.as_slice()].concat();

        if let Err(e) = SessionKey::delete_expired(db).await {
            eprintln!("Error deleting expired session tokens: {e}");
        }

        Ok(Self(b64.encode(session_token)))
    }

    /// Returns the expiration time of the token
    pub async fn get_expire_time(&self, db: &SqlitePool) -> Result<OffsetDateTime> {
        let (_, decrypted_timed_key) = self.decrypt_timed_key(db).await?;

        let (expire_time, _) = split_timed_key(&decrypted_timed_key)?;

        Ok(expire_time)
    }

    /// Gets the user's master key from the session token. First, it splits off the first 16 bytes
    /// of the session token to get the ID of the session key. Then fetches the session key from
    /// the database using this ID. This session key is used to decrypt the timed key. The first 8
    /// bytes from the timed key are split off and converted to the expiration time of the token.
    /// If the token is not expired, the decrypted master key is returned.
    ///
    /// Additionally, any expired session keys in the database are also deleted.
    pub async fn into_master_key(self, db: &SqlitePool) -> Result<SecretKey> {
        let (session_key, decrypted_timed_key) = self.decrypt_timed_key(db).await?;

        let (expire_time, secret_key) = split_timed_key(&decrypted_timed_key)?;

        if expire_time < OffsetDateTime::now_utc() {
            session_key.delete(db).await?;
            bail!("Session key has expired");
        }

        if let Err(e) = SessionKey::delete_expired(db).await {
            eprintln!("Error deleting expired session tokens: {e}");
        }

        Ok(secret_key)
    }

    async fn decrypt_timed_key(&self, db: &SqlitePool) -> Result<(SessionKey, Vec<u8>)> {
        let (session_id, encrypted_timed_key) = self.split_id()?;

        // Fetch the session key from the database
        let session_key = match SessionKey::get(db, &session_id).await {
            Ok(s) => s,
            Err(_) => bail!(INVALID_TOKEN),
        };

        let decrypted_timed_key = crypto::decrypt(&session_key.key, &encrypted_timed_key)?;

        Ok((session_key, decrypted_timed_key))
    }

    pub fn split_id(&self) -> Result<(Uuid, Vec<u8>)> {
        let session_token_bytes = b64.decode(self.0.clone())?;
        if session_token_bytes.len() < 16 {
            bail!(INVALID_TOKEN)
        }
        let (session_id, encrypted_timed_key) = session_token_bytes.split_at(16);
        let session_id = Uuid::from_bytes(session_id.try_into()?);
        Ok((session_id, encrypted_timed_key.to_vec()))
    }
}

impl Display for SessionToken {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.0)
    }
}

fn split_timed_key(timed_key: &[u8]) -> Result<(OffsetDateTime, SecretKey)> {
    // Read the expiration timestamp from the first 8 bytes of the decrypted timed key.
    if timed_key.len() < 8 {
        bail!(INVALID_TOKEN)
    }
    let (timestamp_bytes, decrypted_master_key) = timed_key.split_at(8);

    let timestamp = i64::from_be_bytes(timestamp_bytes.try_into()?);
    let expire_time = OffsetDateTime::from_unix_timestamp(timestamp)?;

    let secret_key = SecretKey::from_slice(decrypted_master_key)?;

    Ok((expire_time, secret_key))
}

#[cfg(test)]
mod session_tests {
    use super::*;
    use anyhow::Result;
    use sqlx::SqlitePool;

    const ONE_SECOND: std::time::Duration = std::time::Duration::from_secs(1);

    #[sqlx::test]
    async fn test_new_token_default_lifetime(db: SqlitePool) -> Result<()> {
        let secret_key = SecretKey::default();
        let now = OffsetDateTime::now_utc();
        let token = SessionToken::new(&db, secret_key, None).await?;

        let expire_time = token.get_expire_time(&db).await?;

        let time_diff = expire_time - (now + DEFAULT_SESSION_LIFETIME);

        assert!(time_diff < ONE_SECOND);

        Ok(())
    }

    #[sqlx::test]
    async fn test_new_token_custom_lifetime(db: SqlitePool) -> Result<()> {
        let session_lifetime = time::Duration::hours(4);
        let secret_key = SecretKey::default();
        let now = OffsetDateTime::now_utc();
        let token = SessionToken::new(&db, secret_key, Some(session_lifetime)).await?;

        let expire_time = token.get_expire_time(&db).await?;

        let time_diff = expire_time - (now + session_lifetime);

        assert!(time_diff < ONE_SECOND);

        Ok(())
    }
}