use crate::extension_sql::SqlDeclared;
use crate::metadata::{SqlMapping, SqlTranslatable, TypeOrigin};
use crate::pgrx_sql::PgrxSql;
use crate::positioning_ref::PositioningRef;
use crate::to_sql::ToSql;
use crate::{SqlGraphEntity, SqlGraphIdentifier};
use std::fmt::Display;
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct ExtensionSqlEntity<'a> {
pub module_path: &'a str,
pub full_path: &'a str,
pub sql: &'a str,
pub file: &'a str,
pub line: u32,
pub name: &'a str,
pub bootstrap: bool,
pub finalize: bool,
pub requires: Vec<PositioningRef>,
pub creates: Vec<SqlDeclaredEntity>,
}
impl ExtensionSqlEntity<'_> {
pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> Option<&SqlDeclaredEntity> {
self.creates.iter().find(|created| created.has_sql_declared_entity(identifier))
}
}
impl<'a> From<ExtensionSqlEntity<'a>> for SqlGraphEntity<'a> {
fn from(val: ExtensionSqlEntity<'a>) -> Self {
SqlGraphEntity::CustomSql(val)
}
}
impl SqlGraphIdentifier for ExtensionSqlEntity<'_> {
fn dot_identifier(&self) -> String {
format!("sql {}", self.name)
}
fn rust_identifier(&self) -> String {
self.name.to_string()
}
fn file(&self) -> Option<&str> {
Some(self.file)
}
fn line(&self) -> Option<u32> {
Some(self.line)
}
}
impl ToSql for ExtensionSqlEntity<'_> {
fn to_sql(&self, _context: &PgrxSql) -> eyre::Result<String> {
let ExtensionSqlEntity { file, line, sql, creates, requires, .. } = self;
let creates = if !creates.is_empty() {
let joined = creates.iter().map(|i| format!("-- {i}")).collect::<Vec<_>>().join("\n");
format!(
"\
-- creates:\n\
{joined}\n\n"
)
} else {
"".to_string()
};
let requires = if !requires.is_empty() {
let joined =
requires.iter().map(|i| format!("-- {i}")).collect::<Vec<_>>().join("\n");
format!(
"\
-- requires:\n\
{joined}\n\n"
)
} else {
"".to_string()
};
let sql = format!(
"\n\
-- {file}:{line}\n\
{bootstrap}\
{creates}\
{requires}\
{finalize}\
{sql}\
",
bootstrap = if self.bootstrap { "-- bootstrap\n" } else { "" },
finalize = if self.finalize { "-- finalize\n" } else { "" },
);
Ok(sql)
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
pub struct SqlDeclaredTypeEntityData {
pub(crate) sql: String,
pub(crate) name: String,
pub(crate) type_ident: String,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
pub struct SqlDeclaredFunctionEntityData {
pub(crate) sql: String,
pub(crate) name: String,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Ord, PartialOrd)]
pub enum SqlDeclaredEntity {
Type(SqlDeclaredTypeEntityData),
Enum(SqlDeclaredTypeEntityData),
Function(SqlDeclaredFunctionEntityData),
}
impl Display for SqlDeclaredEntity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SqlDeclaredEntity::Type(data) => {
write!(f, "Type({})", data.name)
}
SqlDeclaredEntity::Enum(data) => {
write!(f, "Enum({})", data.name)
}
SqlDeclaredEntity::Function(data) => {
write!(f, "Function({})", data.name)
}
}
}
}
impl SqlDeclaredEntity {
pub fn build(variant: &str, name: &str) -> eyre::Result<Self> {
let sql = name
.split("::")
.last()
.ok_or_else(|| eyre::eyre!("Did not get SQL for `{}`", name))?
.to_string();
let retval = match variant {
"Type" => Self::Type(SqlDeclaredTypeEntityData {
sql,
name: name.to_string(),
type_ident: name.to_string(),
}),
"Enum" => Self::Enum(SqlDeclaredTypeEntityData {
sql,
name: name.to_string(),
type_ident: name.to_string(),
}),
"Function" => {
Self::Function(SqlDeclaredFunctionEntityData { sql, name: name.to_string() })
}
_ => {
return Err(eyre::eyre!(
"Can only declare `Type(Ident)`, `Enum(Ident)` or `Function(Ident)`"
));
}
};
Ok(retval)
}
pub fn build_type<T: SqlTranslatable>(variant: &str, name: &str) -> eyre::Result<Self> {
let make_declared = match variant {
"Type" => Self::Type,
"Enum" => Self::Enum,
_ => {
return Err(eyre::eyre!(
"Can only declare `Type(Ident)` or `Enum(Ident)` with type metadata"
));
}
};
if matches!(T::TYPE_ORIGIN, TypeOrigin::External) {
return Err(eyre::eyre!(
"`creates = [{variant}(...)]` is only valid for extension-owned SQL types"
));
}
let sql = match T::argument_sql() {
Ok(SqlMapping::As(sql)) => sql,
Ok(SqlMapping::Composite | SqlMapping::Array(_)) => {
return Err(eyre::eyre!(
"`creates = [{variant}(...)]` requires a concrete SQL type name"
));
}
Ok(SqlMapping::Skip) => {
return Err(eyre::eyre!(
"`creates = [{variant}(...)]` cannot use a skipped SQL type"
));
}
Err(err) => return Err(err.into()),
};
let data = SqlDeclaredTypeEntityData {
sql,
name: name.to_string(),
type_ident: T::TYPE_IDENT.to_string(),
};
Ok(make_declared(data))
}
pub fn sql(&self) -> String {
match self {
SqlDeclaredEntity::Type(data) => data.sql.clone(),
SqlDeclaredEntity::Enum(data) => data.sql.clone(),
SqlDeclaredEntity::Function(data) => data.sql.clone(),
}
}
pub fn type_ident(&self) -> Option<&str> {
match self {
SqlDeclaredEntity::Type(data) | SqlDeclaredEntity::Enum(data) => {
Some(data.type_ident.as_str())
}
SqlDeclaredEntity::Function(_) => None,
}
}
pub fn matches_type_ident(&self, type_ident: &str) -> bool {
matches!(self.type_ident(), Some(value) if value == type_ident)
}
pub fn has_sql_declared_entity(&self, identifier: &SqlDeclared) -> bool {
match (&identifier, &self) {
(SqlDeclared::Type(ident_name), &SqlDeclaredEntity::Type(data))
| (SqlDeclared::Enum(ident_name), &SqlDeclaredEntity::Enum(data)) => {
if ident_name == &data.name || ident_name == &data.type_ident {
return true;
}
false
}
(SqlDeclared::Function(ident_name), &SqlDeclaredEntity::Function(data)) => {
ident_name == &data.name
}
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::{ArgumentError, ReturnsError, ReturnsRef, SqlMappingRef, TypeOrigin};
struct ExtensionOwnedType;
struct ExternalType;
unsafe impl SqlTranslatable for ExtensionOwnedType {
const TYPE_IDENT: &'static str = "tests::ExtensionOwnedType";
const TYPE_ORIGIN: TypeOrigin = TypeOrigin::ThisExtension;
const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
Ok(SqlMappingRef::literal("extension_owned"));
const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
Ok(ReturnsRef::One(SqlMappingRef::literal("extension_owned")));
}
unsafe impl SqlTranslatable for ExternalType {
const TYPE_IDENT: &'static str = "tests::ExternalType";
const TYPE_ORIGIN: TypeOrigin = TypeOrigin::External;
const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> =
Ok(SqlMappingRef::literal("text"));
const RETURN_SQL: Result<ReturnsRef, ReturnsError> =
Ok(ReturnsRef::One(SqlMappingRef::literal("text")));
}
#[test]
fn build_type_accepts_extension_owned_types() {
let declared = SqlDeclaredEntity::build_type::<ExtensionOwnedType>(
"Type",
"tests::ExtensionOwnedType",
)
.unwrap();
assert_eq!(declared.type_ident(), Some("tests::ExtensionOwnedType"));
assert_eq!(declared.sql(), "extension_owned");
}
#[test]
fn build_type_rejects_external_types() {
let error = SqlDeclaredEntity::build_type::<ExternalType>("Type", "tests::ExternalType")
.unwrap_err();
assert!(error.to_string().contains("only valid for extension-owned SQL types"));
let error = SqlDeclaredEntity::build_type::<ExternalType>("Enum", "tests::ExternalType")
.unwrap_err();
assert!(error.to_string().contains("only valid for extension-owned SQL types"));
}
#[test]
fn function_declarations_do_not_carry_type_idents() {
let declared = SqlDeclaredEntity::build("Function", "tests::helper_fn").unwrap();
assert_eq!(declared.type_ident(), None);
assert_eq!(declared.sql(), "helper_fn");
assert!(
declared.has_sql_declared_entity(&SqlDeclared::Function("tests::helper_fn".into()))
);
}
}