1mod argument;
19mod cast;
20mod operator;
21mod returning;
22
23pub use argument::PgExternArgumentEntity;
24pub use cast::PgCastEntity;
25pub use operator::PgOperatorEntity;
26pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem};
27
28use crate::fmt;
29use crate::metadata::{Returns, SqlArrayMapping, SqlMapping};
30use crate::pgrx_sql::PgrxSql;
31use crate::to_sql::ToSql;
32use crate::to_sql::entity::ToSqlConfigEntity;
33use crate::{ExternArgs, SqlGraphEntity, SqlGraphIdentifier};
34
35use eyre::{WrapErr, eyre};
36
37#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
39pub struct PgExternEntity<'a> {
40 pub name: &'a str,
41 pub unaliased_name: &'a str,
42 pub module_path: &'a str,
43 pub full_path: &'a str,
44 pub fn_args: Vec<PgExternArgumentEntity<'a>>,
45 pub fn_return: PgExternReturnEntity<'a>,
46 pub schema: Option<&'a str>,
47 pub file: &'a str,
48 pub line: u32,
49 pub extern_attrs: Vec<ExternArgs>,
50 pub search_path: Option<Vec<&'a str>>,
51 pub operator: Option<PgOperatorEntity<'a>>,
52 pub cast: Option<PgCastEntity>,
53 pub to_sql_config: ToSqlConfigEntity<'a>,
54}
55
56impl<'a> From<PgExternEntity<'a>> for SqlGraphEntity<'a> {
57 fn from(val: PgExternEntity<'a>) -> Self {
58 SqlGraphEntity::Function(val)
59 }
60}
61
62impl SqlGraphIdentifier for PgExternEntity<'_> {
63 fn dot_identifier(&self) -> String {
64 format!("fn {}", self.name)
65 }
66 fn rust_identifier(&self) -> String {
67 self.full_path.to_string()
68 }
69
70 fn file(&self) -> Option<&str> {
71 Some(self.file)
72 }
73
74 fn line(&self) -> Option<u32> {
75 Some(self.line)
76 }
77}
78
79impl PgExternEntity<'_> {
80 fn sql_name(&self, context: &PgrxSql) -> String {
81 let self_index = context.externs[self];
82 let schema = self
83 .schema
84 .map(|schema| format!("{schema}."))
85 .unwrap_or_else(|| context.schema_prefix_for(&self_index));
86
87 format!("{schema}\"{}\"", self.name)
88 }
89}
90
91fn composite_sql_type(composite_type: Option<&str>) -> eyre::Result<String> {
92 composite_type
93 .map(ToString::to_string)
94 .ok_or_else(|| eyre!("Composite mapping requires composite_type"))
95}
96
97fn array_sql_type(mapping: &SqlArrayMapping, composite_type: Option<&str>) -> eyre::Result<String> {
98 Ok(match mapping {
99 SqlArrayMapping::As(sql) => fmt::with_array_brackets(sql.clone(), 1),
100 SqlArrayMapping::Composite => {
101 fmt::with_array_brackets(composite_sql_type(composite_type)?, 1)
102 }
103 })
104}
105
106fn sql_type(mapping: &SqlMapping, composite_type: Option<&str>) -> eyre::Result<String> {
107 match mapping {
108 SqlMapping::As(sql) => Ok(sql.clone()),
109 SqlMapping::Composite => composite_sql_type(composite_type),
110 SqlMapping::Array(value) => array_sql_type(value, composite_type),
111 SqlMapping::Skip => Err(eyre!("Found a skipped SQL type where SQL should be emitted")),
112 }
113}
114
115impl ToSql for PgExternEntity<'_> {
116 fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
117 let self_index = context.externs[self];
118 let mut extern_attrs = self.extern_attrs.clone();
119 let mut strict_upgrade = !extern_attrs.iter().any(|i| i == &ExternArgs::Strict);
122 if strict_upgrade {
123 for arg in &self.fn_args {
126 if arg.used_ty.optional {
127 strict_upgrade = false;
128 }
129 }
130 }
131
132 if strict_upgrade {
133 extern_attrs.push(ExternArgs::Strict);
134 }
135 extern_attrs.sort();
136 extern_attrs.dedup();
137
138 let module_pathname = &context.get_module_pathname();
139 let schema = self
140 .schema
141 .map(|schema| format!("{schema}."))
142 .unwrap_or_else(|| context.schema_prefix_for(&self_index));
143 let arguments = if !self.fn_args.is_empty() {
144 let mut args = Vec::new();
145 let sql_args = self
146 .fn_args
147 .iter()
148 .filter(|arg| arg.used_ty.emits_argument_sql())
149 .collect::<Vec<_>>();
150 for (idx, arg) in sql_args.iter().enumerate() {
151 let needs_comma = idx < (sql_args.len().saturating_sub(1));
152 let schema_prefix = context.schema_prefix_for_used_type(
153 &self_index,
154 &format!("argument `{}`", arg.pattern),
155 &arg.used_ty,
156 )?;
157 match arg.used_ty.metadata.argument_sql {
158 Ok(SqlMapping::As(ref argument_sql)) => {
159 let buf = format!(
160 "\
161 \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
162 ",
163 pattern = arg.pattern,
164 schema_prefix = schema_prefix,
165 sql_type = argument_sql,
167 default = if let Some(def) = arg.used_ty.default {
168 format!(" DEFAULT {def}")
169 } else {
170 String::from("")
171 },
172 variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
173 maybe_comma = if needs_comma { ", " } else { " " },
174 type_name = arg.used_ty.full_path,
175 );
176 args.push(buf);
177 }
178 Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
179 let sql = sql_type(mapping, arg.used_ty.composite_type)?;
180 let buf = format!(
181 "\
182 \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
183 ",
184 pattern = arg.pattern,
185 schema_prefix = schema_prefix,
186 sql_type = sql,
188 default = if let Some(def) = arg.used_ty.default {
189 format!(" DEFAULT {def}")
190 } else {
191 String::from("")
192 },
193 variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
194 maybe_comma = if needs_comma { ", " } else { " " },
195 type_name = arg.used_ty.full_path,
196 );
197 args.push(buf);
198 }
199 Ok(SqlMapping::Skip) => (),
200 Err(err) => return Err(err).wrap_err("While mapping argument"),
201 }
202 }
203 String::from("\n") + &args.join("\n") + "\n"
204 } else {
205 Default::default()
206 };
207
208 let returns = match &self.fn_return {
209 PgExternReturnEntity::None => String::from("RETURNS void"),
210 PgExternReturnEntity::Type { ty } => {
211 let (schema_prefix, sql_type) = match &ty.metadata.return_sql {
212 Ok(Returns::One(SqlMapping::As(sql))) => (
213 context.schema_prefix_for_used_type(&self_index, "return type", ty)?,
214 sql.clone(),
215 ),
216 Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => (
217 context.schema_prefix_for_used_type(&self_index, "return type", ty)?,
218 sql_type(mapping, ty.composite_type)?,
219 ),
220 Ok(other) => {
221 return Err(eyre!(
222 "Got non-plain mapped/composite return variant SQL in what macro-expansion thought was a type, got: {other:?}"
223 ));
224 }
225 Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
226 };
227 format!(
228 "RETURNS {schema_prefix}{sql_type} /* {full_path} */",
229 full_path = ty.full_path
230 )
231 }
232 PgExternReturnEntity::SetOf { ty, .. } => {
233 let (schema_prefix, sql_type) = match &ty.metadata.return_sql {
234 Ok(Returns::One(SqlMapping::As(sql)))
235 | Ok(Returns::SetOf(SqlMapping::As(sql))) => (
236 context.schema_prefix_for_used_type(
237 &self_index,
238 "setof return type",
239 ty,
240 )?,
241 sql.clone(),
242 ),
243 Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_))))
244 | Ok(Returns::SetOf(
245 mapping @ (SqlMapping::Composite | SqlMapping::Array(_)),
246 )) => (
247 context.schema_prefix_for_used_type(
248 &self_index,
249 "setof return type",
250 ty,
251 )?,
252 sql_type(mapping, ty.composite_type)?,
253 ),
254 Ok(other) => {
255 return Err(eyre!(
256 "Got non-scalar mapped/composite return variant SQL in what macro-expansion thought was a setof item, got: {other:?}"
257 ));
258 }
259 Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
260 };
261 format!(
262 "RETURNS SETOF {schema_prefix}{sql_type} /* {full_path} */",
263 full_path = ty.full_path
264 )
265 }
266 PgExternReturnEntity::Iterated { tys: table_items, .. } => {
267 let mut items = String::new();
268 for (idx, PgExternReturnEntityIteratedItem { ty, name: col_name }) in
269 table_items.iter().enumerate()
270 {
271 let needs_comma = idx < (table_items.len() - 1);
272 let (schema_prefix, ty_resolved) = match &ty.metadata.return_sql {
273 Ok(Returns::One(SqlMapping::As(sql))) => (
274 context.schema_prefix_for_used_type(
275 &self_index,
276 "table return column",
277 ty,
278 )?,
279 sql.clone(),
280 ),
281 Ok(Returns::One(
282 mapping @ (SqlMapping::Composite | SqlMapping::Array(_)),
283 )) => (
284 context.schema_prefix_for_used_type(
285 &self_index,
286 "table return column",
287 ty,
288 )?,
289 sql_type(mapping, ty.composite_type)?,
290 ),
291 Ok(other) => {
292 return Err(eyre!(
293 "Got non-scalar table return item SQL in what macro-expansion thought was a table, got: {other:?}"
294 ));
295 }
296 Err(err) => return Err(*err).wrap_err("Error mapping return SQL"),
297 };
298 let item = format!(
299 "\n\t{col_name} {schema_prefix}{ty_resolved}{needs_comma} /* {ty_name} */",
300 col_name = col_name.expect(
301 "An iterator of tuples should have `named!()` macro declarations."
302 ),
303 schema_prefix = schema_prefix,
304 ty_resolved = ty_resolved,
305 needs_comma = if needs_comma { ", " } else { " " },
306 ty_name = ty.full_path
307 );
308 items.push_str(&item);
309 }
310 format!("RETURNS TABLE ({items}\n)")
311 }
312 PgExternReturnEntity::Trigger => String::from("RETURNS trigger"),
313 };
314 let PgExternEntity { name, module_path, file, line, .. } = self;
315
316 let fn_sql = format!(
317 "\
318 CREATE {or_replace} FUNCTION {schema}\"{name}\"({arguments}) {returns}\n\
319 {extern_attrs}\
320 {search_path}\
321 LANGUAGE c /* Rust */\n\
322 AS '{module_pathname}', '{unaliased_name}_wrapper';\
323 ",
324 or_replace =
325 if extern_attrs.contains(&ExternArgs::CreateOrReplace) { "OR REPLACE" } else { "" },
326 search_path = if let Some(search_path) = &self.search_path {
327 let retval = format!("SET search_path TO {}", search_path.join(", "));
328 retval + "\n"
329 } else {
330 Default::default()
331 },
332 extern_attrs = if extern_attrs.is_empty() {
333 String::default()
334 } else {
335 let mut retval = extern_attrs
336 .iter()
337 .filter(|attr| **attr != ExternArgs::CreateOrReplace)
338 .map(|attr| {
339 if matches!(attr, ExternArgs::Support(..)) {
340 let support_fn_name = attr.to_string();
341
342 let support_fn_name =
343 if let Some(entity) = context.find_matching_fn(&support_fn_name) {
344 entity.sql_name(context)
345 } else {
346 panic!("cannot locate SUPPORT function `{support_fn_name}` attached to function `{}`", self.full_path)
347 };
348
349 format!("SUPPORT {support_fn_name}")
350 } else {
351 attr.to_string().to_uppercase()
352 }
353 })
354 .collect::<Vec<_>>()
355 .join(" ");
356 retval.push('\n');
357 retval
358 },
359 unaliased_name = self.unaliased_name,
360 );
361
362 let requires = {
363 let requires_attrs = self
364 .extern_attrs
365 .iter()
366 .filter_map(|x| match x {
367 ExternArgs::Requires(requirements) => Some(requirements.clone()),
368 ExternArgs::Support(support_fn) => Some(vec![support_fn.clone()]),
369 _ => None,
370 })
371 .flatten()
372 .collect::<Vec<_>>();
373
374 if !requires_attrs.is_empty() {
375 format!(
376 "-- requires:\n{}\n",
377 requires_attrs
378 .iter()
379 .map(|i| format!("-- {i}"))
380 .collect::<Vec<_>>()
381 .join("\n")
382 )
383 } else {
384 "".to_string()
385 }
386 };
387
388 let mut ext_sql = format!(
389 "\n\
390 -- {file}:{line}\n\
391 -- {module_path}::{name}\n\
392 {requires}\
393 {fn_sql}"
394 );
395
396 if let Some(op) = &self.operator {
397 let mut optionals = vec![];
398 if let Some(it) = op.commutator {
399 optionals.push(format!("\tCOMMUTATOR = {it}"));
400 };
401 if let Some(it) = op.negator {
402 optionals.push(format!("\tNEGATOR = {it}"));
403 };
404 if let Some(it) = op.restrict {
405 optionals.push(format!("\tRESTRICT = {it}"));
406 };
407 if let Some(it) = op.join {
408 optionals.push(format!("\tJOIN = {it}"));
409 };
410 if op.hashes {
411 optionals.push(String::from("\tHASHES"));
412 };
413 if op.merges {
414 optionals.push(String::from("\tMERGES"));
415 };
416
417 let left_arg = self
418 .fn_args
419 .first()
420 .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
421 let left_arg_schema_prefix = context.schema_prefix_for_used_type(
422 &self_index,
423 "operator left argument",
424 &left_arg.used_ty,
425 )?;
426 let left_arg_sql = match left_arg.used_ty.metadata.argument_sql {
427 Ok(SqlMapping::As(ref sql)) => sql.clone(),
428 Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
429 sql_type(mapping, left_arg.used_ty.composite_type)?
430 }
431 Ok(SqlMapping::Skip) => {
432 return Err(eyre!(
433 "Found an skipped SQL type in an operator, this is not valid"
434 ));
435 }
436 Err(err) => return Err(err.into()),
437 };
438
439 let right_arg = self
440 .fn_args
441 .get(1)
442 .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
443 let right_arg_schema_prefix = context.schema_prefix_for_used_type(
444 &self_index,
445 "operator right argument",
446 &right_arg.used_ty,
447 )?;
448 let right_arg_sql = match right_arg.used_ty.metadata.argument_sql {
449 Ok(SqlMapping::As(ref sql)) => sql.clone(),
450 Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
451 sql_type(mapping, right_arg.used_ty.composite_type)?
452 }
453 Ok(SqlMapping::Skip) => {
454 return Err(eyre!(
455 "Found an skipped SQL type in an operator, this is not valid"
456 ));
457 }
458 Err(err) => return Err(err.into()),
459 };
460
461 let schema = self
462 .schema
463 .map(|schema| format!("{schema}."))
464 .unwrap_or_else(|| context.schema_prefix_for(&self_index));
465
466 let operator_sql = format!(
467 "\n\n\
468 -- {file}:{line}\n\
469 -- {module_path}::{name}\n\
470 CREATE OPERATOR {schema}{opname} (\n\
471 \tPROCEDURE={schema}\"{name}\",\n\
472 \tLEFTARG={schema_prefix_left}{left_arg_sql}, /* {left_name} */\n\
473 \tRIGHTARG={schema_prefix_right}{right_arg_sql}{maybe_comma} /* {right_name} */\n\
474 {optionals}\
475 );\
476 ",
477 opname = op.opname.unwrap(),
478 left_name = left_arg.used_ty.full_path,
479 right_name = right_arg.used_ty.full_path,
480 schema_prefix_left = left_arg_schema_prefix,
481 schema_prefix_right = right_arg_schema_prefix,
482 maybe_comma = if !optionals.is_empty() { "," } else { "" },
483 optionals = if !optionals.is_empty() {
484 optionals.join(",\n") + "\n"
485 } else {
486 "".to_string()
487 },
488 );
489 ext_sql += &operator_sql
490 };
491 if let Some(cast) = &self.cast {
492 let target_fn_arg = &self.fn_return;
493 let target_ty = match target_fn_arg {
494 PgExternReturnEntity::Type { ty } => ty,
495 other => {
496 return Err(eyre!("Casts must return a plain type, got: {other:?}"));
497 }
498 };
499 let target_arg_schema_prefix =
500 context.schema_prefix_for_used_type(&self_index, "cast target type", target_ty)?;
501 let target_arg_sql = match &target_ty.metadata.return_sql {
502 Ok(Returns::One(SqlMapping::As(sql))) => sql.clone(),
503 Ok(Returns::One(mapping @ (SqlMapping::Composite | SqlMapping::Array(_)))) => {
504 sql_type(mapping, target_ty.composite_type)?
505 }
506 Ok(Returns::One(SqlMapping::Skip)) => {
507 return Err(eyre!("Found an skipped SQL type in a cast, this is not valid"));
508 }
509 Err(err) => return Err((*err).into()),
510 Ok(other) => {
511 return Err(eyre!("Casts must return a plain SQL type, got: {other:?}"));
512 }
513 };
514 let source_arg = self
515 .fn_args
516 .first()
517 .ok_or_else(|| eyre!("Did not find source type for cast `{}`.", self.name))?;
518 let source_arg_schema_prefix = context.schema_prefix_for_used_type(
519 &self_index,
520 "cast source type",
521 &source_arg.used_ty,
522 )?;
523 let source_arg_sql = match source_arg.used_ty.metadata.argument_sql {
524 Ok(SqlMapping::As(ref sql)) => sql.clone(),
525 Ok(ref mapping @ (SqlMapping::Composite | SqlMapping::Array(_))) => {
526 sql_type(mapping, source_arg.used_ty.composite_type)?
527 }
528 Ok(SqlMapping::Skip) => {
529 return Err(eyre!("Found an skipped SQL type in a cast, this is not valid"));
530 }
531 Err(err) => return Err(err.into()),
532 };
533 let optional = match cast {
534 PgCastEntity::Default => String::from(""),
535 PgCastEntity::Assignment => String::from(" AS ASSIGNMENT"),
536 PgCastEntity::Implicit => String::from(" AS IMPLICIT"),
537 };
538
539 let cast_sql = format!(
540 "\n\n\
541 -- {file}:{line}\n\
542 -- {module_path}::{name}\n\
543 CREATE CAST (\n\
544 \t{schema_prefix_source}{source_arg_sql} /* {source_name} */\n\
545 \tAS\n\
546 \t{schema_prefix_target}{target_arg_sql} /* {target_name} */\n\
547 )\n\
548 WITH FUNCTION {function_name}{optional};\
549 ",
550 file = self.file,
551 line = self.line,
552 name = self.name,
553 module_path = self.module_path,
554 schema_prefix_source = source_arg_schema_prefix,
555 source_name = source_arg.used_ty.full_path,
556 schema_prefix_target = target_arg_schema_prefix,
557 target_name = target_ty.full_path,
558 function_name = self.name,
559 );
560 ext_sql += &cast_sql
561 };
562 Ok(ext_sql)
563 }
564}