pgrx_sql_entity_graph/metadata/
sql_translatable.rs1use std::any::Any;
19use std::ffi::{CStr, CString};
20use std::fmt::Display;
21use thiserror::Error;
22
23use super::return_variant::ReturnsError;
24use super::{FunctionMetadataTypeEntity, Returns};
25
26#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Error)]
27pub enum ArgumentError {
28    #[error("Cannot use SetOfIterator as an argument")]
29    SetOf,
30    #[error("Cannot use TableIterator as an argument")]
31    Table,
32    #[error("Cannot use bare u8")]
33    BareU8,
34    #[error("SqlMapping::Skip inside Array is not valid")]
35    SkipInArray,
36    #[error("A Datum as an argument means that `sql = \"...\"` must be set in the declaration")]
37    Datum,
38    #[error("`{0}` is not able to be used as a function argument")]
39    NotValidAsArgument(&'static str),
40}
41
42#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
44pub enum SqlMapping {
45    As(String),
47    Composite {
48        array_brackets: bool,
49    },
50    Skip,
52}
53
54impl SqlMapping {
55    pub fn literal(s: &'static str) -> SqlMapping {
56        SqlMapping::As(String::from(s))
57    }
58}
59
60pub unsafe trait SqlTranslatable {
74    fn type_name() -> &'static str {
75        core::any::type_name::<Self>()
76    }
77    fn argument_sql() -> Result<SqlMapping, ArgumentError>;
78    fn return_sql() -> Result<Returns, ReturnsError>;
79    fn variadic() -> bool {
80        false
81    }
82    fn optional() -> bool {
83        false
84    }
85    fn entity() -> FunctionMetadataTypeEntity {
86        FunctionMetadataTypeEntity {
87            type_name: Self::type_name(),
88            argument_sql: Self::argument_sql(),
89            return_sql: Self::return_sql(),
90            variadic: Self::variadic(),
91            optional: Self::optional(),
92        }
93    }
94}
95
96unsafe impl SqlTranslatable for () {
97    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
98        Err(ArgumentError::NotValidAsArgument("()"))
99    }
100
101    fn return_sql() -> Result<Returns, ReturnsError> {
102        Ok(Returns::One(SqlMapping::literal("VOID")))
103    }
104}
105
106unsafe impl<T> SqlTranslatable for Option<T>
107where
108    T: SqlTranslatable,
109{
110    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
111        T::argument_sql()
112    }
113    fn return_sql() -> Result<Returns, ReturnsError> {
114        T::return_sql()
115    }
116    fn optional() -> bool {
117        true
118    }
119}
120
121unsafe impl<T> SqlTranslatable for *mut T
122where
123    T: SqlTranslatable,
124{
125    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
126        T::argument_sql()
127    }
128    fn return_sql() -> Result<Returns, ReturnsError> {
129        T::return_sql()
130    }
131    fn optional() -> bool {
132        T::optional()
133    }
134}
135
136unsafe impl<T, E> SqlTranslatable for Result<T, E>
137where
138    T: SqlTranslatable,
139    E: Any + Display,
140{
141    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
142        T::argument_sql()
143    }
144    fn return_sql() -> Result<Returns, ReturnsError> {
145        T::return_sql()
146    }
147    fn optional() -> bool {
148        true
149    }
150}
151
152unsafe impl<T> SqlTranslatable for Vec<T>
153where
154    T: SqlTranslatable,
155{
156    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
157        match T::type_name() {
158            id if id == u8::type_name() => Ok(SqlMapping::As("bytea".into())),
159            _ => match T::argument_sql() {
160                Ok(SqlMapping::As(val)) => Ok(SqlMapping::As(format!("{val}[]"))),
161                Ok(SqlMapping::Composite { array_brackets: _ }) => {
162                    Ok(SqlMapping::Composite { array_brackets: true })
163                }
164                Ok(SqlMapping::Skip) => Ok(SqlMapping::Skip),
165                err @ Err(_) => err,
166            },
167        }
168    }
169
170    fn return_sql() -> Result<Returns, ReturnsError> {
171        match T::type_name() {
172            id if id == u8::type_name() => Ok(Returns::One(SqlMapping::As("bytea".into()))),
173            _ => match T::return_sql() {
174                Ok(Returns::One(SqlMapping::As(val))) => {
175                    Ok(Returns::One(SqlMapping::As(format!("{val}[]"))))
176                }
177                Ok(Returns::One(SqlMapping::Composite { array_brackets: _ })) => {
178                    Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
179                }
180                Ok(Returns::One(SqlMapping::Skip)) => Ok(Returns::One(SqlMapping::Skip)),
181                Ok(Returns::SetOf(_)) => Err(ReturnsError::SetOfInArray),
182                Ok(Returns::Table(_)) => Err(ReturnsError::TableInArray),
183                err @ Err(_) => err,
184            },
185        }
186    }
187    fn optional() -> bool {
188        T::optional()
189    }
190}
191
192unsafe impl SqlTranslatable for u8 {
193    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
194        Err(ArgumentError::BareU8)
195    }
196    fn return_sql() -> Result<Returns, ReturnsError> {
197        Err(ReturnsError::BareU8)
198    }
199}
200
201unsafe impl SqlTranslatable for i32 {
202    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
203        Ok(SqlMapping::literal("INT"))
204    }
205    fn return_sql() -> Result<Returns, ReturnsError> {
206        Ok(Returns::One(SqlMapping::literal("INT")))
207    }
208}
209
210unsafe impl SqlTranslatable for String {
211    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
212        Ok(SqlMapping::literal("TEXT"))
213    }
214    fn return_sql() -> Result<Returns, ReturnsError> {
215        Ok(Returns::One(SqlMapping::literal("TEXT")))
216    }
217}
218
219unsafe impl<T> SqlTranslatable for &T
220where
221    T: ?Sized + SqlTranslatable,
222{
223    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
224        T::argument_sql()
225    }
226    fn return_sql() -> Result<Returns, ReturnsError> {
227        T::return_sql()
228    }
229}
230
231unsafe impl SqlTranslatable for str {
232    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
233        Ok(SqlMapping::literal("TEXT"))
234    }
235    fn return_sql() -> Result<Returns, ReturnsError> {
236        Ok(Returns::One(SqlMapping::literal("TEXT")))
237    }
238}
239
240unsafe impl SqlTranslatable for [u8] {
241    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
242        Ok(SqlMapping::literal("bytea"))
243    }
244    fn return_sql() -> Result<Returns, ReturnsError> {
245        Ok(Returns::One(SqlMapping::literal("bytea")))
246    }
247}
248
249unsafe impl SqlTranslatable for i8 {
250    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
251        Ok(SqlMapping::As(String::from("\"char\"")))
252    }
253    fn return_sql() -> Result<Returns, ReturnsError> {
254        Ok(Returns::One(SqlMapping::As(String::from("\"char\""))))
255    }
256}
257
258unsafe impl SqlTranslatable for i16 {
259    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
260        Ok(SqlMapping::literal("smallint"))
261    }
262    fn return_sql() -> Result<Returns, ReturnsError> {
263        Ok(Returns::One(SqlMapping::literal("smallint")))
264    }
265}
266
267unsafe impl SqlTranslatable for i64 {
268    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
269        Ok(SqlMapping::literal("bigint"))
270    }
271    fn return_sql() -> Result<Returns, ReturnsError> {
272        Ok(Returns::One(SqlMapping::literal("bigint")))
273    }
274}
275
276unsafe impl SqlTranslatable for bool {
277    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
278        Ok(SqlMapping::literal("bool"))
279    }
280    fn return_sql() -> Result<Returns, ReturnsError> {
281        Ok(Returns::One(SqlMapping::literal("bool")))
282    }
283}
284
285unsafe impl SqlTranslatable for char {
286    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
287        Ok(SqlMapping::literal("varchar"))
288    }
289    fn return_sql() -> Result<Returns, ReturnsError> {
290        Ok(Returns::One(SqlMapping::literal("varchar")))
291    }
292}
293
294unsafe impl SqlTranslatable for f32 {
295    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
296        Ok(SqlMapping::literal("real"))
297    }
298    fn return_sql() -> Result<Returns, ReturnsError> {
299        Ok(Returns::One(SqlMapping::literal("real")))
300    }
301}
302
303unsafe impl SqlTranslatable for f64 {
304    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
305        Ok(SqlMapping::literal("double precision"))
306    }
307    fn return_sql() -> Result<Returns, ReturnsError> {
308        Ok(Returns::One(SqlMapping::literal("double precision")))
309    }
310}
311
312unsafe impl SqlTranslatable for CString {
313    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
314        Ok(SqlMapping::literal("cstring"))
315    }
316    fn return_sql() -> Result<Returns, ReturnsError> {
317        Ok(Returns::One(SqlMapping::literal("cstring")))
318    }
319}
320
321unsafe impl SqlTranslatable for CStr {
322    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
323        Ok(SqlMapping::literal("cstring"))
324    }
325    fn return_sql() -> Result<Returns, ReturnsError> {
326        Ok(Returns::One(SqlMapping::literal("cstring")))
327    }
328}