use crate::{PgMemoryContexts, pg_sys};
use core::fmt::{Debug, Display, Formatter};
use pgrx_sql_entity_graph::metadata::{
ArgumentError, ReturnsError, ReturnsRef, SqlMappingRef, SqlTranslatable,
};
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::ptr::NonNull;
#[repr(transparent)]
pub struct PgBox<T, AllocatedBy: WhoAllocated = AllocatedByPostgres> {
ptr: Option<NonNull<T>>, __marker: PhantomData<AllocatedBy>,
}
pub trait WhoAllocated {
unsafe fn maybe_pfree(ptr: *mut std::os::raw::c_void);
}
pub struct AllocatedByPostgres;
pub struct AllocatedByRust;
impl WhoAllocated for AllocatedByPostgres {
unsafe fn maybe_pfree(_ptr: *mut std::os::raw::c_void) {}
}
impl WhoAllocated for AllocatedByRust {
#[inline]
unsafe fn maybe_pfree(ptr: *mut std::os::raw::c_void) {
pg_sys::pfree(ptr.cast());
}
}
impl<T> PgBox<T, AllocatedByPostgres> {
#[inline]
pub unsafe fn from_pg(ptr: *mut T) -> PgBox<T, AllocatedByPostgres> {
PgBox::<T, AllocatedByPostgres> { ptr: NonNull::new(ptr), __marker: PhantomData }
}
}
impl<T, AllocatedBy: WhoAllocated> PgBox<T, AllocatedBy> {
#[inline]
pub unsafe fn from_rust(ptr: *mut T) -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> { ptr: NonNull::new(ptr), __marker: PhantomData }
}
#[inline]
pub unsafe fn alloc() -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::palloc(std::mem::size_of::<T>()) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc0() -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::palloc0(std::mem::size_of::<T>()) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc_in_context(memory_context: PgMemoryContexts) -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::MemoryContextAlloc(
memory_context.value(),
std::mem::size_of::<T>(),
) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc0_in_context(memory_context: PgMemoryContexts) -> PgBox<T, AllocatedByRust> {
PgBox::<T, AllocatedByRust> {
ptr: Some(unsafe {
NonNull::new_unchecked(pg_sys::MemoryContextAllocZero(
memory_context.value(),
std::mem::size_of::<T>(),
) as *mut T)
}),
__marker: PhantomData,
}
}
#[inline]
pub unsafe fn alloc_node(node_tag: pg_sys::NodeTag) -> PgBox<T, AllocatedByRust>
where
T: pg_sys::PgNode,
{
unsafe {
let node = PgBox::<T>::alloc0();
let ptr = node.as_ptr();
(ptr as *mut _ as *mut pg_sys::Node).as_mut().unwrap_unchecked().type_ = node_tag;
node
}
}
#[inline]
pub fn null() -> PgBox<T, AllocatedBy> {
PgBox::<T, AllocatedBy> { ptr: None, __marker: PhantomData }
}
#[inline]
pub fn is_null(&self) -> bool {
self.ptr.is_none()
}
#[inline]
pub fn as_ptr(&self) -> *mut T {
match self.ptr.as_ref() {
Some(ptr) => unsafe { ptr.clone().as_mut() as *mut T },
None => std::ptr::null_mut(),
}
}
#[inline]
pub fn into_pg(mut self) -> *mut T {
match self.ptr.take() {
Some(ptr) => ptr.as_ptr(),
None => std::ptr::null_mut(),
}
}
#[inline]
pub fn into_pg_boxed(mut self) -> PgBox<T, AllocatedByPostgres> {
unsafe {
PgBox::from_pg(match self.ptr.take() {
Some(ptr) => ptr.as_ptr(),
None => std::ptr::null_mut(),
})
}
}
#[inline]
pub unsafe fn with<F: FnOnce(&mut PgBox<T>)>(ptr: *mut T, func: F) {
func(&mut PgBox::from_pg(ptr))
}
}
impl<T> Clone for PgBox<T, AllocatedByPostgres>
where
T: Copy,
{
fn clone(&self) -> Self {
if self.ptr.is_none() {
PgBox { ptr: None, __marker: Default::default() }
} else {
unsafe {
let copy = PgMemoryContexts::CurrentMemoryContext
.copy_ptr_into(self.as_ptr(), std::mem::size_of::<T>());
PgBox::from_pg(copy)
}
}
}
}
impl<T, AllocatedBy: WhoAllocated> Eq for PgBox<T, AllocatedBy> where T: Eq {}
impl<T, AllocatedBy: WhoAllocated> PartialEq for PgBox<T, AllocatedBy>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.as_ref() == other.as_ref()
}
}
impl<T, AllocatedBy: WhoAllocated> Debug for PgBox<T, AllocatedBy>
where
T: Debug,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{:?}", self.as_ref())
}
}
impl<T, AllocatedBy: WhoAllocated> Display for PgBox<T, AllocatedBy>
where
T: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.as_ref())
}
}
impl<T, AllocatedBy: WhoAllocated> AsRef<T> for PgBox<T, AllocatedBy> {
fn as_ref(&self) -> &T {
match self.ptr.as_ref() {
Some(ptr) => unsafe { ptr.as_ref() },
None => panic!("Attempt to dereference null pointer during `AsRef::as_ref()` of PgBox"),
}
}
}
impl<T, AllocatedBy: WhoAllocated> Deref for PgBox<T, AllocatedBy> {
type Target = T;
#[track_caller]
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl<T, AllocatedBy: WhoAllocated> DerefMut for PgBox<T, AllocatedBy> {
#[track_caller]
fn deref_mut(&mut self) -> &mut T {
match self.ptr.as_mut() {
Some(ptr) => unsafe { ptr.as_mut() },
None => panic!("Attempt to dereference null pointer during DerefMut of PgBox"),
}
}
}
impl<T, AllocatedBy: WhoAllocated> Drop for PgBox<T, AllocatedBy> {
fn drop(&mut self) {
if let Some(ptr) = self.ptr {
unsafe {
AllocatedBy::maybe_pfree(ptr.as_ptr().cast());
}
}
}
}
unsafe impl<T: SqlTranslatable> SqlTranslatable for PgBox<T, AllocatedByPostgres> {
const TYPE_IDENT: &'static str = T::TYPE_IDENT;
const TYPE_ORIGIN: pgrx_sql_entity_graph::metadata::TypeOrigin = T::TYPE_ORIGIN;
const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
}
unsafe impl<T: SqlTranslatable> SqlTranslatable for PgBox<T, AllocatedByRust> {
const TYPE_IDENT: &'static str = T::TYPE_IDENT;
const TYPE_ORIGIN: pgrx_sql_entity_graph::metadata::TypeOrigin = T::TYPE_ORIGIN;
const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
}