use std::{
ops::Deref,
pin::Pin,
sync::{atomic::AtomicU16, Arc, LazyLock},
};
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use sqlx::{
pool::PoolConnection,
postgres::{PgArguments, PgRow, PgValueRef},
query::Query,
Column, Executor, PgPool, Postgres, Row, Statement, TypeInfo, Value, ValueRef,
};
use sqlx_postgres::PgConnection;
use thiserror::Error;
use tokio::sync::Mutex;
use crate::{
query::{BooleanExpression, Expression, ExpressionType, Join, Sort, SortDirection},
Database, DatabaseError, DatabaseValue, DeleteStatement, InsertStatement, SelectQuery,
UpdateStatement, UpsertMultiStatement, UpsertStatement,
};
trait ToSql {
fn to_sql(&self, index: &AtomicU16) -> String;
}
impl<T: Expression + ?Sized> ToSql for T {
#[allow(clippy::too_many_lines)]
fn to_sql(&self, index: &AtomicU16) -> String {
match self.expression_type() {
ExpressionType::Eq(value) => {
if value.right.is_null() {
format!(
"({} IS {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
} else {
format!(
"({} = {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
}
}
ExpressionType::Gt(value) => {
if value.right.is_null() {
panic!("Invalid > comparison with NULL");
} else {
format!(
"({} > {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
}
}
ExpressionType::In(value) => {
format!(
"{} IN ({})",
value.left.to_sql(index),
value.values.to_sql(index)
)
}
ExpressionType::NotIn(value) => {
format!(
"{} NOT IN ({})",
value.left.to_sql(index),
value.values.to_sql(index)
)
}
ExpressionType::Lt(value) => {
if value.right.is_null() {
panic!("Invalid < comparison with NULL");
} else {
format!(
"({} < {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
}
}
ExpressionType::Or(value) => format!(
"({})",
value
.conditions
.iter()
.map(|x| x.to_sql(index))
.collect::<Vec<_>>()
.join(" OR ")
),
ExpressionType::And(value) => format!(
"({})",
value
.conditions
.iter()
.map(|x| x.to_sql(index))
.collect::<Vec<_>>()
.join(" AND ")
),
ExpressionType::Gte(value) => {
if value.right.is_null() {
panic!("Invalid >= comparison with NULL");
} else {
format!(
"({} >= {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
}
}
ExpressionType::Lte(value) => {
if value.right.is_null() {
panic!("Invalid <= comparison with NULL");
} else {
format!(
"({} <= {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
}
}
ExpressionType::Join(value) => format!(
"{} JOIN {} ON {}",
if value.left { "LEFT" } else { "" },
value.table_name,
value.on
),
ExpressionType::Sort(value) => format!(
"({}) {}",
value.expression.to_sql(index),
match value.direction {
SortDirection::Asc => "ASC",
SortDirection::Desc => "DESC",
}
),
ExpressionType::NotEq(value) => {
if value.right.is_null() {
format!(
"({} IS NOT {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
} else {
format!(
"({} != {})",
value.left.to_sql(index),
value.right.to_sql(index)
)
}
}
ExpressionType::InList(value) => value
.values
.iter()
.map(|value| value.to_sql(index))
.collect::<Vec<_>>()
.join(","),
ExpressionType::Coalesce(value) => format!(
"COALESCE({})",
value
.values
.iter()
.map(|value| value.to_sql(index))
.collect::<Vec<_>>()
.join(",")
),
ExpressionType::Literal(value) => value.value.to_string(),
ExpressionType::Identifier(value) => format_identifier(&value.value),
ExpressionType::SelectQuery(value) => {
let joins = value.joins.as_ref().map_or_else(String::new, |joins| {
joins
.iter()
.map(|x| x.to_sql(index))
.collect::<Vec<_>>()
.join(" ")
});
let where_clause = value.filters.as_ref().map_or_else(String::new, |filters| {
if filters.is_empty() {
String::new()
} else {
format!(
"WHERE {}",
filters
.iter()
.map(|x| format!("({})", x.to_sql(index)))
.collect::<Vec<_>>()
.join(" AND ")
)
}
});
let sort_clause = value.sorts.as_ref().map_or_else(String::new, |sorts| {
if sorts.is_empty() {
String::new()
} else {
format!(
"ORDER BY {}",
sorts
.iter()
.map(|x| x.to_sql(index))
.collect::<Vec<_>>()
.join(", ")
)
}
});
let limit = value
.limit
.map_or_else(String::new, |limit| format!("LIMIT {limit}"));
format!(
"SELECT {} {} FROM {} {} {} {} {}",
if value.distinct { "DISTINCT" } else { "" },
value
.columns
.iter()
.map(|x| format_identifier(x))
.collect::<Vec<_>>()
.join(", "),
value.table_name,
joins,
where_clause,
sort_clause,
limit
)
}
ExpressionType::DatabaseValue(value) => match value {
DatabaseValue::Null
| DatabaseValue::BoolOpt(None)
| DatabaseValue::StringOpt(None)
| DatabaseValue::NumberOpt(None)
| DatabaseValue::UNumberOpt(None)
| DatabaseValue::RealOpt(None) => "NULL".to_string(),
DatabaseValue::Now => "NOW()".to_string(),
DatabaseValue::NowAdd(ref add) => format!("NOW() + {add}"),
_ => {
let pos = index.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
format!("${pos}")
}
},
}
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug)]
pub struct PostgresSqlxDatabase {
pool: Arc<Mutex<PgPool>>,
#[allow(clippy::type_complexity)]
connection: Arc<Mutex<Option<Arc<Mutex<PoolConnection<Postgres>>>>>>,
}
impl PostgresSqlxDatabase {
pub fn new(pool: Arc<Mutex<PgPool>>) -> Self {
Self {
pool,
connection: Arc::new(Mutex::new(None)),
}
}
pub async fn get_connection(
&self,
) -> Result<Arc<Mutex<PoolConnection<Postgres>>>, SqlxDatabaseError> {
let connection = { self.connection.lock().await.clone() };
if let Some(connection) = connection {
log::trace!("Returning existing connection from postgres db pool");
return Ok(connection);
}
log::debug!("Fetching new connection from postgres db pool");
let connection = Arc::new(Mutex::new(self.pool.lock().await.acquire().await?));
self.connection.lock().await.replace(connection.clone());
Ok(connection)
}
}
#[derive(Debug, Error)]
pub enum SqlxDatabaseError {
#[error(transparent)]
Sqlx(#[from] sqlx::Error),
#[error("No ID")]
NoId,
#[error("No row")]
NoRow,
#[error("Invalid request")]
InvalidRequest,
#[error("Missing unique")]
MissingUnique,
}
impl From<SqlxDatabaseError> for DatabaseError {
fn from(value: SqlxDatabaseError) -> Self {
Self::PostgresSqlx(value)
}
}
#[async_trait]
impl Database for PostgresSqlxDatabase {
async fn query(&self, query: &SelectQuery<'_>) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(select(
self.get_connection().await?.lock().await.as_mut(),
query.table_name,
query.distinct,
query.columns,
query.filters.as_deref(),
query.joins.as_deref(),
query.sorts.as_deref(),
query.limit,
)
.await?)
}
async fn query_first(
&self,
query: &SelectQuery<'_>,
) -> Result<Option<crate::Row>, DatabaseError> {
Ok(find_row(
self.get_connection().await?.lock().await.as_mut(),
query.table_name,
query.distinct,
query.columns,
query.filters.as_deref(),
query.joins.as_deref(),
query.sorts.as_deref(),
)
.await?)
}
async fn exec_delete(
&self,
statement: &DeleteStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(delete(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
statement.filters.as_deref(),
statement.limit,
)
.await?)
}
async fn exec_delete_first(
&self,
statement: &DeleteStatement<'_>,
) -> Result<Option<crate::Row>, DatabaseError> {
Ok(delete(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
statement.filters.as_deref(),
Some(1),
)
.await?
.into_iter()
.next())
}
async fn exec_insert(
&self,
statement: &InsertStatement<'_>,
) -> Result<crate::Row, DatabaseError> {
Ok(insert_and_get_row(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
&statement.values,
)
.await?)
}
async fn exec_update(
&self,
statement: &UpdateStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(update_and_get_rows(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)
.await?)
}
async fn exec_update_first(
&self,
statement: &UpdateStatement<'_>,
) -> Result<Option<crate::Row>, DatabaseError> {
Ok(update_and_get_row(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)
.await?)
}
async fn exec_upsert(
&self,
statement: &UpsertStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
Ok(upsert(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)
.await?)
}
async fn exec_upsert_first(
&self,
statement: &UpsertStatement<'_>,
) -> Result<crate::Row, DatabaseError> {
Ok(upsert_and_get_row(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
&statement.values,
statement.filters.as_deref(),
statement.limit,
)
.await?)
}
async fn exec_upsert_multi(
&self,
statement: &UpsertMultiStatement<'_>,
) -> Result<Vec<crate::Row>, DatabaseError> {
let rows = {
upsert_multi(
self.get_connection().await?.lock().await.as_mut(),
statement.table_name,
statement
.unique
.as_ref()
.ok_or(SqlxDatabaseError::MissingUnique)?,
&statement.values,
)
.await?
};
Ok(rows)
}
}
fn column_value(value: &PgValueRef<'_>) -> Result<DatabaseValue, sqlx::Error> {
if value.is_null() {
return Ok(DatabaseValue::Null);
}
let owned = sqlx::ValueRef::to_owned(value);
match value.type_info().name() {
"BOOL" => Ok(DatabaseValue::Bool(owned.try_decode()?)),
"CHAR" | "SMALLINT" | "SMALLSERIAL" | "INT2" | "INT" | "SERIAL" | "INT4" | "BIGINT"
| "BIGSERIAL" | "INT8" => Ok(DatabaseValue::Number(owned.try_decode()?)),
"REAL" | "FLOAT4" | "DOUBLE PRECISION" | "FLOAT8" => {
Ok(DatabaseValue::Real(owned.try_decode()?))
}
"VARCHAR" | "CHAR(N)" | "TEXT" | "NAME" | "CITEXT" => {
Ok(DatabaseValue::String(owned.try_decode()?))
}
"TIMESTAMP" => Ok(DatabaseValue::DateTime(owned.try_decode()?)),
_ => Err(sqlx::Error::TypeNotFound {
type_name: value.type_info().name().to_string(),
}),
}
}
fn from_row(column_names: &[String], row: &PgRow) -> Result<crate::Row, SqlxDatabaseError> {
let mut columns = vec![];
for column in column_names {
columns.push((
column.to_string(),
column_value(&row.try_get_raw(column.as_str())?)?,
));
}
Ok(crate::Row { columns })
}
async fn update_and_get_row(
connection: &mut PgConnection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Option<crate::Row>, SqlxDatabaseError> {
let index = AtomicU16::new(0);
let query = format!(
"UPDATE {table_name} {} {} RETURNING *",
build_set_clause(values, &index),
build_update_where_clause(
filters,
limit,
limit
.map(|_| {
format!(
"SELECT CTID FROM {table_name} {}",
build_where_clause(filters, &index),
)
})
.as_deref(),
&index
),
);
let all_values = values
.iter()
.flat_map(|(_, value)| {
value
.params()
.unwrap_or(vec![])
.into_iter()
.cloned()
.map(std::convert::Into::into)
})
.collect::<Vec<_>>();
let mut all_filter_values = filters
.map(|filters| {
filters
.iter()
.flat_map(|value| {
value
.params()
.unwrap_or_default()
.into_iter()
.cloned()
.map(std::convert::Into::into)
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if limit.is_some() {
all_filter_values.extend(all_filter_values.clone());
}
let all_values = [all_values, all_filter_values].concat();
log::trace!(
"Running update query: {query} with params: {:?}",
all_values
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let query = bind_values(statement.query(), Some(&all_values))?;
let mut stream = query.fetch(connection);
let pg_row: Option<PgRow> = stream.next().await.transpose()?;
pg_row.map(|row| from_row(&column_names, &row)).transpose()
}
async fn update_and_get_rows(
connection: &mut PgConnection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let index = AtomicU16::new(0);
let query = format!(
"UPDATE {table_name} {} {} RETURNING *",
build_set_clause(values, &index),
build_update_where_clause(
filters,
limit,
limit
.map(|_| {
format!(
"SELECT CTID FROM {table_name} {}",
build_where_clause(filters, &index),
)
})
.as_deref(),
&index
),
);
let all_values = values
.iter()
.flat_map(|(_, value)| {
value
.params()
.unwrap_or(vec![])
.into_iter()
.cloned()
.map(std::convert::Into::into)
})
.collect::<Vec<PgDatabaseValue>>();
let mut all_filter_values = filters
.map(|filters| {
filters
.iter()
.flat_map(|value| {
value
.params()
.unwrap_or_default()
.into_iter()
.cloned()
.map(std::convert::Into::into)
})
.collect::<Vec<PgDatabaseValue>>()
})
.unwrap_or_default();
if limit.is_some() {
all_filter_values.extend(all_filter_values.clone());
}
let all_values = [all_values, all_filter_values].concat();
log::trace!(
"Running update query: {query} with params: {:?}",
all_values
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let query = bind_values(statement.query(), Some(&all_values))?;
to_rows(&column_names, query.fetch(connection)).await
}
fn build_join_clauses(joins: Option<&[Join]>) -> String {
joins.map_or_else(String::new, |joins| {
joins
.iter()
.map(|join| {
format!(
"{}JOIN {} ON {}",
if join.left { "LEFT " } else { "" },
join.table_name,
join.on
)
})
.collect::<Vec<_>>()
.join(" ")
})
}
fn build_where_clause(filters: Option<&[Box<dyn BooleanExpression>]>, index: &AtomicU16) -> String {
filters.map_or_else(String::new, |filters| {
if filters.is_empty() {
String::new()
} else {
let filters = build_where_props(filters, index);
format!("WHERE {}", filters.join(" AND "))
}
})
}
fn build_where_props(filters: &[Box<dyn BooleanExpression>], index: &AtomicU16) -> Vec<String> {
filters
.iter()
.map(|filter| filter.deref().to_sql(index))
.collect()
}
fn build_sort_clause(sorts: Option<&[Sort]>, index: &AtomicU16) -> String {
sorts.map_or_else(String::new, |sorts| {
if sorts.is_empty() {
String::new()
} else {
format!("ORDER BY {}", build_sort_props(sorts, index).join(", "))
}
})
}
fn build_sort_props(sorts: &[Sort], index: &AtomicU16) -> Vec<String> {
sorts.iter().map(|sort| sort.to_sql(index)).collect()
}
fn build_update_where_clause(
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
query: Option<&str>,
index: &AtomicU16,
) -> String {
let clause = build_where_clause(filters, index);
let limit_clause = build_update_limit_clause(limit, query);
let clause = if limit_clause.is_empty() {
clause
} else if clause.is_empty() {
"WHERE".into()
} else {
clause + " AND"
};
format!("{clause} {limit_clause}").trim().to_string()
}
fn build_update_limit_clause(limit: Option<usize>, query: Option<&str>) -> String {
limit.map_or_else(String::new, |limit| {
query.map_or_else(String::new, |query| {
format!("CTID IN ({query} LIMIT {limit})")
})
})
}
fn build_set_clause(values: &[(&str, Box<dyn Expression>)], index: &AtomicU16) -> String {
if values.is_empty() {
String::new()
} else {
format!("SET {}", build_set_props(values, index).join(", "))
}
}
fn build_set_props(values: &[(&str, Box<dyn Expression>)], index: &AtomicU16) -> Vec<String> {
values
.iter()
.map(|(name, value)| {
format!(
"{}={}",
format_identifier(name),
value.deref().to_sql(index)
)
})
.collect()
}
fn build_values_clause(values: &[(&str, Box<dyn Expression>)], index: &AtomicU16) -> String {
if values.is_empty() {
"DEFAULT VALUES".to_string()
} else {
let filters = build_values_props(values, index).join(", ");
format!("VALUES({filters})")
}
}
fn build_values_props(values: &[(&str, Box<dyn Expression>)], index: &AtomicU16) -> Vec<String> {
values
.iter()
.map(|(_, value)| value.deref().to_sql(index))
.collect()
}
fn format_identifier(identifier: &str) -> String {
static NON_ALPHA_NUMERIC_REGEX: LazyLock<regex::Regex> =
LazyLock::new(|| regex::Regex::new(r"[^A-Za-z0-9_]").expect("Invalid Regex"));
if NON_ALPHA_NUMERIC_REGEX.is_match(identifier) {
identifier.to_string()
} else {
format!("\"{identifier}\"")
}
}
fn bind_values<'a, 'b>(
mut query: Query<'a, Postgres, PgArguments>,
values: Option<&'b [PgDatabaseValue]>,
) -> Result<Query<'a, Postgres, PgArguments>, SqlxDatabaseError>
where
'b: 'a,
{
if let Some(values) = values {
for value in values {
match &**value {
DatabaseValue::String(value) | DatabaseValue::StringOpt(Some(value)) => {
query = query.bind(value);
}
DatabaseValue::Null
| DatabaseValue::StringOpt(None)
| DatabaseValue::BoolOpt(None)
| DatabaseValue::NumberOpt(None)
| DatabaseValue::UNumberOpt(None)
| DatabaseValue::RealOpt(None)
| DatabaseValue::Now => (),
DatabaseValue::Bool(value) | DatabaseValue::BoolOpt(Some(value)) => {
query = query.bind(value);
}
DatabaseValue::Number(value) | DatabaseValue::NumberOpt(Some(value)) => {
query = query.bind(*value);
}
DatabaseValue::UNumber(value) | DatabaseValue::UNumberOpt(Some(value)) => {
query = query.bind(
i64::try_from(*value).map_err(|_| SqlxDatabaseError::InvalidRequest)?,
);
}
DatabaseValue::Real(value) | DatabaseValue::RealOpt(Some(value)) => {
query = query.bind(*value);
}
DatabaseValue::NowAdd(_add) => (),
DatabaseValue::DateTime(value) => {
query = query.bind(value);
}
}
}
}
Ok(query)
}
async fn to_rows<'a>(
column_names: &[String],
mut rows: Pin<Box<(dyn Stream<Item = Result<PgRow, sqlx::Error>> + Send + 'a)>>,
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let mut results = vec![];
while let Some(row) = rows.next().await.transpose()? {
results.push(from_row(column_names, &row)?);
}
log::trace!(
"Got {} row{}",
results.len(),
if results.len() == 1 { "" } else { "s" }
);
Ok(results)
}
fn to_values(values: &[(&str, DatabaseValue)]) -> Vec<PgDatabaseValue> {
values
.iter()
.map(|(_key, value)| value.clone().into())
.collect::<Vec<_>>()
}
fn exprs_to_values(values: &[(&str, Box<dyn Expression>)]) -> Vec<PgDatabaseValue> {
values
.iter()
.flat_map(|value| value.1.values().into_iter())
.flatten()
.cloned()
.map(std::convert::Into::into)
.collect::<Vec<_>>()
}
fn bexprs_to_values(values: &[Box<dyn BooleanExpression>]) -> Vec<PgDatabaseValue> {
values
.iter()
.flat_map(|value| value.values().into_iter())
.flatten()
.cloned()
.map(std::convert::Into::into)
.collect::<Vec<_>>()
}
#[allow(unused)]
fn to_values_opt(values: Option<&[(&str, DatabaseValue)]>) -> Option<Vec<PgDatabaseValue>> {
values.map(to_values)
}
#[allow(unused)]
fn exprs_to_values_opt(
values: Option<&[(&str, Box<dyn Expression>)]>,
) -> Option<Vec<PgDatabaseValue>> {
values.map(exprs_to_values)
}
fn bexprs_to_values_opt(
values: Option<&[Box<dyn BooleanExpression>]>,
) -> Option<Vec<PgDatabaseValue>> {
values.map(bexprs_to_values)
}
#[allow(clippy::too_many_arguments)]
async fn select(
connection: &mut PgConnection,
table_name: &str,
distinct: bool,
columns: &[&str],
filters: Option<&[Box<dyn BooleanExpression>]>,
joins: Option<&[Join<'_>]>,
sort: Option<&[Sort]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let index = AtomicU16::new(0);
let query = format!(
"SELECT {} {} FROM {table_name} {} {} {} {}",
if distinct { "DISTINCT" } else { "" },
columns
.iter()
.map(|x| format_identifier(x))
.collect::<Vec<_>>()
.join(", "),
build_join_clauses(joins),
build_where_clause(filters, &index),
build_sort_clause(sort, &index),
limit.map_or_else(String::new, |limit| format!("LIMIT {limit}"))
);
log::trace!(
"Running select query: {query} with params: {:?}",
filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let filters = bexprs_to_values_opt(filters);
let query = bind_values(statement.query(), filters.as_deref())?;
to_rows(&column_names, query.fetch(connection)).await
}
async fn delete(
connection: &mut PgConnection,
table_name: &str,
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let index = AtomicU16::new(0);
let query = format!(
"DELETE FROM {table_name} {} RETURNING * {}",
build_where_clause(filters, &index),
limit.map_or_else(String::new, |limit| format!("LIMIT {limit}"))
);
log::trace!(
"Running delete query: {query} with params: {:?}",
filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let filters = bexprs_to_values_opt(filters);
let query = bind_values(statement.query(), filters.as_deref())?;
to_rows(&column_names, query.fetch(connection)).await
}
async fn find_row(
connection: &mut PgConnection,
table_name: &str,
distinct: bool,
columns: &[&str],
filters: Option<&[Box<dyn BooleanExpression>]>,
joins: Option<&[Join<'_>]>,
sort: Option<&[Sort]>,
) -> Result<Option<crate::Row>, SqlxDatabaseError> {
let index = AtomicU16::new(0);
let query = format!(
"SELECT {} {} FROM {table_name} {} {} {} LIMIT 1",
if distinct { "DISTINCT" } else { "" },
columns
.iter()
.map(|x| format_identifier(x))
.collect::<Vec<_>>()
.join(", "),
build_join_clauses(joins),
build_where_clause(filters, &index),
build_sort_clause(sort, &index),
);
log::trace!(
"Running find_row query: {query} with params: {:?}",
filters.map(|f| f.iter().filter_map(|x| x.params()).collect::<Vec<_>>())
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let filters = bexprs_to_values_opt(filters);
let query = bind_values(statement.query(), filters.as_deref())?;
let mut query = query.fetch(connection);
query
.next()
.await
.transpose()?
.map(|row| from_row(&column_names, &row))
.transpose()
}
async fn insert_and_get_row(
connection: &mut PgConnection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
) -> Result<crate::Row, SqlxDatabaseError> {
let column_names = values
.iter()
.map(|(key, _v)| format_identifier(key))
.collect::<Vec<_>>()
.join(", ");
let index = AtomicU16::new(0);
let insert_columns = if values.is_empty() {
String::new()
} else {
format!("({column_names})")
};
let query = format!(
"INSERT INTO {table_name} {insert_columns} {} RETURNING *",
build_values_clause(values, &index),
);
log::trace!(
"Running insert_and_get_row query: '{query}' with params: {:?}",
values
.iter()
.filter_map(|(_, x)| x.params())
.collect::<Vec<_>>()
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let values = exprs_to_values(values);
let query = bind_values(statement.query(), Some(&values))?;
let mut stream = query.fetch(connection);
stream
.next()
.await
.transpose()?
.map(|row| from_row(&column_names, &row))
.ok_or(SqlxDatabaseError::NoRow)?
}
pub async fn update_multi(
connection: &mut PgConnection,
table_name: &str,
values: &[Vec<(&str, Box<dyn Expression>)>],
filters: Option<Vec<Box<dyn BooleanExpression>>>,
mut limit: Option<usize>,
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let mut results = vec![];
if values.is_empty() {
return Ok(results);
}
let mut pos = 0;
let mut i = 0;
let mut last_i = i;
for value in values {
let count = value.len();
if pos + count >= (i16::MAX - 1) as usize {
results.append(
&mut update_chunk(connection, table_name, &values[last_i..i], &filters, limit)
.await?,
);
last_i = i;
pos = 0;
}
i += 1;
pos += count;
if let Some(value) = limit {
if count >= value {
return Ok(results);
}
limit.replace(value - count);
}
}
if i > last_i {
results.append(
&mut update_chunk(connection, table_name, &values[last_i..], &filters, limit).await?,
);
}
Ok(results)
}
async fn update_chunk(
connection: &mut PgConnection,
table_name: &str,
values: &[Vec<(&str, Box<dyn Expression>)>],
filters: &Option<Vec<Box<dyn BooleanExpression>>>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let first = values[0].as_slice();
let expected_value_size = first.len();
if let Some(bad_row) = values.iter().skip(1).find(|v| {
v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
}) {
log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
return Err(SqlxDatabaseError::InvalidRequest);
}
let set_clause = values[0]
.iter()
.map(|(name, _value)| {
format!(
"{} = EXCLUDED.{}",
format_identifier(name),
format_identifier(name)
)
})
.collect::<Vec<_>>()
.join(", ");
let column_names = values[0]
.iter()
.map(|(key, _v)| format_identifier(key))
.collect::<Vec<_>>()
.join(", ");
let index = AtomicU16::new(0);
let query = format!(
"
UPDATE {table_name} ({column_names})
{}
SET {set_clause}
RETURNING *",
build_update_where_clause(
filters.as_deref(),
limit,
limit
.map(|_| {
format!(
"SELECT CTID FROM {table_name} {}",
build_where_clause(filters.as_deref(), &index),
)
})
.as_deref(),
&index
),
);
let all_values = values
.iter()
.flat_map(std::iter::IntoIterator::into_iter)
.flat_map(|(_, value)| {
value
.params()
.unwrap_or(vec![])
.into_iter()
.cloned()
.map(std::convert::Into::into)
})
.collect::<Vec<_>>();
let mut all_filter_values = filters
.as_ref()
.map(|filters| {
filters
.iter()
.flat_map(|value| {
value
.params()
.unwrap_or_default()
.into_iter()
.cloned()
.map(std::convert::Into::into)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
if limit.is_some() {
all_filter_values.extend(all_filter_values.clone());
}
let all_values = [all_values, all_filter_values].concat();
log::trace!(
"Running update chunk query: {query} with params: {:?}",
all_values
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let query = bind_values(statement.query(), Some(&all_values))?;
to_rows(&column_names, query.fetch(connection)).await
}
pub async fn upsert_multi(
connection: &mut PgConnection,
table_name: &str,
unique: &[Box<dyn Expression>],
values: &[Vec<(&str, Box<dyn Expression>)>],
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let mut results = vec![];
if values.is_empty() {
return Ok(results);
}
let mut pos = 0;
let mut i = 0;
let mut last_i = i;
for value in values {
let count = value.len();
if pos + count >= (i16::MAX - 1) as usize {
results.append(
&mut upsert_chunk(connection, table_name, unique, &values[last_i..i]).await?,
);
last_i = i;
pos = 0;
}
i += 1;
pos += count;
}
if i > last_i {
results.append(&mut upsert_chunk(connection, table_name, unique, &values[last_i..]).await?);
}
Ok(results)
}
async fn upsert_chunk(
connection: &mut PgConnection,
table_name: &str,
unique: &[Box<dyn Expression>],
values: &[Vec<(&str, Box<dyn Expression>)>],
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let first = values[0].as_slice();
let expected_value_size = first.len();
if let Some(bad_row) = values.iter().skip(1).find(|v| {
v.len() != expected_value_size || v.iter().enumerate().any(|(i, c)| c.0 != first[i].0)
}) {
log::error!("Bad row: {bad_row:?}. Expected to match schema of first row: {first:?}");
return Err(SqlxDatabaseError::InvalidRequest);
}
let set_clause = values[0]
.iter()
.map(|(name, _value)| {
format!(
"{} = EXCLUDED.{}",
format_identifier(name),
format_identifier(name)
)
})
.collect::<Vec<_>>()
.join(", ");
let column_names = values[0]
.iter()
.map(|(key, _v)| format_identifier(key))
.collect::<Vec<_>>()
.join(", ");
let index = AtomicU16::new(0);
let values_str_list = values
.iter()
.map(|v| format!("({})", build_values_props(v, &index).join(", ")))
.collect::<Vec<_>>();
let values_str = values_str_list.join(", ");
let unique_conflict = unique
.iter()
.map(|x| x.to_sql(&index))
.collect::<Vec<_>>()
.join(", ");
let query = format!(
"
INSERT INTO {table_name} ({column_names})
VALUES {values_str}
ON CONFLICT({unique_conflict}) DO UPDATE
SET {set_clause}
RETURNING *"
);
let all_values = &values
.iter()
.flat_map(std::iter::IntoIterator::into_iter)
.flat_map(|(_, value)| {
value
.params()
.unwrap_or(vec![])
.into_iter()
.cloned()
.map(std::convert::Into::into)
})
.collect::<Vec<_>>();
log::trace!(
"Running upsert chunk query: {query} with params: {:?}",
all_values
);
let statement = connection.prepare(&query).await?;
let column_names = statement
.columns()
.iter()
.map(|x| x.name().to_string())
.collect::<Vec<_>>();
let query = bind_values(statement.query(), Some(all_values))?;
to_rows(&column_names, query.fetch(connection)).await
}
async fn upsert(
connection: &mut PgConnection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<Vec<crate::Row>, SqlxDatabaseError> {
let rows = update_and_get_rows(connection, table_name, values, filters, limit).await?;
Ok(if rows.is_empty() {
vec![insert_and_get_row(connection, table_name, values).await?]
} else {
rows
})
}
async fn upsert_and_get_row(
connection: &mut PgConnection,
table_name: &str,
values: &[(&str, Box<dyn Expression>)],
filters: Option<&[Box<dyn BooleanExpression>]>,
limit: Option<usize>,
) -> Result<crate::Row, SqlxDatabaseError> {
match find_row(connection, table_name, false, &["*"], filters, None, None).await? {
Some(row) => {
let updated = update_and_get_row(connection, table_name, values, filters, limit)
.await?
.unwrap();
let str1 = format!("{row:?}");
let str2 = format!("{updated:?}");
if str1 == str2 {
log::trace!("No updates to {table_name}");
} else {
log::debug!("Changed {table_name} from {str1} to {str2}");
}
Ok(updated)
}
None => Ok(insert_and_get_row(connection, table_name, values).await?),
}
}
#[derive(Debug, Clone)]
pub struct PgDatabaseValue(DatabaseValue);
impl From<DatabaseValue> for PgDatabaseValue {
fn from(value: DatabaseValue) -> Self {
Self(value)
}
}
impl Deref for PgDatabaseValue {
type Target = DatabaseValue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Expression for PgDatabaseValue {
fn values(&self) -> Option<Vec<&DatabaseValue>> {
Some(vec![self])
}
fn is_null(&self) -> bool {
self.0.is_null()
}
fn expression_type(&self) -> ExpressionType {
ExpressionType::DatabaseValue(self)
}
}