Skip to main content

pgrx_sql_entity_graph/pg_extern/entity/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_extern]` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18mod argument;
19mod cast;
20mod operator;
21mod returning;
22
23pub use argument::PgExternArgumentEntity;
24pub use cast::PgCastEntity;
25pub use operator::PgOperatorEntity;
26pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem};
27
28use crate::fmt;
29use crate::metadata::{Returns, SqlArrayMapping, SqlMapping};
30use crate::pgrx_sql::PgrxSql;
31use crate::to_sql::ToSql;
32use crate::to_sql::entity::ToSqlConfigEntity;
33use crate::{ExternArgs, SqlGraphEntity, SqlGraphIdentifier};
34
35use eyre::{WrapErr, eyre};
36
37/// The output of a [`PgExtern`](crate::pg_extern::PgExtern) from `quote::ToTokens::to_tokens`.
38#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
39pub struct PgExternEntity<'a> {
40    pub name: &'a str,
41    pub unaliased_name: &'a str,
42    pub module_path: &'a str,
43    pub full_path: &'a str,
44    pub fn_args: Vec<PgExternArgumentEntity<'a>>,
45    pub fn_return: PgExternReturnEntity<'a>,
46    pub schema: Option<&'a str>,
47    pub file: &'a str,
48    pub line: u32,
49    pub extern_attrs: Vec<ExternArgs>,
50    pub search_path: Option<Vec<&'a str>>,
51    pub operator: Option<PgOperatorEntity<'a>>,
52    pub cast: Option<PgCastEntity>,
53    pub to_sql_config: ToSqlConfigEntity<'a>,
54}
55
56impl<'a> From<PgExternEntity<'a>> for SqlGraphEntity<'a> {
57    fn from(val: PgExternEntity<'a>) -> Self {
58        SqlGraphEntity::Function(val)
59    }
60}
61
62impl SqlGraphIdentifier for PgExternEntity<'_> {
63    fn dot_identifier(&self) -> String {
64        format!("fn {}", self.name)
65    }
66    fn rust_identifier(&self) -> String {
67        self.full_path.to_string()
68    }
69
70    fn file(&self) -> Option<&str> {
71        Some(self.file)
72    }
73
74    fn line(&self) -> Option<u32> {
75        Some(self.line)
76    }
77}
78
79impl PgExternEntity<'_> {
80    fn sql_name(&self, context: &PgrxSql) -> String {
81        let self_index = context.externs[self];
82        let schema = self
83            .schema
84            .map(|schema| format!("{schema}."))
85            .unwrap_or_else(|| context.schema_prefix_for(&self_index));
86
87        format!("{schema}\"{}\"", self.name)
88    }
89}
90
91fn composite_sql_type(composite_type: Option<&str>) -> eyre::Result<String> {
92    composite_type
93        .map(ToString::to_string)
94        .ok_or_else(|| eyre!("Composite mapping requires composite_type"))
95}
96
97fn array_sql_type(mapping: &SqlArrayMapping, composite_type: Option<&str>) -> eyre::Result<String> {
98    Ok(match mapping {
99        SqlArrayMapping::As(sql) => fmt::with_array_brackets(sql.clone(), 1),
100        SqlArrayMapping::Composite => {
101            fmt::with_array_brackets(composite_sql_type(composite_type)?, 1)
102        }
103    })
104}
105
106fn sql_type(mapping: &SqlMapping, composite_type: Option<&str>) -> eyre::Result<String> {
107    match mapping {
108        SqlMapping::As(sql) => Ok(sql.clone()),
109        SqlMapping::Composite => composite_sql_type(composite_type),
110        SqlMapping::Array(value) => array_sql_type(value, composite_type),
111        SqlMapping::Skip => Err(eyre!("Found a skipped SQL type where SQL should be emitted")),
112    }
113}
114
115impl ToSql for PgExternEntity<'_> {
116    fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
117        let self_index = context.externs[self];
118        let mut extern_attrs = self.extern_attrs.clone();
119        // if we already have a STRICT marker we do not need to add it
120        // presume we can upgrade, then disprove it
121        let mut strict_upgrade = !extern_attrs.iter().any(|i| i == &ExternArgs::Strict);
122        if strict_upgrade {
123            // It may be possible to infer a `STRICT` marker though.
124            // But we can only do that if the user hasn't used a nullable argument wrapper.
125            for arg in &self.fn_args {
126                if arg.used_ty.optional {
127                    strict_upgrade = false;
128                }
129            }
130        }
131
132        if strict_upgrade {
133            extern_attrs.push(ExternArgs::Strict);
134        }
135        extern_attrs.sort();
136        extern_attrs.dedup();
137
138        let module_pathname = &context.get_module_pathname();
139        let schema = self
140            .schema
141            .map(|schema| format!("{schema}."))
142            .unwrap_or_else(|| context.schema_prefix_for(&self_index));
143        let arguments = if !self.fn_args.is_empty() {
144            let mut args = Vec::new();
145            let sql_args = self
146                .fn_args
147                .iter()
148                .filter(|arg| arg.used_ty.emits_argument_sql())
149                .collect::<Vec<_>>();
150            for (idx, arg) in sql_args.iter().enumerate() {
151                let needs_comma = idx < (sql_args.len().saturating_sub(1));
152                let schema_prefix = context.schema_prefix_for_used_type(
153                    &self_index,
154                    &format!("argument `{}`", arg.pattern),
155                    &arg.used_ty,
156                )?;
157                match arg.used_ty.metadata.argument_sql {
158                    Ok(SqlMapping::As(ref argument_sql)) => {
159                        let buf = format!(
160                            "\
161                                            \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
162                                        ",
163                            pattern = arg.pattern,
164                            schema_prefix = schema_prefix,
165                            // The SQL spelling comes from the embedded schema metadata.
166                            sql_type = argument_sql,
167                            default = if let Some(def) = arg.used_ty.default {
168                                format!(" DEFAULT {def}")
169                            } else {
170                                String::from("")
171                            },
172                            variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
173                            maybe_comma = if needs_comma { ", " } else { " " },
174                            type_name = arg.used_ty.full_path,
175                        );
176                        args.push(buf);
177                    }
178                    Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
179                        let sql = sql_type(mapping, arg.used_ty.composite_type)?;
180                        let buf = format!(
181                            "\
182                            \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
183                        ",
184                            pattern = arg.pattern,
185                            schema_prefix = schema_prefix,
186                            // The SQL spelling comes from the embedded schema metadata.
187                            sql_type = sql,
188                            default = if let Some(def) = arg.used_ty.default {
189                                format!(" DEFAULT {def}")
190                            } else {
191                                String::from("")
192                            },
193                            variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
194                            maybe_comma = if needs_comma { ", " } else { " " },
195                            type_name = arg.used_ty.full_path,
196                        );
197                        args.push(buf);
198                    }
199                    Ok(SqlMapping::Skip) => (),
200                    Err(err) => return Err(err).wrap_err("While mapping argument"),
201                }
202            }
203            String::from("\n") + &args.join("\n") + "\n"
204        } else {
205            Default::default()
206        };
207
208        let returns = match &self.fn_return {
209            PgExternReturnEntity::None => String::from("RETURNS void"),
210            PgExternReturnEntity::Type { ty } => {
211                let (schema_prefix, sql_type) = match &ty.metadata.return_sql {
212                    Ok(Returns::One(SqlMapping::As(sql))) => (
213                        context.schema_prefix_for_used_type(&self_index, "return type", ty)?,
214                        sql.clone(),
215                    ),
216                    Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => (
217                        context.schema_prefix_for_used_type(&self_index, "return type", ty)?,
218                        sql_type(mapping, ty.composite_type)?,
219                    ),
220                    Ok(other) => {
221                        return Err(eyre!(
222                            "Got non-plain mapped/composite return variant SQL in what macro-expansion thought was a type, got: {other:?}"
223                        ));
224                    }
225                    Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
226                };
227                format!(
228                    "RETURNS {schema_prefix}{sql_type} /* {full_path} */",
229                    full_path = ty.full_path
230                )
231            }
232            PgExternReturnEntity::SetOf { ty, .. } => {
233                let (schema_prefix, sql_type) = match &ty.metadata.return_sql {
234                    Ok(Returns::One(SqlMapping::As(sql)))
235                    | Ok(Returns::SetOf(SqlMapping::As(sql))) => (
236                        context.schema_prefix_for_used_type(
237                            &self_index,
238                            "setof return type",
239                            ty,
240                        )?,
241                        sql.clone(),
242                    ),
243                    Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_))))
244                    | Ok(Returns::SetOf(
245                        mapping @ (SqlMapping::Composite | SqlMapping::Array(_)),
246                    )) => (
247                        context.schema_prefix_for_used_type(
248                            &self_index,
249                            "setof return type",
250                            ty,
251                        )?,
252                        sql_type(mapping, ty.composite_type)?,
253                    ),
254                    Ok(other) => {
255                        return Err(eyre!(
256                            "Got non-scalar mapped/composite return variant SQL in what macro-expansion thought was a setof item, got: {other:?}"
257                        ));
258                    }
259                    Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
260                };
261                format!(
262                    "RETURNS SETOF {schema_prefix}{sql_type} /* {full_path} */",
263                    full_path = ty.full_path
264                )
265            }
266            PgExternReturnEntity::Iterated { tys: table_items, .. } => {
267                let mut items = String::new();
268                for (idx, PgExternReturnEntityIteratedItem { ty, name: col_name }) in
269                    table_items.iter().enumerate()
270                {
271                    let needs_comma = idx < (table_items.len() - 1);
272                    let (schema_prefix, ty_resolved) = match &ty.metadata.return_sql {
273                        Ok(Returns::One(SqlMapping::As(sql))) => (
274                            context.schema_prefix_for_used_type(
275                                &self_index,
276                                "table return column",
277                                ty,
278                            )?,
279                            sql.clone(),
280                        ),
281                        Ok(Returns::One(
282                            mapping @ (SqlMapping::Composite | SqlMapping::Array(_)),
283                        )) => (
284                            context.schema_prefix_for_used_type(
285                                &self_index,
286                                "table return column",
287                                ty,
288                            )?,
289                            sql_type(mapping, ty.composite_type)?,
290                        ),
291                        Ok(other) => {
292                            return Err(eyre!(
293                                "Got non-scalar table return item SQL in what macro-expansion thought was a table, got: {other:?}"
294                            ));
295                        }
296                        Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
297                    };
298                    let item = format!(
299                        "\n\t{col_name} {schema_prefix}{ty_resolved}{needs_comma} /* {ty_name} */",
300                        col_name = col_name.expect(
301                            "An iterator of tuples should have `named!()` macro declarations."
302                        ),
303                        schema_prefix = schema_prefix,
304                        ty_resolved = ty_resolved,
305                        needs_comma = if needs_comma { ", " } else { " " },
306                        ty_name = ty.full_path
307                    );
308                    items.push_str(&item);
309                }
310                format!("RETURNS TABLE ({items}\n)")
311            }
312            PgExternReturnEntity::Trigger => String::from("RETURNS trigger"),
313        };
314        let PgExternEntity { name, module_path, file, line, .. } = self;
315
316        let fn_sql = format!(
317            "\
318                CREATE {or_replace} FUNCTION {schema}\"{name}\"({arguments}) {returns}\n\
319                {extern_attrs}\
320                {search_path}\
321                LANGUAGE c /* Rust */\n\
322                AS '{module_pathname}', '{unaliased_name}_wrapper';\
323            ",
324            or_replace =
325                if extern_attrs.contains(&ExternArgs::CreateOrReplace) { "OR REPLACE" } else { "" },
326            search_path = if let Some(search_path) = &self.search_path {
327                let retval = format!("SET search_path TO {}", search_path.join(", "));
328                retval + "\n"
329            } else {
330                Default::default()
331            },
332            extern_attrs = if extern_attrs.is_empty() {
333                String::default()
334            } else {
335                let mut retval = extern_attrs
336                    .iter()
337                    .filter(|attr| **attr != ExternArgs::CreateOrReplace)
338                    .map(|attr| {
339                        if matches!(attr, ExternArgs::Support(..)) {
340                            let support_fn_name = attr.to_string();
341
342                            let support_fn_name =
343                            if let Some(entity) = context.find_matching_fn(&support_fn_name) {
344                                entity.sql_name(context)
345                            } else {
346                                panic!("cannot locate SUPPORT function `{support_fn_name}` attached to function `{}`", self.full_path)
347                            };
348
349                            format!("SUPPORT {support_fn_name}")
350                        } else {
351                            attr.to_string().to_uppercase()
352                        }
353                    })
354                    .collect::<Vec<_>>()
355                    .join(" ");
356                retval.push('\n');
357                retval
358            },
359            unaliased_name = self.unaliased_name,
360        );
361
362        let requires = {
363            let requires_attrs = self
364                .extern_attrs
365                .iter()
366                .filter_map(|x| match x {
367                    ExternArgs::Requires(requirements) => Some(requirements.clone()),
368                    ExternArgs::Support(support_fn) => Some(vec![support_fn.clone()]),
369                    _ => None,
370                })
371                .flatten()
372                .collect::<Vec<_>>();
373
374            if !requires_attrs.is_empty() {
375                format!(
376                    "-- requires:\n{}\n",
377                    requires_attrs
378                        .iter()
379                        .map(|i| format!("--   {i}"))
380                        .collect::<Vec<_>>()
381                        .join("\n")
382                )
383            } else {
384                "".to_string()
385            }
386        };
387
388        let mut ext_sql = format!(
389            "\n\
390            -- {file}:{line}\n\
391            -- {module_path}::{name}\n\
392            {requires}\
393            {fn_sql}"
394        );
395
396        if let Some(op) = &self.operator {
397            let mut optionals = vec![];
398            if let Some(it) = op.commutator {
399                optionals.push(format!("\tCOMMUTATOR = {it}"));
400            };
401            if let Some(it) = op.negator {
402                optionals.push(format!("\tNEGATOR = {it}"));
403            };
404            if let Some(it) = op.restrict {
405                optionals.push(format!("\tRESTRICT = {it}"));
406            };
407            if let Some(it) = op.join {
408                optionals.push(format!("\tJOIN = {it}"));
409            };
410            if op.hashes {
411                optionals.push(String::from("\tHASHES"));
412            };
413            if op.merges {
414                optionals.push(String::from("\tMERGES"));
415            };
416
417            let left_arg = self
418                .fn_args
419                .first()
420                .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
421            let left_arg_schema_prefix = context.schema_prefix_for_used_type(
422                &self_index,
423                "operator left argument",
424                &left_arg.used_ty,
425            )?;
426            let left_arg_sql = match left_arg.used_ty.metadata.argument_sql {
427                Ok(SqlMapping::As(ref sql)) => sql.clone(),
428                Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
429                    sql_type(mapping, left_arg.used_ty.composite_type)?
430                }
431                Ok(SqlMapping::Skip) => {
432                    return Err(eyre!(
433                        "Found an skipped SQL type in an operator, this is not valid"
434                    ));
435                }
436                Err(err) => return Err(err.into()),
437            };
438
439            let right_arg = self
440                .fn_args
441                .get(1)
442                .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
443            let right_arg_schema_prefix = context.schema_prefix_for_used_type(
444                &self_index,
445                "operator right argument",
446                &right_arg.used_ty,
447            )?;
448            let right_arg_sql = match right_arg.used_ty.metadata.argument_sql {
449                Ok(SqlMapping::As(ref sql)) => sql.clone(),
450                Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
451                    sql_type(mapping, right_arg.used_ty.composite_type)?
452                }
453                Ok(SqlMapping::Skip) => {
454                    return Err(eyre!(
455                        "Found an skipped SQL type in an operator, this is not valid"
456                    ));
457                }
458                Err(err) => return Err(err.into()),
459            };
460
461            let schema = self
462                .schema
463                .map(|schema| format!("{schema}."))
464                .unwrap_or_else(|| context.schema_prefix_for(&self_index));
465
466            let operator_sql = format!(
467                "\n\n\
468                                                    -- {file}:{line}\n\
469                                                    -- {module_path}::{name}\n\
470                                                    CREATE OPERATOR {schema}{opname} (\n\
471                                                        \tPROCEDURE={schema}\"{name}\",\n\
472                                                        \tLEFTARG={schema_prefix_left}{left_arg_sql}, /* {left_name} */\n\
473                                                        \tRIGHTARG={schema_prefix_right}{right_arg_sql}{maybe_comma} /* {right_name} */\n\
474                                                        {optionals}\
475                                                    );\
476                                                    ",
477                opname = op.opname.unwrap(),
478                left_name = left_arg.used_ty.full_path,
479                right_name = right_arg.used_ty.full_path,
480                schema_prefix_left = left_arg_schema_prefix,
481                schema_prefix_right = right_arg_schema_prefix,
482                maybe_comma = if !optionals.is_empty() { "," } else { "" },
483                optionals = if !optionals.is_empty() {
484                    optionals.join(",\n") + "\n"
485                } else {
486                    "".to_string()
487                },
488            );
489            ext_sql += &operator_sql
490        };
491        if let Some(cast) = &self.cast {
492            let target_fn_arg = &self.fn_return;
493            let target_ty = match target_fn_arg {
494                PgExternReturnEntity::Type { ty } => ty,
495                other => {
496                    return Err(eyre!("Casts must return a plain type, got: {other:?}"));
497                }
498            };
499            let target_arg_schema_prefix =
500                context.schema_prefix_for_used_type(&self_index, "cast target type", target_ty)?;
501            let target_arg_sql = match &target_ty.metadata.return_sql {
502                Ok(Returns::One(SqlMapping::As(sql))) => sql.clone(),
503                Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => {
504                    sql_type(mapping, target_ty.composite_type)?
505                }
506                Ok(Returns::One(SqlMapping::Skip)) => {
507                    return Err(eyre!("Found an skipped SQL type in a cast, this is not valid"));
508                }
509                Err(err) => return Err((*err).into()),
510                Ok(other) => {
511                    return Err(eyre!("Casts must return a plain SQL type, got: {other:?}"));
512                }
513            };
514            let source_arg = self
515                .fn_args
516                .first()
517                .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?;
518            let source_arg_schema_prefix = context.schema_prefix_for_used_type(
519                &self_index,
520                "cast source type",
521                &source_arg.used_ty,
522            )?;
523            let source_arg_sql = match source_arg.used_ty.metadata.argument_sql {
524                Ok(SqlMapping::As(ref sql)) => sql.clone(),
525                Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
526                    sql_type(mapping, source_arg.used_ty.composite_type)?
527                }
528                Ok(SqlMapping::Skip) => {
529                    return Err(eyre!("Found an skipped SQL type in a cast, this is not valid"));
530                }
531                Err(err) => return Err(err.into()),
532            };
533            let optional = match cast {
534                PgCastEntity::Default => String::from(""),
535                PgCastEntity::Assignment => String::from(" AS ASSIGNMENT"),
536                PgCastEntity::Implicit => String::from(" AS IMPLICIT"),
537            };
538
539            let cast_sql = format!(
540                "\n\n\
541                                                    -- {file}:{line}\n\
542                                                    -- {module_path}::{name}\n\
543                                                    CREATE CAST (\n\
544                                                        \t{schema_prefix_source}{source_arg_sql} /* {source_name} */\n\
545                                                        \tAS\n\
546                                                        \t{schema_prefix_target}{target_arg_sql} /* {target_name} */\n\
547                                                    )\n\
548                                                    WITH FUNCTION {function_name}{optional};\
549                                                    ",
550                file = self.file,
551                line = self.line,
552                name = self.name,
553                module_path = self.module_path,
554                schema_prefix_source = source_arg_schema_prefix,
555                source_name = source_arg.used_ty.full_path,
556                schema_prefix_target = target_arg_schema_prefix,
557                target_name = target_ty.full_path,
558                function_name = self.name,
559            );
560            ext_sql += &cast_sql
561        };
562        Ok(ext_sql)
563    }
564}