use std::{ops::Deref, sync::Arc};
use async_trait::async_trait;
use rusqlite::{types::Value, Connection, Row, Rows, Statement};
use thiserror::Error;
use tokio::sync::Mutex;
use crate::{
query::{BooleanExpression, Expression, ExpressionType, Join, Sort, SortDirection},
Database, DatabaseError, DatabaseValue, DeleteStatement, InsertStatement, SelectQuery,
UpdateStatement, UpsertMultiStatement, UpsertStatement,
};
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct RusqliteDatabase {
connection: Arc<Mutex<Connection>>,
}
impl RusqliteDatabase {
pub const fn new(connection: Arc<Mutex<Connection>>) -> Self {
Self { connection }
}
}
trait ToSql {
fn to_sql(&self) -> String;
}
impl<T: Expression + ?Sized> ToSql for T {
#[allow(clippy::too_many_lines)]
fn to_sql(&self) -> String {
match self.expression_type() {
ExpressionType::Eq(value) => {
if value.right.is_null() {
format!("({} IS {})", value.left.to_sql(), value.right.to_sql())
} else {
format!("({} = {})", value.left.to_sql(), value.right.to_sql())
}
}
ExpressionType::Gt(value) => {
if value.right.is_null() {
panic!("Invalid > comparison with NULL");
} else {
format!("({} > {})", value.left.to_sql(), value.right.to_sql())
}
}
ExpressionType::In(value) => {
format!("{} IN ({})", value.left.to_sql(), value.values.to_sql())
}
ExpressionType::NotIn(value) => {
format!("{} NOT IN ({})", value.left.to_sql(), value.values.to_sql())
}
ExpressionType::Lt(value) => {
if value.right.is_null() {
panic!("Invalid < comparison with NULL");
} else {
format!("({} < {})", value.left.to_sql(), value.right.to_sql())
}
}
ExpressionType::Or(value) => format!(
"({})",
value
.conditions
.iter()
.map(|x| x.to_sql())
.collect::<Vec<_>>()
.join(" OR ")
),
ExpressionType::And(value) => format!(
"({})",
value
.conditions
.iter()
.map(|x| x.to_sql())
.collect::<Vec<_>>()
.join(" AND ")
),
ExpressionType::Gte(value) => {
if value.right.is_null() {
panic!("Invalid >= comparison with NULL");
} else {
format!("({} >= {})", value.left.to_sql(), value.right.to_sql())
}
}
ExpressionType::Lte(value) => {
if value.right.is_null() {
panic!("Invalid <= comparison with NULL");
} else {
format!("({} <= {})", value.left.to_sql(), value.right.to_sql())
}
}
ExpressionType::Join(value) => format!(
"{} JOIN {} ON {}",
if value.left { "LEFT" } else { "" },
value.table_name,
value.on
),
ExpressionType::Sort(value) => format!(
"({}) {}",
value.expression.to_sql(),
match value.direction {
SortDirection::Asc => "ASC",
SortDirection::Desc => "DESC",
}
),
ExpressionType::NotEq(value) => {
if value.right.is_null() {
format!("({} IS NOT {})", value.left.to_sql(), value.right.to_sql())
} else {
format!("({} != {})", value.left.to_sql(), value.right.to_sql())
}
}
ExpressionType::InList(value) => value
.values
.iter()
.map(|value| value.to_sql())
.collect::<Vec<_>>()
.join(","),
ExpressionType::Coalesce(value) => format!(
"IFNULL({})",
value
.values
.iter()
.map(|value| value.to_sql())
.collect::<Vec<_>>()
.join(",")
),
ExpressionType::Literal(value) => value.value.to_string(),
ExpressionType::Identifier(value) => value.value.clone(),
ExpressionType::SelectQuery(value) => {
let joins = value.joins.as_ref().map_or_else(String::new, |joins| {
joins.iter().map(Join::to_sql).collect::<Vec<_>>().join(" ")
});
let where_clause = value.filters.as_ref().map_or_else(String::new, |filters| {
if filters.is_empty() {
String::new()
} else {
format!(
"WHERE {}",
filters
.iter()
.map(|x| format!("({})", x.to_sql()))
.collect::<Vec<_>>()
.join(" AND ")
)
}
});
let sort_clause = value.sorts.as_ref().map_or_else(String::new, |sorts| {
if sorts.is_empty() {
String::new()
} else {
format!(
"ORDER BY {}",
sorts
.iter()
.map(Sort::to_sql)
.collect::<Vec<_>>()
.join(", ")
)
}
});
let limit = value
.limit
.map_or_else(String::new, |limit| format!("LIMIT {limit}"));
format!(
"SELECT {} {} FROM {} {} {} {} {}",
if value.distinct { "DISTINCT" } else { "" },
value.columns.join(", "),
value.table_name,
joins,
where_clause,
sort_clause,
limit
)
}
ExpressionType::DatabaseValue(value) => match value {
DatabaseValue::Null
| DatabaseValue::BoolOpt(None)
| DatabaseValue::StringOpt(None)
| DatabaseValue::NumberOpt(None)
| DatabaseValue::UNumberOpt(None)
| DatabaseValue::RealOpt(None) => "NULL".to_string(),
DatabaseValue::Now => "strftime('%Y-%m-%dT%H:%M:%f', 'now')".to_string(),
DatabaseValue::NowAdd(ref add) => {
format!("strftime('%Y-%m-%dT%H:%M:%f', DateTime('now', 'LocalTime', {add}))")
}
_ => "?".to_string(),
},
}
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Error)]
pub enum RusqliteDatabaseError {
#[error(transparent)]
Rusqlite(#[from] rusqlite::Error),
#[error("No ID")]
NoId,
#[error("No row")]
NoRow,
#[error("Invalid request")]
InvalidRequest,
#[error("Missing unique")]
MissingUnique,
}
impl From<RusqliteDatabaseError> for DatabaseError {
fn from(value: RusqliteDatabaseError) -> Self {
Self::Rusqlite(value)
}
}
#[async_trait]
impl Database for RusqliteDatabase {
async fn query(&self, query: &SelectQuery<'_>) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(select(
&*self.connection.lock().await,
query.table_name,
query.distinct,
query.columns,
query.filters.as_deref(),
query.joins.as_deref(),
query.sorts.as_deref(),
query.limit,
)?)
}
async fn query_first(
&self,
query: &SelectQuery<'_>,
) -> Result<Option<crate::Row>, DatabaseError> {
Ok(find_row(
&*self.connection.lock().await,
query.table_name,
query.distinct,
query.columns,
query.filters.as_deref(),
query.joins.as_deref(),
query.sorts.as_deref(),
)?)
}
async fn exec_delete(
&self,
statement: &DeleteStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(delete(
&*self.connection.lock().await,
statement.table_name,
statement.filters.as_deref(),
statement.limit,
)?)
}
async fn exec_delete_first(
&self,
statement: &DeleteStatement<'_>,
) -> Result<Option<crate::Row>, DatabaseError> {
Ok(delete(
&*self.connection.lock().await,
statement.table_name,
statement.filters.as_deref(),
Some(1),
)?
.into_iter()
.next())
}
async fn exec_insert(
&self,
statement: &InsertStatement<'_>,
) -> Result<crate::Row, DatabaseError> {
Ok(insert_and_get_row(
&*self.connection.lock().await,
statement.table_name,
&statement.values,
)?)
}
async fn exec_update(
&self,
statement: &UpdateStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(update_and_get_rows(
&*self.connection.lock().await,
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)?)
}
async fn exec_update_first(
&self,
statement: &UpdateStatement<'_>,
) -> Result<Option<crate::Row>, DatabaseError> {
Ok(update_and_get_row(
&*self.connection.lock().await,
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)?)
}
async fn exec_upsert(
&self,
statement: &UpsertStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(upsert(
&*self.connection.lock().await,
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)?)
}
async fn exec_upsert_first(
&self,
statement: &UpsertStatement<'_>,
) -> Result<crate::Row, DatabaseError> {
Ok(upsert_and_get_row(
&*self.connection.lock().await,
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)?)
}
async fn exec_upsert_multi(
&self,
statement: &UpsertMultiStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(upsert_multi(
&*self.connection.lock().await,
statement.table_name,
statement
.unique
.as_ref()
.ok_or(RusqliteDatabaseError::MissingUnique)?,
&statement.values,
)?)
}
}
impl From<Value> for DatabaseValue {
fn from(value: Value) -> Self {
match value {
Value::Null => Self::Null,
Value::Integer(value) => Self::Number(value),
Value::Real(value) => Self::Real(value),
Value::Text(value) => Self::String(value),
Value::Blob(_value) => unimplemented!("Blob types are not supported yet"),
}
}
}
fn from_row(column_names: &[String], row: &Row<'_>) -> Result<crate::Row, RusqliteDatabaseError> {
let mut columns = vec![];
for column in column_names {
columns.push((
column.to_string(),
row.get::<_, Value>(column.as_str())?.into(),
));
}
Ok(crate::Row { columns })
}
fn update_and_get_row(
connection: &Connection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Option<crate::Row>, RusqliteDatabaseError> {
let select_query = limit.map(|_| {
format!(
"SELECT rowid FROM {table_name} {}",
build_where_clause(filters),
)
});
let query = format!(
"UPDATE {table_name} {} {} RETURNING *",
build_set_clause(values),
build_update_where_clause(filters, limit, select_query.as_deref()),
);
let all_values = values
.iter()
.flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
.map(std::convert::Into::into)
.collect::<Vec<_>>();
let mut all_filter_values = filters
.map(|filters| {
filters
.iter()
.flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
.map(std::convert::Into::into)
.collect::<Vec<_>>()
})
.unwrap_or_default();
if limit.is_some() {
all_filter_values.extend(all_filter_values.clone());
}
let all_values = [all_values, all_filter_values].concat();
log::trace!(
"Running update query: {query} with params: {:?}",
all_values
);
let mut statement = connection.prepare_cached(&query)?;
bind_values(&mut statement, Some(&all_values), false, 0)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
let mut query = statement.raw_query();
query
.next()?
.map(|row| from_row(&column_names, row))
.transpose()
}
fn update_and_get_rows(
connection: &Connection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let select_query = limit.map(|_| {
format!(
"SELECT rowid FROM {table_name} {}",
build_where_clause(filters),
)
});
let query = format!(
"UPDATE {table_name} {} {} RETURNING *",
build_set_clause(values),
build_update_where_clause(filters, limit, select_query.as_deref()),
);
let all_values = values
.iter()
.flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
.map(std::convert::Into::into)
.collect::<Vec<_>>();
let mut all_filter_values = filters
.map(|filters| {
filters
.iter()
.flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
.map(std::convert::Into::into)
.collect::<Vec<_>>()
})
.unwrap_or_default();
if limit.is_some() {
all_filter_values.extend(all_filter_values.clone());
}
let all_values = [all_values, all_filter_values].concat();
log::trace!(
"Running update query: {query} with params: {:?}",
all_values
);
let mut statement = connection.prepare_cached(&query)?;
bind_values(&mut statement, Some(&all_values), false, 0)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
to_rows(&column_names, statement.raw_query())
}
fn build_join_clauses(joins: Option<&[Join]>) -> String {
joins.map_or_else(String::new, |joins| {
joins
.iter()
.map(|join| {
format!(
"{}JOIN {} ON {}",
if join.left { "LEFT " } else { "" },
join.table_name,
join.on
)
})
.collect::<Vec<_>>()
.join(" ")
})
}
fn build_where_clause(filters: Option<&[Box<dyn BooleanExpression>]>) -> String {
filters.map_or_else(String::new, |filters| {
if filters.is_empty() {
String::new()
} else {
format!("WHERE {}", build_where_props(filters).join(" AND "))
}
})
}
fn build_where_props(filters: &[Box<dyn BooleanExpression>]) -> Vec<String> {
filters
.iter()
.map(|filter| filter.deref().to_sql())
.collect()
}
fn build_sort_clause(sorts: Option<&[Sort]>) -> String {
sorts.map_or_else(String::new, |sorts| {
if sorts.is_empty() {
String::new()
} else {
format!("ORDER BY {}", build_sort_props(sorts).join(", "))
}
})
}
fn build_sort_props(sorts: &[Sort]) -> Vec<String> {
sorts.iter().map(Sort::to_sql).collect()
}
fn build_update_where_clause(
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
query: Option<&str>,
) -> String {
let clause = build_where_clause(filters);
let limit_clause = build_update_limit_clause(limit, query);
let clause = if limit_clause.is_empty() {
clause
} else if clause.is_empty() {
"WHERE".into()
} else {
clause + " AND"
};
format!("{clause} {limit_clause}").trim().to_string()
}
fn build_update_limit_clause(limit: Option<usize>, query: Option<&str>) -> String {
limit.map_or_else(String::new, |limit| {
query.map_or_else(String::new, |query| {
format!("rowid IN ({query} LIMIT {limit})")
})
})
}
fn build_set_clause(values: &[(&str, Box<dyn Expression>)]) -> String {
if values.is_empty() {
String::new()
} else {
format!("SET {}", build_set_props(values).join(", "))
}
}
fn build_set_props(values: &[(&str, Box<dyn Expression>)]) -> Vec<String> {
values
.iter()
.map(|(name, value)| format!("{name}={}", value.deref().to_sql()))
.collect()
}
fn build_values_clause(values: &[(&str, Box<dyn Expression>)]) -> String {
if values.is_empty() {
"DEFAULT VALUES".to_string()
} else {
format!("VALUES({})", build_values_props(values).join(", "))
}
}
fn build_values_props(values: &[(&str, Box<dyn Expression>)]) -> Vec<String> {
values
.iter()
.map(|(_, value)| value.deref().to_sql())
.collect()
}
fn bind_values(
statement: &mut Statement<'_>,
values: Option<&[RusqliteDatabaseValue]>,
constant_inc: bool,
offset: usize,
) -> Result<usize, RusqliteDatabaseError> {
if let Some(values) = values {
let mut i = 1 + offset;
for value in values {
match &**value {
DatabaseValue::String(value) => {
statement.raw_bind_parameter(i, value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::StringOpt(Some(value)) => {
statement.raw_bind_parameter(i, value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::Null
| DatabaseValue::StringOpt(None)
| DatabaseValue::BoolOpt(None)
| DatabaseValue::NumberOpt(None)
| DatabaseValue::UNumberOpt(None)
| DatabaseValue::RealOpt(None)
| DatabaseValue::Now => (),
DatabaseValue::NowAdd(_add) => (),
DatabaseValue::Bool(value) => {
statement.raw_bind_parameter(i, i32::from(*value))?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::BoolOpt(Some(value)) => {
statement.raw_bind_parameter(i, value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::Number(value) => {
statement.raw_bind_parameter(i, *value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::NumberOpt(Some(value)) => {
statement.raw_bind_parameter(i, *value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::UNumber(value) => {
statement.raw_bind_parameter(i, *value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::UNumberOpt(Some(value)) => {
statement.raw_bind_parameter(i, *value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::Real(value) => {
statement.raw_bind_parameter(i, *value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::RealOpt(Some(value)) => {
statement.raw_bind_parameter(i, *value)?;
if !constant_inc {
i += 1;
}
}
DatabaseValue::DateTime(value) => {
statement.raw_bind_parameter(i, value.to_string())?;
if !constant_inc {
i += 1;
}
}
}
if constant_inc {
i += 1;
}
}
Ok(i - 1)
} else {
Ok(0)
}
}
fn to_rows(
column_names: &[String],
mut rows: Rows<'_>,
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let mut results = vec![];
while let Some(row) = rows.next()? {
results.push(from_row(column_names, row)?);
}
log::trace!(
"Got {} row{}",
results.len(),
if results.len() == 1 { "" } else { "s" }
);
Ok(results)
}
fn to_values(values: &[(&str, DatabaseValue)]) -> Vec<RusqliteDatabaseValue> {
values
.iter()
.map(|(_key, value)| value.clone())
.map(std::convert::Into::into)
.collect::<Vec<_>>()
}
fn exprs_to_values(values: &[(&str, Box<dyn Expression>)]) -> Vec<RusqliteDatabaseValue> {
values
.iter()
.flat_map(|value| value.1.values().into_iter())
.flatten()
.cloned()
.map(std::convert::Into::into)
.collect::<Vec<_>>()
}
fn bexprs_to_values(values: &[Box<dyn BooleanExpression>]) -> Vec<RusqliteDatabaseValue> {
values
.iter()
.flat_map(|value| value.values().into_iter())
.flatten()
.cloned()
.map(std::convert::Into::into)
.collect::<Vec<_>>()
}
#[allow(unused)]
fn to_values_opt(values: Option<&[(&str, DatabaseValue)]>) -> Option<Vec<RusqliteDatabaseValue>> {
values.map(to_values)
}
#[allow(unused)]
fn exprs_to_values_opt(
values: Option<&[(&str, Box<dyn Expression>)]>,
) -> Option<Vec<RusqliteDatabaseValue>> {
values.map(exprs_to_values)
}
fn bexprs_to_values_opt(
values: Option<&[Box<dyn BooleanExpression>]>,
) -> Option<Vec<RusqliteDatabaseValue>> {
values.map(bexprs_to_values)
}
#[allow(clippy::too_many_arguments)]
fn select(
connection: &Connection,
table_name: &str,
distinct: bool,
columns: &[&str],
filters: Option<&[Box<dyn BooleanExpression>]>,
joins: Option<&[Join<'_>]>,
sort: Option<&[Sort]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let query = format!(
"SELECT {} {} FROM {table_name} {} {} {} {}",
if distinct { "DISTINCT" } else { "" },
columns.join(", "),
build_join_clauses(joins),
build_where_clause(filters),
build_sort_clause(sort),
limit.map_or_else(String::new, |limit| format!("LIMIT {limit}"))
);
log::trace!(
"Running select query: {query} with params: {:?}",
filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
);
let mut statement = connection.prepare_cached(&query)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
bind_values(
&mut statement,
bexprs_to_values_opt(filters).as_deref(),
false,
0,
)?;
to_rows(&column_names, statement.raw_query())
}
fn delete(
connection: &Connection,
table_name: &str,
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let where_clause = build_where_clause(filters);
let select_query = limit.map(|_| format!("SELECT rowid FROM {table_name} {where_clause}",));
let query = format!(
"DELETE FROM {table_name} {} RETURNING *",
build_update_where_clause(filters, limit, select_query.as_deref()),
);
let mut all_filter_values: Vec<RusqliteDatabaseValue> = filters
.map(|filters| {
filters
.iter()
.flat_map(|value| value.params().unwrap_or_default().into_iter().cloned())
.map(std::convert::Into::into)
.collect::<Vec<_>>()
})
.unwrap_or_default();
if limit.is_some() {
all_filter_values.extend(all_filter_values.clone());
}
log::trace!(
"Running delete query: {query} with params: {:?}",
all_filter_values
.iter()
.filter_map(super::query::Expression::params)
.collect::<Vec<_>>()
);
let mut statement = connection.prepare_cached(&query)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
bind_values(&mut statement, Some(&all_filter_values), false, 0)?;
to_rows(&column_names, statement.raw_query())
}
fn find_row(
connection: &Connection,
table_name: &str,
distinct: bool,
columns: &[&str],
filters: Option<&[Box<dyn BooleanExpression>]>,
joins: Option<&[Join]>,
sort: Option<&[Sort]>,
) -> Result<Option<crate::Row>, RusqliteDatabaseError> {
let query = format!(
"SELECT {} {} FROM {table_name} {} {} {} LIMIT 1",
if distinct { "DISTINCT" } else { "" },
columns.join(", "),
build_join_clauses(joins),
build_where_clause(filters),
build_sort_clause(sort),
);
let mut statement = connection.prepare_cached(&query)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
bind_values(
&mut statement,
bexprs_to_values_opt(filters).as_deref(),
false,
0,
)?;
log::trace!(
"Running find_row query: {query} with params: {:?}",
filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
);
let mut query = statement.raw_query();
query
.next()?
.map(|row| from_row(&column_names, row))
.transpose()
}
fn insert_and_get_row(
connection: &Connection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
) -> Result<crate::Row, RusqliteDatabaseError> {
let column_names = values
.iter()
.map(|(key, _v)| format!("`{key}`"))
.collect::<Vec<_>>()
.join(", ");
let insert_columns = if values.is_empty() {
String::new()
} else {
format!("({column_names})")
};
let query = format!(
"INSERT INTO {table_name} {insert_columns} {} RETURNING *",
build_values_clause(values),
);
let mut statement = connection.prepare_cached(&query)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
bind_values(&mut statement, Some(&exprs_to_values(values)), false, 0)?;
log::trace!(
"Running insert_and_get_row query: {query} with params: {:?}",
values
.iter()
.filter_map(|(_, x)| x.params())
.collect::<Vec<_>>()
);
let mut query = statement.raw_query();
query
.next()?
.map(|row| from_row(&column_names, row))
.ok_or(RusqliteDatabaseError::NoRow)?
}
pub fn update_multi(
connection: &Connection,
table_name: &str,
values: &[Vec<(&str, Box<dyn Expression>)>],
filters: &Option<Vec<Box<dyn BooleanExpression>>>,
mut limit: Option<usize>,
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let mut results = vec![];
if values.is_empty() {
return Ok(results);
}
let mut pos = 0;
let mut i = 0;
let mut last_i = i;
for value in values {
let count = value.len();
if pos + count >= (i16::MAX - 1) as usize {
results.append(&mut update_chunk(
connection,
table_name,
&values[last_i..i],
filters,
limit,
)?);
last_i = i;
pos = 0;
}
i += 1;
pos += count;
if let Some(value) = limit {
if count >= value {
return Ok(results);
}
limit.replace(value - count);
}
}
if i > last_i {
results.append(&mut update_chunk(
connection,
table_name,
&values[last_i..],
filters,
limit,
)?);
}
Ok(results)
}
fn update_chunk(
connection: &Connection,
table_name: &str,
values: &[Vec<(&str, Box<dyn Expression>)>],
filters: &Option<Vec<Box<dyn BooleanExpression>>>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let first = values[0].as_slice();
let expected_value_size = first.len();
if let Some(bad_row) = values.iter().skip(1).find(|v| {
v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
}) {
log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
return Err(RusqliteDatabaseError::InvalidRequest);
}
let set_clause = values[0]
.iter()
.map(|(name, _value)| format!("`{name}` = EXCLUDED.`{name}`"))
.collect::<Vec<_>>()
.join(", ");
let column_names = values[0]
.iter()
.map(|(key, _v)| format!("`{key}`"))
.collect::<Vec<_>>()
.join(", ");
let select_query = limit.map(|_| {
format!(
"SELECT rowid FROM {table_name} {}",
build_where_clause(filters.as_deref()),
)
});
let query = format!(
"
UPDATE {table_name} ({column_names})
{}
SET {set_clause}
RETURNING *",
build_update_where_clause(filters.as_deref(), limit, select_query.as_deref()),
);
let all_values = values
.iter()
.flat_map(std::iter::IntoIterator::into_iter)
.flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
.map(std::convert::Into::into)
.collect::<Vec<_>>();
let mut all_filter_values = filters
.as_ref()
.map(|filters| {
filters
.iter()
.flat_map(|value| {
value
.params()
.unwrap_or_default()
.into_iter()
.cloned()
.map(std::convert::Into::into)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if limit.is_some() {
all_filter_values.extend(all_filter_values.clone());
}
let all_values = [all_values, all_filter_values].concat();
log::trace!(
"Running update chunk query: {query} with params: {:?}",
all_values
);
let mut statement = connection.prepare_cached(&query)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
bind_values(&mut statement, Some(&all_values), true, 0)?;
to_rows(&column_names, statement.raw_query())
}
pub fn upsert_multi(
connection: &Connection,
table_name: &str,
unique: &[Box<dyn Expression>],
values: &[Vec<(&str, Box<dyn Expression>)>],
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let mut results = vec![];
if values.is_empty() {
return Ok(results);
}
let mut pos = 0;
let mut i = 0;
let mut last_i = i;
for value in values {
let count = value.len();
if pos + count >= (i16::MAX - 1) as usize {
results.append(&mut upsert_chunk(
connection,
table_name,
unique,
&values[last_i..i],
)?);
last_i = i;
pos = 0;
}
i += 1;
pos += count;
}
if i > last_i {
results.append(&mut upsert_chunk(
connection,
table_name,
unique,
&values[last_i..],
)?);
}
Ok(results)
}
fn upsert_chunk(
connection: &Connection,
table_name: &str,
unique: &[Box<dyn Expression>],
values: &[Vec<(&str, Box<dyn Expression>)>],
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let first = values[0].as_slice();
let expected_value_size = first.len();
if let Some(bad_row) = values.iter().skip(1).find(|v| {
v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
}) {
log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
return Err(RusqliteDatabaseError::InvalidRequest);
}
let set_clause = values[0]
.iter()
.map(|(name, _value)| format!("`{name}` = EXCLUDED.`{name}`"))
.collect::<Vec<_>>()
.join(", ");
let column_names = values[0]
.iter()
.map(|(key, _v)| format!("`{key}`"))
.collect::<Vec<_>>()
.join(", ");
let values_str_list = values
.iter()
.map(|v| format!("({})", build_values_props(v).join(", ")))
.collect::<Vec<_>>();
let values_str = values_str_list.join(", ");
let values_str = if values_str.is_empty() {
"DEFAULT VALUES".to_string()
} else {
format!("VALUES {values_str}")
};
let unique_conflict = unique
.iter()
.map(|x| x.to_sql())
.collect::<Vec<_>>()
.join(", ");
let insert_columns = if values.is_empty() {
String::new()
} else {
format!("({column_names})")
};
let query = format!(
"
INSERT INTO {table_name} {insert_columns} {values_str}
ON CONFLICT({unique_conflict}) DO UPDATE
SET {set_clause}
RETURNING *"
);
let all_values = &values
.iter()
.flat_map(std::iter::IntoIterator::into_iter)
.flat_map(|(_, value)| value.params().unwrap_or(vec![]).into_iter().cloned())
.map(std::convert::Into::into)
.collect::<Vec<_>>();
log::trace!(
"Running upsert chunk query: {query} with params: {:?}",
all_values
);
let mut statement = connection.prepare_cached(&query)?;
let column_names = statement
.column_names()
.iter()
.map(std::string::ToString::to_string)
.collect::<Vec<_>>();
bind_values(&mut statement, Some(all_values), true, 0)?;
to_rows(&column_names, statement.raw_query())
}
fn upsert(
connection: &Connection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, RusqliteDatabaseError> {
let rows = update_and_get_rows(connection, table_name, values, filters, limit)?;
Ok(if rows.is_empty() {
vec![insert_and_get_row(connection, table_name, values)?]
} else {
rows
})
}
#[allow(unused)]
fn upsert_and_get_row(
connection: &Connection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<crate::Row, RusqliteDatabaseError> {
match find_row(connection, table_name, false, &["*"], filters, None, None)? {
Some(row) => {
let updated =
update_and_get_row(connection, table_name, values, filters, limit)?.unwrap();
let str1 = format!("{row:?}");
let str2 = format!("{updated:?}");
if str1 == str2 {
log::trace!("No updates to {table_name}");
} else {
log::debug!("Changed {table_name} from {str1} to {str2}");
}
Ok(updated)
}
None => Ok(insert_and_get_row(connection, table_name, values)?),
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone)]
pub struct RusqliteDatabaseValue(DatabaseValue);
impl From<DatabaseValue> for RusqliteDatabaseValue {
fn from(value: DatabaseValue) -> Self {
Self(value)
}
}
impl Deref for RusqliteDatabaseValue {
type Target = DatabaseValue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Expression for RusqliteDatabaseValue {
fn values(&self) -> Option<Vec<&DatabaseValue>> {
Some(vec![self])
}
fn is_null(&self) -> bool {
matches!(
self.0,
DatabaseValue::Null
| DatabaseValue::BoolOpt(None)
| DatabaseValue::RealOpt(None)
| DatabaseValue::StringOpt(None)
| DatabaseValue::NumberOpt(None)
| DatabaseValue::UNumberOpt(None)
)
}
fn expression_type(&self) -> ExpressionType {
ExpressionType::DatabaseValue(self)
}
}