1use std::any::Any;
19use std::ffi::{CStr, CString};
20use std::fmt::Display;
21use thiserror::Error;
22
23use super::return_variant::ReturnsError;
24use super::{FunctionMetadataTypeEntity, Returns, TypeOrigin};
25
26#[derive(Clone, Copy, Debug, Hash, Ord, PartialOrd, PartialEq, Eq, Error)]
27pub enum ArgumentError {
28 #[error("Cannot use SetOfIterator as an argument")]
29 SetOf,
30 #[error("Cannot use TableIterator as an argument")]
31 Table,
32 #[error("Nested arrays are not supported in arguments")]
33 NestedArray,
34 #[error("Cannot use bare u8")]
35 BareU8,
36 #[error("SqlMapping::Skip inside Array is not valid")]
37 SkipInArray,
38 #[error("A Datum as an argument means that `sql = \"...\"` must be set in the declaration")]
39 Datum,
40 #[error("`{0}` is not able to be used as a function argument")]
41 NotValidAsArgument(&'static str),
42}
43
44#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
46pub enum SqlMapping {
47 As(String),
49 Composite,
50 Array(SqlArrayMapping),
51 Skip,
53}
54
55impl SqlMapping {
56 pub fn literal(s: &'static str) -> SqlMapping {
57 SqlMapping::As(String::from(s))
58 }
59}
60
61#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
62pub enum SqlArrayMapping {
63 As(String),
65 Composite,
66}
67
68#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
70pub enum SqlMappingRef {
71 As(&'static str),
73 Numeric {
74 precision: Option<u32>,
75 scale: Option<u32>,
76 },
77 Composite,
78 Array(SqlArrayMappingRef),
79 Skip,
81}
82
83impl SqlMappingRef {
84 pub const fn literal(s: &'static str) -> Self {
85 Self::As(s)
86 }
87}
88
89#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
90pub enum SqlArrayMappingRef {
91 As(&'static str),
93 Numeric {
94 precision: Option<u32>,
95 scale: Option<u32>,
96 },
97 Composite,
98}
99
100pub(crate) fn numeric_sql_string(precision: Option<u32>, scale: Option<u32>) -> String {
101 match (precision, scale) {
102 (None, _) => "NUMERIC".to_string(),
103 (Some(precision), None) => format!("NUMERIC({precision})"),
104 (Some(precision), Some(scale)) => format!("NUMERIC({precision}, {scale})"),
105 }
106}
107
108impl From<SqlArrayMappingRef> for SqlArrayMapping {
109 fn from(value: SqlArrayMappingRef) -> Self {
110 match value {
111 SqlArrayMappingRef::As(value) => SqlArrayMapping::As(String::from(value)),
112 SqlArrayMappingRef::Numeric { precision, scale } => {
113 SqlArrayMapping::As(numeric_sql_string(precision, scale))
114 }
115 SqlArrayMappingRef::Composite => SqlArrayMapping::Composite,
116 }
117 }
118}
119
120impl From<SqlMappingRef> for SqlMapping {
121 fn from(value: SqlMappingRef) -> Self {
122 match value {
123 SqlMappingRef::As(value) => SqlMapping::literal(value),
124 SqlMappingRef::Numeric { precision, scale } => {
125 SqlMapping::As(numeric_sql_string(precision, scale))
126 }
127 SqlMappingRef::Composite => SqlMapping::Composite,
128 SqlMappingRef::Array(value) => SqlMapping::Array(value.into()),
129 SqlMappingRef::Skip => SqlMapping::Skip,
130 }
131 }
132}
133
134#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
136pub enum ReturnsRef {
137 One(SqlMappingRef),
138 SetOf(SqlMappingRef),
139 Table(&'static [SqlMappingRef]),
140}
141
142impl From<ReturnsRef> for Returns {
143 fn from(value: ReturnsRef) -> Self {
144 match value {
145 ReturnsRef::One(value) => Returns::One(value.into()),
146 ReturnsRef::SetOf(value) => Returns::SetOf(value.into()),
147 ReturnsRef::Table(values) => {
148 Returns::Table(values.iter().copied().map(Into::into).collect())
149 }
150 }
151 }
152}
153
154pub const fn array_argument_sql(
155 mapping: Result<SqlMappingRef, ArgumentError>,
156) -> Result<SqlMappingRef, ArgumentError> {
157 match mapping {
158 Ok(SqlMappingRef::As(sql)) => Ok(SqlMappingRef::Array(SqlArrayMappingRef::As(sql))),
159 Ok(SqlMappingRef::Numeric { precision, scale }) => {
160 Ok(SqlMappingRef::Array(SqlArrayMappingRef::Numeric { precision, scale }))
161 }
162 Ok(SqlMappingRef::Composite) => Ok(SqlMappingRef::Array(SqlArrayMappingRef::Composite)),
163 Ok(SqlMappingRef::Skip) => Err(ArgumentError::SkipInArray),
164 Ok(SqlMappingRef::Array(_)) => Err(ArgumentError::NestedArray),
165 Err(err) => Err(err),
166 }
167}
168
169pub const fn array_return_sql(
170 returns: Result<ReturnsRef, ReturnsError>,
171) -> Result<ReturnsRef, ReturnsError> {
172 match returns {
173 Ok(ReturnsRef::One(SqlMappingRef::As(sql))) => {
174 Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::As(sql))))
175 }
176 Ok(ReturnsRef::One(SqlMappingRef::Numeric { precision, scale })) => {
177 Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
178 precision,
179 scale,
180 })))
181 }
182 Ok(ReturnsRef::One(SqlMappingRef::Composite)) => {
183 Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Composite)))
184 }
185 Ok(ReturnsRef::One(SqlMappingRef::Skip)) => Err(ReturnsError::SkipInArray),
186 Ok(ReturnsRef::One(SqlMappingRef::Array(_))) => Err(ReturnsError::NestedArray),
187 Ok(ReturnsRef::SetOf(_)) => Err(ReturnsError::SetOfInArray),
188 Ok(ReturnsRef::Table(_)) => Err(ReturnsError::TableInArray),
189 Err(err) => Err(err),
190 }
191}
192
193pub const fn setof_return_sql(
194 returns: Result<ReturnsRef, ReturnsError>,
195) -> Result<ReturnsRef, ReturnsError> {
196 match returns {
197 Ok(ReturnsRef::One(sql)) => Ok(ReturnsRef::SetOf(sql)),
198 Ok(ReturnsRef::SetOf(_)) => Err(ReturnsError::NestedSetOf),
199 Ok(ReturnsRef::Table(_)) => Err(ReturnsError::SetOfContainingTable),
200 Err(err) => Err(err),
201 }
202}
203
204pub const fn table_item_sql(
205 returns: Result<ReturnsRef, ReturnsError>,
206) -> Result<SqlMappingRef, ReturnsError> {
207 match returns {
208 Ok(ReturnsRef::One(sql)) => Ok(sql),
209 Ok(ReturnsRef::SetOf(_)) => Err(ReturnsError::TableContainingSetOf),
210 Ok(ReturnsRef::Table(_)) => Err(ReturnsError::NestedTable),
211 Err(err) => Err(err),
212 }
213}
214
215#[macro_export]
249macro_rules! impl_sql_translatable {
250 ($ty:ty, $sql:literal) => {
251 unsafe impl $crate::metadata::SqlTranslatable for $ty {
252 const TYPE_IDENT: &'static str = $crate::pgrx_resolved_type!($ty);
253 const TYPE_ORIGIN: $crate::metadata::TypeOrigin =
254 $crate::metadata::TypeOrigin::External;
255 const ARGUMENT_SQL: Result<
256 $crate::metadata::SqlMappingRef,
257 $crate::metadata::ArgumentError,
258 > = Ok($crate::metadata::SqlMappingRef::literal($sql));
259 const RETURN_SQL: Result<$crate::metadata::ReturnsRef, $crate::metadata::ReturnsError> =
260 Ok($crate::metadata::ReturnsRef::One($crate::metadata::SqlMappingRef::literal(
261 $sql,
262 )));
263 }
264 };
265 ($ty:ty, arg_only = $sql:literal) => {
266 unsafe impl $crate::metadata::SqlTranslatable for $ty {
267 const TYPE_IDENT: &'static str = $crate::pgrx_resolved_type!($ty);
268 const TYPE_ORIGIN: $crate::metadata::TypeOrigin =
269 $crate::metadata::TypeOrigin::External;
270 const ARGUMENT_SQL: Result<
271 $crate::metadata::SqlMappingRef,
272 $crate::metadata::ArgumentError,
273 > = Ok($crate::metadata::SqlMappingRef::literal($sql));
274 const RETURN_SQL: Result<$crate::metadata::ReturnsRef, $crate::metadata::ReturnsError> =
275 Err($crate::metadata::ReturnsError::Datum);
276 }
277 };
278}
279
280#[diagnostic::on_unimplemented(
298 message = "`{Self}` has no representation in SQL",
299 label = "non-SQL type"
300)]
301pub unsafe trait SqlTranslatable {
302 const TYPE_IDENT: &'static str;
303 const TYPE_ORIGIN: TypeOrigin;
304 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError>;
305 const RETURN_SQL: Result<ReturnsRef, ReturnsError>;
306
307 fn type_name() -> &'static str {
308 core::any::type_name::<Self>()
309 }
310 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
311 Self::ARGUMENT_SQL.map(Into::into)
312 }
313 fn return_sql() -> Result<Returns, ReturnsError> {
314 Self::RETURN_SQL.map(Into::into)
315 }
316 fn entity() -> FunctionMetadataTypeEntity<'static> {
317 FunctionMetadataTypeEntity::resolved(
318 Self::TYPE_IDENT,
319 Self::TYPE_ORIGIN,
320 Self::argument_sql(),
321 Self::return_sql(),
322 )
323 }
324}
325
326unsafe impl SqlTranslatable for () {
327 const TYPE_IDENT: &'static str = crate::pgrx_resolved_type!(());
328 const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
329 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
330 Err(ArgumentError::NotValidAsArgument("()"));
331 const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
332 Ok(ReturnsRef::One(SqlMappingRef::literal("VOID")));
333}
334
335unsafe impl<T> SqlTranslatable for Option<T>
336where
337 T: SqlTranslatable,
338{
339 const TYPE_IDENT: &'static str = T::TYPE_IDENT;
340 const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
341 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
342 const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
343}
344
345unsafe impl<T> SqlTranslatable for *mut T
346where
347 T: SqlTranslatable,
348{
349 const TYPE_IDENT: &'static str = T::TYPE_IDENT;
350 const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
351 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
352 const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
353}
354
355unsafe impl<T, E> SqlTranslatable for Result<T, E>
356where
357 T: SqlTranslatable,
358 E: Any + Display,
359{
360 const TYPE_IDENT: &'static str = T::TYPE_IDENT;
361 const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
362 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
363 const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
364}
365
366unsafe impl<T> SqlTranslatable for Vec<T>
367where
368 T: SqlTranslatable,
369{
370 const TYPE_IDENT: &'static str = T::TYPE_IDENT;
371 const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
372 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = match T::ARGUMENT_SQL {
373 Err(ArgumentError::BareU8) => Ok(SqlMappingRef::As("bytea")),
374 other => array_argument_sql(other),
375 };
376 const RETURN_SQL: Result<ReturnsRef, ReturnsError> = match T::RETURN_SQL {
377 Err(ReturnsError::BareU8) => Ok(ReturnsRef::One(SqlMappingRef::As("bytea"))),
378 other => array_return_sql(other),
379 };
380}
381
382unsafe impl SqlTranslatable for u8 {
383 const TYPE_IDENT: &'static str = crate::pgrx_resolved_type!(u8);
384 const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
385 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = Err(ArgumentError::BareU8);
386 const RETURN_SQL: Result<ReturnsRef, ReturnsError> = Err(ReturnsError::BareU8);
387}
388
389macro_rules! simple_sql_type {
390 ($ty:ty, $sql:literal) => {
391 unsafe impl SqlTranslatable for $ty {
392 const TYPE_IDENT: &'static str = $crate::pgrx_resolved_type!($ty);
393 const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
394 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
395 Ok(SqlMappingRef::literal($sql));
396 const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
397 Ok(ReturnsRef::One(SqlMappingRef::literal($sql)));
398 }
399 };
400}
401
402simple_sql_type!(i32, "INT");
403simple_sql_type!(String, "TEXT");
404simple_sql_type!(str, "TEXT");
405simple_sql_type!([u8], "bytea");
406simple_sql_type!(i8, "\"char\"");
407simple_sql_type!(i16, "smallint");
408simple_sql_type!(i64, "bigint");
409simple_sql_type!(bool, "bool");
410simple_sql_type!(char, "varchar");
411simple_sql_type!(f32, "real");
412simple_sql_type!(f64, "double precision");
413simple_sql_type!(CString, "cstring");
414simple_sql_type!(CStr, "cstring");
415
416unsafe impl<T> SqlTranslatable for &T
417where
418 T: ?Sized + SqlTranslatable,
419{
420 const TYPE_IDENT: &'static str = T::TYPE_IDENT;
421 const TYPE_ORIGIN: TypeOrigin = T::TYPE_ORIGIN;
422 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
423 const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 struct MacroExternalType;
431 impl_sql_translatable!(MacroExternalType, "uuid");
432
433 struct MacroArgOnlyType;
434 impl_sql_translatable!(MacroArgOnlyType, arg_only = "internal");
435
436 #[test]
437 fn impl_sql_translatable_sets_external_defaults() {
438 assert_eq!(
439 <MacroExternalType as SqlTranslatable>::TYPE_IDENT,
440 concat!(module_path!(), "::", "MacroExternalType")
441 );
442 assert_eq!(<MacroExternalType as SqlTranslatable>::TYPE_ORIGIN, TypeOrigin::External);
443 assert_eq!(
444 <MacroExternalType as SqlTranslatable>::ARGUMENT_SQL,
445 Ok(SqlMappingRef::literal("uuid"))
446 );
447 assert_eq!(
448 <MacroExternalType as SqlTranslatable>::RETURN_SQL,
449 Ok(ReturnsRef::One(SqlMappingRef::literal("uuid")))
450 );
451 }
452
453 #[test]
454 fn impl_sql_translatable_supports_arg_only_types() {
455 assert_eq!(
456 <MacroArgOnlyType as SqlTranslatable>::TYPE_IDENT,
457 concat!(module_path!(), "::", "MacroArgOnlyType")
458 );
459 assert_eq!(<MacroArgOnlyType as SqlTranslatable>::TYPE_ORIGIN, TypeOrigin::External);
460 assert_eq!(
461 <MacroArgOnlyType as SqlTranslatable>::ARGUMENT_SQL,
462 Ok(SqlMappingRef::literal("internal"))
463 );
464 assert_eq!(<MacroArgOnlyType as SqlTranslatable>::RETURN_SQL, Err(ReturnsError::Datum));
465 }
466
467 #[test]
468 fn array_argument_sql_wraps_scalar_kinds() {
469 assert_eq!(
470 array_argument_sql(Ok(SqlMappingRef::literal("INT"))),
471 Ok(SqlMappingRef::Array(SqlArrayMappingRef::As("INT")))
472 );
473 assert_eq!(
474 array_argument_sql(Ok(SqlMappingRef::Numeric { precision: Some(10), scale: Some(2) })),
475 Ok(SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
476 precision: Some(10),
477 scale: Some(2),
478 }))
479 );
480 assert_eq!(
481 array_argument_sql(Ok(SqlMappingRef::Composite)),
482 Ok(SqlMappingRef::Array(SqlArrayMappingRef::Composite))
483 );
484 }
485
486 #[test]
487 fn array_return_sql_wraps_scalar_kinds() {
488 assert_eq!(
489 array_return_sql(Ok(ReturnsRef::One(SqlMappingRef::literal("INT")))),
490 Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::As("INT"))))
491 );
492 assert_eq!(
493 array_return_sql(Ok(ReturnsRef::One(SqlMappingRef::Numeric {
494 precision: Some(10),
495 scale: Some(2),
496 }))),
497 Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
498 precision: Some(10),
499 scale: Some(2),
500 })))
501 );
502 assert_eq!(
503 array_return_sql(Ok(ReturnsRef::One(SqlMappingRef::Composite))),
504 Ok(ReturnsRef::One(SqlMappingRef::Array(SqlArrayMappingRef::Composite)))
505 );
506 }
507
508 #[test]
509 fn nested_vec_arrays_fail_fast() {
510 assert_eq!(
511 <Vec<Vec<i32>> as SqlTranslatable>::ARGUMENT_SQL,
512 Err(ArgumentError::NestedArray)
513 );
514 assert_eq!(<Vec<Vec<i32>> as SqlTranslatable>::RETURN_SQL, Err(ReturnsError::NestedArray));
515 }
516
517 #[test]
518 fn nested_numeric_arrays_fail_fast() {
519 let numeric = SqlMappingRef::Array(SqlArrayMappingRef::Numeric {
520 precision: Some(10),
521 scale: Some(2),
522 });
523 assert_eq!(array_argument_sql(Ok(numeric)), Err(ArgumentError::NestedArray));
524 }
525
526 #[test]
527 fn nested_composite_arrays_fail_fast() {
528 let composite = SqlMappingRef::Array(SqlArrayMappingRef::Composite);
529 assert_eq!(
530 array_return_sql(Ok(ReturnsRef::One(composite))),
531 Err(ReturnsError::NestedArray)
532 );
533 }
534}