Skip to main content

openauth_core/db/sql/
statements.rs

1use super::*;
2
3pub fn create_statement(
4    dialect: SqlDialect,
5    schema: &DbSchema,
6    query: &Create,
7) -> Result<SqlStatement, OpenAuthError> {
8    let table = resolve_table(schema, &query.model)?;
9    let selection = selected_fields(table, &query.select)?;
10    let mut columns = Vec::new();
11    let mut placeholders = Vec::new();
12    let mut params = Vec::new();
13
14    for (field, value) in &query.data {
15        let (_, metadata) = resolve_field(table, field)?;
16        columns.push(dialect.quote_identifier(&metadata.name)?);
17        params.push(SqlParam::new(metadata, value.clone()));
18        placeholders.push(dialect.placeholder(params.len()));
19    }
20
21    let mut sql = if columns.is_empty() {
22        match dialect {
23            SqlDialect::Postgres | SqlDialect::Sqlite => format!(
24                "INSERT INTO {} DEFAULT VALUES",
25                dialect.quote_identifier(&table.name)?
26            ),
27            SqlDialect::MySql => format!(
28                "INSERT INTO {} () VALUES ()",
29                dialect.quote_identifier(&table.name)?
30            ),
31        }
32    } else {
33        format!(
34            "INSERT INTO {} ({}) VALUES ({})",
35            dialect.quote_identifier(&table.name)?,
36            columns.join(", "),
37            placeholders.join(", ")
38        )
39    };
40    if dialect.supports_insert_returning() && table_has_database_generated_id(table) {
41        sql.push_str(" RETURNING ");
42        sql.push_str(
43            &selection
44                .iter()
45                .map(|selected| dialect.quote_identifier(&selected.field.name))
46                .collect::<Result<Vec<_>, _>>()?
47                .join(", "),
48        );
49    }
50
51    Ok(SqlStatement { sql, params })
52}
53
54pub fn create_returning_selection(
55    schema: &DbSchema,
56    query: &Create,
57) -> Result<Vec<SqlSelectedField>, OpenAuthError> {
58    selected_fields(resolve_table(schema, &query.model)?, &query.select)
59}
60
61pub fn find_one_statement(
62    dialect: SqlDialect,
63    schema: &DbSchema,
64    query: &FindOne,
65) -> Result<SqlReadStatement, OpenAuthError> {
66    let mut find_many = FindMany::new(query.model.clone());
67    find_many.where_clauses = query.where_clauses.clone();
68    find_many.limit = Some(1);
69    find_many.select = query.select.clone();
70    find_many.joins = query.joins.clone();
71    find_many_statement(dialect, schema, &find_many)
72}
73
74pub fn find_many_statement(
75    dialect: SqlDialect,
76    schema: &DbSchema,
77    query: &FindMany,
78) -> Result<SqlReadStatement, OpenAuthError> {
79    let table = resolve_table(schema, &query.model)?;
80    let selection = selected_fields(table, &query.select)?;
81    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
82    let sql = format!(
83        "SELECT {} FROM {}{}{}",
84        selection
85            .iter()
86            .map(|selected| dialect.quote_identifier(&selected.field.name))
87            .collect::<Result<Vec<_>, _>>()?
88            .join(", "),
89        dialect.quote_identifier(&table.name)?,
90        where_sql.sql,
91        dialect.order_limit_offset(table, query.sort_by.as_ref(), query.limit, query.offset)?
92    );
93
94    Ok(SqlReadStatement {
95        statement: SqlStatement {
96            sql,
97            params: where_sql.params,
98        },
99        selection,
100    })
101}
102
103pub fn find_many_with_joins_statement<'a>(
104    dialect: SqlDialect,
105    schema: &'a DbSchema,
106    query: &FindMany,
107) -> Result<SqlJoinReadStatement<'a>, OpenAuthError> {
108    let (base_logical, table) = resolve_table_with_logical(schema, &query.model)?;
109    let joins = resolve_native_joins(schema, base_logical, table, &query.joins, 100)?;
110    let base_selection = internal_base_selection(table, &query.select, &joins)?;
111    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
112    let base_columns = base_selection
113        .iter()
114        .map(|(_, field)| dialect.quote_identifier(&field.name))
115        .collect::<Result<Vec<_>, _>>()?;
116    let base_sql = format!(
117        "SELECT {} FROM {}{}{}",
118        base_columns.join(", "),
119        dialect.quote_identifier(&table.name)?,
120        where_sql.sql,
121        dialect.order_limit_offset(table, query.sort_by.as_ref(), query.limit, query.offset)?
122    );
123
124    let mut selects = vec![format!(
125        "{}.{} AS {}",
126        dialect.quote_identifier("base")?,
127        dialect.quote_identifier(&resolve_field_from_selection(&base_selection, "id")?.name)?,
128        dialect.quote_identifier("__base_id")?
129    )];
130    for (index, (_, field)) in base_selection.iter().enumerate() {
131        selects.push(format!(
132            "{}.{} AS {}",
133            dialect.quote_identifier("base")?,
134            dialect.quote_identifier(&field.name)?,
135            dialect.quote_identifier(&base_alias(index))?
136        ));
137    }
138    for (join_index, join) in joins.iter().enumerate() {
139        for (field_index, (_, field)) in join.selection.iter().enumerate() {
140            selects.push(format!(
141                "{}.{} AS {}",
142                dialect.quote_identifier(&join_alias(join_index))?,
143                dialect.quote_identifier(&field.name)?,
144                dialect.quote_identifier(&join_field_alias(join_index, field_index))?
145            ));
146        }
147    }
148
149    let mut sql = format!(
150        "SELECT {} FROM ({}) AS {}",
151        selects.join(", "),
152        base_sql,
153        dialect.quote_identifier("base")?
154    );
155    for (index, join) in joins.iter().enumerate() {
156        sql.push_str(" LEFT JOIN ");
157        sql.push_str(&dialect.quote_identifier(&join.table.name)?);
158        sql.push_str(" AS ");
159        sql.push_str(&dialect.quote_identifier(&join_alias(index))?);
160        sql.push_str(" ON ");
161        sql.push_str(&dialect.quote_identifier(&join_alias(index))?);
162        sql.push('.');
163        sql.push_str(&dialect.quote_identifier(&join.to)?);
164        sql.push_str(" = ");
165        sql.push_str(&dialect.quote_identifier("base")?);
166        sql.push('.');
167        sql.push_str(&dialect.quote_identifier(&join.from)?);
168    }
169
170    Ok(SqlJoinReadStatement {
171        statement: SqlStatement {
172            sql,
173            params: where_sql.params,
174        },
175        base_selection,
176        joins,
177    })
178}
179
180pub fn count_statement(
181    dialect: SqlDialect,
182    schema: &DbSchema,
183    query: &Count,
184) -> Result<SqlStatement, OpenAuthError> {
185    let table = resolve_table(schema, &query.model)?;
186    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
187    Ok(SqlStatement {
188        sql: format!(
189            "SELECT COUNT(*) FROM {}{}",
190            dialect.quote_identifier(&table.name)?,
191            where_sql.sql
192        ),
193        params: where_sql.params,
194    })
195}
196
197pub fn update_one_plan(
198    dialect: SqlDialect,
199    schema: &DbSchema,
200    query: &Update,
201) -> Result<SqlUpdateOnePlan, OpenAuthError> {
202    let table = resolve_table(schema, &query.model)?;
203    let selection = selected_fields(table, &[])?;
204
205    match dialect {
206        SqlDialect::Postgres | SqlDialect::Sqlite => {
207            let assignment = update_assignment(dialect, table, &query.data, 1)?;
208            let where_sql =
209                dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
210            let row_id = match dialect {
211                SqlDialect::Postgres => "ctid",
212                SqlDialect::Sqlite => "rowid",
213                SqlDialect::MySql => {
214                    return Err(OpenAuthError::Adapter(
215                        "mysql update-one uses a preselect plan".to_owned(),
216                    ));
217                }
218            };
219            let mut params = assignment.params;
220            params.extend(where_sql.params);
221            Ok(SqlUpdateOnePlan::Returning(SqlReadStatement {
222                statement: SqlStatement {
223                    sql: format!(
224                        "UPDATE {} SET {} WHERE {row_id} IN (SELECT {row_id} FROM {}{} LIMIT 1) RETURNING {}",
225                        dialect.quote_identifier(&table.name)?,
226                        assignment.sql.join(", "),
227                        dialect.quote_identifier(&table.name)?,
228                        where_sql.sql,
229                        selection
230                            .iter()
231                            .map(|selected| dialect.quote_identifier(&selected.field.name))
232                            .collect::<Result<Vec<_>, _>>()?
233                            .join(", ")
234                    ),
235                    params,
236                },
237                selection,
238            }))
239        }
240        SqlDialect::MySql => {
241            let mut select_query = FindMany::new(query.model.clone());
242            select_query.where_clauses = query.where_clauses.clone();
243            select_query.limit = Some(1);
244            let select = find_many_statement(dialect, schema, &select_query)?;
245            let assignment = update_assignment(dialect, table, &query.data, 1)?;
246            let where_sql =
247                dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
248            let mut params = assignment.params;
249            params.extend(where_sql.params);
250            Ok(SqlUpdateOnePlan::PreselectThenUpdate {
251                select,
252                update: SqlStatement {
253                    sql: format!(
254                        "UPDATE {} SET {}{} LIMIT 1",
255                        dialect.quote_identifier(&table.name)?,
256                        assignment.sql.join(", "),
257                        where_sql.sql
258                    ),
259                    params,
260                },
261                data: query.data.clone(),
262            })
263        }
264    }
265}
266
267pub fn update_many_statement(
268    dialect: SqlDialect,
269    schema: &DbSchema,
270    query: &UpdateMany,
271) -> Result<SqlStatement, OpenAuthError> {
272    let table = resolve_table(schema, &query.model)?;
273    let assignment = update_assignment(dialect, table, &query.data, 1)?;
274    let where_sql =
275        dialect.where_clause_starting_at(table, &query.where_clauses, assignment.next)?;
276    let mut params = assignment.params;
277    params.extend(where_sql.params);
278    Ok(SqlStatement {
279        sql: format!(
280            "UPDATE {} SET {}{}",
281            dialect.quote_identifier(&table.name)?,
282            assignment.sql.join(", "),
283            where_sql.sql
284        ),
285        params,
286    })
287}
288
289pub fn delete_one_statement(
290    dialect: SqlDialect,
291    schema: &DbSchema,
292    query: &Delete,
293) -> Result<SqlDeleteOnePlan, OpenAuthError> {
294    let table = resolve_table(schema, &query.model)?;
295    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
296    let statement = match dialect {
297        SqlDialect::Postgres => SqlStatement {
298            sql: format!(
299                "DELETE FROM {} WHERE ctid IN (SELECT ctid FROM {}{} LIMIT 1)",
300                dialect.quote_identifier(&table.name)?,
301                dialect.quote_identifier(&table.name)?,
302                where_sql.sql
303            ),
304            params: where_sql.params,
305        },
306        SqlDialect::Sqlite => SqlStatement {
307            sql: format!(
308                "DELETE FROM {} WHERE rowid IN (SELECT rowid FROM {}{} LIMIT 1)",
309                dialect.quote_identifier(&table.name)?,
310                dialect.quote_identifier(&table.name)?,
311                where_sql.sql
312            ),
313            params: where_sql.params,
314        },
315        SqlDialect::MySql => SqlStatement {
316            sql: format!(
317                "DELETE FROM {}{} LIMIT 1",
318                dialect.quote_identifier(&table.name)?,
319                where_sql.sql
320            ),
321            params: where_sql.params,
322        },
323    };
324    let strategy = match dialect {
325        SqlDialect::Postgres | SqlDialect::Sqlite => DeleteOneStrategy::NestedId,
326        SqlDialect::MySql => DeleteOneStrategy::Limit,
327    };
328    Ok(SqlDeleteOnePlan {
329        statement,
330        strategy,
331    })
332}
333
334pub fn delete_many_statement(
335    dialect: SqlDialect,
336    schema: &DbSchema,
337    query: &DeleteMany,
338) -> Result<SqlStatement, OpenAuthError> {
339    let table = resolve_table(schema, &query.model)?;
340    let where_sql = dialect.where_clause(table, &query.where_clauses)?;
341    Ok(SqlStatement {
342        sql: format!(
343            "DELETE FROM {}{}",
344            dialect.quote_identifier(&table.name)?,
345            where_sql.sql
346        ),
347        params: where_sql.params,
348    })
349}
350
351struct UpdateAssignment {
352    sql: Vec<String>,
353    params: Vec<SqlParam>,
354    next: usize,
355}
356
357fn update_assignment(
358    dialect: SqlDialect,
359    table: &DbTable,
360    data: &DbRecord,
361    first_placeholder: usize,
362) -> Result<UpdateAssignment, OpenAuthError> {
363    let mut sql = Vec::new();
364    let mut params = Vec::new();
365    for (field, value) in data {
366        let (_, metadata) = resolve_field(table, field)?;
367        params.push(SqlParam::new(metadata, value.clone()));
368        sql.push(format!(
369            "{} = {}",
370            dialect.quote_identifier(&metadata.name)?,
371            dialect.placeholder(first_placeholder + params.len() - 1)
372        ));
373    }
374    Ok(UpdateAssignment {
375        sql,
376        next: first_placeholder + params.len(),
377        params,
378    })
379}
380
381pub fn table_has_database_generated_id(table: &DbTable) -> bool {
382    table
383        .field("id")
384        .and_then(|field| field.generated_id)
385        .is_some()
386}
387
388impl SqlDialect {
389    pub fn supports_insert_returning(self) -> bool {
390        matches!(self, Self::Postgres | Self::Sqlite)
391    }
392}