pgx_sql_entity_graph/aggregate/
entity.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9/*!
10
11`#[pg_aggregate]` related entities for Rust to SQL translation
12
13> Like all of the [`sql_entity_graph`][crate::pgx_sql_entity_graph] APIs, this is considered **internal**
14to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
15
16
17*/
18use crate::aggregate::options::{FinalizeModify, ParallelOption};
19use crate::metadata::SqlMapping;
20use crate::pgx_sql::PgxSql;
21use crate::to_sql::entity::ToSqlConfigEntity;
22use crate::to_sql::ToSql;
23use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
24use core::any::TypeId;
25use eyre::{eyre, WrapErr};
26
27#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
28pub struct AggregateTypeEntity {
29    pub used_ty: UsedTypeEntity,
30    pub name: Option<&'static str>,
31}
32
33#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
34pub struct PgAggregateEntity {
35    pub full_path: &'static str,
36    pub module_path: &'static str,
37    pub file: &'static str,
38    pub line: u32,
39    pub ty_id: TypeId,
40
41    pub name: &'static str,
42
43    /// If the aggregate is an ordered set aggregate.
44    ///
45    /// See [the PostgreSQL ordered set docs](https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES).
46    pub ordered_set: bool,
47
48    /// The `arg_data_type` list.
49    ///
50    /// Corresponds to `Args` in [`pgx::aggregate::Aggregate`].
51    pub args: Vec<AggregateTypeEntity>,
52
53    /// The direct argument list, appearing before `ORDER BY` in ordered set aggregates.
54    ///
55    /// Corresponds to `OrderBy` in [`pgx::aggregate::Aggregate`].
56    pub direct_args: Option<Vec<AggregateTypeEntity>>,
57
58    /// The `STYPE` and `name` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
59    ///
60    /// The implementor of an [`pgx::aggregate::Aggregate`].
61    pub stype: AggregateTypeEntity,
62
63    /// The `SFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
64    ///
65    /// Corresponds to `state` in [`pgx::aggregate::Aggregate`].
66    pub sfunc: &'static str,
67
68    /// The `FINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
69    ///
70    /// Corresponds to `finalize` in [`pgx::aggregate::Aggregate`].
71    pub finalfunc: Option<&'static str>,
72
73    /// The `FINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
74    ///
75    /// Corresponds to `FINALIZE_MODIFY` in [`pgx::aggregate::Aggregate`].
76    pub finalfunc_modify: Option<FinalizeModify>,
77
78    /// The `COMBINEFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
79    ///
80    /// Corresponds to `combine` in [`pgx::aggregate::Aggregate`].
81    pub combinefunc: Option<&'static str>,
82
83    /// The `SERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
84    ///
85    /// Corresponds to `serial` in [`pgx::aggregate::Aggregate`].
86    pub serialfunc: Option<&'static str>,
87
88    /// The `DESERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
89    ///
90    /// Corresponds to `deserial` in [`pgx::aggregate::Aggregate`].
91    pub deserialfunc: Option<&'static str>,
92
93    /// The `INITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
94    ///
95    /// Corresponds to `INITIAL_CONDITION` in [`pgx::aggregate::Aggregate`].
96    pub initcond: Option<&'static str>,
97
98    /// The `MSFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
99    ///
100    /// Corresponds to `moving_state` in [`pgx::aggregate::Aggregate`].
101    pub msfunc: Option<&'static str>,
102
103    /// The `MINVFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
104    ///
105    /// Corresponds to `moving_state_inverse` in [`pgx::aggregate::Aggregate`].
106    pub minvfunc: Option<&'static str>,
107
108    /// The `MSTYPE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
109    ///
110    /// Corresponds to `MovingState` in [`pgx::aggregate::Aggregate`].
111    pub mstype: Option<UsedTypeEntity>,
112
113    // The `MSSPACE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
114    //
115    // TODO: Currently unused.
116    // pub msspace: &'static str,
117    /// The `MFINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
118    ///
119    /// Corresponds to `moving_state_finalize` in [`pgx::aggregate::Aggregate`].
120    pub mfinalfunc: Option<&'static str>,
121
122    /// The `MFINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
123    ///
124    /// Corresponds to `MOVING_FINALIZE_MODIFY` in [`pgx::aggregate::Aggregate`].
125    pub mfinalfunc_modify: Option<FinalizeModify>,
126
127    /// The `MINITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
128    ///
129    /// Corresponds to `MOVING_INITIAL_CONDITION` in [`pgx::aggregate::Aggregate`].
130    pub minitcond: Option<&'static str>,
131
132    /// The `SORTOP` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
133    ///
134    /// Corresponds to `SORT_OPERATOR` in [`pgx::aggregate::Aggregate`].
135    pub sortop: Option<&'static str>,
136
137    /// The `PARALLEL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
138    ///
139    /// Corresponds to `PARALLEL` in [`pgx::aggregate::Aggregate`].
140    pub parallel: Option<ParallelOption>,
141
142    /// The `HYPOTHETICAL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
143    ///
144    /// Corresponds to `hypothetical` in [`pgx::aggregate::Aggregate`].
145    pub hypothetical: bool,
146    pub to_sql_config: ToSqlConfigEntity,
147}
148
149impl From<PgAggregateEntity> for SqlGraphEntity {
150    fn from(val: PgAggregateEntity) -> Self {
151        SqlGraphEntity::Aggregate(val)
152    }
153}
154
155impl SqlGraphIdentifier for PgAggregateEntity {
156    fn dot_identifier(&self) -> String {
157        format!("aggregate {}", self.full_path)
158    }
159    fn rust_identifier(&self) -> String {
160        self.full_path.to_string()
161    }
162    fn file(&self) -> Option<&'static str> {
163        Some(self.file)
164    }
165    fn line(&self) -> Option<u32> {
166        Some(self.line)
167    }
168}
169
170impl ToSql for PgAggregateEntity {
171    fn to_sql(&self, context: &PgxSql) -> eyre::Result<String> {
172        let self_index = context.aggregates[self];
173        let mut optional_attributes = Vec::new();
174        let schema = context.schema_prefix_for(&self_index);
175
176        if let Some(value) = self.finalfunc {
177            optional_attributes.push((
178                format!("\tFINALFUNC = {}\"{}\"", schema, value),
179                format!("/* {}::final */", self.full_path),
180            ));
181        }
182        if let Some(value) = self.finalfunc_modify {
183            optional_attributes.push((
184                format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
185                format!("/* {}::FINALIZE_MODIFY */", self.full_path),
186            ));
187        }
188        if let Some(value) = self.combinefunc {
189            optional_attributes.push((
190                format!("\tCOMBINEFUNC = {}\"{}\"", schema, value),
191                format!("/* {}::combine */", self.full_path),
192            ));
193        }
194        if let Some(value) = self.serialfunc {
195            optional_attributes.push((
196                format!("\tSERIALFUNC = {}\"{}\"", schema, value),
197                format!("/* {}::serial */", self.full_path),
198            ));
199        }
200        if let Some(value) = self.deserialfunc {
201            optional_attributes.push((
202                format!("\tDESERIALFUNC ={} \"{}\"", schema, value),
203                format!("/* {}::deserial */", self.full_path),
204            ));
205        }
206        if let Some(value) = self.initcond {
207            optional_attributes.push((
208                format!("\tINITCOND = '{}'", value),
209                format!("/* {}::INITIAL_CONDITION */", self.full_path),
210            ));
211        }
212        if let Some(value) = self.msfunc {
213            optional_attributes.push((
214                format!("\tMSFUNC = {}\"{}\"", schema, value),
215                format!("/* {}::moving_state */", self.full_path),
216            ));
217        }
218        if let Some(value) = self.minvfunc {
219            optional_attributes.push((
220                format!("\tMINVFUNC = {}\"{}\"", schema, value),
221                format!("/* {}::moving_state_inverse */", self.full_path),
222            ));
223        }
224        if let Some(value) = self.mfinalfunc {
225            optional_attributes.push((
226                format!("\tMFINALFUNC = {}\"{}\"", schema, value),
227                format!("/* {}::moving_state_finalize */", self.full_path),
228            ));
229        }
230        if let Some(value) = self.mfinalfunc_modify {
231            optional_attributes.push((
232                format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
233                format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
234            ));
235        }
236        if let Some(value) = self.minitcond {
237            optional_attributes.push((
238                format!("\tMINITCOND = '{}'", value),
239                format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
240            ));
241        }
242        if let Some(value) = self.sortop {
243            optional_attributes.push((
244                format!("\tSORTOP = \"{}\"", value),
245                format!("/* {}::SORT_OPERATOR */", self.full_path),
246            ));
247        }
248        if let Some(value) = self.parallel {
249            optional_attributes.push((
250                format!("\tPARALLEL = {}", value.to_sql(context)?),
251                format!("/* {}::PARALLEL */", self.full_path),
252            ));
253        }
254        if self.hypothetical {
255            optional_attributes.push((
256                String::from("\tHYPOTHETICAL"),
257                format!("/* {}::hypothetical */", self.full_path),
258            ))
259        }
260
261        let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
262            match used_ty.metadata.argument_sql {
263                Ok(SqlMapping::As(ref argument_sql)) => Ok(argument_sql.to_string()),
264                Ok(SqlMapping::Composite { array_brackets }) => used_ty
265                    .composite_type
266                    .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
267                    .ok_or_else(|| {
268                        eyre!("Macro expansion time suggested a composite_type!() in return")
269                    }),
270                Ok(SqlMapping::Source { array_brackets }) => {
271                    let sql = context
272                        .source_only_to_sql_type(used_ty.ty_source)
273                        .map(|v| if array_brackets { format!("{v}[]") } else { format!("{v}") })
274                        .ok_or_else(|| {
275                            eyre!("Macro expansion time suggested a source only mapping in return")
276                        })?;
277                    Ok(sql)
278                }
279                Ok(SqlMapping::Skip) => {
280                    Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
281                }
282                Err(err) => match context.source_only_to_sql_type(used_ty.ty_source) {
283                    Some(source_only_mapping) => Ok(source_only_mapping.to_string()),
284                    None => return Err(err).wrap_err("While mapping argument"),
285                },
286            }
287        };
288
289        let stype_sql = map_ty(&self.stype.used_ty).wrap_err("Mapping state type")?;
290        let mut stype_schema = String::from("");
291        for (ty_item, ty_index) in context.types.iter() {
292            if ty_item.id_matches(&self.stype.used_ty.ty_id) {
293                stype_schema = context.schema_prefix_for(ty_index);
294                break;
295            }
296        }
297        if String::is_empty(&stype_schema) {
298            for (ty_item, ty_index) in context.enums.iter() {
299                if ty_item.id_matches(&self.stype.used_ty.ty_id) {
300                    stype_schema = context.schema_prefix_for(ty_index);
301                    break;
302                }
303            }
304        }
305
306        if let Some(value) = &self.mstype {
307            let mstype_sql = map_ty(&value).wrap_err("Mapping moving state type")?;
308            optional_attributes.push((
309                format!("\tMSTYPE = {}", mstype_sql),
310                format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
311            ));
312        }
313
314        let mut optional_attributes_string = String::new();
315        for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
316            let optional_attribute_string = format!(
317                "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
318                optional_attribute = optional_attribute,
319                maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
320                comment = comment,
321                maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
322            );
323            optional_attributes_string += &optional_attribute_string;
324        }
325
326        let sql = format!(
327            "\n\
328                -- {file}:{line}\n\
329                -- {full_path}\n\
330                CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
331                (\n\
332                    \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
333                    \tSTYPE = {stype_schema}{stype}{maybe_comma_after_stype} /* {stype_full_path} */\
334                    {optional_attributes}\
335                );\
336            ",
337            schema = schema,
338            name = self.name,
339            full_path = self.full_path,
340            file = self.file,
341            line = self.line,
342            sfunc = self.sfunc,
343            stype_schema = stype_schema,
344            stype = stype_sql,
345            stype_full_path = self.stype.used_ty.full_path,
346            maybe_comma_after_stype = if optional_attributes.len() == 0 { "" } else { "," },
347            args = {
348                let mut args = Vec::new();
349                for (idx, arg) in self.args.iter().enumerate() {
350                    let graph_index = context
351                        .graph
352                        .neighbors_undirected(self_index)
353                        .find(|neighbor| match &context.graph[*neighbor] {
354                            SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
355                            SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
356                            SqlGraphEntity::BuiltinType(defined) => {
357                                defined == &arg.used_ty.full_path
358                            }
359                            _ => false,
360                        })
361                        .ok_or_else(|| {
362                            eyre!("Could not find arg type in graph. Got: {:?}", arg.used_ty)
363                        })?;
364                    let needs_comma = idx < (self.args.len() - 1);
365                    let buf = format!("\
366                           \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
367                       ",
368                           schema_prefix = context.schema_prefix_for(&graph_index),
369                           // First try to match on [`TypeId`] since it's most reliable.
370                           sql_type = match arg.used_ty.metadata.argument_sql {
371                                Ok(SqlMapping::As(ref argument_sql)) => {
372                                    argument_sql.to_string()
373                                }
374                                Ok(SqlMapping::Composite {
375                                    array_brackets,
376                                }) => {
377                                    arg.used_ty
378                                        .composite_type
379                                        .map(|v| {
380                                            if array_brackets {
381                                                format!("{v}[]")
382                                            } else {
383                                                format!("{v}")
384                                            }
385                                        })
386                                        .ok_or_else(|| {
387                                            eyre!(
388                                            "Macro expansion time suggested a composite_type!() in return"
389                                        )
390                                        })?
391                                }
392                                Ok(SqlMapping::Source {
393                                    array_brackets,
394                                }) => {
395                                    let sql = context
396                                        .source_only_to_sql_type(arg.used_ty.ty_source)
397                                        .map(|v| {
398                                            if array_brackets {
399                                                format!("{v}[]")
400                                            } else {
401                                                format!("{v}")
402                                            }
403                                        })
404                                        .ok_or_else(|| {
405                                            eyre!(
406                                            "Macro expansion time suggested a source only mapping in return"
407                                        )
408                                        })?;
409                                    sql
410                                }
411                                Ok(SqlMapping::Skip) => return Err(eyre!("Got a skipped SQL translatable type in aggregate args, this is not permitted")),
412                                Err(err) => {
413                                    match context.source_only_to_sql_type(arg.used_ty.ty_source) {
414                                        Some(source_only_mapping) => {
415                                            source_only_mapping.to_string()
416                                        }
417                                        None => return Err(err).wrap_err("While mapping argument"),
418                                    }
419                                }
420                            },
421                           variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
422                           maybe_comma = if needs_comma { ", " } else { " " },
423                           full_path = arg.used_ty.full_path,
424                           name = if let Some(name) = arg.name {
425                               format!(r#""{}" "#, name)
426                           } else { "".to_string() },
427                    );
428                    args.push(buf);
429                }
430                "\n".to_string() + &args.join("\n") + "\n"
431            },
432            direct_args = if let Some(direct_args) = &self.direct_args {
433                let mut args = Vec::new();
434                for (idx, arg) in direct_args.iter().enumerate() {
435                    let graph_index = context
436                        .graph
437                        .neighbors_undirected(self_index)
438                        .find(|neighbor| match &context.graph[*neighbor] {
439                            SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
440                            SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
441                            SqlGraphEntity::BuiltinType(defined) => {
442                                defined == &arg.used_ty.full_path
443                            }
444                            _ => false,
445                        })
446                        .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
447                    let needs_comma = idx < (direct_args.len() - 1);
448                    let buf = format!(
449                        "\
450                        \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
451                       ",
452                        schema_prefix = context.schema_prefix_for(&graph_index),
453                        // First try to match on [`TypeId`] since it's most reliable.
454                        sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
455                        maybe_name = if let Some(name) = arg.name {
456                            "\"".to_string() + name + "\" "
457                        } else {
458                            "".to_string()
459                        },
460                        maybe_comma = if needs_comma { ", " } else { " " },
461                        full_path = arg.used_ty.full_path,
462                    );
463                    args.push(buf);
464                }
465                "\n".to_string() + &args.join("\n,") + "\n"
466            } else {
467                String::default()
468            },
469            maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
470            optional_attributes = String::from("\n")
471                + &optional_attributes_string
472                + if optional_attributes.len() == 0 { "" } else { "\n" },
473        );
474        Ok(sql)
475    }
476}