use std::future::Future;
use tokio_rusqlite::Connection as AsyncConnection;
use tokio_rusqlite::rusqlite;
use crate::error::DBError;
use sqlw::{FromRow, Query, RowCell, RowError, RowLike, Value, ValueRef};
pub struct SqliteRowRef<'a> {
row: &'a tokio_rusqlite::rusqlite::Row<'a>,
}
impl<'a> SqliteRowRef<'a> {
pub fn new(row: &'a tokio_rusqlite::rusqlite::Row<'a>) -> Self {
Self { row }
}
}
impl<'a> RowLike for SqliteRowRef<'a> {
fn cell<'b>(&'b self, name: &str) -> Result<RowCell<'b>, RowError> {
use tokio_rusqlite::rusqlite::{self, types::ValueRef as SqliteValueRef};
let sqlite_value: SqliteValueRef = self.row.get_ref(name).map_err(|e| match e {
rusqlite::Error::InvalidColumnName(column) => RowError::ColumnNotFound { name: column },
other => RowError::Any(other.to_string()),
})?;
let value_ref = match sqlite_value {
SqliteValueRef::Null => ValueRef::Null,
SqliteValueRef::Integer(i) => ValueRef::Int(i),
SqliteValueRef::Real(f) => ValueRef::Float(f),
SqliteValueRef::Text(t) => {
ValueRef::Text(std::str::from_utf8(t).map_err(|_| RowError::TypeMismatch {
expected: "valid UTF-8",
found: "invalid UTF-8".to_string(),
})?)
}
SqliteValueRef::Blob(b) => ValueRef::Blob(b),
};
Ok(RowCell::Borrowed(value_ref))
}
}
pub struct SqliteExecutor(AsyncConnection);
impl SqliteExecutor {
pub async fn new<F, Fut>(connector: F) -> Result<Self, tokio_rusqlite::Error>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<AsyncConnection, tokio_rusqlite::Error>> + Send + 'static,
{
let instance = connector().await?;
Ok(SqliteExecutor(instance))
}
}
impl sqlw::QueryExecutor for SqliteExecutor {
type Error = DBError;
fn query_void(&self, query: Query) -> impl Future<Output = Result<(), DBError>> {
async move {
let (sql, args) = query.split();
let args: Vec<rusqlite::types::Value> =
args.into_iter().map(sqlw_to_rusqlite_value).collect();
self.0
.call(
move |conn: &mut rusqlite::Connection| -> Result<(), rusqlite::Error> {
conn.execute(&sql, rusqlite::params_from_iter(args.iter()))?;
Ok(())
},
)
.await
.map_err(|e: tokio_rusqlite::Error| DBError::Execution(e.into()))
}
}
fn query_one<T: FromRow + Send + 'static>(
&self,
query: Query,
) -> impl Future<Output = Result<Option<T>, DBError>> {
async move {
let (sql, args) = query.split();
let args: Vec<rusqlite::types::Value> =
args.into_iter().map(sqlw_to_rusqlite_value).collect();
self.0
.call(
move |conn: &mut rusqlite::Connection| -> Result<Option<T>, rusqlite::Error> {
let mut stmt = conn.prepare(&sql)?;
let mut rows = stmt.query(rusqlite::params_from_iter(args.iter()))?;
if let Some(row) = rows.next()? {
let row_ref = SqliteRowRef::new(row);
let t = T::from_row(&row_ref).map_err(|e| {
rusqlite::Error::SqliteFailure(
rusqlite::ffi::Error::new(1),
Some(e.to_string()),
)
})?;
Ok(Some(t))
} else {
Ok(None)
}
},
)
.await
.map_err(|e: tokio_rusqlite::Error| DBError::Execution(e.into()))
}
}
fn query_list<T: FromRow + Send + 'static>(
&self,
query: Query,
) -> impl Future<Output = Result<Vec<T>, DBError>> {
async move {
let (sql, args) = query.split();
let args: Vec<rusqlite::types::Value> =
args.into_iter().map(sqlw_to_rusqlite_value).collect();
self.0
.call(
move |conn: &mut rusqlite::Connection| -> Result<Vec<T>, rusqlite::Error> {
let mut stmt = conn.prepare(&sql)?;
let mut rows = stmt.query(rusqlite::params_from_iter(args.iter()))?;
let mut results = Vec::new();
while let Some(row) = rows.next()? {
let row_ref = SqliteRowRef::new(row);
let t = T::from_row(&row_ref).map_err(|e| {
rusqlite::Error::SqliteFailure(
rusqlite::ffi::Error::new(1),
Some(e.to_string()),
)
})?;
results.push(t);
}
Ok(results)
},
)
.await
.map_err(|e: tokio_rusqlite::Error| DBError::Execution(e.into()))
}
}
}
fn sqlw_to_rusqlite_value(value: Value) -> rusqlite::types::Value {
use rusqlite::types::Value as SqliteValue;
match value {
Value::Text(s) => SqliteValue::Text(s),
Value::Int(i) => SqliteValue::Integer(i),
Value::Float(f) => SqliteValue::Real(f),
Value::Bool(b) => SqliteValue::Integer(i64::from(b)),
Value::Blob(b) => SqliteValue::Blob(b),
Value::Null => SqliteValue::Null,
}
}