use crate::pg_sys::{VARATT_SHORT_MAX, VARHDRSZ_SHORT};
use crate::{
pg_sys, rust_regtypein, set_varsize, set_varsize_short, vardata_any, varsize_any,
varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, PostgresType,
StringInfo,
};
use pgx_pg_sys::varlena;
use pgx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
struct PallocdVarlena {
ptr: *mut pg_sys::varlena,
len: usize,
}
impl Clone for PallocdVarlena {
fn clone(&self) -> Self {
let len = self.len;
let ptr = unsafe {
PgMemoryContexts::of(self.ptr as void_mut_ptr)
.expect("could not determine owning memory context")
.copy_ptr_into(self.ptr as void_mut_ptr, len) as *mut pg_sys::varlena
};
PallocdVarlena { ptr, len }
}
}
pub struct PgVarlena<T>
where
T: Copy + Sized,
{
leaked: Option<*mut PallocdVarlena>,
varlena: Cow<'static, PallocdVarlena>,
need_free: bool,
__marker: PhantomData<T>,
}
impl<T> PgVarlena<T>
where
T: Copy + Sized,
{
pub fn new() -> Self {
let size_of = std::mem::size_of::<T>();
let ptr = unsafe { pg_sys::palloc0(pg_sys::VARHDRSZ + size_of) as *mut pg_sys::varlena };
unsafe {
if size_of + VARHDRSZ_SHORT() <= VARATT_SHORT_MAX as usize {
set_varsize_short(ptr, (size_of + VARHDRSZ_SHORT()) as i32);
} else {
set_varsize(ptr, (size_of + pg_sys::VARHDRSZ) as i32);
}
}
PgVarlena {
leaked: None,
varlena: Cow::Owned(PallocdVarlena { ptr, len: unsafe { varsize_any(ptr) } }),
need_free: true,
__marker: PhantomData,
}
}
pub unsafe fn from_datum(datum: pg_sys::Datum) -> Self {
let ptr = pg_sys::pg_detoast_datum(datum.cast_mut_ptr());
let len = varsize_any(ptr);
if ptr == datum.cast_mut_ptr() {
let leaked = Box::leak(Box::new(PallocdVarlena { ptr, len }));
PgVarlena {
leaked: Some(leaked),
varlena: Cow::Borrowed(leaked),
need_free: false,
__marker: PhantomData,
}
} else {
PgVarlena {
leaked: None,
varlena: Cow::Owned(PallocdVarlena { ptr, len }),
need_free: true,
__marker: PhantomData,
}
}
}
pub fn into_pg(mut self) -> *mut pg_sys::varlena {
self.need_free = false;
self.varlena.ptr
}
}
impl<T> Drop for PgVarlena<T>
where
T: Copy + Sized,
{
fn drop(&mut self) {
if self.need_free {
unsafe {
pg_sys::pfree(self.varlena.ptr as void_mut_ptr);
}
}
if let Some(leaked) = self.leaked {
unsafe { drop(Box::from_raw(leaked)) }
}
}
}
impl<T> Deref for PgVarlena<T>
where
T: Copy + Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
impl<T> DerefMut for PgVarlena<T>
where
T: Copy + Sized,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut()
}
}
impl<T> AsRef<T> for PgVarlena<T>
where
T: Copy + Sized,
{
fn as_ref(&self) -> &T {
unsafe {
let ptr = vardata_any(self.varlena.ptr) as *const T;
ptr.as_ref().unwrap()
}
}
}
impl<T> Default for PgVarlena<T>
where
T: Default + Copy,
{
fn default() -> Self {
let mut ptr = Self::new();
*ptr = T::default();
ptr
}
}
impl<T> AsMut<T> for PgVarlena<T>
where
T: Copy + Sized,
{
fn as_mut(&mut self) -> &mut T {
unsafe {
let ptr = vardata_any(self.varlena.to_mut().ptr) as *mut T;
ptr.as_mut().unwrap()
}
}
}
impl<T> From<PgVarlena<T>> for Option<pg_sys::Datum>
where
T: Copy + Sized,
{
fn from(val: PgVarlena<T>) -> Self {
Some(val.into_pg().into())
}
}
impl<T> IntoDatum for PgVarlena<T>
where
T: Copy + Sized,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(self.into_pg().into())
}
fn type_oid() -> pg_sys::Oid {
rust_regtypein::<T>()
}
}
impl<T> FromDatum for PgVarlena<T>
where
T: Copy + Sized,
{
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
Some(PgVarlena::<T>::from_datum(datum))
}
}
unsafe fn from_datum_in_memory_context(
mut memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
memory_context.switch_to(|_| {
let detoasted = pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
let varlena = pg_sys::pg_detoast_datum_packed(detoasted);
Some(PgVarlena::<T>::from_datum(varlena.into()))
})
}
}
}
impl<T> IntoDatum for T
where
T: PostgresType + Serialize,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
Some(cbor_encode(&self).into())
}
fn type_oid() -> pg_sys::Oid {
crate::rust_regtypein::<T>()
}
}
impl<'de, T> FromDatum for T
where
T: PostgresType + Deserialize<'de>,
{
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode(datum.cast_mut_ptr())
}
}
unsafe fn from_datum_in_memory_context(
memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Self> {
if is_null {
None
} else {
cbor_decode_into_context(memory_context, datum.cast_mut_ptr())
}
}
}
fn cbor_encode<T>(input: T) -> *const pg_sys::varlena
where
T: Serialize,
{
let mut serialized = StringInfo::new();
serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); serde_cbor::to_writer(&mut serialized, &input).expect("failed to encode as CBOR");
let size = serialized.len() as usize;
let varlena = serialized.into_char_ptr();
unsafe {
set_varsize(varlena as *mut pg_sys::varlena, size as i32);
}
varlena as *const pg_sys::varlena
}
pub unsafe fn cbor_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
{
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = varsize_any_exhdr(varlena);
let data = vardata_any(varlena);
let slice = std::slice::from_raw_parts(data as *const u8, len);
serde_cbor::from_slice(slice).expect("failed to decode CBOR")
}
pub unsafe fn cbor_decode_into_context<'de, T>(
mut memory_context: PgMemoryContexts,
input: *mut pg_sys::varlena,
) -> T
where
T: Deserialize<'de>,
{
memory_context.switch_to(|_| {
let varlena = pg_sys::pg_detoast_datum_copy(input as *mut pg_sys::varlena);
cbor_decode(varlena)
})
}
#[allow(dead_code)]
fn json_encode<T>(input: T) -> *const varlena
where
T: Serialize,
{
let mut serialized = StringInfo::new();
serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); serde_json::to_writer(&mut serialized, &input).expect("failed to encode as JSON");
let size = serialized.len() as usize;
let varlena = serialized.into_char_ptr();
unsafe {
set_varsize(varlena as *mut pg_sys::varlena, size as i32);
}
varlena as *const pg_sys::varlena
}
#[allow(dead_code)]
unsafe fn json_decode<'de, T>(input: *mut pg_sys::varlena) -> T
where
T: Deserialize<'de>,
{
let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena);
let len = varsize_any_exhdr(varlena);
let data = vardata_any(varlena);
let slice = std::slice::from_raw_parts(data as *const u8, len);
serde_json::from_slice(slice).expect("failed to decode JSON")
}
unsafe impl<T> SqlTranslatable for PgVarlena<T>
where
T: SqlTranslatable + Copy,
{
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
T::argument_sql()
}
fn return_sql() -> Result<Returns, ReturnsError> {
T::return_sql()
}
}