pgx_sql_entity_graph/metadata/
sql_translatable.rs

1/*!
2
3A trait denoting a type can possibly be mapped to an SQL type
4
5> Like all of the [`sql_entity_graph`][crate::pgx_sql_entity_graph] APIs, this is considered **internal**
6to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
7
8*/
9use std::any::Any;
10use std::error::Error;
11use std::fmt::Display;
12
13use super::return_variant::ReturnsError;
14use super::{FunctionMetadataTypeEntity, Returns};
15
16#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq)]
17pub enum ArgumentError {
18    SetOf,
19    Table,
20    BareU8,
21    SkipInArray,
22    Datum,
23    NotValidAsArgument(&'static str),
24}
25
26impl std::fmt::Display for ArgumentError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            ArgumentError::SetOf => {
30                write!(f, "Cannot use SetOfIterator as an argument")
31            }
32            ArgumentError::Table => {
33                write!(f, "Cannot use TableIterator as an argument")
34            }
35            ArgumentError::BareU8 => {
36                write!(f, "Cannot use bare u8")
37            }
38            ArgumentError::SkipInArray => {
39                write!(f, "SqlMapping::Skip inside Array is not valid")
40            }
41            ArgumentError::Datum => {
42                write!(f, "A Datum as an argument means that `sql = \"...\"` must be set in the declaration")
43            }
44            ArgumentError::NotValidAsArgument(type_name) => {
45                write!(f, "`{}` is not able to be used as a function argument", type_name)
46            }
47        }
48    }
49}
50
51/// Describes ways that Rust types are mapped into SQL
52#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
53pub enum SqlMapping {
54    /// Explicit mappings provided by PGX
55    As(String),
56    Composite {
57        array_brackets: bool,
58    },
59    /// Some types are still directly from source
60    Source {
61        array_brackets: bool,
62    },
63    /// Placeholder for some types with no simple translation
64    Skip,
65}
66
67impl SqlMapping {
68    pub fn literal(s: &'static str) -> SqlMapping {
69        SqlMapping::As(String::from(s))
70    }
71}
72
73impl Error for ArgumentError {}
74
75/**
76A value which can be represented in SQL
77
78# Safety
79
80By implementing this, you assert you are not lying to either Postgres or Rust in doing so.
81This trait asserts a safe translation exists between values of this type from Rust to SQL,
82or from SQL into Rust. If you are mistaken about how this works, either the Postgres C API
83or the Rust handling in PGX may emit undefined behavior.
84
85It cannot be made private or sealed due to details of the structure of the PGX framework.
86Nonetheless, if you are not confident the translation is valid: do not implement this trait.
87*/
88pub unsafe trait SqlTranslatable {
89    fn type_name() -> &'static str {
90        core::any::type_name::<Self>()
91    }
92    fn argument_sql() -> Result<SqlMapping, ArgumentError>;
93    fn return_sql() -> Result<Returns, ReturnsError>;
94    fn variadic() -> bool {
95        false
96    }
97    fn optional() -> bool {
98        false
99    }
100    fn entity() -> FunctionMetadataTypeEntity {
101        FunctionMetadataTypeEntity {
102            type_name: Self::type_name(),
103            argument_sql: Self::argument_sql(),
104            return_sql: Self::return_sql(),
105            variadic: Self::variadic(),
106            optional: Self::optional(),
107        }
108    }
109}
110
111unsafe impl<E> SqlTranslatable for Result<(), E>
112where
113    E: Any + Display,
114{
115    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
116        Err(ArgumentError::NotValidAsArgument("()"))
117    }
118
119    fn return_sql() -> Result<Returns, ReturnsError> {
120        Ok(Returns::One(SqlMapping::literal("VOID")))
121    }
122}
123
124unsafe impl<T> SqlTranslatable for Option<T>
125where
126    T: SqlTranslatable,
127{
128    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
129        T::argument_sql()
130    }
131    fn return_sql() -> Result<Returns, ReturnsError> {
132        T::return_sql()
133    }
134    fn optional() -> bool {
135        true
136    }
137}
138
139unsafe impl<T> SqlTranslatable for *mut T
140where
141    T: SqlTranslatable,
142{
143    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
144        T::argument_sql()
145    }
146    fn return_sql() -> Result<Returns, ReturnsError> {
147        T::return_sql()
148    }
149    fn optional() -> bool {
150        T::optional()
151    }
152}
153
154unsafe impl<T, E> SqlTranslatable for Result<T, E>
155where
156    T: SqlTranslatable,
157    E: Any + Display,
158{
159    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
160        T::argument_sql()
161    }
162    fn return_sql() -> Result<Returns, ReturnsError> {
163        T::return_sql()
164    }
165    fn optional() -> bool {
166        true
167    }
168}
169
170unsafe impl<T> SqlTranslatable for Vec<T>
171where
172    T: SqlTranslatable,
173{
174    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
175        match T::type_name() {
176            id if id == u8::type_name() => Ok(SqlMapping::As(format!("bytea"))),
177            _ => match T::argument_sql() {
178                Ok(SqlMapping::As(val)) => Ok(SqlMapping::As(format!("{val}[]"))),
179                Ok(SqlMapping::Composite { array_brackets: _ }) => {
180                    Ok(SqlMapping::Composite { array_brackets: true })
181                }
182                Ok(SqlMapping::Source { array_brackets: _ }) => {
183                    Ok(SqlMapping::Source { array_brackets: true })
184                }
185                Ok(SqlMapping::Skip) => Ok(SqlMapping::Skip),
186                err @ Err(_) => err,
187            },
188        }
189    }
190
191    fn return_sql() -> Result<Returns, ReturnsError> {
192        match T::type_name() {
193            id if id == u8::type_name() => Ok(Returns::One(SqlMapping::As(format!("bytea")))),
194            _ => match T::return_sql() {
195                Ok(Returns::One(SqlMapping::As(val))) => {
196                    Ok(Returns::One(SqlMapping::As(format!("{val}[]"))))
197                }
198                Ok(Returns::One(SqlMapping::Composite { array_brackets: _ })) => {
199                    Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
200                }
201                Ok(Returns::One(SqlMapping::Source { array_brackets: _ })) => {
202                    Ok(Returns::One(SqlMapping::Source { array_brackets: true }))
203                }
204                Ok(Returns::One(SqlMapping::Skip)) => Ok(Returns::One(SqlMapping::Skip)),
205                Ok(Returns::SetOf(_)) => Err(ReturnsError::SetOfInArray),
206                Ok(Returns::Table(_)) => Err(ReturnsError::TableInArray),
207                err @ Err(_) => err,
208            },
209        }
210    }
211    fn optional() -> bool {
212        T::optional()
213    }
214}
215
216unsafe impl SqlTranslatable for u8 {
217    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
218        Err(ArgumentError::BareU8)
219    }
220    fn return_sql() -> Result<Returns, ReturnsError> {
221        Err(ReturnsError::BareU8)
222    }
223}
224
225unsafe impl SqlTranslatable for i32 {
226    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
227        Ok(SqlMapping::literal("INT"))
228    }
229    fn return_sql() -> Result<Returns, ReturnsError> {
230        Ok(Returns::One(SqlMapping::literal("INT")))
231    }
232}
233
234unsafe impl SqlTranslatable for u32 {
235    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
236        Ok(SqlMapping::Source { array_brackets: false })
237    }
238    fn return_sql() -> Result<Returns, ReturnsError> {
239        Ok(Returns::One(SqlMapping::Source { array_brackets: false }))
240    }
241}
242
243unsafe impl SqlTranslatable for String {
244    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
245        Ok(SqlMapping::literal("TEXT"))
246    }
247    fn return_sql() -> Result<Returns, ReturnsError> {
248        Ok(Returns::One(SqlMapping::literal("TEXT")))
249    }
250}
251
252unsafe impl<T> SqlTranslatable for &T
253where
254    T: SqlTranslatable,
255{
256    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
257        T::argument_sql()
258    }
259    fn return_sql() -> Result<Returns, ReturnsError> {
260        T::return_sql()
261    }
262}
263
264unsafe impl<'a> SqlTranslatable for &'a str {
265    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
266        Ok(SqlMapping::literal("TEXT"))
267    }
268    fn return_sql() -> Result<Returns, ReturnsError> {
269        Ok(Returns::One(SqlMapping::literal("TEXT")))
270    }
271}
272
273unsafe impl<'a> SqlTranslatable for &'a [u8] {
274    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
275        Ok(SqlMapping::literal("bytea"))
276    }
277    fn return_sql() -> Result<Returns, ReturnsError> {
278        Ok(Returns::One(SqlMapping::literal("bytea")))
279    }
280}
281
282unsafe impl SqlTranslatable for i8 {
283    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
284        Ok(SqlMapping::As(String::from("\"char\"")))
285    }
286    fn return_sql() -> Result<Returns, ReturnsError> {
287        Ok(Returns::One(SqlMapping::As(String::from("\"char\""))))
288    }
289}
290
291unsafe impl SqlTranslatable for i16 {
292    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
293        Ok(SqlMapping::literal("smallint"))
294    }
295    fn return_sql() -> Result<Returns, ReturnsError> {
296        Ok(Returns::One(SqlMapping::literal("smallint")))
297    }
298}
299
300unsafe impl SqlTranslatable for i64 {
301    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
302        Ok(SqlMapping::literal("bigint"))
303    }
304    fn return_sql() -> Result<Returns, ReturnsError> {
305        Ok(Returns::One(SqlMapping::literal("bigint")))
306    }
307}
308
309unsafe impl SqlTranslatable for bool {
310    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
311        Ok(SqlMapping::literal("bool"))
312    }
313    fn return_sql() -> Result<Returns, ReturnsError> {
314        Ok(Returns::One(SqlMapping::literal("bool")))
315    }
316}
317
318unsafe impl SqlTranslatable for char {
319    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
320        Ok(SqlMapping::literal("varchar"))
321    }
322    fn return_sql() -> Result<Returns, ReturnsError> {
323        Ok(Returns::One(SqlMapping::literal("varchar")))
324    }
325}
326
327unsafe impl SqlTranslatable for f32 {
328    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
329        Ok(SqlMapping::literal("real"))
330    }
331    fn return_sql() -> Result<Returns, ReturnsError> {
332        Ok(Returns::One(SqlMapping::literal("real")))
333    }
334}
335
336unsafe impl SqlTranslatable for f64 {
337    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
338        Ok(SqlMapping::literal("double precision"))
339    }
340    fn return_sql() -> Result<Returns, ReturnsError> {
341        Ok(Returns::One(SqlMapping::literal("double precision")))
342    }
343}
344
345unsafe impl SqlTranslatable for core::ffi::CStr {
346    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
347        Ok(SqlMapping::literal("cstring"))
348    }
349    fn return_sql() -> Result<Returns, ReturnsError> {
350        Ok(Returns::One(SqlMapping::literal("cstring")))
351    }
352}
353
354unsafe impl SqlTranslatable for &'static core::ffi::CStr {
355    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
356        Ok(SqlMapping::literal("cstring"))
357    }
358    fn return_sql() -> Result<Returns, ReturnsError> {
359        Ok(Returns::One(SqlMapping::literal("cstring")))
360    }
361}