use crate::{
Compatible, FromRow, HasDialect, LowerCompatible, Qrafting, Query, QueryOf, RpnInstr, TypeMeta,
alias::Aliased,
cte::IntoCtes,
emitter::Emitter,
expression::{As, Binary, Column, op},
impl_for_all_tuples,
lower::{Instructions, LowerCtx},
param::encode_param,
query::{LowerFilter, LowerFrom, LowerProject, Table, TypedCompiled, rewrite_params},
relation::PersistedField,
};
pub trait UpdateTable<M> {
fn table_name(&self) -> &'static str;
fn table_alias(&self) -> Option<&'static str>;
}
impl<M> UpdateTable<M> for Table<M> {
fn table_name(&self) -> &'static str {
self.name()
}
fn table_alias(&self) -> Option<&'static str> {
None
}
}
impl<M> UpdateTable<M> for Aliased<Table<M>> {
fn table_name(&self) -> &'static str {
self.inner.name()
}
fn table_alias(&self) -> Option<&'static str> {
Some(self.alias)
}
}
pub struct UpdateFrom<M> {
table: &'static str,
alias: Option<&'static str>,
query: Query,
marker: std::marker::PhantomData<M>,
}
pub fn update<M, T>(table: T) -> UpdateFrom<M>
where
T: UpdateTable<M>,
{
UpdateFrom {
table: table.table_name(),
alias: table.table_alias(),
query: Query::default(),
marker: std::marker::PhantomData,
}
}
pub fn build_dynamic_update<M, T, F>(table: T, mut query: Query, build: F) -> Update<M>
where
T: UpdateTable<M>,
F: for<'a> FnOnce(&mut UpdateVisitor<'a>),
{
let mut instrs = Vec::new();
{
let mut ctx = LowerCtx {
instrs: &mut instrs,
params: &mut query.params,
data: &mut query.data,
};
let mut visitor = UpdateVisitor::new(&mut ctx);
build(&mut visitor);
drop(visitor);
}
Update {
table: table.table_name(),
alias: table.table_alias(),
set: instrs,
from: Vec::new(),
returning: Vec::new(),
query,
marker: std::marker::PhantomData,
}
}
pub struct Update<M> {
pub table: &'static str,
pub alias: Option<&'static str>,
pub set: Vec<RpnInstr>,
pub from: Vec<RpnInstr>,
pub returning: Vec<RpnInstr>,
pub query: Query,
pub marker: std::marker::PhantomData<M>,
}
pub struct ReturningUpdate<M, T> {
pub update: Update<M>,
pub marker: std::marker::PhantomData<T>,
}
impl<M> Update<M> {
pub fn with<C>(mut self, ctes: C) -> Self
where
C: IntoCtes,
{
self.query = self.query.with(ctes);
self
}
pub fn with_recursive<C>(mut self, ctes: C) -> Self
where
C: IntoCtes,
{
self.query = self.query.with_recursive(ctes);
self
}
pub fn from<F>(mut self, from: F) -> Self
where
F: LowerFrom,
{
let mut ctx = LowerCtx {
instrs: &mut self.from,
params: &mut self.query.params,
data: &mut self.query.data,
};
from.lower_from(&mut ctx);
self
}
pub fn returning<P>(mut self, project: P) -> ReturningUpdate<M, M>
where
P: LowerProject,
M: FromRow,
{
let mut ctx = LowerCtx {
instrs: &mut self.returning,
params: &mut self.query.params,
data: &mut self.query.data,
};
project.lower_project(&mut ctx);
ReturningUpdate {
update: Update {
table: self.table,
alias: self.alias,
set: self.set,
from: self.from,
returning: self.returning,
query: self.query,
marker: std::marker::PhantomData,
},
marker: std::marker::PhantomData,
}
}
pub fn or_filter<E>(mut self, e: E) -> Self
where
E: LowerFilter,
{
self.query = self.query.or_filter(e);
self
}
pub fn filter<E>(mut self, e: E) -> Self
where
E: LowerFilter,
{
self.query = self.query.filter(e);
self
}
pub fn into_compiled<D: HasDialect>(mut self) -> TypedCompiled<M> {
let sql = self.to_sql::<D>();
TypedCompiled {
sql,
params: self.query.params,
data: self.query.data,
marker: std::marker::PhantomData,
}
}
pub fn to_debug_sql<D: HasDialect>(&mut self) -> String {
let mut sql = self.to_sql::<D>();
self.query.debug_params(&mut sql).unwrap();
sql
}
pub fn to_sql<D: HasDialect>(&mut self) -> String {
let mut writer = String::new();
let mut directives = Vec::new();
let mut indexes = Vec::new();
let mut emitter = Emitter::new(
&mut writer,
&self.query.data,
D::DIALECT,
&mut directives,
&mut indexes,
);
emitter.emit_update(self).unwrap();
rewrite_params(&indexes, &mut self.query.params);
writer
}
}
impl<M, T> ReturningUpdate<M, T>
where
T: FromRow,
{
pub fn typed<R>(self) -> ReturningUpdate<M, R>
where
R: FromRow,
{
ReturningUpdate {
update: self.update,
marker: std::marker::PhantomData,
}
}
pub fn into_compiled<D: HasDialect>(mut self) -> TypedCompiled<T> {
let sql = self.to_sql::<D>();
TypedCompiled {
sql,
params: self.update.query.params,
data: self.update.query.data,
marker: std::marker::PhantomData,
}
}
pub fn to_debug_sql<D: HasDialect>(&mut self) -> String {
let mut sql = self.to_sql::<D>();
self.update.query.debug_params(&mut sql).unwrap();
sql
}
pub fn to_sql<D: HasDialect>(&mut self) -> String {
let mut writer = String::new();
let mut directives = Vec::new();
let mut indexes = Vec::new();
let mut emitter = Emitter::new(
&mut writer,
&self.update.query.data,
D::DIALECT,
&mut directives,
&mut indexes,
);
emitter.emit_update(&self.update).unwrap();
rewrite_params(&indexes, &mut self.update.query.params);
writer
}
}
pub trait VisitUpdate {
fn visit_set<T, E>(&mut self, column: &'static str, param: E)
where
T: TypeMeta,
E: LowerCompatible<T>;
fn visit_typed_set<T>(&mut self, column: &'static str, value: &(impl Compatible<T> + ?Sized))
where
T: TypeMeta;
}
pub struct UpdateVisitor<'a> {
ctx: &'a mut LowerCtx<'a>,
count: usize,
}
impl<'a> UpdateVisitor<'a> {
pub fn new(ctx: &'a mut LowerCtx<'a>) -> Self {
Self { ctx, count: 0 }
}
}
impl<'a> Drop for UpdateVisitor<'a> {
fn drop(&mut self) {
self.ctx.instrs.push_seperated(self.count);
}
}
impl<'a> VisitUpdate for UpdateVisitor<'a> {
fn visit_set<T, E>(&mut self, column: &'static str, param: E)
where
T: TypeMeta,
E: LowerCompatible<T>,
{
let lhs = self.ctx.lower_column(None, column);
let rhs = param.lower_compatible(self.ctx);
self.ctx.instrs.push_assignment(lhs, rhs);
self.count += 1;
}
fn visit_typed_set<T>(&mut self, column: &'static str, value: &(impl Compatible<T> + ?Sized))
where
T: TypeMeta,
{
let lhs = self.ctx.lower_column(None, column);
let param = encode_param(value, self.ctx.data);
let rhs = self.ctx.lower_param(param);
self.ctx.instrs.push_assignment(lhs, rhs);
self.count += 1;
}
}
pub trait Updatable<M> {
fn update<V: VisitUpdate>(self, visitor: &mut V);
}
impl<M, T, R> Updatable<M> for Binary<op::Eq, Column<M, T>, R>
where
T: TypeMeta,
R: LowerCompatible<T>,
{
fn update<V: VisitUpdate>(self, visitor: &mut V) {
visitor.visit_set::<T, _>(self.left.name, self.right);
}
}
impl<M, Value, T, K, R> Updatable<M> for Binary<op::Eq, PersistedField<M, Value, T, K>, R>
where
T: TypeMeta,
R: LowerCompatible<T>,
{
fn update<Visitor: VisitUpdate>(self, visitor: &mut Visitor) {
visitor.visit_set::<T, _>(self.left.name(), self.right);
}
}
impl<M, T, R> Updatable<M> for Binary<op::Eq, As<Column<M, T>>, R>
where
T: TypeMeta,
R: LowerCompatible<T>,
{
fn update<V: VisitUpdate>(self, visitor: &mut V) {
visitor.visit_set::<T, _>(self.left.inner.name, self.right);
}
}
pub trait LowerSet<M> {
fn lower_update(self, ctx: &mut LowerCtx);
}
impl<M> UpdateFrom<M> {
pub fn with<C>(mut self, ctes: C) -> Self
where
C: IntoCtes,
{
self.query = self.query.with(ctes);
self
}
pub fn with_recursive<C>(mut self, ctes: C) -> Self
where
C: IntoCtes,
{
self.query = self.query.with_recursive(ctes);
self
}
fn set_query<S>(self, set: S, mut query: Query) -> Update<M>
where
S: Updatable<M>,
{
let mut instrs = Vec::new();
let mut ctx = LowerCtx {
instrs: &mut instrs,
params: &mut query.params,
data: &mut query.data,
};
let mut visitor = UpdateVisitor::new(&mut ctx);
set.update(&mut visitor);
drop(visitor);
Update {
table: self.table,
alias: self.alias,
set: instrs,
from: Vec::new(),
returning: Vec::new(),
query,
marker: std::marker::PhantomData,
}
}
pub fn set<S>(self, set: S) -> Update<M>
where
S: Updatable<M>,
{
let Self {
table,
alias,
mut query,
..
} = self;
let mut instrs = Vec::new();
let mut ctx = LowerCtx {
instrs: &mut instrs,
params: &mut query.params,
data: &mut query.data,
};
let mut visitor = UpdateVisitor::new(&mut ctx);
set.update(&mut visitor);
drop(visitor);
Update {
table,
alias,
set: instrs,
from: Vec::new(),
returning: Vec::new(),
query,
marker: std::marker::PhantomData,
}
}
}
macro_rules! impl_set_macro {
($($T:ident),+) => {
impl<M, $($T,)+> Updatable<M> for ($($T,)+)
where
$($T: Updatable<M>,)+
{
fn update<V: VisitUpdate>(self, visitor: &mut V) {
#[allow(non_snake_case)]
let ($($T,)+) = self;
$(
$T.update(visitor);
)+
}
}
};
}
impl_for_all_tuples!(impl_set_macro);
impl<M> QueryOf<M>
where
M: Qrafting,
{
pub fn update<S>(self, set: S) -> Update<M>
where
S: Updatable<M>,
{
let table = Table::new(M::TABLE);
update(table).set_query(set, self.query)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
FromRow, Postgres, Qrafting, Sqlite, Timestamp,
alias::Alias,
cte::{CteDefinition, with},
expression::{Column, EqExt, PredicateExt, now},
query::{Order, Table, select},
tests::{id, name, table, username},
};
#[allow(dead_code)]
#[derive(Debug)]
struct ClaimedUser {
pub id: i64,
pub username: String,
}
impl FromRow for ClaimedUser {
fn from_row(row: &crate::quex::Row) -> crate::quex::Result<Self> {
Ok(Self {
id: row.get("id")?,
username: row.get("username")?,
})
}
}
#[allow(non_upper_case_globals)]
const updated_at: Column<crate::tests::User, Timestamp> = Column::new("updated_at");
struct Claim;
impl Qrafting for Claim {
type Schema = ();
type QueryPolicy = crate::DefaultQueryPolicy<Self>;
const FIELD_COUNT: usize = 2;
const TABLE: &'static str = "claim";
}
#[test]
fn test_update_simple() {
let stmt = update(table).set(id.eq(10)).to_debug_sql::<Sqlite>();
assert_eq!(stmt, r#"update "users" set "id" = ?; params=[10]"#);
}
#[test]
fn test_update_filters() {
let stmt = update(table)
.set((id.eq(10), name.eq("test"), username.eq("username")))
.filter(id.eq(1).or(id.eq(2)))
.filter(name.eq("test"))
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"update "users" set "id" = ?, "name" = ?, "username" = ? where ("users"."id" = ? or "users"."id" = ?) and "users"."name" = ?; params=[10, "test", "username", 1, 2, "test"]"#
);
}
#[test]
fn test_update_or_filter_combines_with_existing_where_clause() {
let stmt = update(table)
.set(username.eq("renamed"))
.filter(id.eq(1))
.or_filter(id.eq(2))
.to_debug_sql::<Sqlite>();
assert_eq!(
stmt,
r#"update "users" set "username" = ? where "users"."id" = ? or "users"."id" = ?; params=["renamed", 1, 2]"#
);
}
#[test]
fn test_update_with_cte_from_alias_returning_and_now() {
let claim = CteDefinition::<Claim>::new(
"claim",
&[],
select((id, username))
.from(table)
.filter(id.eq(1))
.order_by(id.asc())
.lock_for_update()
.skip_locked()
.limit(10_i64),
);
let u = table.alias("u");
let c = Table::<Claim>::new("claim").alias("c");
let stmt = with(claim)
.update(u)
.set(u.col(updated_at).eq(now()))
.from(c)
.filter(u.col(id).eq(c.col(id)))
.returning((c.col(id).alias("id"), c.col(username).alias("username")))
.typed::<ClaimedUser>()
.to_debug_sql::<Postgres>();
assert_eq!(
stmt,
r#"with "claim" as (select "users"."id", "users"."username" from "users" where "users"."id" = $1 order by "users"."id" asc limit $2 for update skip locked) update "users" as "u" set "updated_at" = current_timestamp from "claim" as "c" where "u"."id" = "c"."id" returning "c"."id" as "id", "c"."username" as "username"; params=[1, 10]"#
);
}
}