diesel_uplete/
lib.rs

1// SPDX-FileCopyrightText: 2024, 2025 Joseph Silva
2//
3// SPDX-License-Identifier: MIT OR Apache-2.0
4
5use diesel::associations::HasTable;
6use diesel::dsl;
7use diesel::expression::is_aggregate;
8use diesel::expression::ValidGrouping;
9use diesel::pg::Pg;
10use diesel::query_builder::AsQuery;
11use diesel::query_builder::AstPass;
12use diesel::query_builder::Query;
13use diesel::query_builder::QueryFragment;
14use diesel::query_builder::QueryId;
15use diesel::query_dsl::methods::FilterDsl;
16use diesel::query_dsl::methods::SelectDsl;
17use diesel::result::Error;
18use diesel::sql_types;
19use diesel::Column;
20use diesel::Expression;
21use diesel::Queryable;
22use diesel::RunQueryDsl;
23use diesel::Table;
24use std::any::TypeId;
25use tuplex::IntoArray;
26
27/// Set columns (each specified with [`UpleteBuilder::set_null`]) to null in the rows found by
28/// `query`, and delete rows that have no remaining non-null values in non-ignored columns
29///
30/// # Panics
31///
32/// Running the built query panics if [`UpleteBuilder::set_null`] is not called at least once.
33pub fn uplete<Q>(query: Q) -> UpleteBuilder<dsl::Select<Q::Query, <Q::Table as Table>::PrimaryKey>>
34where
35    Q: AsQuery + HasTable,
36    Q::Table: Default,
37    Q::Query: SelectDsl<<Q::Table as Table>::PrimaryKey>,
38
39    // For better error messages
40    UpleteBuilder<Q>: AsQuery,
41{
42    UpleteBuilder {
43        query: query.as_query().select(Q::Table::default().primary_key()),
44        columns_being_set_to_null: Vec::new(),
45    }
46}
47
48pub struct UpleteBuilder<Q> {
49    query: Q,
50    columns_being_set_to_null: Vec<DynColumn>,
51}
52
53impl<Q: HasTable> UpleteBuilder<Q> {
54    pub fn set_null<C: Column<Table = Q::Table> + Into<DynColumn>>(mut self, column: C) -> Self {
55        self.columns_being_set_to_null.push(column.into());
56        self
57    }
58}
59
60impl<Q> AsQuery for UpleteBuilder<Q>
61where
62    Q: HasTable,
63    Q::Table: SupportedTable,
64    <Q::Table as Table>::AllColumns: IntoArray<DynColumn>,
65    <<Q::Table as SupportedTable>::Key as IntoArray<DynColumn>>::Output:
66        IntoIterator<Item = DynColumn>,
67    <<Q::Table as SupportedTable>::AdditionalIgnoredColumns as IntoArray<DynColumn>>::Output:
68        IntoIterator<Item = DynColumn>,
69    <<Q::Table as Table>::AllColumns as IntoArray<DynColumn>>::Output:
70        IntoIterator<Item = DynColumn>,
71    Q: Clone + FilterDsl<AllNull> + FilterDsl<dsl::not<AllNull>>,
72    dsl::Filter<Q, AllNull>: QueryFragment<Pg> + Send + 'static,
73    dsl::Filter<Q, dsl::not<AllNull>>: QueryFragment<Pg> + Send + 'static,
74{
75    type Query = UpleteQuery;
76
77    type SqlType = (sql_types::BigInt, sql_types::BigInt);
78
79    fn as_query(self) -> Self::Query {
80        let table = Q::Table::default;
81        let deletion_condition = AllNull(
82            Q::Table::all_columns()
83                .into_array()
84                .into_iter()
85                .filter(|c: &DynColumn| {
86                    self.columns_being_set_to_null
87                        .iter()
88                        .cloned()
89                        .chain(<Q::Table as SupportedTable>::Key::default().into_array())
90                        .chain(
91                            <Q::Table as SupportedTable>::AdditionalIgnoredColumns::default()
92                                .into_array(),
93                        )
94                        .all(|excluded_column| excluded_column.type_id != c.type_id)
95                })
96                .collect::<Vec<_>>(),
97        );
98        UpleteQuery {
99            // Updated rows and deleted rows must not overlap, so updating all rows and using the returned
100            // new rows to determine which ones to delete is not an option.
101            //
102            // https://www.postgresql.org/docs/16/queries-with.html#QUERIES-WITH-MODIFYING
103            //
104            // "Trying to update the same row twice in a single statement is not supported. Only one of
105            // the modifications takes place, but it is not easy (and sometimes not possible) to reliably
106            // predict which one. This also applies to deleting a row that was already updated in the same
107            // statement: only the update is performed."
108            update_subquery: Box::new(
109                self.query
110                    .clone()
111                    .filter(dsl::not(deletion_condition.clone())),
112            ),
113            delete_subquery: Box::new(self.query.filter(deletion_condition)),
114            table: Box::new(table()),
115            key: Box::new(<Q::Table as SupportedTable>::Key::default()),
116            columns_being_set_to_null: self.columns_being_set_to_null,
117        }
118    }
119}
120
121pub struct UpleteQuery {
122    update_subquery: Box<dyn QueryFragment<Pg> + Send + 'static>,
123    delete_subquery: Box<dyn QueryFragment<Pg> + Send + 'static>,
124    table: Box<dyn QueryFragment<Pg> + Send + 'static>,
125    key: Box<dyn QueryFragment<Pg> + Send + 'static>,
126    columns_being_set_to_null: Vec<DynColumn>,
127}
128
129impl QueryId for UpleteQuery {
130    type QueryId = ();
131
132    const HAS_STATIC_QUERY_ID: bool = false;
133}
134
135impl Query for UpleteQuery {
136    type SqlType = (sql_types::BigInt, sql_types::BigInt);
137}
138
139impl QueryFragment<Pg> for UpleteQuery {
140    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> Result<(), Error> {
141        assert_ne!(
142            self.columns_being_set_to_null.len(),
143            0,
144            "`set_null` was not called"
145        );
146
147        // This is checked by require_uplete triggers
148        out.push_sql("/**/");
149
150        // Declare `update_keys` and `delete_keys` CTEs, which select primary keys
151        for (prefix, subquery) in [
152            ("WITH update_keys", &self.update_subquery),
153            (", delete_keys", &self.delete_subquery),
154        ] {
155            out.push_sql(prefix);
156            out.push_sql(" AS (");
157            subquery.walk_ast(out.reborrow())?;
158            out.push_sql(" FOR UPDATE)");
159        }
160
161        // Update rows that are referenced in `update_keys`
162        out.push_sql(", update_result AS (UPDATE ");
163        self.table.walk_ast(out.reborrow())?;
164        let mut item_prefix = " SET ";
165        for column in &self.columns_being_set_to_null {
166            out.push_sql(item_prefix);
167            out.push_identifier(column.name)?;
168            out.push_sql(" = NULL");
169            item_prefix = ",";
170        }
171        out.push_sql(" WHERE (");
172        self.key.walk_ast(out.reborrow())?;
173        out.push_sql(") = ANY (SELECT * FROM update_keys) RETURNING 1)");
174
175        // Delete rows that are referenced in `delete_keys`
176        out.push_sql(", delete_result AS (DELETE FROM ");
177        self.table.walk_ast(out.reborrow())?;
178        out.push_sql(" WHERE (");
179        self.key.walk_ast(out.reborrow())?;
180        out.push_sql(") = ANY (SELECT * FROM delete_keys) RETURNING 1)");
181
182        // Count updated rows and deleted rows (`RETURNING 1` makes this possible)
183        out.push_sql(" SELECT (SELECT count(*) FROM update_result)");
184        out.push_sql(", (SELECT count(*) FROM delete_result)");
185
186        Ok(())
187    }
188}
189
190/// Appears in some trait bounds.
191#[derive(Clone)]
192pub struct AllNull<T = DynColumn>(Vec<T>);
193
194impl<T> Expression for AllNull<T> {
195    type SqlType = sql_types::Bool;
196}
197
198impl<T> ValidGrouping<()> for AllNull<T> {
199    type IsAggregate = is_aggregate::No;
200}
201
202impl<T: QueryFragment<Pg>> QueryFragment<Pg> for AllNull<T> {
203    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> Result<(), Error> {
204        // Must produce a valid expression even if `self.0` is empty
205        out.push_sql("(TRUE");
206        for item in &self.0 {
207            out.push_sql(" AND (");
208            item.walk_ast(out.reborrow())?;
209            out.push_sql(" IS NULL)");
210        }
211        out.push_sql(")");
212
213        Ok(())
214    }
215}
216
217/// Appears in some trait bounds.
218#[derive(Clone)]
219pub struct DynColumn {
220    type_id: TypeId,
221    name: &'static str,
222}
223
224impl<T: Column + 'static> From<T> for DynColumn {
225    fn from(_value: T) -> Self {
226        DynColumn {
227            type_id: TypeId::of::<T>(),
228            name: T::NAME,
229        }
230    }
231}
232
233impl QueryFragment<Pg> for DynColumn {
234    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> Result<(), Error> {
235        out.push_identifier(self.name)
236    }
237}
238
239/// The output of an uplete.
240///
241/// Each row in the query given to [`uplete`] is either updated or deleted. No row is both updated and deleted.
242#[derive(Queryable, PartialEq, Eq, Debug)]
243pub struct UpleteCount {
244    /// Number of rows that were updated.
245    pub updated: i64,
246    /// Number of rows that were deleted.
247    pub deleted: i64,
248}
249
250impl UpleteCount {
251    /// Returns `UpleteCount { updated: n, deleted: 0 }`.
252    pub fn only_updated(n: i64) -> Self {
253        UpleteCount {
254            updated: n,
255            deleted: 0,
256        }
257    }
258
259    /// Returns `UpleteCount { updated: 0, deleted: n }`.
260    pub fn only_deleted(n: i64) -> Self {
261        UpleteCount {
262            updated: 0,
263            deleted: n,
264        }
265    }
266}
267
268/// Trait that you must implement for a table in order to do upletes on it.
269pub trait SupportedTable: Default + QueryFragment<Pg> + Send + 'static {
270    /// A tuple of columns that, as a whole, must always has a unique value for each row. Subqueries in the generated SQL rely on this requirement being met.
271    ///
272    /// Each column in `Key` is also ignored. See the documentation of [`uplete`] for what that means.
273    ///
274    /// `Key` usually should the same as the primary key.
275    ///
276    /// To use only one column, use a trailing comma: `type Key = (my_table::my_column,);`
277    type Key: Default + IntoArray<DynColumn> + QueryFragment<Pg> + Send + 'static;
278    /// Each column in `AdditionalIgnoredColumns` is ignored. See the documentation of [`uplete`] for what that means.
279    ///
280    /// May be `()` if `Key` has all columns that should be ignored.
281    ///
282    /// To use only one column, use a trailing comma: `type Key = (my_table::my_column,);`
283    type AdditionalIgnoredColumns: Default + IntoArray<DynColumn>;
284}
285
286impl<Conn, Q> RunQueryDsl<Conn> for UpleteBuilder<Q> where Self: AsQuery {}
287
288impl<Conn> RunQueryDsl<Conn> for UpleteQuery {}
289
290#[cfg(test)]
291mod tests {
292    use crate::AllNull;
293    use diesel::connection::SimpleConnection;
294    use diesel::debug_query;
295    use diesel::insert_into;
296    use diesel::pg::Pg;
297    use diesel::query_builder::AsQuery;
298    use diesel::query_builder::QueryId;
299    use diesel::select;
300    use diesel::sql_types;
301    use diesel::AppearsOnTable;
302    use diesel::Connection;
303    use diesel::ExpressionMethods;
304    use diesel::IntoSql;
305    use diesel::PgConnection;
306    use diesel::QueryDsl;
307    use diesel::QueryResult;
308    use diesel::RunQueryDsl;
309    use diesel::SelectableExpression;
310    use std::env;
311    use std::process::Command;
312    use tempfile::tempdir_in;
313
314    impl<T, QS> AppearsOnTable<QS> for AllNull<T> {}
315
316    impl<T, QS> SelectableExpression<QS> for AllNull<T> {}
317
318    impl<T> QueryId for AllNull<T> {
319        type QueryId = ();
320        const HAS_STATIC_QUERY_ID: bool = false;
321    }
322
323    diesel::table! {
324      t (id1, id2) {
325        // uplete doesn't work for non-tuple primary key
326        id1 -> Int4,
327        id2 -> Int4,
328        a -> Nullable<Int4>,
329        b -> Nullable<Int4>,
330      }
331    }
332
333    impl crate::SupportedTable for t::table {
334        type Key = (t::id1, t::id2);
335        type AdditionalIgnoredColumns = ();
336    }
337
338    fn expect_rows(
339        conn: &mut PgConnection,
340        expected: &[(Option<i32>, Option<i32>)],
341    ) -> QueryResult<()> {
342        let rows: Vec<(Option<i32>, Option<i32>)> =
343            t::table.select((t::a, t::b)).order_by(t::id1).load(conn)?;
344        assert_eq!(expected, &rows);
345
346        Ok(())
347    }
348
349    // Main purpose of this test is to check accuracy of the returned `UpleteCount`, which other modules'
350    // tests rely on
351    fn test_count(conn: &mut PgConnection) -> QueryResult<()> {
352        conn
353            .batch_execute("CREATE TABLE t (id1 serial, id2 int NOT NULL DEFAULT 1, a int, b int, PRIMARY KEY (id1, id2));")?;
354        expect_rows(conn, &[])?;
355
356        insert_into(t::table)
357            .values(&[
358                (t::a.eq(Some(1)), t::b.eq(Some(2))),
359                (t::a.eq(Some(3)), t::b.eq(None)),
360                (t::a.eq(Some(4)), t::b.eq(Some(5))),
361            ])
362            .execute(conn)?;
363        expect_rows(
364            conn,
365            &[(Some(1), Some(2)), (Some(3), None), (Some(4), Some(5))],
366        )?;
367
368        let count1 = crate::uplete(t::table).set_null(t::a).get_result(conn)?;
369        assert_eq!(
370            crate::UpleteCount {
371                updated: 2,
372                deleted: 1
373            },
374            count1
375        );
376        expect_rows(conn, &[(None, Some(2)), (None, Some(5))])?;
377
378        let count2 = crate::uplete(t::table).set_null(t::b).get_result(conn)?;
379        assert_eq!(crate::UpleteCount::only_deleted(2), count2);
380        expect_rows(conn, &[])?;
381
382        conn.batch_execute("DROP TABLE t;")?;
383
384        Ok(())
385    }
386
387    fn expected_sql(check_null: &str, set_null: &str) -> String {
388        let with_queries = {
389            let key = r#""t"."id1", "t"."id2""#;
390            let t = r#""t""#;
391
392            let update_keys =
393                format!("SELECT {key} FROM {t} WHERE  NOT (({check_null})) FOR UPDATE");
394            let delete_keys = format!("SELECT {key} FROM {t} WHERE ({check_null}) FOR UPDATE");
395            let update_result = format!(
396          "UPDATE {t} SET {set_null} WHERE ({key}) = ANY (SELECT * FROM update_keys) RETURNING 1"
397        );
398            let delete_result = format!(
399                "DELETE FROM {t} WHERE ({key}) = ANY (SELECT * FROM delete_keys) RETURNING 1"
400            );
401
402            format!("update_keys AS ({update_keys}), delete_keys AS ({delete_keys}), update_result AS ({update_result}), delete_result AS ({delete_result})")
403        };
404        let update_count = "SELECT count(*) FROM update_result";
405        let delete_count = "SELECT count(*) FROM delete_result";
406
407        format!(r#"/**/WITH {with_queries} SELECT ({update_count}), ({delete_count}) -- binds: []"#)
408    }
409
410    #[test]
411    fn test_generated_sql() {
412        // Unlike the `get_result` method, `debug_query` does not automatically call `as_query`
413        assert_eq!(
414            debug_query::<Pg, _>(&crate::uplete(t::table).set_null(t::b).as_query()).to_string(),
415            expected_sql(r#"TRUE AND ("a" IS NULL)"#, r#""b" = NULL"#)
416        );
417        assert_eq!(
418            debug_query::<Pg, _>(
419                &crate::uplete(t::table)
420                    .set_null(t::a)
421                    .set_null(t::b)
422                    .as_query()
423            )
424            .to_string(),
425            expected_sql(r#"TRUE"#, r#""a" = NULL,"b" = NULL"#)
426        );
427    }
428
429    #[test]
430    fn test_count_methods() {
431        assert_eq!(
432            crate::UpleteCount::only_updated(1),
433            crate::UpleteCount {
434                updated: 1,
435                deleted: 0
436            }
437        );
438        assert_eq!(
439            crate::UpleteCount::only_deleted(1),
440            crate::UpleteCount {
441                updated: 0,
442                deleted: 1
443            }
444        );
445    }
446
447    fn test_all_null(conn: &mut PgConnection) -> QueryResult<()> {
448        let some = Some(1).into_sql::<sql_types::Nullable<sql_types::Integer>>();
449        let none = None::<i32>.into_sql::<sql_types::Nullable<sql_types::Integer>>();
450
451        // Allows type inference for `vec![]`
452        let mut all_null = |items| select(AllNull(items)).get_result::<bool>(conn);
453
454        assert!(all_null(vec![])?);
455        assert!(all_null(vec![none])?);
456        assert!(all_null(vec![none, none])?);
457        assert!(all_null(vec![none, none, none])?);
458        assert!(!all_null(vec![some])?);
459        assert!(!all_null(vec![some, none])?);
460        assert!(!all_null(vec![none, some, none])?);
461
462        Ok(())
463    }
464
465    #[test]
466    fn test_db_stuff() -> QueryResult<()> {
467        let user_specific_tmp = env::var("XDG_RUNTIME_DIR").unwrap();
468        let tempdir = tempdir_in(&user_specific_tmp).unwrap();
469        let tempdir_path = tempdir.path().to_str().unwrap();
470        let pgdata = tempdir.path().join("pgdata");
471        let envs = [("PGDATA", &pgdata)];
472        assert!(Command::new("initdb")
473            .envs(envs)
474            .args(["--username=postgres", "--auth=trust", "--no-sync"])
475            .status()
476            .unwrap()
477            .success());
478        let mut postgres_process = Command::new("postgres")
479            .envs(envs)
480            .args([
481                "-c",
482                "listen_addresses=",
483                "-c",
484                &format!("unix_socket_directories={}", tempdir_path),
485                "-c",
486                "fsync=off",
487                "-c",
488                "logging_collector=off",
489                "-c",
490                "port=1",
491            ])
492            .spawn()
493            .unwrap();
494        // TODO: kill postgres on failure
495        std::thread::sleep(std::time::Duration::from_secs(3));
496        let mut conn = PgConnection::establish(&format!(
497            "postgresql://postgres@{}:1",
498            tempdir_path.replace('/', "%2F")
499        ))
500        .unwrap();
501        test_count(&mut conn)?;
502        test_all_null(&mut conn)?;
503        postgres_process.kill().unwrap();
504        tempdir.close().unwrap();
505        Ok(())
506    }
507}