use super::SqlGenContext;
use super::common::{value_expr_to_expr, where_to_expr_validated};
use crate::QError;
use dibs_query_schema::Update;
use dibs_sql::{ColumnName, ParamName, UpdateStmt, render};
#[derive(Debug, Clone)]
pub struct GeneratedUpdate {
pub sql: String,
pub params: Vec<ParamName>,
pub returning_columns: Vec<ColumnName>,
}
pub fn generate_update_sql(
ctx: &SqlGenContext,
update: &Update,
) -> Result<GeneratedUpdate, QError> {
let mut stmt = UpdateStmt::new(update.table.value.clone());
for (col_meta, value_expr) in &update.set.columns {
let col_name = &col_meta.value;
let expr = value_expr_to_expr(col_name, value_expr, update.params.as_ref());
stmt = stmt.set(col_name.clone(), expr);
}
if let Some(where_clause) = &update.where_clause
&& let Some(expr) = where_to_expr_validated(ctx, where_clause)?
{
stmt = stmt.where_(expr);
}
let returning_columns: Vec<ColumnName> = if let Some(returning) = &update.returning {
returning.columns.keys().map(|k| k.value.clone()).collect()
} else {
vec![]
};
for col in &returning_columns {
stmt = stmt.returning([col.clone()]);
}
let rendered = render(&stmt);
Ok(GeneratedUpdate {
sql: rendered.sql,
params: rendered.params,
returning_columns,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parse_query_file;
use dibs_db_schema::Schema;
fn get_first_update(source: &str) -> (Update, crate::QSource) {
let (file, qsource) = parse_query_file(camino::Utf8Path::new("<test>"), source).unwrap();
for (_, decl) in file.0.iter() {
if let dibs_query_schema::Decl::Update(u) = decl {
return (u.clone(), (*qsource).clone());
}
}
panic!("No update found in source");
}
#[test]
fn test_simple_update() {
let source = r#"
UpdateUserEmail @update{
params {id @uuid, email @string}
table users
set {email $email}
where {id $id}
returning {id, email}
}
"#;
let (update, qsource) = get_first_update(source);
let schema = Schema::default();
let ctx = SqlGenContext::new(&schema, std::sync::Arc::new(qsource));
let result = generate_update_sql(&ctx, &update).unwrap();
insta::assert_snapshot!(result.sql);
}
#[test]
fn test_update_with_function() {
let source = r#"
UpdateUser @update{
params {id @uuid, name @string}
table users
set {name $name, updated_at @now}
where {id $id}
returning {id, name, updated_at}
}
"#;
let (update, qsource) = get_first_update(source);
let schema = Schema::default();
let ctx = SqlGenContext::new(&schema, std::sync::Arc::new(qsource));
let result = generate_update_sql(&ctx, &update).unwrap();
insta::assert_snapshot!(result.sql);
}
#[test]
fn test_update_set_null() {
let source = r#"
ClearDeletedAt @update{
params {id @uuid}
table products
set {deleted_at @null}
where {id $id}
}
"#;
let (update, qsource) = get_first_update(source);
let schema = Schema::default();
let ctx = SqlGenContext::new(&schema, std::sync::Arc::new(qsource));
let result = generate_update_sql(&ctx, &update).unwrap();
assert!(
result.sql.contains(r#"SET "deleted_at" = NULL"#),
"expected `= NULL`, got: {}",
result.sql
);
assert!(
!result.sql.contains("NULL()"),
"must not emit the invalid `NULL()`: {}",
result.sql
);
}
#[test]
fn test_update_multiple_conditions() {
let source = r#"
UpdateProductStatus @update{
params {user_id @uuid, old_status @string, new_status @string}
table products
set {status $new_status, updated_at @now}
where {user_id $user_id, status $old_status}
returning {id, status}
}
"#;
let (update, qsource) = get_first_update(source);
let schema = Schema::default();
let ctx = SqlGenContext::new(&schema, std::sync::Arc::new(qsource));
let result = generate_update_sql(&ctx, &update).unwrap();
insta::assert_snapshot!(result.sql);
}
#[test]
fn test_update_shorthand_params() {
let source = r#"
UpdateUser @update{
params {id @uuid, name @string, email @string}
table users
set {name, email}
where {id}
returning {id}
}
"#;
let (update, qsource) = get_first_update(source);
let schema = Schema::default();
let ctx = SqlGenContext::new(&schema, std::sync::Arc::new(qsource));
let result = generate_update_sql(&ctx, &update).unwrap();
insta::assert_snapshot!(result.sql);
}
}