Skip to main content

pgrx_sql_entity_graph/aggregate/
entity.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_aggregate]` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19use crate::aggregate::options::{FinalizeModify, ParallelOption};
20use crate::fmt;
21use crate::metadata::{SqlArrayMapping, SqlMapping};
22use crate::pgrx_sql::PgrxSql;
23use crate::to_sql::ToSql;
24use crate::to_sql::entity::ToSqlConfigEntity;
25use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
26use eyre::{WrapErr, eyre};
27use petgraph::graph::NodeIndex;
28
29#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
30pub struct AggregateTypeEntity<'a> {
31    pub used_ty: UsedTypeEntity<'a>,
32    pub name: Option<&'a str>,
33}
34
35#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
36pub struct PgAggregateEntity<'a> {
37    pub full_path: &'a str,
38    pub module_path: &'a str,
39    pub file: &'a str,
40    pub line: u32,
41
42    pub name: &'a str,
43
44    /// If the aggregate is an ordered set aggregate.
45    ///
46    /// See [the PostgreSQL ordered set docs](https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES).
47    pub ordered_set: bool,
48
49    /// The `arg_data_type` list.
50    ///
51    /// Corresponds to `Args` in `pgrx::aggregate::Aggregate`.
52    pub args: Vec<AggregateTypeEntity<'a>>,
53
54    /// The direct argument list, appearing before `ORDER BY` in ordered set aggregates.
55    ///
56    /// Corresponds to `OrderBy` in `pgrx::aggregate::Aggregate`.
57    pub direct_args: Option<Vec<AggregateTypeEntity<'a>>>,
58
59    /// The `STYPE` and `name` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
60    ///
61    /// The implementor of an `pgrx::aggregate::Aggregate`.
62    pub stype: AggregateTypeEntity<'a>,
63
64    /// The `SFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
65    ///
66    /// Corresponds to `state` in `pgrx::aggregate::Aggregate`.
67    pub sfunc: &'a str,
68
69    /// The `FINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
70    ///
71    /// Corresponds to `finalize` in `pgrx::aggregate::Aggregate`.
72    pub finalfunc: Option<&'a str>,
73
74    /// The `FINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
75    ///
76    /// Corresponds to `FINALIZE_MODIFY` in `pgrx::aggregate::Aggregate`.
77    pub finalfunc_modify: Option<FinalizeModify>,
78
79    /// The `COMBINEFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
80    ///
81    /// Corresponds to `combine` in `pgrx::aggregate::Aggregate`.
82    pub combinefunc: Option<&'a str>,
83
84    /// The `SERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
85    ///
86    /// Corresponds to `serial` in `pgrx::aggregate::Aggregate`.
87    pub serialfunc: Option<&'a str>,
88
89    /// The `DESERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
90    ///
91    /// Corresponds to `deserial` in `pgrx::aggregate::Aggregate`.
92    pub deserialfunc: Option<&'a str>,
93
94    /// The `INITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
95    ///
96    /// Corresponds to `INITIAL_CONDITION` in `pgrx::aggregate::Aggregate`.
97    pub initcond: Option<&'a str>,
98
99    /// The `MSFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
100    ///
101    /// Corresponds to `moving_state` in `pgrx::aggregate::Aggregate`.
102    pub msfunc: Option<&'a str>,
103
104    /// The `MINVFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
105    ///
106    /// Corresponds to `moving_state_inverse` in `pgrx::aggregate::Aggregate`.
107    pub minvfunc: Option<&'a str>,
108
109    /// The `MSTYPE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
110    ///
111    /// Corresponds to `MovingState` in `pgrx::aggregate::Aggregate`.
112    pub mstype: Option<UsedTypeEntity<'a>>,
113
114    // The `MSSPACE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
115    //
116    // TODO: Currently unused.
117    // pub msspace: &'a str,
118    /// The `MFINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
119    ///
120    /// Corresponds to `moving_state_finalize` in `pgrx::aggregate::Aggregate`.
121    pub mfinalfunc: Option<&'a str>,
122
123    /// The `MFINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
124    ///
125    /// Corresponds to `MOVING_FINALIZE_MODIFY` in `pgrx::aggregate::Aggregate`.
126    pub mfinalfunc_modify: Option<FinalizeModify>,
127
128    /// The `MINITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
129    ///
130    /// Corresponds to `MOVING_INITIAL_CONDITION` in `pgrx::aggregate::Aggregate`.
131    pub minitcond: Option<&'a str>,
132
133    /// The `SORTOP` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
134    ///
135    /// Corresponds to `SORT_OPERATOR` in `pgrx::aggregate::Aggregate`.
136    pub sortop: Option<&'a str>,
137
138    /// The `PARALLEL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
139    ///
140    /// Corresponds to `PARALLEL` in `pgrx::aggregate::Aggregate`.
141    pub parallel: Option<ParallelOption>,
142
143    /// The `HYPOTHETICAL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
144    ///
145    /// Corresponds to `hypothetical` in `pgrx::aggregate::Aggregate`.
146    pub hypothetical: bool,
147    pub to_sql_config: ToSqlConfigEntity<'a>,
148}
149
150impl<'a> From<PgAggregateEntity<'a>> for SqlGraphEntity<'a> {
151    fn from(val: PgAggregateEntity<'a>) -> Self {
152        SqlGraphEntity::Aggregate(val)
153    }
154}
155
156impl SqlGraphIdentifier for PgAggregateEntity<'_> {
157    fn dot_identifier(&self) -> String {
158        format!("aggregate {}", self.full_path)
159    }
160    fn rust_identifier(&self) -> String {
161        self.full_path.to_string()
162    }
163    fn file(&self) -> Option<&str> {
164        Some(self.file)
165    }
166    fn line(&self) -> Option<u32> {
167        Some(self.line)
168    }
169}
170
171fn aggregate_sql_type(mapping: &SqlMapping, composite_type: Option<&str>) -> eyre::Result<String> {
172    match mapping {
173        SqlMapping::As(sql) => Ok(sql.clone()),
174        SqlMapping::Composite => composite_type
175            .map(ToString::to_string)
176            .ok_or_else(|| eyre!("Composite mapping requires composite_type")),
177        SqlMapping::Array(SqlArrayMapping::As(sql)) => Ok(fmt::with_array_brackets(sql.clone(), 1)),
178        SqlMapping::Array(SqlArrayMapping::Composite) => composite_type
179            .map(ToString::to_string)
180            .map(|sql| fmt::with_array_brackets(sql, 1))
181            .ok_or_else(|| eyre!("Composite mapping requires composite_type")),
182        SqlMapping::Skip => {
183            Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
184        }
185    }
186}
187
188/// Render the positional argument-type signature for an aggregate as it
189/// would appear inside `ALTER EXTENSION … ADD AGGREGATE name(…)`. For
190/// ordered-set aggregates the rendering is `(direct ORDER BY args)`;
191/// otherwise it is `(args)`. Matches the shape produced by
192/// `PgAggregateEntity::to_sql`.
193pub(crate) fn render_aggregate_argtypes(
194    context: &PgrxSql,
195    owner: NodeIndex,
196    a: &PgAggregateEntity,
197) -> eyre::Result<String> {
198    let render_slot = |arg: &AggregateTypeEntity| -> eyre::Result<String> {
199        let slot = arg.name.unwrap_or("aggregate argument");
200        let prefix = context.schema_prefix_for_used_type(&owner, slot, &arg.used_ty)?;
201        let sql = match arg.used_ty.metadata.argument_sql {
202            Ok(ref mapping) => aggregate_sql_type(mapping, arg.used_ty.composite_type)?,
203            Err(err) => return Err(err.into()),
204        };
205        let variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" };
206        Ok(format!("{variadic}{prefix}{sql}"))
207    };
208
209    let args = a.args.iter().map(render_slot).collect::<eyre::Result<Vec<_>>>()?.join(", ");
210    let direct = a.direct_args.as_deref().unwrap_or(&[]);
211
212    if a.ordered_set {
213        let direct_rendered =
214            direct.iter().map(render_slot).collect::<eyre::Result<Vec<_>>>()?.join(", ");
215        Ok(format!("({direct_rendered} ORDER BY {args})"))
216    } else {
217        Ok(format!("({args})"))
218    }
219}
220
221impl ToSql for PgAggregateEntity<'_> {
222    fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
223        let self_index = context.aggregates[self];
224        let mut optional_attributes = Vec::new();
225        let schema = context.schema_prefix_for(&self_index);
226
227        if let Some(value) = self.finalfunc {
228            optional_attributes.push((
229                format!("\tFINALFUNC = {schema}\"{value}\""),
230                format!("/* {}::final */", self.full_path),
231            ));
232        }
233        if let Some(value) = self.finalfunc_modify {
234            optional_attributes.push((
235                format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
236                format!("/* {}::FINALIZE_MODIFY */", self.full_path),
237            ));
238        }
239        if let Some(value) = self.combinefunc {
240            optional_attributes.push((
241                format!("\tCOMBINEFUNC = {schema}\"{value}\""),
242                format!("/* {}::combine */", self.full_path),
243            ));
244        }
245        if let Some(value) = self.serialfunc {
246            optional_attributes.push((
247                format!("\tSERIALFUNC = {schema}\"{value}\""),
248                format!("/* {}::serial */", self.full_path),
249            ));
250        }
251        if let Some(value) = self.deserialfunc {
252            optional_attributes.push((
253                format!("\tDESERIALFUNC ={schema} \"{value}\""),
254                format!("/* {}::deserial */", self.full_path),
255            ));
256        }
257        if let Some(value) = self.initcond {
258            optional_attributes.push((
259                format!("\tINITCOND = '{value}'"),
260                format!("/* {}::INITIAL_CONDITION */", self.full_path),
261            ));
262        }
263        if let Some(value) = self.msfunc {
264            optional_attributes.push((
265                format!("\tMSFUNC = {schema}\"{value}\""),
266                format!("/* {}::moving_state */", self.full_path),
267            ));
268        }
269        if let Some(value) = self.minvfunc {
270            optional_attributes.push((
271                format!("\tMINVFUNC = {schema}\"{value}\""),
272                format!("/* {}::moving_state_inverse */", self.full_path),
273            ));
274        }
275        if let Some(value) = self.mfinalfunc {
276            optional_attributes.push((
277                format!("\tMFINALFUNC = {schema}\"{value}\""),
278                format!("/* {}::moving_state_finalize */", self.full_path),
279            ));
280        }
281        if let Some(value) = self.mfinalfunc_modify {
282            optional_attributes.push((
283                format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
284                format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
285            ));
286        }
287        if let Some(value) = self.minitcond {
288            optional_attributes.push((
289                format!("\tMINITCOND = '{value}'"),
290                format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
291            ));
292        }
293        if let Some(value) = self.sortop {
294            optional_attributes.push((
295                format!("\tSORTOP = \"{value}\""),
296                format!("/* {}::SORT_OPERATOR */", self.full_path),
297            ));
298        }
299        if let Some(value) = self.parallel {
300            optional_attributes.push((
301                format!("\tPARALLEL = {}", value.to_sql(context)?),
302                format!("/* {}::PARALLEL */", self.full_path),
303            ));
304        }
305        if self.hypothetical {
306            optional_attributes.push((
307                String::from("\tHYPOTHETICAL"),
308                format!("/* {}::hypothetical */", self.full_path),
309            ))
310        }
311
312        let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
313            match used_ty.metadata.argument_sql {
314                Ok(ref mapping) => aggregate_sql_type(mapping, used_ty.composite_type),
315                Err(err) => Err(err).wrap_err("While mapping argument"),
316            }
317        };
318
319        let sql_type_for_slot = |slot: &str,
320                                 used_ty: &UsedTypeEntity|
321         -> eyre::Result<(String, String)> {
322            let sql = map_ty(used_ty).wrap_err_with(|| format!("Mapping {slot}"))?;
323            let schema_prefix = context.schema_prefix_for_used_type(&self_index, slot, used_ty)?;
324            Ok((schema_prefix, sql))
325        };
326        let (stype_schema, stype_sql) = sql_type_for_slot("STYPE", &self.stype.used_ty)?;
327
328        if let Some(value) = &self.mstype {
329            let (mstype_schema, mstype_sql) = sql_type_for_slot("MSTYPE", value)?;
330            optional_attributes.push((
331                format!("\tMSTYPE = {mstype_schema}{mstype_sql}"),
332                format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
333            ));
334        }
335
336        let mut optional_attributes_string = String::new();
337        for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
338            let optional_attribute_string = format!(
339                "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
340                optional_attribute = optional_attribute,
341                maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
342                comment = comment,
343                maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
344            );
345            optional_attributes_string += &optional_attribute_string;
346        }
347
348        let args = {
349            let mut args = Vec::new();
350            for (idx, arg) in self.args.iter().enumerate() {
351                let needs_comma = idx < (self.args.len() - 1);
352                let schema_prefix = context.schema_prefix_for_used_type(
353                    &self_index,
354                    arg.name.unwrap_or("aggregate argument"),
355                    &arg.used_ty,
356                )?;
357                let buf = format!(
358                    "\
359                       \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
360                   ",
361                    schema_prefix = schema_prefix,
362                    // The SQL spelling comes from the embedded schema metadata.
363                    sql_type = match arg.used_ty.metadata.argument_sql {
364                        Ok(ref mapping) => aggregate_sql_type(mapping, arg.used_ty.composite_type)?,
365                        Err(err) => return Err(err).wrap_err("While mapping argument"),
366                    },
367                    variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
368                    maybe_comma = if needs_comma { ", " } else { " " },
369                    full_path = arg.used_ty.full_path,
370                    name = if let Some(name) = arg.name {
371                        format!(r#""{name}" "#)
372                    } else {
373                        "".to_string()
374                    },
375                );
376                args.push(buf);
377            }
378            "\n".to_string() + &args.join("\n") + "\n"
379        };
380        let direct_args = if let Some(direct_args) = &self.direct_args {
381            let mut args = Vec::new();
382            for (idx, arg) in direct_args.iter().enumerate() {
383                let schema_prefix = context.schema_prefix_for_used_type(
384                    &self_index,
385                    arg.name.unwrap_or("aggregate direct argument"),
386                    &arg.used_ty,
387                )?;
388                let needs_comma = idx < (direct_args.len() - 1);
389                let buf = format!(
390                    "\
391                    \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
392                   ",
393                    schema_prefix = schema_prefix,
394                    // The SQL spelling comes from the embedded schema metadata.
395                    sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
396                    maybe_name = if let Some(name) = arg.name {
397                        "\"".to_string() + name + "\" "
398                    } else {
399                        "".to_string()
400                    },
401                    maybe_comma = if needs_comma { ", " } else { " " },
402                    full_path = arg.used_ty.full_path,
403                );
404                args.push(buf);
405            }
406            "\n".to_string() + &args.join("\n") + "\n"
407        } else {
408            String::default()
409        };
410
411        let PgAggregateEntity { name, full_path, file, line, sfunc, .. } = self;
412
413        let sql = format!(
414            "\n\
415                -- {file}:{line}\n\
416                -- {full_path}\n\
417                CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
418                (\n\
419                    \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
420                    \tSTYPE = {stype_schema}{stype_sql}{maybe_comma_after_stype} /* {stype_full_path} */\
421                    {optional_attributes}\
422                );\
423            ",
424            stype_full_path = self.stype.used_ty.full_path,
425            maybe_comma_after_stype = if optional_attributes.is_empty() { "" } else { "," },
426            maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
427            optional_attributes = String::from("\n")
428                + &optional_attributes_string
429                + if optional_attributes.is_empty() { "" } else { "\n" },
430        );
431        Ok(sql)
432    }
433}