rustio-core 0.3.1

Runtime core for RustIO: HTTP server, router, middleware, ORM, admin, and migrations.
Documentation
//! SQLite-backed ORM.
//!
//! Implement [`Model`] on your struct to get `find / all / create / update /
//! delete` for free. SQLx is used internally; user code never references it.
//!
//! Phase 1 supports `i32`, `i64`, `String`, and `bool` field types; the id
//! column is required to be `i64`.

use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteRow};
use sqlx::Row as _;

use crate::error::Error;

#[derive(Clone)]
pub struct Db {
    pool: SqlitePool,
}

impl Db {
    pub async fn connect(url: &str) -> Result<Self, Error> {
        let pool = SqlitePool::connect(url).await?;
        Ok(Self { pool })
    }

    pub async fn memory() -> Result<Self, Error> {
        let pool = SqlitePoolOptions::new()
            .max_connections(1)
            .connect("sqlite::memory:")
            .await?;
        Ok(Self { pool })
    }

    pub async fn execute(&self, sql: &str) -> Result<(), Error> {
        sqlx::query(sql).execute(&self.pool).await?;
        Ok(())
    }

    pub(crate) fn pool(&self) -> &SqlitePool {
        &self.pool
    }
}

pub struct Row<'a> {
    inner: &'a SqliteRow,
}

impl<'a> Row<'a> {
    pub(crate) fn new(inner: &'a SqliteRow) -> Self {
        Self { inner }
    }

    pub fn get_i32(&self, name: &str) -> Result<i32, Error> {
        self.inner.try_get(name).map_err(Error::from)
    }

    pub fn get_i64(&self, name: &str) -> Result<i64, Error> {
        self.inner.try_get(name).map_err(Error::from)
    }

    pub fn get_string(&self, name: &str) -> Result<String, Error> {
        self.inner.try_get(name).map_err(Error::from)
    }

    pub fn get_bool(&self, name: &str) -> Result<bool, Error> {
        self.inner.try_get(name).map_err(Error::from)
    }
}

#[derive(Debug)]
pub enum Value {
    I32(i32),
    I64(i64),
    String(String),
    Bool(bool),
    Null,
}

impl From<i32> for Value {
    fn from(v: i32) -> Self {
        Value::I32(v)
    }
}

impl From<i64> for Value {
    fn from(v: i64) -> Self {
        Value::I64(v)
    }
}

impl From<String> for Value {
    fn from(v: String) -> Self {
        Value::String(v)
    }
}

impl From<&str> for Value {
    fn from(v: &str) -> Self {
        Value::String(v.to_owned())
    }
}

impl From<bool> for Value {
    fn from(v: bool) -> Self {
        Value::Bool(v)
    }
}

pub trait Model: Sized + Send + Sync + Unpin + 'static {
    const TABLE: &'static str;
    const COLUMNS: &'static [&'static str];
    const INSERT_COLUMNS: &'static [&'static str];

    fn id(&self) -> i64;
    fn from_row(row: Row<'_>) -> Result<Self, Error>;
    fn insert_values(&self) -> Vec<Value>;

    fn find(
        db: &Db,
        id: i64,
    ) -> impl std::future::Future<Output = Result<Option<Self>, Error>> + Send
    where
        Self: Send,
    {
        async move {
            let sql = format!(
                "SELECT {} FROM {} WHERE id = ?",
                Self::COLUMNS.join(", "),
                Self::TABLE,
            );
            let row = sqlx::query(&sql).bind(id).fetch_optional(db.pool()).await?;
            match row {
                Some(r) => Ok(Some(Self::from_row(Row::new(&r))?)),
                None => Ok(None),
            }
        }
    }

    fn all(db: &Db) -> impl std::future::Future<Output = Result<Vec<Self>, Error>> + Send {
        async move {
            let sql = format!("SELECT {} FROM {}", Self::COLUMNS.join(", "), Self::TABLE);
            let rows = sqlx::query(&sql).fetch_all(db.pool()).await?;
            rows.iter().map(|r| Self::from_row(Row::new(r))).collect()
        }
    }

    fn create<'a>(
        &'a self,
        db: &'a Db,
    ) -> impl std::future::Future<Output = Result<i64, Error>> + Send + 'a {
        async move {
            let placeholders = vec!["?"; Self::INSERT_COLUMNS.len()].join(", ");
            let sql = format!(
                "INSERT INTO {} ({}) VALUES ({})",
                Self::TABLE,
                Self::INSERT_COLUMNS.join(", "),
                placeholders,
            );
            let mut query = sqlx::query(&sql);
            for v in self.insert_values() {
                query = bind_value(query, v);
            }
            let result = query.execute(db.pool()).await?;
            Ok(result.last_insert_rowid())
        }
    }

    fn update<'a>(
        &'a self,
        db: &'a Db,
    ) -> impl std::future::Future<Output = Result<(), Error>> + Send + 'a {
        async move {
            let assignments: Vec<String> = Self::INSERT_COLUMNS
                .iter()
                .map(|c| format!("{c} = ?"))
                .collect();
            let sql = format!(
                "UPDATE {} SET {} WHERE id = ?",
                Self::TABLE,
                assignments.join(", "),
            );
            let mut query = sqlx::query(&sql);
            for v in self.insert_values() {
                query = bind_value(query, v);
            }
            query = query.bind(self.id());
            query.execute(db.pool()).await?;
            Ok(())
        }
    }

    fn delete(db: &Db, id: i64) -> impl std::future::Future<Output = Result<(), Error>> + Send {
        async move {
            let sql = format!("DELETE FROM {} WHERE id = ?", Self::TABLE);
            sqlx::query(&sql).bind(id).execute(db.pool()).await?;
            Ok(())
        }
    }
}

