pgrx_sql_entity_graph/extension_sql/
entity.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`pgrx::extension_sql!()` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19use crate::extension_sql::SqlDeclared;
20use crate::pgrx_sql::PgrxSql;
21use crate::positioning_ref::PositioningRef;
22use crate::to_sql::ToSql;
23use crate::{SqlGraphEntity, SqlGraphIdentifier};
24
25use std::fmt::Display;
26
27/// The output of a [`ExtensionSql`](crate::ExtensionSql) from `quote::ToTokens::to_tokens`.
28#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
29pub struct ExtensionSqlEntity {
30    pub module_path: &'static str,
31    pub full_path: &'static str,
32    pub sql: &'static str,
33    pub file: &'static str,
34    pub line: u32,
35    pub name: &'static str,
36    pub bootstrap: bool,
37    pub finalize: bool,
38    pub requires: Vec<PositioningRef>,
39    pub creates: Vec<SqlDeclaredEntity>,
40}
41
42impl ExtensionSqlEntity {
43    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
44        self.creates.iter().find(|created| created.has_sql_declared_entity(identifier))
45    }
46}
47
48impl From<ExtensionSqlEntity> for SqlGraphEntity {
49    fn from(val: ExtensionSqlEntity) -> Self {
50        SqlGraphEntity::CustomSql(val)
51    }
52}
53
54impl SqlGraphIdentifier for ExtensionSqlEntity {
55    fn dot_identifier(&self) -> String {
56        format!("sql {}", self.name)
57    }
58    fn rust_identifier(&self) -> String {
59        self.name.to_string()
60    }
61
62    fn file(&self) -> Option<&'static str> {
63        Some(self.file)
64    }
65
66    fn line(&self) -> Option<u32> {
67        Some(self.line)
68    }
69}
70
71impl ToSql for ExtensionSqlEntity {
72    fn to_sql(&self, _context: &PgrxSql) -> eyre::Result<String> {
73        let ExtensionSqlEntity { file, line, sql, creates, requires, .. } = self;
74        let creates = if !creates.is_empty() {
75            let joined =
76                creates.iter().map(|i| format!("--   {}", i)).collect::<Vec<_>>().join("\n");
77            format!(
78                "\
79                -- creates:\n\
80                {joined}\n\n"
81            )
82        } else {
83            "".to_string()
84        };
85        let requires = if !requires.is_empty() {
86            let joined =
87                requires.iter().map(|i| format!("--   {}", i)).collect::<Vec<_>>().join("\n");
88            format!(
89                "\
90               -- requires:\n\
91                {joined}\n\n"
92            )
93        } else {
94            "".to_string()
95        };
96        let sql = format!(
97            "\n\
98                -- {file}:{line}\n\
99                {bootstrap}\
100                {creates}\
101                {requires}\
102                {finalize}\
103                {sql}\
104                ",
105            bootstrap = if self.bootstrap { "-- bootstrap\n" } else { "" },
106            finalize = if self.finalize { "-- finalize\n" } else { "" },
107        );
108        Ok(sql)
109    }
110}
111
112#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
113pub struct SqlDeclaredEntityData {
114    sql: String,
115    name: String,
116    option: String,
117    vec: String,
118    vec_option: String,
119    option_vec: String,
120    option_vec_option: String,
121    array: String,
122    option_array: String,
123    varlena: String,
124    pg_box: Vec<String>,
125}
126#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
127pub enum SqlDeclaredEntity {
128    Type(SqlDeclaredEntityData),
129    Enum(SqlDeclaredEntityData),
130    Function(SqlDeclaredEntityData),
131}
132
133impl Display for SqlDeclaredEntity {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        match self {
136            SqlDeclaredEntity::Type(data) => {
137                write!(f, "Type({})", data.name)
138            }
139            SqlDeclaredEntity::Enum(data) => {
140                write!(f, "Enum({})", data.name)
141            }
142            SqlDeclaredEntity::Function(data) => {
143                write!(f, "Function({})", data.name)
144            }
145        }
146    }
147}
148
149impl SqlDeclaredEntity {
150    pub fn build(variant: impl AsRef<str>, name: impl AsRef<str>) -> eyre::Result<Self> {
151        let name = name.as_ref();
152        let data = SqlDeclaredEntityData {
153            sql: name
154                .split("::")
155                .last()
156                .ok_or_else(|| eyre::eyre!("Did not get SQL for `{}`", name))?
157                .to_string(),
158            name: name.to_string(),
159            option: format!("Option<{}>", name),
160            vec: format!("Vec<{}>", name),
161            vec_option: format!("Vec<Option<{}>>", name),
162            option_vec: format!("Option<Vec<{}>>", name),
163            option_vec_option: format!("Option<Vec<Option<{}>>", name),
164            array: format!("Array<{}>", name),
165            option_array: format!("Option<{}>", name),
166            varlena: format!("Varlena<{}>", name),
167            pg_box: vec![
168                format!("pgrx::pgbox::PgBox<{}>", name),
169                format!("pgrx::pgbox::PgBox<{}, pgrx::pgbox::AllocatedByRust>", name),
170                format!("pgrx::pgbox::PgBox<{}, pgrx::pgbox::AllocatedByPostgres>", name),
171            ],
172        };
173        let retval = match variant.as_ref() {
174            "Type" => Self::Type(data),
175            "Enum" => Self::Enum(data),
176            "Function" => Self::Function(data),
177            _ => {
178                return Err(eyre::eyre!(
179                    "Can only declare `Type(Ident)`, `Enum(Ident)` or `Function(Ident)`"
180                ))
181            }
182        };
183        Ok(retval)
184    }
185    pub fn sql(&self) -> String {
186        match self {
187            SqlDeclaredEntity::Type(data) => data.sql.clone(),
188            SqlDeclaredEntity::Enum(data) => data.sql.clone(),
189            SqlDeclaredEntity::Function(data) => data.sql.clone(),
190        }
191    }
192
193    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> bool {
194        match (&identifier, &self) {
195            (SqlDeclared::Type(ident_name), &SqlDeclaredEntity::Type(data))
196            | (SqlDeclared::Enum(ident_name), &SqlDeclaredEntity::Enum(data))
197            | (SqlDeclared::Function(ident_name), &SqlDeclaredEntity::Function(data)) => {
198                let matches = |identifier_name: &str| {
199                    identifier_name == data.name
200                        || identifier_name == data.option
201                        || identifier_name == data.vec
202                        || identifier_name == data.vec_option
203                        || identifier_name == data.option_vec
204                        || identifier_name == data.option_vec_option
205                        || identifier_name == data.array
206                        || identifier_name == data.option_array
207                        || identifier_name == data.varlena
208                };
209                if matches(ident_name) || data.pg_box.contains(ident_name) {
210                    return true;
211                }
212                // there are cases where the identifier is
213                // `core::option::Option<Foo>` while the data stores
214                // `Option<Foo>` check again for this
215                let Some(generics_start) = ident_name.find('<') else { return false };
216                let Some(qual_end) = ident_name[..generics_start].rfind("::") else { return false };
217                matches(&ident_name[qual_end + 2..])
218            }
219            _ => false,
220        }
221    }
222}