use std::ffi::{CStr, CString};
use std::marker::PhantomData;
use std::ops::Deref;
use std::ptr::NonNull;
use libc::c_char;
use super::{Spi, SpiClient, SpiCursor, SpiError, SpiResult, SpiTupleTable};
use crate::{
datum::DatumWithOid,
pg_sys::{self, PgOid},
};
pub trait Query<'conn>: Sized {
fn execute<'mcx>(
self,
client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiTupleTable<'conn>>;
#[deprecated(since = "0.12.2", note = "undefined behavior")]
fn open_cursor<'mcx>(
self,
client: &SpiClient<'conn>,
args: &[DatumWithOid<'mcx>],
) -> SpiCursor<'conn> {
self.try_open_cursor(client, args).unwrap()
}
fn try_open_cursor<'mcx>(
self,
client: &SpiClient<'conn>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiCursor<'conn>>;
}
pub trait PreparableQuery<'conn>: Query<'conn> {
fn prepare(
self,
client: &SpiClient<'conn>,
args: &[PgOid],
) -> SpiResult<PreparedStatement<'conn>>;
fn prepare_mut(
self,
client: &SpiClient<'conn>,
args: &[PgOid],
) -> SpiResult<PreparedStatement<'conn>>;
}
fn execute<'conn, 'mcx>(
cmd: &CStr,
args: &[DatumWithOid<'mcx>],
limit: Option<libc::c_long>,
) -> SpiResult<SpiTupleTable<'conn>> {
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}
let status_code = match args.len() {
0 => unsafe {
pg_sys::SPI_execute(cmd.as_ptr(), Spi::is_xact_still_immutable(), limit.unwrap_or(0))
},
nargs => {
let (mut argtypes, mut datums, nulls) = args_to_datums(args);
unsafe {
pg_sys::SPI_execute_with_args(
cmd.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
limit.unwrap_or(0),
)
}
}
};
SpiClient::prepare_tuple_table(status_code)
}
fn open_cursor<'conn, 'mcx>(
cmd: &CStr,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiCursor<'conn>> {
let nargs = args.len();
let (mut argtypes, mut datums, nulls) = args_to_datums(args);
let ptr = unsafe {
NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args(
std::ptr::null_mut(), cmd.as_ptr(),
nargs as i32,
argtypes.as_mut_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
Spi::is_xact_still_immutable(),
0,
))
};
Ok(SpiCursor { ptr, __marker: PhantomData })
}
fn args_to_datums<'mcx>(
args: &[DatumWithOid<'mcx>],
) -> (Vec<pg_sys::Oid>, Vec<pg_sys::Datum>, Vec<c_char>) {
let mut argtypes = Vec::with_capacity(args.len());
let mut datums = Vec::with_capacity(args.len());
let mut nulls = Vec::with_capacity(args.len());
for arg in args {
let (datum, null) = prepare_datum(arg);
argtypes.push(arg.oid());
datums.push(datum);
nulls.push(null);
}
(argtypes, datums, nulls)
}
fn prepare_datum<'mcx>(datum: &DatumWithOid<'mcx>) -> (pg_sys::Datum, std::os::raw::c_char) {
match datum.datum() {
Some(datum) => (datum.sans_lifetime(), ' ' as std::os::raw::c_char),
None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char),
}
}
fn prepare<'conn>(
cmd: &CStr,
args: &[PgOid],
mutating: bool,
) -> SpiResult<PreparedStatement<'conn>> {
let plan = unsafe {
pg_sys::SPI_prepare(
cmd.as_ptr(),
args.len() as i32,
args.iter().map(|arg| arg.value()).collect::<Vec<_>>().as_mut_ptr(),
)
};
Ok(PreparedStatement {
plan: NonNull::new(plan).ok_or_else(|| {
Spi::check_status(unsafe {
pg_sys::SPI_result
})
.err()
.unwrap()
})?,
__marker: PhantomData,
mutating,
})
}
macro_rules! impl_prepared_query {
($t:ident, $s:ident) => {
impl<'conn> Query<'conn> for &$t {
#[inline]
fn execute(
self,
_client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: &[DatumWithOid],
) -> SpiResult<SpiTupleTable<'conn>> {
execute($s(self).as_ref(), args, limit)
}
#[inline]
fn try_open_cursor(
self,
_client: &SpiClient<'conn>,
args: &[DatumWithOid],
) -> SpiResult<SpiCursor<'conn>> {
open_cursor($s(self).as_ref(), args)
}
}
impl<'conn> PreparableQuery<'conn> for &$t {
fn prepare(
self,
_client: &SpiClient<'conn>,
args: &[PgOid],
) -> SpiResult<PreparedStatement<'conn>> {
prepare($s(self).as_ref(), args, false)
}
fn prepare_mut(
self,
_client: &SpiClient<'conn>,
args: &[PgOid],
) -> SpiResult<PreparedStatement<'conn>> {
prepare($s(self).as_ref(), args, true)
}
}
};
}
#[inline]
fn pass_as_is<T>(s: T) -> T {
s
}
#[inline]
fn pass_with_conv<T: AsRef<str>>(s: T) -> CString {
CString::new(s.as_ref()).expect("query contained a null byte")
}
impl_prepared_query!(CStr, pass_as_is);
impl_prepared_query!(CString, pass_as_is);
impl_prepared_query!(String, pass_with_conv);
impl_prepared_query!(str, pass_with_conv);
pub struct PreparedStatement<'conn> {
pub(super) plan: NonNull<pg_sys::_SPI_plan>,
pub(super) __marker: PhantomData<&'conn ()>,
pub(super) mutating: bool,
}
pub struct OwnedPreparedStatement(PreparedStatement<'static>);
impl Deref for OwnedPreparedStatement {
type Target = PreparedStatement<'static>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Drop for OwnedPreparedStatement {
fn drop(&mut self) {
unsafe {
pg_sys::SPI_freeplan(self.0.plan.as_ptr());
}
}
}
impl<'conn> Query<'conn> for &OwnedPreparedStatement {
fn execute<'mcx>(
self,
client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiTupleTable<'conn>> {
(&self.0).execute(client, limit, args)
}
fn try_open_cursor<'mcx>(
self,
client: &SpiClient<'conn>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiCursor<'conn>> {
(&self.0).try_open_cursor(client, args)
}
}
impl<'conn> Query<'conn> for OwnedPreparedStatement {
fn execute<'mcx>(
self,
client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiTupleTable<'conn>> {
(&self.0).execute(client, limit, args)
}
fn try_open_cursor<'mcx>(
self,
client: &SpiClient<'conn>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiCursor<'conn>> {
(&self.0).try_open_cursor(client, args)
}
}
impl<'conn> PreparedStatement<'conn> {
pub fn keep(self) -> OwnedPreparedStatement {
unsafe {
pg_sys::SPI_keepplan(self.plan.as_ptr());
}
OwnedPreparedStatement(PreparedStatement {
__marker: PhantomData,
plan: self.plan,
mutating: self.mutating,
})
}
fn args_to_datums<'mcx>(
&self,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<(Vec<pg_sys::Datum>, Vec<std::os::raw::c_char>)> {
let actual = args.len();
let expected = unsafe { pg_sys::SPI_getargcount(self.plan.as_ptr()) } as usize;
if expected == actual {
Ok(args.iter().map(prepare_datum).unzip())
} else {
Err(SpiError::PreparedStatementArgumentMismatch { expected, got: actual })
}
}
}
impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
fn execute<'mcx>(
self,
_client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiTupleTable<'conn>> {
unsafe {
pg_sys::SPI_tuptable = std::ptr::null_mut();
}
let (mut datums, mut nulls) = self.args_to_datums(args)?;
let status_code = unsafe {
pg_sys::SPI_execute_plan(
self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_mut_ptr(),
!self.mutating && Spi::is_xact_still_immutable(),
limit.unwrap_or(0),
)
};
SpiClient::prepare_tuple_table(status_code)
}
fn try_open_cursor<'mcx>(
self,
_client: &SpiClient<'conn>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiCursor<'conn>> {
let (mut datums, nulls) = self.args_to_datums(args)?;
let ptr = unsafe {
NonNull::new_unchecked(pg_sys::SPI_cursor_open(
std::ptr::null_mut(), self.plan.as_ptr(),
datums.as_mut_ptr(),
nulls.as_ptr(),
!self.mutating && Spi::is_xact_still_immutable(),
))
};
Ok(SpiCursor { ptr, __marker: PhantomData })
}
}
impl<'conn> Query<'conn> for PreparedStatement<'conn> {
fn execute<'mcx>(
self,
client: &SpiClient<'conn>,
limit: Option<libc::c_long>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiTupleTable<'conn>> {
(&self).execute(client, limit, args)
}
fn try_open_cursor<'mcx>(
self,
client: &SpiClient<'conn>,
args: &[DatumWithOid<'mcx>],
) -> SpiResult<SpiCursor<'conn>> {
(&self).try_open_cursor(client, args)
}
}