use crate::{
FromDatum, IntoDatum, PgMemoryContexts, StringInfo, pg_sys, rust_regtypein, set_varsize_4b,
set_varsize_short, vardata_any, varsize_any, varsize_any_exhdr, void_mut_ptr,
};
use pgrx_sql_entity_graph::metadata::{
ArgumentError, ReturnsError, ReturnsRef, SqlMappingRef, SqlTranslatable,
};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
struct PallocdVarlena {
ptr: *mut pg_sys::varlena,
len: usize,
}
impl Clone for PallocdVarlena {
fn clone(&self) -> Self {
let len = self.len;
let ptr = unsafe {
PgMemoryContexts::of(self.ptr as void_mut_ptr)
.expect("could not determine owning memory context")
.copy_ptr_into(self.ptr as void_mut_ptr, len) as *mut pg_sys::varlena
};
PallocdVarlena { ptr, len }
}
}
pub struct PgVarlena<T>
where
T: Copy + Sized,
{
leaked: Option<*mut PallocdVarlena>,
varlena: Cow<'static, PallocdVarlena>,
need_free: bool,
__marker: PhantomData<T>,
}
impl<T> PgVarlena<T>
where
T: Copy + Sized,
{
pub fn new() -> Self {
let size_of = std::mem::size_of::<T>();
let ptr = unsafe { pg_sys::palloc0(pg_sys::VARHDRSZ + size_of) as *mut pg_sys::varlena };
unsafe {
if size_of + pg_sys::VARHDRSZ_SHORT <= pg_sys::VARATT_SHORT_MAX as usize {
set_varsize_short(ptr, (size_of + pg_sys::VARHDRSZ_SHORT) as i32);
} else {
set_varsize_4b(ptr, (size_of + pg_sys::VARHDRSZ) as i32);
}
}
PgVarlena {
leaked: None,
varlena: Cow::Owned(PallocdVarlena { ptr, len: unsafe { varsize_any(ptr) } }),
need_free: true,
__marker: PhantomData,
}
}
pub unsafe fn from_datum(datum: pg_sys::Datum) -> Self {
let ptr = pg_sys::pg_detoast_datum(datum.cast_mut_ptr());
let len = varsize_any(ptr);
if ptr == datum.cast_mut_ptr() {
let leaked = Box::leak(Box::new(PallocdVarlena { ptr, len }));
PgVarlena {
leaked: Some(leaked),
varlena: Cow::Borrowed(leaked),
need_free: false,
__marker: PhantomData,
}
} else {
PgVarlena {
leaked: None,
varlena: Cow::Owned(PallocdVarlena { ptr, len }),
need_free: true,
__marker: PhantomData,
}
}
}
pub fn into_pg(mut self) -> *mut pg_sys::varlena {
self.need_free = false;
self.varlena.ptr
}
}
impl<T> Drop for PgVarlena<T>
where
T: Copy + Sized,
{
fn drop(&mut self) {
if self.need_free {
unsafe {
pg_sys::pfree(self.varlena.ptr as void_mut_ptr);
}
}
if let Some(leaked) = self.leaked {
unsafe { drop(Box::from_raw(leaked)) }
}
}
}
impl<T> Eq for PgVarlena<T> where T: Eq + Copy + Sized {}
impl<T> PartialEq for PgVarlena<T>
where
T: PartialEq + Copy + Sized,
{
#[inline]
fn eq(&self, other: &Self) -> bool {
self.as_ref() == other.as_ref()
}
}
impl<T> Ord for PgVarlena<T>
where
T: Ord + Copy + Sized,
{
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
self.as_ref().cmp(other.as_ref())
}
}
impl<T> PartialOrd for PgVarlena<T>
where
T: Ord + Copy + Sized,
{
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.as_ref().cmp(other.as_ref()))
}
}
impl<T> Hash for PgVarlena<T>
where
T: Hash + Copy + Sized,
{
#[inline]
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state)
}
}
impl<T> Deref for PgVarlena<T>
where
T: Copy + Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl<T> DerefMut for PgVarlena<T>
where
T: Copy + Sized,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut()
}
}
impl<T> AsRef<T> for PgVarlena<T>
where
T: Copy + Sized,
{
fn as_ref(&self) -> &T {
unsafe {
let ptr = vardata_any(self.varlena.ptr) as *const T;
ptr.as_ref().unwrap()
}
}
}
impl<T> Default for PgVarlena<T>
where
T: Default + Copy,
{
fn default() -> Self {
let mut ptr = Self::new();
*ptr = T::default();
ptr
}
}
impl<T> AsMut<T> for PgVarlena<T>
where
T: Copy + Sized,
{
fn as_mut(&mut self) -> &mut T {
unsafe {
let ptr = vardata_any(self.varlena.to_mut().ptr) as *mut T;
ptr.as_mut().unwrap()
}
}
}
impl<T> From<PgVarlena<T>> for Option<pg_sys::Datum>
where
T: Copy + Sized,
{
fn from(val: PgVarlena<T>) -> Self {
Some(val.into_pg().into())
}
}
impl<T> IntoDatum for PgVarlena<T>
where
T: Copy + Sized,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(self.into_pg().into())
}
fn type_oid() -> pg_sys::Oid {
rust_regtypein::<T>()
}
}
impl<T> FromDatum for PgVarlena<T>
where
T: Copy + Sized,
{
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null { None } else { Some(PgVarlena::<T>::from_datum(datum)) }
}
unsafe fn from_datum_in_memory_context(
mut memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
memory_context.switch_to(|_| {
let detoasted = pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
let varlena = pg_sys::pg_detoast_datum_packed(detoasted);
Some(PgVarlena::<T>::from_datum(varlena.into()))
})
}
}
}
#[doc(hidden)]
pub unsafe fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
let mut serialized = StringInfo::new();
serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); serde_cbor::to_writer(&mut serialized, &input).expect("failed to encode as CBOR");
let size = serialized.len();
let varlena = serialized.into_char_ptr();
unsafe {
set_varsize_4b(varlena as *mut pg_sys::varlena, size as i32);
}
varlena as *const pg_sys::varlena
}
#[doc(hidden)]
pub unsafe fn cbor_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
{
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = varsize_any_exhdr(varlena);
let data = vardata_any(varlena);
let slice = std::slice::from_raw_parts(data as *const u8, len);
serde_cbor::from_slice(slice).expect("failed to decode CBOR")
}
#[doc(hidden)]
#[deprecated(since = "0.12.0", note = "just use the FromDatum impl")]
pub unsafe fn cbor_decode_into_context<'de, T>(
mut memory_context: PgMemoryContexts,
input: *mut pg_sys::varlena,
) -> T
where
T: Deserialize<'de>,
{
memory_context.switch_to(|_| {
let varlena = pg_sys::pg_detoast_datum_copy(input as *mut pg_sys::varlena);
cbor_decode(varlena)
})
}
unsafe impl<T> SqlTranslatable for PgVarlena<T>
where
T: SqlTranslatable + Copy,
{
const TYPE_IDENT: &'static str = T::TYPE_IDENT;
const TYPE_ORIGIN: pgrx_sql_entity_graph::metadata::TypeOrigin = T::TYPE_ORIGIN;
const ARGUMENT_SQL: Result<SqlMappingRef, ArgumentError> = T::ARGUMENT_SQL;
const RETURN_SQL: Result<ReturnsRef, ReturnsError> = T::RETURN_SQL;
}