fn bind_value<'q>(
    query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
    value: Value,
) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
    match value {
        Value::I32(v) => query.bind(v),
        Value::I64(v) => query.bind(v),
        Value::String(v) => query.bind(v),
        Value::Bool(v) => query.bind(v),
        Value::Null => query.bind(Option::<i64>::None),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Debug, PartialEq)]
    struct User {
        id: i64,
        name: String,
        is_admin: bool,
    }

    impl Model for User {
        const TABLE: &'static str = "users";
        const COLUMNS: &'static [&'static str] = &["id", "name", "is_admin"];
        const INSERT_COLUMNS: &'static [&'static str] = &["name", "is_admin"];

        fn id(&self) -> i64 {
            self.id
        }

        fn from_row(row: Row<'_>) -> Result<Self, Error> {
            Ok(Self {
                id: row.get_i64("id")?,
                name: row.get_string("name")?,
                is_admin: row.get_bool("is_admin")?,
            })
        }

        fn insert_values(&self) -> Vec<Value> {
            vec![self.name.clone().into(), self.is_admin.into()]
        }
    }

    async fn setup() -> Db {
        let db = Db::memory().await.unwrap();
        db.execute(
            "CREATE TABLE users (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                name TEXT NOT NULL,
                is_admin INTEGER NOT NULL
            )",
        )
        .await
        .unwrap();
        db
    }

    #[tokio::test]
    async fn create_assigns_new_id_and_find_reads_it_back() {
        let db = setup().await;
        let u = User {
            id: 0,
            name: "Alice".into(),
            is_admin: false,
        };
        let id = u.create(&db).await.unwrap();
        assert!(id >= 1);
        let back = User::find(&db, id).await.unwrap().unwrap();
        assert_eq!(back.name, "Alice");
        assert!(!back.is_admin);
        assert_eq!(back.id, id);
    }

    #[tokio::test]
    async fn find_missing_returns_none() {
        let db = setup().await;
        assert!(User::find(&db, 42).await.unwrap().is_none());
    }

    #[tokio::test]
    async fn all_returns_every_row() {
        let db = setup().await;
        User {
            id: 0,
            name: "a".into(),
            is_admin: false,
        }
        .create(&db)
        .await
        .unwrap();
        User {
            id: 0,
            name: "b".into(),
            is_admin: true,
        }
        .create(&db)
        .await
        .unwrap();
        User {
            id: 0,
            name: "c".into(),
            is_admin: false,
        }
        .create(&db)
        .await
        .unwrap();
        let rows = User::all(&db).await.unwrap();
        assert_eq!(rows.len(), 3);
        let names: Vec<&str> = rows.iter().map(|u| u.name.as_str()).collect();
        assert_eq!(names, vec!["a", "b", "c"]);
    }

    #[tokio::test]
    async fn update_modifies_row_in_place() {
        let db = setup().await;
        let id = User {
            id: 0,
            name: "old".into(),
            is_admin: false,
        }
        .create(&db)
        .await
        .unwrap();
        let updated = User {
            id,
            name: "new".into(),
            is_admin: true,
        };
        updated.update(&db).await.unwrap();
        let back = User::find(&db, id).await.unwrap().unwrap();
        assert_eq!(back.name, "new");
        assert!(back.is_admin);
    }

    #[tokio::test]
    async fn delete_removes_row() {
        let db = setup().await;
        let id = User {
            id: 0,
            name: "x".into(),
            is_admin: false,
        }
        .create(&db)
        .await
        .unwrap();
        assert!(User::find(&db, id).await.unwrap().is_some());
        User::delete(&db, id).await.unwrap();
        assert!(User::find(&db, id).await.unwrap().is_none());
    }

    #[tokio::test]
    async fn row_getters_handle_wrong_type_gracefully() {
        let db = setup().await;
        User {
            id: 0,
            name: "a".into(),
            is_admin: false,
        }
        .create(&db)
        .await
        .unwrap();
        let row = sqlx::query("SELECT id, name, is_admin FROM users LIMIT 1")
            .fetch_one(db.pool())
            .await
            .unwrap();
        let wrapped = Row::new(&row);
        assert!(wrapped.get_i64("id").is_ok());
        assert!(wrapped.get_string("nonexistent_column").is_err());
    }
}