1use crate::array::RawArray;
11use crate::layout::*;
12use crate::slice::PallocSlice;
13use crate::{pg_sys, FromDatum, IntoDatum, PgMemoryContexts};
14use bitvec::slice::BitSlice;
15use core::ptr::NonNull;
16use pgx_sql_entity_graph::metadata::{
17 ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable,
18};
19use serde::Serializer;
20use std::marker::PhantomData;
21use std::{mem, ptr};
22
23pub struct Array<'a, T: FromDatum> {
57 raw: Option<RawArray>,
58 nelems: usize,
59 datum_slice: Option<PallocSlice<pg_sys::Datum>>,
61 needs_pfree: bool,
62 null_slice: NullKind<'a>,
63 elem_layout: Layout,
64 _marker: PhantomData<T>,
65}
66
67enum NullKind<'a> {
71 Bits(&'a BitSlice<u8>),
72 Strict(usize),
73}
74
75impl NullKind<'_> {
76 fn get(&self, index: usize) -> Option<bool> {
77 match self {
78 Self::Bits(b1) => b1.get(index).map(|b| !b),
79 Self::Strict(len) => index.le(len).then(|| false),
80 }
81 }
82
83 fn any(&self) -> bool {
84 match self {
85 Self::Bits(b1) => !b1.all(),
86 Self::Strict(_) => false,
87 }
88 }
89}
90
91impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for Array<'a, T> {
92 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
93 where
94 S: Serializer,
95 {
96 serializer.collect_seq(self.iter())
97 }
98}
99
100impl<'a, T: FromDatum> Drop for Array<'a, T> {
101 fn drop(&mut self) {
102 let slice = mem::take(&mut self.datum_slice);
104 mem::drop(slice);
105 if self.needs_pfree {
106 if let Some(raw) = self.raw.take().map(|r| r.into_ptr()) {
107 unsafe { pg_sys::pfree(raw.as_ptr().cast()) }
109 }
110 }
111 }
112}
113
114#[deny(unsafe_op_in_unsafe_fn)]
115impl<'a, T: FromDatum> Array<'a, T> {
116 unsafe fn deconstruct_from(
121 ptr: NonNull<pg_sys::varlena>,
122 raw: RawArray,
123 elem_layout: Layout,
124 ) -> Array<'a, T> {
125 let oid = raw.oid();
126 let len = raw.len();
127 let array = raw.into_ptr().as_ptr();
128
129 let mut elements = ptr::null_mut();
131 let mut nulls = ptr::null_mut();
132 let mut nelems = 0;
133
134 unsafe {
144 pg_sys::deconstruct_array(
145 array,
146 oid,
147 elem_layout.size.as_typlen().into(),
148 matches!(elem_layout.pass, PassBy::Value),
149 elem_layout.align.as_typalign(),
150 &mut elements,
151 &mut nulls,
152 &mut nelems,
153 )
154 };
155
156 let nelems = nelems as usize;
157
158 assert_eq!(nelems, len);
160
161 let needs_pfree = ptr.as_ptr().cast() != array;
164 let mut raw = unsafe { RawArray::from_ptr(NonNull::new_unchecked(array)) };
165
166 let null_slice = raw
167 .nulls_bitslice()
168 .map(|nonnull| NullKind::Bits(unsafe { &*nonnull.as_ptr() }))
169 .unwrap_or(NullKind::Strict(nelems));
170
171 let pallocd_null_slice =
176 unsafe { PallocSlice::from_raw_parts(NonNull::new(nulls).unwrap(), nelems) };
177 #[cfg(debug_assertions)]
178 for i in 0..nelems {
179 assert!(null_slice.get(i).unwrap().eq(unsafe { pallocd_null_slice.get_unchecked(i) }));
180 }
181
182 let datum_slice =
184 Some(unsafe { PallocSlice::from_raw_parts(NonNull::new(elements).unwrap(), nelems) });
185
186 Array {
187 needs_pfree,
188 raw: Some(raw),
189 nelems,
190 datum_slice,
191 null_slice,
192 elem_layout,
193 _marker: PhantomData,
194 }
195 }
196
197 pub fn into_array_type(mut self) -> *const pg_sys::ArrayType {
198 let ptr = mem::take(&mut self.raw).map(|raw| raw.into_ptr().as_ptr() as _);
199 mem::forget(self);
200 ptr.unwrap_or(ptr::null())
201 }
202
203 #[deprecated(
208 since = "0.5.0",
209 note = "this function cannot be safe and is not generically sound\n\
210 even `unsafe fn as_slice(&self) -> &[T]` is not sound for all `&[T]`\n\
211 if you are sure your usage is sound, consider RawArray"
212 )]
213 pub fn as_slice(&self) -> &[T] {
214 const DATUM_SIZE: usize = mem::size_of::<pg_sys::Datum>();
215 if self.null_slice.any() {
216 panic!("null detected: can't expose potentially uninit data as a slice!")
217 }
218 match (self.elem_layout.size_matches::<T>(), self.raw.as_ref()) {
219 #[allow(unreachable_patterns)] (Some(1 | 2 | 4 | DATUM_SIZE), Some(raw)) => unsafe {
222 raw.assume_init_data_slice::<T>()
223 },
224 (_, _) => panic!("no correctly-sized slice exists"),
225 }
226 }
227
228 pub fn iter(&self) -> ArrayIterator<'_, T> {
230 ArrayIterator { array: self, curr: 0 }
231 }
232
233 pub fn iter_deny_null(&self) -> ArrayTypedIterator<'_, T> {
237 if let Some(at) = &self.raw {
238 if unsafe { at.any_nulls() } {
240 panic!("array contains NULL");
241 }
242 } else {
243 panic!("array is NULL");
244 };
245
246 ArrayTypedIterator { array: self, curr: 0 }
247 }
248
249 #[inline]
250 pub fn len(&self) -> usize {
251 self.nelems
252 }
253
254 #[inline]
255 pub fn is_empty(&self) -> bool {
256 self.nelems == 0
257 }
258
259 #[allow(clippy::option_option)]
260 #[inline]
261 pub fn get(&self, i: usize) -> Option<Option<T>> {
262 if i >= self.nelems {
263 None
264 } else {
265 Some(unsafe {
266 T::from_polymorphic_datum(
267 *(self.datum_slice.as_ref()?.get(i)?),
268 self.null_slice.get(i)?,
269 self.raw.as_ref().map(|r| r.oid()).unwrap_or_default(),
270 )
271 })
272 }
273 }
274}
275
276pub struct VariadicArray<'a, T: FromDatum>(Array<'a, T>);
277
278impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for VariadicArray<'a, T> {
279 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
280 where
281 S: Serializer,
282 {
283 serializer.collect_seq(self.0.iter())
284 }
285}
286
287impl<'a, T: FromDatum> VariadicArray<'a, T> {
288 pub fn into_array_type(self) -> *const pg_sys::ArrayType {
289 self.0.into_array_type()
290 }
291
292 #[deprecated(
297 since = "0.5.0",
298 note = "this function cannot be safe and is not generically sound\n\
299 even `unsafe fn as_slice(&self) -> &[T]` is not sound for all `&[T]`\n\
300 if you are sure your usage is sound, consider RawArray"
301 )]
302 #[allow(deprecated)]
303 pub fn as_slice(&self) -> &[T] {
304 self.0.as_slice()
305 }
306
307 pub fn iter(&self) -> ArrayIterator<'_, T> {
309 self.0.iter()
310 }
311
312 pub fn iter_deny_null(&self) -> ArrayTypedIterator<'_, T> {
316 self.0.iter_deny_null()
317 }
318
319 #[inline]
320 pub fn len(&self) -> usize {
321 self.0.len()
322 }
323
324 #[inline]
325 pub fn is_empty(&self) -> bool {
326 self.0.is_empty()
327 }
328
329 #[allow(clippy::option_option)]
330 #[inline]
331 pub fn get(&self, i: usize) -> Option<Option<T>> {
332 self.0.get(i)
333 }
334}
335
336pub struct ArrayTypedIterator<'a, T: 'a + FromDatum> {
337 array: &'a Array<'a, T>,
338 curr: usize,
339}
340
341impl<'a, T: FromDatum> Iterator for ArrayTypedIterator<'a, T> {
342 type Item = T;
343
344 #[inline]
345 fn next(&mut self) -> Option<Self::Item> {
346 if self.curr >= self.array.nelems {
347 None
348 } else {
349 let element = self
350 .array
351 .get(self.curr)
352 .expect("array index out of bounds")
353 .expect("array element was unexpectedly NULL during iteration");
354 self.curr += 1;
355 Some(element)
356 }
357 }
358}
359
360impl<'a, T: FromDatum + serde::Serialize> serde::Serialize for ArrayTypedIterator<'a, T> {
361 fn serialize<S>(&self, serializer: S) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
362 where
363 S: Serializer,
364 {
365 serializer.collect_seq(self.array.iter())
366 }
367}
368
369pub struct ArrayIterator<'a, T: 'a + FromDatum> {
370 array: &'a Array<'a, T>,
371 curr: usize,
372}
373
374impl<'a, T: FromDatum> Iterator for ArrayIterator<'a, T> {
375 type Item = Option<T>;
376
377 #[inline]
378 fn next(&mut self) -> Option<Self::Item> {
379 if self.curr >= self.array.nelems {
380 None
381 } else {
382 let element = self.array.get(self.curr).unwrap();
383 self.curr += 1;
384 Some(element)
385 }
386 }
387}
388
389pub struct ArrayIntoIterator<'a, T: FromDatum> {
390 array: Array<'a, T>,
391 curr: usize,
392}
393
394impl<'a, T: FromDatum> IntoIterator for Array<'a, T> {
395 type Item = Option<T>;
396 type IntoIter = ArrayIntoIterator<'a, T>;
397
398 fn into_iter(self) -> Self::IntoIter {
399 ArrayIntoIterator { array: self, curr: 0 }
400 }
401}
402
403impl<'a, T: FromDatum> IntoIterator for VariadicArray<'a, T> {
404 type Item = Option<T>;
405 type IntoIter = ArrayIntoIterator<'a, T>;
406
407 fn into_iter(self) -> Self::IntoIter {
408 ArrayIntoIterator { array: self.0, curr: 0 }
409 }
410}
411
412impl<'a, T: FromDatum> Iterator for ArrayIntoIterator<'a, T> {
413 type Item = Option<T>;
414
415 #[inline]
416 fn next(&mut self) -> Option<Self::Item> {
417 if self.curr >= self.array.nelems {
418 None
419 } else {
420 let element = self.array.get(self.curr).unwrap();
421 self.curr += 1;
422 Some(element)
423 }
424 }
425
426 fn size_hint(&self) -> (usize, Option<usize>) {
427 (0, Some(self.array.nelems))
428 }
429
430 fn count(self) -> usize
431 where
432 Self: Sized,
433 {
434 self.array.nelems
435 }
436
437 fn nth(&mut self, n: usize) -> Option<Self::Item> {
438 self.array.get(n)
439 }
440}
441
442impl<'a, T: FromDatum> FromDatum for VariadicArray<'a, T> {
443 #[inline]
444 unsafe fn from_polymorphic_datum(
445 datum: pg_sys::Datum,
446 is_null: bool,
447 oid: pg_sys::Oid,
448 ) -> Option<VariadicArray<'a, T>> {
449 Array::from_polymorphic_datum(datum, is_null, oid).map(Self)
450 }
451}
452
453impl<'a, T: FromDatum> FromDatum for Array<'a, T> {
454 #[inline]
455 unsafe fn from_polymorphic_datum(
456 datum: pg_sys::Datum,
457 is_null: bool,
458 _typoid: pg_sys::Oid,
459 ) -> Option<Array<'a, T>> {
460 if is_null {
461 None
462 } else {
463 let ptr = NonNull::new(datum.cast_mut_ptr())?;
464 let array = pg_sys::pg_detoast_datum(datum.cast_mut_ptr()) as *mut pg_sys::ArrayType;
465 let raw =
466 RawArray::from_ptr(NonNull::new(array).expect("detoast returned null ArrayType*"));
467 let oid = raw.oid();
468 let layout = Layout::lookup_oid(oid);
469
470 Some(Array::deconstruct_from(ptr, raw, layout))
471 }
472 }
473}
474
475impl<T: FromDatum> FromDatum for Vec<T> {
476 #[inline]
477 unsafe fn from_polymorphic_datum(
478 datum: pg_sys::Datum,
479 is_null: bool,
480 typoid: pg_sys::Oid,
481 ) -> Option<Vec<T>> {
482 if is_null {
483 None
484 } else {
485 let array = Array::<T>::from_polymorphic_datum(datum, is_null, typoid).unwrap();
486 let mut v = Vec::with_capacity(array.len());
487
488 for element in array.iter() {
489 v.push(element.expect("array element was NULL"))
490 }
491 Some(v)
492 }
493 }
494}
495
496impl<T: FromDatum> FromDatum for Vec<Option<T>> {
497 #[inline]
498 unsafe fn from_polymorphic_datum(
499 datum: pg_sys::Datum,
500 is_null: bool,
501 typoid: pg_sys::Oid,
502 ) -> Option<Vec<Option<T>>> {
503 if is_null || datum.is_null() {
504 None
505 } else {
506 let array = Array::<T>::from_polymorphic_datum(datum, is_null, typoid).unwrap();
507 Some(array.iter().collect::<Vec<_>>())
508 }
509 }
510}
511
512impl<T> IntoDatum for Vec<T>
513where
514 T: IntoDatum,
515{
516 fn into_datum(self) -> Option<pg_sys::Datum> {
517 let mut state = unsafe {
518 pg_sys::initArrayResult(
519 T::type_oid(),
520 PgMemoryContexts::CurrentMemoryContext.value(),
521 false,
522 )
523 };
524 for s in self {
525 let datum = s.into_datum();
526 let isnull = datum.is_none();
527
528 unsafe {
529 state = pg_sys::accumArrayResult(
530 state,
531 datum.unwrap_or(0.into()),
532 isnull,
533 T::type_oid(),
534 PgMemoryContexts::CurrentMemoryContext.value(),
535 );
536 }
537 }
538
539 if state.is_null() {
540 None
542 } else {
543 Some(unsafe {
544 pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
545 })
546 }
547 }
548
549 fn type_oid() -> pg_sys::Oid {
550 unsafe { pg_sys::get_array_type(T::type_oid()) }
551 }
552
553 #[inline]
554 fn is_compatible_with(other: pg_sys::Oid) -> bool {
555 Self::type_oid() == other || other == unsafe { pg_sys::get_array_type(T::type_oid()) }
556 }
557}
558
559impl<'a, T> IntoDatum for &'a [T]
560where
561 T: IntoDatum + Copy + 'a,
562{
563 fn into_datum(self) -> Option<pg_sys::Datum> {
564 let mut state = unsafe {
565 pg_sys::initArrayResult(
566 T::type_oid(),
567 PgMemoryContexts::CurrentMemoryContext.value(),
568 false,
569 )
570 };
571 for s in self {
572 let datum = s.into_datum();
573 let isnull = datum.is_none();
574
575 unsafe {
576 state = pg_sys::accumArrayResult(
577 state,
578 datum.unwrap_or(0.into()),
579 isnull,
580 T::type_oid(),
581 PgMemoryContexts::CurrentMemoryContext.value(),
582 );
583 }
584 }
585
586 if state.is_null() {
587 None
589 } else {
590 Some(unsafe {
591 pg_sys::makeArrayResult(state, PgMemoryContexts::CurrentMemoryContext.value())
592 })
593 }
594 }
595
596 fn type_oid() -> pg_sys::Oid {
597 unsafe { pg_sys::get_array_type(T::type_oid()) }
598 }
599
600 #[inline]
601 fn is_compatible_with(other: pg_sys::Oid) -> bool {
602 Self::type_oid() == other || other == unsafe { pg_sys::get_array_type(T::type_oid()) }
603 }
604}
605
606unsafe impl<'a, T> SqlTranslatable for Array<'a, T>
607where
608 T: SqlTranslatable + FromDatum,
609{
610 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
611 match T::argument_sql()? {
612 SqlMapping::As(sql) => Ok(SqlMapping::As(format!("{sql}[]"))),
613 SqlMapping::Skip => Err(ArgumentError::SkipInArray),
614 SqlMapping::Composite { .. } => Ok(SqlMapping::Composite { array_brackets: true }),
615 SqlMapping::Source { .. } => Ok(SqlMapping::Source { array_brackets: true }),
616 }
617 }
618
619 fn return_sql() -> Result<Returns, ReturnsError> {
620 match T::return_sql()? {
621 Returns::One(SqlMapping::As(sql)) => {
622 Ok(Returns::One(SqlMapping::As(format!("{sql}[]"))))
623 }
624 Returns::One(SqlMapping::Composite { array_brackets: _ }) => {
625 Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
626 }
627 Returns::One(SqlMapping::Source { array_brackets: _ }) => {
628 Ok(Returns::One(SqlMapping::Source { array_brackets: true }))
629 }
630 Returns::One(SqlMapping::Skip) => Err(ReturnsError::SkipInArray),
631 Returns::SetOf(_) => Err(ReturnsError::SetOfInArray),
632 Returns::Table(_) => Err(ReturnsError::TableInArray),
633 }
634 }
635}
636
637unsafe impl<'a, T> SqlTranslatable for VariadicArray<'a, T>
638where
639 T: SqlTranslatable + FromDatum,
640{
641 fn argument_sql() -> Result<SqlMapping, ArgumentError> {
642 match T::argument_sql()? {
643 SqlMapping::As(sql) => Ok(SqlMapping::As(format!("{sql}[]"))),
644 SqlMapping::Skip => Err(ArgumentError::SkipInArray),
645 SqlMapping::Composite { .. } => Ok(SqlMapping::Composite { array_brackets: true }),
646 SqlMapping::Source { .. } => Ok(SqlMapping::Source { array_brackets: true }),
647 }
648 }
649
650 fn return_sql() -> Result<Returns, ReturnsError> {
651 match T::return_sql()? {
652 Returns::One(SqlMapping::As(sql)) => {
653 Ok(Returns::One(SqlMapping::As(format!("{sql}[]"))))
654 }
655 Returns::One(SqlMapping::Composite { array_brackets: _ }) => {
656 Ok(Returns::One(SqlMapping::Composite { array_brackets: true }))
657 }
658 Returns::One(SqlMapping::Source { array_brackets: _ }) => {
659 Ok(Returns::One(SqlMapping::Source { array_brackets: true }))
660 }
661 Returns::One(SqlMapping::Skip) => Err(ReturnsError::SkipInArray),
662 Returns::SetOf(_) => Err(ReturnsError::SetOfInArray),
663 Returns::Table(_) => Err(ReturnsError::TableInArray),
664 }
665 }
666
667 fn variadic() -> bool {
668 true
669 }
670}