pgx/datum/
array.rs

1/*
2Portions Copyright 2019-2021 ZomboDB, LLC.
3Portions Copyright 2021-2022 Technology Concepts & Design, Inc. <support@tcdi.com>
4
5All rights reserved.
6
7Use of this source code is governed by the MIT license that can be found in the LICENSE file.
8*/
9
10use crate::array::RawArray;
11use crate::layout::*;
12use crate::slice::PallocSlice;
13use crate::{pg_sys, FromDatum, IntoDatum, PgMemoryContexts};
14use bitvec::slice::BitSlice;
15use core::ptr::NonNull;
16use pgx_sql_entity_graph::metadata::{
17    ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
18};
19use serde::Serializer;
20use std::marker::PhantomData;
21use std::{mem, ptr};
22
23/** An array of some type (eg. `TEXT[]`, `int[]`)
24
25While conceptually similar to a [`Vec<T>`][std::vec::Vec], arrays are lazy.
26
27Using a [`Vec<T>`][std::vec::Vec] here means each element of the passed array will be eagerly fetched and converted into a Rust type:
28
29```rust,no_run
30use pgx::prelude::*;
31
32#[pg_extern]
33fn with_vec(elems: Vec<String>) {
34    // Elements all already converted.
35    for elem in elems {
36        todo!()
37    }
38}
39```
40
41Using an array, elements are only fetched and converted into a Rust type on demand:
42
43```rust,no_run
44use pgx::prelude::*;
45
46#[pg_extern]
47fn with_vec(elems: Array<String>) {
48    // Elements converted one by one
49    for maybe_elem in elems {
50        let elem = maybe_elem.unwrap();
51        todo!()
52    }
53}
54```
55*/
56pub struct Array<'a, T: FromDatum> {
57    raw: Option<RawArray>,
58    nelems: usize,
59    // Remove this field if/when we figure out how to stop using pg_sys::deconstruct_array
60    datum_slice: Option<PallocSlice<pg_sys::Datum>>,
61    needs_pfree: bool,
62    null_slice: NullKind<'a>,
63    elem_layout: Layout,
64    _marker: PhantomData<T>,
65}
66
67// FIXME: When Array::over gets removed, this enum can probably be dropped
68// since we won't be entertaining ArrayTypes which don't use bitslices anymore.
69// However, we could also use a static resolution? Hard to say what's best.
70enum NullKind<'a> {
71    Bits(&'a BitSlice<u8>),
72    Strict(usize),
73}
74
75impl NullKind<'_> {
76    fn get(&self, index: usize) -> Option<bool> {
77        match self {
78            Self::Bits(b1) => b1.get(index).map(|b| !b),
79            Self::Strict(len) => index.le(len).then(|| false),
80        }
81    }
82
83    fn any(&self) -> bool {
84        match self {
85            Self::Bits(b1) => !b1.all(),
86            Self::Strict(_) => false,
87        }
88    }
89}
90
91impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for Array<'a, T> {
92    fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
93    where
94        S: Serializer,
95    {
96        serializer.collect_seq(self.iter())
97    }
98}
99
100impl<'a, T: FromDatum> Drop for Array<'a, T> {
101    fn drop(&mut self) {
102        // First drop the slice that references the RawArray
103        let slice = mem::take(&mut self.datum_slice);
104        mem::drop(slice);
105        if self.needs_pfree {
106            if let Some(raw) = self.raw.take().map(|r| r.into_ptr()) {
107                // SAFETY: if pgx detoasted a clone of this varlena, pfree the clone
108                unsafe { pg_sys::pfree(raw.as_ptr().cast()) }
109            }
110        }
111    }
112}
113
114#[deny(unsafe_op_in_unsafe_fn)]
115impl<'a, T: FromDatum> Array<'a, T> {
116    /// # Safety
117    ///
118    /// This function requires that the RawArray was obtained in a properly-constructed form
119    /// (probably from Postgres).
120    unsafe fn deconstruct_from(
121        ptr: NonNull<pg_sys::varlena>,
122        raw: RawArray,
123        elem_layout: Layout,
124    ) -> Array<'a, T> {
125        let oid = raw.oid();
126        let len = raw.len();
127        let array = raw.into_ptr().as_ptr();
128
129        // outvals for deconstruct_array
130        let mut elements = ptr::null_mut();
131        let mut nulls = ptr::null_mut();
132        let mut nelems = 0;
133
134        /*
135        FIXME(jubilee): This way of getting array buffers causes problems for any Drop impl,
136        and clashes with assumptions of Array being a "zero-copy", lifetime-bound array,
137        some of which are implicitly embedded in other methods (e.g. Array::over).
138        It also risks leaking memory, as deconstruct_array calls palloc.
139
140        SAFETY: We have already asserted the validity of the RawArray, so
141        this only makes mistakes if we mix things up and pass Postgres the wrong data.
142        */
143        unsafe {
144            pg_sys::deconstruct_array(
145                array,
146                oid,
147                elem_layout.size.as_typlen().into(),
148                matches!(elem_layout.pass, PassBy::Value),
149                elem_layout.align.as_typalign(),
150                &mut elements,
151                &mut nulls,
152                &mut nelems,
153            )
154        };
155
156        let nelems = nelems as usize;
157
158        // Check our RawArray len impl for correctness.
159        assert_eq!(nelems, len);
160
161        // Detoasting the varlena may have allocated a fresh ArrayType.
162        // Check pointer equivalence, then pfree the palloc later if it's a new one.
163        let needs_pfree = ptr.as_ptr().cast() != array;
164        let mut raw = unsafe { RawArray::from_ptr(NonNull::new_unchecked(array)) };
165
166        let null_slice = raw
167            .nulls_bitslice()
168            .map(|nonnull| NullKind::Bits(unsafe { &*nonnull.as_ptr() }))
169            .unwrap_or(NullKind::Strict(nelems));
170
171        // The array was just deconstructed, which allocates twice: effectively [Datum] and [bool].
172        // But pgx doesn't actually need [bool] if NullKind's handling of BitSlices is correct.
173        // So, assert correctness of the NullKind implementation and cleanup.
174        // SAFETY: The pointer we got should be correctly constructed for slice validity.
175        let pallocd_null_slice =
176            unsafe { PallocSlice::from_raw_parts(NonNull::new(nulls).unwrap(), nelems) };
177        #[cfg(debug_assertions)]
178        for i in 0..nelems {
179            assert!(null_slice.get(i).unwrap().eq(unsafe { pallocd_null_slice.get_unchecked(i) }));
180        }
181
182        // SAFETY: This was just handed over as a palloc, so of course we can do this.
183        let datum_slice =
184            Some(unsafe { PallocSlice::from_raw_parts(NonNull::new(elements).unwrap(), nelems) });
185
186        Array {
187            needs_pfree,
188            raw: Some(raw),
189            nelems,
190            datum_slice,
191            null_slice,
192            elem_layout,
193            _marker: PhantomData,
194        }
195    }
196
197    pub fn into_array_type(mut self) -> *const pg_sys::ArrayType {
198        let ptr = mem::take(&mut self.raw).map(|raw| raw.into_ptr().as_ptr() as _);
199        mem::forget(self);
200        ptr.unwrap_or(ptr::null())
201    }
202
203    // # Panics
204    //
205    // Panics if it detects the slightest misalignment between types,
206    // or if a valid slice contains nulls, which may be uninit data.
207    #[deprecated(
208        since = "0.5.0",
209        note = "this function cannot be safe and is not generically sound\n\
210        even `unsafe fn as_slice(&self) -> &[T]` is not sound for all `&[T]`\n\
211        if you are sure your usage is sound, consider RawArray"
212    )]
213    pub fn as_slice(&self) -> &[T] {
214        const DATUM_SIZE: usize = mem::size_of::<pg_sys::Datum>();
215        if self.null_slice.any() {
216            panic!("null detected: can't expose potentially uninit data as a slice!")
217        }
218        match (self.elem_layout.size_matches::<T>(), self.raw.as_ref()) {
219            // SAFETY: Rust slice layout matches Postgres data layout and this array is "owned"
220            #[allow(unreachable_patterns)] // happens on 32-bit when DATUM_SIZE = 4
221            (Some(1 | 2 | 4 | DATUM_SIZE), Some(raw)) => unsafe {
222                raw.assume_init_data_slice::<T>()
223            },
224            (_, _) => panic!("no correctly-sized slice exists"),
225        }
226    }
227
228    /// Return an Iterator of Option<T> over the contained Datums.
229    pub fn iter(&self) -> ArrayIterator<'_, T> {
230        ArrayIterator { array: self, curr: 0 }
231    }
232
233    /// Return an Iterator of the contained Datums (converted to Rust types).
234    ///
235    /// This function will panic when called if the array contains any SQL NULL values.
236    pub fn iter_deny_null(&self) -> ArrayTypedIterator<'_, T> {
237        if let Some(at) = &self.raw {
238            // SAFETY: if Some, then the ArrayType is from Postgres
239            if unsafe { at.any_nulls() } {
240                panic!("array contains NULL");
241            }
242        } else {
243            panic!("array is NULL");
244        };
245
246        ArrayTypedIterator { array: self, curr: 0 }
247    }
248
249    #[inline]
250    pub fn len(&self) -> usize {
251        self.nelems
252    }
253
254    #[inline]
255    pub fn is_empty(&self) -> bool {
256        self.nelems == 0
257    }
258
259    #[allow(clippy::option_option)]
260    #[inline]
261    pub fn get(&self, i: usize) -> Option<Option<T>> {
262        if i >= self.nelems {
263            None
264        } else {
265            Some(unsafe {
266                T::from_polymorphic_datum(
267                    *(self.datum_slice.as_ref()?.get(i)?),
268                    self.null_slice.get(i)?,
269                    self.raw.as_ref().map(|r| r.oid()).unwrap_or_default(),
270                )
271            })
272        }
273    }
274}
275
276pub struct VariadicArray<'a, T: FromDatum>(Array<'a, T>);
277
278impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for VariadicArray<'a, T> {
279    fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
280    where
281        S: Serializer,
282    {
283        serializer.collect_seq(self.0.iter())
284    }
285}
286
287impl<'a, T: FromDatum> VariadicArray<'a, T> {
288    pub fn into_array_type(self) -> *const pg_sys::ArrayType {
289        self.0.into_array_type()
290    }
291
292    // # Panics
293    //
294    // Panics if it detects the slightest misalignment between types,
295    // or if a valid slice contains nulls, which may be uninit data.
296    #[deprecated(
297        since = "0.5.0",
298        note = "this function cannot be safe and is not generically sound\n\
299        even `unsafe fn as_slice(&self) -> &[T]` is not sound for all `&[T]`\n\
300        if you are sure your usage is sound, consider RawArray"
301    )]
302    #[allow(deprecated)]
303    pub fn as_slice(&self) -> &[T] {
304        self.0.as_slice()
305    }
306
307    /// Return an Iterator of Option<T> over the contained Datums.
308    pub fn iter(&self) -> ArrayIterator<'_, T> {
309        self.0.iter()
310    }
311
312    /// Return an Iterator of the contained Datums (converted to Rust types).
313    ///
314    /// This function will panic when called if the array contains any SQL NULL values.
315    pub fn iter_deny_null(&self) -> ArrayTypedIterator<'_, T> {
316        self.0.iter_deny_null()
317    }
318
319    #[inline]
320    pub fn len(&self) -> usize {
321        self.0.len()
322    }
323
324    #[inline]
325    pub fn is_empty(&self) -> bool {
326        self.0.is_empty()
327    }
328
329    #[allow(clippy::option_option)]
330    #[inline]
331    pub fn get(&self, i: usize) -> Option<Option<T>> {
332        self.0.get(i)
333    }
334}
335
336pub struct ArrayTypedIterator<'a, T: 'a + FromDatum> {
337    array: &'a Array<'a, T>,
338    curr: usize,
339}
340
341impl<'a, T: FromDatum> Iterator for ArrayTypedIterator<'a, T> {
342    type Item = T;
343
344    #[inline]
345    fn next(&mut self) -> Option<Self::Item> {
346        if self.curr >= self.array.nelems {
347            None
348        } else {
349            let element = self
350                .array
351                .get(self.curr)
352                .expect("array index out of bounds")
353                .expect("array element was unexpectedly NULL during iteration");
354            self.curr += 1;
355            Some(element)
356        }
357    }
358}
359
360impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for ArrayTypedIterator<'a, T> {
361    fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
362    where
363        S: Serializer,
364    {
365        serializer.collect_seq(self.array.iter())
366    }
367}
368
369pub struct ArrayIterator<'a, T: 'a + FromDatum> {
370    array: &'a Array<'a, T>,
371    curr: usize,
372}
373
374impl<'a, T: FromDatum> Iterator for ArrayIterator<'a, T> {
375    type Item = Option<T>;
376
377    #[inline]
378    fn next(&mut self) -> Option<Self::Item> {
379        if self.curr >= self.array.nelems {
380            None
381        } else {
382            let element = self.array.get(self.curr).unwrap();
383            self.curr += 1;
384            Some(element)
385        }
386    }
387}
388
389pub struct ArrayIntoIterator<'a, T: FromDatum> {
390    array: Array<'a, T>,
391    curr: usize,
392}
393
394impl<'a, T: FromDatum> IntoIterator for Array<'a, T> {
395    type Item = Option<T>;
396    type IntoIter = ArrayIntoIterator<'a, T>;
397
398    fn into_iter(self) -> Self::IntoIter {
399        ArrayIntoIterator { array: self, curr: 0 }
400    }
401}
402
403impl<'a, T: FromDatum> IntoIterator for VariadicArray<'a, T> {
404    type Item = Option<T>;
405    type IntoIter = ArrayIntoIterator<'a, T>;
406
407    fn into_iter(self) -> Self::IntoIter {
408        ArrayIntoIterator { array: self.0, curr: 0 }
409    }
410}
411
412impl<'a, T: FromDatum> Iterator for ArrayIntoIterator<'a, T> {
413    type Item = Option<T>;
414
415    #[inline]
416    fn next(&mut self) -> Option<Self::Item> {
417        if self.curr >= self.array.nelems {
418            None
419        } else {
420            let element = self.array.get(self.curr).unwrap();
421            self.curr += 1;
422            Some(element)
423        }
424    }
425
426    fn size_hint(&self) -> (usize, Option<usize>) {
427        (0, Some(self.array.nelems))
428    }
429
430    fn count(self) -> usize
431    where
432        Self: Sized,
433    {
434        self.array.nelems
435    }
436
437    fn nth(&mut self, n: usize) -> Option<Self::Item> {
438        self.array.get(n)
439    }
440}
441
442impl<'a, T: FromDatum> FromDatum for VariadicArray<'a, T> {
443    #[inline]
444    unsafe fn from_polymorphic_datum(
445        datum: pg_sys::Datum,
446        is_null: bool,
447        oid: pg_sys::Oid,
448    ) -> Option<VariadicArray<'a, T>> {
449        Array::from_polymorphic_datum(datum, is_null, oid).map(Self)
450    }
451}
452
453impl<'a, T: FromDatum> FromDatum for Array<'a, T> {
454    #[inline]
455    unsafe fn from_polymorphic_datum(
456        datum: pg_sys::Datum,
457        is_null: bool,
458        _typoid: pg_sys::Oid,
459    ) -> Option<Array<'a, T>> {
460        if is_null {
461            None
462        } else {
463            let ptr = NonNull::new(datum.cast_mut_ptr())?;
464            let array = pg_sys::pg_detoast_datum(datum.cast_mut_ptr()) as *mut pg_sys::ArrayType;
465            let raw =
466                RawArray::from_ptr(NonNull::new(array).expect("detoast returned null ArrayType*"));
467            let oid = raw.oid();
468            let layout = Layout::lookup_oid(oid);
469
470            Some(Array::deconstruct_from(ptr, raw, layout))
471        }
472    }
473}
474
475impl<T: FromDatum> FromDatum for Vec<T> {
476    #[inline]
477    unsafe fn from_polymorphic_datum(
478        datum: pg_sys::Datum,
479        is_null: bool,
480        typoid: pg_sys::Oid,
481    ) -> Option<Vec<T>> {
482        if is_null {
483            None
484        } else {
485            let array = Array::<T>::from_polymorphic_datum(datum, is_null, typoid).unwrap();
486            let mut v = Vec::with_capacity(array.len());
487
488            for element in array.iter() {
489                v.push(element.expect("array element was NULL"))
490            }
491            Some(v)
492        }
493    }
494}
495
496impl<T: FromDatum> FromDatum for Vec<Option<T>> {
497    #[inline]
498    unsafe fn from_polymorphic_datum(
499        datum: pg_sys::Datum,
500        is_null: bool,
501        typoid: pg_sys::Oid,
502    ) -> Option<Vec<Option<T>>> {
503        if is_null || datum.is_null() {
504            None
505        } else {
506            let array = Array::<T>::from_polymorphic_datum(datum, is_null, typoid).unwrap();
507            Some(array.iter().collect::<Vec<_>>())
508        }
509    }
510}
511
512impl<T> IntoDatum for Vec<T>
513where
514    T: IntoDatum,
515{
516    fn into_datum(self) -> Option<pg_sys::Datum> {
517        let mut state = unsafe {
518            pg_sys::initArrayResult(
519                T::type_oid(),
520                PgMemoryContexts::CurrentMemoryContext.value(),
521                false,
522            )
523        };
524        for s in self {
525            let datum = s.into_datum();
526            let isnull = datum.is_none();
527
528            unsafe {
529                state = pg_sys::accumArrayResult(
530                    state,
531                    datum.unwrap_or(0.into()),
532                    isnull,
533                    T::type_oid(),
534                    PgMemoryContexts::CurrentMemoryContext.value(),
535                );
536            }
537        }
538
539        if state.is_null() {
540            // shouldn't happen
541            None
542        } else {
543            Some(unsafe {
544                pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
545            })
546        }
547    }
548
549    fn type_oid() -> pg_sys::Oid {
550        unsafe { pg_sys::get_array_type(T::type_oid()) }
551    }
552
553    #[inline]
554    fn is_compatible_with(other: pg_sys::Oid) -> bool {
555        Self::type_oid() == other || other == unsafe { pg_sys::get_array_type(T::type_oid()) }
556    }
557}
558
559impl<'a, T> IntoDatum for &'a [T]
560where
561    T: IntoDatum + Copy + 'a,
562{
563    fn into_datum(self) -> Option<pg_sys::Datum> {
564        let mut state = unsafe {
565            pg_sys::initArrayResult(
566                T::type_oid(),
567                PgMemoryContexts::CurrentMemoryContext.value(),
568                false,
569            )
570        };
571        for s in self {
572            let datum = s.into_datum();
573            let isnull = datum.is_none();
574
575            unsafe {
576                state = pg_sys::accumArrayResult(
577                    state,
578                    datum.unwrap_or(0.into()),
579                    isnull,
580                    T::type_oid(),
581                    PgMemoryContexts::CurrentMemoryContext.value(),
582                );
583            }
584        }
585
586        if state.is_null() {
587            // shouldn't happen
588            None
589        } else {
590            Some(unsafe {
591                pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
592            })
593        }
594    }
595
596    fn type_oid() -> pg_sys::Oid {
597        unsafe { pg_sys::get_array_type(T::type_oid()) }
598    }
599
600    #[inline]
601    fn is_compatible_with(other: pg_sys::Oid) -> bool {
602        Self::type_oid() == other || other == unsafe { pg_sys::get_array_type(T::type_oid()) }
603    }
604}
605
606unsafe impl<'a, T> SqlTranslatable for Array<'a, T>
607where
608    T: SqlTranslatable + FromDatum,
609{
610    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
611        match T::argument_sql()? {
612            SqlMapping::As(sql) => Ok(SqlMapping::As(format!("{sql}[]"))),
613            SqlMapping::Skip => Err(ArgumentError::SkipInArray),
614            SqlMapping::Composite { .. } => Ok(SqlMapping::Composite { array_brackets: true }),
615            SqlMapping::Source { .. } => Ok(SqlMapping::Source { array_brackets: true }),
616        }
617    }
618
619    fn return_sql() -> Result<Returns, ReturnsError> {
620        match T::return_sql()? {
621            Returns::One(SqlMapping::As(sql)) => {
622                Ok(Returns::One(SqlMapping::As(format!("{sql}[]"))))
623            }
624            Returns::One(SqlMapping::Composite { array_brackets: _ }) => {
625                Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
626            }
627            Returns::One(SqlMapping::Source { array_brackets: _ }) => {
628                Ok(Returns::One(SqlMapping::Source { array_brackets: true }))
629            }
630            Returns::One(SqlMapping::Skip) => Err(ReturnsError::SkipInArray),
631            Returns::SetOf(_) => Err(ReturnsError::SetOfInArray),
632            Returns::Table(_) => Err(ReturnsError::TableInArray),
633        }
634    }
635}
636
637unsafe impl<'a, T> SqlTranslatable for VariadicArray<'a, T>
638where
639    T: SqlTranslatable + FromDatum,
640{
641    fn argument_sql() -> Result<SqlMapping, ArgumentError> {
642        match T::argument_sql()? {
643            SqlMapping::As(sql) => Ok(SqlMapping::As(format!("{sql}[]"))),
644            SqlMapping::Skip => Err(ArgumentError::SkipInArray),
645            SqlMapping::Composite { .. } => Ok(SqlMapping::Composite { array_brackets: true }),
646            SqlMapping::Source { .. } => Ok(SqlMapping::Source { array_brackets: true }),
647        }
648    }
649
650    fn return_sql() -> Result<Returns, ReturnsError> {
651        match T::return_sql()? {
652            Returns::One(SqlMapping::As(sql)) => {
653                Ok(Returns::One(SqlMapping::As(format!("{sql}[]"))))
654            }
655            Returns::One(SqlMapping::Composite { array_brackets: _ }) => {
656                Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
657            }
658            Returns::One(SqlMapping::Source { array_brackets: _ }) => {
659                Ok(Returns::One(SqlMapping::Source { array_brackets: true }))
660            }
661            Returns::One(SqlMapping::Skip) => Err(ReturnsError::SkipInArray),
662            Returns::SetOf(_) => Err(ReturnsError::SetOfInArray),
663            Returns::Table(_) => Err(ReturnsError::TableInArray),
664        }
665    }
666
667    fn variadic() -> bool {
668        true
669    }
670}