pgrx_sql_entity_graph/metadata/
sql_translatable.rs1use std::any::Any;
19use std::ffi::{CStr, CString};
20use std::fmt::Display;
21use thiserror::Error;
22
23use super::return_variant::ReturnsError;
24use super::{FunctionMetadataTypeEntity, Returns};
25
26#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Error)]
27pub enum ArgumentError {
28 #[error("Cannot use SetOfIterator as an argument")]
29 SetOf,
30 #[error("Cannot use TableIterator as an argument")]
31 Table,
32 #[error("Cannot use bare u8")]
33 BareU8,
34 #[error("SqlMapping::Skip inside Array is not valid")]
35 SkipInArray,
36 #[error("A Datum as an argument means that `sql = \"...\"` must be set in the declaration")]
37 Datum,
38 #[error("`{0}` is not able to be used as a function argument")]
39 NotValidAsArgument(&'static str),
40}
41
42#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
44pub enum SqlMapping {
45 As(String),
47 Composite {
48 array_brackets: bool,
49 },
50 Skip,
52}
53
54impl SqlMapping {
55 pub fn literal(s: &'static str) -> SqlMapping {
56 SqlMapping::As(String::from(s))
57 }
58}
59
60#[diagnostic::on_unimplemented(
74 message = "`{Self}` has no representation in SQL",
75 label = "non-SQL type"
76)]
77pub unsafe trait SqlTranslatable {
78 fn type_name() -> &'static str {
79 core::any::type_name::<Self>()
80 }
81 fn argument_sql() -> Result<SqlMapping, ArgumentError>;
82 fn return_sql() -> Result<Returns, ReturnsError>;
83 fn variadic() -> bool {
84 false
85 }
86 fn optional() -> bool {
87 false
88 }
89 fn entity() -> FunctionMetadataTypeEntity {
90 FunctionMetadataTypeEntity {
91 type_name: Self::type_name(),
92 argument_sql: Self::argument_sql(),
93 return_sql: Self::return_sql(),
94 variadic: Self::variadic(),
95 optional: Self::optional(),
96 }
97 }
98}
99
100unsafe impl SqlTranslatable for () {
101 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
102 Err(ArgumentError::NotValidAsArgument("()"))
103 }
104
105 fn return_sql() -> Result<Returns, ReturnsError> {
106 Ok(Returns::One(SqlMapping::literal("VOID")))
107 }
108}
109
110unsafe impl<T> SqlTranslatable for Option<T>
111where
112 T: SqlTranslatable,
113{
114 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
115 T::argument_sql()
116 }
117 fn return_sql() -> Result<Returns, ReturnsError> {
118 T::return_sql()
119 }
120 fn optional() -> bool {
121 true
122 }
123}
124
125unsafe impl<T> SqlTranslatable for *mut T
126where
127 T: SqlTranslatable,
128{
129 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
130 T::argument_sql()
131 }
132 fn return_sql() -> Result<Returns, ReturnsError> {
133 T::return_sql()
134 }
135 fn optional() -> bool {
136 T::optional()
137 }
138}
139
140unsafe impl<T, E> SqlTranslatable for Result<T, E>
141where
142 T: SqlTranslatable,
143 E: Any + Display,
144{
145 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
146 T::argument_sql()
147 }
148 fn return_sql() -> Result<Returns, ReturnsError> {
149 T::return_sql()
150 }
151 fn optional() -> bool {
152 true
153 }
154}
155
156unsafe impl<T> SqlTranslatable for Vec<T>
157where
158 T: SqlTranslatable,
159{
160 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
161 match T::type_name() {
162 id if id == u8::type_name() => Ok(SqlMapping::As("bytea".into())),
163 _ => match T::argument_sql() {
164 Ok(SqlMapping::As(val)) => Ok(SqlMapping::As(format!("{val}[]"))),
165 Ok(SqlMapping::Composite { array_brackets: _ }) => {
166 Ok(SqlMapping::Composite { array_brackets: true })
167 }
168 Ok(SqlMapping::Skip) => Ok(SqlMapping::Skip),
169 err @ Err(_) => err,
170 },
171 }
172 }
173
174 fn return_sql() -> Result<Returns, ReturnsError> {
175 match T::type_name() {
176 id if id == u8::type_name() => Ok(Returns::One(SqlMapping::As("bytea".into()))),
177 _ => match T::return_sql() {
178 Ok(Returns::One(SqlMapping::As(val))) => {
179 Ok(Returns::One(SqlMapping::As(format!("{val}[]"))))
180 }
181 Ok(Returns::One(SqlMapping::Composite { array_brackets: _ })) => {
182 Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
183 }
184 Ok(Returns::One(SqlMapping::Skip)) => Ok(Returns::One(SqlMapping::Skip)),
185 Ok(Returns::SetOf(_)) => Err(ReturnsError::SetOfInArray),
186 Ok(Returns::Table(_)) => Err(ReturnsError::TableInArray),
187 err @ Err(_) => err,
188 },
189 }
190 }
191 fn optional() -> bool {
192 T::optional()
193 }
194}
195
196unsafe impl SqlTranslatable for u8 {
197 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
198 Err(ArgumentError::BareU8)
199 }
200 fn return_sql() -> Result<Returns, ReturnsError> {
201 Err(ReturnsError::BareU8)
202 }
203}
204
205unsafe impl SqlTranslatable for i32 {
206 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
207 Ok(SqlMapping::literal("INT"))
208 }
209 fn return_sql() -> Result<Returns, ReturnsError> {
210 Ok(Returns::One(SqlMapping::literal("INT")))
211 }
212}
213
214unsafe impl SqlTranslatable for String {
215 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
216 Ok(SqlMapping::literal("TEXT"))
217 }
218 fn return_sql() -> Result<Returns, ReturnsError> {
219 Ok(Returns::One(SqlMapping::literal("TEXT")))
220 }
221}
222
223unsafe impl<T> SqlTranslatable for &T
224where
225 T: ?Sized + SqlTranslatable,
226{
227 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
228 T::argument_sql()
229 }
230 fn return_sql() -> Result<Returns, ReturnsError> {
231 T::return_sql()
232 }
233}
234
235unsafe impl SqlTranslatable for str {
236 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
237 Ok(SqlMapping::literal("TEXT"))
238 }
239 fn return_sql() -> Result<Returns, ReturnsError> {
240 Ok(Returns::One(SqlMapping::literal("TEXT")))
241 }
242}
243
244unsafe impl SqlTranslatable for [u8] {
245 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
246 Ok(SqlMapping::literal("bytea"))
247 }
248 fn return_sql() -> Result<Returns, ReturnsError> {
249 Ok(Returns::One(SqlMapping::literal("bytea")))
250 }
251}
252
253unsafe impl SqlTranslatable for i8 {
254 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
255 Ok(SqlMapping::As(String::from("\"char\"")))
256 }
257 fn return_sql() -> Result<Returns, ReturnsError> {
258 Ok(Returns::One(SqlMapping::As(String::from("\"char\""))))
259 }
260}
261
262unsafe impl SqlTranslatable for i16 {
263 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
264 Ok(SqlMapping::literal("smallint"))
265 }
266 fn return_sql() -> Result<Returns, ReturnsError> {
267 Ok(Returns::One(SqlMapping::literal("smallint")))
268 }
269}
270
271unsafe impl SqlTranslatable for i64 {
272 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
273 Ok(SqlMapping::literal("bigint"))
274 }
275 fn return_sql() -> Result<Returns, ReturnsError> {
276 Ok(Returns::One(SqlMapping::literal("bigint")))
277 }
278}
279
280unsafe impl SqlTranslatable for bool {
281 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
282 Ok(SqlMapping::literal("bool"))
283 }
284 fn return_sql() -> Result<Returns, ReturnsError> {
285 Ok(Returns::One(SqlMapping::literal("bool")))
286 }
287}
288
289unsafe impl SqlTranslatable for char {
290 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
291 Ok(SqlMapping::literal("varchar"))
292 }
293 fn return_sql() -> Result<Returns, ReturnsError> {
294 Ok(Returns::One(SqlMapping::literal("varchar")))
295 }
296}
297
298unsafe impl SqlTranslatable for f32 {
299 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
300 Ok(SqlMapping::literal("real"))
301 }
302 fn return_sql() -> Result<Returns, ReturnsError> {
303 Ok(Returns::One(SqlMapping::literal("real")))
304 }
305}
306
307unsafe impl SqlTranslatable for f64 {
308 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
309 Ok(SqlMapping::literal("double precision"))
310 }
311 fn return_sql() -> Result<Returns, ReturnsError> {
312 Ok(Returns::One(SqlMapping::literal("double precision")))
313 }
314}
315
316unsafe impl SqlTranslatable for CString {
317 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
318 Ok(SqlMapping::literal("cstring"))
319 }
320 fn return_sql() -> Result<Returns, ReturnsError> {
321 Ok(Returns::One(SqlMapping::literal("cstring")))
322 }
323}
324
325unsafe impl SqlTranslatable for CStr {
326 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
327 Ok(SqlMapping::literal("cstring"))
328 }
329 fn return_sql() -> Result<Returns, ReturnsError> {
330 Ok(Returns::One(SqlMapping::literal("cstring")))
331 }
332}