use crate::pgrx_attribute::{ArgValue, PgrxArg, PgrxAttribute};
use crate::pgrx_sql::PgrxSql;
use crate::to_sql::ToSql;
use crate::to_sql::entity::ToSqlConfigEntity;
use crate::{SqlGraphEntity, SqlGraphIdentifier, TypeMatch};
use eyre::eyre;
use proc_macro2::TokenStream;
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
use syn::spanned::Spanned;
use syn::{AttrStyle, Attribute, Lit};
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub enum Alignment {
On,
Off,
}
const INVALID_ATTR_CONTENT: &str =
r#"expected `#[pgrx(alignment = align)]`, where `align` is "on", or "off""#;
impl ToTokens for Alignment {
fn to_tokens(&self, tokens: &mut TokenStream) {
let value = match self {
Alignment::On => format_ident!("On"),
Alignment::Off => format_ident!("Off"),
};
let quoted = quote! {
::pgrx::pgrx_sql_entity_graph::Alignment::#value
};
tokens.append_all(quoted);
}
}
impl Alignment {
pub fn from_attribute(attr: &Attribute) -> Result<Option<Self>, syn::Error> {
if attr.style != AttrStyle::Outer {
return Err(syn::Error::new(
attr.span(),
"#[pgrx(alignment = ..)] is only valid in an outer context",
));
}
let attr = attr.parse_args::<PgrxAttribute>()?;
for arg in attr.args.iter() {
let PgrxArg::NameValue(nv) = arg;
if !nv.path.is_ident("alignment") {
continue;
}
return match nv.value {
ArgValue::Lit(Lit::Str(ref s)) => match s.value().as_ref() {
"on" => Ok(Some(Self::On)),
"off" => Ok(Some(Self::Off)),
_ => Err(syn::Error::new(s.span(), INVALID_ATTR_CONTENT)),
},
ArgValue::Path(ref p) => Err(syn::Error::new(p.span(), INVALID_ATTR_CONTENT)),
ArgValue::Lit(ref l) => Err(syn::Error::new(l.span(), INVALID_ATTR_CONTENT)),
};
}
Ok(None)
}
pub fn from_attributes(attrs: &[Attribute]) -> Result<Self, syn::Error> {
for attr in attrs {
if attr.path().is_ident("pgrx")
&& let Some(v) = Self::from_attribute(attr)?
{
return Ok(v);
}
}
Ok(Self::Off)
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct PostgresTypeEntity<'a> {
pub name: &'a str,
pub file: &'a str,
pub line: u32,
pub full_path: &'a str,
pub module_path: &'a str,
pub type_ident: &'a str,
pub in_fn_path: &'a str,
pub out_fn_path: &'a str,
pub receive_fn_path: Option<&'a str>,
pub send_fn_path: Option<&'a str>,
pub to_sql_config: ToSqlConfigEntity<'a>,
pub alignment: Option<usize>,
}
impl TypeMatch for PostgresTypeEntity<'_> {
fn type_ident(&self) -> &str {
self.type_ident
}
}
impl<'a> From<PostgresTypeEntity<'a>> for SqlGraphEntity<'a> {
fn from(val: PostgresTypeEntity<'a>) -> Self {
SqlGraphEntity::Type(val)
}
}
impl SqlGraphIdentifier for PostgresTypeEntity<'_> {
fn dot_identifier(&self) -> String {
format!("type {}", self.full_path)
}
fn rust_identifier(&self) -> String {
self.full_path.to_string()
}
fn file(&self) -> Option<&str> {
Some(self.file)
}
fn line(&self) -> Option<u32> {
Some(self.line)
}
}
impl ToSql for PostgresTypeEntity<'_> {
fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
let self_index = context.types[self];
let item_node = &context.graph[self_index];
let SqlGraphEntity::Type(PostgresTypeEntity {
name,
file,
line,
module_path,
full_path,
in_fn_path,
out_fn_path,
receive_fn_path,
send_fn_path,
alignment,
..
}) = item_node
else {
return Err(eyre!("Was not called on a Type. Got: {:?}", item_node));
};
let in_fn_path = resolve_function_path(module_path, in_fn_path);
let in_fn = function_name(&in_fn_path);
let (_, _index) = context
.externs
.iter()
.find(|(k, _v)| k.full_path == in_fn_path)
.ok_or_else(|| eyre::eyre!("Did not find `in_fn: {}`.", in_fn_path))?;
let (in_fn_graph_index, in_fn_entity) = context
.graph
.neighbors_undirected(self_index)
.find_map(|neighbor| match &context.graph[neighbor] {
SqlGraphEntity::Function(func) if func.full_path == in_fn_path => {
Some((neighbor, func))
}
_ => None,
})
.ok_or_else(|| {
let neighbors = context
.graph
.neighbors_undirected(self_index)
.filter_map(|neighbor| match &context.graph[neighbor] {
SqlGraphEntity::Function(func) => Some(func.full_path),
_ => None,
})
.collect::<Vec<_>>();
eyre!(
"Could not find in_fn graph entity for type `{full_path}` (`{in_fn_path}`). neighboring functions: {:?}",
neighbors
)
})?;
let in_fn_sql = in_fn_entity.to_sql(context)?;
let out_fn_path = resolve_function_path(module_path, out_fn_path);
let out_fn = function_name(&out_fn_path);
let (_, _index) = context
.externs
.iter()
.find(|(k, _v)| k.full_path == out_fn_path)
.ok_or_else(|| eyre::eyre!("Did not find `out_fn: {}`.", out_fn_path))?;
let (out_fn_graph_index, out_fn_entity) = context
.graph
.neighbors_undirected(self_index)
.find_map(|neighbor| match &context.graph[neighbor] {
SqlGraphEntity::Function(func) if func.full_path == out_fn_path => {
Some((neighbor, func))
}
_ => None,
})
.ok_or_else(|| {
let neighbors = context
.graph
.neighbors_undirected(self_index)
.filter_map(|neighbor| match &context.graph[neighbor] {
SqlGraphEntity::Function(func) => Some(func.full_path),
_ => None,
})
.collect::<Vec<_>>();
eyre!(
"Could not find out_fn graph entity for type `{full_path}` (`{out_fn_path}`). neighboring functions: {:?}",
neighbors
)
})?;
let out_fn_sql = out_fn_entity.to_sql(context)?;
let receive_fn_graph_index_and_receive_fn_sql = receive_fn_path
.as_ref()
.map(|receive_fn_path| {
let receive_fn_path = resolve_function_path(module_path, receive_fn_path);
let (_, _index) = context
.externs
.iter()
.find(|(k, _v)| k.full_path == receive_fn_path)
.ok_or_else(|| eyre::eyre!("Did not find `receive_fn`: {receive_fn_path}."))?;
let (receive_fn_graph_index, receive_fn_entity) = context
.graph
.neighbors_undirected(self_index)
.find_map(|neighbor| match &context.graph[neighbor] {
SqlGraphEntity::Function(func) if func.full_path == receive_fn_path => {
Some((neighbor, func))
}
_ => None,
})
.ok_or_else(|| eyre!("Could not find receive_fn graph entity."))?;
let receive_fn_sql = receive_fn_entity.to_sql(context)?;
Ok::<_, eyre::Report>((receive_fn_graph_index, receive_fn_sql, receive_fn_path))
})
.transpose()?;
let send_fn_graph_index_and_send_fn_sql = send_fn_path
.as_ref()
.map(|send_fn_path| {
let send_fn_path = resolve_function_path(module_path, send_fn_path);
let (_, _index) = context
.externs
.iter()
.find(|(k, _v)| k.full_path == send_fn_path)
.ok_or_else(|| eyre::eyre!("Did not find `send_fn: {}`.", send_fn_path))?;
let (send_fn_graph_index, send_fn_entity) = context
.graph
.neighbors_undirected(self_index)
.find_map(|neighbor| match &context.graph[neighbor] {
SqlGraphEntity::Function(func) if func.full_path == send_fn_path => {
Some((neighbor, func))
}
_ => None,
})
.ok_or_else(|| eyre!("Could not find send_fn graph entity."))?;
let send_fn_sql = send_fn_entity.to_sql(context)?;
Ok::<_, eyre::Report>((send_fn_graph_index, send_fn_sql, send_fn_path))
})
.transpose()?;
let shell_type = format!(
"\n\
-- {file}:{line}\n\
-- {full_path}\n\
CREATE TYPE {schema}{name};\
",
schema = context.schema_prefix_for(&self_index),
);
let alignment = alignment
.map(|alignment| {
assert!(alignment.is_power_of_two());
let alignment = match alignment {
1 => "char",
2 => "int2",
4 => "int4",
8 => "double",
_ => panic!("type '{name}' wants unsupported alignment '{alignment}'"),
};
format!(
",\n\
\tALIGNMENT = {alignment}"
)
})
.unwrap_or_default();
let (receive_send_attributes, receive_send_sql) = receive_fn_graph_index_and_receive_fn_sql
.zip(send_fn_graph_index_and_send_fn_sql)
.map(|((receive_fn_graph_index, receive_fn_sql, receive_fn_path), (send_fn_graph_index, send_fn_sql, send_fn_path))| {
let receive_fn = function_name(&receive_fn_path);
let send_fn = function_name(&send_fn_path);
(
format! {
"\
\tRECEIVE = {schema_prefix_receive_fn}{receive_fn}, /* {receive_fn_path} */\n\
\tSEND = {schema_prefix_send_fn}{send_fn}, /* {send_fn_path} */\n\
",
schema_prefix_receive_fn = context.schema_prefix_for(&receive_fn_graph_index),
schema_prefix_send_fn = context.schema_prefix_for(&send_fn_graph_index),
},
format! {
"\n\
{receive_fn_sql}\n\
{send_fn_sql}\n\
"
}
)
}).unwrap_or_default();
let materialized_type = format! {
"\n\
-- {file}:{line}\n\
-- {full_path}\n\
CREATE TYPE {schema}{name} (\n\
\tINTERNALLENGTH = variable,\n\
\tINPUT = {schema_prefix_in_fn}{in_fn}, /* {in_fn_path} */\n\
\tOUTPUT = {schema_prefix_out_fn}{out_fn}, /* {out_fn_path} */\n\
{receive_send_attributes}\
\tSTORAGE = extended{alignment}\n\
);\
",
schema = context.schema_prefix_for(&self_index),
schema_prefix_in_fn = context.schema_prefix_for(&in_fn_graph_index),
schema_prefix_out_fn = context.schema_prefix_for(&out_fn_graph_index)
};
let result = shell_type
+ "\n"
+ &in_fn_sql
+ "\n"
+ &out_fn_sql
+ &receive_send_sql
+ "\n"
+ &materialized_type;
Ok(result)
}
}
fn resolve_function_path(module_path: &str, path: &str) -> String {
if path.contains("::") { path.to_string() } else { format!("{module_path}::{path}") }
}
fn function_name(path: &str) -> &str {
path.rsplit("::").next().unwrap_or(path)
}