Skip to main content

saola_query_builder/
lib.rs

1use query_structure::{
2    AggregationSelection, FieldSelection, Filter, Model, Placeholder, PrismaValue, QueryArguments, RecordFilter,
3    RelationField, RelationLoadStrategy, ScalarCondition, ScalarField, SelectedField, SelectionResult, WriteArgs,
4};
5use serde::Serialize;
6use std::collections::BTreeMap;
7use std::fmt::Formatter;
8use std::{collections::HashMap, fmt};
9
10mod query_arguments_ext;
11
12pub use query_arguments_ext::QueryArgumentsExt;
13use query_template::{Fragment, PlaceholderFormat};
14
15pub trait QueryBuilder {
16    fn build_get_records(
17        &self,
18        model: &Model,
19        query_arguments: QueryArguments,
20        selected_fields: &FieldSelection,
21        relation_load_strategy: RelationLoadStrategy,
22    ) -> Result<Vec<DbQuery>, Box<dyn std::error::Error + Send + Sync>>;
23
24    /// Retrieve related records through an M2M relation.
25    #[cfg(feature = "relation_joins")]
26    fn build_get_related_records(
27        &self,
28        linkage: RelationLinkage,
29        query_arguments: QueryArguments,
30        selected_fields: &FieldSelection,
31    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
32
33    fn build_aggregate(
34        &self,
35        model: &Model,
36        args: QueryArguments,
37        selections: &[AggregationSelection],
38        group_by: Vec<ScalarField>,
39        having: Option<Filter>,
40    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
41
42    fn build_create_record(
43        &self,
44        model: &Model,
45        args: WriteArgs,
46        selected_fields: &FieldSelection,
47    ) -> Result<CreateRecord, Box<dyn std::error::Error + Send + Sync>>;
48
49    fn build_inserts(
50        &self,
51        model: &Model,
52        args: Vec<WriteArgs>,
53        skip_duplicates: bool,
54        selected_fields: Option<&FieldSelection>,
55    ) -> Result<Vec<DbQuery>, Box<dyn std::error::Error + Send + Sync>>;
56
57    fn build_update(
58        &self,
59        model: &Model,
60        record_filter: RecordFilter,
61        args: WriteArgs,
62        selected_fields: Option<&FieldSelection>,
63    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
64
65    fn build_updates(
66        &self,
67        model: &Model,
68        record_filter: RecordFilter,
69        args: WriteArgs,
70        selected_fields: Option<&FieldSelection>,
71        limit: Option<usize>,
72    ) -> Result<Vec<DbQuery>, Box<dyn std::error::Error + Send + Sync>>;
73
74    fn build_upsert(
75        &self,
76        model: &Model,
77        filter: Filter,
78        create_args: WriteArgs,
79        update_args: WriteArgs,
80        selected_fields: &FieldSelection,
81        unique_constraints: &[ScalarField],
82    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
83
84    fn build_m2m_connect(
85        &self,
86        parent_field: RelationField,
87        parent: PrismaValue,
88        child: PrismaValue,
89    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
90
91    fn build_m2m_disconnect(
92        &self,
93        field: RelationField,
94        parent_id: &SelectionResult,
95        child_ids: &[SelectionResult],
96    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
97
98    fn build_delete(
99        &self,
100        model: &Model,
101        filter: RecordFilter,
102        selected_fields: Option<&FieldSelection>,
103    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
104
105    fn build_deletes(
106        &self,
107        model: &Model,
108        filter: RecordFilter,
109        limit: Option<usize>,
110    ) -> Result<Vec<DbQuery>, Box<dyn std::error::Error + Send + Sync>>;
111
112    fn build_raw(
113        &self,
114        model: Option<&Model>,
115        inputs: HashMap<String, PrismaValue>,
116        query_type: Option<String>,
117    ) -> Result<DbQuery, Box<dyn std::error::Error + Send + Sync>>;
118}
119
120/// An insertion operation for a record in the database.
121pub struct CreateRecord {
122    /// The insert query to run in order to create the record.
123    pub insert_query: DbQuery,
124    /// The query to run prior to the insert in order to create default column values.
125    /// This is used in some cases where the database does not support returning default values.
126    pub select_defaults: Option<CreateRecordDefaultsQuery>,
127    /// The field in the model of the record that corresponds to the last inserted ID, if
128    /// required by the database.
129    pub last_insert_id_field: Option<ScalarField>,
130    /// The values to merge into the resulting record after insertion. These are inferred from the
131    /// input arguments.
132    pub merge_values: Vec<(SelectedField, PrismaValue)>,
133}
134
135/// A query that retrieves default values needed for an insert operation.
136pub struct CreateRecordDefaultsQuery {
137    /// The query that returns the default values.
138    pub query: DbQuery,
139    /// The fields that are selected in the query and their corresponding placeholders.
140    /// These placeholders are referred to by the subsequent insert query.
141    pub field_placeholders: Vec<(ScalarField, Placeholder)>,
142}
143
144#[derive(Debug)]
145pub struct ConditionalLink {
146    field: ScalarField,
147    conditions: Vec<ScalarCondition>,
148}
149
150impl ConditionalLink {
151    pub fn new(field: ScalarField, conditions: Vec<ScalarCondition>) -> Self {
152        Self { field, conditions }
153    }
154
155    pub fn field(&self) -> &ScalarField {
156        &self.field
157    }
158
159    pub fn into_field_and_conditions(self) -> (ScalarField, Vec<ScalarCondition>) {
160        (self.field, self.conditions)
161    }
162}
163
164#[derive(Debug)]
165pub struct RelationLinkage {
166    parent_field: RelationField,
167    conditions: BTreeMap<ScalarField, Vec<ScalarCondition>>,
168}
169
170impl RelationLinkage {
171    pub fn new(field: RelationField, links: Vec<ConditionalLink>) -> Self {
172        Self {
173            parent_field: field,
174            conditions: links
175                .into_iter()
176                .map(ConditionalLink::into_field_and_conditions)
177                .collect(),
178        }
179    }
180
181    pub fn parent_field(&self) -> &RelationField {
182        &self.parent_field
183    }
184
185    pub fn add_condition(&mut self, field: ScalarField, condition: ScalarCondition) {
186        self.conditions.entry(field).or_default().push(condition);
187    }
188
189    pub fn into_parent_field_and_conditions(
190        self,
191    ) -> (
192        RelationField,
193        impl Iterator<Item = (ScalarField, Vec<ScalarCondition>)> + fmt::Debug,
194    ) {
195        (self.parent_field, self.conditions.into_iter())
196    }
197}
198
199impl fmt::Display for RelationLinkage {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        write!(
202            f,
203            "{}@{}",
204            self.parent_field.relation().name(),
205            self.parent_field.model().name()
206        )
207    }
208}
209
210#[derive(Debug, Serialize)]
211#[serde(tag = "type", rename_all = "camelCase")]
212pub enum DbQuery {
213    #[serde(rename_all = "camelCase")]
214    RawSql {
215        sql: String,
216        args: Vec<PrismaValue>,
217        arg_types: Vec<ArgType>,
218    },
219    #[serde(rename_all = "camelCase")]
220    TemplateSql {
221        fragments: Vec<Fragment>,
222        args: Vec<PrismaValue>,
223        arg_types: Vec<DynamicArgType>,
224        placeholder_format: PlaceholderFormat,
225        chunkable: Chunkable,
226    },
227}
228
229impl DbQuery {
230    pub fn params(&self) -> &[PrismaValue] {
231        match self {
232            DbQuery::RawSql { args: params, .. } => params,
233            DbQuery::TemplateSql { args: params, .. } => params,
234        }
235    }
236}
237
238impl fmt::Display for DbQuery {
239    /// Should only be used for debugging, unit testing and playground CLI output.
240    /// The placeholder syntax does not attempt to match any actual SQL flavour.
241    fn fmt(&self, formatter: &mut Formatter<'_>) -> fmt::Result {
242        match self {
243            DbQuery::RawSql { sql, .. } => {
244                write!(formatter, "{sql}")?;
245            }
246            DbQuery::TemplateSql { fragments, .. } => {
247                let placeholder_format = PlaceholderFormat {
248                    prefix: "$",
249                    has_numbering: true,
250                };
251                let mut number = 1;
252                for fragment in fragments {
253                    match fragment {
254                        Fragment::StringChunk { chunk } => {
255                            write!(formatter, "{chunk}")?;
256                        }
257                        Fragment::Parameter => {
258                            placeholder_format.write(formatter, &mut number)?;
259                        }
260                        Fragment::ParameterTuple => {
261                            write!(formatter, "[")?;
262                            placeholder_format.write(formatter, &mut number)?;
263                            write!(formatter, "]")?;
264                        }
265                        Fragment::ParameterTupleList { .. } => {
266                            write!(formatter, "[(")?;
267                            placeholder_format.write(formatter, &mut number)?;
268                            write!(formatter, ")]")?;
269                        }
270                    };
271                }
272            }
273        }
274        Ok(())
275    }
276}
277
278#[derive(Debug, Serialize)]
279#[serde(tag = "arity", rename_all = "camelCase")]
280pub enum DynamicArgType {
281    Tuple {
282        elements: Vec<ArgType>,
283    },
284    #[serde(untagged)]
285    Single {
286        #[serde(flatten)]
287        r#type: ArgType,
288    },
289}
290
291#[derive(Debug, Serialize)]
292#[serde(rename_all = "camelCase")]
293pub struct ArgType {
294    pub arity: Arity,
295    pub scalar_type: ArgScalarType,
296    pub db_type: Option<String>,
297}
298
299impl ArgType {
300    pub fn new(arity: Arity, scalar_type: ArgScalarType, db_type: Option<String>) -> Self {
301        Self {
302            arity,
303            scalar_type,
304            db_type,
305        }
306    }
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
310#[serde(rename_all = "camelCase")]
311pub enum Arity {
312    Scalar,
313    List,
314}
315
316#[derive(Debug, Serialize)]
317#[serde(rename_all = "camelCase")]
318pub enum ArgScalarType {
319    String,
320    Int,
321    #[serde(rename = "bigint")]
322    BigInt,
323    Float,
324    Decimal,
325    Boolean,
326    Enum,
327    Uuid,
328    Json,
329    #[serde(rename = "datetime")]
330    DateTime,
331    Bytes,
332    Unknown,
333}
334
335/// Indicates whether the parameters of this query can be chunked into smaller queries.
336#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
337#[serde(into = "bool")]
338pub enum Chunkable {
339    Yes,
340    No,
341}
342
343impl From<Chunkable> for bool {
344    fn from(chunkable: Chunkable) -> Self {
345        matches!(chunkable, Chunkable::Yes)
346    }
347}
348
349impl From<&QueryArguments> for Chunkable {
350    fn from(args: &QueryArguments) -> Self {
351        if !args.order_by.is_empty()
352            || args.cursor.is_some()
353            || args.has_unbatchable_filters()
354            || args.has_unbatchable_ordering()
355        {
356            Chunkable::No
357        } else {
358            Chunkable::Yes
359        }
360    }
361}
362
363impl From<&Filter> for Chunkable {
364    fn from(filter: &Filter) -> Self {
365        if filter.can_batch() {
366            Chunkable::Yes
367        } else {
368            Chunkable::No
369        }
370    }
371}