pgx_sql_entity_graph/extension_sql/
entity.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9/*!
10
11`pgx::extension_sql!()` related entities for Rust to SQL translation
12
13> Like all of the [`sql_entity_graph`][crate::pgx_sql_entity_graph] APIs, this is considered **internal**
14to the `pgx` framework and very subject to change between versions. While you may use this, please do it with caution.
15
16
17*/
18use crate::extension_sql::SqlDeclared;
19use crate::pgx_sql::PgxSql;
20use crate::positioning_ref::PositioningRef;
21use crate::to_sql::ToSql;
22use crate::{SqlGraphEntity, SqlGraphIdentifier};
23
24use std::fmt::Display;
25
26/// The output of a [`ExtensionSql`](crate::ExtensionSql) from `quote::ToTokens::to_tokens`.
27#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
28pub struct ExtensionSqlEntity {
29    pub module_path: &'static str,
30    pub full_path: &'static str,
31    pub sql: &'static str,
32    pub file: &'static str,
33    pub line: u32,
34    pub name: &'static str,
35    pub bootstrap: bool,
36    pub finalize: bool,
37    pub requires: Vec<PositioningRef>,
38    pub creates: Vec<SqlDeclaredEntity>,
39}
40
41impl ExtensionSqlEntity {
42    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
43        self.creates.iter().find(|created| created.has_sql_declared_entity(identifier))
44    }
45}
46
47impl From<ExtensionSqlEntity> for SqlGraphEntity {
48    fn from(val: ExtensionSqlEntity) -> Self {
49        SqlGraphEntity::CustomSql(val)
50    }
51}
52
53impl SqlGraphIdentifier for ExtensionSqlEntity {
54    fn dot_identifier(&self) -> String {
55        format!("sql {}", self.name)
56    }
57    fn rust_identifier(&self) -> String {
58        self.name.to_string()
59    }
60
61    fn file(&self) -> Option<&'static str> {
62        Some(self.file)
63    }
64
65    fn line(&self) -> Option<u32> {
66        Some(self.line)
67    }
68}
69
70impl ToSql for ExtensionSqlEntity {
71    fn to_sql(&self, _context: &PgxSql) -> eyre::Result<String> {
72        let sql = format!(
73            "\n\
74                -- {file}:{line}\n\
75                {bootstrap}\
76                {creates}\
77                {requires}\
78                {finalize}\
79                {sql}\
80                ",
81            file = self.file,
82            line = self.line,
83            bootstrap = if self.bootstrap { "-- bootstrap\n" } else { "" },
84            creates = if !self.creates.is_empty() {
85                format!(
86                    "\
87                    -- creates:\n\
88                    {}\n\
89                ",
90                    self.creates
91                        .iter()
92                        .map(|i| format!("--   {}", i))
93                        .collect::<Vec<_>>()
94                        .join("\n")
95                ) + "\n"
96            } else {
97                "".to_string()
98            },
99            requires = if !self.requires.is_empty() {
100                format!(
101                    "\
102                   -- requires:\n\
103                    {}\n\
104                ",
105                    self.requires
106                        .iter()
107                        .map(|i| format!("--   {}", i))
108                        .collect::<Vec<_>>()
109                        .join("\n")
110                ) + "\n"
111            } else {
112                "".to_string()
113            },
114            finalize = if self.finalize { "-- finalize\n" } else { "" },
115            sql = self.sql,
116        );
117        Ok(sql)
118    }
119}
120
121#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
122pub struct SqlDeclaredEntityData {
123    sql: String,
124    name: String,
125    option: String,
126    vec: String,
127    vec_option: String,
128    option_vec: String,
129    option_vec_option: String,
130    array: String,
131    option_array: String,
132    varlena: String,
133    pg_box: Vec<String>,
134}
135#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
136pub enum SqlDeclaredEntity {
137    Type(SqlDeclaredEntityData),
138    Enum(SqlDeclaredEntityData),
139    Function(SqlDeclaredEntityData),
140}
141
142impl Display for SqlDeclaredEntity {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        match self {
145            SqlDeclaredEntity::Type(data) => {
146                f.write_str(&(String::from("Type(") + &data.name + ")"))
147            }
148            SqlDeclaredEntity::Enum(data) => {
149                f.write_str(&(String::from("Enum(") + &data.name + ")"))
150            }
151            SqlDeclaredEntity::Function(data) => {
152                f.write_str(&(String::from("Function ") + &data.name + ")"))
153            }
154        }
155    }
156}
157
158impl SqlDeclaredEntity {
159    pub fn build(variant: impl AsRef<str>, name: impl AsRef<str>) -> eyre::Result<Self> {
160        let name = name.as_ref();
161        let data = SqlDeclaredEntityData {
162            sql: name
163                .split("::")
164                .last()
165                .ok_or_else(|| eyre::eyre!("Did not get SQL for `{}`", name))?
166                .to_string(),
167            name: name.to_string(),
168            option: format!("Option<{}>", name),
169            vec: format!("Vec<{}>", name),
170            vec_option: format!("Vec<Option<{}>>", name),
171            option_vec: format!("Option<Vec<{}>>", name),
172            option_vec_option: format!("Option<Vec<Option<{}>>", name),
173            array: format!("Array<{}>", name),
174            option_array: format!("Option<{}>", name),
175            varlena: format!("Varlena<{}>", name),
176            pg_box: vec![
177                format!("pgx::pgbox::PgBox<{}>", name),
178                format!("pgx::pgbox::PgBox<{}, pgx::pgbox::AllocatedByRust>", name),
179                format!("pgx::pgbox::PgBox<{}, pgx::pgbox::AllocatedByPostgres>", name),
180            ],
181        };
182        let retval = match variant.as_ref() {
183            "Type" => Self::Type(data),
184            "Enum" => Self::Enum(data),
185            "Function" => Self::Function(data),
186            _ => {
187                return Err(eyre::eyre!(
188                    "Can only declare `Type(Ident)`, `Enum(Ident)` or `Function(Ident)`"
189                ))
190            }
191        };
192        Ok(retval)
193    }
194    pub fn sql(&self) -> String {
195        match self {
196            SqlDeclaredEntity::Type(data) => data.sql.clone(),
197            SqlDeclaredEntity::Enum(data) => data.sql.clone(),
198            SqlDeclaredEntity::Function(data) => data.sql.clone(),
199        }
200    }
201
202    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> bool {
203        match (&identifier, &self) {
204            (SqlDeclared::Type(identifier_name), &SqlDeclaredEntity::Type(data))
205            | (SqlDeclared::Enum(identifier_name), &SqlDeclaredEntity::Enum(data))
206            | (SqlDeclared::Function(identifier_name), &SqlDeclaredEntity::Function(data)) => {
207                let matches = |identifier_name: &str| {
208                    identifier_name == &data.name
209                        || identifier_name == &data.option
210                        || identifier_name == &data.vec
211                        || identifier_name == &data.vec_option
212                        || identifier_name == &data.option_vec
213                        || identifier_name == &data.option_vec_option
214                        || identifier_name == &data.array
215                        || identifier_name == &data.option_array
216                        || identifier_name == &data.varlena
217                };
218                if matches(&*identifier_name) || data.pg_box.contains(&identifier_name) {
219                    return true;
220                }
221                // there are cases where the identifier is
222                // `core::option::Option<Foo>` while the data stores
223                // `Option<Foo>` check again for this
224                let generics_start = match identifier_name.find('<') {
225                    None => return false,
226                    Some(idx) => idx,
227                };
228                let qualification_end = match identifier_name[..generics_start].rfind("::") {
229                    None => return false,
230                    Some(idx) => idx,
231                };
232                matches(&identifier_name[qualification_end + 2..])
233            }
234            _ => false,
235        }
236    }
237}