shortener 0.1.1

A simple URL shortener.
Documentation
use std::path::Path;
use std::sync::{Mutex, MutexGuard};

use rusqlite::{Connection, OpenFlags, OptionalExtension, named_params};
use thiserror::Error;

use crate::api_key::{generate_api_key, hash_api_key};

#[derive(Debug, Error)]
pub enum DatabaseError {
  #[error("not found")]
  NotFound,
  #[error("username already in use")]
  UsernameAlreadyInUse,
  #[error("code already in use")]
  CodeAlreadyInUse,
  #[error("could not generate API key")]
  CouldNotGenerateApiKey,
  #[error(transparent)]
  ApiKey(#[from] base64::DecodeError),
  #[error(transparent)]
  Sqlite(#[from] rusqlite::Error),
}

pub struct Database {
  connection: Mutex<Connection>,
}

impl Database {
  pub fn new(
    path: impl AsRef<Path>,
    create: bool,
  ) -> Result<Self, DatabaseError> {
    let flags = if create {
      OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE
    } else {
      OpenFlags::SQLITE_OPEN_READ_WRITE
    };
    let connection = Connection::open_with_flags(path, flags)?;
    Ok(Self {
      connection: Mutex::new(connection),
    })
  }

  pub fn init(&self) -> Result<(), DatabaseError> {
    let connection = self.connection();
    connection.execute_batch(
      r#"
        PRAGMA foreign_keys = ON;

        CREATE TABLE IF NOT EXISTS Users(
          username TEXT PRIMARY KEY,
          active   INTEGER NOT NULL
        );

        CREATE TABLE IF NOT EXISTS Urls(
          code       TEXT PRIMARY KEY,
          url        TEXT NOT NULL,
          created_at INTEGER NOT NULL,
          created_by TEXT NOT NULL,
          hits       INTEGER NOT NULL,
          last_hit   INTEGER
        );

        CREATE TABLE IF NOT EXISTS ApiKeys(
          key_hash TEXT PRIMARY KEY,
          username TEXT NOT NULL,
          active   INTEGER NOT NULL,
          FOREIGN KEY(username) REFERENCES Users(username)
        );
      "#,
    )?;
    Ok(())
  }

  pub fn create_user(&self, username: &str) -> Result<(), DatabaseError> {
    let connection = self.connection();
    match connection.execute(
      r#"
        INSERT INTO Users(username, active)
        VALUES (:username, 1);
      "#,
      named_params! { ":username": username },
    ) {
      Ok(_) => Ok(()),
      Err(error) if is_constraint(&error) => {
        Err(DatabaseError::UsernameAlreadyInUse)
      }
      Err(error) => Err(error.into()),
    }
  }

  pub fn list_users(&self) -> Result<Vec<String>, DatabaseError> {
    let connection = self.connection();
    let mut statement = connection.prepare(
      r#"
        SELECT username
        FROM Users
        WHERE active = 1;
      "#,
    )?;
    let users = statement
      .query_map([], |row| row.get::<_, String>(0))?
      .collect::<Result<Vec<_>, _>>()?;
    Ok(users)
  }

  pub fn delete_user(&self, username: &str) -> Result<(), DatabaseError> {
    let connection = self.connection();
    connection.execute(
      r#"
        UPDATE ApiKeys
        SET active = 0
        WHERE username = :username
          AND active = 1;
      "#,
      named_params! { ":username": username },
    )?;

    let rows = connection.execute(
      r#"
        UPDATE Users
        SET active = 0
        WHERE username = :username
          AND active = 1;
      "#,
      named_params! { ":username": username },
    )?;
    if rows == 0 {
      return Err(DatabaseError::NotFound);
    }
    Ok(())
  }

  pub fn create_api_key(
    &self,
    username: &str,
  ) -> Result<String, DatabaseError> {
    let connection = self.connection();
    let active = connection
      .query_row(
        r#"
          SELECT active
          FROM Users
          WHERE username = :username
            AND active = 1;
        "#,
        named_params! { ":username": username },
        |row| row.get::<_, i64>(0),
      )
      .optional()?;
    if active.is_none() {
      return Err(DatabaseError::NotFound);
    }

    for _attempt in 0..3 {
      let key = generate_api_key();
      let key_hash = hash_api_key(&key)?;
      match connection.execute(
        r#"
          INSERT INTO ApiKeys(key_hash, username, active)
          VALUES (:key_hash, :username, 1);
        "#,
        named_params! { ":key_hash": key_hash, ":username": username },
      ) {
        Ok(_) => return Ok(key),
        Err(error) if is_constraint(&error) => continue,
        Err(error) => return Err(error.into()),
      }
    }

    Err(DatabaseError::CouldNotGenerateApiKey)
  }

  pub fn check_api_key(&self, key: &str) -> Result<String, DatabaseError> {
    let key_hash = hash_api_key(key)?;
    self.check_api_key_by_hash(&key_hash)
  }

  pub fn check_api_key_by_hash(
    &self,
    key_hash: &str,
  ) -> Result<String, DatabaseError> {
    let connection = self.connection();
    let username = connection
      .query_row(
        r#"
          SELECT ak.username
          FROM ApiKeys ak
          JOIN Users u
            ON ak.username = u.username
          WHERE ak.key_hash = :key_hash
            AND ak.active = 1
            AND u.active = 1;
        "#,
        named_params! { ":key_hash": key_hash },
        |row| row.get::<_, String>(0),
      )
      .optional()?;
    username.ok_or(DatabaseError::NotFound)
  }

  pub fn list_api_keys(
    &self,
    username: &str,
  ) -> Result<Vec<String>, DatabaseError> {
    let connection = self.connection();
    let mut statement = connection.prepare(
      r#"
        SELECT key_hash
        FROM ApiKeys
        WHERE username = :username
          AND active = 1;
      "#,
    )?;
    let keys = statement
      .query_map(named_params! { ":username": username }, |row| {
        row.get::<_, String>(0)
      })?
      .collect::<Result<Vec<_>, _>>()?;
    Ok(keys)
  }

  pub fn delete_api_key(&self, key: &str) -> Result<(), DatabaseError> {
    let key_hash = hash_api_key(key)?;
    self.delete_api_key_by_hash(&key_hash)
  }

  pub fn delete_api_key_by_hash(
    &self,
    key_hash: &str,
  ) -> Result<(), DatabaseError> {
    let connection = self.connection();
    let rows = connection.execute(
      r#"
        UPDATE ApiKeys
        SET active = 0
        WHERE key_hash = :key_hash
          AND active = 1;
      "#,
      named_params! { ":key_hash": key_hash },
    )?;
    if rows == 0 {
      return Err(DatabaseError::NotFound);
    }
    Ok(())
  }

  pub fn get_url(&self, code: &str) -> Result<String, DatabaseError> {
    let connection = self.connection();
    let url = connection
      .query_row(
        r#"
            SELECT url
            FROM Urls
            WHERE code = :code;
          "#,
        named_params! { ":code": code },
        |row| row.get::<_, String>(0),
      )
      .optional()?;
    let url = url.ok_or(DatabaseError::NotFound)?;

    connection.execute(
      r#"
        UPDATE Urls
        SET hits = hits + 1, last_hit = UNIXEPOCH()
        WHERE code = :code;
      "#,
      named_params! { ":code": code },
    )?;

    Ok(url)
  }

  pub fn create_code(
    &self,
    url: &str,
    code: &str,
    created_by: &str,
  ) -> Result<(), DatabaseError> {
    let connection = self.connection();
    match connection.execute(
      r#"
        INSERT INTO Urls(code, url, created_at, created_by, hits, last_hit)
        VALUES (:code, :url, UNIXEPOCH(), :created_by, 0, NULL);
      "#,
      named_params! { ":code": code, ":url": url, ":created_by": created_by },
    ) {
      Ok(_) => Ok(()),
      Err(error) if is_constraint(&error) => {
        Err(DatabaseError::CodeAlreadyInUse)
      }
      Err(error) => Err(error.into()),
    }
  }

  fn connection(&self) -> MutexGuard<'_, Connection> {
    self.connection.lock().expect("database mutex poisoned")
  }
}

fn is_constraint(error: &rusqlite::Error) -> bool {
  matches!(
    error,
    rusqlite::Error::SqliteFailure(sqlite_error, _)
      if sqlite_error.code == rusqlite::ErrorCode::ConstraintViolation
  )
}