Skip to main content

pgrx_sql_entity_graph/extension_sql/
entity.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`pgrx::extension_sql!()` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19use crate::extension_sql::SqlDeclared;
20use crate::metadata::{SqlMapping, SqlTranslatable, TypeOrigin};
21use crate::pgrx_sql::PgrxSql;
22use crate::positioning_ref::PositioningRef;
23use crate::to_sql::ToSql;
24use crate::{SqlGraphEntity, SqlGraphIdentifier};
25
26use std::fmt::Display;
27
28/// The output of a [`ExtensionSql`](crate::ExtensionSql) from `quote::ToTokens::to_tokens`.
29#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
30pub struct ExtensionSqlEntity<'a> {
31    pub module_path: &'a str,
32    pub full_path: &'a str,
33    pub sql: &'a str,
34    pub file: &'a str,
35    pub line: u32,
36    pub name: &'a str,
37    pub bootstrap: bool,
38    pub finalize: bool,
39    pub requires: Vec<PositioningRef>,
40    pub creates: Vec<SqlDeclaredEntity>,
41}
42
43impl ExtensionSqlEntity<'_> {
44    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
45        self.creates.iter().find(|created| created.has_sql_declared_entity(identifier))
46    }
47}
48
49impl<'a> From<ExtensionSqlEntity<'a>> for SqlGraphEntity<'a> {
50    fn from(val: ExtensionSqlEntity<'a>) -> Self {
51        SqlGraphEntity::CustomSql(val)
52    }
53}
54
55impl SqlGraphIdentifier for ExtensionSqlEntity<'_> {
56    fn dot_identifier(&self) -> String {
57        format!("sql {}", self.name)
58    }
59    fn rust_identifier(&self) -> String {
60        self.name.to_string()
61    }
62
63    fn file(&self) -> Option<&str> {
64        Some(self.file)
65    }
66
67    fn line(&self) -> Option<u32> {
68        Some(self.line)
69    }
70}
71
72impl ToSql for ExtensionSqlEntity<'_> {
73    fn to_sql(&self, _context: &PgrxSql) -> eyre::Result<String> {
74        let ExtensionSqlEntity { file, line, sql, creates, requires, .. } = self;
75        let creates = if !creates.is_empty() {
76            let joined = creates.iter().map(|i| format!("--   {i}")).collect::<Vec<_>>().join("\n");
77            format!(
78                "\
79                -- creates:\n\
80                {joined}\n\n"
81            )
82        } else {
83            "".to_string()
84        };
85        let requires = if !requires.is_empty() {
86            let joined =
87                requires.iter().map(|i| format!("--   {i}")).collect::<Vec<_>>().join("\n");
88            format!(
89                "\
90               -- requires:\n\
91                {joined}\n\n"
92            )
93        } else {
94            "".to_string()
95        };
96        let sql = format!(
97            "\n\
98                -- {file}:{line}\n\
99                {bootstrap}\
100                {creates}\
101                {requires}\
102                {finalize}\
103                {sql}\
104                ",
105            bootstrap = if self.bootstrap { "-- bootstrap\n" } else { "" },
106            finalize = if self.finalize { "-- finalize\n" } else { "" },
107        );
108        Ok(sql)
109    }
110}
111
112#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
113pub struct SqlDeclaredTypeEntityData {
114    pub(crate) sql: String,
115    pub(crate) name: String,
116    pub(crate) type_ident: String,
117}
118
119#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
120pub struct SqlDeclaredFunctionEntityData {
121    pub(crate) sql: String,
122    pub(crate) name: String,
123}
124
125#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
126pub enum SqlDeclaredEntity {
127    Type(SqlDeclaredTypeEntityData),
128    Enum(SqlDeclaredTypeEntityData),
129    Function(SqlDeclaredFunctionEntityData),
130}
131
132impl Display for SqlDeclaredEntity {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        match self {
135            SqlDeclaredEntity::Type(data) => {
136                write!(f, "Type({})", data.name)
137            }
138            SqlDeclaredEntity::Enum(data) => {
139                write!(f, "Enum({})", data.name)
140            }
141            SqlDeclaredEntity::Function(data) => {
142                write!(f, "Function({})", data.name)
143            }
144        }
145    }
146}
147
148impl SqlDeclaredEntity {
149    pub fn build(variant: &str, name: &str) -> eyre::Result<Self> {
150        let sql = name
151            .split("::")
152            .last()
153            .ok_or_else(|| eyre::eyre!("Did not get SQL for `{}`", name))?
154            .to_string();
155        let retval = match variant {
156            "Type" => Self::Type(SqlDeclaredTypeEntityData {
157                sql,
158                name: name.to_string(),
159                type_ident: name.to_string(),
160            }),
161            "Enum" => Self::Enum(SqlDeclaredTypeEntityData {
162                sql,
163                name: name.to_string(),
164                type_ident: name.to_string(),
165            }),
166            "Function" => {
167                Self::Function(SqlDeclaredFunctionEntityData { sql, name: name.to_string() })
168            }
169            _ => {
170                return Err(eyre::eyre!(
171                    "Can only declare `Type(Ident)`, `Enum(Ident)` or `Function(Ident)`"
172                ));
173            }
174        };
175        Ok(retval)
176    }
177
178    pub fn build_type<T: SqlTranslatable>(variant: &str, name: &str) -> eyre::Result<Self> {
179        let make_declared = match variant {
180            "Type" => Self::Type,
181            "Enum" => Self::Enum,
182            _ => {
183                return Err(eyre::eyre!(
184                    "Can only declare `Type(Ident)` or `Enum(Ident)` with type metadata"
185                ));
186            }
187        };
188
189        if matches!(T::TYPE_ORIGIN, TypeOrigin::External) {
190            return Err(eyre::eyre!(
191                "`creates = [{variant}(...)]` is only valid for extension-owned SQL types"
192            ));
193        }
194
195        let sql = match T::argument_sql() {
196            Ok(SqlMapping::As(sql)) => sql,
197            Ok(SqlMapping::Composite | SqlMapping::Array(_)) => {
198                return Err(eyre::eyre!(
199                    "`creates = [{variant}(...)]` requires a concrete SQL type name"
200                ));
201            }
202            Ok(SqlMapping::Skip) => {
203                return Err(eyre::eyre!(
204                    "`creates = [{variant}(...)]` cannot use a skipped SQL type"
205                ));
206            }
207            Err(err) => return Err(err.into()),
208        };
209        let data = SqlDeclaredTypeEntityData {
210            sql,
211            name: name.to_string(),
212            type_ident: T::TYPE_IDENT.to_string(),
213        };
214        Ok(make_declared(data))
215    }
216
217    pub fn sql(&self) -> String {
218        match self {
219            SqlDeclaredEntity::Type(data) => data.sql.clone(),
220            SqlDeclaredEntity::Enum(data) => data.sql.clone(),
221            SqlDeclaredEntity::Function(data) => data.sql.clone(),
222        }
223    }
224
225    pub fn type_ident(&self) -> Option<&str> {
226        match self {
227            SqlDeclaredEntity::Type(data) | SqlDeclaredEntity::Enum(data) => {
228                Some(data.type_ident.as_str())
229            }
230            SqlDeclaredEntity::Function(_) => None,
231        }
232    }
233
234    pub fn matches_type_ident(&self, type_ident: &str) -> bool {
235        matches!(self.type_ident(), Some(value) if value == type_ident)
236    }
237
238    pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> bool {
239        match (&identifier, &self) {
240            (SqlDeclared::Type(ident_name), &SqlDeclaredEntity::Type(data))
241            | (SqlDeclared::Enum(ident_name), &SqlDeclaredEntity::Enum(data)) => {
242                if ident_name == &data.name || ident_name == &data.type_ident {
243                    return true;
244                }
245                false
246            }
247            (SqlDeclared::Function(ident_name), &SqlDeclaredEntity::Function(data)) => {
248                ident_name == &data.name
249            }
250            _ => false,
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::metadata::{ArgumentError, ReturnsError, ReturnsRef, SqlMappingRef, TypeOrigin};
259
260    struct ExtensionOwnedType;
261    struct ExternalType;
262
263    unsafe impl SqlTranslatable for ExtensionOwnedType {
264        const TYPE_IDENT: &'static str = "tests::ExtensionOwnedType";
265        const TYPE_ORIGIN: TypeOrigin = TypeOrigin::ThisExtension;
266        const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
267            Ok(SqlMappingRef::literal("extension_owned"));
268        const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
269            Ok(ReturnsRef::One(SqlMappingRef::literal("extension_owned")));
270    }
271
272    unsafe impl SqlTranslatable for ExternalType {
273        const TYPE_IDENT: &'static str = "tests::ExternalType";
274        const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
275        const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
276            Ok(SqlMappingRef::literal("text"));
277        const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
278            Ok(ReturnsRef::One(SqlMappingRef::literal("text")));
279    }
280
281    #[test]
282    fn build_type_accepts_extension_owned_types() {
283        let declared = SqlDeclaredEntity::build_type::<ExtensionOwnedType>(
284            "Type",
285            "tests::ExtensionOwnedType",
286        )
287        .unwrap();
288
289        assert_eq!(declared.type_ident(), Some("tests::ExtensionOwnedType"));
290        assert_eq!(declared.sql(), "extension_owned");
291    }
292
293    #[test]
294    fn build_type_rejects_external_types() {
295        let error = SqlDeclaredEntity::build_type::<ExternalType>("Type", "tests::ExternalType")
296            .unwrap_err();
297        assert!(error.to_string().contains("only valid for extension-owned SQL types"));
298
299        let error = SqlDeclaredEntity::build_type::<ExternalType>("Enum", "tests::ExternalType")
300            .unwrap_err();
301        assert!(error.to_string().contains("only valid for extension-owned SQL types"));
302    }
303
304    #[test]
305    fn function_declarations_do_not_carry_type_idents() {
306        let declared = SqlDeclaredEntity::build("Function", "tests::helper_fn").unwrap();
307
308        assert_eq!(declared.type_ident(), None);
309        assert_eq!(declared.sql(), "helper_fn");
310        assert!(
311            declared.has_sql_declared_entity(&SqlDeclared::Function("tests::helper_fn".into()))
312        );
313    }
314}