use crate::{
pg_sys, register_xact_callback, FromDatum, IntoDatum, Json, PgMemoryContexts, PgOid,
PgXactCallbackEvent, TryFromDatumError,
};
use core::fmt::Formatter;
use pgx_pg_sys::panic::ErrorReportable;
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, Index};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, PartialEq)]
#[repr(i32)]
#[non_exhaustive]
pub enum SpiOkCodes {
Connect = 1,
Finish = 2,
Fetch = 3,
Utility = 4,
Select = 5,
SelInto = 6,
Insert = 7,
Delete = 8,
Update = 9,
Cursor = 10,
InsertReturning = 11,
DeleteReturning = 12,
UpdateReturning = 13,
Rewritten = 14,
RelRegister = 15,
RelUnregister = 16,
TdRegister = 17,
Merge = 18,
}
#[derive(thiserror::Error, Debug, PartialEq)]
#[repr(i32)]
pub enum SpiErrorCodes {
Connect = -1,
Copy = -2,
OpUnknown = -3,
Unconnected = -4,
#[allow(dead_code)]
Cursor = -5,
Argument = -6,
Param = -7,
Transaction = -8,
NoAttribute = -9,
NoOutFunc = -10,
TypUnknown = -11,
RelDuplicate = -12,
RelNotFound = -13,
}
impl std::fmt::Display for SpiErrorCodes {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!("{:?}", self))
}
}
#[derive(Debug)]
pub struct UnknownVariant;
impl TryFrom<libc::c_int> for SpiOkCodes {
type Error = std::result::Result<SpiErrorCodes, UnknownVariant>;
fn try_from(code: libc::c_int) -> std::result::Result<SpiOkCodes, Self::Error> {
match code as i32 {
err @ -13..=-1 => Err(Ok(
unsafe { mem::transmute::<i32, SpiErrorCodes>(err) },
)),
ok @ 1..=18 => Ok(
unsafe { mem::transmute::<i32, SpiOkCodes>(ok) },
),
_unknown => Err(Err(UnknownVariant)),
}
}
}
impl TryFrom<libc::c_int> for SpiErrorCodes {
type Error = std::result::Result<SpiOkCodes, UnknownVariant>;
fn try_from(code: libc::c_int) -> std::result::Result<SpiErrorCodes, Self::Error> {
match SpiOkCodes::try_from(code) {
Ok(ok) => Err(Ok(ok)),
Err(Ok(err)) => Ok(err),
Err(Err(unknown)) => Err(Err(unknown)),
}
}
}
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum Error {
#[error("SPI error: {0:?}")]
SpiError(#[from] SpiErrorCodes),
#[error("Datum error: {0}")]
DatumError(#[from] TryFromDatumError),
#[error("Argument count mismatch (expected {expected}, got {got})")]
PreparedStatementArgumentMismatch { expected: usize, got: usize },
#[error("SpiTupleTable positioned before the start or after the end")]
InvalidPosition,
#[error("Cursor named {0} not found")]
CursorNotFound(String),
#[error("The active `SPI_tuptable` is NULL")]
NoTupleTable,
}
pub struct Spi;
static MUTABLE_MODE: AtomicBool = AtomicBool::new(false);
impl Spi {
#[inline]
fn is_read_only() -> bool {
MUTABLE_MODE.load(Ordering::Relaxed) == false
}
#[inline]
fn clear_mutable() {
MUTABLE_MODE.store(false, Ordering::Relaxed)
}
fn mark_mutable() {
if Spi::is_read_only() {
register_xact_callback(PgXactCallbackEvent::Commit, || Spi::clear_mutable());
register_xact_callback(PgXactCallbackEvent::Abort, || Spi::clear_mutable());
MUTABLE_MODE.store(true, Ordering::Relaxed)
}
}
}
pub struct SpiClient<'conn> {
__marker: PhantomData<&'conn SpiConnection>,
}
struct SpiConnection(PhantomData<*mut ()>);
impl SpiConnection {
fn connect() -> Result<Self> {
Spi::check_status(unsafe { pg_sys::SPI_connect() })?;
Ok(SpiConnection(PhantomData))
}
}
impl Drop for SpiConnection {
fn drop(&mut self) {
Spi::check_status(unsafe { pg_sys::SPI_finish() }).ok();
}
}
impl SpiConnection {
fn client(&self) -> SpiClient<'_> {
SpiClient { __marker: PhantomData }
}
}
pub trait Query {
type Arguments;
type Result;
fn execute(
self,
client: &SpiClient,
limit: Option<i64>,
arguments: Self::Arguments,
) -> Self::Result;
fn open_cursor<'c: 'cc, 'cc>(
self,
client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> SpiCursor<'c>;
}
impl<'a> Query for &'a String {
type Arguments = Option<Vec<(PgOid, Option<pg_sys::Datum>)>>;
type Result = Result<SpiTupleTable>;
fn execute(
self,
client: &SpiClient,
limit: Option<i64>,
arguments: Self::Arguments,
) -> Self::Result {
self.as_str().execute(client, limit, arguments)
}
fn open_cursor<'c: 'cc, 'cc>(
self,
client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> SpiCursor<'c> {
self.as_str().open_cursor(client, args)
}
}
fn prepare_datum(datum: Option<pg_sys::Datum>) -> (pg_sys::Datum, std::os::raw::c_char) {
match datum {
Some(datum) => (datum, ' ' as std::os::raw::c_char),
None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char),
}
}
impl<'a> Query for &'a str {
type Arguments = Option<Vec<(PgOid, Option<pg_sys::Datum>)>>;
type Result = Result<SpiTupleTable>;
fn execute(
self,
_client: &SpiClient,
limit: Option<i64>,
arguments: Self::Arguments,
) -> Self::Result {
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}
let src = CString::new(self).expect("query contained a null byte");
let status_code = match arguments {
Some(args) => {
let nargs = args.len();
let (types, data): (Vec<_>, Vec<_>) = args.into_iter().unzip();
let mut argtypes = types.into_iter().map(PgOid::value).collect::<Vec<_>>();
let (mut datums, nulls): (Vec<_>, Vec<_>) =
data.into_iter().map(prepare_datum).unzip();
unsafe {
pg_sys::SPI_execute_with_args(
src.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_read_only(),
limit.unwrap_or(0),
)
}
}
None => unsafe {
pg_sys::SPI_execute(src.as_ptr(), Spi::is_read_only(), limit.unwrap_or(0))
},
};
Ok(SpiClient::prepare_tuple_table(status_code)?)
}
fn open_cursor<'c: 'cc, 'cc>(
self,
_client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> SpiCursor<'c> {
let src = CString::new(self).expect("query contained a null byte");
let args = args.unwrap_or_default();
let nargs = args.len();
let (types, data): (Vec<_>, Vec<_>) = args.into_iter().unzip();
let mut argtypes = types.into_iter().map(PgOid::value).collect::<Vec<_>>();
let (mut datums, nulls): (Vec<_>, Vec<_>) = data.into_iter().map(prepare_datum).unzip();
let ptr = unsafe {
NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args(
std::ptr::null_mut(), src.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_read_only(),
0,
))
};
SpiCursor { ptr, __marker: PhantomData }
}
}
#[derive(Debug)]
pub struct SpiTupleTable {
#[allow(dead_code)]
status_code: SpiOkCodes,
table: Option<*mut pg_sys::SPITupleTable>,
size: usize,
current: isize,
}
pub struct SpiHeapTupleDataEntry {
datum: Option<pg_sys::Datum>,
type_oid: pg_sys::Oid,
}
pub struct SpiHeapTupleData {
tupdesc: NonNull<pg_sys::TupleDescData>,
entries: HashMap<usize, SpiHeapTupleDataEntry>,
}
impl Spi {
pub fn get_one<A: FromDatum + IntoDatum>(query: &str) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), None)?.first().get_one())
}
pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| client.update(query, Some(1), None)?.first().get_two::<A, B>())
}
pub fn get_three<
A: FromDatum + IntoDatum,
B: FromDatum + IntoDatum,
C: FromDatum + IntoDatum,
>(
query: &str,
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
client.update(query, Some(1), None)?.first().get_three::<A, B, C>()
})
}
pub fn get_one_with_args<A: FromDatum + IntoDatum>(
query: &str,
args: Vec<(PgOid, Option<pg_sys::Datum>)>,
) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), Some(args))?.first().get_one())
}
pub fn get_two_with_args<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
args: Vec<(PgOid, Option<pg_sys::Datum>)>,
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| {
client.update(query, Some(1), Some(args))?.first().get_two::<A, B>()
})
}
pub fn get_three_with_args<
A: FromDatum + IntoDatum,
B: FromDatum + IntoDatum,
C: FromDatum + IntoDatum,
>(
query: &str,
args: Vec<(PgOid, Option<pg_sys::Datum>)>,
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
client.update(query, Some(1), Some(args))?.first().get_three::<A, B, C>()
})
}
pub fn run(query: &str) -> std::result::Result<(), Error> {
Spi::run_with_args(query, None)
}
pub fn run_with_args(
query: &str,
args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>,
) -> std::result::Result<(), Error> {
Spi::connect(|mut client| client.update(query, None, args)).map(|_| ())
}
pub fn explain(query: &str) -> Result<Json> {
Spi::explain_with_args(query, None)
}
pub fn explain_with_args(
query: &str,
args: Option<Vec<(PgOid, Option<pg_sys::Datum>)>>,
) -> Result<Json> {
Ok(Spi::connect(|mut client| {
client
.update(&format!("EXPLAIN (format json) {}", query), None, args)?
.first()
.get_one::<Json>()
})?
.unwrap())
}
pub fn connect<R, F: FnOnce(SpiClient<'_>) -> R>(f: F) -> R {
let connection =
SpiConnection::connect().expect("SPI_connect indicated an unexpected failure");
f(connection.client())
}
#[track_caller]
pub fn check_status(status_code: i32) -> std::result::Result<SpiOkCodes, Error> {
match SpiOkCodes::try_from(status_code) {
Ok(ok) => Ok(ok),
Err(Err(UnknownVariant)) => panic!("unrecognized SPI status code: {status_code}"),
Err(Ok(code)) => Err(Error::SpiError(code)),
}
}
}
impl<'a> SpiClient<'a> {
pub fn select<Q: Query>(&self, query: Q, limit: Option<i64>, args: Q::Arguments) -> Q::Result {
self.execute(query, limit, args)
}
pub fn update<Q: Query>(
&mut self,
query: Q,
limit: Option<i64>,
args: Q::Arguments,
) -> Q::Result {
Spi::mark_mutable();
self.execute(query, limit, args)
}
fn execute<Q: Query>(&self, query: Q, limit: Option<i64>, args: Q::Arguments) -> Q::Result {
query.execute(&self, limit, args)
}
fn prepare_tuple_table(status_code: i32) -> std::result::Result<SpiTupleTable, Error> {
Ok(SpiTupleTable {
status_code: Spi::check_status(status_code)?,
table: unsafe {
if pg_sys::SPI_tuptable.is_null() {
None
} else {
Some(pg_sys::SPI_tuptable)
}
},
size: unsafe { pg_sys::SPI_processed as usize },
current: -1,
})
}
pub fn open_cursor<Q: Query>(&self, query: Q, args: Q::Arguments) -> SpiCursor {
query.open_cursor(&self, args)
}
pub fn open_cursor_mut<Q: Query>(&mut self, query: Q, args: Q::Arguments) -> SpiCursor {
Spi::mark_mutable();
query.open_cursor(&self, args)
}
pub fn find_cursor(&self, name: &str) -> Result<SpiCursor> {
use pgx_pg_sys::AsPgCStr;
let ptr = NonNull::new(unsafe { pg_sys::SPI_cursor_find(name.as_pg_cstr()) })
.ok_or(Error::CursorNotFound(name.to_string()))?;
Ok(SpiCursor { ptr, __marker: PhantomData })
}
}
type CursorName = String;
pub struct SpiCursor<'client> {
ptr: NonNull<pg_sys::PortalData>,
__marker: PhantomData<&'client SpiClient<'client>>,
}
impl SpiCursor<'_> {
pub fn fetch(&mut self, count: i64) -> std::result::Result<SpiTupleTable, Error> {
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}
unsafe { pg_sys::SPI_cursor_fetch(self.ptr.as_mut(), true, count) }
Ok(SpiClient::prepare_tuple_table(SpiOkCodes::Fetch as i32)?)
}
pub fn detach_into_name(self) -> CursorName {
let cursor_ptr = unsafe { self.ptr.as_ref() };
std::mem::forget(self);
unsafe { CStr::from_ptr(cursor_ptr.name) }
.to_str()
.expect("cursor name is not valid UTF8")
.to_string()
}
}
impl Drop for SpiCursor<'_> {
fn drop(&mut self) {
unsafe {
pg_sys::SPI_cursor_close(self.ptr.as_mut());
}
}
}
pub struct PreparedStatement<'a> {
plan: NonNull<pg_sys::_SPI_plan>,
__marker: PhantomData<&'a ()>,
}
pub struct OwnedPreparedStatement(PreparedStatement<'static>);
impl Deref for OwnedPreparedStatement {
type Target = PreparedStatement<'static>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Drop for OwnedPreparedStatement {
fn drop(&mut self) {
unsafe {
pg_sys::SPI_freeplan(self.0.plan.as_ptr());
}
}
}
impl<'a> Query for &'a OwnedPreparedStatement {
type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
type Result = Result<SpiTupleTable>;
fn execute(
self,
client: &SpiClient,
limit: Option<i64>,
arguments: Self::Arguments,
) -> Self::Result {
(&self.0).execute(client, limit, arguments)
}
fn open_cursor<'c: 'cc, 'cc>(
self,
client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> SpiCursor<'c> {
(&self.0).open_cursor(client, args)
}
}
impl Query for OwnedPreparedStatement {
type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
type Result = Result<SpiTupleTable>;
fn execute(
self,
client: &SpiClient,
limit: Option<i64>,
arguments: Self::Arguments,
) -> Self::Result {
(&self.0).execute(client, limit, arguments)
}
fn open_cursor<'c: 'cc, 'cc>(
self,
client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> SpiCursor<'c> {
(&self.0).open_cursor(client, args)
}
}
impl<'a> PreparedStatement<'a> {
pub fn keep(self) -> OwnedPreparedStatement {
unsafe {
pg_sys::SPI_keepplan(self.plan.as_ptr());
}
OwnedPreparedStatement(PreparedStatement { __marker: PhantomData, plan: self.plan })
}
}
impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> {
type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
type Result = Result<SpiTupleTable>;
fn execute(
self,
_client: &SpiClient,
limit: Option<i64>,
arguments: Self::Arguments,
) -> Self::Result {
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}
let args = arguments.unwrap_or_default();
let nargs = args.len();
let expected = unsafe { pg_sys::SPI_getargcount(self.plan.as_ptr()) } as usize;
if nargs != expected {
return Err(Error::PreparedStatementArgumentMismatch { expected, got: nargs });
}
let (mut datums, mut nulls): (Vec<_>, Vec<_>) = args.into_iter().map(prepare_datum).unzip();
let status_code = unsafe {
pg_sys::SPI_execute_plan(
self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_mut_ptr(),
Spi::is_read_only(),
limit.unwrap_or(0),
)
};
Ok(SpiClient::prepare_tuple_table(status_code)?)
}
fn open_cursor<'c: 'cc, 'cc>(
self,
_client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> SpiCursor<'c> {
let args = args.unwrap_or_default();
let (mut datums, nulls): (Vec<_>, Vec<_>) = args.into_iter().map(prepare_datum).unzip();
let ptr = unsafe {
NonNull::new_unchecked(pg_sys::SPI_cursor_open(
std::ptr::null_mut(), self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_read_only(),
))
};
SpiCursor { ptr, __marker: PhantomData }
}
}
impl<'a> Query for PreparedStatement<'a> {
type Arguments = Option<Vec<Option<pg_sys::Datum>>>;
type Result = Result<SpiTupleTable>;
fn execute(
self,
client: &SpiClient,
limit: Option<i64>,
arguments: Self::Arguments,
) -> Self::Result {
(&self).execute(client, limit, arguments)
}
fn open_cursor<'c: 'cc, 'cc>(
self,
client: &'cc SpiClient<'c>,
args: Self::Arguments,
) -> SpiCursor<'c> {
(&self).open_cursor(client, args)
}
}
impl<'a> SpiClient<'a> {
pub fn prepare(&self, query: &str, args: Option<Vec<PgOid>>) -> Result<PreparedStatement> {
let src = CString::new(query).expect("query contained a null byte");
let args = args.unwrap_or_default();
let nargs = args.len();
let plan = unsafe {
pg_sys::SPI_prepare(
src.as_ptr(),
nargs as i32,
args.into_iter().map(PgOid::value).collect::<Vec<_>>().as_mut_ptr(),
)
};
Ok(PreparedStatement {
plan: NonNull::new(plan).ok_or_else(|| {
Spi::check_status(unsafe {
pg_sys::SPI_result
})
.err()
.unwrap()
})?,
__marker: PhantomData,
})
}
}
impl SpiTupleTable {
pub fn first(mut self) -> Self {
self.current = 0;
self
}
pub fn rewind(mut self) -> Self {
self.current = -1;
self
}
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get_one<A: FromDatum + IntoDatum>(&self) -> Result<Option<A>> {
self.get(1)
}
pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
&self,
) -> Result<(Option<A>, Option<B>)> {
let a = self.get::<A>(1)?;
let b = self.get::<B>(2)?;
Ok((a, b))
}
pub fn get_three<
A: FromDatum + IntoDatum,
B: FromDatum + IntoDatum,
C: FromDatum + IntoDatum,
>(
&self,
) -> Result<(Option<A>, Option<B>, Option<C>)> {
let a = self.get::<A>(1)?;
let b = self.get::<B>(2)?;
let c = self.get::<C>(3)?;
Ok((a, b, c))
}
#[inline(always)]
fn get_spi_tuptable(&self) -> Result<(*mut pg_sys::SPITupleTable, *mut pg_sys::TupleDescData)> {
let table = *self.table.as_ref().ok_or(Error::NoTupleTable)?;
unsafe {
Ok((table, (*table).tupdesc))
}
}
pub fn get_heap_tuple(&self) -> Result<Option<SpiHeapTupleData>> {
if self.size == 0 || self.table.is_none() {
Ok(None)
} else if self.current < 0 || self.current as usize >= self.size {
Err(Error::InvalidPosition)
} else {
let (table, tupdesc) = self.get_spi_tuptable()?;
unsafe {
let heap_tuple =
std::slice::from_raw_parts((*table).vals, self.size)[self.current as usize];
SpiHeapTupleData::new(tupdesc, heap_tuple)
}
}
}
pub fn get<T: IntoDatum + FromDatum>(&self, ordinal: usize) -> Result<Option<T>> {
let (_, tupdesc) = self.get_spi_tuptable()?;
let datum = self.get_datum_by_ordinal(ordinal)?;
let is_null = datum.is_none();
let datum = datum.unwrap_or_else(|| pg_sys::Datum::from(0));
unsafe {
Ok(T::try_from_datum_in_memory_context(
PgMemoryContexts::CurrentMemoryContext
.parent()
.expect("parent memory context is absent"),
datum,
is_null,
pg_sys::SPI_gettypeid(tupdesc, ordinal as _),
)?)
}
}
pub fn get_by_name<T: IntoDatum + FromDatum, S: AsRef<str>>(
&self,
name: S,
) -> Result<Option<T>> {
self.get(self.column_ordinal(name)?)
}
pub fn get_datum_by_ordinal(&self, ordinal: usize) -> Result<Option<pg_sys::Datum>> {
self.check_ordinal_bounds(ordinal)?;
let (table, tupdesc) = self.get_spi_tuptable()?;
if self.current < 0 || self.current as usize >= self.size {
return Err(Error::InvalidPosition);
}
unsafe {
let heap_tuple =
std::slice::from_raw_parts((*table).vals, self.size)[self.current as usize];
let mut is_null = false;
let datum = pg_sys::SPI_getbinval(heap_tuple, tupdesc, ordinal as _, &mut is_null);
if is_null {
Ok(None)
} else {
Ok(Some(datum))
}
}
}
pub fn get_datum_by_name<S: AsRef<str>>(&self, name: S) -> Result<Option<pg_sys::Datum>> {
self.get_datum_by_ordinal(self.column_ordinal(name)?)
}
pub fn columns(&self) -> Result<usize> {
let (_, tupdesc) = self.get_spi_tuptable()?;
Ok(unsafe { (*tupdesc).natts as _ })
}
#[inline]
fn check_ordinal_bounds(&self, ordinal: usize) -> Result<()> {
if ordinal < 1 || ordinal > self.columns()? {
Err(Error::SpiError(SpiErrorCodes::NoAttribute))
} else {
Ok(())
}
}
pub fn column_type_oid(&self, ordinal: usize) -> Result<PgOid> {
self.check_ordinal_bounds(ordinal)?;
let (_, tupdesc) = self.get_spi_tuptable()?;
unsafe {
let oid = pg_sys::SPI_gettypeid(tupdesc, ordinal as i32);
Ok(PgOid::from(oid))
}
}
pub fn column_name(&self, ordinal: usize) -> Result<String> {
self.check_ordinal_bounds(ordinal)?;
let (_, tupdesc) = self.get_spi_tuptable()?;
unsafe {
let name = pg_sys::SPI_fname(tupdesc, ordinal as i32);
let str =
CStr::from_ptr(name).to_str().expect("column name is not value UTF8").to_string();
pg_sys::pfree(name as *mut _);
Ok(str)
}
}
pub fn column_ordinal<S: AsRef<str>>(&self, name: S) -> Result<usize> {
let (_, tupdesc) = self.get_spi_tuptable()?;
unsafe {
let name_cstr = CString::new(name.as_ref()).expect("name contained a null byte");
let fnumber = pg_sys::SPI_fnumber(tupdesc, name_cstr.as_ptr());
if fnumber == pg_sys::SPI_ERROR_NOATTRIBUTE {
Err(Error::SpiError(SpiErrorCodes::NoAttribute))
} else {
Ok(fnumber as usize)
}
}
}
}
impl SpiHeapTupleData {
pub unsafe fn new(
tupdesc: pg_sys::TupleDesc,
htup: *mut pg_sys::HeapTupleData,
) -> Result<Option<Self>> {
let tupdesc = NonNull::new(tupdesc).ok_or(Error::NoTupleTable)?;
let mut data = SpiHeapTupleData { tupdesc, entries: HashMap::default() };
let tupdesc = tupdesc.as_ptr();
unsafe {
for i in 1..=tupdesc.as_ref().unwrap().natts {
let mut is_null = false;
let datum = pg_sys::SPI_getbinval(htup, tupdesc, i, &mut is_null);
data.entries.entry(i as usize).or_insert_with(|| SpiHeapTupleDataEntry {
datum: if is_null { None } else { Some(datum) },
type_oid: pg_sys::SPI_gettypeid(tupdesc, i),
});
}
}
Ok(Some(data))
}
pub fn get<T: IntoDatum + FromDatum>(&self, ordinal: usize) -> Result<Option<T>> {
self.get_datum_by_ordinal(ordinal).map(|entry| entry.value())?
}
pub fn get_by_name<T: IntoDatum + FromDatum, S: AsRef<str>>(
&self,
name: S,
) -> Result<Option<T>> {
self.get_datum_by_name(name.as_ref()).map(|entry| entry.value())?
}
pub fn get_datum_by_ordinal(
&self,
ordinal: usize,
) -> std::result::Result<&SpiHeapTupleDataEntry, Error> {
self.entries.get(&ordinal).ok_or_else(|| Error::SpiError(SpiErrorCodes::NoAttribute))
}
pub fn get_datum_by_name<S: AsRef<str>>(
&self,
name: S,
) -> std::result::Result<&SpiHeapTupleDataEntry, Error> {
unsafe {
let name_cstr = CString::new(name.as_ref()).expect("name contained a null byte");
let fnumber = pg_sys::SPI_fnumber(self.tupdesc.as_ptr(), name_cstr.as_ptr());
if fnumber == pg_sys::SPI_ERROR_NOATTRIBUTE {
Err(Error::SpiError(SpiErrorCodes::NoAttribute))
} else {
self.get_datum_by_ordinal(fnumber as usize)
}
}
}
pub fn set_by_ordinal<T: IntoDatum>(
&mut self,
ordinal: usize,
datum: T,
) -> std::result::Result<(), Error> {
self.check_ordinal_bounds(ordinal)?;
self.entries.insert(
ordinal,
SpiHeapTupleDataEntry { datum: datum.into_datum(), type_oid: T::type_oid() },
);
Ok(())
}
pub fn set_by_name<T: IntoDatum>(
&mut self,
name: &str,
datum: T,
) -> std::result::Result<(), Error> {
unsafe {
let name_cstr = CString::new(name).expect("name contained a null byte");
let fnumber = pg_sys::SPI_fnumber(self.tupdesc.as_ptr(), name_cstr.as_ptr());
if fnumber == pg_sys::SPI_ERROR_NOATTRIBUTE {
Err(Error::SpiError(SpiErrorCodes::NoAttribute))
} else {
self.set_by_ordinal(fnumber as usize, datum)
}
}
}
#[inline]
pub fn columns(&self) -> usize {
unsafe {
(*self.tupdesc.as_ptr()).natts as usize
}
}
#[inline]
fn check_ordinal_bounds(&self, ordinal: usize) -> std::result::Result<(), Error> {
if ordinal < 1 || ordinal > self.columns() {
Err(Error::SpiError(SpiErrorCodes::NoAttribute))
} else {
Ok(())
}
}
}
impl SpiHeapTupleDataEntry {
pub fn value<T: IntoDatum + FromDatum>(&self) -> Result<Option<T>> {
match self.datum.as_ref() {
Some(datum) => unsafe {
T::try_from_datum(*datum, false, self.type_oid).map_err(|e| Error::DatumError(e))
},
None => Ok(None),
}
}
pub fn oid(&self) -> pg_sys::Oid {
self.type_oid
}
}
impl Index<usize> for SpiHeapTupleData {
type Output = SpiHeapTupleDataEntry;
fn index(&self, index: usize) -> &Self::Output {
self.get_datum_by_ordinal(index).expect("invalid ordinal value")
}
}
impl Index<&str> for SpiHeapTupleData {
type Output = SpiHeapTupleDataEntry;
fn index(&self, index: &str) -> &Self::Output {
self.get_datum_by_name(index).expect("invalid field name")
}
}
impl Iterator for SpiTupleTable {
type Item = SpiHeapTupleData;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.current += 1;
if self.current >= self.size as isize {
None
} else {
assert!(self.current >= 0);
self.get_heap_tuple().report()
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.size))
}
#[inline]
fn count(self) -> usize
where
Self: Sized,
{
self.size
}
}