use crate::{pg_sys, PgMemoryContexts};
use pgx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
};
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;
#[repr(transparent)]
pub struct PgBox<T, AllocatedBy: WhoAllocated = AllocatedByPostgres> {
ptr: Option<NonNull<T>>,
__marker: PhantomData<AllocatedBy>,
}
pub trait WhoAllocated {
unsafe fn maybe_pfree(ptr: *mut std::os::raw::c_void);
}
pub struct AllocatedByPostgres;
pub struct AllocatedByRust;
impl WhoAllocated for AllocatedByPostgres {
unsafe fn maybe_pfree(_ptr: *mut std::os::raw::c_void) {}
}
impl WhoAllocated for AllocatedByRust {
#[inline]
unsafe fn maybe_pfree(ptr: *mut std::os::raw::c_void) {
pg_sys::pfree(ptr.cast());
}
}
impl<T> PgBox<T, AllocatedByPostgres> {
#[inline]
pub unsafe fn from_pg(ptr: *mut T) -> PgBox<T, AllocatedByPostgres> {
PgBox::<T, AllocatedByPostgres> { ptr: NonNull::new(ptr), __marker: PhantomData }
}
}
impl<T, AllocatedBy: WhoAllocated> PgBox<T, AllocatedBy> {
#[inline]
pub unsafe fn from_rust(ptr: *mut T) -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> { ptr: NonNull::new(ptr), __marker: PhantomData }
}
#[inline]
pub unsafe fn alloc() -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::palloc(std::mem::size_of::<T>()) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc0() -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::palloc0(std::mem::size_of::<T>()) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc_in_context(memory_context: PgMemoryContexts) -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::MemoryContextAlloc(
memory_context.value(),
std::mem::size_of::<T>(),
) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc0_in_context(memory_context: PgMemoryContexts) -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::MemoryContextAllocZero(
memory_context.value(),
std::mem::size_of::<T>(),
) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc_node(node_tag: pg_sys::NodeTag) -> PgBox<T, AllocatedByRust>
where
T: pg_sys::PgNode,
{
unsafe {
let node = PgBox::<T>::alloc0();
let ptr = node.as_ptr();
(ptr as *mut _ as *mut pg_sys::Node).as_mut().unwrap_unchecked().type_ = node_tag;
node
}
}
#[inline]
pub fn null() -> PgBox<T, AllocatedBy> {
PgBox::<T, AllocatedBy> { ptr: None, __marker: PhantomData }
}
#[inline]
pub fn is_null(&self) -> bool {
self.ptr.is_none()
}
#[inline]
pub fn as_ptr(&self) -> *mut T {
match self.ptr.as_ref() {
Some(ptr) => unsafe { ptr.clone().as_mut() as *mut T },
None => std::ptr::null_mut(),
}
}
#[inline]
pub fn into_pg(mut self) -> *mut T {
match self.ptr.take() {
Some(ptr) => ptr.as_ptr(),
None => std::ptr::null_mut(),
}
}
#[inline]
pub fn into_pg_boxed(mut self) -> PgBox<T, AllocatedByPostgres> {
unsafe {
PgBox::from_pg(match self.ptr.take() {
Some(ptr) => ptr.as_ptr(),
None => std::ptr::null_mut(),
})
}
}
#[inline]
pub unsafe fn with<F: FnOnce(&mut PgBox<T>)>(ptr: *mut T, func: F) {
func(&mut PgBox::from_pg(ptr))
}
}
impl<T, AllocatedBy: WhoAllocated> Deref for PgBox<T, AllocatedBy> {
type Target = T;
#[track_caller]
fn deref(&self) -> &Self::Target {
match self.ptr.as_ref() {
Some(ptr) => unsafe { ptr.as_ref() },
None => panic!("Attempt to dereference null pointer during Deref of PgBox"),
}
}
}
impl<T, AllocatedBy: WhoAllocated> DerefMut for PgBox<T, AllocatedBy> {
#[track_caller]
fn deref_mut(&mut self) -> &mut T {
match self.ptr.as_mut() {
Some(ptr) => unsafe { ptr.as_mut() },
None => panic!("Attempt to dereference null pointer during DerefMut of PgBox"),
}
}
}
impl<T, AllocatedBy: WhoAllocated> Drop for PgBox<T, AllocatedBy> {
fn drop(&mut self) {
if let Some(ptr) = self.ptr {
unsafe {
AllocatedBy::maybe_pfree(ptr.as_ptr().cast());
}
}
}
}
unsafe impl<T: SqlTranslatable> SqlTranslatable for PgBox<T, AllocatedByPostgres> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
T::argument_sql()
}
fn return_sql() -> Result<Returns, ReturnsError> {
T::return_sql()
}
}
unsafe impl<T: SqlTranslatable> SqlTranslatable for PgBox<T, AllocatedByRust> {
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
T::argument_sql()
}
fn return_sql() -> Result<Returns, ReturnsError> {
T::return_sql()
}
}