Skip to main content

pgrx_sql_entity_graph/metadata/
sql_translatable.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12A trait denoting a type can possibly be mapped to an SQL type
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18use std::any::Any;
19use std::ffi::{CStr, CString};
20use std::fmt::Display;
21use thiserror::Error;
22
23use super::return_variant::ReturnsError;
24use super::{FunctionMetadataTypeEntity, Returns, TypeOrigin};
25
26#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Error)]
27pub enum ArgumentError {
28    #[error("Cannot use SetOfIterator as an argument")]
29    SetOf,
30    #[error("Cannot use TableIterator as an argument")]
31    Table,
32    #[error("Nested arrays are not supported in arguments")]
33    NestedArray,
34    #[error("Cannot use bare u8")]
35    BareU8,
36    #[error("SqlMapping::Skip inside Array is not valid")]
37    SkipInArray,
38    #[error("A Datum as an argument means that `sql = \"...\"` must be set in the declaration")]
39    Datum,
40    #[error("`{0}` is not able to be used as a function argument")]
41    NotValidAsArgument(&'static str),
42}
43
44/// Describes ways that Rust types are mapped into SQL
45#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
46pub enum SqlMapping {
47    /// Explicit mappings provided by PGRX
48    As(String),
49    Composite,
50    Array(SqlArrayMapping),
51    /// A type which does not actually appear in SQL
52    Skip,
53}
54
55impl SqlMapping {
56    pub fn literal(s: &'static str) -> SqlMapping {
57        SqlMapping::As(String::from(s))
58    }
59}
60
61#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
62pub enum SqlArrayMapping {
63    /// Explicit mappings provided by PGRX
64    As(String),
65    Composite,
66}
67
68/// Const-friendly SQL mapping metadata.
69#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
70pub enum SqlMappingRef {
71    /// Explicit mappings provided by PGRX
72    As(&'static str),
73    Numeric {
74        precision: Option<u32>,
75        scale: Option<u32>,
76    },
77    Composite,
78    Array(SqlArrayMappingRef),
79    /// A type which does not actually appear in SQL
80    Skip,
81}
82
83impl SqlMappingRef {
84    pub const fn literal(s: &'static str) -> Self {
85        Self::As(s)
86    }
87}
88
89#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
90pub enum SqlArrayMappingRef {
91    /// Explicit mappings provided by PGRX
92    As(&'static str),
93    Numeric {
94        precision: Option<u32>,
95        scale: Option<u32>,
96    },
97    Composite,
98}
99
100pub(crate) fn numeric_sql_string(precision: Option<u32>, scale: Option<u32>) -> String {
101    match (precision, scale) {
102        (None, _) => "NUMERIC".to_string(),
103        (Some(precision), None) => format!("NUMERIC({precision})"),
104        (Some(precision), Some(scale)) => format!("NUMERIC({precision}, {scale})"),
105    }
106}
107
108impl From<SqlArrayMappingRef> for SqlArrayMapping {
109    fn from(value: SqlArrayMappingRef) -> Self {
110        match value {
111            SqlArrayMappingRef::As(value) => SqlArrayMapping::As(String::from(value)),
112            SqlArrayMappingRef::Numeric { precision, scale } => {
113                SqlArrayMapping::As(numeric_sql_string(precision, scale))
114            }
115            SqlArrayMappingRef::Composite => SqlArrayMapping::Composite,
116        }
117    }
118}
119
120impl From<SqlMappingRef> for SqlMapping {
121    fn from(value: SqlMappingRef) -> Self {
122        match value {
123            SqlMappingRef::As(value) => SqlMapping::literal(value),
124            SqlMappingRef::Numeric { precision, scale } => {
125                SqlMapping::As(numeric_sql_string(precision, scale))
126            }
127            SqlMappingRef::Composite => SqlMapping::Composite,
128            SqlMappingRef::Array(value) => SqlMapping::Array(value.into()),
129            SqlMappingRef::Skip => SqlMapping::Skip,
130        }
131    }
132}
133
134/// Const-friendly return metadata.
135#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
136pub enum ReturnsRef {
137    One(SqlMappingRef),
138    SetOf(SqlMappingRef),
139    Table(&'static [SqlMappingRef]),
140}
141
142impl From<ReturnsRef> for Returns {
143    fn from(value: ReturnsRef) -> Self {
144        match value {
145            ReturnsRef::One(value) => Returns::One(value.into()),
146            ReturnsRef::SetOf(value) => Returns::SetOf(value.into()),
147            ReturnsRef::Table(values) => {
148                Returns::Table(values.iter().copied().map(Into::into).collect())
149            }
150        }
151    }
152}
153
154pub const fn array_argument_sql(
155    mapping: Result<SqlMappingRef, ArgumentError>,
156) -> Result<SqlMappingRef, ArgumentError> {
157    match mapping {
158        Ok(SqlMappingRef::As(sql)) => Ok(SqlMappingRef::Array(SqlArrayMappingRef::As(sql))),
159        Ok(SqlMappingRef::Numeric { precision, scale }) => {
160            Ok(SqlMappingRef::Array(SqlArrayMappingRef::Numeric { precision, scale }))
161        }
162        Ok(SqlMappingRef::Composite) => Ok(SqlMappingRef::Array(SqlArrayMappingRef::Composite)),
163        Ok(SqlMappingRef::Skip) => Err(ArgumentError::SkipInArray),
164        Ok(SqlMappingRef::Array(_)) => Err(ArgumentError::NestedArray),
165        Err(err) => Err(err),
166    }
167}
168
169pub const fn array_return_sql(
170    returns: Result<ReturnsRef, ReturnsError>,
171) -> Result<ReturnsRef, ReturnsError> {
172    match returns {
173        Ok(ReturnsRef::One(SqlMappingRef::As(sql))) => {
174            Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::As(sql))))
175        }
176        Ok(ReturnsRef::One(SqlMappingRef::Numeric { precision, scale })) => {
177            Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
178                precision,
179                scale,
180            })))
181        }
182        Ok(ReturnsRef::One(SqlMappingRef::Composite)) => {
183            Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Composite)))
184        }
185        Ok(ReturnsRef::One(SqlMappingRef::Skip)) => Err(ReturnsError::SkipInArray),
186        Ok(ReturnsRef::One(SqlMappingRef::Array(_))) => Err(ReturnsError::NestedArray),
187        Ok(ReturnsRef::SetOf(_)) => Err(ReturnsError::SetOfInArray),
188        Ok(ReturnsRef::Table(_)) => Err(ReturnsError::TableInArray),
189        Err(err) => Err(err),
190    }
191}
192
193pub const fn setof_return_sql(
194    returns: Result<ReturnsRef, ReturnsError>,
195) -> Result<ReturnsRef, ReturnsError> {
196    match returns {
197        Ok(ReturnsRef::One(sql)) => Ok(ReturnsRef::SetOf(sql)),
198        Ok(ReturnsRef::SetOf(_)) => Err(ReturnsError::NestedSetOf),
199        Ok(ReturnsRef::Table(_)) => Err(ReturnsError::SetOfContainingTable),
200        Err(err) => Err(err),
201    }
202}
203
204pub const fn table_item_sql(
205    returns: Result<ReturnsRef, ReturnsError>,
206) -> Result<SqlMappingRef, ReturnsError> {
207    match returns {
208        Ok(ReturnsRef::One(sql)) => Ok(sql),
209        Ok(ReturnsRef::SetOf(_)) => Err(ReturnsError::TableContainingSetOf),
210        Ok(ReturnsRef::Table(_)) => Err(ReturnsError::NestedTable),
211        Err(err) => Err(err),
212    }
213}
214
215/// Implements `SqlTranslatable` for a type with a fixed external SQL mapping.
216///
217/// This macro uses `pgrx_resolved_type!(T)` for `TYPE_IDENT`, sets
218/// `TYPE_ORIGIN` to `TypeOrigin::External`, and fills in the const SQL metadata
219/// for the common "map this Rust wrapper to an existing SQL type" case.
220///
221/// Spell out the `unsafe impl SqlTranslatable` instead when (1) the type is owned by
222/// this extension or (2) when its argument and return SQL need different mappings.
223///
224/// This macro is re-exported by `pgrx` and is also available through
225/// `pgrx::prelude::*`.
226///
227/// # Examples
228///
229/// A wrapper that maps to the existing `uuid` type:
230///
231/// ```ignore
232/// use pgrx::prelude::*;
233///
234/// pub struct UuidWrapper(uuid::Uuid);
235///
236/// impl_sql_translatable!(UuidWrapper, "uuid");
237/// ```
238///
239/// An argument-only wrapper for a pseudo-type:
240///
241/// ```ignore
242/// use pgrx::prelude::*;
243///
244/// pub struct InternalArg(*mut core::ffi::c_void);
245///
246/// impl_sql_translatable!(InternalArg, arg_only = "internal");
247/// ```
248#[macro_export]
249macro_rules! impl_sql_translatable {
250    ($ty:ty, $sql:literal) => {
251        unsafe impl $crate::metadata::SqlTranslatable for $ty {
252            const TYPE_IDENT: &'static str = $crate::pgrx_resolved_type!($ty);
253            const TYPE_ORIGIN: $crate::metadata::TypeOrigin =
254                $crate::metadata::TypeOrigin::External;
255            const ARGUMENT_SQL: Result<
256                $crate::metadata::SqlMappingRef,
257                $crate::metadata::ArgumentError,
258            > = Ok($crate::metadata::SqlMappingRef::literal($sql));
259            const RETURN_SQL: Result<$crate::metadata::ReturnsRef, $crate::metadata::ReturnsError> =
260                Ok($crate::metadata::ReturnsRef::One($crate::metadata::SqlMappingRef::literal(
261                    $sql,
262                )));
263        }
264    };
265    ($ty:ty, arg_only = $sql:literal) => {
266        unsafe impl $crate::metadata::SqlTranslatable for $ty {
267            const TYPE_IDENT: &'static str = $crate::pgrx_resolved_type!($ty);
268            const TYPE_ORIGIN: $crate::metadata::TypeOrigin =
269                $crate::metadata::TypeOrigin::External;
270            const ARGUMENT_SQL: Result<
271                $crate::metadata::SqlMappingRef,
272                $crate::metadata::ArgumentError,
273            > = Ok($crate::metadata::SqlMappingRef::literal($sql));
274            const RETURN_SQL: Result<$crate::metadata::ReturnsRef, $crate::metadata::ReturnsError> =
275                Err($crate::metadata::ReturnsError::Datum);
276        }
277    };
278}
279
280/**
281A value which can be represented in SQL
282
283If you need the common "fixed external SQL type" case, prefer
284`impl_sql_translatable!`. Spell out this trait impl when (1) the type is owned
285by this extension or (2) when the argument or return SQL is unusual.
286
287# Safety
288
289By implementing this, you assert you are not lying to either Postgres or Rust in doing so.
290This trait asserts a safe translation exists between values of this type from Rust to SQL,
291or from SQL into Rust. If you are mistaken about how this works, either the Postgres C API
292or the Rust handling in PGRX may emit undefined behavior.
293
294It cannot be made private or sealed due to details of the structure of the PGRX framework.
295Nonetheless, if you are not confident the translation is valid: do not implement this trait.
296*/
297#[diagnostic::on_unimplemented(
298    message = "`{Self}` has no representation in SQL",
299    label = "non-SQL type"
300)]
301pub unsafe trait SqlTranslatable {
302    const TYPE_IDENT: &'static str;
303    const TYPE_ORIGIN: TypeOrigin;
304    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError>;
305    const RETURN_SQL: Result<ReturnsRef, ReturnsError>;
306
307    fn type_name() -> &'static str {
308        core::any::type_name::<Self>()
309    }
310    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
311        Self::ARGUMENT_SQL.map(Into::into)
312    }
313    fn return_sql() -> Result<Returns, ReturnsError> {
314        Self::RETURN_SQL.map(Into::into)
315    }
316    fn entity() -> FunctionMetadataTypeEntity<'static> {
317        FunctionMetadataTypeEntity::resolved(
318            Self::TYPE_IDENT,
319            Self::TYPE_ORIGIN,
320            Self::argument_sql(),
321            Self::return_sql(),
322        )
323    }
324}
325
326unsafe impl SqlTranslatable for () {
327    const TYPE_IDENT: &'static str = crate::pgrx_resolved_type!(());
328    const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
329    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
330        Err(ArgumentError::NotValidAsArgument("()"));
331    const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
332        Ok(ReturnsRef::One(SqlMappingRef::literal("VOID")));
333}
334
335unsafe impl<T> SqlTranslatable for Option<T>
336where
337    T: SqlTranslatable,
338{
339    const TYPE_IDENT: &'static str = T::TYPE_IDENT;
340    const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
341    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
342    const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
343}
344
345unsafe impl<T> SqlTranslatable for *mut T
346where
347    T: SqlTranslatable,
348{
349    const TYPE_IDENT: &'static str = T::TYPE_IDENT;
350    const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
351    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
352    const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
353}
354
355unsafe impl<T, E> SqlTranslatable for Result<T, E>
356where
357    T: SqlTranslatable,
358    E: Any + Display,
359{
360    const TYPE_IDENT: &'static str = T::TYPE_IDENT;
361    const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
362    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
363    const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
364}
365
366unsafe impl<T> SqlTranslatable for Vec<T>
367where
368    T: SqlTranslatable,
369{
370    const TYPE_IDENT: &'static str = T::TYPE_IDENT;
371    const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
372    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = match T::ARGUMENT_SQL {
373        Err(ArgumentError::BareU8) => Ok(SqlMappingRef::As("bytea")),
374        other => array_argument_sql(other),
375    };
376    const RETURN_SQL: Result<ReturnsRef, ReturnsError> = match T::RETURN_SQL {
377        Err(ReturnsError::BareU8) => Ok(ReturnsRef::One(SqlMappingRef::As("bytea"))),
378        other => array_return_sql(other),
379    };
380}
381
382unsafe impl SqlTranslatable for u8 {
383    const TYPE_IDENT: &'static str = crate::pgrx_resolved_type!(u8);
384    const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
385    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = Err(ArgumentError::BareU8);
386    const RETURN_SQL: Result<ReturnsRef, ReturnsError> = Err(ReturnsError::BareU8);
387}
388
389macro_rules! simple_sql_type {
390    ($ty:ty, $sql:literal) => {
391        unsafe impl SqlTranslatable for $ty {
392            const TYPE_IDENT: &'static str = $crate::pgrx_resolved_type!($ty);
393            const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
394            const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
395                Ok(SqlMappingRef::literal($sql));
396            const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
397                Ok(ReturnsRef::One(SqlMappingRef::literal($sql)));
398        }
399    };
400}
401
402simple_sql_type!(i32, "INT");
403simple_sql_type!(String, "TEXT");
404simple_sql_type!(str, "TEXT");
405simple_sql_type!([u8], "bytea");
406simple_sql_type!(i8, "\"char\"");
407simple_sql_type!(i16, "smallint");
408simple_sql_type!(i64, "bigint");
409simple_sql_type!(bool, "bool");
410simple_sql_type!(char, "varchar");
411simple_sql_type!(f32, "real");
412simple_sql_type!(f64, "double precision");
413simple_sql_type!(CString, "cstring");
414simple_sql_type!(CStr, "cstring");
415
416unsafe impl<T> SqlTranslatable for &T
417where
418    T: ?Sized + SqlTranslatable,
419{
420    const TYPE_IDENT: &'static str = T::TYPE_IDENT;
421    const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
422    const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
423    const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    struct MacroExternalType;
431    impl_sql_translatable!(MacroExternalType, "uuid");
432
433    struct MacroArgOnlyType;
434    impl_sql_translatable!(MacroArgOnlyType, arg_only = "internal");
435
436    #[test]
437    fn impl_sql_translatable_sets_external_defaults() {
438        assert_eq!(
439            <MacroExternalType as SqlTranslatable>::TYPE_IDENT,
440            concat!(module_path!(), "::", "MacroExternalType")
441        );
442        assert_eq!(<MacroExternalType as SqlTranslatable>::TYPE_ORIGIN, TypeOrigin::External);
443        assert_eq!(
444            <MacroExternalType as SqlTranslatable>::ARGUMENT_SQL,
445            Ok(SqlMappingRef::literal("uuid"))
446        );
447        assert_eq!(
448            <MacroExternalType as SqlTranslatable>::RETURN_SQL,
449            Ok(ReturnsRef::One(SqlMappingRef::literal("uuid")))
450        );
451    }
452
453    #[test]
454    fn impl_sql_translatable_supports_arg_only_types() {
455        assert_eq!(
456            <MacroArgOnlyType as SqlTranslatable>::TYPE_IDENT,
457            concat!(module_path!(), "::", "MacroArgOnlyType")
458        );
459        assert_eq!(<MacroArgOnlyType as SqlTranslatable>::TYPE_ORIGIN, TypeOrigin::External);
460        assert_eq!(
461            <MacroArgOnlyType as SqlTranslatable>::ARGUMENT_SQL,
462            Ok(SqlMappingRef::literal("internal"))
463        );
464        assert_eq!(<MacroArgOnlyType as SqlTranslatable>::RETURN_SQL, Err(ReturnsError::Datum));
465    }
466
467    #[test]
468    fn array_argument_sql_wraps_scalar_kinds() {
469        assert_eq!(
470            array_argument_sql(Ok(SqlMappingRef::literal("INT"))),
471            Ok(SqlMappingRef::Array(SqlArrayMappingRef::As("INT")))
472        );
473        assert_eq!(
474            array_argument_sql(Ok(SqlMappingRef::Numeric { precision: Some(10), scale: Some(2) })),
475            Ok(SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
476                precision: Some(10),
477                scale: Some(2),
478            }))
479        );
480        assert_eq!(
481            array_argument_sql(Ok(SqlMappingRef::Composite)),
482            Ok(SqlMappingRef::Array(SqlArrayMappingRef::Composite))
483        );
484    }
485
486    #[test]
487    fn array_return_sql_wraps_scalar_kinds() {
488        assert_eq!(
489            array_return_sql(Ok(ReturnsRef::One(SqlMappingRef::literal("INT")))),
490            Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::As("INT"))))
491        );
492        assert_eq!(
493            array_return_sql(Ok(ReturnsRef::One(SqlMappingRef::Numeric {
494                precision: Some(10),
495                scale: Some(2),
496            }))),
497            Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
498                precision: Some(10),
499                scale: Some(2),
500            })))
501        );
502        assert_eq!(
503            array_return_sql(Ok(ReturnsRef::One(SqlMappingRef::Composite))),
504            Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Composite)))
505        );
506    }
507
508    #[test]
509    fn nested_vec_arrays_fail_fast() {
510        assert_eq!(
511            <Vec<Vec<i32>> as SqlTranslatable>::ARGUMENT_SQL,
512            Err(ArgumentError::NestedArray)
513        );
514        assert_eq!(<Vec<Vec<i32>> as SqlTranslatable>::RETURN_SQL, Err(ReturnsError::NestedArray));
515    }
516
517    #[test]
518    fn nested_numeric_arrays_fail_fast() {
519        let numeric = SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
520            precision: Some(10),
521            scale: Some(2),
522        });
523        assert_eq!(array_argument_sql(Ok(numeric)), Err(ArgumentError::NestedArray));
524    }
525
526    #[test]
527    fn nested_composite_arrays_fail_fast() {
528        let composite = SqlMappingRef::Array(SqlArrayMappingRef::Composite);
529        assert_eq!(
530            array_return_sql(Ok(ReturnsRef::One(composite))),
531            Err(ReturnsError::NestedArray)
532        );
533    }
534}