sqlw-backend 0.1.0

Database executor implementations for sqlw
Documentation
//! SQLite executor and row types using `tokio-rusqlite`.

use std::future::Future;

use tokio_rusqlite::Connection as AsyncConnection;
use tokio_rusqlite::rusqlite;

use crate::error::DBError;
use sqlw::{FromRow, Query, RowCell, RowError, RowLike, Value, ValueRef};

/// A borrowed row from a SQLite query result.
pub struct SqliteRowRef<'a> {
    row: &'a tokio_rusqlite::rusqlite::Row<'a>,
}

impl<'a> SqliteRowRef<'a> {
    /// Creates a new row reference from a rusqlite row.
    pub fn new(row: &'a tokio_rusqlite::rusqlite::Row<'a>) -> Self {
        Self { row }
    }
}

impl<'a> RowLike for SqliteRowRef<'a> {
    fn cell<'b>(&'b self, name: &str) -> Result<RowCell<'b>, RowError> {
        use tokio_rusqlite::rusqlite::{self, types::ValueRef as SqliteValueRef};

        let sqlite_value: SqliteValueRef = self.row.get_ref(name).map_err(|e| match e {
            rusqlite::Error::InvalidColumnName(column) => RowError::ColumnNotFound { name: column },
            other => RowError::Any(other.to_string()),
        })?;

        let value_ref = match sqlite_value {
            SqliteValueRef::Null => ValueRef::Null,
            SqliteValueRef::Integer(i) => ValueRef::Int(i),
            SqliteValueRef::Real(f) => ValueRef::Float(f),
            SqliteValueRef::Text(t) => {
                ValueRef::Text(std::str::from_utf8(t).map_err(|_| RowError::TypeMismatch {
                    expected: "valid UTF-8",
                    found: "invalid UTF-8".to_string(),
                })?)
            }
            SqliteValueRef::Blob(b) => ValueRef::Blob(b),
        };

        Ok(RowCell::Borrowed(value_ref))
    }
}

/// An async SQLite executor using `tokio-rusqlite`.
///
/// Provides async database operations via the [`QueryExecutor`](sqlw::QueryExecutor) trait.
pub struct SqliteExecutor(AsyncConnection);

impl SqliteExecutor {
    /// Creates a new SQLite executor using the given connector function.
    pub async fn new<F, Fut>(connector: F) -> Result<Self, tokio_rusqlite::Error>
    where
        F: FnOnce() -> Fut,
        Fut: Future<Output = Result<AsyncConnection, tokio_rusqlite::Error>> + Send + 'static,
    {
        let instance = connector().await?;
        Ok(SqliteExecutor(instance))
    }
}

impl sqlw::QueryExecutor for SqliteExecutor {
    type Error = DBError;

    fn query_void(&self, query: Query) -> impl Future<Output = Result<(), DBError>> {
        async move {
            let (sql, args) = query.split();
            let args: Vec<rusqlite::types::Value> =
                args.into_iter().map(sqlw_to_rusqlite_value).collect();

            self.0
                .call(
                    move |conn: &mut rusqlite::Connection| -> Result<(), rusqlite::Error> {
                        conn.execute(&sql, rusqlite::params_from_iter(args.iter()))?;
                        Ok(())
                    },
                )
                .await
                .map_err(|e: tokio_rusqlite::Error| DBError::Execution(e.into()))
        }
    }

    fn query_one<T: FromRow + Send + 'static>(
        &self,
        query: Query,
    ) -> impl Future<Output = Result<Option<T>, DBError>> {
        async move {
            let (sql, args) = query.split();
            let args: Vec<rusqlite::types::Value> =
                args.into_iter().map(sqlw_to_rusqlite_value).collect();

            self.0
                .call(
                    move |conn: &mut rusqlite::Connection| -> Result<Option<T>, rusqlite::Error> {
                        let mut stmt = conn.prepare(&sql)?;
                        let mut rows = stmt.query(rusqlite::params_from_iter(args.iter()))?;

                        if let Some(row) = rows.next()? {
                            let row_ref = SqliteRowRef::new(row);
                            let t = T::from_row(&row_ref).map_err(|e| {
                                rusqlite::Error::SqliteFailure(
                                    rusqlite::ffi::Error::new(1),
                                    Some(e.to_string()),
                                )
                            })?;
                            Ok(Some(t))
                        } else {
                            Ok(None)
                        }
                    },
                )
                .await
                .map_err(|e: tokio_rusqlite::Error| DBError::Execution(e.into()))
        }
    }

    fn query_list<T: FromRow + Send + 'static>(
        &self,
        query: Query,
    ) -> impl Future<Output = Result<Vec<T>, DBError>> {
        async move {
            let (sql, args) = query.split();
            let args: Vec<rusqlite::types::Value> =
                args.into_iter().map(sqlw_to_rusqlite_value).collect();

            self.0
                .call(
                    move |conn: &mut rusqlite::Connection| -> Result<Vec<T>, rusqlite::Error> {
                        let mut stmt = conn.prepare(&sql)?;
                        let mut rows = stmt.query(rusqlite::params_from_iter(args.iter()))?;

                        let mut results = Vec::new();
                        while let Some(row) = rows.next()? {
                            let row_ref = SqliteRowRef::new(row);
                            let t = T::from_row(&row_ref).map_err(|e| {
                                rusqlite::Error::SqliteFailure(
                                    rusqlite::ffi::Error::new(1),
                                    Some(e.to_string()),
                                )
                            })?;
                            results.push(t);
                        }
                        Ok(results)
                    },
                )
                .await
                .map_err(|e: tokio_rusqlite::Error| DBError::Execution(e.into()))
        }
    }
}

fn sqlw_to_rusqlite_value(value: Value) -> rusqlite::types::Value {
    use rusqlite::types::Value as SqliteValue;

    match value {
        Value::Text(s) => SqliteValue::Text(s),
        Value::Int(i) => SqliteValue::Integer(i),
        Value::Float(f) => SqliteValue::Real(f),
        Value::Bool(b) => SqliteValue::Integer(i64::from(b)),
        Value::Blob(b) => SqliteValue::Blob(b),
        Value::Null => SqliteValue::Null,
    }
}