pgrx_sql_entity_graph/extension_sql/
entity.rs1use crate::extension_sql::SqlDeclared;
20use crate::metadata::{SqlMapping, SqlTranslatable, TypeOrigin};
21use crate::pgrx_sql::PgrxSql;
22use crate::positioning_ref::PositioningRef;
23use crate::to_sql::ToSql;
24use crate::{SqlGraphEntity, SqlGraphIdentifier};
25
26use std::fmt::Display;
27
28#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
30pub struct ExtensionSqlEntity<'a> {
31 pub module_path: &'a str,
32 pub full_path: &'a str,
33 pub sql: &'a str,
34 pub file: &'a str,
35 pub line: u32,
36 pub name: &'a str,
37 pub bootstrap: bool,
38 pub finalize: bool,
39 pub requires: Vec<PositioningRef>,
40 pub creates: Vec<SqlDeclaredEntity>,
41}
42
43impl ExtensionSqlEntity<'_> {
44 pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
45 self.creates.iter().find(|created| created.has_sql_declared_entity(identifier))
46 }
47}
48
49impl<'a> From<ExtensionSqlEntity<'a>> for SqlGraphEntity<'a> {
50 fn from(val: ExtensionSqlEntity<'a>) -> Self {
51 SqlGraphEntity::CustomSql(val)
52 }
53}
54
55impl SqlGraphIdentifier for ExtensionSqlEntity<'_> {
56 fn dot_identifier(&self) -> String {
57 format!("sql {}", self.name)
58 }
59 fn rust_identifier(&self) -> String {
60 self.name.to_string()
61 }
62
63 fn file(&self) -> Option<&str> {
64 Some(self.file)
65 }
66
67 fn line(&self) -> Option<u32> {
68 Some(self.line)
69 }
70}
71
72impl ToSql for ExtensionSqlEntity<'_> {
73 fn to_sql(&self, _context: &PgrxSql) -> eyre::Result<String> {
74 let ExtensionSqlEntity { file, line, sql, creates, requires, .. } = self;
75 let creates = if !creates.is_empty() {
76 let joined = creates.iter().map(|i| format!("-- {i}")).collect::<Vec<_>>().join("\n");
77 format!(
78 "\
79 -- creates:\n\
80 {joined}\n\n"
81 )
82 } else {
83 "".to_string()
84 };
85 let requires = if !requires.is_empty() {
86 let joined =
87 requires.iter().map(|i| format!("-- {i}")).collect::<Vec<_>>().join("\n");
88 format!(
89 "\
90 -- requires:\n\
91 {joined}\n\n"
92 )
93 } else {
94 "".to_string()
95 };
96 let sql = format!(
97 "\n\
98 -- {file}:{line}\n\
99 {bootstrap}\
100 {creates}\
101 {requires}\
102 {finalize}\
103 {sql}\
104 ",
105 bootstrap = if self.bootstrap { "-- bootstrap\n" } else { "" },
106 finalize = if self.finalize { "-- finalize\n" } else { "" },
107 );
108 Ok(sql)
109 }
110}
111
112#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
113pub struct SqlDeclaredTypeEntityData {
114 pub(crate) sql: String,
115 pub(crate) name: String,
116 pub(crate) type_ident: String,
117}
118
119#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
120pub struct SqlDeclaredFunctionEntityData {
121 pub(crate) sql: String,
122 pub(crate) name: String,
123}
124
125#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
126pub enum SqlDeclaredEntity {
127 Type(SqlDeclaredTypeEntityData),
128 Enum(SqlDeclaredTypeEntityData),
129 Function(SqlDeclaredFunctionEntityData),
130}
131
132impl Display for SqlDeclaredEntity {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 match self {
135 Self::Type(data) => {
136 write!(f, "Type({})", data.name)
137 }
138 Self::Enum(data) => {
139 write!(f, "Enum({})", data.name)
140 }
141 Self::Function(data) => {
142 write!(f, "Function({})", data.name)
143 }
144 }
145 }
146}
147
148impl SqlDeclaredEntity {
149 pub fn build(variant: &str, name: &str) -> eyre::Result<Self> {
150 let sql = name
151 .split("::")
152 .last()
153 .ok_or_else(|| eyre::eyre!("Did not get SQL for `{}`", name))?
154 .to_string();
155 let retval = match variant {
156 "Type" => Self::Type(SqlDeclaredTypeEntityData {
157 sql,
158 name: name.to_string(),
159 type_ident: name.to_string(),
160 }),
161 "Enum" => Self::Enum(SqlDeclaredTypeEntityData {
162 sql,
163 name: name.to_string(),
164 type_ident: name.to_string(),
165 }),
166 "Function" => {
167 Self::Function(SqlDeclaredFunctionEntityData { sql, name: name.to_string() })
168 }
169 _ => {
170 return Err(eyre::eyre!(
171 "Can only declare `Type(Ident)`, `Enum(Ident)` or `Function(Ident)`"
172 ));
173 }
174 };
175 Ok(retval)
176 }
177
178 pub fn build_type<T: SqlTranslatable>(variant: &str, name: &str) -> eyre::Result<Self> {
179 let make_declared = match variant {
180 "Type" => Self::Type,
181 "Enum" => Self::Enum,
182 _ => {
183 return Err(eyre::eyre!(
184 "Can only declare `Type(Ident)` or `Enum(Ident)` with type metadata"
185 ));
186 }
187 };
188
189 if matches!(T::TYPE_ORIGIN, TypeOrigin::External) {
190 return Err(eyre::eyre!(
191 "`creates = [{variant}(...)]` is only valid for extension-owned SQL types"
192 ));
193 }
194
195 let sql = match T::argument_sql() {
196 Ok(SqlMapping::As(sql)) => sql,
197 Ok(SqlMapping::Composite | SqlMapping::Array(_)) => {
198 return Err(eyre::eyre!(
199 "`creates = [{variant}(...)]` requires a concrete SQL type name"
200 ));
201 }
202 Ok(SqlMapping::Skip) => {
203 return Err(eyre::eyre!(
204 "`creates = [{variant}(...)]` cannot use a skipped SQL type"
205 ));
206 }
207 Err(err) => return Err(err.into()),
208 };
209 let data = SqlDeclaredTypeEntityData {
210 sql,
211 name: name.to_string(),
212 type_ident: T::TYPE_IDENT.to_string(),
213 };
214 Ok(make_declared(data))
215 }
216
217 pub fn sql(&self) -> String {
218 match self {
219 Self::Type(data) => data.sql.clone(),
220 Self::Enum(data) => data.sql.clone(),
221 Self::Function(data) => data.sql.clone(),
222 }
223 }
224
225 pub fn type_ident(&self) -> Option<&str> {
226 match self {
227 Self::Type(data) | Self::Enum(data) => Some(data.type_ident.as_str()),
228 Self::Function(_) => None,
229 }
230 }
231
232 pub fn matches_type_ident(&self, type_ident: &str) -> bool {
233 matches!(self.type_ident(), Some(value) if value == type_ident)
234 }
235
236 pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> bool {
237 match (&identifier, &self) {
238 (SqlDeclared::Type(ident_name), &Self::Type(data))
239 | (SqlDeclared::Enum(ident_name), &Self::Enum(data)) => {
240 if ident_name == &data.name || ident_name == &data.type_ident {
241 return true;
242 }
243 false
244 }
245 (SqlDeclared::Function(ident_name), &Self::Function(data)) => ident_name == &data.name,
246 _ => false,
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::metadata::{ArgumentError, ReturnsError, ReturnsRef, SqlMappingRef, TypeOrigin};
255
256 struct ExtensionOwnedType;
257 struct ExternalType;
258
259 unsafe impl SqlTranslatable for ExtensionOwnedType {
260 const TYPE_IDENT: &'static str = "tests::ExtensionOwnedType";
261 const TYPE_ORIGIN: TypeOrigin = TypeOrigin::ThisExtension;
262 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
263 Ok(SqlMappingRef::literal("extension_owned"));
264 const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
265 Ok(ReturnsRef::One(SqlMappingRef::literal("extension_owned")));
266 }
267
268 unsafe impl SqlTranslatable for ExternalType {
269 const TYPE_IDENT: &'static str = "tests::ExternalType";
270 const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
271 const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
272 Ok(SqlMappingRef::literal("text"));
273 const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
274 Ok(ReturnsRef::One(SqlMappingRef::literal("text")));
275 }
276
277 #[test]
278 fn build_type_accepts_extension_owned_types() {
279 let declared = SqlDeclaredEntity::build_type::<ExtensionOwnedType>(
280 "Type",
281 "tests::ExtensionOwnedType",
282 )
283 .unwrap();
284
285 assert_eq!(declared.type_ident(), Some("tests::ExtensionOwnedType"));
286 assert_eq!(declared.sql(), "extension_owned");
287 }
288
289 #[test]
290 fn build_type_rejects_external_types() {
291 let error = SqlDeclaredEntity::build_type::<ExternalType>("Type", "tests::ExternalType")
292 .unwrap_err();
293 assert!(error.to_string().contains("only valid for extension-owned SQL types"));
294
295 let error = SqlDeclaredEntity::build_type::<ExternalType>("Enum", "tests::ExternalType")
296 .unwrap_err();
297 assert!(error.to_string().contains("only valid for extension-owned SQL types"));
298 }
299
300 #[test]
301 fn function_declarations_do_not_carry_type_idents() {
302 let declared = SqlDeclaredEntity::build("Function", "tests::helper_fn").unwrap();
303
304 assert_eq!(declared.type_ident(), None);
305 assert_eq!(declared.sql(), "helper_fn");
306 assert!(
307 declared.has_sql_declared_entity(&SqlDeclared::Function("tests::helper_fn".into()))
308 );
309 }
310}