Skip to main content

pgrx_sql_entity_graph/metadata/
sql_translatable.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12A trait denoting a type can possibly be mapped to an SQL type
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18use std::any::Any;
19use std::ffi::{CStr, CString};
20use std::fmt::Display;
21use thiserror::Error;
22
23use super::return_variant::ReturnsError;
24use super::{FunctionMetadataTypeEntity, Returns};
25
26#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Error)]
27pub enum ArgumentError {
28    #[error("Cannot use SetOfIterator as an argument")]
29    SetOf,
30    #[error("Cannot use TableIterator as an argument")]
31    Table,
32    #[error("Cannot use bare u8")]
33    BareU8,
34    #[error("SqlMapping::Skip inside Array is not valid")]
35    SkipInArray,
36    #[error("A Datum as an argument means that `sql = \"...\"` must be set in the declaration")]
37    Datum,
38    #[error("`{0}` is not able to be used as a function argument")]
39    NotValidAsArgument(&'static str),
40}
41
42/// Describes ways that Rust types are mapped into SQL
43#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
44pub enum SqlMapping {
45    /// Explicit mappings provided by PGRX
46    As(String),
47    Composite {
48        array_brackets: bool,
49    },
50    /// A type which does not actually appear in SQL
51    Skip,
52}
53
54impl SqlMapping {
55    pub fn literal(s: &'static str) -> SqlMapping {
56        SqlMapping::As(String::from(s))
57    }
58}
59
60/**
61A value which can be represented in SQL
62
63# Safety
64
65By implementing this, you assert you are not lying to either Postgres or Rust in doing so.
66This trait asserts a safe translation exists between values of this type from Rust to SQL,
67or from SQL into Rust. If you are mistaken about how this works, either the Postgres C API
68or the Rust handling in PGRX may emit undefined behavior.
69
70It cannot be made private or sealed due to details of the structure of the PGRX framework.
71Nonetheless, if you are not confident the translation is valid: do not implement this trait.
72*/
73#[diagnostic::on_unimplemented(
74    message = "`{Self}` has no representation in SQL",
75    label = "non-SQL type"
76)]
77pub unsafe trait SqlTranslatable {
78    fn type_name() -> &'static str {
79        core::any::type_name::<Self>()
80    }
81    fn argument_sql() -> Result<SqlMapping, ArgumentError>;
82    fn return_sql() -> Result<Returns, ReturnsError>;
83    fn variadic() -> bool {
84        false
85    }
86    fn optional() -> bool {
87        false
88    }
89    fn entity() -> FunctionMetadataTypeEntity {
90        FunctionMetadataTypeEntity {
91            type_name: Self::type_name(),
92            argument_sql: Self::argument_sql(),
93            return_sql: Self::return_sql(),
94            variadic: Self::variadic(),
95            optional: Self::optional(),
96        }
97    }
98}
99
100unsafe impl SqlTranslatable for () {
101    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
102        Err(ArgumentError::NotValidAsArgument("()"))
103    }
104
105    fn return_sql() -> Result<Returns, ReturnsError> {
106        Ok(Returns::One(SqlMapping::literal("VOID")))
107    }
108}
109
110unsafe impl<T> SqlTranslatable for Option<T>
111where
112    T: SqlTranslatable,
113{
114    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
115        T::argument_sql()
116    }
117    fn return_sql() -> Result<Returns, ReturnsError> {
118        T::return_sql()
119    }
120    fn optional() -> bool {
121        true
122    }
123}
124
125unsafe impl<T> SqlTranslatable for *mut T
126where
127    T: SqlTranslatable,
128{
129    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
130        T::argument_sql()
131    }
132    fn return_sql() -> Result<Returns, ReturnsError> {
133        T::return_sql()
134    }
135    fn optional() -> bool {
136        T::optional()
137    }
138}
139
140unsafe impl<T, E> SqlTranslatable for Result<T, E>
141where
142    T: SqlTranslatable,
143    E: Any + Display,
144{
145    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
146        T::argument_sql()
147    }
148    fn return_sql() -> Result<Returns, ReturnsError> {
149        T::return_sql()
150    }
151    fn optional() -> bool {
152        true
153    }
154}
155
156unsafe impl<T> SqlTranslatable for Vec<T>
157where
158    T: SqlTranslatable,
159{
160    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
161        match T::type_name() {
162            id if id == u8::type_name() => Ok(SqlMapping::As("bytea".into())),
163            _ => match T::argument_sql() {
164                Ok(SqlMapping::As(val)) => Ok(SqlMapping::As(format!("{val}[]"))),
165                Ok(SqlMapping::Composite { array_brackets: _ }) => {
166                    Ok(SqlMapping::Composite { array_brackets: true })
167                }
168                Ok(SqlMapping::Skip) => Ok(SqlMapping::Skip),
169                err @ Err(_) => err,
170            },
171        }
172    }
173
174    fn return_sql() -> Result<Returns, ReturnsError> {
175        match T::type_name() {
176            id if id == u8::type_name() => Ok(Returns::One(SqlMapping::As("bytea".into()))),
177            _ => match T::return_sql() {
178                Ok(Returns::One(SqlMapping::As(val))) => {
179                    Ok(Returns::One(SqlMapping::As(format!("{val}[]"))))
180                }
181                Ok(Returns::One(SqlMapping::Composite { array_brackets: _ })) => {
182                    Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
183                }
184                Ok(Returns::One(SqlMapping::Skip)) => Ok(Returns::One(SqlMapping::Skip)),
185                Ok(Returns::SetOf(_)) => Err(ReturnsError::SetOfInArray),
186                Ok(Returns::Table(_)) => Err(ReturnsError::TableInArray),
187                err @ Err(_) => err,
188            },
189        }
190    }
191    fn optional() -> bool {
192        T::optional()
193    }
194}
195
196unsafe impl SqlTranslatable for u8 {
197    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
198        Err(ArgumentError::BareU8)
199    }
200    fn return_sql() -> Result<Returns, ReturnsError> {
201        Err(ReturnsError::BareU8)
202    }
203}
204
205unsafe impl SqlTranslatable for i32 {
206    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
207        Ok(SqlMapping::literal("INT"))
208    }
209    fn return_sql() -> Result<Returns, ReturnsError> {
210        Ok(Returns::One(SqlMapping::literal("INT")))
211    }
212}
213
214unsafe impl SqlTranslatable for String {
215    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
216        Ok(SqlMapping::literal("TEXT"))
217    }
218    fn return_sql() -> Result<Returns, ReturnsError> {
219        Ok(Returns::One(SqlMapping::literal("TEXT")))
220    }
221}
222
223unsafe impl<T> SqlTranslatable for &T
224where
225    T: ?Sized + SqlTranslatable,
226{
227    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
228        T::argument_sql()
229    }
230    fn return_sql() -> Result<Returns, ReturnsError> {
231        T::return_sql()
232    }
233}
234
235unsafe impl SqlTranslatable for str {
236    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
237        Ok(SqlMapping::literal("TEXT"))
238    }
239    fn return_sql() -> Result<Returns, ReturnsError> {
240        Ok(Returns::One(SqlMapping::literal("TEXT")))
241    }
242}
243
244unsafe impl SqlTranslatable for [u8] {
245    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
246        Ok(SqlMapping::literal("bytea"))
247    }
248    fn return_sql() -> Result<Returns, ReturnsError> {
249        Ok(Returns::One(SqlMapping::literal("bytea")))
250    }
251}
252
253unsafe impl SqlTranslatable for i8 {
254    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
255        Ok(SqlMapping::As(String::from("\"char\"")))
256    }
257    fn return_sql() -> Result<Returns, ReturnsError> {
258        Ok(Returns::One(SqlMapping::As(String::from("\"char\""))))
259    }
260}
261
262unsafe impl SqlTranslatable for i16 {
263    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
264        Ok(SqlMapping::literal("smallint"))
265    }
266    fn return_sql() -> Result<Returns, ReturnsError> {
267        Ok(Returns::One(SqlMapping::literal("smallint")))
268    }
269}
270
271unsafe impl SqlTranslatable for i64 {
272    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
273        Ok(SqlMapping::literal("bigint"))
274    }
275    fn return_sql() -> Result<Returns, ReturnsError> {
276        Ok(Returns::One(SqlMapping::literal("bigint")))
277    }
278}
279
280unsafe impl SqlTranslatable for bool {
281    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
282        Ok(SqlMapping::literal("bool"))
283    }
284    fn return_sql() -> Result<Returns, ReturnsError> {
285        Ok(Returns::One(SqlMapping::literal("bool")))
286    }
287}
288
289unsafe impl SqlTranslatable for char {
290    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
291        Ok(SqlMapping::literal("varchar"))
292    }
293    fn return_sql() -> Result<Returns, ReturnsError> {
294        Ok(Returns::One(SqlMapping::literal("varchar")))
295    }
296}
297
298unsafe impl SqlTranslatable for f32 {
299    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
300        Ok(SqlMapping::literal("real"))
301    }
302    fn return_sql() -> Result<Returns, ReturnsError> {
303        Ok(Returns::One(SqlMapping::literal("real")))
304    }
305}
306
307unsafe impl SqlTranslatable for f64 {
308    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
309        Ok(SqlMapping::literal("double precision"))
310    }
311    fn return_sql() -> Result<Returns, ReturnsError> {
312        Ok(Returns::One(SqlMapping::literal("double precision")))
313    }
314}
315
316unsafe impl SqlTranslatable for CString {
317    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
318        Ok(SqlMapping::literal("cstring"))
319    }
320    fn return_sql() -> Result<Returns, ReturnsError> {
321        Ok(Returns::One(SqlMapping::literal("cstring")))
322    }
323}
324
325unsafe impl SqlTranslatable for CStr {
326    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
327        Ok(SqlMapping::literal("cstring"))
328    }
329    fn return_sql() -> Result<Returns, ReturnsError> {
330        Ok(Returns::One(SqlMapping::literal("cstring")))
331    }
332}