Skip to main content

pgrx_sql_entity_graph/
pgrx_sql.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12Rust to SQL mapping support.
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17*/
18
19use eyre::eyre;
20use petgraph::dot::Dot;
21use petgraph::graph::NodeIndex;
22use petgraph::stable_graph::StableGraph;
23use petgraph::visit::EdgeRef;
24use std::collections::{BTreeMap, HashMap};
25use std::fmt::Debug;
26use std::path::Path;
27
28use crate::aggregate::entity::PgAggregateEntity;
29use crate::control_file::ControlFile;
30use crate::extension_sql::SqlDeclared;
31use crate::extension_sql::entity::{ExtensionSqlEntity, SqlDeclaredEntity};
32use crate::metadata::TypeOrigin;
33use crate::pg_extern::entity::PgExternEntity;
34use crate::pg_trigger::entity::PgTriggerEntity;
35use crate::positioning_ref::PositioningRef;
36use crate::postgres_enum::entity::PostgresEnumEntity;
37use crate::postgres_hash::entity::PostgresHashEntity;
38use crate::postgres_ord::entity::PostgresOrdEntity;
39use crate::postgres_type::entity::PostgresTypeEntity;
40use crate::schema::entity::SchemaEntity;
41use crate::to_sql::ToSql;
42use crate::type_keyed;
43use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
44
45use super::{PgExternReturnEntity, PgExternReturnEntityIteratedItem};
46
47#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
48pub enum SqlGraphRequires {
49    By,
50    ByArg,
51    ByReturn,
52}
53
54/// A generator for SQL.
55///
56/// Consumes a base mapping of types (typically `pgrx::DEFAULT_TYPEID_SQL_MAPPING`), a
57/// [`ControlFile`], and collections of each SQL entity.
58///
59/// During construction, a Directed Acyclic Graph is formed out the dependencies. For example,
60/// an item `detect_dog(x: &[u8]) -> animals::Dog` would have have a relationship with
61/// `animals::Dog`.
62///
63/// Typically, [`PgrxSql`] types are constructed in a `pgrx::pg_binary_magic!()` call in a binary
64/// out of entities collected during a `pgrx::pg_module_magic!()` call in a library.
65#[derive(Debug, Clone)]
66pub struct PgrxSql<'a> {
67    pub control: ControlFile,
68    pub graph: StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
69    pub graph_root: NodeIndex,
70    pub graph_bootstrap: Option<NodeIndex>,
71    pub graph_finalize: Option<NodeIndex>,
72    pub schemas: HashMap<SchemaEntity<'a>, NodeIndex>,
73    pub extension_sqls: HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
74    pub externs: HashMap<PgExternEntity<'a>, NodeIndex>,
75    pub types: HashMap<PostgresTypeEntity<'a>, NodeIndex>,
76    pub builtin_types: HashMap<String, NodeIndex>,
77    pub enums: HashMap<PostgresEnumEntity<'a>, NodeIndex>,
78    pub ords: HashMap<PostgresOrdEntity<'a>, NodeIndex>,
79    pub hashes: HashMap<PostgresHashEntity<'a>, NodeIndex>,
80    pub aggregates: HashMap<PgAggregateEntity<'a>, NodeIndex>,
81    pub triggers: HashMap<PgTriggerEntity<'a>, NodeIndex>,
82    pub extension_name: String,
83    pub versioned_so: bool,
84}
85
86impl<'a> PgrxSql<'a> {
87    pub fn build(
88        entities: impl Iterator<Item = SqlGraphEntity<'a>>,
89        extension_name: String,
90        versioned_so: bool,
91    ) -> eyre::Result<Self> {
92        let mut graph = StableGraph::new();
93
94        let mut entities = entities.collect::<Vec<_>>();
95        entities.sort();
96        // Split up things into their specific types:
97        let mut control: Option<ControlFile> = None;
98        let mut schemas: Vec<SchemaEntity<'a>> = Vec::default();
99        let mut extension_sqls: Vec<ExtensionSqlEntity<'a>> = Vec::default();
100        let mut externs: Vec<PgExternEntity<'a>> = Vec::default();
101        let mut types: Vec<PostgresTypeEntity<'a>> = Vec::default();
102        let mut enums: Vec<PostgresEnumEntity<'a>> = Vec::default();
103        let mut ords: Vec<PostgresOrdEntity<'a>> = Vec::default();
104        let mut hashes: Vec<PostgresHashEntity<'a>> = Vec::default();
105        let mut aggregates: Vec<PgAggregateEntity<'a>> = Vec::default();
106        let mut triggers: Vec<PgTriggerEntity<'a>> = Vec::default();
107        for entity in entities {
108            match entity {
109                SqlGraphEntity::ExtensionRoot(input_control) => {
110                    control = Some(input_control);
111                }
112                SqlGraphEntity::Schema(input_schema) => {
113                    schemas.push(input_schema);
114                }
115                SqlGraphEntity::CustomSql(input_sql) => {
116                    extension_sqls.push(input_sql);
117                }
118                SqlGraphEntity::Function(input_function) => {
119                    externs.push(input_function);
120                }
121                SqlGraphEntity::Type(input_type) => {
122                    types.push(input_type);
123                }
124                SqlGraphEntity::BuiltinType(_) => (),
125                SqlGraphEntity::Enum(input_enum) => {
126                    enums.push(input_enum);
127                }
128                SqlGraphEntity::Ord(input_ord) => {
129                    ords.push(input_ord);
130                }
131                SqlGraphEntity::Hash(input_hash) => {
132                    hashes.push(input_hash);
133                }
134                SqlGraphEntity::Aggregate(input_aggregate) => {
135                    aggregates.push(input_aggregate);
136                }
137                SqlGraphEntity::Trigger(input_trigger) => {
138                    triggers.push(input_trigger);
139                }
140            }
141        }
142
143        let control: ControlFile = control.expect("No control file found");
144        let root = graph.add_node(SqlGraphEntity::ExtensionRoot(control.clone()));
145
146        // The initial build phase.
147        //
148        // Notably, we do not set non-root edges here. We do that in a second step. This is
149        // primarily because externs, types, operators, and the like tend to intertwine. If we tried
150        // to do it here, we'd find ourselves trying to create edges to non-existing entities.
151
152        // Both of these must be unique, so we can only hold one.
153        // Populate nodes, but don't build edges until we know if there is a bootstrap/finalize.
154        let (mapped_extension_sqls, bootstrap, finalize) =
155            initialize_extension_sqls(&mut graph, root, extension_sqls)?;
156        let mapped_schemas = initialize_schemas(&mut graph, bootstrap, finalize, schemas)?;
157        let mapped_enums = initialize_enums(&mut graph, root, bootstrap, finalize, enums)?;
158        let mapped_types = initialize_types(&mut graph, root, bootstrap, finalize, types)?;
159        ensure_unique_type_targets(&mapped_types, &mapped_enums, &mapped_extension_sqls)?;
160        let (mapped_externs, mut mapped_builtin_types) = initialize_externs(
161            &mut graph,
162            root,
163            bootstrap,
164            finalize,
165            externs,
166            &mapped_types,
167            &mapped_enums,
168            &mapped_extension_sqls,
169        )?;
170        let mapped_ords = initialize_ords(&mut graph, root, bootstrap, finalize, ords)?;
171        let mapped_hashes = initialize_hashes(&mut graph, root, bootstrap, finalize, hashes)?;
172        let mapped_aggregates = initialize_aggregates(
173            &mut graph,
174            root,
175            bootstrap,
176            finalize,
177            aggregates,
178            &mut mapped_builtin_types,
179            &mapped_enums,
180            &mapped_types,
181            &mapped_extension_sqls,
182        )?;
183        let mapped_triggers = initialize_triggers(&mut graph, root, bootstrap, finalize, triggers)?;
184
185        // Now we can circle back and build up the edge sets.
186        connect_schemas(&mut graph, &mapped_schemas, root);
187        connect_extension_sqls(
188            &mut graph,
189            &mapped_extension_sqls,
190            &mapped_schemas,
191            &mapped_types,
192            &mapped_enums,
193            &mapped_externs,
194            &mapped_triggers,
195        )?;
196        connect_enums(&mut graph, &mapped_enums, &mapped_schemas);
197        connect_types(&mut graph, &mapped_types, &mapped_schemas, &mapped_externs)?;
198        connect_externs(
199            &mut graph,
200            &mapped_externs,
201            &mapped_hashes,
202            &mapped_schemas,
203            &mapped_types,
204            &mapped_enums,
205            &mapped_builtin_types,
206            &mapped_extension_sqls,
207            &mapped_triggers,
208        )?;
209        connect_ords(
210            &mut graph,
211            &mapped_ords,
212            &mapped_schemas,
213            &mapped_types,
214            &mapped_enums,
215            &mapped_externs,
216        );
217        connect_hashes(
218            &mut graph,
219            &mapped_hashes,
220            &mapped_schemas,
221            &mapped_types,
222            &mapped_enums,
223            &mapped_externs,
224        );
225        connect_aggregates(
226            &mut graph,
227            &mapped_aggregates,
228            &mapped_schemas,
229            &mapped_types,
230            &mapped_enums,
231            &mapped_builtin_types,
232            &mapped_externs,
233            &mapped_extension_sqls,
234        )?;
235        connect_triggers(&mut graph, &mapped_triggers, &mapped_schemas);
236
237        let this = Self {
238            control,
239            schemas: mapped_schemas,
240            extension_sqls: mapped_extension_sqls,
241            externs: mapped_externs,
242            types: mapped_types,
243            builtin_types: mapped_builtin_types,
244            enums: mapped_enums,
245            ords: mapped_ords,
246            hashes: mapped_hashes,
247            aggregates: mapped_aggregates,
248            triggers: mapped_triggers,
249            graph,
250            graph_root: root,
251            graph_bootstrap: bootstrap,
252            graph_finalize: finalize,
253            extension_name,
254            versioned_so,
255        };
256        Ok(this)
257    }
258
259    // NOTE: this signature is demanded by the codegen we embed via cargo-pgrx
260    pub fn to_file(&self, file: impl AsRef<Path> + Debug) -> eyre::Result<()> {
261        use std::fs::{File, create_dir_all};
262        use std::io::Write;
263        let generated = self.to_sql()?;
264        let path = Path::new(file.as_ref());
265
266        let parent = path.parent();
267        if let Some(parent) = parent {
268            create_dir_all(parent)?;
269        }
270        let mut out = File::create(path)?;
271        write!(out, "{generated}")?;
272        Ok(())
273    }
274
275    pub fn write(&self, out: &mut impl std::io::Write) -> eyre::Result<()> {
276        let generated = self.to_sql()?;
277
278        #[cfg(feature = "syntax-highlighting")]
279        {
280            use std::io::{IsTerminal, stdout};
281            if stdout().is_terminal() {
282                self.write_highlighted(out, &generated)?;
283            } else {
284                write!(*out, "{}", generated)?;
285            }
286        }
287
288        #[cfg(not(feature = "syntax-highlighting"))]
289        {
290            write!(*out, "{generated}")?;
291        }
292
293        Ok(())
294    }
295
296    #[cfg(feature = "syntax-highlighting")]
297    fn write_highlighted(&self, out: &mut dyn std::io::Write, generated: &str) -> eyre::Result<()> {
298        use eyre::WrapErr as _;
299        use owo_colors::{OwoColorize, XtermColors};
300        use syntect::easy::HighlightLines;
301        use syntect::highlighting::{Style, ThemeSet};
302        use syntect::parsing::SyntaxSet;
303        use syntect::util::LinesWithEndings;
304        let ps = SyntaxSet::load_defaults_newlines();
305        let theme_bytes = include_str!("../assets/ansi.tmTheme").as_bytes();
306        let mut theme_reader = std::io::Cursor::new(theme_bytes);
307        let theme = ThemeSet::load_from_reader(&mut theme_reader)
308            .wrap_err("Couldn't parse theme for SQL highlighting, try piping to a file")?;
309
310        if let Some(syntax) = ps.find_syntax_by_extension("sql") {
311            let mut h = HighlightLines::new(syntax, &theme);
312            for line in LinesWithEndings::from(&generated) {
313                let ranges: Vec<(Style, &str)> = h.highlight_line(line, &ps)?;
314                // Concept from https://github.com/sharkdp/bat/blob/1b030dc03b906aa345f44b8266bffeea77d763fe/src/terminal.rs#L6
315                for (style, content) in ranges {
316                    if style.foreground.a == 0x01 {
317                        write!(*out, "{}", content)?;
318                    } else {
319                        write!(*out, "{}", content.color(XtermColors::from(style.foreground.r)))?;
320                    }
321                }
322                write!(*out, "\x1b[0m")?;
323            }
324        } else {
325            write!(*out, "{}", generated)?;
326        }
327        Ok(())
328    }
329
330    // NOTE: this signature is demanded by the codegen we embed via cargo-pgrx
331    pub fn to_dot(&self, file: impl AsRef<Path> + Debug) -> eyre::Result<()> {
332        use std::fs::{File, create_dir_all};
333        use std::io::Write;
334        let generated = Dot::with_attr_getters(
335            &self.graph,
336            &[petgraph::dot::Config::EdgeNoLabel, petgraph::dot::Config::NodeNoLabel],
337            &|_graph, edge| {
338                match edge.weight() {
339                    SqlGraphRequires::By => r#"color = "gray""#,
340                    SqlGraphRequires::ByArg => r#"color = "black""#,
341                    SqlGraphRequires::ByReturn => r#"dir = "back", color = "black""#,
342                }
343                .to_owned()
344            },
345            &|_graph, (_index, node)| {
346                let dot_id = node.dot_identifier();
347                match node {
348                    // Colors derived from https://www.schemecolor.com/touch-of-creativity.php
349                    SqlGraphEntity::Schema(_item) => {
350                        format!("label = \"{dot_id}\", weight = 6, shape = \"tab\"")
351                    }
352                    SqlGraphEntity::Function(_item) => format!(
353                        "label = \"{dot_id}\", penwidth = 0, style = \"filled\", fillcolor = \"#ADC7C6\", weight = 4, shape = \"box\"",
354                    ),
355                    SqlGraphEntity::Type(_item) => format!(
356                        "label = \"{dot_id}\", penwidth = 0, style = \"filled\", fillcolor = \"#AE9BBD\", weight = 5, shape = \"oval\"",
357                    ),
358                    SqlGraphEntity::BuiltinType(_item) => {
359                        format!("label = \"{dot_id}\", shape = \"plain\"")
360                    }
361                    SqlGraphEntity::Enum(_item) => format!(
362                        "label = \"{dot_id}\", penwidth = 0, style = \"filled\", fillcolor = \"#C9A7C8\", weight = 5, shape = \"oval\""
363                    ),
364                    SqlGraphEntity::Ord(_item) => format!(
365                        "label = \"{dot_id}\", penwidth = 0, style = \"filled\", fillcolor = \"#FFCFD3\", weight = 5, shape = \"diamond\""
366                    ),
367                    SqlGraphEntity::Hash(_item) => format!(
368                        "label = \"{dot_id}\", penwidth = 0, style = \"filled\", fillcolor = \"#FFE4E0\", weight = 5, shape = \"diamond\""
369                    ),
370                    SqlGraphEntity::Aggregate(_item) => format!(
371                        "label = \"{dot_id}\", penwidth = 0, style = \"filled\", fillcolor = \"#FFE4E0\", weight = 5, shape = \"diamond\""
372                    ),
373                    SqlGraphEntity::Trigger(_item) => format!(
374                        "label = \"{dot_id}\", penwidth = 0, style = \"filled\", fillcolor = \"#FFE4E0\", weight = 5, shape = \"diamond\""
375                    ),
376                    SqlGraphEntity::CustomSql(_item) => {
377                        format!("label = \"{dot_id}\", weight = 3, shape = \"signature\"")
378                    }
379                    SqlGraphEntity::ExtensionRoot(_item) => {
380                        format!("label = \"{dot_id}\", shape = \"cylinder\"")
381                    }
382                }
383            },
384        );
385        let path = Path::new(file.as_ref());
386
387        let parent = path.parent();
388        if let Some(parent) = parent {
389            create_dir_all(parent)?;
390        }
391        let mut out = File::create(path)?;
392        write!(out, "{generated:?}")?;
393        Ok(())
394    }
395
396    pub fn schema_alias_of(&self, item_index: &NodeIndex) -> Option<String> {
397        self.graph
398            .neighbors_undirected(*item_index)
399            .flat_map(|neighbor_index| match &self.graph[neighbor_index] {
400                SqlGraphEntity::Schema(s) => Some(String::from(s.name)),
401                SqlGraphEntity::ExtensionRoot(_control) => None,
402                _ => None,
403            })
404            .next()
405    }
406
407    pub fn schema_prefix_for(&self, target: &NodeIndex) -> String {
408        self.schema_alias_of(target).map(|v| (v + ".").to_string()).unwrap_or_default()
409    }
410
411    pub fn find_type_dependency(
412        &self,
413        owner: &NodeIndex,
414        ty: &dyn crate::TypeIdentifiable,
415    ) -> Option<NodeIndex> {
416        self.graph
417            .neighbors_undirected(*owner)
418            .find(|neighbor| self.graph[*neighbor].type_matches(ty))
419    }
420
421    pub fn schema_prefix_for_used_type(
422        &self,
423        owner: &NodeIndex,
424        slot: &str,
425        used_ty: &UsedTypeEntity<'_>,
426    ) -> eyre::Result<String> {
427        if !used_ty.needs_type_resolution() {
428            return Ok(String::new());
429        }
430
431        let graph_index = self
432            .find_type_dependency(owner, used_ty)
433            .ok_or_else(|| eyre!("Could not find {slot} in graph. Got: {used_ty:?}"))?;
434        Ok(self.schema_prefix_for(&graph_index))
435    }
436
437    pub fn to_sql(&self) -> eyre::Result<String> {
438        let mut full_sql = String::new();
439
440        // NB:  A properly we'd *like* to maintain is that the schema generator outputs
441        // consistent results from run-to-run when there are no changes to the schema.
442        // This is to improve change detection using simple tools like `diff`.
443        //
444        // Historically, we used [`petgraph::algo:toposort`] but its ordering is not at all
445        // consistent.
446        //
447        // [`petgraph::algo::tarjan_scc`] appears to be consistent, although it's not exactly
448        // clear if this is due to an implementation detail or specifics of the algorithm itself.
449        // (I, eeeebbbbrrrr, am not a graph theory expert)
450        //
451        // In any event, if in the future schema generation stops being consistent, this is the
452        // place to look.
453        //
454        // We have no tests around this as it's really just a property we'd like to have, and
455        // it does seem ensuring it is a bit of black magic.
456        for nodes in petgraph::algo::tarjan_scc(&self.graph).iter().rev() {
457            let mut inner_sql = Vec::with_capacity(nodes.len());
458
459            for node in self.connected_component_emit_order(nodes) {
460                let step = &self.graph[node];
461                let sql = step.to_sql(self)?;
462
463                let trimmed = sql.trim();
464                if !trimmed.is_empty() {
465                    inner_sql.push(format!("{trimmed}\n"))
466                }
467            }
468
469            if !inner_sql.is_empty() {
470                full_sql.push_str("/* <begin connected objects> */\n");
471                full_sql.push_str(&inner_sql.join("\n\n"));
472                full_sql.push_str("/* </end connected objects> */\n\n");
473            }
474        }
475
476        Ok(full_sql)
477    }
478
479    fn connected_component_emit_order(&self, nodes: &[NodeIndex]) -> Vec<NodeIndex> {
480        if nodes.len() <= 1 {
481            return nodes.to_vec();
482        }
483
484        // When a connected component contains a cycle, user-authored `requires = [...]`
485        // edges are the strongest ordering signal we have. Type-resolution edges may still
486        // point back into the declaration that ultimately creates the type, such as shell-type
487        // bootstrap patterns for manual `extension_sql!()` types.
488        let mut explicit_dependents = HashMap::<NodeIndex, Vec<NodeIndex>>::new();
489        let mut remaining_explicit_dependencies = HashMap::<NodeIndex, usize>::new();
490        let mut has_explicit_edges = false;
491
492        for &node in nodes {
493            explicit_dependents.insert(node, Vec::new());
494            remaining_explicit_dependencies.insert(node, 0);
495        }
496
497        for &node in nodes {
498            for edge in self.graph.edges(node) {
499                if edge.weight() != &SqlGraphRequires::By {
500                    continue;
501                }
502
503                let dependent = edge.target();
504                if !remaining_explicit_dependencies.contains_key(&dependent) {
505                    continue;
506                }
507
508                has_explicit_edges = true;
509                explicit_dependents
510                    .get_mut(&node)
511                    .expect("component members should be initialized")
512                    .push(dependent);
513                *remaining_explicit_dependencies
514                    .get_mut(&dependent)
515                    .expect("component members should be initialized") += 1;
516            }
517        }
518
519        if !has_explicit_edges {
520            return nodes.to_vec();
521        }
522
523        let mut ready = remaining_explicit_dependencies
524            .iter()
525            .filter_map(|(node, count)| (*count == 0).then_some(*node))
526            .collect::<Vec<_>>();
527        let mut ordered = Vec::with_capacity(nodes.len());
528
529        while !ready.is_empty() {
530            ready.sort_unstable_by(|left, right| {
531                self.graph[*left]
532                    .cmp(&self.graph[*right])
533                    .then_with(|| left.index().cmp(&right.index()))
534            });
535            let next = ready.remove(0);
536            ordered.push(next);
537
538            if let Some(dependents) = explicit_dependents.get(&next) {
539                for dependent in dependents {
540                    let remaining = remaining_explicit_dependencies
541                        .get_mut(dependent)
542                        .expect("component members should be initialized");
543                    *remaining -= 1;
544                    if *remaining == 0 {
545                        ready.push(*dependent);
546                    }
547                }
548            }
549        }
550
551        if ordered.len() == nodes.len() { ordered } else { nodes.to_vec() }
552    }
553
554    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
555        self.extension_sqls.iter().find_map(|(item, _index)| {
556            item.creates
557                .iter()
558                .find(|create_entity| create_entity.has_sql_declared_entity(identifier))
559        })
560    }
561
562    pub fn get_module_pathname(&self) -> String {
563        if self.versioned_so {
564            let extname = &self.extension_name;
565            let extver = &self.control.default_version;
566            // Note: versioned so-name format must agree with cargo pgrx
567            format!("{extname}-{extver}")
568        } else {
569            String::from("MODULE_PATHNAME")
570        }
571    }
572
573    pub fn find_matching_fn(&self, name: &str) -> Option<&PgExternEntity<'a>> {
574        self.externs.keys().find(|key| key.full_path.ends_with(name))
575    }
576}
577
578fn build_base_edges<'a>(
579    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
580    index: NodeIndex,
581    root: NodeIndex,
582    bootstrap: Option<NodeIndex>,
583    finalize: Option<NodeIndex>,
584) {
585    graph.add_edge(root, index, SqlGraphRequires::By);
586    if let Some(bootstrap) = bootstrap {
587        graph.add_edge(bootstrap, index, SqlGraphRequires::By);
588    }
589    if let Some(finalize) = finalize {
590        graph.add_edge(index, finalize, SqlGraphRequires::By);
591    }
592}
593
594#[allow(clippy::type_complexity)]
595fn initialize_extension_sqls<'a>(
596    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
597    root: NodeIndex,
598    extension_sqls: Vec<ExtensionSqlEntity<'a>>,
599) -> eyre::Result<(HashMap<ExtensionSqlEntity<'a>, NodeIndex>, Option<NodeIndex>, Option<NodeIndex>)>
600{
601    let mut bootstrap = None;
602    let mut finalize = None;
603    let mut mapped_extension_sqls = HashMap::default();
604    for item in extension_sqls {
605        let entity: SqlGraphEntity = item.clone().into();
606        let index = graph.add_node(entity);
607        mapped_extension_sqls.insert(item.clone(), index);
608
609        if item.bootstrap {
610            if let Some(existing_index) = bootstrap {
611                let existing: &SqlGraphEntity = &graph[existing_index];
612                return Err(eyre!(
613                    "Cannot have multiple `extension_sql!()` with `bootstrap` positioning, found `{}`, other was `{}`",
614                    item.rust_identifier(),
615                    existing.rust_identifier(),
616                ));
617            }
618            bootstrap = Some(index)
619        }
620        if item.finalize {
621            if let Some(existing_index) = finalize {
622                let existing: &SqlGraphEntity = &graph[existing_index];
623                return Err(eyre!(
624                    "Cannot have multiple `extension_sql!()` with `finalize` positioning, found `{}`, other was `{}`",
625                    item.rust_identifier(),
626                    existing.rust_identifier(),
627                ));
628            }
629            finalize = Some(index)
630        }
631    }
632    for (item, index) in &mapped_extension_sqls {
633        graph.add_edge(root, *index, SqlGraphRequires::By);
634        if !item.bootstrap
635            && let Some(bootstrap) = bootstrap
636        {
637            graph.add_edge(bootstrap, *index, SqlGraphRequires::By);
638        }
639        if !item.finalize
640            && let Some(finalize) = finalize
641        {
642            graph.add_edge(*index, finalize, SqlGraphRequires::By);
643        }
644    }
645    Ok((mapped_extension_sqls, bootstrap, finalize))
646}
647
648/// A best effort attempt to find the related [`NodeIndex`] for some [`PositioningRef`].
649pub fn find_positioning_ref_target<'a, 'b>(
650    positioning_ref: &'b PositioningRef,
651    types: &'b HashMap<PostgresTypeEntity<'a>, NodeIndex>,
652    enums: &'b HashMap<PostgresEnumEntity<'a>, NodeIndex>,
653    externs: &'b HashMap<PgExternEntity<'a>, NodeIndex>,
654    schemas: &'b HashMap<SchemaEntity<'a>, NodeIndex>,
655    extension_sqls: &'b HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
656    triggers: &'b HashMap<PgTriggerEntity<'a>, NodeIndex>,
657) -> Option<&'b NodeIndex> {
658    match positioning_ref {
659        PositioningRef::FullPath(path) => {
660            // The best we can do here is a fuzzy search.
661            let segments = path.split("::").collect::<Vec<_>>();
662            let last_segment = segments.last().expect("Expected at least one segment.");
663            let rest = &segments[..segments.len() - 1];
664            let module_path = rest.join("::");
665
666            for (other, other_index) in types {
667                if *last_segment == other.name && other.module_path.ends_with(&module_path) {
668                    return Some(other_index);
669                }
670            }
671            for (other, other_index) in enums {
672                if last_segment == &other.name && other.module_path.ends_with(&module_path) {
673                    return Some(other_index);
674                }
675            }
676            for (other, other_index) in externs {
677                if *last_segment == other.unaliased_name
678                    && other.module_path.ends_with(&module_path)
679                {
680                    return Some(other_index);
681                }
682            }
683            for (other, other_index) in schemas {
684                if other.module_path.ends_with(path) {
685                    return Some(other_index);
686                }
687            }
688
689            for (other, other_index) in triggers {
690                if last_segment == &other.function_name && other.module_path.ends_with(&module_path)
691                {
692                    return Some(other_index);
693                }
694            }
695        }
696        PositioningRef::Name(name) => {
697            for (other, other_index) in extension_sqls {
698                if other.name == name {
699                    return Some(other_index);
700                }
701            }
702        }
703    };
704    None
705}
706
707fn connect_extension_sqls<'a>(
708    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
709    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
710    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
711    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
712    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
713    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
714    triggers: &HashMap<PgTriggerEntity<'a>, NodeIndex>,
715) -> eyre::Result<()> {
716    for (item, &index) in extension_sqls {
717        make_schema_connection(
718            graph,
719            "Extension SQL",
720            index,
721            &item.rust_identifier(),
722            item.module_path,
723            schemas,
724        );
725
726        for requires in &item.requires {
727            if let Some(target) = find_positioning_ref_target(
728                requires,
729                types,
730                enums,
731                externs,
732                schemas,
733                extension_sqls,
734                triggers,
735            ) {
736                graph.add_edge(*target, index, SqlGraphRequires::By);
737            } else {
738                return Err(eyre!(
739                    "Could not find `requires` target of `{}`{}: {}",
740                    item.rust_identifier(),
741                    match (item.file(), item.line()) {
742                        (Some(file), Some(line)) => format!(" ({file}:{line})"),
743                        _ => "".to_string(),
744                    },
745                    match requires {
746                        PositioningRef::FullPath(path) => path.to_string(),
747                        PositioningRef::Name(name) => format!(r#""{name}""#),
748                    },
749                ));
750            }
751        }
752    }
753    Ok(())
754}
755
756fn initialize_schemas<'a>(
757    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
758    bootstrap: Option<NodeIndex>,
759    finalize: Option<NodeIndex>,
760    schemas: Vec<SchemaEntity<'a>>,
761) -> eyre::Result<HashMap<SchemaEntity<'a>, NodeIndex>> {
762    let mut mapped_schemas = HashMap::default();
763    for item in schemas {
764        let entity = item.clone().into();
765        let index = graph.add_node(entity);
766        mapped_schemas.insert(item, index);
767        if let Some(bootstrap) = bootstrap {
768            graph.add_edge(bootstrap, index, SqlGraphRequires::By);
769        }
770        if let Some(finalize) = finalize {
771            graph.add_edge(index, finalize, SqlGraphRequires::By);
772        }
773    }
774    Ok(mapped_schemas)
775}
776
777fn connect_schemas<'a>(
778    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
779    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
780    root: NodeIndex,
781) {
782    for index in schemas.values().copied() {
783        graph.add_edge(root, index, SqlGraphRequires::By);
784    }
785}
786
787fn initialize_enums<'a>(
788    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
789    root: NodeIndex,
790    bootstrap: Option<NodeIndex>,
791    finalize: Option<NodeIndex>,
792    enums: Vec<PostgresEnumEntity<'a>>,
793) -> eyre::Result<HashMap<PostgresEnumEntity<'a>, NodeIndex>> {
794    let mut mapped_enums = HashMap::default();
795    for item in enums {
796        let entity: SqlGraphEntity = item.clone().into();
797        let index = graph.add_node(entity);
798        mapped_enums.insert(item, index);
799        build_base_edges(graph, index, root, bootstrap, finalize);
800    }
801    Ok(mapped_enums)
802}
803
804fn connect_enums<'a>(
805    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
806    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
807    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
808) {
809    for (item, &index) in enums {
810        make_schema_connection(
811            graph,
812            "Enum",
813            index,
814            &item.rust_identifier(),
815            item.module_path,
816            schemas,
817        );
818    }
819}
820
821fn initialize_types<'a>(
822    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
823    root: NodeIndex,
824    bootstrap: Option<NodeIndex>,
825    finalize: Option<NodeIndex>,
826    types: Vec<PostgresTypeEntity<'a>>,
827) -> eyre::Result<HashMap<PostgresTypeEntity<'a>, NodeIndex>> {
828    let mut mapped_types = HashMap::default();
829    for item in types {
830        let entity = item.clone().into();
831        let index = graph.add_node(entity);
832        mapped_types.insert(item, index);
833        build_base_edges(graph, index, root, bootstrap, finalize);
834    }
835    Ok(mapped_types)
836}
837
838fn connect_types<'a>(
839    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
840    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
841    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
842    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
843) -> eyre::Result<()> {
844    for (item, &index) in types {
845        make_schema_connection(
846            graph,
847            "Type",
848            index,
849            &item.rust_identifier(),
850            item.module_path,
851            schemas,
852        );
853
854        make_extern_connection(
855            graph,
856            "Type",
857            index,
858            &item.rust_identifier(),
859            &resolve_function_path(item.module_path, item.in_fn_path),
860            externs,
861        )?;
862        make_extern_connection(
863            graph,
864            "Type",
865            index,
866            &item.rust_identifier(),
867            &resolve_function_path(item.module_path, item.out_fn_path),
868            externs,
869        )?;
870        if let Some(path) = item.receive_fn_path {
871            make_extern_connection(
872                graph,
873                "Type",
874                index,
875                &item.rust_identifier(),
876                &resolve_function_path(item.module_path, path),
877                externs,
878            )?;
879        }
880        if let Some(path) = item.send_fn_path {
881            make_extern_connection(
882                graph,
883                "Type",
884                index,
885                &item.rust_identifier(),
886                &resolve_function_path(item.module_path, path),
887                externs,
888            )?;
889        }
890    }
891    Ok(())
892}
893
894fn initialize_externs<'a>(
895    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
896    root: NodeIndex,
897    bootstrap: Option<NodeIndex>,
898    finalize: Option<NodeIndex>,
899    externs: Vec<PgExternEntity<'a>>,
900    mapped_types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
901    mapped_enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
902    mapped_extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
903) -> eyre::Result<(HashMap<PgExternEntity<'a>, NodeIndex>, HashMap<String, NodeIndex>)> {
904    let mut mapped_externs = HashMap::default();
905    let mut mapped_builtin_types = HashMap::default();
906    for item in externs {
907        let entity: SqlGraphEntity = item.clone().into();
908        let index = graph.add_node(entity.clone());
909        mapped_externs.insert(item.clone(), index);
910        build_base_edges(graph, index, root, bootstrap, finalize);
911
912        for arg in &item.fn_args {
913            if !arg.used_ty.emits_argument_sql() || !arg.used_ty.needs_type_resolution() {
914                continue;
915            }
916            let slot = format!("argument `{}`", arg.pattern);
917            let (type_ident, type_origin) = arg
918                .used_ty
919                .resolution()
920                .expect("SQL-visible extern arguments should carry resolution metadata");
921            initialize_resolved_type(
922                graph,
923                &mut mapped_builtin_types,
924                type_ident,
925                type_origin,
926                mapped_types,
927                mapped_enums,
928                mapped_extension_sqls,
929                "Function",
930                item.full_path,
931                &slot,
932                arg.used_ty.full_path,
933            )?;
934        }
935
936        match &item.fn_return {
937            PgExternReturnEntity::None | PgExternReturnEntity::Trigger => (),
938            PgExternReturnEntity::Type { ty, .. } | PgExternReturnEntity::SetOf { ty, .. } => {
939                if let Some((type_ident, type_origin)) = ty.resolution() {
940                    initialize_resolved_type(
941                        graph,
942                        &mut mapped_builtin_types,
943                        type_ident,
944                        type_origin,
945                        mapped_types,
946                        mapped_enums,
947                        mapped_extension_sqls,
948                        "Function",
949                        item.full_path,
950                        "return type",
951                        ty.full_path,
952                    )?;
953                }
954            }
955            PgExternReturnEntity::Iterated { tys: iterated_returns, .. } => {
956                for PgExternReturnEntityIteratedItem { ty, .. } in iterated_returns {
957                    if let Some((type_ident, type_origin)) = ty.resolution() {
958                        initialize_resolved_type(
959                            graph,
960                            &mut mapped_builtin_types,
961                            type_ident,
962                            type_origin,
963                            mapped_types,
964                            mapped_enums,
965                            mapped_extension_sqls,
966                            "Function",
967                            item.full_path,
968                            "table return column",
969                            ty.full_path,
970                        )?;
971                    }
972                }
973            }
974        }
975    }
976    Ok((mapped_externs, mapped_builtin_types))
977}
978
979fn connect_externs<'a>(
980    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
981    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
982    hashes: &HashMap<PostgresHashEntity<'a>, NodeIndex>,
983    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
984    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
985    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
986    builtin_types: &HashMap<String, NodeIndex>,
987    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
988    triggers: &HashMap<PgTriggerEntity<'a>, NodeIndex>,
989) -> eyre::Result<()> {
990    for (item, &index) in externs {
991        let mut found_schema_declaration = false;
992        for extern_attr in &item.extern_attrs {
993            match extern_attr {
994                crate::ExternArgs::Requires(requirements) => {
995                    for requires in requirements {
996                        if let Some(target) = find_positioning_ref_target(
997                            requires,
998                            types,
999                            enums,
1000                            externs,
1001                            schemas,
1002                            extension_sqls,
1003                            triggers,
1004                        ) {
1005                            graph.add_edge(*target, index, SqlGraphRequires::By);
1006                        } else {
1007                            return Err(eyre!("Could not find `requires` target: {:?}", requires));
1008                        }
1009                    }
1010                }
1011                crate::ExternArgs::Support(support_fn) => {
1012                    if let Some(target) = find_positioning_ref_target(
1013                        support_fn,
1014                        types,
1015                        enums,
1016                        externs,
1017                        schemas,
1018                        extension_sqls,
1019                        triggers,
1020                    ) {
1021                        graph.add_edge(*target, index, SqlGraphRequires::By);
1022                    }
1023                }
1024                crate::ExternArgs::Schema(declared_schema_name) => {
1025                    for (schema, schema_index) in schemas {
1026                        if schema.name == declared_schema_name {
1027                            graph.add_edge(*schema_index, index, SqlGraphRequires::By);
1028                            found_schema_declaration = true;
1029                        }
1030                    }
1031                    if !found_schema_declaration {
1032                        return Err(eyre!(
1033                            "Got manual `schema = \"{declared_schema_name}\"` setting, but that schema did not exist."
1034                        ));
1035                    }
1036                }
1037                _ => (),
1038            }
1039        }
1040
1041        if !found_schema_declaration {
1042            make_schema_connection(
1043                graph,
1044                "Extern",
1045                index,
1046                &item.rust_identifier(),
1047                item.module_path,
1048                schemas,
1049            );
1050        }
1051
1052        // The hash function must be defined after the {typename}_eq function.
1053        for (hash_item, &hash_index) in hashes {
1054            if item.module_path == hash_item.module_path
1055                && item.name == hash_item.name.to_lowercase() + "_eq"
1056            {
1057                graph.add_edge(index, hash_index, SqlGraphRequires::By);
1058            }
1059        }
1060
1061        for arg in &item.fn_args {
1062            if !arg.used_ty.emits_argument_sql() || !arg.used_ty.needs_type_resolution() {
1063                continue;
1064            }
1065            let slot = format!("argument `{}`", arg.pattern);
1066            let (type_ident, type_origin) = arg
1067                .used_ty
1068                .resolution()
1069                .expect("SQL-visible extern arguments should carry resolution metadata");
1070            connect_resolved_type(
1071                graph,
1072                index,
1073                SqlGraphRequires::ByArg,
1074                type_ident,
1075                type_origin,
1076                types,
1077                enums,
1078                builtin_types,
1079                extension_sqls,
1080                "Function",
1081                item.full_path,
1082                &slot,
1083                arg.used_ty.full_path,
1084            )?;
1085        }
1086
1087        match &item.fn_return {
1088            PgExternReturnEntity::None | PgExternReturnEntity::Trigger => (),
1089            PgExternReturnEntity::Type { ty, .. } | PgExternReturnEntity::SetOf { ty, .. } => {
1090                if let Some((type_ident, type_origin)) = ty.resolution() {
1091                    connect_resolved_type(
1092                        graph,
1093                        index,
1094                        SqlGraphRequires::ByReturn,
1095                        type_ident,
1096                        type_origin,
1097                        types,
1098                        enums,
1099                        builtin_types,
1100                        extension_sqls,
1101                        "Function",
1102                        item.full_path,
1103                        "return type",
1104                        ty.full_path,
1105                    )?;
1106                }
1107            }
1108            PgExternReturnEntity::Iterated { tys: iterated_returns, .. } => {
1109                for PgExternReturnEntityIteratedItem { ty, .. } in iterated_returns {
1110                    if let Some((type_ident, type_origin)) = ty.resolution() {
1111                        connect_resolved_type(
1112                            graph,
1113                            index,
1114                            SqlGraphRequires::ByReturn,
1115                            type_ident,
1116                            type_origin,
1117                            types,
1118                            enums,
1119                            builtin_types,
1120                            extension_sqls,
1121                            "Function",
1122                            item.full_path,
1123                            "table return column",
1124                            ty.full_path,
1125                        )?;
1126                    }
1127                }
1128            }
1129        }
1130    }
1131    Ok(())
1132}
1133
1134fn initialize_ords<'a>(
1135    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1136    root: NodeIndex,
1137    bootstrap: Option<NodeIndex>,
1138    finalize: Option<NodeIndex>,
1139    ords: Vec<PostgresOrdEntity<'a>>,
1140) -> eyre::Result<HashMap<PostgresOrdEntity<'a>, NodeIndex>> {
1141    let mut mapped_ords = HashMap::default();
1142    for item in ords {
1143        let entity = item.clone().into();
1144        let index = graph.add_node(entity);
1145        mapped_ords.insert(item.clone(), index);
1146        build_base_edges(graph, index, root, bootstrap, finalize);
1147    }
1148    Ok(mapped_ords)
1149}
1150
1151fn connect_ords<'a>(
1152    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1153    ords: &HashMap<PostgresOrdEntity<'a>, NodeIndex>,
1154    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
1155    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1156    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1157    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
1158) {
1159    for (item, &index) in ords {
1160        make_schema_connection(
1161            graph,
1162            "Ord",
1163            index,
1164            &item.rust_identifier(),
1165            item.module_path,
1166            schemas,
1167        );
1168
1169        make_type_or_enum_connection(graph, index, item.type_ident, types, enums);
1170
1171        // Make PostgresOrdEntities (which will be translated into `CREATE OPERATOR CLASS` statements) depend
1172        // on the operators which they will reference. For example, a pgrx-defined Postgres type `parakeet`
1173        // which has `#[derive(PostgresOrd)]` will emit a `parakeet_btree_ops` operator class, which references
1174        // a definition of a < operator (among others) on the `parakeet` type. This code should ensure that the
1175        // < operator (along with all the others) is emitted before the `OPERATOR CLASS` itself.
1176
1177        for (extern_item, &extern_index) in externs {
1178            let fn_matches = |fn_name| {
1179                item.module_path == extern_item.module_path && extern_item.name == fn_name
1180            };
1181            let cmp_fn_matches = fn_matches(item.cmp_fn_name());
1182            let lt_fn_matches = fn_matches(item.lt_fn_name());
1183            let lte_fn_matches = fn_matches(item.le_fn_name());
1184            let eq_fn_matches = fn_matches(item.eq_fn_name());
1185            let gt_fn_matches = fn_matches(item.gt_fn_name());
1186            let gte_fn_matches = fn_matches(item.ge_fn_name());
1187            if cmp_fn_matches
1188                || lt_fn_matches
1189                || lte_fn_matches
1190                || eq_fn_matches
1191                || gt_fn_matches
1192                || gte_fn_matches
1193            {
1194                graph.add_edge(extern_index, index, SqlGraphRequires::By);
1195            }
1196        }
1197    }
1198}
1199
1200fn initialize_hashes<'a>(
1201    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1202    root: NodeIndex,
1203    bootstrap: Option<NodeIndex>,
1204    finalize: Option<NodeIndex>,
1205    hashes: Vec<PostgresHashEntity<'a>>,
1206) -> eyre::Result<HashMap<PostgresHashEntity<'a>, NodeIndex>> {
1207    let mut mapped_hashes = HashMap::default();
1208    for item in hashes {
1209        let entity: SqlGraphEntity = item.clone().into();
1210        let index = graph.add_node(entity);
1211        mapped_hashes.insert(item, index);
1212        build_base_edges(graph, index, root, bootstrap, finalize);
1213    }
1214    Ok(mapped_hashes)
1215}
1216
1217fn connect_hashes<'a>(
1218    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1219    hashes: &HashMap<PostgresHashEntity<'a>, NodeIndex>,
1220    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
1221    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1222    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1223    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
1224) {
1225    for (item, &index) in hashes {
1226        make_schema_connection(
1227            graph,
1228            "Hash",
1229            index,
1230            &item.rust_identifier(),
1231            item.module_path,
1232            schemas,
1233        );
1234
1235        make_type_or_enum_connection(graph, index, item.type_ident, types, enums);
1236
1237        if let Some((_, extern_index)) = externs.iter().find(|(extern_item, _)| {
1238            item.module_path == extern_item.module_path && extern_item.name == item.fn_name()
1239        }) {
1240            graph.add_edge(*extern_index, index, SqlGraphRequires::By);
1241        }
1242    }
1243}
1244
1245fn initialize_aggregates<'a>(
1246    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1247    root: NodeIndex,
1248    bootstrap: Option<NodeIndex>,
1249    finalize: Option<NodeIndex>,
1250    aggregates: Vec<PgAggregateEntity<'a>>,
1251    mapped_builtin_types: &mut HashMap<String, NodeIndex>,
1252    mapped_enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1253    mapped_types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1254    mapped_extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1255) -> eyre::Result<HashMap<PgAggregateEntity<'a>, NodeIndex>> {
1256    let mut mapped_aggregates = HashMap::default();
1257    for item in aggregates {
1258        let entity: SqlGraphEntity = item.clone().into();
1259        let index = graph.add_node(entity);
1260
1261        for arg in &item.args {
1262            if !arg.used_ty.needs_type_resolution() {
1263                continue;
1264            }
1265            let slot = aggregate_slot(arg.name, "argument");
1266            let (type_ident, type_origin) = arg
1267                .used_ty
1268                .resolution()
1269                .expect("aggregate arguments should carry resolution metadata");
1270            initialize_resolved_type(
1271                graph,
1272                mapped_builtin_types,
1273                type_ident,
1274                type_origin,
1275                mapped_types,
1276                mapped_enums,
1277                mapped_extension_sqls,
1278                "Aggregate",
1279                item.full_path,
1280                &slot,
1281                arg.used_ty.full_path,
1282            )?;
1283        }
1284
1285        for arg in item.direct_args.as_ref().unwrap_or(&vec![]) {
1286            if !arg.used_ty.needs_type_resolution() {
1287                continue;
1288            }
1289            let slot = aggregate_slot(arg.name, "direct argument");
1290            let (type_ident, type_origin) = arg
1291                .used_ty
1292                .resolution()
1293                .expect("aggregate direct arguments should carry resolution metadata");
1294            initialize_resolved_type(
1295                graph,
1296                mapped_builtin_types,
1297                type_ident,
1298                type_origin,
1299                mapped_types,
1300                mapped_enums,
1301                mapped_extension_sqls,
1302                "Aggregate",
1303                item.full_path,
1304                &slot,
1305                arg.used_ty.full_path,
1306            )?;
1307        }
1308
1309        if let Some((type_ident, type_origin)) = item.stype.used_ty.resolution() {
1310            initialize_resolved_type(
1311                graph,
1312                mapped_builtin_types,
1313                type_ident,
1314                type_origin,
1315                mapped_types,
1316                mapped_enums,
1317                mapped_extension_sqls,
1318                "Aggregate",
1319                item.full_path,
1320                "STYPE",
1321                item.stype.used_ty.full_path,
1322            )?;
1323        }
1324
1325        if let Some(arg) = &item.mstype
1326            && let Some((type_ident, type_origin)) = arg.resolution()
1327        {
1328            initialize_resolved_type(
1329                graph,
1330                mapped_builtin_types,
1331                type_ident,
1332                type_origin,
1333                mapped_types,
1334                mapped_enums,
1335                mapped_extension_sqls,
1336                "Aggregate",
1337                item.full_path,
1338                "MSTYPE",
1339                arg.full_path,
1340            )?;
1341        }
1342
1343        mapped_aggregates.insert(item, index);
1344        build_base_edges(graph, index, root, bootstrap, finalize);
1345    }
1346    Ok(mapped_aggregates)
1347}
1348
1349fn connect_aggregate<'a>(
1350    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1351    item: &PgAggregateEntity<'a>,
1352    index: NodeIndex,
1353    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
1354    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1355    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1356    builtin_types: &HashMap<String, NodeIndex>,
1357    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
1358    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1359) -> eyre::Result<()> {
1360    make_schema_connection(
1361        graph,
1362        "Aggregate",
1363        index,
1364        &item.rust_identifier(),
1365        item.module_path,
1366        schemas,
1367    );
1368
1369    for arg in &item.args {
1370        if !arg.used_ty.needs_type_resolution() {
1371            continue;
1372        }
1373        let slot = aggregate_slot(arg.name, "argument");
1374        let (type_ident, type_origin) =
1375            arg.used_ty.resolution().expect("aggregate arguments should carry resolution metadata");
1376        connect_resolved_type(
1377            graph,
1378            index,
1379            SqlGraphRequires::ByArg,
1380            type_ident,
1381            type_origin,
1382            types,
1383            enums,
1384            builtin_types,
1385            extension_sqls,
1386            "Aggregate",
1387            item.full_path,
1388            &slot,
1389            arg.used_ty.full_path,
1390        )?;
1391    }
1392
1393    for arg in item.direct_args.as_ref().unwrap_or(&vec![]) {
1394        if !arg.used_ty.needs_type_resolution() {
1395            continue;
1396        }
1397        let slot = aggregate_slot(arg.name, "direct argument");
1398        let (type_ident, type_origin) = arg
1399            .used_ty
1400            .resolution()
1401            .expect("aggregate direct arguments should carry resolution metadata");
1402        connect_resolved_type(
1403            graph,
1404            index,
1405            SqlGraphRequires::ByArg,
1406            type_ident,
1407            type_origin,
1408            types,
1409            enums,
1410            builtin_types,
1411            extension_sqls,
1412            "Aggregate",
1413            item.full_path,
1414            &slot,
1415            arg.used_ty.full_path,
1416        )?;
1417    }
1418
1419    if let Some(arg) = &item.mstype
1420        && let Some((type_ident, type_origin)) = arg.resolution()
1421    {
1422        connect_resolved_type(
1423            graph,
1424            index,
1425            SqlGraphRequires::ByArg,
1426            type_ident,
1427            type_origin,
1428            types,
1429            enums,
1430            builtin_types,
1431            extension_sqls,
1432            "Aggregate",
1433            item.full_path,
1434            "MSTYPE",
1435            arg.full_path,
1436        )?;
1437    }
1438
1439    if let Some((type_ident, type_origin)) = item.stype.used_ty.resolution() {
1440        connect_resolved_type(
1441            graph,
1442            index,
1443            SqlGraphRequires::ByArg,
1444            type_ident,
1445            type_origin,
1446            types,
1447            enums,
1448            builtin_types,
1449            extension_sqls,
1450            "Aggregate",
1451            item.full_path,
1452            "STYPE",
1453            item.stype.used_ty.full_path,
1454        )?;
1455    }
1456
1457    make_extern_connection(
1458        graph,
1459        "Aggregate",
1460        index,
1461        &item.rust_identifier(),
1462        &(item.module_path.to_string() + "::" + item.sfunc),
1463        externs,
1464    )?;
1465
1466    if let Some(value) = item.finalfunc {
1467        make_extern_connection(
1468            graph,
1469            "Aggregate",
1470            index,
1471            &item.rust_identifier(),
1472            &(item.module_path.to_string() + "::" + value),
1473            externs,
1474        )?;
1475    }
1476    if let Some(value) = item.combinefunc {
1477        make_extern_connection(
1478            graph,
1479            "Aggregate",
1480            index,
1481            &item.rust_identifier(),
1482            &(item.module_path.to_string() + "::" + value),
1483            externs,
1484        )?;
1485    }
1486    if let Some(value) = item.serialfunc {
1487        make_extern_connection(
1488            graph,
1489            "Aggregate",
1490            index,
1491            &item.rust_identifier(),
1492            &(item.module_path.to_string() + "::" + value),
1493            externs,
1494        )?;
1495    }
1496    if let Some(value) = item.deserialfunc {
1497        make_extern_connection(
1498            graph,
1499            "Aggregate",
1500            index,
1501            &item.rust_identifier(),
1502            &(item.module_path.to_string() + "::" + value),
1503            externs,
1504        )?;
1505    }
1506    if let Some(value) = item.msfunc {
1507        make_extern_connection(
1508            graph,
1509            "Aggregate",
1510            index,
1511            &item.rust_identifier(),
1512            &(item.module_path.to_string() + "::" + value),
1513            externs,
1514        )?;
1515    }
1516    if let Some(value) = item.minvfunc {
1517        make_extern_connection(
1518            graph,
1519            "Aggregate",
1520            index,
1521            &item.rust_identifier(),
1522            &(item.module_path.to_string() + "::" + value),
1523            externs,
1524        )?;
1525    }
1526    if let Some(value) = item.mfinalfunc {
1527        make_extern_connection(
1528            graph,
1529            "Aggregate",
1530            index,
1531            &item.rust_identifier(),
1532            &(item.module_path.to_string() + "::" + value),
1533            externs,
1534        )?;
1535    }
1536    if let Some(value) = item.sortop {
1537        make_extern_connection(
1538            graph,
1539            "Aggregate",
1540            index,
1541            &item.rust_identifier(),
1542            &(item.module_path.to_string() + "::" + value),
1543            externs,
1544        )?;
1545    }
1546    Ok(())
1547}
1548
1549fn connect_aggregates<'a>(
1550    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1551    aggregates: &HashMap<PgAggregateEntity<'a>, NodeIndex>,
1552    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
1553    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1554    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1555    builtin_types: &HashMap<String, NodeIndex>,
1556    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
1557    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1558) -> eyre::Result<()> {
1559    for (item, &index) in aggregates {
1560        connect_aggregate(
1561            graph,
1562            item,
1563            index,
1564            schemas,
1565            types,
1566            enums,
1567            builtin_types,
1568            externs,
1569            extension_sqls,
1570        )?
1571    }
1572    Ok(())
1573}
1574
1575fn initialize_triggers<'a>(
1576    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1577    root: NodeIndex,
1578    bootstrap: Option<NodeIndex>,
1579    finalize: Option<NodeIndex>,
1580    triggers: Vec<PgTriggerEntity<'a>>,
1581) -> eyre::Result<HashMap<PgTriggerEntity<'a>, NodeIndex>> {
1582    let mut mapped_triggers = HashMap::default();
1583    for item in triggers {
1584        let entity: SqlGraphEntity = item.clone().into();
1585        let index = graph.add_node(entity);
1586
1587        mapped_triggers.insert(item, index);
1588        build_base_edges(graph, index, root, bootstrap, finalize);
1589    }
1590    Ok(mapped_triggers)
1591}
1592
1593fn connect_triggers<'a>(
1594    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1595    triggers: &HashMap<PgTriggerEntity<'a>, NodeIndex>,
1596    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
1597) {
1598    for (item, &index) in triggers {
1599        make_schema_connection(
1600            graph,
1601            "Trigger",
1602            index,
1603            &item.rust_identifier(),
1604            item.module_path,
1605            schemas,
1606        );
1607    }
1608}
1609
1610fn make_schema_connection<'a>(
1611    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1612    _kind: &str,
1613    index: NodeIndex,
1614    _rust_identifier: &str,
1615    module_path: &str,
1616    schemas: &HashMap<SchemaEntity<'a>, NodeIndex>,
1617) -> bool {
1618    let mut found = false;
1619    for (schema_item, &schema_index) in schemas {
1620        if module_path == schema_item.module_path {
1621            graph.add_edge(schema_index, index, SqlGraphRequires::By);
1622            found = true;
1623            break;
1624        }
1625    }
1626    found
1627}
1628
1629fn make_extern_connection<'a>(
1630    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1631    _kind: &str,
1632    index: NodeIndex,
1633    _rust_identifier: &str,
1634    full_path: &str,
1635    externs: &HashMap<PgExternEntity<'a>, NodeIndex>,
1636) -> eyre::Result<()> {
1637    match externs.iter().find(|(extern_item, _)| full_path == extern_item.full_path) {
1638        Some((_, extern_index)) => {
1639            graph.add_edge(*extern_index, index, SqlGraphRequires::By);
1640            Ok(())
1641        }
1642        None => Err(eyre!("Did not find connection `{full_path}` in {:#?}", {
1643            let mut paths = externs.keys().map(|v| v.full_path).collect::<Vec<_>>();
1644            paths.sort();
1645            paths
1646        })),
1647    }
1648}
1649
1650fn resolve_function_path(module_path: &str, path: &str) -> String {
1651    if path.contains("::") { path.to_string() } else { format!("{module_path}::{path}") }
1652}
1653
1654fn aggregate_slot(name: Option<&str>, kind: &str) -> String {
1655    name.map(|name| format!("{kind} `{name}`")).unwrap_or_else(|| kind.to_string())
1656}
1657
1658fn find_type_or_enum<'a>(
1659    type_ident: &str,
1660    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1661    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1662) -> Option<NodeIndex> {
1663    types
1664        .iter()
1665        .map(type_keyed)
1666        .chain(enums.iter().map(type_keyed))
1667        .find(|(ty, _)| ty.matches_type_ident(type_ident))
1668        .map(|(_, index)| *index)
1669}
1670
1671fn find_declared_type_or_enum<'a>(
1672    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1673    type_ident: &str,
1674) -> Option<NodeIndex> {
1675    extension_sqls.iter().find_map(|(item, index)| {
1676        item.creates
1677            .iter()
1678            .any(|declared| declared.matches_type_ident(type_ident))
1679            .then_some(*index)
1680    })
1681}
1682
1683fn find_graph_type_target<'a>(
1684    type_ident: &str,
1685    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1686    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1687    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1688) -> Option<NodeIndex> {
1689    find_type_or_enum(type_ident, types, enums)
1690        .or_else(|| find_declared_type_or_enum(extension_sqls, type_ident))
1691}
1692
1693fn ensure_unique_type_targets<'a>(
1694    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1695    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1696    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1697) -> eyre::Result<()> {
1698    let mut seen = BTreeMap::<String, Vec<String>>::new();
1699
1700    for item in types.keys() {
1701        seen.entry(item.type_ident.to_string())
1702            .or_default()
1703            .push(format!("type `{}`", item.full_path));
1704    }
1705
1706    for item in enums.keys() {
1707        seen.entry(item.type_ident.to_string())
1708            .or_default()
1709            .push(format!("enum `{}`", item.full_path));
1710    }
1711
1712    for item in extension_sqls.keys() {
1713        for declared in &item.creates {
1714            if let Some(type_ident) = declared.type_ident() {
1715                seen.entry(type_ident.to_string())
1716                    .or_default()
1717                    .push(format!("extension_sql `{}` ({declared})", item.name));
1718            }
1719        }
1720    }
1721
1722    for locations in seen.values_mut() {
1723        locations.sort();
1724    }
1725
1726    if let Some((type_ident, locations)) =
1727        seen.into_iter().find(|(_, locations)| locations.len() > 1)
1728    {
1729        return Err(eyre!(
1730            "type ident `{type_ident}` matched multiple SQL entities: {}",
1731            locations.join(", ")
1732        ));
1733    }
1734
1735    Ok(())
1736}
1737
1738fn unresolved_type_ident(
1739    owner_kind: &str,
1740    owner_name: &str,
1741    slot: &str,
1742    ty_name: &str,
1743    type_ident: &str,
1744) -> eyre::Report {
1745    eyre!(
1746        "{owner_kind} `{owner_name}` uses `{ty_name}` as {slot}, but type ident `{type_ident}` did not resolve. use `pgrx::pgrx_resolved_type!(T)` together with a matching `#[derive(PostgresType)]`, `#[derive(PostgresEnum)]`, or `extension_sql!(..., creates = [Type(T)]/[Enum(T)])`. for a manual mapping to an existing SQL type, set `TYPE_ORIGIN = TypeOrigin::External`."
1747    )
1748}
1749
1750fn initialize_resolved_type<'a>(
1751    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1752    builtin_types: &mut HashMap<String, NodeIndex>,
1753    type_ident: &str,
1754    type_origin: TypeOrigin,
1755    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1756    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1757    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1758    owner_kind: &str,
1759    owner_name: &str,
1760    slot: &str,
1761    ty_name: &str,
1762) -> eyre::Result<()> {
1763    if find_graph_type_target(type_ident, types, enums, extension_sqls).is_some() {
1764        return Ok(());
1765    }
1766
1767    if matches!(type_origin, TypeOrigin::External) {
1768        builtin_types
1769            .entry(type_ident.to_string())
1770            .or_insert_with(|| graph.add_node(SqlGraphEntity::BuiltinType(type_ident.to_string())));
1771        return Ok(());
1772    }
1773
1774    Err(unresolved_type_ident(owner_kind, owner_name, slot, ty_name, type_ident))
1775}
1776
1777fn connect_resolved_type<'a>(
1778    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1779    index: NodeIndex,
1780    requires: SqlGraphRequires,
1781    type_ident: &str,
1782    type_origin: TypeOrigin,
1783    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1784    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1785    builtin_types: &HashMap<String, NodeIndex>,
1786    extension_sqls: &HashMap<ExtensionSqlEntity<'a>, NodeIndex>,
1787    owner_kind: &str,
1788    owner_name: &str,
1789    slot: &str,
1790    ty_name: &str,
1791) -> eyre::Result<()> {
1792    if let Some(ty_index) = find_graph_type_target(type_ident, types, enums, extension_sqls) {
1793        graph.add_edge(ty_index, index, requires);
1794        return Ok(());
1795    }
1796
1797    if let Some(builtin_index) = builtin_types.get(type_ident) {
1798        graph.add_edge(*builtin_index, index, requires);
1799        return Ok(());
1800    }
1801
1802    if matches!(type_origin, TypeOrigin::External) {
1803        return Err(eyre!(
1804            "missing external-type placeholder for type ident `{type_ident}` while connecting {owner_kind} `{owner_name}` {slot}"
1805        ));
1806    }
1807
1808    Err(unresolved_type_ident(owner_kind, owner_name, slot, ty_name, type_ident))
1809}
1810
1811fn make_type_or_enum_connection<'a>(
1812    graph: &mut StableGraph<SqlGraphEntity<'a>, SqlGraphRequires>,
1813    index: NodeIndex,
1814    type_ident: &str,
1815    types: &HashMap<PostgresTypeEntity<'a>, NodeIndex>,
1816    enums: &HashMap<PostgresEnumEntity<'a>, NodeIndex>,
1817) -> bool {
1818    find_type_or_enum(type_ident, types, enums)
1819        .map(|ty_index| graph.add_edge(ty_index, index, SqlGraphRequires::By))
1820        .is_some()
1821}
1822
1823#[cfg(test)]
1824mod tests {
1825    use super::*;
1826    use crate::UsedTypeEntity;
1827    use crate::aggregate::entity::{AggregateTypeEntity, PgAggregateEntity};
1828    use crate::extension_sql::entity::{ExtensionSqlEntity, SqlDeclaredTypeEntityData};
1829    use crate::extern_args::ExternArgs;
1830    use crate::metadata::{FunctionMetadataTypeEntity, Returns, SqlArrayMapping, SqlMapping};
1831    use crate::pg_extern::entity::{PgExternArgumentEntity, PgExternEntity, PgExternReturnEntity};
1832    use crate::postgres_type::entity::PostgresTypeEntity;
1833    use crate::schema::entity::SchemaEntity;
1834    use crate::to_sql::entity::ToSqlConfigEntity;
1835
1836    fn control_file() -> ControlFile {
1837        ControlFile {
1838            comment: "test".into(),
1839            default_version: "1.0".into(),
1840            module_pathname: None,
1841            relocatable: false,
1842            superuser: true,
1843            schema: None,
1844            trusted: false,
1845        }
1846    }
1847
1848    fn to_sql_config() -> ToSqlConfigEntity<'static> {
1849        ToSqlConfigEntity { enabled: true, content: None }
1850    }
1851
1852    fn used_type(
1853        full_path: &'static str,
1854        type_ident: &'static str,
1855        sql: &'static str,
1856        type_origin: TypeOrigin,
1857    ) -> UsedTypeEntity<'static> {
1858        UsedTypeEntity {
1859            ty_source: full_path,
1860            full_path,
1861            composite_type: None,
1862            variadic: false,
1863            default: None,
1864            optional: false,
1865            metadata: FunctionMetadataTypeEntity::resolved(
1866                type_ident,
1867                type_origin,
1868                Ok(SqlMapping::literal(sql)),
1869                Ok(Returns::One(SqlMapping::literal(sql))),
1870            ),
1871        }
1872    }
1873
1874    fn external_type(
1875        full_path: &'static str,
1876        type_ident: &'static str,
1877        sql: &'static str,
1878    ) -> UsedTypeEntity<'static> {
1879        used_type(full_path, type_ident, sql, TypeOrigin::External)
1880    }
1881
1882    fn extension_owned_type(
1883        full_path: &'static str,
1884        type_ident: &'static str,
1885        sql: &'static str,
1886    ) -> UsedTypeEntity<'static> {
1887        used_type(full_path, type_ident, sql, TypeOrigin::ThisExtension)
1888    }
1889
1890    fn function_entity(
1891        name: &'static str,
1892        fn_args: Vec<PgExternArgumentEntity<'static>>,
1893        fn_return: PgExternReturnEntity<'static>,
1894    ) -> PgExternEntity<'static> {
1895        PgExternEntity {
1896            name,
1897            unaliased_name: name,
1898            module_path: "tests",
1899            full_path: Box::leak(format!("tests::{name}").into_boxed_str()),
1900            fn_args,
1901            fn_return,
1902            schema: None,
1903            file: "test.rs",
1904            line: 1,
1905            extern_attrs: vec![],
1906            search_path: None,
1907            operator: None,
1908            cast: None,
1909            to_sql_config: to_sql_config(),
1910        }
1911    }
1912
1913    fn aggregate_entity(
1914        name: &'static str,
1915        args: Vec<AggregateTypeEntity<'static>>,
1916        stype: UsedTypeEntity<'static>,
1917        mstype: Option<UsedTypeEntity<'static>>,
1918    ) -> PgAggregateEntity<'static> {
1919        PgAggregateEntity {
1920            full_path: Box::leak(format!("tests::{name}").into_boxed_str()),
1921            module_path: "tests",
1922            file: "test.rs",
1923            line: 1,
1924            name,
1925            ordered_set: false,
1926            args,
1927            direct_args: None,
1928            stype: AggregateTypeEntity { used_ty: stype, name: None },
1929            sfunc: "state_fn",
1930            finalfunc: None,
1931            finalfunc_modify: None,
1932            combinefunc: None,
1933            serialfunc: None,
1934            deserialfunc: None,
1935            initcond: None,
1936            msfunc: None,
1937            minvfunc: None,
1938            mstype,
1939            mfinalfunc: None,
1940            mfinalfunc_modify: None,
1941            minitcond: None,
1942            sortop: None,
1943            parallel: None,
1944            hypothetical: false,
1945            to_sql_config: to_sql_config(),
1946        }
1947    }
1948
1949    fn declared_type_sql(
1950        module_path: &'static str,
1951        full_path: &'static str,
1952        declaration_name: &'static str,
1953        name: &'static str,
1954        type_ident: &'static str,
1955        sql: &'static str,
1956    ) -> ExtensionSqlEntity<'static> {
1957        ExtensionSqlEntity {
1958            module_path,
1959            full_path,
1960            sql: "CREATE TYPE custom_type;",
1961            file: "test.rs",
1962            line: 1,
1963            name: declaration_name,
1964            bootstrap: false,
1965            finalize: false,
1966            requires: vec![],
1967            creates: vec![SqlDeclaredEntity::Type(SqlDeclaredTypeEntityData {
1968                sql: sql.into(),
1969                name: name.into(),
1970                type_ident: type_ident.into(),
1971            })],
1972        }
1973    }
1974
1975    fn schema_entity(module_path: &'static str, name: &'static str) -> SchemaEntity<'static> {
1976        SchemaEntity { module_path, name, file: "test.rs", line: 1 }
1977    }
1978
1979    fn type_entity(
1980        name: &'static str,
1981        full_path: &'static str,
1982        type_ident: &'static str,
1983    ) -> PostgresTypeEntity<'static> {
1984        PostgresTypeEntity {
1985            name,
1986            file: "test.rs",
1987            line: 1,
1988            full_path,
1989            module_path: "tests",
1990            type_ident,
1991            in_fn_path: "in_fn",
1992            out_fn_path: "out_fn",
1993            receive_fn_path: None,
1994            send_fn_path: None,
1995            to_sql_config: to_sql_config(),
1996            alignment: None,
1997        }
1998    }
1999
2000    fn state_function() -> PgExternEntity<'static> {
2001        function_entity("state_fn", vec![], PgExternReturnEntity::None)
2002    }
2003
2004    #[test]
2005    fn external_function_type_resolution_succeeds() {
2006        let manual_text =
2007            used_type("tests::ManualText", "tests::ManualText", "TEXT", TypeOrigin::External);
2008        let function = function_entity(
2009            "manual_text_echo",
2010            vec![PgExternArgumentEntity { pattern: "value", used_ty: manual_text.clone() }],
2011            PgExternReturnEntity::Type { ty: manual_text.clone() },
2012        );
2013
2014        let sql = PgrxSql::build(
2015            vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(function)]
2016                .into_iter(),
2017            "test".into(),
2018            false,
2019        )
2020        .unwrap();
2021
2022        assert!(sql.builtin_types.contains_key("tests::ManualText"));
2023    }
2024
2025    fn skipped_type(full_path: &'static str, type_ident: &'static str) -> UsedTypeEntity<'static> {
2026        UsedTypeEntity {
2027            ty_source: full_path,
2028            full_path,
2029            composite_type: None,
2030            variadic: false,
2031            default: None,
2032            optional: false,
2033            metadata: FunctionMetadataTypeEntity::resolved(
2034                type_ident,
2035                TypeOrigin::ThisExtension,
2036                Ok(SqlMapping::Skip),
2037                Ok(Returns::One(SqlMapping::Skip)),
2038            ),
2039        }
2040    }
2041
2042    fn explicit_composite_type(name: &'static str) -> UsedTypeEntity<'static> {
2043        UsedTypeEntity {
2044            ty_source: "pgrx::heap_tuple::PgHeapTuple<'static, AllocatedByRust>",
2045            full_path: "pgrx::heap_tuple::PgHeapTuple<'static, AllocatedByRust>",
2046            composite_type: Some(name),
2047            variadic: false,
2048            default: None,
2049            optional: false,
2050            metadata: FunctionMetadataTypeEntity::sql_only(
2051                Ok(SqlMapping::Composite),
2052                Ok(Returns::One(SqlMapping::Composite)),
2053            ),
2054        }
2055    }
2056
2057    fn explicit_composite_array_type(name: &'static str) -> UsedTypeEntity<'static> {
2058        UsedTypeEntity {
2059            ty_source: "pgrx::heap_tuple::PgHeapTuple<'static, AllocatedByRust>",
2060            full_path: "pgrx::heap_tuple::PgHeapTuple<'static, AllocatedByRust>",
2061            composite_type: Some(name),
2062            variadic: false,
2063            default: None,
2064            optional: false,
2065            metadata: FunctionMetadataTypeEntity::sql_only(
2066                Ok(SqlMapping::Array(SqlArrayMapping::Composite)),
2067                Ok(Returns::One(SqlMapping::Array(SqlArrayMapping::Composite))),
2068            ),
2069        }
2070    }
2071
2072    #[test]
2073    fn extension_sql_declared_type_orders_before_function_and_aggregate() {
2074        let custom_type = extension_owned_type("tests::HexInt", "tests::HexInt", "hexint");
2075        let declared_type = declared_type_sql(
2076            "tests",
2077            "tests::concrete_type",
2078            "concrete_type",
2079            "tests::HexInt",
2080            "tests::HexInt",
2081            "hexint",
2082        );
2083        let function = function_entity(
2084            "takes_hexint",
2085            vec![PgExternArgumentEntity { pattern: "value", used_ty: custom_type.clone() }],
2086            PgExternReturnEntity::None,
2087        );
2088        let aggregate = aggregate_entity(
2089            "hexint_accum",
2090            vec![AggregateTypeEntity { used_ty: custom_type.clone(), name: Some("value") }],
2091            custom_type.clone(),
2092            Some(custom_type.clone()),
2093        );
2094        let state_fn = state_function();
2095
2096        let sql = PgrxSql::build(
2097            vec![
2098                SqlGraphEntity::ExtensionRoot(control_file()),
2099                SqlGraphEntity::CustomSql(declared_type.clone()),
2100                SqlGraphEntity::Function(state_fn),
2101                SqlGraphEntity::Function(function.clone()),
2102                SqlGraphEntity::Aggregate(aggregate.clone()),
2103            ]
2104            .into_iter(),
2105            "test".into(),
2106            false,
2107        )
2108        .unwrap();
2109
2110        let declared_index = sql.extension_sqls[&declared_type];
2111        let function_index = sql.externs[&function];
2112        let aggregate_index = sql.aggregates[&aggregate];
2113
2114        assert!(!sql.builtin_types.contains_key("tests::HexInt"));
2115        assert!(sql.graph.find_edge(declared_index, function_index).is_some());
2116        assert!(sql.graph.find_edge(declared_index, aggregate_index).is_some());
2117    }
2118
2119    #[test]
2120    fn declared_type_cycle_prefers_explicit_requirements_with_shell_type() {
2121        let custom_type = extension_owned_type("tests::HexInt", "tests::HexInt", "hexint");
2122        let text_type = external_type("alloc::string::String", "alloc::string::String", "text");
2123
2124        let shell_type = ExtensionSqlEntity {
2125            module_path: "tests",
2126            full_path: "tests::shell_type",
2127            sql: "CREATE TYPE hexint;",
2128            file: "test.rs",
2129            line: 1,
2130            name: "shell_type",
2131            bootstrap: true,
2132            finalize: false,
2133            requires: vec![],
2134            creates: vec![],
2135        };
2136
2137        let mut hexint_in = function_entity(
2138            "hexint_in",
2139            vec![],
2140            PgExternReturnEntity::Type { ty: custom_type.clone() },
2141        );
2142        hexint_in.extern_attrs =
2143            vec![ExternArgs::Requires(vec![PositioningRef::Name("shell_type".into())])];
2144
2145        let mut hexint_out = function_entity(
2146            "hexint_out",
2147            vec![PgExternArgumentEntity { pattern: "value", used_ty: custom_type.clone() }],
2148            PgExternReturnEntity::Type { ty: text_type },
2149        );
2150        hexint_out.extern_attrs =
2151            vec![ExternArgs::Requires(vec![PositioningRef::Name("shell_type".into())])];
2152
2153        let mut declared_type = declared_type_sql(
2154            "tests",
2155            "tests::concrete_type",
2156            "concrete_type",
2157            "tests::HexInt",
2158            "tests::HexInt",
2159            "hexint",
2160        );
2161        declared_type.sql = "CREATE TYPE hexint (\n    INPUT = hexint_in,\n    OUTPUT = hexint_out,\n    LIKE = int8\n);";
2162        declared_type.requires = vec![
2163            PositioningRef::Name("shell_type".into()),
2164            PositioningRef::FullPath("tests::hexint_in".into()),
2165            PositioningRef::FullPath("tests::hexint_out".into()),
2166        ];
2167
2168        let sql = PgrxSql::build(
2169            vec![
2170                SqlGraphEntity::ExtensionRoot(control_file()),
2171                SqlGraphEntity::CustomSql(shell_type),
2172                SqlGraphEntity::CustomSql(declared_type),
2173                SqlGraphEntity::Function(hexint_in),
2174                SqlGraphEntity::Function(hexint_out),
2175            ]
2176            .into_iter(),
2177            "test".into(),
2178            false,
2179        )
2180        .unwrap()
2181        .to_sql()
2182        .unwrap();
2183
2184        let shell = sql.find("CREATE TYPE hexint;").unwrap();
2185        let input = sql.find("-- tests::hexint_in").unwrap();
2186        let output = sql.find("-- tests::hexint_out").unwrap();
2187        let concrete = sql.find("CREATE TYPE hexint (\n").unwrap();
2188
2189        assert!(shell < input);
2190        assert!(shell < output);
2191        assert!(input < concrete);
2192        assert!(output < concrete);
2193    }
2194
2195    #[test]
2196    fn extension_sql_declared_type_in_custom_schema_prefixes_aggregate_state_type() {
2197        let custom_type = extension_owned_type("tests::HexInt", "tests::HexInt", "hexint");
2198        let declared_type = declared_type_sql(
2199            "tests::custom_schema",
2200            "tests::custom_schema::hexint_sql",
2201            "hexint_sql",
2202            "tests::HexInt",
2203            "tests::HexInt",
2204            "hexint",
2205        );
2206        let aggregate =
2207            aggregate_entity("hexint_accum", vec![], custom_type.clone(), Some(custom_type));
2208        let state_fn = state_function();
2209        let schema = schema_entity("tests::custom_schema", "custom_schema");
2210
2211        let sql = PgrxSql::build(
2212            vec![
2213                SqlGraphEntity::ExtensionRoot(control_file()),
2214                SqlGraphEntity::Schema(schema),
2215                SqlGraphEntity::CustomSql(declared_type),
2216                SqlGraphEntity::Function(state_fn),
2217                SqlGraphEntity::Aggregate(aggregate),
2218            ]
2219            .into_iter(),
2220            "test".into(),
2221            false,
2222        )
2223        .unwrap()
2224        .to_sql()
2225        .unwrap();
2226
2227        assert!(sql.contains("STYPE = custom_schema.hexint"));
2228        assert!(sql.contains("MSTYPE = custom_schema.hexint"));
2229    }
2230
2231    #[test]
2232    fn skipped_function_argument_does_not_require_schema_resolution() {
2233        let function = function_entity(
2234            "skipped_arg",
2235            vec![PgExternArgumentEntity {
2236                pattern: "virtual_arg",
2237                used_ty: skipped_type("tests::VirtualArg", "tests::VirtualArg"),
2238            }],
2239            PgExternReturnEntity::None,
2240        );
2241
2242        let sql = PgrxSql::build(
2243            vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(function)]
2244                .into_iter(),
2245            "test".into(),
2246            false,
2247        )
2248        .unwrap()
2249        .to_sql()
2250        .unwrap();
2251
2252        assert!(sql.contains("skipped_arg"));
2253        assert!(!sql.contains("virtual_arg"));
2254        assert!(!sql.contains("tests::VirtualArg"));
2255    }
2256
2257    #[test]
2258    fn explicit_composite_type_does_not_require_schema_resolution() {
2259        let dog = explicit_composite_type("Dog");
2260        assert!(!dog.needs_type_resolution());
2261
2262        let function = function_entity("make_dog", vec![], PgExternReturnEntity::Type { ty: dog });
2263
2264        let sql = PgrxSql::build(
2265            vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(function)]
2266                .into_iter(),
2267            "test".into(),
2268            false,
2269        )
2270        .unwrap()
2271        .to_sql()
2272        .unwrap();
2273
2274        assert!(sql.contains("RETURNS Dog"));
2275    }
2276
2277    #[test]
2278    fn explicit_composite_array_type_does_not_require_schema_resolution() {
2279        let dog_pack = explicit_composite_array_type("Dog");
2280        assert!(!dog_pack.needs_type_resolution());
2281
2282        let function =
2283            function_entity("make_dog_pack", vec![], PgExternReturnEntity::Type { ty: dog_pack });
2284
2285        let sql = PgrxSql::build(
2286            vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(function)]
2287                .into_iter(),
2288            "test".into(),
2289            false,
2290        )
2291        .unwrap()
2292        .to_sql()
2293        .unwrap();
2294
2295        assert!(sql.contains("RETURNS Dog[]"));
2296    }
2297
2298    #[test]
2299    fn explicit_composite_array_aggregate_state_does_not_require_schema_resolution() {
2300        let stype = explicit_composite_array_type("Dog");
2301        assert!(!stype.needs_type_resolution());
2302        let mstype = explicit_composite_array_type("Dog");
2303        assert!(!mstype.needs_type_resolution());
2304
2305        let aggregate = aggregate_entity("pack_dogs", vec![], stype, Some(mstype));
2306
2307        let sql = PgrxSql::build(
2308            vec![
2309                SqlGraphEntity::ExtensionRoot(control_file()),
2310                SqlGraphEntity::Function(state_function()),
2311                SqlGraphEntity::Aggregate(aggregate),
2312            ]
2313            .into_iter(),
2314            "test".into(),
2315            false,
2316        )
2317        .unwrap()
2318        .to_sql()
2319        .unwrap();
2320
2321        assert!(sql.contains("STYPE = Dog[]"));
2322        assert!(sql.contains("MSTYPE = Dog[]"));
2323    }
2324
2325    #[test]
2326    fn duplicate_type_ident_errors() {
2327        let left = type_entity("LeftType", "tests::LeftType", "tests::SharedType");
2328        let right = type_entity("RightType", "tests::RightType", "tests::SharedType");
2329
2330        let error = PgrxSql::build(
2331            vec![
2332                SqlGraphEntity::ExtensionRoot(control_file()),
2333                SqlGraphEntity::Type(left),
2334                SqlGraphEntity::Type(right),
2335            ]
2336            .into_iter(),
2337            "test".into(),
2338            false,
2339        )
2340        .expect_err("duplicate type idents should fail");
2341
2342        assert!(error.to_string().contains("tests::SharedType"));
2343        assert!(error.to_string().contains("tests::LeftType"));
2344        assert!(error.to_string().contains("tests::RightType"));
2345    }
2346
2347    #[test]
2348    fn unresolved_function_argument_type_ident_errors() {
2349        let bad_type = extension_owned_type("tests::BadArg", "tests::BadArg", "TEXT");
2350        let function = function_entity(
2351            "bad_arg",
2352            vec![PgExternArgumentEntity { pattern: "value", used_ty: bad_type }],
2353            PgExternReturnEntity::None,
2354        );
2355
2356        let error = PgrxSql::build(
2357            vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(function)]
2358                .into_iter(),
2359            "test".into(),
2360            false,
2361        )
2362        .expect_err("function argument should fail");
2363
2364        assert!(error.to_string().contains("Function `tests::bad_arg`"));
2365        assert!(error.to_string().contains("argument `value`"));
2366        assert!(error.to_string().contains("tests::BadArg"));
2367    }
2368
2369    #[test]
2370    fn unresolved_function_return_type_ident_errors() {
2371        let bad_type = extension_owned_type("tests::BadReturn", "tests::BadReturn", "TEXT");
2372        let function =
2373            function_entity("bad_return", vec![], PgExternReturnEntity::Type { ty: bad_type });
2374
2375        let error = PgrxSql::build(
2376            vec![SqlGraphEntity::ExtensionRoot(control_file()), SqlGraphEntity::Function(function)]
2377                .into_iter(),
2378            "test".into(),
2379            false,
2380        )
2381        .expect_err("function return should fail");
2382
2383        assert!(error.to_string().contains("Function `tests::bad_return`"));
2384        assert!(error.to_string().contains("return type"));
2385        assert!(error.to_string().contains("tests::BadReturn"));
2386    }
2387
2388    #[test]
2389    fn unresolved_aggregate_argument_type_ident_errors() {
2390        let aggregate = aggregate_entity(
2391            "bad_aggregate_arg",
2392            vec![AggregateTypeEntity {
2393                used_ty: extension_owned_type("tests::BadArg", "tests::BadArg", "TEXT"),
2394                name: Some("value"),
2395            }],
2396            external_type("tests::State", "tests::State", "TEXT"),
2397            None,
2398        );
2399
2400        let error = PgrxSql::build(
2401            vec![
2402                SqlGraphEntity::ExtensionRoot(control_file()),
2403                SqlGraphEntity::Function(state_function()),
2404                SqlGraphEntity::Aggregate(aggregate),
2405            ]
2406            .into_iter(),
2407            "test".into(),
2408            false,
2409        )
2410        .expect_err("aggregate argument should fail");
2411
2412        assert!(error.to_string().contains("Aggregate `tests::bad_aggregate_arg`"));
2413        assert!(error.to_string().contains("argument `value`"));
2414        assert!(error.to_string().contains("tests::BadArg"));
2415    }
2416
2417    #[test]
2418    fn unresolved_aggregate_stype_type_ident_errors() {
2419        let aggregate = aggregate_entity(
2420            "bad_aggregate_stype",
2421            vec![],
2422            extension_owned_type("tests::BadState", "tests::BadState", "TEXT"),
2423            None,
2424        );
2425
2426        let error = PgrxSql::build(
2427            vec![
2428                SqlGraphEntity::ExtensionRoot(control_file()),
2429                SqlGraphEntity::Function(state_function()),
2430                SqlGraphEntity::Aggregate(aggregate),
2431            ]
2432            .into_iter(),
2433            "test".into(),
2434            false,
2435        )
2436        .expect_err("aggregate stype should fail");
2437
2438        assert!(error.to_string().contains("Aggregate `tests::bad_aggregate_stype`"));
2439        assert!(error.to_string().contains("STYPE"));
2440        assert!(error.to_string().contains("tests::BadState"));
2441    }
2442
2443    #[test]
2444    fn unresolved_aggregate_mstype_type_ident_errors() {
2445        let aggregate = aggregate_entity(
2446            "bad_aggregate_mstype",
2447            vec![],
2448            external_type("tests::State", "tests::State", "TEXT"),
2449            Some(extension_owned_type("tests::BadMovingState", "tests::BadMovingState", "TEXT")),
2450        );
2451
2452        let error = PgrxSql::build(
2453            vec![
2454                SqlGraphEntity::ExtensionRoot(control_file()),
2455                SqlGraphEntity::Function(state_function()),
2456                SqlGraphEntity::Aggregate(aggregate),
2457            ]
2458            .into_iter(),
2459            "test".into(),
2460            false,
2461        )
2462        .expect_err("aggregate mstype should fail");
2463
2464        assert!(error.to_string().contains("Aggregate `tests::bad_aggregate_mstype`"));
2465        assert!(error.to_string().contains("MSTYPE"));
2466        assert!(error.to_string().contains("tests::BadMovingState"));
2467    }
2468}