use crate::array::RawArray;
use crate::datum::array::casper::ChaChaSlide;
use crate::layout::*;
use crate::toast::Toast;
use crate::{pg_sys, FromDatum, IntoDatum, PgMemoryContexts};
use bitvec::slice::BitSlice;
use core::fmt::{Debug, Formatter};
use core::ops::DerefMut;
use core::ptr::NonNull;
use pgrx_pg_sys::{Datum, Oid};
use pgrx_sql_entity_graph::metadata::{
ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
};
use serde::Serializer;
pub struct Array<'a, T: FromDatum> {
null_slice: NullKind<'a>,
slide_impl: ChaChaSlideImpl<T>,
raw: Toast<RawArray>,
}
impl<'a, T: FromDatum + Debug> Debug for Array<'a, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_list().entries(self.iter()).finish()
}
}
type ChaChaSlideImpl<T> = Box<dyn ChaChaSlide<T>>;
enum NullKind<'a> {
Bits(&'a BitSlice<u8>),
Strict(usize),
}
impl NullKind<'_> {
fn get(&self, index: usize) -> Option<bool> {
match self {
Self::Bits(b1) => b1.get(index).map(|b| !b),
Self::Strict(len) => index.lt(len).then(|| false),
}
}
fn any(&self) -> bool {
match self {
Self::Bits(b1) => !b1.all(),
Self::Strict(_) => false,
}
}
}
impl<'a, T: FromDatum + serde::Serialize + 'a> serde::Serialize for Array<'a, T> {
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where
S: Serializer,
{
serializer.collect_seq(self.iter())
}
}
#[deny(unsafe_op_in_unsafe_fn)]
impl<'a, T: FromDatum> Array<'a, T> {
unsafe fn deconstruct_from(mut raw: Toast<RawArray>) -> Array<'a, T> {
let oid = raw.oid();
let elem_layout = Layout::lookup_oid(oid);
let nelems = raw.len();
let null_slice = raw
.nulls_bitslice()
.map(|nonnull| NullKind::Bits(unsafe { &*nonnull.as_ptr() }))
.unwrap_or(NullKind::Strict(nelems));
let slide_impl: ChaChaSlideImpl<T> = match elem_layout.pass {
PassBy::Value => match elem_layout.size {
Size::Fixed(1) => Box::new(casper::FixedSizeByVal::<1>),
Size::Fixed(2) => Box::new(casper::FixedSizeByVal::<2>),
Size::Fixed(4) => Box::new(casper::FixedSizeByVal::<4>),
#[cfg(target_pointer_width = "64")]
Size::Fixed(8) => Box::new(casper::FixedSizeByVal::<8>),
_ => {
panic!("unrecognized pass-by-value array element layout: {:?}", elem_layout)
}
},
PassBy::Ref => match elem_layout.size {
Size::Varlena => Box::new(casper::PassByVarlena { align: elem_layout.align }),
Size::CStr => Box::new(casper::PassByCStr),
Size::Fixed(size) => Box::new(casper::PassByFixed {
padded_size: elem_layout.align.pad(size.into()),
}),
},
};
Array { raw, slide_impl, null_slice }
}
#[inline]
pub fn into_array_type(self) -> *const pg_sys::ArrayType {
let Array { raw, .. } = self;
let mut raw = core::mem::ManuallyDrop::new(raw);
let ptr = raw.deref_mut().deref_mut() as *mut RawArray;
unsafe { ptr.read() }.into_ptr().as_ptr() as _
}
#[inline]
pub fn iter(&self) -> ArrayIterator<'_, T> {
let ptr = self.raw.data_ptr();
ArrayIterator { array: self, curr: 0, ptr }
}
#[inline]
pub fn iter_deny_null(&self) -> ArrayTypedIterator<'_, T> {
if self.null_slice.any() {
panic!("array contains NULL");
}
let ptr = self.raw.data_ptr();
ArrayTypedIterator { array: self, curr: 0, ptr }
}
#[inline]
pub fn contains_nulls(&self) -> bool {
self.null_slice.any()
}
#[inline]
pub fn len(&self) -> usize {
self.raw.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.raw.len() == 0
}
#[allow(clippy::option_option)]
#[inline]
pub fn get(&self, index: usize) -> Option<Option<T>> {
let Some(is_null) = self.null_slice.get(index) else { return None };
if is_null {
return Some(None);
}
let mut at_byte = self.raw.data_ptr();
for i in 0..index {
match self.null_slice.get(i) {
None => unreachable!("array was exceeded while walking to known non-null index???"),
Some(true) => continue,
Some(false) => {
at_byte = unsafe { self.one_hop_this_time(at_byte) };
}
}
}
Some(unsafe { self.bring_it_back_now(at_byte, false) })
}
#[inline]
unsafe fn bring_it_back_now(&self, ptr: *const u8, is_null: bool) -> Option<T> {
match is_null {
true => None,
false => unsafe { self.slide_impl.bring_it_back_now(self, ptr) },
}
}
#[inline]
unsafe fn one_hop_this_time(&self, ptr: *const u8) -> *const u8 {
unsafe {
let offset = self.slide_impl.hop_size(ptr);
debug_assert!(ptr.wrapping_add(offset) <= self.raw.end_ptr());
ptr.add(offset)
}
}
}
#[derive(thiserror::Error, Debug, Copy, Clone, Eq, PartialEq)]
pub enum ArraySliceError {
#[error("Cannot create a slice of an Array that contains nulls")]
ContainsNulls,
}
#[cfg(target_pointer_width = "64")]
impl<'a> Array<'a, f64> {
#[inline]
pub fn as_slice(&self) -> Result<&[f64], ArraySliceError> {
as_slice(self)
}
}
impl<'a> Array<'a, f32> {
#[inline]
pub fn as_slice(&self) -> Result<&[f32], ArraySliceError> {
as_slice(self)
}
}
#[cfg(target_pointer_width = "64")]
impl<'a> Array<'a, i64> {
#[inline]
pub fn as_slice(&self) -> Result<&[i64], ArraySliceError> {
as_slice(self)
}
}
impl<'a> Array<'a, i32> {
#[inline]
pub fn as_slice(&self) -> Result<&[i32], ArraySliceError> {
as_slice(self)
}
}
impl<'a> Array<'a, i16> {
#[inline]
pub fn as_slice(&self) -> Result<&[i16], ArraySliceError> {
as_slice(self)
}
}
impl<'a> Array<'a, i8> {
#[inline]
pub fn as_slice(&self) -> Result<&[i8], ArraySliceError> {
as_slice(self)
}
}
#[inline(always)]
fn as_slice<'a, T: Sized + FromDatum>(array: &'a Array<'_, T>) -> Result<&'a [T], ArraySliceError> {
if array.contains_nulls() {
return Err(ArraySliceError::ContainsNulls);
}
let slice =
unsafe { std::slice::from_raw_parts(array.raw.data_ptr() as *const _, array.len()) };
Ok(slice)
}
mod casper {
use crate::layout::Align;
use crate::{pg_sys, varlena, Array, FromDatum};
pub(super) trait ChaChaSlide<T: FromDatum> {
unsafe fn bring_it_back_now(&self, array: &Array<T>, ptr: *const u8) -> Option<T>;
unsafe fn hop_size(&self, ptr: *const u8) -> usize;
}
#[inline(always)]
fn is_aligned<T>(p: *const T) -> bool {
(p as usize) & (core::mem::align_of::<T>() - 1) == 0
}
#[track_caller]
#[inline(always)]
pub(super) unsafe fn byval_read<T: Copy>(ptr: *const u8) -> T {
let ptr = ptr.cast::<T>();
debug_assert!(is_aligned(ptr), "not aligned to {}: {ptr:p}", std::mem::align_of::<T>());
ptr.read()
}
pub(super) struct FixedSizeByVal<const N: usize>;
impl<T: FromDatum, const N: usize> ChaChaSlide<T> for FixedSizeByVal<N> {
#[inline(always)]
unsafe fn bring_it_back_now(&self, array: &Array<T>, ptr: *const u8) -> Option<T> {
let datum = match N {
1 => pg_sys::Datum::from(byval_read::<u8>(ptr)),
2 => pg_sys::Datum::from(byval_read::<u16>(ptr)),
4 => pg_sys::Datum::from(byval_read::<u32>(ptr)),
8 => pg_sys::Datum::from(byval_read::<u64>(ptr)),
_ => unreachable!("`N` must be 1, 2, 4, or 8 (got {N})"),
};
T::from_polymorphic_datum(datum, false, array.raw.oid())
}
#[inline(always)]
unsafe fn hop_size(&self, _ptr: *const u8) -> usize {
N
}
}
pub(super) struct PassByVarlena {
pub(super) align: Align,
}
impl<T: FromDatum> ChaChaSlide<T> for PassByVarlena {
#[inline]
unsafe fn bring_it_back_now(&self, array: &Array<T>, ptr: *const u8) -> Option<T> {
let datum = pg_sys::Datum::from(ptr);
unsafe { T::from_polymorphic_datum(datum, false, array.raw.oid()) }
}
#[inline]
unsafe fn hop_size(&self, ptr: *const u8) -> usize {
let varsize = varlena::varsize_any(ptr.cast());
self.align.pad(varsize)
}
}
pub(super) struct PassByCStr;
impl<T: FromDatum> ChaChaSlide<T> for PassByCStr {
#[inline]
unsafe fn bring_it_back_now(&self, array: &Array<T>, ptr: *const u8) -> Option<T> {
let datum = pg_sys::Datum::from(ptr);
unsafe { T::from_polymorphic_datum(datum, false, array.raw.oid()) }
}
#[inline]
unsafe fn hop_size(&self, ptr: *const u8) -> usize {
let strlen = core::ffi::CStr::from_ptr(ptr.cast()).to_bytes().len();
strlen + 1
}
}
pub(super) struct PassByFixed {
pub(super) padded_size: usize,
}
impl<T: FromDatum> ChaChaSlide<T> for PassByFixed {
#[inline]
unsafe fn bring_it_back_now(&self, array: &Array<T>, ptr: *const u8) -> Option<T> {
let datum = pg_sys::Datum::from(ptr);
unsafe { T::from_polymorphic_datum(datum, false, array.raw.oid()) }
}
#[inline]
unsafe fn hop_size(&self, _ptr: *const u8) -> usize {
self.padded_size
}
}
}
pub struct VariadicArray<'a, T: FromDatum>(Array<'a, T>);
impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for VariadicArray<'a, T> {
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where
S: Serializer,
{
serializer.collect_seq(self.0.iter())
}
}
impl<'a, T: FromDatum> VariadicArray<'a, T> {
#[inline]
pub fn into_array_type(self) -> *const pg_sys::ArrayType {
self.0.into_array_type()
}
#[inline]
pub fn iter(&self) -> ArrayIterator<'_, T> {
self.0.iter()
}
#[inline]
pub fn iter_deny_null(&self) -> ArrayTypedIterator<'_, T> {
self.0.iter_deny_null()
}
#[inline]
pub fn len(&self) -> usize {
self.0.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
#[allow(clippy::option_option)]
#[inline]
pub fn get(&self, i: usize) -> Option<Option<T>> {
self.0.get(i)
}
}
pub struct ArrayTypedIterator<'a, T: 'a + FromDatum> {
array: &'a Array<'a, T>,
curr: usize,
ptr: *const u8,
}
impl<'a, T: FromDatum> Iterator for ArrayTypedIterator<'a, T> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let Self { array, curr, ptr } = self;
if *curr >= array.raw.len() {
None
} else {
let element = unsafe { array.bring_it_back_now(*ptr, false) };
*curr += 1;
*ptr = unsafe { array.one_hop_this_time(*ptr) };
element
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.array.raw.len().saturating_sub(self.curr);
(len, Some(len))
}
}
impl<'a, T: FromDatum> ExactSizeIterator for ArrayTypedIterator<'a, T> {}
impl<'a, T: FromDatum> core::iter::FusedIterator for ArrayTypedIterator<'a, T> {}
impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for ArrayTypedIterator<'a, T> {
fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where
S: Serializer,
{
serializer.collect_seq(self.array.iter())
}
}
pub struct ArrayIterator<'a, T: 'a + FromDatum> {
array: &'a Array<'a, T>,
curr: usize,
ptr: *const u8,
}
impl<'a, T: FromDatum> Iterator for ArrayIterator<'a, T> {
type Item = Option<T>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let Self { array, curr, ptr } = self;
let Some(is_null) = array.null_slice.get(*curr) else { return None };
*curr += 1;
let element = unsafe { array.bring_it_back_now(*ptr, is_null) };
if !is_null {
*ptr = unsafe { array.one_hop_this_time(*ptr) };
}
Some(element)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.array.raw.len().saturating_sub(self.curr);
(len, Some(len))
}
}
impl<'a, T: FromDatum> ExactSizeIterator for ArrayIterator<'a, T> {}
impl<'a, T: FromDatum> core::iter::FusedIterator for ArrayIterator<'a, T> {}
pub struct ArrayIntoIterator<'a, T: FromDatum> {
array: Array<'a, T>,
curr: usize,
ptr: *const u8,
}
impl<'a, T: FromDatum> IntoIterator for Array<'a, T> {
type Item = Option<T>;
type IntoIter = ArrayIntoIterator<'a, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
let ptr = self.raw.data_ptr();
ArrayIntoIterator { array: self, curr: 0, ptr }
}
}
impl<'a, T: FromDatum> IntoIterator for VariadicArray<'a, T> {
type Item = Option<T>;
type IntoIter = ArrayIntoIterator<'a, T>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
let ptr = self.0.raw.data_ptr();
ArrayIntoIterator { array: self.0, curr: 0, ptr }
}
}
impl<'a, T: FromDatum> Iterator for ArrayIntoIterator<'a, T> {
type Item = Option<T>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let Self { array, curr, ptr } = self;
let Some(is_null) = array.null_slice.get(*curr) else { return None };
*curr += 1;
let element = unsafe { array.bring_it_back_now(*ptr, is_null) };
if !is_null {
*ptr = unsafe { array.one_hop_this_time(*ptr) };
}
Some(element)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.array.raw.len().saturating_sub(self.curr);
(len, Some(len))
}
}
impl<'a, T: FromDatum> ExactSizeIterator for ArrayIntoIterator<'a, T> {}
impl<'a, T: FromDatum> core::iter::FusedIterator for ArrayIntoIterator<'a, T> {}
impl<'a, T: FromDatum> FromDatum for VariadicArray<'a, T> {
#[inline]
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
oid: pg_sys::Oid,
) -> Option<VariadicArray<'a, T>> {
Array::from_polymorphic_datum(datum, is_null, oid).map(Self)
}
}
impl<'a, T: FromDatum> FromDatum for Array<'a, T> {
#[inline]
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
_typoid: pg_sys::Oid,
) -> Option<Array<'a, T>> {
if is_null {
None
} else {
let Some(ptr) = NonNull::new(datum.cast_mut_ptr()) else { return None };
let raw = RawArray::detoast_from_varlena(ptr);
Some(Array::deconstruct_from(raw))
}
}
unsafe fn from_datum_in_memory_context(
mut memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
typoid: pg_sys::Oid,
) -> Option<Self>
where
Self: Sized,
{
if is_null {
None
} else {
memory_context.switch_to(|_| {
let copy = pg_sys::pg_detoast_datum_copy(datum.cast_mut_ptr());
Array::<T>::from_polymorphic_datum(pg_sys::Datum::from(copy), false, typoid)
})
}
}
}
impl<T: IntoDatum + FromDatum> IntoDatum for Array<'_, T> {
#[inline]
fn into_datum(self) -> Option<Datum> {
let array_type = self.into_array_type();
let datum = Datum::from(array_type);
Some(datum)
}
#[inline]
fn type_oid() -> Oid {
T::array_type_oid()
}
fn composite_type_oid(&self) -> Option<Oid> {
self.get(0)
.map(|v| v.composite_type_oid().map(|oid| unsafe { pg_sys::get_array_type(oid) }))
.flatten()
}
}
impl<T: FromDatum> FromDatum for Vec<T> {
#[inline]
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
typoid: pg_sys::Oid,
) -> Option<Vec<T>> {
if is_null {
None
} else {
Array::<T>::from_polymorphic_datum(datum, is_null, typoid)
.map(|array| array.iter_deny_null().collect::<Vec<_>>())
}
}
unsafe fn from_datum_in_memory_context(
memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
typoid: pg_sys::Oid,
) -> Option<Self>
where
Self: Sized,
{
Array::<T>::from_datum_in_memory_context(memory_context, datum, is_null, typoid)
.map(|array| array.iter_deny_null().collect::<Vec<_>>())
}
}
impl<T: FromDatum> FromDatum for Vec<Option<T>> {
#[inline]
unsafe fn from_polymorphic_datum(
datum: pg_sys::Datum,
is_null: bool,
typoid: pg_sys::Oid,
) -> Option<Vec<Option<T>>> {
Array::<T>::from_polymorphic_datum(datum, is_null, typoid)
.map(|array| array.iter().collect::<Vec<_>>())
}
unsafe fn from_datum_in_memory_context(
memory_context: PgMemoryContexts,
datum: pg_sys::Datum,
is_null: bool,
typoid: pg_sys::Oid,
) -> Option<Self>
where
Self: Sized,
{
Array::<T>::from_datum_in_memory_context(memory_context, datum, is_null, typoid)
.map(|array| array.iter().collect::<Vec<_>>())
}
}
impl<T> IntoDatum for Vec<T>
where
T: IntoDatum,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
let mut state = unsafe {
pg_sys::initArrayResult(
T::type_oid(),
PgMemoryContexts::CurrentMemoryContext.value(),
false,
)
};
for s in self {
let datum = s.into_datum();
let isnull = datum.is_none();
unsafe {
state = pg_sys::accumArrayResult(
state,
datum.unwrap_or(0.into()),
isnull,
T::type_oid(),
PgMemoryContexts::CurrentMemoryContext.value(),
);
}
}
if state.is_null() {
None
} else {
Some(unsafe {
pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
})
}
}
fn type_oid() -> pg_sys::Oid {
unsafe { pg_sys::get_array_type(T::type_oid()) }
}
fn composite_type_oid(&self) -> Option<Oid> {
self.get(0)
.map(|v| v.composite_type_oid().map(|oid| unsafe { pg_sys::get_array_type(oid) }))
.flatten()
}
#[inline]
fn is_compatible_with(other: pg_sys::Oid) -> bool {
Self::type_oid() == other || other == unsafe { pg_sys::get_array_type(T::type_oid()) }
}
}
impl<'a, T> IntoDatum for &'a [T]
where
T: IntoDatum + Copy + 'a,
{
fn into_datum(self) -> Option<pg_sys::Datum> {
let mut state = unsafe {
pg_sys::initArrayResult(
T::type_oid(),
PgMemoryContexts::CurrentMemoryContext.value(),
false,
)
};
for s in self {
let datum = s.into_datum();
let isnull = datum.is_none();
unsafe {
state = pg_sys::accumArrayResult(
state,
datum.unwrap_or(0.into()),
isnull,
T::type_oid(),
PgMemoryContexts::CurrentMemoryContext.value(),
);
}
}
if state.is_null() {
None
} else {
Some(unsafe {
pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
})
}
}
fn type_oid() -> pg_sys::Oid {
unsafe { pg_sys::get_array_type(T::type_oid()) }
}
#[inline]
fn is_compatible_with(other: pg_sys::Oid) -> bool {
Self::type_oid() == other || other == unsafe { pg_sys::get_array_type(T::type_oid()) }
}
}
unsafe impl<'a, T> SqlTranslatable for Array<'a, T>
where
T: SqlTranslatable + FromDatum,
{
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
match T::argument_sql()? {
SqlMapping::As(sql) => Ok(SqlMapping::As(format!("{sql}[]"))),
SqlMapping::Skip => Err(ArgumentError::SkipInArray),
SqlMapping::Composite { .. } => Ok(SqlMapping::Composite { array_brackets: true }),
SqlMapping::Source { .. } => Ok(SqlMapping::Source { array_brackets: true }),
}
}
fn return_sql() -> Result<Returns, ReturnsError> {
match T::return_sql()? {
Returns::One(SqlMapping::As(sql)) => {
Ok(Returns::One(SqlMapping::As(format!("{sql}[]"))))
}
Returns::One(SqlMapping::Composite { array_brackets: _ }) => {
Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
}
Returns::One(SqlMapping::Source { array_brackets: _ }) => {
Ok(Returns::One(SqlMapping::Source { array_brackets: true }))
}
Returns::One(SqlMapping::Skip) => Err(ReturnsError::SkipInArray),
Returns::SetOf(_) => Err(ReturnsError::SetOfInArray),
Returns::Table(_) => Err(ReturnsError::TableInArray),
}
}
}
unsafe impl<'a, T> SqlTranslatable for VariadicArray<'a, T>
where
T: SqlTranslatable + FromDatum,
{
fn argument_sql() -> Result<SqlMapping, ArgumentError> {
match T::argument_sql()? {
SqlMapping::As(sql) => Ok(SqlMapping::As(format!("{sql}[]"))),
SqlMapping::Skip => Err(ArgumentError::SkipInArray),
SqlMapping::Composite { .. } => Ok(SqlMapping::Composite { array_brackets: true }),
SqlMapping::Source { .. } => Ok(SqlMapping::Source { array_brackets: true }),
}
}
fn return_sql() -> Result<Returns, ReturnsError> {
match T::return_sql()? {
Returns::One(SqlMapping::As(sql)) => {
Ok(Returns::One(SqlMapping::As(format!("{sql}[]"))))
}
Returns::One(SqlMapping::Composite { array_brackets: _ }) => {
Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
}
Returns::One(SqlMapping::Source { array_brackets: _ }) => {
Ok(Returns::One(SqlMapping::Source { array_brackets: true }))
}
Returns::One(SqlMapping::Skip) => Err(ReturnsError::SkipInArray),
Returns::SetOf(_) => Err(ReturnsError::SetOfInArray),
Returns::Table(_) => Err(ReturnsError::TableInArray),
}
}
fn variadic() -> bool {
true
}
}