1mod argument;
18mod operator;
19mod returning;
20
21pub use argument::PgExternArgumentEntity;
22pub use operator::PgOperatorEntity;
23pub use returning::{PgExternReturnEntity, PgExternReturnEntityIteratedItem};
24
25use crate::sql_entity_graph::metadata::{Returns, SqlMapping};
26use crate::sql_entity_graph::pgx_sql::PgxSql;
27use crate::sql_entity_graph::to_sql::entity::ToSqlConfigEntity;
28use crate::sql_entity_graph::to_sql::ToSql;
29use crate::sql_entity_graph::{SqlGraphEntity, SqlGraphIdentifier};
30use crate::ExternArgs;
31
32use eyre::{eyre, WrapErr};
33use std::cmp::Ordering;
34
35#[derive(Debug, Clone)]
37pub struct PgExternEntity {
38 pub name: &'static str,
39 pub unaliased_name: &'static str,
40 pub module_path: &'static str,
41 pub full_path: &'static str,
42 pub metadata: crate::sql_entity_graph::metadata::FunctionMetadataEntity,
43 pub fn_args: Vec<PgExternArgumentEntity>,
44 pub fn_return: PgExternReturnEntity,
45 pub schema: Option<&'static str>,
46 pub file: &'static str,
47 pub line: u32,
48 pub extern_attrs: Vec<ExternArgs>,
49 pub search_path: Option<Vec<&'static str>>,
50 pub operator: Option<PgOperatorEntity>,
51 pub to_sql_config: ToSqlConfigEntity,
52}
53
54impl std::hash::Hash for PgExternEntity {
55 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
56 self.full_path.hash(state);
57 }
58}
59
60impl PartialEq for PgExternEntity {
61 fn eq(&self, other: &Self) -> bool {
62 self.full_path.eq(other.full_path)
63 }
64}
65
66impl Eq for PgExternEntity {}
67
68impl Ord for PgExternEntity {
69 fn cmp(&self, other: &Self) -> Ordering {
70 self.full_path.cmp(&other.full_path)
71 }
72}
73
74impl PartialOrd for PgExternEntity {
75 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
76 Some(self.cmp(other))
77 }
78}
79
80impl From<PgExternEntity> for SqlGraphEntity {
81 fn from(val: PgExternEntity) -> Self {
82 SqlGraphEntity::Function(val)
83 }
84}
85
86impl SqlGraphIdentifier for PgExternEntity {
87 fn dot_identifier(&self) -> String {
88 format!("fn {}", self.name)
89 }
90 fn rust_identifier(&self) -> String {
91 self.full_path.to_string()
92 }
93
94 fn file(&self) -> Option<&'static str> {
95 Some(self.file)
96 }
97
98 fn line(&self) -> Option<u32> {
99 Some(self.line)
100 }
101}
102
103impl ToSql for PgExternEntity {
104 #[tracing::instrument(
105 level = "error",
106 skip(self, context),
107 fields(identifier = %self.rust_identifier()),
108 )]
109 fn to_sql(&self, context: &PgxSql) -> eyre::Result<String> {
110 let self_index = context.externs[self];
111 let mut extern_attrs = self.extern_attrs.clone();
112 let mut strict_upgrade = !extern_attrs.iter().any(|i| i == &ExternArgs::Strict);
115 if strict_upgrade {
116 for arg in &self.metadata.arguments {
119 if arg.optional {
120 strict_upgrade = false;
121 }
122 }
123 }
124
125 if strict_upgrade {
126 extern_attrs.push(ExternArgs::Strict);
127 }
128 extern_attrs.sort();
129 extern_attrs.dedup();
130
131 let module_pathname = &context.get_module_pathname();
132
133 let fn_sql = format!(
134 "\
135 CREATE {or_replace} FUNCTION {schema}\"{name}\"({arguments}) {returns}\n\
136 {extern_attrs}\
137 {search_path}\
138 LANGUAGE c /* Rust */\n\
139 AS '{module_pathname}', '{unaliased_name}_wrapper';\
140 ",
141 or_replace = if extern_attrs.contains(&ExternArgs::CreateOrReplace) { "OR REPLACE" } else { "" },
142 schema = self
143 .schema
144 .map(|schema| format!("{}.", schema))
145 .unwrap_or_else(|| context.schema_prefix_for(&self_index)),
146 name = self.name,
147 module_pathname = module_pathname,
148 arguments = if !self.fn_args.is_empty() {
149 let mut args = Vec::new();
150 let metadata_without_arg_skips = &self
151 .metadata
152 .arguments
153 .iter()
154 .filter(|v| v.argument_sql != Ok(SqlMapping::Skip))
155 .collect::<Vec<_>>();
156 for (idx, arg) in self.fn_args.iter().enumerate() {
157 let graph_index = context
158 .graph
159 .neighbors_undirected(self_index)
160 .find(|neighbor| match &context.graph[*neighbor] {
161 SqlGraphEntity::Type(ty) => ty.id_matches(&arg.used_ty.ty_id),
162 SqlGraphEntity::Enum(en) => en.id_matches(&arg.used_ty.ty_id),
163 SqlGraphEntity::BuiltinType(defined) => {
164 defined == arg.used_ty.full_path
165 }
166 _ => false,
167 })
168 .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
169 let needs_comma = idx < (metadata_without_arg_skips.len().saturating_sub(1));
170 let metadata_argument = &self.metadata.arguments[idx];
171 match metadata_argument.argument_sql {
172 Ok(SqlMapping::As(ref argument_sql)) => {
173 let buf = format!("\
174 \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
175 ",
176 pattern = arg.pattern,
177 schema_prefix = context.schema_prefix_for(&graph_index),
178 sql_type = argument_sql,
180 default = if let Some(def) = arg.used_ty.default { format!(" DEFAULT {}", def) } else { String::from("") },
181 variadic = if metadata_argument.variadic { "VARIADIC " } else { "" },
182 maybe_comma = if needs_comma { ", " } else { " " },
183 type_name = metadata_argument.type_name,
184 );
185 args.push(buf);
186 }
187 Ok(SqlMapping::Composite {
188 array_brackets,
189 }) => {
190 let sql = self.fn_args[idx]
191 .used_ty
192 .composite_type
193 .map(|v| {
194 if array_brackets {
195 format!("{v}[]")
196 } else {
197 format!("{v}")
198 }
199 })
200 .ok_or_else(|| {
201 eyre!(
202 "Macro expansion time suggested a composite_type!() in return"
203 )
204 })?;
205 let buf = format!("\
206 \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
207 ",
208 pattern = arg.pattern,
209 schema_prefix = context.schema_prefix_for(&graph_index),
210 sql_type = sql,
212 default = if let Some(def) = arg.used_ty.default { format!(" DEFAULT {}", def) } else { String::from("") },
213 variadic = if metadata_argument.variadic { "VARIADIC " } else { "" },
214 maybe_comma = if needs_comma { ", " } else { " " },
215 type_name = metadata_argument.type_name,
216 );
217 args.push(buf);
218 }
219 Ok(SqlMapping::Source {
220 array_brackets,
221 }) => {
222 let sql = context
223 .source_only_to_sql_type(arg.used_ty.ty_source)
224 .map(|v| {
225 if array_brackets {
226 format!("{v}[]")
227 } else {
228 format!("{v}")
229 }
230 })
231 .ok_or_else(|| {
232 eyre!(
233 "Macro expansion time suggested a source only mapping in return"
234 )
235 })?;
236 let buf = format!("\
237 \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
238 ",
239 pattern = arg.pattern,
240 schema_prefix = context.schema_prefix_for(&graph_index),
241 sql_type = sql,
243 default = if let Some(def) = arg.used_ty.default { format!(" DEFAULT {}", def) } else { String::from("") },
244 variadic = if metadata_argument.variadic { "VARIADIC " } else { "" },
245 maybe_comma = if needs_comma { ", " } else { " " },
246 type_name = metadata_argument.type_name,
247 );
248 args.push(buf);
249 }
250 Ok(SqlMapping::Skip) => (),
251 Err(err) => {
252 match context.source_only_to_sql_type(arg.used_ty.ty_source) {
253 Some(source_only_mapping) => {
254 let buf = format!("\
255 \t\"{pattern}\" {variadic}{schema_prefix}{sql_type}{default}{maybe_comma}/* {type_name} */\
256 ",
257 pattern = arg.pattern,
258 schema_prefix = context.schema_prefix_for(&graph_index),
259 sql_type = source_only_mapping,
261 default = if let Some(def) = arg.used_ty.default { format!(" DEFAULT {}", def) } else { String::from("") },
262 variadic = if metadata_argument.variadic { "VARIADIC " } else { "" },
263 maybe_comma = if needs_comma { ", " } else { " " },
264 type_name = metadata_argument.type_name,
265 );
266 args.push(buf);
267 }
268 None => return Err(err).wrap_err("While mapping argument"),
269 }
270 }
271 }
272 }
273 String::from("\n") + &args.join("\n") + "\n"
274 } else {
275 Default::default()
276 },
277 returns = match &self.fn_return {
278 PgExternReturnEntity::None => String::from("RETURNS void"),
279 PgExternReturnEntity::Type { ty } => {
280 let graph_index = context
281 .graph
282 .neighbors_undirected(self_index)
283 .find(|neighbor| match &context.graph[*neighbor] {
284 SqlGraphEntity::Type(neighbor_ty) => neighbor_ty.id_matches(&ty.ty_id),
285 SqlGraphEntity::Enum(neighbor_en) => neighbor_en.id_matches(&ty.ty_id),
286 SqlGraphEntity::BuiltinType(defined) => &*defined == ty.full_path,
287 _ => false,
288 })
289 .ok_or_else(|| eyre!("Could not find return type in graph."))?;
290 let metadata_retval = self.metadata.retval.clone().ok_or_else(|| eyre!("Macro expansion time and SQL resolution time had differing opinions about the return value existing"))?;
291 let metadata_retval_sql = match metadata_retval.return_sql {
292 Ok(Returns::One(SqlMapping::As(ref sql))) => sql.clone(),
293 Ok(Returns::One(SqlMapping::Composite { array_brackets })) => ty.composite_type.unwrap().to_string()
294 + if array_brackets {
295 "[]"
296 } else {
297 ""
298 },
299 Ok(Returns::SetOf(SqlMapping::Source { array_brackets })) =>
300 context.source_only_to_sql_type(ty.ty_source).unwrap().to_string() + if array_brackets {
301 "[]"
302 } else {
303 ""
304 },
305 Ok(other) => return Err(eyre!("Got non-plain mapped/composite return variant SQL in what macro-expansion thought was a type, got: {other:?}")),
306 Err(err) => {
307 match context.source_only_to_sql_type(ty.ty_source) {
308 Some(source_only_mapping) => source_only_mapping,
309 None => return Err(err).wrap_err("Error mapping return SQL")
310 }
311 },
312 };
313 format!(
314 "RETURNS {schema_prefix}{sql_type} /* {full_path} */",
315 sql_type = metadata_retval_sql,
316 schema_prefix = context.schema_prefix_for(&graph_index),
317 full_path = ty.full_path
318 )
319 }
320 PgExternReturnEntity::SetOf { ty, optional: _ } => {
321 let graph_index = context
322 .graph
323 .neighbors_undirected(self_index)
324 .find(|neighbor| match &context.graph[*neighbor] {
325 SqlGraphEntity::Type(neighbor_ty) => neighbor_ty.id_matches(&ty.ty_id),
326 SqlGraphEntity::Enum(neighbor_en) => neighbor_en.id_matches(&ty.ty_id),
327 SqlGraphEntity::BuiltinType(defined) => defined == ty.full_path,
328 _ => false,
329 })
330 .ok_or_else(|| eyre!("Could not find return type in graph."))?;
331 let metadata_retval = self.metadata.retval.clone().ok_or_else(|| eyre!("Macro expansion time and SQL resolution time had differing opinions about the return value existing"))?;
332 let metadata_retval_sql = match metadata_retval.return_sql {
333 Ok(Returns::SetOf(SqlMapping::As(ref sql))) => sql.clone(),
334 Ok(Returns::SetOf(SqlMapping::Composite { array_brackets })) =>
335 ty.composite_type.unwrap().to_string() + if array_brackets {
336 "[]"
337 } else {
338 ""
339 },
340 Ok(Returns::SetOf(SqlMapping::Source { array_brackets })) =>
341 context.source_only_to_sql_type(ty.ty_source).unwrap().to_string() + if array_brackets {
342 "[]"
343 } else {
344 ""
345 },
346 Ok(_other) => return Err(eyre!("Got non-setof mapped/composite return variant SQL in what macro-expansion thought was a setof")),
347 Err(err) => return Err(err).wrap_err("Error mapping return SQL"),
348 };
349 format!(
350 "RETURNS SETOF {schema_prefix}{sql_type} /* {full_path} */",
351 sql_type = metadata_retval_sql,
352 schema_prefix = context.schema_prefix_for(&graph_index),
353 full_path = ty.full_path
354 )
355 }
356 PgExternReturnEntity::Iterated {
357 tys: table_items,
358 optional: _,
359 } => {
360 let mut items = String::new();
361 let metadata_retval = self.metadata.retval.clone().ok_or_else(|| eyre!("Macro expansion time and SQL resolution time had differing opinions about the return value existing"))?;
362 let metadata_retval_sqls = match metadata_retval.return_sql {
363 Ok(Returns::Table(variants)) => {
364 let mut retval_sqls = vec![];
365 for (idx, variant) in variants.iter().enumerate() {
366 let sql = match variant {
367 SqlMapping::As(sql) => sql.clone(),
368 SqlMapping::Composite { array_brackets } => {
369 let composite = table_items[idx].ty.composite_type.unwrap().to_string();
370 composite + if *array_brackets {
371 "[]"
372 } else {
373 ""
374 }
375 },
376 SqlMapping::Source { array_brackets } =>
377 context.source_only_to_sql_type(table_items[idx].ty.ty_source).unwrap() + if *array_brackets {
378 "[]"
379 } else {
380 ""
381 },
382 SqlMapping::Skip => todo!(),
383 };
384 retval_sqls.push(sql)
385 }
386 retval_sqls
387 },
388 Ok(_other) => return Err(eyre!("Got non-table return variant SQL in what macro-expansion thought was a table")),
389 Err(err) => return Err(err).wrap_err("Error mapping return SQL"),
390 };
391
392 for (idx, returning::PgExternReturnEntityIteratedItem { ty, name: col_name }) in
393 table_items.iter().enumerate()
394 {
395 let graph_index =
396 context
397 .graph
398 .neighbors_undirected(self_index)
399 .find(|neighbor| match &context.graph[*neighbor] {
400 SqlGraphEntity::Type(neightbor_ty) => {
401 neightbor_ty.id_matches(&ty.ty_id)
402 }
403 SqlGraphEntity::Enum(neightbor_en) => {
404 neightbor_en.id_matches(&ty.ty_id)
405 }
406 SqlGraphEntity::BuiltinType(defined) => defined == ty.ty_source,
407 _ => false,
408 });
409
410 let needs_comma = idx < (table_items.len() - 1);
411 let item = format!(
412 "\n\t{col_name} {schema_prefix}{ty_resolved}{needs_comma} /* {ty_name} */",
413 col_name = col_name.expect("An iterator of tuples should have `named!()` macro declarations."),
414 schema_prefix = if let Some(graph_index) = graph_index {
415 context.schema_prefix_for(&graph_index)
416 } else { "".into() },
417 ty_resolved = metadata_retval_sqls[idx],
418 needs_comma = if needs_comma { ", " } else { " " },
419 ty_name = ty.full_path
420 );
421 items.push_str(&item);
422 }
423 format!("RETURNS TABLE ({}\n)", items)
424 }
425 PgExternReturnEntity::Trigger => String::from("RETURNS trigger"),
426 },
427 search_path = if let Some(search_path) = &self.search_path {
428 let retval = format!("SET search_path TO {}", search_path.join(", "));
429 retval + "\n"
430 } else {
431 Default::default()
432 },
433 extern_attrs = if extern_attrs.is_empty() {
434 String::default()
435 } else {
436 let mut retval = extern_attrs
437 .iter()
438 .filter(|attr| **attr != ExternArgs::CreateOrReplace)
439 .map(|attr| format!("{}", attr).to_uppercase())
440 .collect::<Vec<_>>()
441 .join(" ");
442 retval.push('\n');
443 retval
444 },
445 unaliased_name = self.unaliased_name,
446 );
447
448 let ext_sql = format!(
449 "\n\
450 -- {file}:{line}\n\
451 -- {module_path}::{name}\n\
452 {requires}\
453 {fn_sql}\
454 ",
455 name = self.name,
456 module_path = self.module_path,
457 file = self.file,
458 line = self.line,
459 fn_sql = fn_sql,
460 requires = {
461 let requires_attrs = self
462 .extern_attrs
463 .iter()
464 .filter_map(|x| match x {
465 ExternArgs::Requires(requirements) => Some(requirements),
466 _ => None,
467 })
468 .flatten()
469 .collect::<Vec<_>>();
470 if !requires_attrs.is_empty() {
471 format!(
472 "\
473 -- requires:\n\
474 {}\n\
475 ",
476 requires_attrs
477 .iter()
478 .map(|i| format!("-- {}", i))
479 .collect::<Vec<_>>()
480 .join("\n")
481 )
482 } else {
483 "".to_string()
484 }
485 },
486 );
487 tracing::trace!(sql = %ext_sql);
488
489 let rendered = if let Some(op) = &self.operator {
490 let mut optionals = vec![];
491 if let Some(it) = op.commutator {
492 optionals.push(format!("\tCOMMUTATOR = {}", it));
493 };
494 if let Some(it) = op.negator {
495 optionals.push(format!("\tNEGATOR = {}", it));
496 };
497 if let Some(it) = op.restrict {
498 optionals.push(format!("\tRESTRICT = {}", it));
499 };
500 if let Some(it) = op.join {
501 optionals.push(format!("\tJOIN = {}", it));
502 };
503 if op.hashes {
504 optionals.push(String::from("\tHASHES"));
505 };
506 if op.merges {
507 optionals.push(String::from("\tMERGES"));
508 };
509
510 let left_arg =
511 self.metadata.arguments.get(0).ok_or_else(|| {
512 eyre!("Did not find `left_arg` for operator `{}`.", self.name)
513 })?;
514 let left_fn_arg = self
515 .fn_args
516 .get(0)
517 .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
518 let left_arg_graph_index = context
519 .graph
520 .neighbors_undirected(self_index)
521 .find(|neighbor| match &context.graph[*neighbor] {
522 SqlGraphEntity::Type(ty) => ty.id_matches(&left_fn_arg.used_ty.ty_id),
523 SqlGraphEntity::Enum(en) => en.id_matches(&left_fn_arg.used_ty.ty_id),
524 SqlGraphEntity::BuiltinType(defined) => defined == &left_arg.type_name,
525 _ => false,
526 })
527 .ok_or_else(|| {
528 eyre!("Could not find left arg type in graph. Got: {:?}", left_arg)
529 })?;
530 let left_arg_sql = match left_arg.argument_sql {
531 Ok(SqlMapping::As(ref sql)) => sql.clone(),
532 Ok(SqlMapping::Composite { array_brackets }) => {
533 if array_brackets {
534 let composite_type = self.fn_args[0].used_ty.composite_type
535 .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgx::composite_type!()`"))?;
536 format!("{composite_type}[]")
537 } else {
538 self.fn_args[0].used_ty.composite_type
539 .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgx::composite_type!()`"))?.to_string()
540 }
541 }
542 Ok(SqlMapping::Source { array_brackets }) => {
543 if array_brackets {
544 let composite_type = context
545 .source_only_to_sql_type(self.fn_args[0].used_ty.ty_source)
546 .ok_or(eyre!(
547 "Found a source only mapping but no source mapping exists for this"
548 ))?;
549 format!("{composite_type}[]")
550 } else {
551 context.source_only_to_sql_type(self.fn_args[0].used_ty.ty_source)
552 .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgx::composite_type!()`"))?.to_string()
553 }
554 }
555 Ok(SqlMapping::Skip) => {
556 return Err(eyre!(
557 "Found an skipped SQL type in an operator, this is not valid"
558 ))
559 }
560 Err(err) => return Err(err.into()),
561 };
562
563 let right_arg =
564 self.metadata.arguments.get(1).ok_or_else(|| {
565 eyre!("Did not find `left_arg` for operator `{}`.", self.name)
566 })?;
567 let right_fn_arg = self
568 .fn_args
569 .get(1)
570 .ok_or_else(|| eyre!("Did not find `left_arg` for operator `{}`.", self.name))?;
571 let right_arg_graph_index = context
572 .graph
573 .neighbors_undirected(self_index)
574 .find(|neighbor| match &context.graph[*neighbor] {
575 SqlGraphEntity::Type(ty) => ty.id_matches(&right_fn_arg.used_ty.ty_id),
576 SqlGraphEntity::Enum(en) => en.id_matches(&right_fn_arg.used_ty.ty_id),
577 SqlGraphEntity::BuiltinType(defined) => defined == &right_arg.type_name,
578 _ => false,
579 })
580 .ok_or_else(|| {
581 eyre!("Could not find right arg type in graph. Got: {:?}", right_arg)
582 })?;
583 let right_arg_sql = match right_arg.argument_sql {
584 Ok(SqlMapping::As(ref sql)) => sql.clone(),
585 Ok(SqlMapping::Composite { array_brackets }) => {
586 if array_brackets {
587 let composite_type = self.fn_args[1].used_ty.composite_type
588 .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgx::composite_type!()`"))?;
589 format!("{composite_type}[]")
590 } else {
591 self.fn_args[0].used_ty.composite_type
592 .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgx::composite_type!()`"))?.to_string()
593 }
594 }
595 Ok(SqlMapping::Source { array_brackets }) => {
596 if array_brackets {
597 let composite_type = context
598 .source_only_to_sql_type(self.fn_args[1].used_ty.ty_source)
599 .ok_or(eyre!(
600 "Found a source only mapping but no source mapping exists for this"
601 ))?;
602 format!("{composite_type}[]")
603 } else {
604 context.source_only_to_sql_type(self.fn_args[1].used_ty.ty_source)
605 .ok_or(eyre!("Found a composite type but macro expansion time did not reveal a name, use `pgx::composite_type!()`"))?.to_string()
606 }
607 }
608 Ok(SqlMapping::Skip) => {
609 return Err(eyre!(
610 "Found an skipped SQL type in an operator, this is not valid"
611 ))
612 }
613 Err(err) => return Err(err.into()),
614 };
615
616 let operator_sql = format!("\n\n\
617 -- {file}:{line}\n\
618 -- {module_path}::{name}\n\
619 CREATE OPERATOR {opname} (\n\
620 \tPROCEDURE=\"{name}\",\n\
621 \tLEFTARG={schema_prefix_left}{left_arg}, /* {left_name} */\n\
622 \tRIGHTARG={schema_prefix_right}{right_arg}{maybe_comma} /* {right_name} */\n\
623 {optionals}\
624 );\
625 ",
626 opname = op.opname.unwrap(),
627 file = self.file,
628 line = self.line,
629 name = self.name,
630 module_path = self.module_path,
631 left_name = left_arg.type_name,
632 right_name = right_arg.type_name,
633 schema_prefix_left = context.schema_prefix_for(&left_arg_graph_index),
634 left_arg = left_arg_sql,
635 schema_prefix_right = context.schema_prefix_for(&right_arg_graph_index),
636 right_arg = right_arg_sql,
637 maybe_comma = if optionals.len() >= 1 { "," } else { "" },
638 optionals = if !optionals.is_empty() { optionals.join(",\n") + "\n" } else { "".to_string() },
639 );
640 tracing::trace!(sql = %operator_sql);
641 ext_sql + &operator_sql
642 } else {
643 ext_sql
644 };
645 Ok(rendered)
646 }
647}