pgrx_sql_entity_graph/extension_sql/
entity.rs1use crate::extension_sql::SqlDeclared;
20use crate::metadata::{SqlMapping, SqlTranslatable, TypeOrigin};
21use crate::pgrx_sql::PgrxSql;
22use crate::positioning_ref::PositioningRef;
23use crate::to_sql::ToSql;
24use crate::{SqlGraphEntity, SqlGraphIdentifier};
25
26use std::fmt::Display;
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
30pub struct ExtensionSqlEntity<'a> {
31 pub module_path: &'a str,
32 pub full_path: &'a str,
33 pub sql: &'a str,
34 pub file: &'a str,
35 pub line: u32,
36 pub name: &'a str,
37 pub bootstrap: bool,
38 pub finalize: bool,
39 pub requires: Vec<PositioningRef>,
40 pub creates: Vec<SqlDeclaredEntity>,
41}
42
43impl ExtensionSqlEntity<'_> {
44 pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
45 self.creates.iter().find(|created| created.has_sql_declared_entity(identifier))
46 }
47}
48
49impl<'a> From<ExtensionSqlEntity<'a>> for SqlGraphEntity<'a> {
50 fn from(val: ExtensionSqlEntity<'a>) -> Self {
51 SqlGraphEntity::CustomSql(val)
52 }
53}
54
55impl SqlGraphIdentifier for ExtensionSqlEntity<'_> {
56 fn dot_identifier(&self) -> String {
57 format!("sql {}", self.name)
58 }
59 fn rust_identifier(&self) -> String {
60 self.name.to_string()
61 }
62
63 fn file(&self) -> Option<&str> {
64 Some(self.file)
65 }
66
67 fn line(&self) -> Option<u32> {
68 Some(self.line)
69 }
70}
71
72impl ToSql for ExtensionSqlEntity<'_> {
73 fn to_sql(&self, _context: &PgrxSql) -> eyre::Result<String> {
74 let ExtensionSqlEntity { file, line, sql, creates, requires, .. } = self;
75 let creates = if !creates.is_empty() {
76 let joined = creates.iter().map(|i| format!("-- {i}")).collect::<Vec<_>>().join("\n");
77 format!(
78 "\
79 -- creates:\n\
80 {joined}\n\n"
81 )
82 } else {
83 "".to_string()
84 };
85 let requires = if !requires.is_empty() {
86 let joined =
87 requires.iter().map(|i| format!("-- {i}")).collect::<Vec<_>>().join("\n");
88 format!(
89 "\
90 -- requires:\n\
91 {joined}\n\n"
92 )
93 } else {
94 "".to_string()
95 };
96 let sql = format!(
97 "\n\
98 -- {file}:{line}\n\
99 {bootstrap}\
100 {creates}\
101 {requires}\
102 {finalize}\
103 {sql}\
104 ",
105 bootstrap = if self.bootstrap { "-- bootstrap\n" } else { "" },
106 finalize = if self.finalize { "-- finalize\n" } else { "" },
107 );
108 Ok(sql)
109 }
110}
111
112#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
113pub struct SqlDeclaredTypeEntityData {
114 pub(crate) sql: String,
115 pub(crate) name: String,
116 pub(crate) type_ident: String,
117}
118
119#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
120pub struct SqlDeclaredFunctionEntityData {
121 pub(crate) sql: String,
122 pub(crate) name: String,
123}
124
125#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
126pub enum SqlDeclaredEntity {
127 Type(SqlDeclaredTypeEntityData),
128 Enum(SqlDeclaredTypeEntityData),
129 Function(SqlDeclaredFunctionEntityData),
130}
131
132impl Display for SqlDeclaredEntity {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 match self {
135 SqlDeclaredEntity::Type(data) => {
136 write!(f, "Type({})", data.name)
137 }
138 SqlDeclaredEntity::Enum(data) => {
139 write!(f, "Enum({})", data.name)
140 }
141 SqlDeclaredEntity::Function(data) => {
142 write!(f, "Function({})", data.name)
143 }
144 }
145 }
146}
147
148impl SqlDeclaredEntity {
149 pub fn build(variant: &str, name: &str) -> eyre::Result<Self> {
150 let sql = name
151 .split("::")
152 .last()
153 .ok_or_else(|| eyre::eyre!("Did not get SQL for `{}`", name))?
154 .to_string();
155 let retval = match variant {
156 "Type" => Self::Type(SqlDeclaredTypeEntityData {
157 sql,
158 name: name.to_string(),
159 type_ident: name.to_string(),
160 }),
161 "Enum" => Self::Enum(SqlDeclaredTypeEntityData {
162 sql,
163 name: name.to_string(),
164 type_ident: name.to_string(),
165 }),
166 "Function" => {
167 Self::Function(SqlDeclaredFunctionEntityData { sql, name: name.to_string() })
168 }
169 _ => {
170 return Err(eyre::eyre!(
171 "Can only declare `Type(Ident)`, `Enum(Ident)` or `Function(Ident)`"
172 ));
173 }
174 };
175 Ok(retval)
176 }
177
178 pub fn build_type<T: SqlTranslatable>(variant: &str, name: &str) -> eyre::Result<Self> {
179 let make_declared = match variant {
180 "Type" => Self::Type,
181 "Enum" => Self::Enum,
182 _ => {
183 return Err(eyre::eyre!(
184 "Can only declare `Type(Ident)` or `Enum(Ident)` with type metadata"
185 ));
186 }
187 };
188
189 if matches!(T::TYPE_ORIGIN, TypeOrigin::External) {
190 return Err(eyre::eyre!(
191 "`creates = [{variant}(...)]` is only valid for extension-owned SQL types"
192 ));
193 }
194
195 let sql = match T::argument_sql() {
196 Ok(SqlMapping::As(sql)) => sql,
197 Ok(SqlMapping::Composite | SqlMapping::Array(_)) => {
198 return Err(eyre::eyre!(
199 "`creates = [{variant}(...)]` requires a concrete SQL type name"
200 ));
201 }
202 Ok(SqlMapping::Skip) => {
203 return Err(eyre::eyre!(
204 "`creates = [{variant}(...)]` cannot use a skipped SQL type"
205 ));
206 }
207 Err(err) => return Err(err.into()),
208 };
209 let data = SqlDeclaredTypeEntityData {
210 sql,
211 name: name.to_string(),
212 type_ident: T::TYPE_IDENT.to_string(),
213 };
214 Ok(make_declared(data))
215 }
216
217 pub fn sql(&self) -> String {
218 match self {
219 SqlDeclaredEntity::Type(data) => data.sql.clone(),
220 SqlDeclaredEntity::Enum(data) => data.sql.clone(),
221 SqlDeclaredEntity::Function(data) => data.sql.clone(),
222 }
223 }
224
225 pub fn type_ident(&self) -> Option<&str> {
226 match self {
227 SqlDeclaredEntity::Type(data) | SqlDeclaredEntity::Enum(data) => {
228 Some(data.type_ident.as_str())
229 }
230 SqlDeclaredEntity::Function(_) => None,
231 }
232 }
233
234 pub fn matches_type_ident(&self, type_ident: &str) -> bool {
235 matches!(self.type_ident(), Some(value) if value == type_ident)
236 }
237
238 pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> bool {
239 match (&identifier, &self) {
240 (SqlDeclared::Type(ident_name), &SqlDeclaredEntity::Type(data))
241 | (SqlDeclared::Enum(ident_name), &SqlDeclaredEntity::Enum(data)) => {
242 if ident_name == &data.name || ident_name == &data.type_ident {
243 return true;
244 }
245 false
246 }
247 (SqlDeclared::Function(ident_name), &SqlDeclaredEntity::Function(data)) => {
248 ident_name == &data.name
249 }
250 _ => false,
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::metadata::{ArgumentError, ReturnsError, ReturnsRef, SqlMappingRef, TypeOrigin};
259
260 struct ExtensionOwnedType;
261 struct ExternalType;
262
263 unsafe impl SqlTranslatable for ExtensionOwnedType {
264 const TYPE_IDENT: &'static str = "tests::ExtensionOwnedType";
265 const TYPE_ORIGIN: TypeOrigin = TypeOrigin::ThisExtension;
266 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
267 Ok(SqlMappingRef::literal("extension_owned"));
268 const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
269 Ok(ReturnsRef::One(SqlMappingRef::literal("extension_owned")));
270 }
271
272 unsafe impl SqlTranslatable for ExternalType {
273 const TYPE_IDENT: &'static str = "tests::ExternalType";
274 const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
275 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
276 Ok(SqlMappingRef::literal("text"));
277 const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
278 Ok(ReturnsRef::One(SqlMappingRef::literal("text")));
279 }
280
281 #[test]
282 fn build_type_accepts_extension_owned_types() {
283 let declared = SqlDeclaredEntity::build_type::<ExtensionOwnedType>(
284 "Type",
285 "tests::ExtensionOwnedType",
286 )
287 .unwrap();
288
289 assert_eq!(declared.type_ident(), Some("tests::ExtensionOwnedType"));
290 assert_eq!(declared.sql(), "extension_owned");
291 }
292
293 #[test]
294 fn build_type_rejects_external_types() {
295 let error = SqlDeclaredEntity::build_type::<ExternalType>("Type", "tests::ExternalType")
296 .unwrap_err();
297 assert!(error.to_string().contains("only valid for extension-owned SQL types"));
298
299 let error = SqlDeclaredEntity::build_type::<ExternalType>("Enum", "tests::ExternalType")
300 .unwrap_err();
301 assert!(error.to_string().contains("only valid for extension-owned SQL types"));
302 }
303
304 #[test]
305 fn function_declarations_do_not_carry_type_idents() {
306 let declared = SqlDeclaredEntity::build("Function", "tests::helper_fn").unwrap();
307
308 assert_eq!(declared.type_ident(), None);
309 assert_eq!(declared.sql(), "helper_fn");
310 assert!(
311 declared.has_sql_declared_entity(&SqlDeclared::Function("tests::helper_fn".into()))
312 );
313 }
314}