Skip to main content

pgrx_sql_entity_graph/pg_extern/entity/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_extern]` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18mod argument;
19mod cast;
20mod operator;
21mod returning;
22
23pub use argument::PgExternArgumentEntity;
24pub use cast::PgCastEntity;
25pub use operator::PgOperatorEntity;
26pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem};
27
28use crate::UsedTypeEntity;
29use crate::fmt;
30use crate::metadata::{Returns, SqlArrayMapping, SqlMapping};
31use crate::pgrx_sql::PgrxSql;
32use crate::to_sql::ToSql;
33use crate::to_sql::entity::ToSqlConfigEntity;
34use crate::{ExternArgs, SqlGraphEntity, SqlGraphIdentifier};
35
36use eyre::{WrapErr, eyre};
37use petgraph::graph::NodeIndex;
38
39/// The output of a [`PgExtern`](crate::pg_extern::PgExtern) from `quote::ToTokens::to_tokens`.
40#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
41pub struct PgExternEntity<'a> {
42    pub name: &'a str,
43    pub unaliased_name: &'a str,
44    pub module_path: &'a str,
45    pub full_path: &'a str,
46    pub fn_args: Vec<PgExternArgumentEntity<'a>>,
47    pub fn_return: PgExternReturnEntity<'a>,
48    pub schema: Option<&'a str>,
49    pub file: &'a str,
50    pub line: u32,
51    pub extern_attrs: Vec<ExternArgs>,
52    pub search_path: Option<Vec<&'a str>>,
53    pub operator: Option<PgOperatorEntity<'a>>,
54    pub cast: Option<PgCastEntity>,
55    pub to_sql_config: ToSqlConfigEntity<'a>,
56}
57
58impl<'a> From<PgExternEntity<'a>> for SqlGraphEntity<'a> {
59    fn from(val: PgExternEntity<'a>) -> Self {
60        SqlGraphEntity::Function(val)
61    }
62}
63
64impl SqlGraphIdentifier for PgExternEntity<'_> {
65    fn dot_identifier(&self) -> String {
66        format!("fn {}", self.name)
67    }
68    fn rust_identifier(&self) -> String {
69        self.full_path.to_string()
70    }
71
72    fn file(&self) -> Option<&str> {
73        Some(self.file)
74    }
75
76    fn line(&self) -> Option<u32> {
77        Some(self.line)
78    }
79}
80
81impl PgExternEntity<'_> {
82    fn sql_name(&self, context: &PgrxSql) -> String {
83        let self_index = context.externs[self];
84        let schema = self
85            .schema
86            .map(|schema| format!("{schema}."))
87            .unwrap_or_else(|| context.schema_prefix_for(&self_index));
88
89        format!("{schema}\"{}\"", self.name)
90    }
91}
92
93fn composite_sql_type(composite_type: Option<&str>) -> eyre::Result<String> {
94    composite_type
95        .map(ToString::to_string)
96        .ok_or_else(|| eyre!("Composite mapping requires composite_type"))
97}
98
99fn array_sql_type(mapping: &SqlArrayMapping, composite_type: Option<&str>) -> eyre::Result<String> {
100    Ok(match mapping {
101        SqlArrayMapping::As(sql) => fmt::with_array_brackets(sql.clone(), 1),
102        SqlArrayMapping::Composite => {
103            fmt::with_array_brackets(composite_sql_type(composite_type)?, 1)
104        }
105    })
106}
107
108fn sql_type(mapping: &SqlMapping, composite_type: Option<&str>) -> eyre::Result<String> {
109    match mapping {
110        SqlMapping::As(sql) => Ok(sql.clone()),
111        SqlMapping::Composite => composite_sql_type(composite_type),
112        SqlMapping::Array(value) => array_sql_type(value, composite_type),
113        SqlMapping::Skip => Err(eyre!("Found a skipped SQL type where SQL should be emitted")),
114    }
115}
116
117/// Render the SQL spelling of one `UsedType`, with schema prefix applied.
118/// This is the bit that sits right after `"name" VARIADIC ` in a CREATE
119/// FUNCTION signature.
120pub(crate) fn render_used_type_sql(
121    context: &PgrxSql,
122    owner: NodeIndex,
123    slot: &str,
124    used_ty: &UsedTypeEntity,
125) -> eyre::Result<String> {
126    let schema_prefix = context.schema_prefix_for_used_type(&owner, slot, used_ty)?;
127    let body = match used_ty.metadata.argument_sql {
128        Ok(SqlMapping::As(ref sql)) => sql.clone(),
129        Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
130            sql_type(mapping, used_ty.composite_type)?
131        }
132        Ok(SqlMapping::Skip) => {
133            return Err(eyre!("Found a skipped SQL type where SQL should be emitted"));
134        }
135        Err(err) => return Err(err.into()),
136    };
137    Ok(format!("{schema_prefix}{body}"))
138}
139
140/// Render the comma-separated argument-type list for a pg_extern, matching
141/// the positional shape of CREATE FUNCTION but without names or defaults.
142/// Skipped args are filtered out; variadic args get a `VARIADIC ` prefix.
143pub(crate) fn render_function_argtypes(
144    context: &PgrxSql,
145    owner: NodeIndex,
146    f: &PgExternEntity,
147) -> eyre::Result<String> {
148    let mut pieces = Vec::new();
149    for arg in f.fn_args.iter().filter(|a| a.used_ty.emits_argument_sql()) {
150        let slot = format!("argument `{}`", arg.pattern);
151        let rendered = render_used_type_sql(context, owner, &slot, &arg.used_ty)?;
152        if arg.used_ty.variadic {
153            pieces.push(format!("VARIADIC {rendered}"));
154        } else {
155            pieces.push(rendered);
156        }
157    }
158    Ok(pieces.join(", "))
159}
160
161/// Render the return type of a pg_extern for contexts that need a plain
162/// scalar, such as CREATE CAST. Errors if the function returns a set or a
163/// table.
164pub(crate) fn render_function_return_type(
165    context: &PgrxSql,
166    owner: NodeIndex,
167    f: &PgExternEntity,
168) -> eyre::Result<String> {
169    let ty = match &f.fn_return {
170        PgExternReturnEntity::Type { ty } => ty,
171        PgExternReturnEntity::None => {
172            return Err(eyre!("Cannot render return type for a function with no return"));
173        }
174        other => {
175            return Err(eyre!("Cannot render a scalar return type for {other:?}"));
176        }
177    };
178    let schema_prefix = context.schema_prefix_for_used_type(&owner, "return type", ty)?;
179    let body = match &ty.metadata.return_sql {
180        Ok(Returns::One(SqlMapping::As(sql))) => sql.clone(),
181        Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => {
182            sql_type(mapping, ty.composite_type)?
183        }
184        Ok(Returns::One(SqlMapping::Skip)) => {
185            return Err(eyre!("Return type was SqlMapping::Skip"));
186        }
187        Ok(other) => {
188            return Err(eyre!("Return type is not a scalar: {other:?}"));
189        }
190        Err(err) => return Err((*err).into()),
191    };
192    Ok(format!("{schema_prefix}{body}"))
193}
194
195impl ToSql for PgExternEntity<'_> {
196    fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
197        let self_index = context.externs[self];
198        let mut extern_attrs = self.extern_attrs.clone();
199        // if we already have a STRICT marker we do not need to add it
200        // presume we can upgrade, then disprove it
201        let mut strict_upgrade = !extern_attrs.iter().any(|i| i == &ExternArgs::Strict);
202        if strict_upgrade {
203            // It may be possible to infer a `STRICT` marker though.
204            // But we can only do that if the user hasn't used a nullable argument wrapper.
205            for arg in &self.fn_args {
206                if arg.used_ty.optional {
207                    strict_upgrade = false;
208                }
209            }
210        }
211
212        if strict_upgrade {
213            extern_attrs.push(ExternArgs::Strict);
214        }
215        extern_attrs.sort();
216        extern_attrs.dedup();
217
218        let module_pathname = &context.get_module_pathname();
219        let schema = self
220            .schema
221            .map(|schema| format!("{schema}."))
222            .unwrap_or_else(|| context.schema_prefix_for(&self_index));
223        let arguments = if !self.fn_args.is_empty() {
224            let mut args = Vec::new();
225            let sql_args = self
226                .fn_args
227                .iter()
228                .filter(|arg| arg.used_ty.emits_argument_sql())
229                .collect::<Vec<_>>();
230            for (idx, arg) in sql_args.iter().enumerate() {
231                let needs_comma = idx < (sql_args.len().saturating_sub(1));
232                let schema_prefix = context.schema_prefix_for_used_type(
233                    &self_index,
234                    &format!("argument `{}`", arg.pattern),
235                    &arg.used_ty,
236                )?;
237                match arg.used_ty.metadata.argument_sql {
238                    Ok(SqlMapping::As(ref argument_sql)) => {
239                        let buf = format!(
240                            "\
241                                            \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
242                                        ",
243                            pattern = arg.pattern,
244                            schema_prefix = schema_prefix,
245                            // The SQL spelling comes from the embedded schema metadata.
246                            sql_type = argument_sql,
247                            default = if let Some(def) = arg.used_ty.default {
248                                format!(" DEFAULT {def}")
249                            } else {
250                                String::from("")
251                            },
252                            variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
253                            maybe_comma = if needs_comma { ", " } else { " " },
254                            type_name = arg.used_ty.full_path,
255                        );
256                        args.push(buf);
257                    }
258                    Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
259                        let sql = sql_type(mapping, arg.used_ty.composite_type)?;
260                        let buf = format!(
261                            "\
262                            \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
263                        ",
264                            pattern = arg.pattern,
265                            schema_prefix = schema_prefix,
266                            // The SQL spelling comes from the embedded schema metadata.
267                            sql_type = sql,
268                            default = if let Some(def) = arg.used_ty.default {
269                                format!(" DEFAULT {def}")
270                            } else {
271                                String::from("")
272                            },
273                            variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
274                            maybe_comma = if needs_comma { ", " } else { " " },
275                            type_name = arg.used_ty.full_path,
276                        );
277                        args.push(buf);
278                    }
279                    Ok(SqlMapping::Skip) => (),
280                    Err(err) => return Err(err).wrap_err("While mapping argument"),
281                }
282            }
283            String::from("\n") + &args.join("\n") + "\n"
284        } else {
285            Default::default()
286        };
287
288        let returns = match &self.fn_return {
289            PgExternReturnEntity::None => String::from("RETURNS void"),
290            PgExternReturnEntity::Type { ty } => {
291                let (schema_prefix, sql_type) = match &ty.metadata.return_sql {
292                    Ok(Returns::One(SqlMapping::As(sql))) => (
293                        context.schema_prefix_for_used_type(&self_index, "return type", ty)?,
294                        sql.clone(),
295                    ),
296                    Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => (
297                        context.schema_prefix_for_used_type(&self_index, "return type", ty)?,
298                        sql_type(mapping, ty.composite_type)?,
299                    ),
300                    Ok(other) => {
301                        return Err(eyre!(
302                            "Got non-plain mapped/composite return variant SQL in what macro-expansion thought was a type, got: {other:?}"
303                        ));
304                    }
305                    Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
306                };
307                format!(
308                    "RETURNS {schema_prefix}{sql_type} /* {full_path} */",
309                    full_path = ty.full_path
310                )
311            }
312            PgExternReturnEntity::SetOf { ty, .. } => {
313                let (schema_prefix, sql_type) = match &ty.metadata.return_sql {
314                    Ok(Returns::One(SqlMapping::As(sql)))
315                    | Ok(Returns::SetOf(SqlMapping::As(sql))) => (
316                        context.schema_prefix_for_used_type(
317                            &self_index,
318                            "setof return type",
319                            ty,
320                        )?,
321                        sql.clone(),
322                    ),
323                    Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_))))
324                    | Ok(Returns::SetOf(
325                        mapping @ (SqlMapping::Composite | SqlMapping::Array(_)),
326                    )) => (
327                        context.schema_prefix_for_used_type(
328                            &self_index,
329                            "setof return type",
330                            ty,
331                        )?,
332                        sql_type(mapping, ty.composite_type)?,
333                    ),
334                    Ok(other) => {
335                        return Err(eyre!(
336                            "Got non-scalar mapped/composite return variant SQL in what macro-expansion thought was a setof item, got: {other:?}"
337                        ));
338                    }
339                    Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
340                };
341                format!(
342                    "RETURNS SETOF {schema_prefix}{sql_type} /* {full_path} */",
343                    full_path = ty.full_path
344                )
345            }
346            PgExternReturnEntity::Iterated { tys: table_items, .. } => {
347                let mut items = String::new();
348                for (idx, PgExternReturnEntityIteratedItem { ty, name: col_name }) in
349                    table_items.iter().enumerate()
350                {
351                    let needs_comma = idx < (table_items.len() - 1);
352                    let (schema_prefix, ty_resolved) = match &ty.metadata.return_sql {
353                        Ok(Returns::One(SqlMapping::As(sql))) => (
354                            context.schema_prefix_for_used_type(
355                                &self_index,
356                                "table return column",
357                                ty,
358                            )?,
359                            sql.clone(),
360                        ),
361                        Ok(Returns::One(
362                            mapping @ (SqlMapping::Composite | SqlMapping::Array(_)),
363                        )) => (
364                            context.schema_prefix_for_used_type(
365                                &self_index,
366                                "table return column",
367                                ty,
368                            )?,
369                            sql_type(mapping, ty.composite_type)?,
370                        ),
371                        Ok(other) => {
372                            return Err(eyre!(
373                                "Got non-scalar table return item SQL in what macro-expansion thought was a table, got: {other:?}"
374                            ));
375                        }
376                        Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
377                    };
378                    let item = format!(
379                        "\n\t{col_name} {schema_prefix}{ty_resolved}{needs_comma} /* {ty_name} */",
380                        col_name = col_name.expect(
381                            "An iterator of tuples should have `named!()` macro declarations."
382                        ),
383                        schema_prefix = schema_prefix,
384                        ty_resolved = ty_resolved,
385                        needs_comma = if needs_comma { ", " } else { " " },
386                        ty_name = ty.full_path
387                    );
388                    items.push_str(&item);
389                }
390                format!("RETURNS TABLE ({items}\n)")
391            }
392            PgExternReturnEntity::Trigger => String::from("RETURNS trigger"),
393        };
394        let PgExternEntity { name, module_path, file, line, .. } = self;
395
396        let fn_sql = format!(
397            "\
398                CREATE {or_replace} FUNCTION {schema}\"{name}\"({arguments}) {returns}\n\
399                {extern_attrs}\
400                {search_path}\
401                LANGUAGE c /* Rust */\n\
402                AS '{module_pathname}', '{unaliased_name}_wrapper';\
403            ",
404            or_replace =
405                if extern_attrs.contains(&ExternArgs::CreateOrReplace) { "OR REPLACE" } else { "" },
406            search_path = if let Some(search_path) = &self.search_path {
407                let retval = format!("SET search_path TO {}", search_path.join(", "));
408                retval + "\n"
409            } else {
410                Default::default()
411            },
412            extern_attrs = if extern_attrs.is_empty() {
413                String::default()
414            } else {
415                let mut retval = extern_attrs
416                    .iter()
417                    .filter(|attr| **attr != ExternArgs::CreateOrReplace)
418                    .map(|attr| {
419                        if matches!(attr, ExternArgs::Support(..)) {
420                            let support_fn_name = attr.to_string();
421
422                            let support_fn_name =
423                            if let Some(entity) = context.find_matching_fn(&support_fn_name) {
424                                entity.sql_name(context)
425                            } else {
426                                panic!("cannot locate SUPPORT function `{support_fn_name}` attached to function `{}`", self.full_path)
427                            };
428
429                            format!("SUPPORT {support_fn_name}")
430                        } else {
431                            attr.to_string().to_uppercase()
432                        }
433                    })
434                    .collect::<Vec<_>>()
435                    .join(" ");
436                retval.push('\n');
437                retval
438            },
439            unaliased_name = self.unaliased_name,
440        );
441
442        let requires = {
443            let requires_attrs = self
444                .extern_attrs
445                .iter()
446                .filter_map(|x| match x {
447                    ExternArgs::Requires(requirements) => Some(requirements.clone()),
448                    ExternArgs::Support(support_fn) => Some(vec![support_fn.clone()]),
449                    _ => None,
450                })
451                .flatten()
452                .collect::<Vec<_>>();
453
454            if !requires_attrs.is_empty() {
455                format!(
456                    "-- requires:\n{}\n",
457                    requires_attrs
458                        .iter()
459                        .map(|i| format!("--   {i}"))
460                        .collect::<Vec<_>>()
461                        .join("\n")
462                )
463            } else {
464                "".to_string()
465            }
466        };
467
468        let mut ext_sql = format!(
469            "\n\
470            -- {file}:{line}\n\
471            -- {module_path}::{name}\n\
472            {requires}\
473            {fn_sql}"
474        );
475
476        if let Some(op) = &self.operator {
477            let mut optionals = vec![];
478            if let Some(it) = op.commutator {
479                optionals.push(format!("\tCOMMUTATOR = {it}"));
480            };
481            if let Some(it) = op.negator {
482                optionals.push(format!("\tNEGATOR = {it}"));
483            };
484            if let Some(it) = op.restrict {
485                optionals.push(format!("\tRESTRICT = {it}"));
486            };
487            if let Some(it) = op.join {
488                optionals.push(format!("\tJOIN = {it}"));
489            };
490            if op.hashes {
491                optionals.push(String::from("\tHASHES"));
492            };
493            if op.merges {
494                optionals.push(String::from("\tMERGES"));
495            };
496
497            let left_arg = self
498                .fn_args
499                .first()
500                .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
501            let left_arg_sql = render_used_type_sql(
502                context,
503                self_index,
504                "operator left argument",
505                &left_arg.used_ty,
506            )?;
507
508            let right_arg = self
509                .fn_args
510                .get(1)
511                .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
512            let right_arg_sql = render_used_type_sql(
513                context,
514                self_index,
515                "operator right argument",
516                &right_arg.used_ty,
517            )?;
518
519            let schema = self
520                .schema
521                .map(|schema| format!("{schema}."))
522                .unwrap_or_else(|| context.schema_prefix_for(&self_index));
523
524            let operator_sql = format!(
525                "\n\n\
526                                                    -- {file}:{line}\n\
527                                                    -- {module_path}::{name}\n\
528                                                    CREATE OPERATOR {schema}{opname} (\n\
529                                                        \tPROCEDURE={schema}\"{name}\",\n\
530                                                        \tLEFTARG={left_arg_sql}, /* {left_name} */\n\
531                                                        \tRIGHTARG={right_arg_sql}{maybe_comma} /* {right_name} */\n\
532                                                        {optionals}\
533                                                    );\
534                                                    ",
535                opname = op.opname.unwrap(),
536                left_name = left_arg.used_ty.full_path,
537                right_name = right_arg.used_ty.full_path,
538                maybe_comma = if !optionals.is_empty() { "," } else { "" },
539                optionals = if !optionals.is_empty() {
540                    optionals.join(",\n") + "\n"
541                } else {
542                    "".to_string()
543                },
544            );
545            ext_sql += &operator_sql
546        };
547        if let Some(cast) = &self.cast {
548            let target_ty = match &self.fn_return {
549                PgExternReturnEntity::Type { ty } => ty,
550                other => {
551                    return Err(eyre!("Casts must return a plain type, got: {other:?}"));
552                }
553            };
554            let target_arg_sql = render_function_return_type(context, self_index, self)?;
555            let source_arg = self
556                .fn_args
557                .first()
558                .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?;
559            let source_arg_sql =
560                render_used_type_sql(context, self_index, "cast source type", &source_arg.used_ty)?;
561            let optional = match cast {
562                PgCastEntity::Default => String::from(""),
563                PgCastEntity::Assignment => String::from(" AS ASSIGNMENT"),
564                PgCastEntity::Implicit => String::from(" AS IMPLICIT"),
565            };
566
567            let cast_sql = format!(
568                "\n\n\
569                                                    -- {file}:{line}\n\
570                                                    -- {module_path}::{name}\n\
571                                                    CREATE CAST (\n\
572                                                        \t{source_arg_sql} /* {source_name} */\n\
573                                                        \tAS\n\
574                                                        \t{target_arg_sql} /* {target_name} */\n\
575                                                    )\n\
576                                                    WITH FUNCTION {function_name}{optional};\
577                                                    ",
578                file = self.file,
579                line = self.line,
580                name = self.name,
581                module_path = self.module_path,
582                source_name = source_arg.used_ty.full_path,
583                target_name = target_ty.full_path,
584                function_name = self.name,
585            );
586            ext_sql += &cast_sql
587        };
588        Ok(ext_sql)
589    }
590}