use futures::future::BoxFuture;
use heck::ToSnakeCase;
use sqlx::{Any, Arguments, Decode, Encode, Type, any::AnyArguments};
use std::marker::PhantomData;
use std::collections::{HashMap, HashSet};
use crate::{
AnyImpl, Error,
any_struct::FromAnyRow,
database::{Connection, Drivers},
model::{ColumnInfo, Model},
temporal::{self, is_temporal_type},
value_binding::ValueBinder,
};
pub type FilterFn = Box<dyn Fn(&mut String, &mut AnyArguments<'_>, &Drivers, &mut usize) + Send + Sync>;
pub trait ToUpdateValue {
fn to_update_value(self) -> Option<String>;
}
macro_rules! impl_update_value {
($($t:ty),*) => {
$(
impl ToUpdateValue for $t {
fn to_update_value(self) -> Option<String> {
Some(self.to_string())
}
}
impl ToUpdateValue for Option<$t> {
fn to_update_value(self) -> Option<String> {
self.map(|v| v.to_string())
}
}
)*
}
}
impl_update_value!(i8, i16, i32, i64, isize, u8, u16, u32, u64, usize, f32, f64, bool, String, &str, uuid::Uuid);
impl_update_value!(chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::FixedOffset>, chrono::NaiveDateTime, chrono::NaiveDate, chrono::NaiveTime);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Op {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
Like,
NotLike,
In,
NotIn,
Between,
NotBetween,
}
impl Op {
pub fn as_sql(&self) -> &'static str {
match self {
Op::Eq => "=",
Op::Ne => "!=",
Op::Gt => ">",
Op::Gte => ">=",
Op::Lt => "<",
Op::Lte => "<=",
Op::Like => "LIKE",
Op::NotLike => "NOT LIKE",
Op::In => "IN",
Op::NotIn => "NOT IN",
Op::Between => "BETWEEN",
Op::NotBetween => "NOT BETWEEN",
}
}
}
pub struct QueryBuilder<T, E> {
pub(crate) tx: E,
pub(crate) driver: Drivers,
pub(crate) table_name: &'static str,
pub(crate) alias: Option<String>,
pub(crate) columns_info: Vec<ColumnInfo>,
pub(crate) columns: Vec<String>,
pub(crate) select_columns: Vec<String>,
pub where_clauses: Vec<FilterFn>,
pub order_clauses: Vec<String>,
pub joins_clauses: Vec<FilterFn>,
pub with_relations: Vec<String>,
pub with_modifiers: std::collections::HashMap<String, std::sync::Arc<dyn std::any::Any + Send + Sync>>,
pub join_aliases: std::collections::HashMap<String, String>,
pub limit: Option<usize>,
pub offset: Option<usize>,
pub(crate) debug_mode: bool,
pub(crate) group_by_clauses: Vec<String>,
pub(crate) having_clauses: Vec<FilterFn>,
pub(crate) is_distinct: bool,
pub(crate) omit_columns: Vec<String>,
pub(crate) with_deleted: bool,
pub(crate) union_clauses: Vec<(String, FilterFn)>,
pub(crate) _marker: PhantomData<T>,
}
pub struct QueryModifier {
pub modifier: std::sync::Arc<dyn Fn(QueryBuilder<crate::any_struct::AnyImplStruct, crate::Database>) -> QueryBuilder<crate::any_struct::AnyImplStruct, crate::Database> + Send + Sync + 'static>,
}
impl<T, E> QueryBuilder<T, E>
where
T: Model + Send + Sync + Unpin + AnyImpl,
E: Connection,
{
pub fn new(
tx: E,
driver: Drivers,
table_name: &'static str,
columns_info: Vec<ColumnInfo>,
columns: Vec<String>,
) -> Self {
let omit_columns: Vec<String> =
columns_info.iter().filter(|c| c.omit).map(|c| c.name.to_snake_case()).collect();
Self {
tx,
alias: None,
driver,
table_name,
columns_info,
columns,
debug_mode: false,
select_columns: Vec::new(),
where_clauses: Vec::new(),
order_clauses: Vec::new(),
joins_clauses: Vec::new(),
join_aliases: std::collections::HashMap::new(),
group_by_clauses: Vec::new(),
having_clauses: Vec::new(),
is_distinct: false,
omit_columns,
limit: None,
offset: None,
with_deleted: false,
union_clauses: Vec::new(),
with_relations: Vec::new(),
with_modifiers: std::collections::HashMap::new(),
_marker: PhantomData,
}
}
pub(crate) fn get_table_identifier(&self) -> String {
self.alias.clone().unwrap_or_else(|| self.table_name.to_snake_case())
}
pub fn with(mut self, relation: &str) -> Self {
self.with_relations.push(relation.to_string());
self
}
pub fn with_query<F>(mut self, relation: &str, modifier: F) -> Self
where
F: Fn(QueryBuilder<crate::any_struct::AnyImplStruct, crate::Database>) -> QueryBuilder<crate::any_struct::AnyImplStruct, crate::Database> + Send + Sync + 'static,
{
self.with_relations.push(relation.to_string());
let arc_mod = std::sync::Arc::new(modifier);
let wrapper = QueryModifier { modifier: arc_mod };
self.with_modifiers.insert(relation.to_string(), std::sync::Arc::new(wrapper));
self
}
fn filter_internal<V>(mut self, joiner: &str, col: &'static str, op: Op, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
let op_str = op.as_sql();
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col.to_snake_case());
let joiner_owned = joiner.to_string();
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(&joiner_owned);
if let Some((table, column)) = col.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col));
} else {
query.push_str(&format!("\"{}\"", col));
}
query.push(' ');
query.push_str(op_str);
query.push(' ');
match driver {
Drivers::Postgres => {
query.push_str(&format!("${}", arg_counter));
*arg_counter += 1;
}
_ => query.push('?'),
}
let _ = args.add(value.clone());
});
self.where_clauses.push(clause);
self
}
pub fn filter_subquery<S, SE>(mut self, col: &'static str, op: Op, mut subquery: QueryBuilder<S, SE>) -> Self
where
S: Model + Send + Sync + Unpin + AnyImpl + 'static,
SE: Connection + 'static,
{
subquery.apply_soft_delete_filter();
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col.to_snake_case());
let op_str = op.as_sql();
let clause: FilterFn = Box::new(move |query, args, _driver, arg_counter| {
query.push_str(" AND ");
if let Some((table, column)) = col.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col));
} else {
query.push_str(&format!("\"{}\"", col));
}
query.push_str(&format!(" {} (", op_str));
subquery.write_select_sql::<S>(query, args, arg_counter);
query.push_str(")");
});
self.where_clauses.push(clause);
self
}
pub async fn truncate(self) -> Result<(), sqlx::Error> {
let table_name = self.table_name.to_snake_case();
let query = match self.driver {
Drivers::Postgres | Drivers::MySQL => format!("TRUNCATE TABLE \"{}\"", table_name),
Drivers::SQLite => format!("DELETE FROM \"{}\"", table_name),
};
if self.debug_mode {
log::debug!("SQL: {}", query);
}
self.tx.execute(&query, AnyArguments::default()).await?;
if matches!(self.driver, Drivers::SQLite) {
let _ = self.tx.execute(&format!("DELETE FROM sqlite_sequence WHERE name=\"{}\"", table_name), AnyArguments::default()).await;
}
Ok(())
}
pub fn union(self, other: QueryBuilder<T, E>) -> Self where T: AnyImpl + 'static, E: 'static {
self.union_internal("UNION", other)
}
pub fn union_all(self, other: QueryBuilder<T, E>) -> Self where T: AnyImpl + 'static, E: 'static {
self.union_internal("UNION ALL", other)
}
fn union_internal(mut self, op: &str, mut other: QueryBuilder<T, E>) -> Self where T: AnyImpl + 'static, E: 'static {
other.apply_soft_delete_filter();
let op_owned = op.to_string();
self.union_clauses.push((op_owned.clone(), Box::new(move |query: &mut String, args: &mut AnyArguments<'_>, _driver: &Drivers, arg_counter: &mut usize| {
query.push_str(" ");
query.push_str(&op_owned);
query.push_str(" ");
other.write_select_sql::<T>(query, args, arg_counter);
})));
self
}
pub(crate) fn write_select_sql<R: AnyImpl>(
&self,
query: &mut String,
args: &mut AnyArguments,
arg_counter: &mut usize,
) {
query.push_str("SELECT ");
if self.is_distinct {
query.push_str("DISTINCT ");
}
query.push_str(&self.select_args_sql::<R>().join(", "));
query.push_str(" FROM \"");
query.push_str(&self.table_name.to_snake_case());
query.push_str("\" ");
if let Some(alias) = &self.alias {
query.push_str(&format!("\"{}\" ", alias));
}
if !self.joins_clauses.is_empty() {
for join_clause in &self.joins_clauses {
query.push(' ');
join_clause(query, args, &self.driver, arg_counter);
}
}
query.push_str(" WHERE 1=1");
for clause in &self.where_clauses {
clause(query, args, &self.driver, arg_counter);
}
if !self.group_by_clauses.is_empty() {
query.push_str(&format!(" GROUP BY {}", self.group_by_clauses.join(", ")));
}
if !self.having_clauses.is_empty() {
query.push_str(" HAVING 1=1");
for clause in &self.having_clauses {
clause(query, args, &self.driver, arg_counter);
}
}
if !self.order_clauses.is_empty() {
query.push_str(&format!(" ORDER BY {}", self.order_clauses.join(", ")));
}
if let Some(limit) = self.limit {
query.push_str(" LIMIT ");
match self.driver {
Drivers::Postgres => {
query.push_str(&format!("${}", arg_counter));
*arg_counter += 1;
}
_ => query.push('?'),
}
let _ = args.add(limit as i64);
}
if let Some(offset) = self.offset {
query.push_str(" OFFSET ");
match self.driver {
Drivers::Postgres => {
query.push_str(&format!("${}", arg_counter));
*arg_counter += 1;
}
_ => query.push('?'),
}
let _ = args.add(offset as i64);
}
for (_op, clause) in &self.union_clauses {
clause(query, args, &self.driver, arg_counter);
}
}
pub fn filter<V>(self, col: &'static str, op: Op, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.filter_internal(" AND ", col, op, value)
}
pub fn or_filter<V>(self, col: &'static str, op: Op, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.filter_internal(" OR ", col, op, value)
}
pub fn not_filter<V>(self, col: &'static str, op: Op, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.filter_internal(" AND NOT ", col, op, value)
}
pub fn or_not_filter<V>(self, col: &'static str, op: Op, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.filter_internal(" OR NOT ", col, op, value)
}
pub fn between<V>(mut self, col: &'static str, start: V, end: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col.to_snake_case());
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(" AND ");
if let Some((table, column)) = col.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col));
} else {
query.push_str(&format!("\"{}\"", col));
}
query.push_str(" BETWEEN ");
match driver {
Drivers::Postgres => {
query.push_str(&format!("${} AND ${}", arg_counter, *arg_counter + 1));
*arg_counter += 2;
}
_ => query.push_str("? AND ?"),
}
let _ = args.add(start.clone());
let _ = args.add(end.clone());
});
self.where_clauses.push(clause);
self
}
pub fn or_between<V>(mut self, col: &'static str, start: V, end: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col.to_snake_case());
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(" OR ");
if let Some((table, column)) = col.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col));
} else {
query.push_str(&format!("\"{}\"", col));
}
query.push_str(" BETWEEN ");
match driver {
Drivers::Postgres => {
query.push_str(&format!("${} AND ${}", arg_counter, *arg_counter + 1));
*arg_counter += 2;
}
_ => query.push_str("? AND ?"),
}
let _ = args.add(start.clone());
let _ = args.add(end.clone());
});
self.where_clauses.push(clause);
self
}
pub fn in_list<V>(mut self, col: &'static str, values: Vec<V>) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
if values.is_empty() {
let clause: FilterFn = Box::new(|query, _, _, _| {
query.push_str(" AND 1=0");
});
self.where_clauses.push(clause);
return self;
}
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col.to_snake_case());
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(" AND ");
if let Some((table, column)) = col.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col));
} else {
query.push_str(&format!("\"{}\"", col));
}
query.push_str(" IN (");
let mut placeholders = Vec::new();
for _ in &values {
match driver {
Drivers::Postgres => {
placeholders.push(format!("${}", arg_counter));
*arg_counter += 1;
}
_ => placeholders.push("?".to_string()),
}
}
query.push_str(&placeholders.join(", "));
query.push(')');
for val in &values {
let _ = args.add(val.clone());
}
});
self.where_clauses.push(clause);
self
}
pub fn or_in_list<V>(mut self, col: &'static str, values: Vec<V>) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
if values.is_empty() {
return self;
}
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col.to_snake_case());
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(" OR ");
if let Some((table, column)) = col.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col));
} else {
query.push_str(&format!("\"{}\"", col));
}
query.push_str(" IN (");
let mut placeholders = Vec::new();
for _ in &values {
match driver {
Drivers::Postgres => {
placeholders.push(format!("${}", arg_counter));
*arg_counter += 1;
}
_ => placeholders.push("?".to_string()),
}
}
query.push_str(&placeholders.join(", "));
query.push(')');
for val in &values {
let _ = args.add(val.clone());
}
});
self.where_clauses.push(clause);
self
}
pub fn group<F>(mut self, f: F) -> Self
where
F: FnOnce(Self) -> Self,
{
let old_clauses = std::mem::take(&mut self.where_clauses);
self = f(self);
let group_clauses = std::mem::take(&mut self.where_clauses);
self.where_clauses = old_clauses;
if !group_clauses.is_empty() {
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(" AND (1=1");
for c in &group_clauses {
c(query, args, driver, arg_counter);
}
query.push_str(")");
});
self.where_clauses.push(clause);
}
self
}
pub fn or_group<F>(mut self, f: F) -> Self
where
F: FnOnce(Self) -> Self,
{
let old_clauses = std::mem::take(&mut self.where_clauses);
self = f(self);
let group_clauses = std::mem::take(&mut self.where_clauses);
self.where_clauses = old_clauses;
if !group_clauses.is_empty() {
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(" OR (1=1");
for c in &group_clauses {
c(query, args, driver, arg_counter);
}
query.push_str(")");
});
self.where_clauses.push(clause);
}
self
}
pub fn where_raw<V>(mut self, sql: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.where_clauses.push(self.create_raw_clause(" AND ", sql, value));
self
}
pub fn or_where_raw<V>(mut self, sql: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.where_clauses.push(self.create_raw_clause(" OR ", sql, value));
self
}
fn create_raw_clause<V>(&self, joiner: &'static str, sql: &str, value: V) -> FilterFn
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
let sql_owned = sql.to_string();
Box::new(move |query, args, driver, arg_counter| {
query.push_str(joiner);
let mut processed_sql = sql_owned.clone();
if !processed_sql.contains('?') {
let trimmed = processed_sql.trim();
if trimmed.ends_with('=') || trimmed.ends_with('>') || trimmed.ends_with('<') || trimmed.to_uppercase().ends_with(" LIKE") {
processed_sql.push_str(" ?");
} else if !trimmed.contains(' ') && !trimmed.contains('(') {
processed_sql.push_str(" = ?");
}
}
if matches!(driver, Drivers::Postgres) {
while let Some(pos) = processed_sql.find('?') {
let placeholder = format!("${}", arg_counter);
*arg_counter += 1;
processed_sql.replace_range(pos..pos + 1, &placeholder);
}
}
query.push_str(&processed_sql);
let _ = args.add(value.clone());
})
}
pub fn equals<V>(self, col: &'static str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.filter(col, Op::Eq, value)
}
pub fn order(mut self, order: &str) -> Self {
self.order_clauses.push(order.to_string());
self
}
pub fn alias(mut self, alias: &str) -> Self {
self.alias = Some(alias.to_string());
self
}
pub fn debug(mut self) -> Self {
self.debug_mode = true;
self
}
pub fn is_null(mut self, col: &str) -> Self {
let col_owned = col.to_string();
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col_owned.to_snake_case());
let clause: FilterFn = Box::new(move |query, _args, _driver, _arg_counter| {
query.push_str(" AND ");
if let Some((table, column)) = col_owned.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col_owned));
} else {
query.push_str(&format!("\"{}\"", col_owned));
}
query.push_str(" IS NULL");
});
self.where_clauses.push(clause);
self
}
pub fn is_not_null(mut self, col: &str) -> Self {
let col_owned = col.to_string();
let table_id = self.get_table_identifier();
let is_main_col = self.columns.contains(&col_owned.to_snake_case());
let clause: FilterFn = Box::new(move |query, _args, _driver, _arg_counter| {
query.push_str(" AND ");
if let Some((table, column)) = col_owned.split_once(".") {
query.push_str(&format!("\"{}\".\"{}\"", table, column));
} else if is_main_col {
query.push_str(&format!("\"{}\".\"{}\"", table_id, col_owned));
} else {
query.push_str(&format!("\"{}\"", col_owned));
}
query.push_str(" IS NOT NULL");
});
self.where_clauses.push(clause);
self
}
pub fn with_deleted(mut self) -> Self {
self.with_deleted = true;
self
}
pub fn join(self, table: &str, s_query: &str) -> Self {
self.join_generic("", table, s_query)
}
fn join_generic(mut self, join_type: &str, table: &str, s_query: &str) -> Self {
let table_owned = table.to_string();
let join_type_owned = join_type.to_string();
let trimmed_value = s_query.replace(" ", "");
let values = trimmed_value.split_once("=");
let mut parsed_query = s_query.to_string();
if let Some((first, second)) = values {
if let Some((t1, c1)) = first.split_once('.') {
if let Some((t2, c2)) = second.split_once('.') {
parsed_query = format!("\"{}\".\"{}\" = \"{}\".\"{}\"", t1, c1, t2, c2);
}
}
}
if let Some((table_name, alias)) = table.split_once(" ") {
self.join_aliases.insert(table_name.to_snake_case(), alias.to_string());
} else {
self.join_aliases.insert(table.to_snake_case(), table.to_string());
}
self.joins_clauses.push(Box::new(move |query, _args, _driver, _arg_counter| {
if let Some((table_name, alias)) = table_owned.split_once(" ") {
query.push_str(&format!("{} JOIN \"{}\" \"{}\" ON {}", join_type_owned, table_name, alias, parsed_query));
} else {
query.push_str(&format!("{} JOIN \"{}\" ON {}", join_type_owned, table_owned, parsed_query));
}
}));
self
}
pub fn join_raw<V>(self, table: &str, on: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.join_generic_raw("", table, on, value)
}
pub fn left_join_raw<V>(self, table: &str, on: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.join_generic_raw("LEFT", table, on, value)
}
pub fn right_join_raw<V>(self, table: &str, on: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.join_generic_raw("RIGHT", table, on, value)
}
pub fn inner_join_raw<V>(self, table: &str, on: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.join_generic_raw("INNER", table, on, value)
}
pub fn full_join_raw<V>(self, table: &str, on: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.join_generic_raw("FULL", table, on, value)
}
fn join_generic_raw<V>(mut self, join_type: &str, table: &str, on: &str, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
let table_owned = table.to_string();
let on_owned = on.to_string();
let join_type_owned = join_type.to_string();
if let Some((table_name, alias)) = table.split_once(" ") {
self.join_aliases.insert(table_name.to_snake_case(), alias.to_string());
} else {
self.join_aliases.insert(table.to_snake_case(), table.to_string());
}
self.joins_clauses.push(Box::new(move |query, args, driver, arg_counter| {
if let Some((table_name, alias)) = table_owned.split_once(" ") {
query.push_str(&format!("{} JOIN \"{}\" {} ON ", join_type_owned, table_name, alias));
} else {
query.push_str(&format!("{} JOIN \"{}\" ON ", join_type_owned, table_owned));
}
let mut processed_on = on_owned.clone();
if let Some(pos) = processed_on.find('?') {
let placeholder = match driver {
Drivers::Postgres => {
let p = format!("${}", arg_counter);
*arg_counter += 1;
p
}
_ => "?".to_string(),
};
processed_on.replace_range(pos..pos + 1, &placeholder);
}
query.push_str(&processed_on);
let _ = args.add(value.clone());
}));
self
}
pub fn left_join(self, table: &str, on: &str) -> Self {
self.join_generic("LEFT", table, on)
}
pub fn right_join(self, table: &str, on: &str) -> Self {
self.join_generic("RIGHT", table, on)
}
pub fn inner_join(self, table: &str, on: &str) -> Self {
self.join_generic("INNER", table, on)
}
pub fn full_join(self, table: &str, on: &str) -> Self {
self.join_generic("FULL", table, on)
}
pub fn distinct(mut self) -> Self {
self.is_distinct = true;
self
}
pub fn group_by(mut self, columns: &str) -> Self {
self.group_by_clauses.push(columns.to_string());
self
}
pub fn having<V>(mut self, col: &'static str, op: Op, value: V) -> Self
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
let op_str = op.as_sql();
let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
query.push_str(" AND ");
query.push_str(col);
query.push(' ');
query.push_str(op_str);
query.push(' ');
match driver {
Drivers::Postgres => {
query.push_str(&format!("${}", arg_counter));
*arg_counter += 1;
}
_ => query.push('?'),
}
let _ = args.add(value.clone());
});
self.having_clauses.push(clause);
self
}
pub async fn count(mut self) -> Result<i64, sqlx::Error> {
self.select_columns = vec!["COUNT(*)".to_string()];
self.scalar::<i64>().await
}
pub async fn sum<N>(mut self, column: &str) -> Result<N, sqlx::Error>
where
N: FromAnyRow + AnyImpl + for<'r> Decode<'r, Any> + Type<Any> + Send + Unpin,
{
let quoted_col = if column.contains('.') {
let parts: Vec<&str> = column.split('.').collect();
format!("\"{}\".\"{}\"", parts[0].trim_matches('"'), parts[1].trim_matches('"'))
} else {
format!("\"{}\"", column.trim_matches('"'))
};
self.select_columns = vec![format!("SUM({})", quoted_col)];
self.scalar::<N>().await
}
pub async fn avg<N>(mut self, column: &str) -> Result<N, sqlx::Error>
where
N: FromAnyRow + AnyImpl + for<'r> Decode<'r, Any> + Type<Any> + Send + Unpin,
{
let quoted_col = if column.contains('.') {
let parts: Vec<&str> = column.split('.').collect();
format!("\"{}\".\"{}\"", parts[0].trim_matches('"'), parts[1].trim_matches('"'))
} else {
format!("\"{}\"", column.trim_matches('"'))
};
self.select_columns = vec![format!("AVG({})", quoted_col)];
self.scalar::<N>().await
}
pub async fn min<N>(mut self, column: &str) -> Result<N, sqlx::Error>
where
N: FromAnyRow + AnyImpl + for<'r> Decode<'r, Any> + Type<Any> + Send + Unpin,
{
let quoted_col = if column.contains('.') {
let parts: Vec<&str> = column.split('.').collect();
format!("\"{}\".\"{}\"", parts[0].trim_matches('"'), parts[1].trim_matches('"'))
} else {
format!("\"{}\"", column.trim_matches('"'))
};
self.select_columns = vec![format!("MIN({})", quoted_col)];
self.scalar::<N>().await
}
pub async fn max<N>(mut self, column: &str) -> Result<N, sqlx::Error>
where
N: FromAnyRow + AnyImpl + for<'r> Decode<'r, Any> + Type<Any> + Send + Unpin,
{
let quoted_col = if column.contains('.') {
let parts: Vec<&str> = column.split('.').collect();
format!("\"{}\".\"{}\"", parts[0].trim_matches('"'), parts[1].trim_matches('"'))
} else {
format!("\"{}\"", column.trim_matches('"'))
};
self.select_columns = vec![format!("MAX({})", quoted_col)];
self.scalar::<N>().await
}
pub fn pagination(mut self, max_value: usize, default: usize, page: usize, value: isize) -> Result<Self, Error> {
if value < 0 {
return Err(Error::InvalidArgument("value cannot be negative".into()));
}
let mut f_value = value as usize;
if f_value > max_value {
f_value = default;
}
self = self.offset(f_value * page);
self = self.limit(f_value);
Ok(self)
}
pub fn select(mut self, columns: &str) -> Self {
self.select_columns.push(columns.to_string());
self
}
pub fn omit(mut self, columns: &str) -> Self {
for col in columns.split(',') {
self.omit_columns.push(col.trim().to_snake_case());
}
self
}
pub fn offset(mut self, offset: usize) -> Self {
self.offset = Some(offset);
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = Some(limit);
self
}
pub fn insert<'b>(&'b mut self, model: &'b T) -> BoxFuture<'b, Result<(), sqlx::Error>> {
Box::pin(async move {
let data_map = Model::to_map(model);
if data_map.is_empty() {
return Ok(());
}
let table_name = self.table_name.to_snake_case();
let columns_info = <T as Model>::columns();
let mut target_columns = Vec::new();
let mut bindings: Vec<(Option<String>, &str)> = Vec::new();
for (col_name, value) in data_map {
let col_name_clean = col_name.strip_prefix("r#").unwrap_or(&col_name).to_snake_case();
target_columns.push(format!("\"{}\"", col_name_clean));
let sql_type = columns_info.iter().find(|c| c.name == col_name).map(|c| c.sql_type).unwrap_or("TEXT");
bindings.push((value, sql_type));
}
let placeholders: Vec<String> = bindings
.iter()
.enumerate()
.map(|(i, (_, sql_type))| match self.driver {
Drivers::Postgres => {
let idx = i + 1;
if temporal::is_temporal_type(sql_type) {
format!("${}{}", idx, temporal::get_postgres_type_cast(sql_type))
} else {
match *sql_type {
"UUID" => format!("${}::UUID", idx),
"JSONB" | "jsonb" => format!("${}::JSONB", idx),
s if s.ends_with("[]") => format!("${}::{}", idx, s),
_ => format!("${}", idx),
}
}
}
_ => "?".to_string(),
})
.collect();
let query_str = format!(
"INSERT INTO \"{}\" ({}) VALUES ({})",
table_name,
target_columns.join(", "),
placeholders.join(", ")
);
if self.debug_mode {
log::debug!("SQL: {}", query_str);
}
let mut args = AnyArguments::default();
for (val_opt, sql_type) in bindings {
if let Some(val_str) = val_opt {
if args.bind_value(&val_str, sql_type, &self.driver).is_err() {
let _ = args.add(val_str);
}
} else {
match sql_type {
"INTEGER" | "INT" | "INT4" | "SERIAL" => { let _ = args.add(None::<i32>); }
"BIGINT" | "INT8" | "BIGSERIAL" => { let _ = args.add(None::<i64>); }
"REAL" | "FLOAT4" => { let _ = args.add(None::<f32>); }
"DOUBLE PRECISION" | "FLOAT8" | "FLOAT" => { let _ = args.add(None::<f64>); }
"BOOLEAN" | "BOOL" => { let _ = args.add(None::<bool>); }
_ => { let _ = args.add(None::<String>); }
}
}
}
self.tx.execute(&query_str, args).await?;
Ok(())
})
}
pub fn batch_insert<'b>(&'b mut self, models: &'b [T]) -> BoxFuture<'b, Result<(), sqlx::Error>> {
Box::pin(async move {
if models.is_empty() {
return Ok(());
}
let table_name = self.table_name.to_snake_case();
let columns_info = <T as Model>::columns();
let target_columns: Vec<String> = columns_info
.iter()
.map(|c| {
let col_name_clean = c.name.strip_prefix("r#").unwrap_or(c.name).to_snake_case();
format!("\"{}\"", col_name_clean)
})
.collect();
let mut value_groups = Vec::new();
let mut bind_index = 1;
for _ in models {
let mut placeholders = Vec::new();
for col in &columns_info {
match self.driver {
Drivers::Postgres => {
let p = if temporal::is_temporal_type(col.sql_type) {
format!("${}{}", bind_index, temporal::get_postgres_type_cast(col.sql_type))
} else {
match col.sql_type {
"UUID" => format!("${}::UUID", bind_index),
"JSONB" | "jsonb" => format!("${}::JSONB", bind_index),
_ => format!("${}", bind_index),
}
};
placeholders.push(p);
bind_index += 1;
}
_ => {
placeholders.push("?".to_string());
}
}
}
value_groups.push(format!("({})", placeholders.join(", ")));
}
let query_str = format!(
"INSERT INTO \"{}\" ({}) VALUES {}",
table_name,
target_columns.join(", "),
value_groups.join(", ")
);
if self.debug_mode {
log::debug!("SQL Batch: {}", query_str);
}
let mut args = AnyArguments::default();
for model in models {
let data_map = Model::to_map(model);
for col in &columns_info {
let val_opt = data_map.get(col.name);
let sql_type = col.sql_type;
if let Some(Some(val_str)) = val_opt {
if args.bind_value(val_str, sql_type, &self.driver).is_err() {
let _ = args.add(val_str.clone());
}
} else {
match sql_type {
"INTEGER" | "INT" | "INT4" | "SERIAL" => { let _ = args.add(None::<i32>); }
"BIGINT" | "INT8" | "BIGSERIAL" => { let _ = args.add(None::<i64>); }
"REAL" | "FLOAT4" => { let _ = args.add(None::<f32>); }
"DOUBLE PRECISION" | "FLOAT8" | "FLOAT" => { let _ = args.add(None::<f64>); }
"BOOLEAN" | "BOOL" => { let _ = args.add(None::<bool>); }
_ => { let _ = args.add(None::<String>); }
}
}
}
}
self.tx.execute(&query_str, args).await?;
Ok(())
})
}
pub fn upsert<'b>(
&'b mut self,
model: &'b T,
conflict_columns: &'b [&'b str],
update_columns: &'b [&'b str],
) -> BoxFuture<'b, Result<u64, sqlx::Error>> {
Box::pin(async move {
let data_map = Model::to_map(model);
if data_map.is_empty() {
return Ok(0);
}
let table_name = self.table_name.to_snake_case();
let columns_info = <T as Model>::columns();
let mut target_columns = Vec::new();
let mut bindings: Vec<(Option<String>, &str)> = Vec::new();
for (col_name, value) in &data_map {
let col_name_clean = col_name.strip_prefix("r#").unwrap_or(col_name).to_snake_case();
target_columns.push(format!("\"{}\"", col_name_clean));
let sql_type = columns_info.iter().find(|c| {
let c_clean = c.name.strip_prefix("r#").unwrap_or(c.name);
c_clean == *col_name || c_clean.to_snake_case() == col_name_clean
}).map(|c| c.sql_type).unwrap_or("TEXT");
bindings.push((value.clone(), sql_type));
}
let mut arg_counter = 1;
let mut placeholders = Vec::new();
for (_, sql_type) in &bindings {
match self.driver {
Drivers::Postgres => {
let p = if temporal::is_temporal_type(sql_type) {
format!("${}{}", arg_counter, temporal::get_postgres_type_cast(sql_type))
} else {
match *sql_type {
"UUID" => format!("${}::UUID", arg_counter),
"JSONB" | "jsonb" => format!("${}::JSONB", arg_counter),
_ => format!("${}", arg_counter),
}
};
placeholders.push(p);
arg_counter += 1;
}
_ => {
placeholders.push("?".to_string());
}
}
}
let mut query_str = format!(
"INSERT INTO \"{}\" ({}) VALUES ({})",
table_name,
target_columns.join(", "),
placeholders.join(", ")
);
match self.driver {
Drivers::Postgres | Drivers::SQLite => {
let conflict_cols_str = conflict_columns
.iter()
.map(|c| format!("\"{}\"", c.to_snake_case()))
.collect::<Vec<_>>()
.join(", ");
query_str.push_str(&format!(" ON CONFLICT ({}) DO UPDATE SET ", conflict_cols_str));
let mut update_clauses = Vec::new();
let mut update_bindings = Vec::new();
for col in update_columns {
let col_snake = col.to_snake_case();
if let Some((_key, val_opt)) = data_map.iter().find(|(k, _)| {
let k_clean = k.strip_prefix("r#").unwrap_or(*k);
k_clean == *col || k_clean.to_snake_case() == col_snake
}) {
let sql_type_opt = columns_info.iter().find(|c| {
let c_clean = c.name.strip_prefix("r#").unwrap_or(c.name);
c_clean == *col || c_clean.to_snake_case() == col_snake
}).map(|c| c.sql_type);
let sql_type = match sql_type_opt {
Some(t) => t,
None => continue,
};
let placeholder = match self.driver {
Drivers::Postgres => {
let p = if temporal::is_temporal_type(sql_type) {
format!("${}{}", arg_counter, temporal::get_postgres_type_cast(sql_type))
} else {
match sql_type {
"UUID" => format!("${}::UUID", arg_counter),
"JSONB" | "jsonb" => format!("${}::JSONB", arg_counter),
_ => format!("${}", arg_counter),
}
};
arg_counter += 1;
p
}
_ => "?".to_string(),
};
update_clauses.push(format!("\"{}\" = {}", col_snake, placeholder));
update_bindings.push((val_opt.clone(), sql_type));
}
}
if update_clauses.is_empty() {
query_str.push_str(" NOTHING");
} else {
query_str.push_str(&update_clauses.join(", "));
}
bindings.extend(update_bindings);
}
Drivers::MySQL => {
query_str.push_str(" ON DUPLICATE KEY UPDATE ");
let mut update_clauses = Vec::new();
for col in update_columns {
let col_snake = col.to_snake_case();
update_clauses.push(format!("\"{}\" = VALUES(\"{}\")", col_snake, col_snake));
}
query_str.push_str(&update_clauses.join(", "));
}
}
if self.debug_mode {
log::debug!("SQL Upsert: {}", query_str);
}
let mut args = AnyArguments::default();
for (val_opt, sql_type) in bindings {
if let Some(val_str) = val_opt {
if args.bind_value(&val_str, sql_type, &self.driver).is_err() {
let _ = args.add(val_str);
}
} else {
match sql_type {
"INTEGER" | "INT" | "INT4" | "SERIAL" => { let _ = args.add(None::<i32>); }
"BIGINT" | "INT8" | "BIGSERIAL" => { let _ = args.add(None::<i64>); }
"REAL" | "FLOAT4" => { let _ = args.add(None::<f32>); }
"DOUBLE PRECISION" | "FLOAT8" | "FLOAT" => { let _ = args.add(None::<f64>); }
"BOOLEAN" | "BOOL" => { let _ = args.add(None::<bool>); }
_ => { let _ = args.add(None::<String>); }
}
}
}
let result = self.tx.execute(&query_str, args).await?;
Ok(result.rows_affected())
})
}
pub fn to_sql(&self) -> String {
let mut query = String::new();
let mut args = AnyArguments::default();
let mut arg_counter = 1;
self.write_select_sql::<T>(&mut query, &mut args, &mut arg_counter);
query
}
fn select_args_sql<R: AnyImpl>(&self) -> Vec<String> {
let struct_cols = R::columns();
let table_id = self.get_table_identifier();
let main_table_snake = self.table_name.to_snake_case();
if struct_cols.is_empty() {
if self.select_columns.is_empty() { return vec!["*".to_string()]; }
if matches!(self.driver, Drivers::Postgres) {
let mut args = Vec::new();
for s in &self.select_columns {
for sub in s.split(',') {
let s_trim = sub.trim();
if s_trim.contains(' ') || s_trim.contains('(') {
args.push(s_trim.to_string());
continue;
}
let (t, c) = if let Some((t, c)) = s_trim.split_once('.') {
(t.trim().trim_matches('"'), c.trim().trim_matches('"'))
} else {
(table_id.as_str(), s_trim.trim_matches('"'))
};
let c_snake = c.to_snake_case();
let mut is_temporal = false;
if let Some(info) = self.columns_info.iter().find(|info| {
info.name.to_snake_case() == c_snake
}) {
if is_temporal_type(info.sql_type) {
is_temporal = true;
}
}
if is_temporal {
args.push(format!("to_json(\"{}\".\"{}\") #>> '{{}}' AS \"{}\"", t, c, c));
} else {
args.push(format!("\"{}\".\"{}\"", t, c));
}
}
}
return args;
}
return self.select_columns.clone();
}
let mut flat_selects = Vec::new();
for s in &self.select_columns {
for sub in s.split(',') { flat_selects.push(sub.trim().to_string()); }
}
let mut expanded_tables = HashSet::new();
for s in &flat_selects {
if s == "*" { expanded_tables.insert(table_id.clone()); expanded_tables.insert(main_table_snake.clone()); }
else if let Some(t) = s.strip_suffix(".*") { let t_clean = t.trim().trim_matches('"'); expanded_tables.insert(t_clean.to_string()); expanded_tables.insert(t_clean.to_snake_case()); }
}
let mut col_counts = HashMap::new();
for col_info in &struct_cols {
let col_snake = col_info.column.strip_prefix("r#").unwrap_or(col_info.column).to_snake_case();
*col_counts.entry(col_snake).or_insert(0) += 1;
}
let is_tuple = format!("{:?}", std::any::type_name::<R>()).contains('(');
let mut matched_s_indices = HashSet::new();
let mut manual_field_map = HashMap::new();
for (f_idx, s) in flat_selects.iter().enumerate() {
if s == "*" || s.ends_with(".*") { continue; }
let s_lower = s.to_lowercase();
for (s_idx, col_info) in struct_cols.iter().enumerate() {
if matched_s_indices.contains(&s_idx) { continue; }
let col_snake = col_info.column.strip_prefix("r#").unwrap_or(col_info.column).to_snake_case();
let mut m = false;
if let Some((_, alias)) = s_lower.split_once(" as ") {
let ca = alias.trim().trim_matches('"').trim_matches('\'');
if ca == col_info.column || ca == &col_snake { m = true; }
} else if s == col_info.column || s == &col_snake || s.ends_with(&format!(".{}", col_info.column)) || s.ends_with(&format!(".{}", col_snake)) {
m = true;
}
if m { manual_field_map.insert(f_idx, s_idx); matched_s_indices.insert(s_idx); break; }
}
}
let mut args = Vec::new();
if self.select_columns.is_empty() {
for (s_idx, col_info) in struct_cols.iter().enumerate() {
let mut t_use = table_id.clone();
if !col_info.table.is_empty() {
let c_snake = col_info.table.to_snake_case();
if c_snake == main_table_snake { t_use = table_id.clone(); }
else if let Some(alias) = self.join_aliases.get(&c_snake) { t_use = alias.clone(); }
else if self.join_aliases.values().any(|a| a == &col_info.table) { t_use = col_info.table.to_string(); }
}
args.push(self.format_select_field::<R>(s_idx, &t_use, &main_table_snake, &col_counts, is_tuple));
}
} else {
for (f_idx, s) in flat_selects.iter().enumerate() {
let s_trim = s.trim();
if s_trim == "*" || s_trim.ends_with(".*") {
let mut t_exp = if s_trim == "*" { String::new() } else { s_trim.strip_suffix(".*").unwrap_or(s_trim).trim().trim_matches('"').to_string() };
if !t_exp.is_empty() && (t_exp.to_snake_case() == main_table_snake || t_exp == table_id) { t_exp = table_id.clone(); }
for (s_idx, col_info) in struct_cols.iter().enumerate() {
if matched_s_indices.contains(&s_idx) { continue; }
let mut t_col = table_id.clone(); let mut known = false;
if !col_info.table.is_empty() {
let c_snake = col_info.table.to_snake_case();
if c_snake == main_table_snake { t_col = table_id.clone(); known = true; }
else if let Some(alias) = self.join_aliases.get(&c_snake) { t_col = alias.clone(); known = true; }
else if self.join_aliases.values().any(|a| a == &col_info.table) { t_col = col_info.table.to_string(); known = true; }
}
if !known && !t_exp.is_empty() && flat_selects.iter().filter(|x| x.ends_with(".*") || *x == "*").count() == 1 { t_col = t_exp.clone(); known = true; }
if (t_exp.is_empty() && known) || (!t_exp.is_empty() && t_col == t_exp) {
args.push(self.format_select_field::<R>(s_idx, &t_col, &main_table_snake, &col_counts, is_tuple));
matched_s_indices.insert(s_idx);
}
}
} else if let Some(s_idx) = manual_field_map.get(&f_idx) {
if s.to_lowercase().contains(" as ") { args.push(s_trim.to_string()); }
else {
let mut t = table_id.clone();
if let Some((prefix, _)) = s_trim.split_once('.') { t = prefix.trim().trim_matches('"').to_string(); }
args.push(self.format_select_field::<R>(*s_idx, &t, &main_table_snake, &col_counts, is_tuple));
}
} else {
if !s_trim.contains(' ') && !s_trim.contains('(') {
if let Some((t, c)) = s_trim.split_once('.') { args.push(format!("\"{}\".\"{}\"", t.trim().trim_matches('"'), c.trim().trim_matches('"'))); }
else { args.push(format!("\"{}\"", s_trim.trim_matches('"'))); }
} else { args.push(s_trim.to_string()); }
}
}
}
if args.is_empty() { vec!["*".to_string()] } else { args }
}
fn format_select_field<R: AnyImpl>(&self, s_idx: usize, table_to_use: &str, main_table_snake: &str, col_counts: &HashMap<String, usize>, is_tuple: bool) -> String {
let col_info = &R::columns()[s_idx];
let col_snake = col_info.column.strip_prefix("r#").unwrap_or(col_info.column).to_snake_case();
let has_collision = *col_counts.get(&col_snake).unwrap_or(&0) > 1;
let alias = if is_tuple || has_collision {
let t_alias = if !col_info.table.is_empty() { col_info.table.to_snake_case() } else { main_table_snake.to_string() };
format!("{}__{}", t_alias.to_lowercase(), col_snake.to_lowercase())
} else { col_snake.to_lowercase() };
if is_temporal_type(col_info.sql_type) && matches!(self.driver, Drivers::Postgres) {
format!("to_json(\"{}\".\"{}\") #>> '{{}}' AS \"{}\"", table_to_use, col_snake, alias)
} else {
format!("\"{}\".\"{}\" AS \"{}\"", table_to_use, col_snake, alias)
}
}
pub async fn scan<R>(mut self) -> Result<Vec<R>, sqlx::Error>
where
R: FromAnyRow + AnyImpl + Send + Unpin,
{
self.apply_soft_delete_filter();
let mut query = String::new();
let mut args = AnyArguments::default();
let mut arg_counter = 1;
self.write_select_sql::<R>(&mut query, &mut args, &mut arg_counter);
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let rows = self.tx.fetch_all(&query, args).await?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
result.push(R::from_any_row(&row)?);
}
Ok(result)
}
pub async fn scan_with(self) -> Result<Vec<T>, sqlx::Error>
where
T: FromAnyRow + AnyImpl + crate::model::Model + Send + Unpin + 'static,
E: Connection + Clone,
{
self.scan_as_with::<T>().await
}
pub async fn scan_as_with<R>(mut self) -> Result<Vec<R>, sqlx::Error>
where
R: FromAnyRow + AnyImpl + crate::model::Model + Send + Unpin + 'static,
E: Clone,
{
let with_relations = std::mem::take(&mut self.with_relations);
let with_modifiers = std::mem::take(&mut self.with_modifiers);
let tx = self.tx.clone();
let mut results: Vec<R> = self.scan_as::<R>().await?;
if !results.is_empty() && !with_relations.is_empty() {
let mut grouped: std::collections::HashMap<String, Vec<String>> = std::collections::HashMap::new();
for rel in with_relations {
if let Some(pos) = rel.find('.') {
let base = rel[..pos].to_string();
let nested = rel[pos + 1..].to_string();
grouped.entry(base).or_default().push(nested);
} else {
grouped.entry(rel).or_default();
}
}
for (base, nested_parts) in grouped {
let modifier = with_modifiers.get(&base).cloned();
let full_rel = if nested_parts.is_empty() {
base
} else {
let filtered: Vec<_> = nested_parts.into_iter().filter(|s| !s.is_empty()).collect();
if filtered.is_empty() {
base
} else if filtered.len() == 1 {
format!("{}.{}", base, filtered[0])
} else {
format!("{}.({})", base, filtered.join("|"))
}
};
R::load_relations(&full_rel, &mut results, &tx, modifier).await?;
}
}
Ok(results)
}
pub async fn scan_as<R>(mut self) -> Result<Vec<R>, sqlx::Error>
where
R: FromAnyRow + AnyImpl + Send + Unpin,
{
self.apply_soft_delete_filter();
let mut query = String::new();
let mut args = AnyArguments::default();
let mut arg_counter = 1;
self.write_select_sql::<R>(&mut query, &mut args, &mut arg_counter);
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let rows = self.tx.fetch_all(&query, args).await?;
let mut result = Vec::with_capacity(rows.len());
for row in rows {
result.push(R::from_any_row(&row)?);
}
Ok(result)
}
pub async fn first<R>(mut self) -> Result<R, sqlx::Error>
where
R: FromAnyRow + AnyImpl + Send + Unpin,
{
self.apply_soft_delete_filter();
let mut query = String::new();
let mut args = AnyArguments::default();
let mut arg_counter = 1;
if self.limit.is_none() {
self.limit = Some(1);
}
if self.order_clauses.is_empty() {
let table_id = self.get_table_identifier();
let pk_columns: Vec<String> = <T as Model>::columns()
.iter()
.filter(|c| c.is_primary_key)
.map(|c| format!("\"{}\".\"{}\"", table_id, c.name.strip_prefix("r#").unwrap_or(c.name).to_snake_case()))
.collect();
if !pk_columns.is_empty() {
self.order_clauses.push(pk_columns.iter().map(|col| format!("{} ASC", col)).collect::<Vec<_>>().join(", "));
}
}
self.write_select_sql::<R>(&mut query, &mut args, &mut arg_counter);
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let row = self.tx.fetch_one(&query, args).await?;
R::from_any_row(&row)
}
pub async fn scalar<O>(mut self) -> Result<O, sqlx::Error>
where
O: FromAnyRow + AnyImpl + Send + Unpin,
{
self.apply_soft_delete_filter();
let mut query = String::new();
let mut args = AnyArguments::default();
let mut arg_counter = 1;
if self.limit.is_none() {
self.limit = Some(1);
}
self.write_select_sql::<O>(&mut query, &mut args, &mut arg_counter);
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let row = self.tx.fetch_one(&query, args).await?;
O::from_any_row(&row)
}
pub fn update<'b, V>(&'b mut self, col: &str, value: V) -> BoxFuture<'b, Result<u64, sqlx::Error>>
where
V: ToUpdateValue + Send + Sync,
{
let mut map = std::collections::HashMap::new();
map.insert(col.to_string(), value.to_update_value());
self.execute_update(map)
}
pub fn updates<'b>(&'b mut self, model: &T) -> BoxFuture<'b, Result<u64, sqlx::Error>> {
self.execute_update(Model::to_map(model))
}
pub fn update_partial<'b, P: AnyImpl>(&'b mut self, partial: &P) -> BoxFuture<'b, Result<u64, sqlx::Error>> {
self.execute_update(AnyImpl::to_map(partial))
}
pub fn update_raw<'b, V>(
&'b mut self,
col: &str,
expr: &str,
value: V,
) -> BoxFuture<'b, Result<u64, sqlx::Error>>
where
V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
{
self.apply_soft_delete_filter();
let col_name_clean = col.strip_prefix("r#").unwrap_or(col).to_snake_case();
let expr_owned = expr.to_string();
let value_owned = value.clone();
Box::pin(async move {
let table_name = self.table_name.to_snake_case();
let mut query = format!("UPDATE \"{}\" ", table_name);
if let Some(alias) = &self.alias {
query.push_str(&format!("AS {} ", alias));
}
query.push_str("SET ");
let mut arg_counter = 1;
let mut args = AnyArguments::default();
let mut processed_expr = expr_owned.clone();
let mut has_placeholder = false;
if processed_expr.contains('?') {
has_placeholder = true;
if matches!(self.driver, Drivers::Postgres) {
while let Some(pos) = processed_expr.find('?') {
let placeholder = format!("${}", arg_counter);
arg_counter += 1;
processed_expr.replace_range(pos..pos + 1, &placeholder);
}
}
}
if has_placeholder {
let _ = args.add(value_owned);
}
query.push_str(&format!("\"{}\" = {}", col_name_clean, processed_expr));
query.push_str(" WHERE 1=1");
for clause in &self.where_clauses {
clause(&mut query, &mut args, &self.driver, &mut arg_counter);
}
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let result = self.tx.execute(&query, args).await?;
Ok(result.rows_affected())
})
}
fn apply_soft_delete_filter(&mut self) {
if !self.with_deleted {
if let Some(soft_delete_col) = self.columns_info.iter().find(|c| c.soft_delete).map(|c| c.name) {
let col_owned = soft_delete_col.to_string();
let clause: FilterFn = Box::new(move |query, _args, _driver, _arg_counter| {
query.push_str(" AND ");
query.push_str(&format!("\"{}\"", col_owned));
query.push_str(" IS NULL");
});
self.where_clauses.push(clause);
}
}
}
fn execute_update<'b>(
&'b mut self,
data_map: std::collections::HashMap<String, Option<String>>,
) -> BoxFuture<'b, Result<u64, sqlx::Error>> {
self.apply_soft_delete_filter();
Box::pin(async move {
let table_name = self.table_name.to_snake_case();
let mut query = format!("UPDATE \"{}\" ", table_name);
if let Some(alias) = &self.alias {
query.push_str(&format!("{} ", alias));
}
query.push_str("SET ");
let mut bindings: Vec<(Option<String>, &str)> = Vec::new();
let mut set_clauses = Vec::new();
let mut arg_counter = 1;
for (col_name, value) in data_map {
let col_name_clean = col_name.strip_prefix("r#").unwrap_or(&col_name).to_snake_case();
let sql_type_opt = self
.columns_info
.iter()
.find(|c| c.name == col_name || c.name == col_name_clean)
.map(|c| c.sql_type);
let sql_type = match sql_type_opt {
Some(t) => t,
None => continue,
};
let placeholder = match self.driver {
Drivers::Postgres => {
let idx = arg_counter;
arg_counter += 1;
if temporal::is_temporal_type(sql_type) {
format!("${}{}", idx, temporal::get_postgres_type_cast(sql_type))
} else {
match sql_type {
"UUID" => format!("${}::UUID", idx),
"JSONB" | "jsonb" => format!("${}::JSONB", idx),
s if s.ends_with("[]") => format!("${}::{}", idx, s),
_ => format!("${}", idx),
}
}
}
_ => "?".to_string(),
};
set_clauses.push(format!("\"{}\" = {}", col_name_clean, placeholder));
bindings.push((value, sql_type));
}
if set_clauses.is_empty() {
return Ok(0);
}
query.push_str(&set_clauses.join(", "));
query.push_str(" WHERE 1=1");
let mut args = AnyArguments::default();
for (val_opt, sql_type) in bindings {
if let Some(val_str) = val_opt {
if args.bind_value(&val_str, sql_type, &self.driver).is_err() {
let _ = args.add(val_str);
}
} else {
match sql_type {
"INTEGER" | "INT" | "INT4" | "SERIAL" => { let _ = args.add(None::<i32>); }
"BIGINT" | "INT8" | "BIGSERIAL" => { let _ = args.add(None::<i64>); }
"REAL" | "FLOAT4" => { let _ = args.add(None::<f32>); }
"DOUBLE PRECISION" | "FLOAT8" | "FLOAT" => { let _ = args.add(None::<f64>); }
"BOOLEAN" | "BOOL" => { let _ = args.add(None::<bool>); }
_ => { let _ = args.add(None::<String>); }
}
}
}
for clause in &self.where_clauses {
clause(&mut query, &mut args, &self.driver, &mut arg_counter);
}
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let result = self.tx.execute(&query, args).await?;
Ok(result.rows_affected())
})
}
pub async fn delete(self) -> Result<u64, sqlx::Error> {
let soft_delete_col = self.columns_info.iter().find(|c| c.soft_delete).map(|c| c.name);
if let Some(col) = soft_delete_col {
let table_name = self.table_name.to_snake_case();
let mut query = format!("UPDATE \"{}\" ", table_name);
if let Some(alias) = &self.alias {
query.push_str(&format!("{} ", alias));
}
query.push_str(&format!("SET \"{}\" = ", col));
match self.driver {
Drivers::Postgres => query.push_str("NOW()"),
Drivers::SQLite => query.push_str("strftime('%Y-%m-%dT%H:%M:%SZ', 'now')"),
Drivers::MySQL => query.push_str("NOW()"),
}
query.push_str(" WHERE 1=1");
let mut args = AnyArguments::default();
let mut arg_counter = 1;
for clause in &self.where_clauses {
clause(&mut query, &mut args, &self.driver, &mut arg_counter);
}
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let result = self.tx.execute(&query, args).await?;
Ok(result.rows_affected())
} else {
let mut query = String::from("DELETE FROM \"");
query.push_str(&self.table_name.to_snake_case());
query.push_str("\" WHERE 1=1");
let mut args = AnyArguments::default();
let mut arg_counter = 1;
for clause in &self.where_clauses {
clause(&mut query, &mut args, &self.driver, &mut arg_counter);
}
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let result = self.tx.execute(&query, args).await?;
Ok(result.rows_affected())
}
}
pub async fn hard_delete(self) -> Result<u64, sqlx::Error> {
let mut query = String::from("DELETE FROM \"");
query.push_str(&self.table_name.to_snake_case());
query.push_str("\" WHERE 1=1");
let mut args = AnyArguments::default();
let mut arg_counter = 1;
for clause in &self.where_clauses {
clause(&mut query, &mut args, &self.driver, &mut arg_counter);
}
if self.debug_mode {
log::debug!("SQL: {}", query);
}
let result = self.tx.execute(&query, args).await?;
Ok(result.rows_affected())
}
}