pgx_utils/sql_entity_graph/aggregate/
entity.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9/*!
10
11`#[pg_aggregate]` related entities for Rust to SQL translation
12
13> Like all of the [`sql_entity_graph`][crate::sql_entity_graph] APIs, this is considered **internal**
14to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
15
16
17*/
18use crate::sql_entity_graph::aggregate::options::{FinalizeModify, ParallelOption};
19use crate::sql_entity_graph::metadata::SqlMapping;
20use crate::sql_entity_graph::pgx_sql::PgxSql;
21use crate::sql_entity_graph::to_sql::entity::ToSqlConfigEntity;
22use crate::sql_entity_graph::to_sql::ToSql;
23use crate::sql_entity_graph::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
24use core::any::TypeId;
25use core::cmp::Ordering;
26use eyre::{eyre, WrapErr};
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq)]
29pub struct AggregateTypeEntity {
30    pub used_ty: UsedTypeEntity,
31    pub name: Option<&'static str>,
32}
33
34#[derive(Debug, Clone, Hash, PartialEq, Eq)]
35pub struct PgAggregateEntity {
36    pub full_path: &'static str,
37    pub module_path: &'static str,
38    pub file: &'static str,
39    pub line: u32,
40    pub ty_id: TypeId,
41
42    pub name: &'static str,
43
44    /// If the aggregate is an ordered set aggregate.
45    ///
46    /// See [the PostgreSQL ordered set docs](https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES).
47    pub ordered_set: bool,
48
49    /// The `arg_data_type` list.
50    ///
51    /// Corresponds to `Args` in [`pgx::aggregate::Aggregate`].
52    pub args: Vec<AggregateTypeEntity>,
53
54    /// The direct argument list, appearing before `ORDER BY` in ordered set aggregates.
55    ///
56    /// Corresponds to `OrderBy` in [`pgx::aggregate::Aggregate`].
57    pub direct_args: Option<Vec<AggregateTypeEntity>>,
58
59    /// The `STYPE` and `name` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
60    ///
61    /// The implementor of an [`pgx::aggregate::Aggregate`].
62    pub stype: AggregateTypeEntity,
63
64    /// The `SFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
65    ///
66    /// Corresponds to `state` in [`pgx::aggregate::Aggregate`].
67    pub sfunc: &'static str,
68
69    /// The `FINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
70    ///
71    /// Corresponds to `finalize` in [`pgx::aggregate::Aggregate`].
72    pub finalfunc: Option<&'static str>,
73
74    /// The `FINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
75    ///
76    /// Corresponds to `FINALIZE_MODIFY` in [`pgx::aggregate::Aggregate`].
77    pub finalfunc_modify: Option<FinalizeModify>,
78
79    /// The `COMBINEFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
80    ///
81    /// Corresponds to `combine` in [`pgx::aggregate::Aggregate`].
82    pub combinefunc: Option<&'static str>,
83
84    /// The `SERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
85    ///
86    /// Corresponds to `serial` in [`pgx::aggregate::Aggregate`].
87    pub serialfunc: Option<&'static str>,
88
89    /// The `DESERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
90    ///
91    /// Corresponds to `deserial` in [`pgx::aggregate::Aggregate`].
92    pub deserialfunc: Option<&'static str>,
93
94    /// The `INITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
95    ///
96    /// Corresponds to `INITIAL_CONDITION` in [`pgx::aggregate::Aggregate`].
97    pub initcond: Option<&'static str>,
98
99    /// The `MSFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
100    ///
101    /// Corresponds to `moving_state` in [`pgx::aggregate::Aggregate`].
102    pub msfunc: Option<&'static str>,
103
104    /// The `MINVFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
105    ///
106    /// Corresponds to `moving_state_inverse` in [`pgx::aggregate::Aggregate`].
107    pub minvfunc: Option<&'static str>,
108
109    /// The `MSTYPE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
110    ///
111    /// Corresponds to `MovingState` in [`pgx::aggregate::Aggregate`].
112    pub mstype: Option<UsedTypeEntity>,
113
114    // The `MSSPACE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
115    //
116    // TODO: Currently unused.
117    // pub msspace: &'static str,
118    /// The `MFINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
119    ///
120    /// Corresponds to `moving_state_finalize` in [`pgx::aggregate::Aggregate`].
121    pub mfinalfunc: Option<&'static str>,
122
123    /// The `MFINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
124    ///
125    /// Corresponds to `MOVING_FINALIZE_MODIFY` in [`pgx::aggregate::Aggregate`].
126    pub mfinalfunc_modify: Option<FinalizeModify>,
127
128    /// The `MINITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
129    ///
130    /// Corresponds to `MOVING_INITIAL_CONDITION` in [`pgx::aggregate::Aggregate`].
131    pub minitcond: Option<&'static str>,
132
133    /// The `SORTOP` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
134    ///
135    /// Corresponds to `SORT_OPERATOR` in [`pgx::aggregate::Aggregate`].
136    pub sortop: Option<&'static str>,
137
138    /// The `PARALLEL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
139    ///
140    /// Corresponds to `PARALLEL` in [`pgx::aggregate::Aggregate`].
141    pub parallel: Option<ParallelOption>,
142
143    /// The `HYPOTHETICAL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
144    ///
145    /// Corresponds to `hypothetical` in [`pgx::aggregate::Aggregate`].
146    pub hypothetical: bool,
147    pub to_sql_config: ToSqlConfigEntity,
148}
149
150impl Ord for PgAggregateEntity {
151    fn cmp(&self, other: &Self) -> Ordering {
152        self.file.cmp(other.full_path).then_with(|| self.file.cmp(other.full_path))
153    }
154}
155
156impl PartialOrd for PgAggregateEntity {
157    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
158        Some(self.cmp(other))
159    }
160}
161
162impl From<PgAggregateEntity> for SqlGraphEntity {
163    fn from(val: PgAggregateEntity) -> Self {
164        SqlGraphEntity::Aggregate(val)
165    }
166}
167
168impl SqlGraphIdentifier for PgAggregateEntity {
169    fn dot_identifier(&self) -> String {
170        format!("aggregate {}", self.full_path)
171    }
172    fn rust_identifier(&self) -> String {
173        self.full_path.to_string()
174    }
175    fn file(&self) -> Option<&'static str> {
176        Some(self.file)
177    }
178    fn line(&self) -> Option<u32> {
179        Some(self.line)
180    }
181}
182
183impl ToSql for PgAggregateEntity {
184    #[tracing::instrument(level = "debug", err, skip(self, context), fields(identifier = %self.rust_identifier()))]
185    fn to_sql(&self, context: &PgxSql) -> eyre::Result<String> {
186        let self_index = context.aggregates[self];
187        let mut optional_attributes = Vec::new();
188        let schema = context.schema_prefix_for(&self_index);
189
190        if let Some(value) = self.finalfunc {
191            optional_attributes.push((
192                format!("\tFINALFUNC = {}\"{}\"", schema, value),
193                format!("/* {}::final */", self.full_path),
194            ));
195        }
196        if let Some(value) = self.finalfunc_modify {
197            optional_attributes.push((
198                format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
199                format!("/* {}::FINALIZE_MODIFY */", self.full_path),
200            ));
201        }
202        if let Some(value) = self.combinefunc {
203            optional_attributes.push((
204                format!("\tCOMBINEFUNC = {}\"{}\"", schema, value),
205                format!("/* {}::combine */", self.full_path),
206            ));
207        }
208        if let Some(value) = self.serialfunc {
209            optional_attributes.push((
210                format!("\tSERIALFUNC = {}\"{}\"", schema, value),
211                format!("/* {}::serial */", self.full_path),
212            ));
213        }
214        if let Some(value) = self.deserialfunc {
215            optional_attributes.push((
216                format!("\tDESERIALFUNC ={} \"{}\"", schema, value),
217                format!("/* {}::deserial */", self.full_path),
218            ));
219        }
220        if let Some(value) = self.initcond {
221            optional_attributes.push((
222                format!("\tINITCOND = '{}'", value),
223                format!("/* {}::INITIAL_CONDITION */", self.full_path),
224            ));
225        }
226        if let Some(value) = self.msfunc {
227            optional_attributes.push((
228                format!("\tMSFUNC = {}\"{}\"", schema, value),
229                format!("/* {}::moving_state */", self.full_path),
230            ));
231        }
232        if let Some(value) = self.minvfunc {
233            optional_attributes.push((
234                format!("\tMINVFUNC = {}\"{}\"", schema, value),
235                format!("/* {}::moving_state_inverse */", self.full_path),
236            ));
237        }
238        if let Some(value) = self.mfinalfunc {
239            optional_attributes.push((
240                format!("\tMFINALFUNC = {}\"{}\"", schema, value),
241                format!("/* {}::moving_state_finalize */", self.full_path),
242            ));
243        }
244        if let Some(value) = self.mfinalfunc_modify {
245            optional_attributes.push((
246                format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
247                format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
248            ));
249        }
250        if let Some(value) = self.minitcond {
251            optional_attributes.push((
252                format!("\tMINITCOND = '{}'", value),
253                format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
254            ));
255        }
256        if let Some(value) = self.sortop {
257            optional_attributes.push((
258                format!("\tSORTOP = \"{}\"", value),
259                format!("/* {}::SORT_OPERATOR */", self.full_path),
260            ));
261        }
262        if let Some(value) = self.parallel {
263            optional_attributes.push((
264                format!("\tPARALLEL = {}", value.to_sql(context)?),
265                format!("/* {}::PARALLEL */", self.full_path),
266            ));
267        }
268        if self.hypothetical {
269            optional_attributes.push((
270                String::from("\tHYPOTHETICAL"),
271                format!("/* {}::hypothetical */", self.full_path),
272            ))
273        }
274
275        let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
276            match used_ty.metadata.argument_sql {
277                Ok(SqlMapping::As(ref argument_sql)) => Ok(argument_sql.to_string()),
278                Ok(SqlMapping::Composite { array_brackets }) => used_ty
279                    .composite_type
280                    .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
281                    .ok_or_else(|| {
282                        eyre!("Macro expansion time suggested a composite_type!() in return")
283                    }),
284                Ok(SqlMapping::Source { array_brackets }) => {
285                    let sql = context
286                        .source_only_to_sql_type(used_ty.ty_source)
287                        .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
288                        .ok_or_else(|| {
289                            eyre!("Macro expansion time suggested a source only mapping in return")
290                        })?;
291                    Ok(sql)
292                }
293                Ok(SqlMapping::Skip) => {
294                    Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
295                }
296                Err(err) => match context.source_only_to_sql_type(used_ty.ty_source) {
297                    Some(source_only_mapping) => Ok(source_only_mapping.to_string()),
298                    None => return Err(err).wrap_err("While mapping argument"),
299                },
300            }
301        };
302
303        let stype_sql = map_ty(&self.stype.used_ty).wrap_err("Mapping state type")?;
304
305        if let Some(value) = &self.mstype {
306            let mstype_sql = map_ty(&value).wrap_err("Mapping moving state type")?;
307            optional_attributes.push((
308                format!("\tMSTYPE = {}", mstype_sql),
309                format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
310            ));
311        }
312
313        let mut optional_attributes_string = String::new();
314        for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
315            let optional_attribute_string = format!(
316                "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
317                optional_attribute = optional_attribute,
318                maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
319                comment = comment,
320                maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
321            );
322            optional_attributes_string += &optional_attribute_string;
323        }
324
325        let sql = format!(
326            "\n\
327                -- {file}:{line}\n\
328                -- {full_path}\n\
329                CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
330                (\n\
331                    \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
332                    \tSTYPE = {schema}{stype}{maybe_comma_after_stype} /* {stype_full_path} */\
333                    {optional_attributes}\
334                );\
335            ",
336            schema = schema,
337            name = self.name,
338            full_path = self.full_path,
339            file = self.file,
340            line = self.line,
341            sfunc = self.sfunc,
342            stype = stype_sql,
343            stype_full_path = self.stype.used_ty.full_path,
344            maybe_comma_after_stype = if optional_attributes.len() == 0 { "" } else { "," },
345            args = {
346                let mut args = Vec::new();
347                for (idx, arg) in self.args.iter().enumerate() {
348                    let graph_index = context
349                        .graph
350                        .neighbors_undirected(self_index)
351                        .find(|neighbor| match &context.graph[*neighbor] {
352                            SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
353                            SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
354                            SqlGraphEntity::BuiltinType(defined) => {
355                                defined == &arg.used_ty.full_path
356                            }
357                            _ => false,
358                        })
359                        .ok_or_else(|| {
360                            eyre!("Could not find arg type in graph. Got: {:?}", arg.used_ty)
361                        })?;
362                    let needs_comma = idx < (self.args.len() - 1);
363                    let buf = format!("\
364                           \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
365                       ",
366                           schema_prefix = context.schema_prefix_for(&graph_index),
367                           // First try to match on [`TypeId`] since it's most reliable.
368                           sql_type = match arg.used_ty.metadata.argument_sql {
369                                Ok(SqlMapping::As(ref argument_sql)) => {
370                                    argument_sql.to_string()
371                                }
372                                Ok(SqlMapping::Composite {
373                                    array_brackets,
374                                }) => {
375                                    arg.used_ty
376                                        .composite_type
377                                        .map(|v| {
378                                            if array_brackets {
379                                                format!("{v}[]")
380                                            } else {
381                                                format!("{v}")
382                                            }
383                                        })
384                                        .ok_or_else(|| {
385                                            eyre!(
386                                            "Macro expansion time suggested a composite_type!() in return"
387                                        )
388                                        })?
389                                }
390                                Ok(SqlMapping::Source {
391                                    array_brackets,
392                                }) => {
393                                    let sql = context
394                                        .source_only_to_sql_type(arg.used_ty.ty_source)
395                                        .map(|v| {
396                                            if array_brackets {
397                                                format!("{v}[]")
398                                            } else {
399                                                format!("{v}")
400                                            }
401                                        })
402                                        .ok_or_else(|| {
403                                            eyre!(
404                                            "Macro expansion time suggested a source only mapping in return"
405                                        )
406                                        })?;
407                                    sql
408                                }
409                                Ok(SqlMapping::Skip) => return Err(eyre!("Got a skipped SQL translatable type in aggregate args, this is not permitted")),
410                                Err(err) => {
411                                    match context.source_only_to_sql_type(arg.used_ty.ty_source) {
412                                        Some(source_only_mapping) => {
413                                            source_only_mapping.to_string()
414                                        }
415                                        None => return Err(err).wrap_err("While mapping argument"),
416                                    }
417                                }
418                            },
419                           variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
420                           maybe_comma = if needs_comma { ", " } else { " " },
421                           full_path = arg.used_ty.full_path,
422                           name = if let Some(name) = arg.name {
423                               format!(r#""{}" "#, name)
424                           } else { "".to_string() },
425                    );
426                    args.push(buf);
427                }
428                "\n".to_string() + &args.join("\n") + "\n"
429            },
430            direct_args = if let Some(direct_args) = &self.direct_args {
431                let mut args = Vec::new();
432                for (idx, arg) in direct_args.iter().enumerate() {
433                    let graph_index = context
434                        .graph
435                        .neighbors_undirected(self_index)
436                        .find(|neighbor| match &context.graph[*neighbor] {
437                            SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
438                            SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
439                            SqlGraphEntity::BuiltinType(defined) => {
440                                defined == &arg.used_ty.full_path
441                            }
442                            _ => false,
443                        })
444                        .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
445                    let needs_comma = idx < (direct_args.len() - 1);
446                    let buf = format!(
447                        "\
448                        \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
449                       ",
450                        schema_prefix = context.schema_prefix_for(&graph_index),
451                        // First try to match on [`TypeId`] since it's most reliable.
452                        sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
453                        maybe_name = if let Some(name) = arg.name {
454                            "\"".to_string() + name + "\" "
455                        } else {
456                            "".to_string()
457                        },
458                        maybe_comma = if needs_comma { ", " } else { " " },
459                        full_path = arg.used_ty.full_path,
460                    );
461                    args.push(buf);
462                }
463                "\n".to_string() + &args.join("\n,") + "\n"
464            } else {
465                String::default()
466            },
467            maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
468            optional_attributes = if optional_attributes.len() == 0 {
469                String::from("\n")
470            } else {
471                String::from("\n")
472            } + &optional_attributes_string
473                + if optional_attributes.len() == 0 { "" } else { "\n" },
474        );
475        tracing::trace!(%sql);
476        Ok(sql)
477    }
478}