1use crate::{
13 pg_sys, register_xact_callback, FromDatum, IntoDatum, Json, PgMemoryContexts, PgOid,
14 PgXactCallbackEvent, TryFromDatumError,
15};
16use core::fmt::Formatter;
17use pgx_pg_sys::panic::ErrorReportable;
18use std::collections::HashMap;
19use std::ffi::{CStr, CString};
20use std::fmt::Debug;
21use std::marker::PhantomData;
22use std::mem;
23use std::ops::{Deref, Index};
24use std::ptr::NonNull;
25use std::sync::atomic::{AtomicBool, Ordering};
26
27pub type Result<T> = std::result::Result<T, Error>;
28
29#[derive(Debug, PartialEq)]
31#[repr(i32)]
32#[non_exhaustive]
33pub enum SpiOkCodes {
34 Connect = 1,
35 Finish = 2,
36 Fetch = 3,
37 Utility = 4,
38 Select = 5,
39 SelInto = 6,
40 Insert = 7,
41 Delete = 8,
42 Update = 9,
43 Cursor = 10,
44 InsertReturning = 11,
45 DeleteReturning = 12,
46 UpdateReturning = 13,
47 Rewritten = 14,
48 RelRegister = 15,
49 RelUnregister = 16,
50 TdRegister = 17,
51 Merge = 18,
53}
54
55#[derive(thiserror::Error, Debug, PartialEq)]
59#[repr(i32)]
60pub enum SpiErrorCodes {
61 Connect = -1,
62 Copy = -2,
63 OpUnknown = -3,
64 Unconnected = -4,
65 #[allow(dead_code)]
66 Cursor = -5, Argument = -6,
68 Param = -7,
69 Transaction = -8,
70 NoAttribute = -9,
71 NoOutFunc = -10,
72 TypUnknown = -11,
73 RelDuplicate = -12,
74 RelNotFound = -13,
75}
76
77impl std::fmt::Display for SpiErrorCodes {
78 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
79 f.write_fmt(format_args!("{:?}", self))
80 }
81}
82
83pub fn quote_identifier<StringLike: AsRef<str>>(ident: StringLike) -> String {
86 let ident_cstr = CString::new(ident.as_ref()).unwrap();
87 let quoted_cstr = unsafe {
89 let quoted_ptr = pg_sys::quote_identifier(ident_cstr.as_ptr());
90 CStr::from_ptr(quoted_ptr)
91 };
92 quoted_cstr.to_str().unwrap().to_string()
93}
94
95pub fn quote_qualified_identifier<StringLike: AsRef<str>>(
99 qualifier: StringLike,
100 ident: StringLike,
101) -> String {
102 let qualifier_cstr = CString::new(qualifier.as_ref()).unwrap();
103 let ident_cstr = CString::new(ident.as_ref()).unwrap();
104 let quoted_cstr = unsafe {
106 let quoted_ptr =
107 pg_sys::quote_qualified_identifier(qualifier_cstr.as_ptr(), ident_cstr.as_ptr());
108 CStr::from_ptr(quoted_ptr)
109 };
110 quoted_cstr.to_str().unwrap().to_string()
111}
112
113pub fn quote_literal<StringLike: AsRef<str>>(literal: StringLike) -> String {
116 let literal_cstr = CString::new(literal.as_ref()).unwrap();
117 let quoted_cstr = unsafe {
119 let quoted_ptr = pg_sys::quote_literal_cstr(literal_cstr.as_ptr());
120 CStr::from_ptr(quoted_ptr)
121 };
122 quoted_cstr.to_str().unwrap().to_string()
123}
124
125#[derive(Debug)]
126pub struct UnknownVariant;
127
128impl TryFrom<libc::c_int> for SpiOkCodes {
129 type Error = std::result::Result<SpiErrorCodes, UnknownVariant>;
131
132 fn try_from(code: libc::c_int) -> std::result::Result<SpiOkCodes, Self::Error> {
133 match code as i32 {
136 err @ -13..=-1 => Err(Ok(
137 unsafe { mem::transmute::<i32, SpiErrorCodes>(err) },
139 )),
140 ok @ 1..=18 => Ok(
141 unsafe { mem::transmute::<i32, SpiOkCodes>(ok) },
143 ),
144 _unknown => Err(Err(UnknownVariant)),
145 }
146 }
147}
148
149impl TryFrom<libc::c_int> for SpiErrorCodes {
150 type Error = std::result::Result<SpiOkCodes, UnknownVariant>;
152
153 fn try_from(code: libc::c_int) -> std::result::Result<SpiErrorCodes, Self::Error> {
154 match SpiOkCodes::try_from(code) {
155 Ok(ok) => Err(Ok(ok)),
156 Err(Ok(err)) => Ok(err),
157 Err(Err(unknown)) => Err(Err(unknown)),
158 }
159 }
160}
161
162#[derive(thiserror::Error, Debug, PartialEq)]
164pub enum Error {
165 #[error("SPI error: {0:?}")]
167 SpiError(#[from] SpiErrorCodes),
168
169 #[error("Datum error: {0}")]
171 DatumError(#[from] TryFromDatumError),
172
173 #[error("Argument count mismatch (expected {expected}, got {got})")]
175 PreparedStatementArgumentMismatch { expected: usize, got: usize },
176
177 #[error("SpiTupleTable positioned before the start or after the end")]
179 InvalidPosition,
180
181 #[error("Cursor named {0} not found")]
183 CursorNotFound(String),
184
185 #[error("The active `SPI_tuptable` is NULL")]
187 NoTupleTable,
188}
189
190pub struct Spi;
191
192static MUTABLE_MODE: AtomicBool = AtomicBool::new(false);
193impl Spi {
194 #[inline]
195 fn is_read_only() -> bool {
196 MUTABLE_MODE.load(Ordering::Relaxed) == false
197 }
198
199 #[inline]
200 fn clear_mutable() {
201 MUTABLE_MODE.store(false, Ordering::Relaxed)
202 }
203
204 fn mark_mutable() {
217 if Spi::is_read_only() {
218 register_xact_callback(PgXactCallbackEvent::Commit, || Spi::clear_mutable());
219 register_xact_callback(PgXactCallbackEvent::Abort, || Spi::clear_mutable());
220
221 MUTABLE_MODE.store(true, Ordering::Relaxed)
222 }
223 }
224}
225
226pub struct SpiClient<'conn> {
228 __marker: PhantomData<&'conn SpiConnection>,
229}
230
231struct SpiConnection(PhantomData<*mut ()>);
233
234impl SpiConnection {
235 fn connect() -> Result<Self> {
237 Spi::check_status(unsafe { pg_sys::SPI_connect() })?;
243 Ok(SpiConnection(PhantomData))
244 }
245}
246
247impl Drop for SpiConnection {
248 fn drop(&mut self) {
250 Spi::check_status(unsafe { pg_sys::SPI_finish() }).ok();
254 }
255}
256
257impl SpiConnection {
258 fn client(&self) -> SpiClient<'_> {
260 SpiClient { __marker: PhantomData }
261 }
262}
263
264pub trait Query {
270 type Arguments;
271 type Result;
272
273 fn execute(
275 self,
276 client: &SpiClient,
277 limit: Option<libc::c_long>,
278 arguments: Self::Arguments,
279 ) -> Self::Result;
280
281 fn open_cursor<'c: 'cc, 'cc>(
283 self,
284 client: &'cc SpiClient<'c>,
285 args: Self::Arguments,
286 ) -> SpiCursor<'c>;
287}
288
289impl<'a> Query for &'a String {
290 type Arguments = Option<Vec<(PgOid, Option<pg_sys::Datum>)>>;
291 type Result = Result<SpiTupleTable>;
292
293 fn execute(
294 self,
295 client: &SpiClient,
296 limit: Option<libc::c_long>,
297 arguments: Self::Arguments,
298 ) -> Self::Result {
299 self.as_str().execute(client, limit, arguments)
300 }
301
302 fn open_cursor<'c: 'cc, 'cc>(
303 self,
304 client: &'cc SpiClient<'c>,
305 args: Self::Arguments,
306 ) -> SpiCursor<'c> {
307 self.as_str().open_cursor(client, args)
308 }
309}
310
311fn prepare_datum(datum: Option<pg_sys::Datum>) -> (pg_sys::Datum, std::os::raw::c_char) {
312 match datum {
313 Some(datum) => (datum, ' ' as std::os::raw::c_char),
314 None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char),
315 }
316}
317
318impl<'a> Query for &'a str {
319 type Arguments = Option<Vec<(PgOid, Option<pg_sys::Datum>)>>;
320 type Result = Result<SpiTupleTable>;
321
322 fn execute(
326 self,
327 _client: &SpiClient,
328 limit: Option<libc::c_long>,
329 arguments: Self::Arguments,
330 ) -> Self::Result {
331 unsafe {
333 pg_sys::SPI_tuptable = std::ptr::null_mut();
334 }
335
336 let src = CString::new(self).expect("query contained a null byte");
337 let status_code = match arguments {
338 Some(args) => {
339 let nargs = args.len();
340 let (types, data): (Vec<_>, Vec<_>) = args.into_iter().unzip();
341 let mut argtypes = types.into_iter().map(PgOid::value).collect::<Vec<_>>();
342 let (mut datums, nulls): (Vec<_>, Vec<_>) =
343 data.into_iter().map(prepare_datum).unzip();
344
345 unsafe {
347 pg_sys::SPI_execute_with_args(
348 src.as_ptr(),
349 nargs as i32,
350 argtypes.as_mut_ptr(),
351 datums.as_mut_ptr(),
352 nulls.as_ptr(),
353 Spi::is_read_only(),
354 limit.unwrap_or(0),
355 )
356 }
357 }
358 None => unsafe {
360 pg_sys::SPI_execute(src.as_ptr(), Spi::is_read_only(), limit.unwrap_or(0))
361 },
362 };
363
364 Ok(SpiClient::prepare_tuple_table(status_code)?)
365 }
366
367 fn open_cursor<'c: 'cc, 'cc>(
368 self,
369 _client: &'cc SpiClient<'c>,
370 args: Self::Arguments,
371 ) -> SpiCursor<'c> {
372 let src = CString::new(self).expect("query contained a null byte");
373 let args = args.unwrap_or_default();
374
375 let nargs = args.len();
376 let (types, data): (Vec<_>, Vec<_>) = args.into_iter().unzip();
377 let mut argtypes = types.into_iter().map(PgOid::value).collect::<Vec<_>>();
378 let (mut datums, nulls): (Vec<_>, Vec<_>) = data.into_iter().map(prepare_datum).unzip();
379
380 let ptr = unsafe {
381 NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args(
384 std::ptr::null_mut(), src.as_ptr(),
386 nargs as i32,
387 argtypes.as_mut_ptr(),
388 datums.as_mut_ptr(),
389 nulls.as_ptr(),
390 Spi::is_read_only(),
391 0,
392 ))
393 };
394 SpiCursor { ptr, __marker: PhantomData }
395 }
396}
397
398#[derive(Debug)]
399pub struct SpiTupleTable {
400 #[allow(dead_code)]
401 status_code: SpiOkCodes,
402 table: Option<*mut pg_sys::SPITupleTable>,
403 size: usize,
404 current: isize,
405}
406
407pub struct SpiHeapTupleDataEntry {
409 datum: Option<pg_sys::Datum>,
410 type_oid: pg_sys::Oid,
411}
412
413pub struct SpiHeapTupleData {
415 tupdesc: NonNull<pg_sys::TupleDescData>,
416 entries: HashMap<usize, SpiHeapTupleDataEntry>,
417}
418
419impl Spi {
420 pub fn get_one<A: FromDatum + IntoDatum>(query: &str) -> Result<Option<A>> {
421 Spi::connect(|mut client| client.update(query, Some(1), None)?.first().get_one())
422 }
423
424 pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
425 query: &str,
426 ) -> Result<(Option<A>, Option<B>)> {
427 Spi::connect(|mut client| client.update(query, Some(1), None)?.first().get_two::<A, B>())
428 }
429
430 pub fn get_three<
431 A: FromDatum + IntoDatum,
432 B: FromDatum + IntoDatum,
433 C: FromDatum + IntoDatum,
434 >(
435 query: &str,
436 ) -> Result<(Option<A>, Option<B>, Option<C>)> {
437 Spi::connect(|mut client| {
438 client.update(query, Some(1), None)?.first().get_three::<A, B, C>()
439 })
440 }
441
442 pub fn get_one_with_args<A: FromDatum + IntoDatum>(
443 query: &str,
444 args: Vec<(PgOid, Option<pg_sys::Datum>)>,
445 ) -> Result<Option<A>> {
446 Spi::connect(|mut client| client.update(query, Some(1), Some(args))?.first().get_one())
447 }
448
449 pub fn get_two_with_args<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
450 query: &str,
451 args: Vec<(PgOid, Option<pg_sys::Datum>)>,
452 ) -> Result<(Option<A>, Option<B>)> {
453 Spi::connect(|mut client| {
454 client.update(query, Some(1), Some(args))?.first().get_two::<A, B>()
455 })
456 }
457
458 pub fn get_three_with_args<
459 A: FromDatum + IntoDatum,
460 B: FromDatum + IntoDatum,
461 C: FromDatum + IntoDatum,
462 >(
463 query: &str,
464 args: Vec<(PgOid, Option<pg_sys::Datum>)>,
465 ) -> Result<(Option<A>, Option<B>, Option<C>)> {
466 Spi::connect(|mut client| {
467 client.update(query, Some(1), Some(args))?.first().get_three::<A, B, C>()
468 })
469 }
470
471 pub fn run(query: &str) -> std::result::Result<(), Error> {
477 Spi::run_with_args(query, None)
478 }
479
480 pub fn run_with_args(
486 query: &str,
487 args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>,
488 ) -> std::result::Result<(), Error> {
489 Spi::connect(|mut client| client.update(query, None, args)).map(|_| ())
490 }
491
492 pub fn explain(query: &str) -> Result<Json> {
494 Spi::explain_with_args(query, None)
495 }
496
497 pub fn explain_with_args(
499 query: &str,
500 args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>,
501 ) -> Result<Json> {
502 Ok(Spi::connect(|mut client| {
503 client
504 .update(&format!("EXPLAIN (format json) {}", query), None, args)?
505 .first()
506 .get_one::<Json>()
507 })?
508 .unwrap())
509 }
510
511 pub fn connect<R, F: FnOnce(SpiClient<'_>) -> R>(f: F) -> R {
547 let connection =
562 SpiConnection::connect().expect("SPI_connect indicated an unexpected failure");
563
564 f(connection.client())
569 }
570
571 #[track_caller]
572 pub fn check_status(status_code: i32) -> std::result::Result<SpiOkCodes, Error> {
573 match SpiOkCodes::try_from(status_code) {
574 Ok(ok) => Ok(ok),
575 Err(Err(UnknownVariant)) => panic!("unrecognized SPI status code: {status_code}"),
576 Err(Ok(code)) => Err(Error::SpiError(code)),
577 }
578 }
579}
580
581impl<'a> SpiClient<'a> {
582 pub fn select<Q: Query>(
584 &self,
585 query: Q,
586 limit: Option<libc::c_long>,
587 args: Q::Arguments,
588 ) -> Q::Result {
589 self.execute(query, limit, args)
590 }
591
592 pub fn update<Q: Query>(
594 &mut self,
595 query: Q,
596 limit: Option<libc::c_long>,
597 args: Q::Arguments,
598 ) -> Q::Result {
599 Spi::mark_mutable();
600 self.execute(query, limit, args)
601 }
602
603 fn execute<Q: Query>(
604 &self,
605 query: Q,
606 limit: Option<libc::c_long>,
607 args: Q::Arguments,
608 ) -> Q::Result {
609 query.execute(&self, limit, args)
610 }
611
612 fn prepare_tuple_table(status_code: i32) -> std::result::Result<SpiTupleTable, Error> {
613 Ok(SpiTupleTable {
614 status_code: Spi::check_status(status_code)?,
615 table: unsafe {
617 if pg_sys::SPI_tuptable.is_null() {
618 None
619 } else {
620 Some(pg_sys::SPI_tuptable)
621 }
622 },
623 #[cfg(any(feature = "pg11", feature = "pg12"))]
624 size: unsafe { pg_sys::SPI_processed as usize },
625 #[cfg(not(any(feature = "pg11", feature = "pg12")))]
626 size: unsafe {
628 if pg_sys::SPI_tuptable.is_null() {
629 pg_sys::SPI_processed as usize
630 } else {
631 (*pg_sys::SPI_tuptable).numvals as usize
632 }
633 },
634 current: -1,
635 })
636 }
637
638 pub fn open_cursor<Q: Query>(&self, query: Q, args: Q::Arguments) -> SpiCursor {
644 query.open_cursor(&self, args)
645 }
646
647 pub fn open_cursor_mut<Q: Query>(&mut self, query: Q, args: Q::Arguments) -> SpiCursor {
653 Spi::mark_mutable();
654 query.open_cursor(&self, args)
655 }
656
657 pub fn find_cursor(&self, name: &str) -> Result<SpiCursor> {
665 use pgx_pg_sys::AsPgCStr;
666
667 let ptr = NonNull::new(unsafe { pg_sys::SPI_cursor_find(name.as_pg_cstr()) })
668 .ok_or(Error::CursorNotFound(name.to_string()))?;
669 Ok(SpiCursor { ptr, __marker: PhantomData })
670 }
671}
672
673type CursorName = String;
674
675pub struct SpiCursor<'client> {
731 ptr: NonNull<pg_sys::PortalData>,
732 __marker: PhantomData<&'client SpiClient<'client>>,
733}
734
735impl SpiCursor<'_> {
736 pub fn fetch(&mut self, count: libc::c_long) -> std::result::Result<SpiTupleTable, Error> {
740 unsafe {
742 pg_sys::SPI_tuptable = std::ptr::null_mut();
743 }
744 unsafe { pg_sys::SPI_cursor_fetch(self.ptr.as_mut(), true, count) }
746 Ok(SpiClient::prepare_tuple_table(SpiOkCodes::Fetch as i32)?)
747 }
748
749 pub fn detach_into_name(self) -> CursorName {
759 let cursor_ptr = unsafe { self.ptr.as_ref() };
761 std::mem::forget(self);
764 unsafe { CStr::from_ptr(cursor_ptr.name) }
766 .to_str()
767 .expect("cursor name is not valid UTF8")
768 .to_string()
769 }
770}
771
772impl Drop for SpiCursor<'_> {
773 fn drop(&mut self) {
774 unsafe {
776 pg_sys::SPI_cursor_close(self.ptr.as_mut());
777 }
778 }
779}
780
781pub struct PreparedStatement<'a> {
783 plan: NonNull<pg_sys::_SPI_plan>,
784 __marker: PhantomData<&'a ()>,
785}
786
787pub struct OwnedPreparedStatement(PreparedStatement<'static>);
789
790impl Deref for OwnedPreparedStatement {
791 type Target = PreparedStatement<'static>;
792
793 fn deref(&self) -> &Self::Target {
794 &self.0
795 }
796}
797
798impl Drop for OwnedPreparedStatement {
799 fn drop(&mut self) {
800 unsafe {
801 pg_sys::SPI_freeplan(self.0.plan.as_ptr());
802 }
803 }
804}
805
806impl<'a> Query for &'a OwnedPreparedStatement {
807 type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
808 type Result = Result<SpiTupleTable>;
809
810 fn execute(
811 self,
812 client: &SpiClient,
813 limit: Option<libc::c_long>,
814 arguments: Self::Arguments,
815 ) -> Self::Result {
816 (&self.0).execute(client, limit, arguments)
817 }
818
819 fn open_cursor<'c: 'cc, 'cc>(
820 self,
821 client: &'cc SpiClient<'c>,
822 args: Self::Arguments,
823 ) -> SpiCursor<'c> {
824 (&self.0).open_cursor(client, args)
825 }
826}
827
828impl Query for OwnedPreparedStatement {
829 type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
830 type Result = Result<SpiTupleTable>;
831
832 fn execute(
833 self,
834 client: &SpiClient,
835 limit: Option<libc::c_long>,
836 arguments: Self::Arguments,
837 ) -> Self::Result {
838 (&self.0).execute(client, limit, arguments)
839 }
840
841 fn open_cursor<'c: 'cc, 'cc>(
842 self,
843 client: &'cc SpiClient<'c>,
844 args: Self::Arguments,
845 ) -> SpiCursor<'c> {
846 (&self.0).open_cursor(client, args)
847 }
848}
849
850impl<'a> PreparedStatement<'a> {
851 pub fn keep(self) -> OwnedPreparedStatement {
855 unsafe {
859 pg_sys::SPI_keepplan(self.plan.as_ptr());
860 }
861 OwnedPreparedStatement(PreparedStatement { __marker: PhantomData, plan: self.plan })
862 }
863}
864
865impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> {
866 type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
867 type Result = Result<SpiTupleTable>;
868
869 fn execute(
870 self,
871 _client: &SpiClient,
872 limit: Option<libc::c_long>,
873 arguments: Self::Arguments,
874 ) -> Self::Result {
875 unsafe {
877 pg_sys::SPI_tuptable = std::ptr::null_mut();
878 }
879 let args = arguments.unwrap_or_default();
880 let nargs = args.len();
881
882 let expected = unsafe { pg_sys::SPI_getargcount(self.plan.as_ptr()) } as usize;
883
884 if nargs != expected {
885 return Err(Error::PreparedStatementArgumentMismatch { expected, got: nargs });
886 }
887
888 let (mut datums, mut nulls): (Vec<_>, Vec<_>) = args.into_iter().map(prepare_datum).unzip();
889
890 let status_code = unsafe {
892 pg_sys::SPI_execute_plan(
893 self.plan.as_ptr(),
894 datums.as_mut_ptr(),
895 nulls.as_mut_ptr(),
896 Spi::is_read_only(),
897 limit.unwrap_or(0),
898 )
899 };
900
901 Ok(SpiClient::prepare_tuple_table(status_code)?)
902 }
903
904 fn open_cursor<'c: 'cc, 'cc>(
905 self,
906 _client: &'cc SpiClient<'c>,
907 args: Self::Arguments,
908 ) -> SpiCursor<'c> {
909 let args = args.unwrap_or_default();
910
911 let (mut datums, nulls): (Vec<_>, Vec<_>) = args.into_iter().map(prepare_datum).unzip();
912
913 let ptr = unsafe {
916 NonNull::new_unchecked(pg_sys::SPI_cursor_open(
917 std::ptr::null_mut(), self.plan.as_ptr(),
919 datums.as_mut_ptr(),
920 nulls.as_ptr(),
921 Spi::is_read_only(),
922 ))
923 };
924 SpiCursor { ptr, __marker: PhantomData }
925 }
926}
927
928impl<'a> Query for PreparedStatement<'a> {
929 type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
930 type Result = Result<SpiTupleTable>;
931
932 fn execute(
933 self,
934 client: &SpiClient,
935 limit: Option<libc::c_long>,
936 arguments: Self::Arguments,
937 ) -> Self::Result {
938 (&self).execute(client, limit, arguments)
939 }
940
941 fn open_cursor<'c: 'cc, 'cc>(
942 self,
943 client: &'cc SpiClient<'c>,
944 args: Self::Arguments,
945 ) -> SpiCursor<'c> {
946 (&self).open_cursor(client, args)
947 }
948}
949
950impl<'a> SpiClient<'a> {
951 pub fn prepare(&self, query: &str, args: Option<Vec<PgOid>>) -> Result<PreparedStatement> {
957 let src = CString::new(query).expect("query contained a null byte");
958 let args = args.unwrap_or_default();
959 let nargs = args.len();
960
961 let plan = unsafe {
963 pg_sys::SPI_prepare(
964 src.as_ptr(),
965 nargs as i32,
966 args.into_iter().map(PgOid::value).collect::<Vec<_>>().as_mut_ptr(),
967 )
968 };
969 Ok(PreparedStatement {
970 plan: NonNull::new(plan).ok_or_else(|| {
971 Spi::check_status(unsafe {
972 pg_sys::SPI_result
974 })
975 .err()
976 .unwrap()
977 })?,
978 __marker: PhantomData,
979 })
980 }
981}
982
983impl SpiTupleTable {
984 pub fn first(mut self) -> Self {
989 self.current = 0;
990 self
991 }
992
993 pub fn rewind(mut self) -> Self {
997 self.current = -1;
998 self
999 }
1000
1001 pub fn len(&self) -> usize {
1003 self.size
1004 }
1005
1006 pub fn is_empty(&self) -> bool {
1007 self.len() == 0
1008 }
1009
1010 pub fn get_one<A: FromDatum + IntoDatum>(&self) -> Result<Option<A>> {
1011 self.get(1)
1012 }
1013
1014 pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
1015 &self,
1016 ) -> Result<(Option<A>, Option<B>)> {
1017 let a = self.get::<A>(1)?;
1018 let b = self.get::<B>(2)?;
1019 Ok((a, b))
1020 }
1021
1022 pub fn get_three<
1023 A: FromDatum + IntoDatum,
1024 B: FromDatum + IntoDatum,
1025 C: FromDatum + IntoDatum,
1026 >(
1027 &self,
1028 ) -> Result<(Option<A>, Option<B>, Option<C>)> {
1029 let a = self.get::<A>(1)?;
1030 let b = self.get::<B>(2)?;
1031 let c = self.get::<C>(3)?;
1032 Ok((a, b, c))
1033 }
1034
1035 #[inline(always)]
1036 fn get_spi_tuptable(&self) -> Result<(*mut pg_sys::SPITupleTable, *mut pg_sys::TupleDescData)> {
1037 let table = *self.table.as_ref().ok_or(Error::NoTupleTable)?;
1038 unsafe {
1039 Ok((table, (*table).tupdesc))
1041 }
1042 }
1043
1044 pub fn get_heap_tuple(&self) -> Result<Option<SpiHeapTupleData>> {
1045 if self.size == 0 || self.table.is_none() {
1046 Ok(None)
1052 } else if self.current < 0 || self.current as usize >= self.size {
1053 Err(Error::InvalidPosition)
1054 } else {
1055 let (table, tupdesc) = self.get_spi_tuptable()?;
1056 unsafe {
1057 let heap_tuple =
1058 std::slice::from_raw_parts((*table).vals, self.size)[self.current as usize];
1059
1060 SpiHeapTupleData::new(tupdesc, heap_tuple)
1062 }
1063 }
1064 }
1065
1066 pub fn get<T: IntoDatum + FromDatum>(&self, ordinal: usize) -> Result<Option<T>> {
1080 let (_, tupdesc) = self.get_spi_tuptable()?;
1081 let datum = self.get_datum_by_ordinal(ordinal)?;
1082 let is_null = datum.is_none();
1083 let datum = datum.unwrap_or_else(|| pg_sys::Datum::from(0));
1084
1085 unsafe {
1086 Ok(T::try_from_datum_in_memory_context(
1089 PgMemoryContexts::CurrentMemoryContext
1090 .parent()
1091 .expect("parent memory context is absent"),
1092 datum,
1093 is_null,
1094 pg_sys::SPI_gettypeid(tupdesc, ordinal as _),
1097 )?)
1098 }
1099 }
1100
1101 pub fn get_by_name<T: IntoDatum + FromDatum, S: AsRef<str>>(
1108 &self,
1109 name: S,
1110 ) -> Result<Option<T>> {
1111 self.get(self.column_ordinal(name)?)
1112 }
1113
1114 pub fn get_datum_by_ordinal(&self, ordinal: usize) -> Result<Option<pg_sys::Datum>> {
1123 self.check_ordinal_bounds(ordinal)?;
1124
1125 let (table, tupdesc) = self.get_spi_tuptable()?;
1126 if self.current < 0 || self.current as usize >= self.size {
1127 return Err(Error::InvalidPosition);
1128 }
1129 unsafe {
1130 let heap_tuple =
1131 std::slice::from_raw_parts((*table).vals, self.size)[self.current as usize];
1132 let mut is_null = false;
1133 let datum = pg_sys::SPI_getbinval(heap_tuple, tupdesc, ordinal as _, &mut is_null);
1134
1135 if is_null {
1136 Ok(None)
1137 } else {
1138 Ok(Some(datum))
1139 }
1140 }
1141 }
1142
1143 pub fn get_datum_by_name<S: AsRef<str>>(&self, name: S) -> Result<Option<pg_sys::Datum>> {
1150 self.get_datum_by_ordinal(self.column_ordinal(name)?)
1151 }
1152
1153 pub fn columns(&self) -> Result<usize> {
1155 let (_, tupdesc) = self.get_spi_tuptable()?;
1156 Ok(unsafe { (*tupdesc).natts as _ })
1158 }
1159
1160 #[inline]
1162 fn check_ordinal_bounds(&self, ordinal: usize) -> Result<()> {
1163 if ordinal < 1 || ordinal > self.columns()? {
1164 Err(Error::SpiError(SpiErrorCodes::NoAttribute))
1165 } else {
1166 Ok(())
1167 }
1168 }
1169
1170 pub fn column_type_oid(&self, ordinal: usize) -> Result<PgOid> {
1174 self.check_ordinal_bounds(ordinal)?;
1175
1176 let (_, tupdesc) = self.get_spi_tuptable()?;
1177 unsafe {
1178 let oid = pg_sys::SPI_gettypeid(tupdesc, ordinal as i32);
1180 Ok(PgOid::from(oid))
1181 }
1182 }
1183
1184 pub fn column_name(&self, ordinal: usize) -> Result<String> {
1196 self.check_ordinal_bounds(ordinal)?;
1197 let (_, tupdesc) = self.get_spi_tuptable()?;
1198 unsafe {
1199 let name = pg_sys::SPI_fname(tupdesc, ordinal as i32);
1201
1202 let str =
1205 CStr::from_ptr(name).to_str().expect("column name is not value UTF8").to_string();
1206
1207 pg_sys::pfree(name as *mut _);
1209 Ok(str)
1210 }
1211 }
1212
1213 pub fn column_ordinal<S: AsRef<str>>(&self, name: S) -> Result<usize> {
1224 let (_, tupdesc) = self.get_spi_tuptable()?;
1225 unsafe {
1226 let name_cstr = CString::new(name.as_ref()).expect("name contained a null byte");
1227 let fnumber = pg_sys::SPI_fnumber(tupdesc, name_cstr.as_ptr());
1228
1229 if fnumber == pg_sys::SPI_ERROR_NOATTRIBUTE {
1230 Err(Error::SpiError(SpiErrorCodes::NoAttribute))
1231 } else {
1232 Ok(fnumber as usize)
1233 }
1234 }
1235 }
1236}
1237
1238impl SpiHeapTupleData {
1239 pub unsafe fn new(
1246 tupdesc: pg_sys::TupleDesc,
1247 htup: *mut pg_sys::HeapTupleData,
1248 ) -> Result<Option<Self>> {
1249 let tupdesc = NonNull::new(tupdesc).ok_or(Error::NoTupleTable)?;
1250 let mut data = SpiHeapTupleData { tupdesc, entries: HashMap::default() };
1251 let tupdesc = tupdesc.as_ptr();
1252
1253 unsafe {
1254 for i in 1..=tupdesc.as_ref().unwrap().natts {
1256 let mut is_null = false;
1257 let datum = pg_sys::SPI_getbinval(htup, tupdesc, i, &mut is_null);
1258
1259 data.entries.entry(i as usize).or_insert_with(|| SpiHeapTupleDataEntry {
1260 datum: if is_null { None } else { Some(datum) },
1261 type_oid: pg_sys::SPI_gettypeid(tupdesc, i),
1262 });
1263 }
1264 }
1265
1266 Ok(Some(data))
1267 }
1268
1269 pub fn get<T: IntoDatum + FromDatum>(&self, ordinal: usize) -> Result<Option<T>> {
1278 self.get_datum_by_ordinal(ordinal).map(|entry| entry.value())?
1279 }
1280
1281 pub fn get_by_name<T: IntoDatum + FromDatum, S: AsRef<str>>(
1288 &self,
1289 name: S,
1290 ) -> Result<Option<T>> {
1291 self.get_datum_by_name(name.as_ref()).map(|entry| entry.value())?
1292 }
1293
1294 pub fn get_datum_by_ordinal(
1302 &self,
1303 ordinal: usize,
1304 ) -> std::result::Result<&SpiHeapTupleDataEntry, Error> {
1305 self.entries.get(&ordinal).ok_or_else(|| Error::SpiError(SpiErrorCodes::NoAttribute))
1306 }
1307
1308 pub fn get_datum_by_name<S: AsRef<str>>(
1318 &self,
1319 name: S,
1320 ) -> std::result::Result<&SpiHeapTupleDataEntry, Error> {
1321 unsafe {
1322 let name_cstr = CString::new(name.as_ref()).expect("name contained a null byte");
1323 let fnumber = pg_sys::SPI_fnumber(self.tupdesc.as_ptr(), name_cstr.as_ptr());
1324
1325 if fnumber == pg_sys::SPI_ERROR_NOATTRIBUTE {
1326 Err(Error::SpiError(SpiErrorCodes::NoAttribute))
1327 } else {
1328 self.get_datum_by_ordinal(fnumber as usize)
1329 }
1330 }
1331 }
1332
1333 pub fn set_by_ordinal<T: IntoDatum>(
1339 &mut self,
1340 ordinal: usize,
1341 datum: T,
1342 ) -> std::result::Result<(), Error> {
1343 self.check_ordinal_bounds(ordinal)?;
1344 self.entries.insert(
1345 ordinal,
1346 SpiHeapTupleDataEntry { datum: datum.into_datum(), type_oid: T::type_oid() },
1347 );
1348 Ok(())
1349 }
1350
1351 pub fn set_by_name<T: IntoDatum>(
1361 &mut self,
1362 name: &str,
1363 datum: T,
1364 ) -> std::result::Result<(), Error> {
1365 unsafe {
1366 let name_cstr = CString::new(name).expect("name contained a null byte");
1367 let fnumber = pg_sys::SPI_fnumber(self.tupdesc.as_ptr(), name_cstr.as_ptr());
1368 if fnumber == pg_sys::SPI_ERROR_NOATTRIBUTE {
1369 Err(Error::SpiError(SpiErrorCodes::NoAttribute))
1370 } else {
1371 self.set_by_ordinal(fnumber as usize, datum)
1372 }
1373 }
1374 }
1375
1376 #[inline]
1377 pub fn columns(&self) -> usize {
1378 unsafe {
1379 (*self.tupdesc.as_ptr()).natts as usize
1381 }
1382 }
1383
1384 #[inline]
1386 fn check_ordinal_bounds(&self, ordinal: usize) -> std::result::Result<(), Error> {
1387 if ordinal < 1 || ordinal > self.columns() {
1388 Err(Error::SpiError(SpiErrorCodes::NoAttribute))
1389 } else {
1390 Ok(())
1391 }
1392 }
1393}
1394
1395impl SpiHeapTupleDataEntry {
1396 pub fn value<T: IntoDatum + FromDatum>(&self) -> Result<Option<T>> {
1397 match self.datum.as_ref() {
1398 Some(datum) => unsafe {
1399 T::try_from_datum(*datum, false, self.type_oid).map_err(|e| Error::DatumError(e))
1400 },
1401 None => Ok(None),
1402 }
1403 }
1404
1405 pub fn oid(&self) -> pg_sys::Oid {
1406 self.type_oid
1407 }
1408}
1409
1410impl Index<usize> for SpiHeapTupleData {
1414 type Output = SpiHeapTupleDataEntry;
1415
1416 fn index(&self, index: usize) -> &Self::Output {
1417 self.get_datum_by_ordinal(index).expect("invalid ordinal value")
1418 }
1419}
1420
1421impl Index<&str> for SpiHeapTupleData {
1425 type Output = SpiHeapTupleDataEntry;
1426
1427 fn index(&self, index: &str) -> &Self::Output {
1428 self.get_datum_by_name(index).expect("invalid field name")
1429 }
1430}
1431
1432impl Iterator for SpiTupleTable {
1433 type Item = SpiHeapTupleData;
1434
1435 #[inline]
1439 fn next(&mut self) -> Option<Self::Item> {
1440 self.current += 1;
1441 if self.current >= self.size as isize {
1442 None
1443 } else {
1444 assert!(self.current >= 0);
1445 self.get_heap_tuple().report()
1446 }
1447 }
1448
1449 #[inline]
1450 fn size_hint(&self) -> (usize, Option<usize>) {
1451 (0, Some(self.size))
1452 }
1453
1454 #[inline]
1455 fn count(self) -> usize
1456 where
1457 Self: Sized,
1458 {
1459 self.size
1460 }
1461}