1use diesel::associations::HasTable;
6use diesel::dsl;
7use diesel::expression::is_aggregate;
8use diesel::expression::ValidGrouping;
9use diesel::pg::Pg;
10use diesel::query_builder::AsQuery;
11use diesel::query_builder::AstPass;
12use diesel::query_builder::Query;
13use diesel::query_builder::QueryFragment;
14use diesel::query_builder::QueryId;
15use diesel::query_dsl::methods::FilterDsl;
16use diesel::query_dsl::methods::SelectDsl;
17use diesel::result::Error;
18use diesel::sql_types;
19use diesel::Column;
20use diesel::Expression;
21use diesel::Queryable;
22use diesel::RunQueryDsl;
23use diesel::Table;
24use std::any::TypeId;
25use tuplex::IntoArray;
26
27pub fn uplete<Q>(query: Q) -> UpleteBuilder<dsl::Select<Q::Query, <Q::Table as Table>::PrimaryKey>>
34where
35 Q: AsQuery + HasTable,
36 Q::Table: Default,
37 Q::Query: SelectDsl<<Q::Table as Table>::PrimaryKey>,
38
39 UpleteBuilder<Q>: AsQuery,
41{
42 UpleteBuilder {
43 query: query.as_query().select(Q::Table::default().primary_key()),
44 columns_being_set_to_null: Vec::new(),
45 }
46}
47
48pub struct UpleteBuilder<Q> {
49 query: Q,
50 columns_being_set_to_null: Vec<DynColumn>,
51}
52
53impl<Q: HasTable> UpleteBuilder<Q> {
54 pub fn set_null<C: Column<Table = Q::Table> + Into<DynColumn>>(mut self, column: C) -> Self {
55 self.columns_being_set_to_null.push(column.into());
56 self
57 }
58}
59
60impl<Q> AsQuery for UpleteBuilder<Q>
61where
62 Q: HasTable,
63 Q::Table: SupportedTable,
64 <Q::Table as Table>::AllColumns: IntoArray<DynColumn>,
65 <<Q::Table as SupportedTable>::Key as IntoArray<DynColumn>>::Output:
66 IntoIterator<Item = DynColumn>,
67 <<Q::Table as SupportedTable>::AdditionalIgnoredColumns as IntoArray<DynColumn>>::Output:
68 IntoIterator<Item = DynColumn>,
69 <<Q::Table as Table>::AllColumns as IntoArray<DynColumn>>::Output:
70 IntoIterator<Item = DynColumn>,
71 Q: Clone + FilterDsl<AllNull> + FilterDsl<dsl::not<AllNull>>,
72 dsl::Filter<Q, AllNull>: QueryFragment<Pg> + Send + 'static,
73 dsl::Filter<Q, dsl::not<AllNull>>: QueryFragment<Pg> + Send + 'static,
74{
75 type Query = UpleteQuery;
76
77 type SqlType = (sql_types::BigInt, sql_types::BigInt);
78
79 fn as_query(self) -> Self::Query {
80 let table = Q::Table::default;
81 let deletion_condition = AllNull(
82 Q::Table::all_columns()
83 .into_array()
84 .into_iter()
85 .filter(|c: &DynColumn| {
86 self.columns_being_set_to_null
87 .iter()
88 .cloned()
89 .chain(<Q::Table as SupportedTable>::Key::default().into_array())
90 .chain(
91 <Q::Table as SupportedTable>::AdditionalIgnoredColumns::default()
92 .into_array(),
93 )
94 .all(|excluded_column| excluded_column.type_id != c.type_id)
95 })
96 .collect::<Vec<_>>(),
97 );
98 UpleteQuery {
99 update_subquery: Box::new(
109 self.query
110 .clone()
111 .filter(dsl::not(deletion_condition.clone())),
112 ),
113 delete_subquery: Box::new(self.query.filter(deletion_condition)),
114 table: Box::new(table()),
115 key: Box::new(<Q::Table as SupportedTable>::Key::default()),
116 columns_being_set_to_null: self.columns_being_set_to_null,
117 }
118 }
119}
120
121pub struct UpleteQuery {
122 update_subquery: Box<dyn QueryFragment<Pg> + Send + 'static>,
123 delete_subquery: Box<dyn QueryFragment<Pg> + Send + 'static>,
124 table: Box<dyn QueryFragment<Pg> + Send + 'static>,
125 key: Box<dyn QueryFragment<Pg> + Send + 'static>,
126 columns_being_set_to_null: Vec<DynColumn>,
127}
128
129impl QueryId for UpleteQuery {
130 type QueryId = ();
131
132 const HAS_STATIC_QUERY_ID: bool = false;
133}
134
135impl Query for UpleteQuery {
136 type SqlType = (sql_types::BigInt, sql_types::BigInt);
137}
138
139impl QueryFragment<Pg> for UpleteQuery {
140 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> Result<(), Error> {
141 assert_ne!(
142 self.columns_being_set_to_null.len(),
143 0,
144 "`set_null` was not called"
145 );
146
147 out.push_sql("/**/");
149
150 for (prefix, subquery) in [
152 ("WITH update_keys", &self.update_subquery),
153 (", delete_keys", &self.delete_subquery),
154 ] {
155 out.push_sql(prefix);
156 out.push_sql(" AS (");
157 subquery.walk_ast(out.reborrow())?;
158 out.push_sql(" FOR UPDATE)");
159 }
160
161 out.push_sql(", update_result AS (UPDATE ");
163 self.table.walk_ast(out.reborrow())?;
164 let mut item_prefix = " SET ";
165 for column in &self.columns_being_set_to_null {
166 out.push_sql(item_prefix);
167 out.push_identifier(column.name)?;
168 out.push_sql(" = NULL");
169 item_prefix = ",";
170 }
171 out.push_sql(" WHERE (");
172 self.key.walk_ast(out.reborrow())?;
173 out.push_sql(") = ANY (SELECT * FROM update_keys) RETURNING 1)");
174
175 out.push_sql(", delete_result AS (DELETE FROM ");
177 self.table.walk_ast(out.reborrow())?;
178 out.push_sql(" WHERE (");
179 self.key.walk_ast(out.reborrow())?;
180 out.push_sql(") = ANY (SELECT * FROM delete_keys) RETURNING 1)");
181
182 out.push_sql(" SELECT (SELECT count(*) FROM update_result)");
184 out.push_sql(", (SELECT count(*) FROM delete_result)");
185
186 Ok(())
187 }
188}
189
190#[derive(Clone)]
192pub struct AllNull<T = DynColumn>(Vec<T>);
193
194impl<T> Expression for AllNull<T> {
195 type SqlType = sql_types::Bool;
196}
197
198impl<T> ValidGrouping<()> for AllNull<T> {
199 type IsAggregate = is_aggregate::No;
200}
201
202impl<T: QueryFragment<Pg>> QueryFragment<Pg> for AllNull<T> {
203 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> Result<(), Error> {
204 out.push_sql("(TRUE");
206 for item in &self.0 {
207 out.push_sql(" AND (");
208 item.walk_ast(out.reborrow())?;
209 out.push_sql(" IS NULL)");
210 }
211 out.push_sql(")");
212
213 Ok(())
214 }
215}
216
217#[derive(Clone)]
219pub struct DynColumn {
220 type_id: TypeId,
221 name: &'static str,
222}
223
224impl<T: Column + 'static> From<T> for DynColumn {
225 fn from(_value: T) -> Self {
226 DynColumn {
227 type_id: TypeId::of::<T>(),
228 name: T::NAME,
229 }
230 }
231}
232
233impl QueryFragment<Pg> for DynColumn {
234 fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> Result<(), Error> {
235 out.push_identifier(self.name)
236 }
237}
238
239#[derive(Queryable, PartialEq, Eq, Debug)]
243pub struct UpleteCount {
244 pub updated: i64,
246 pub deleted: i64,
248}
249
250impl UpleteCount {
251 pub fn only_updated(n: i64) -> Self {
253 UpleteCount {
254 updated: n,
255 deleted: 0,
256 }
257 }
258
259 pub fn only_deleted(n: i64) -> Self {
261 UpleteCount {
262 updated: 0,
263 deleted: n,
264 }
265 }
266}
267
268pub trait SupportedTable: Default + QueryFragment<Pg> + Send + 'static {
270 type Key: Default + IntoArray<DynColumn> + QueryFragment<Pg> + Send + 'static;
278 type AdditionalIgnoredColumns: Default + IntoArray<DynColumn>;
284}
285
286impl<Conn, Q> RunQueryDsl<Conn> for UpleteBuilder<Q> where Self: AsQuery {}
287
288impl<Conn> RunQueryDsl<Conn> for UpleteQuery {}
289
290#[cfg(test)]
291mod tests {
292 use crate::AllNull;
293 use diesel::connection::SimpleConnection;
294 use diesel::debug_query;
295 use diesel::insert_into;
296 use diesel::pg::Pg;
297 use diesel::query_builder::AsQuery;
298 use diesel::query_builder::QueryId;
299 use diesel::select;
300 use diesel::sql_types;
301 use diesel::AppearsOnTable;
302 use diesel::Connection;
303 use diesel::ExpressionMethods;
304 use diesel::IntoSql;
305 use diesel::PgConnection;
306 use diesel::QueryDsl;
307 use diesel::QueryResult;
308 use diesel::RunQueryDsl;
309 use diesel::SelectableExpression;
310 use std::env;
311 use std::process::Command;
312 use tempfile::tempdir_in;
313
314 impl<T, QS> AppearsOnTable<QS> for AllNull<T> {}
315
316 impl<T, QS> SelectableExpression<QS> for AllNull<T> {}
317
318 impl<T> QueryId for AllNull<T> {
319 type QueryId = ();
320 const HAS_STATIC_QUERY_ID: bool = false;
321 }
322
323 diesel::table! {
324 t (id1, id2) {
325 id1 -> Int4,
327 id2 -> Int4,
328 a -> Nullable<Int4>,
329 b -> Nullable<Int4>,
330 }
331 }
332
333 impl crate::SupportedTable for t::table {
334 type Key = (t::id1, t::id2);
335 type AdditionalIgnoredColumns = ();
336 }
337
338 fn expect_rows(
339 conn: &mut PgConnection,
340 expected: &[(Option<i32>, Option<i32>)],
341 ) -> QueryResult<()> {
342 let rows: Vec<(Option<i32>, Option<i32>)> =
343 t::table.select((t::a, t::b)).order_by(t::id1).load(conn)?;
344 assert_eq!(expected, &rows);
345
346 Ok(())
347 }
348
349 fn test_count(conn: &mut PgConnection) -> QueryResult<()> {
352 conn
353 .batch_execute("CREATE TABLE t (id1 serial, id2 int NOT NULL DEFAULT 1, a int, b int, PRIMARY KEY (id1, id2));")?;
354 expect_rows(conn, &[])?;
355
356 insert_into(t::table)
357 .values(&[
358 (t::a.eq(Some(1)), t::b.eq(Some(2))),
359 (t::a.eq(Some(3)), t::b.eq(None)),
360 (t::a.eq(Some(4)), t::b.eq(Some(5))),
361 ])
362 .execute(conn)?;
363 expect_rows(
364 conn,
365 &[(Some(1), Some(2)), (Some(3), None), (Some(4), Some(5))],
366 )?;
367
368 let count1 = crate::uplete(t::table).set_null(t::a).get_result(conn)?;
369 assert_eq!(
370 crate::UpleteCount {
371 updated: 2,
372 deleted: 1
373 },
374 count1
375 );
376 expect_rows(conn, &[(None, Some(2)), (None, Some(5))])?;
377
378 let count2 = crate::uplete(t::table).set_null(t::b).get_result(conn)?;
379 assert_eq!(crate::UpleteCount::only_deleted(2), count2);
380 expect_rows(conn, &[])?;
381
382 conn.batch_execute("DROP TABLE t;")?;
383
384 Ok(())
385 }
386
387 fn expected_sql(check_null: &str, set_null: &str) -> String {
388 let with_queries = {
389 let key = r#""t"."id1", "t"."id2""#;
390 let t = r#""t""#;
391
392 let update_keys =
393 format!("SELECT {key} FROM {t} WHERE NOT (({check_null})) FOR UPDATE");
394 let delete_keys = format!("SELECT {key} FROM {t} WHERE ({check_null}) FOR UPDATE");
395 let update_result = format!(
396 "UPDATE {t} SET {set_null} WHERE ({key}) = ANY (SELECT * FROM update_keys) RETURNING 1"
397 );
398 let delete_result = format!(
399 "DELETE FROM {t} WHERE ({key}) = ANY (SELECT * FROM delete_keys) RETURNING 1"
400 );
401
402 format!("update_keys AS ({update_keys}), delete_keys AS ({delete_keys}), update_result AS ({update_result}), delete_result AS ({delete_result})")
403 };
404 let update_count = "SELECT count(*) FROM update_result";
405 let delete_count = "SELECT count(*) FROM delete_result";
406
407 format!(r#"/**/WITH {with_queries} SELECT ({update_count}), ({delete_count}) -- binds: []"#)
408 }
409
410 #[test]
411 fn test_generated_sql() {
412 assert_eq!(
414 debug_query::<Pg, _>(&crate::uplete(t::table).set_null(t::b).as_query()).to_string(),
415 expected_sql(r#"TRUE AND ("a" IS NULL)"#, r#""b" = NULL"#)
416 );
417 assert_eq!(
418 debug_query::<Pg, _>(
419 &crate::uplete(t::table)
420 .set_null(t::a)
421 .set_null(t::b)
422 .as_query()
423 )
424 .to_string(),
425 expected_sql(r#"TRUE"#, r#""a" = NULL,"b" = NULL"#)
426 );
427 }
428
429 #[test]
430 fn test_count_methods() {
431 assert_eq!(
432 crate::UpleteCount::only_updated(1),
433 crate::UpleteCount {
434 updated: 1,
435 deleted: 0
436 }
437 );
438 assert_eq!(
439 crate::UpleteCount::only_deleted(1),
440 crate::UpleteCount {
441 updated: 0,
442 deleted: 1
443 }
444 );
445 }
446
447 fn test_all_null(conn: &mut PgConnection) -> QueryResult<()> {
448 let some = Some(1).into_sql::<sql_types::Nullable<sql_types::Integer>>();
449 let none = None::<i32>.into_sql::<sql_types::Nullable<sql_types::Integer>>();
450
451 let mut all_null = |items| select(AllNull(items)).get_result::<bool>(conn);
453
454 assert!(all_null(vec![])?);
455 assert!(all_null(vec![none])?);
456 assert!(all_null(vec![none, none])?);
457 assert!(all_null(vec![none, none, none])?);
458 assert!(!all_null(vec![some])?);
459 assert!(!all_null(vec![some, none])?);
460 assert!(!all_null(vec![none, some, none])?);
461
462 Ok(())
463 }
464
465 #[test]
466 fn test_db_stuff() -> QueryResult<()> {
467 let user_specific_tmp = env::var("XDG_RUNTIME_DIR").unwrap();
468 let tempdir = tempdir_in(&user_specific_tmp).unwrap();
469 let tempdir_path = tempdir.path().to_str().unwrap();
470 let pgdata = tempdir.path().join("pgdata");
471 let envs = [("PGDATA", &pgdata)];
472 assert!(Command::new("initdb")
473 .envs(envs)
474 .args(["--username=postgres", "--auth=trust", "--no-sync"])
475 .status()
476 .unwrap()
477 .success());
478 let mut postgres_process = Command::new("postgres")
479 .envs(envs)
480 .args([
481 "-c",
482 "listen_addresses=",
483 "-c",
484 &format!("unix_socket_directories={}", tempdir_path),
485 "-c",
486 "fsync=off",
487 "-c",
488 "logging_collector=off",
489 "-c",
490 "port=1",
491 ])
492 .spawn()
493 .unwrap();
494 std::thread::sleep(std::time::Duration::from_secs(3));
496 let mut conn = PgConnection::establish(&format!(
497 "postgresql://postgres@{}:1",
498 tempdir_path.replace('/', "%2F")
499 ))
500 .unwrap();
501 test_count(&mut conn)?;
502 test_all_null(&mut conn)?;
503 postgres_process.kill().unwrap();
504 tempdir.close().unwrap();
505 Ok(())
506 }
507}