1use std::ffi::{CStr, CString};
2use std::marker::PhantomData;
3use std::ops::Deref;
4use std::ptr::NonNull;
5
6use libc::c_char;
7
8use super::{Spi, SpiClient, SpiCursor, SpiError, SpiResult, SpiTupleTable};
9use crate::{
10 datum::DatumWithOid,
11 pg_sys::{self, PgOid},
12};
13
14pub trait Query<'conn>: Sized {
20 fn execute<'mcx>(
22 self,
23 client: &SpiClient<'conn>,
24 limit: Option<libc::c_long>,
25 args: &[DatumWithOid<'mcx>],
26 ) -> SpiResult<SpiTupleTable<'conn>>;
27
28 #[deprecated(since = "0.12.2", note = "undefined behavior")]
34 fn open_cursor<'mcx>(
35 self,
36 client: &SpiClient<'conn>,
37 args: &[DatumWithOid<'mcx>],
38 ) -> SpiCursor<'conn> {
39 self.try_open_cursor(client, args).unwrap()
40 }
41
42 fn try_open_cursor<'mcx>(
44 self,
45 client: &SpiClient<'conn>,
46 args: &[DatumWithOid<'mcx>],
47 ) -> SpiResult<SpiCursor<'conn>>;
48}
49
50pub trait PreparableQuery<'conn>: Query<'conn> {
52 fn prepare(
54 self,
55 client: &SpiClient<'conn>,
56 args: &[PgOid],
57 ) -> SpiResult<PreparedStatement<'conn>>;
58
59 fn prepare_mut(
61 self,
62 client: &SpiClient<'conn>,
63 args: &[PgOid],
64 ) -> SpiResult<PreparedStatement<'conn>>;
65}
66
67fn execute<'conn>(
68 cmd: &CStr,
69 args: &[DatumWithOid<'_>],
70 limit: Option<libc::c_long>,
71) -> SpiResult<SpiTupleTable<'conn>> {
72 unsafe {
74 pg_sys::SPI_tuptable = std::ptr::null_mut();
75 }
76
77 let status_code = match args.len() {
78 0 => unsafe {
80 pg_sys::SPI_execute(cmd.as_ptr(), Spi::is_xact_still_immutable(), limit.unwrap_or(0))
81 },
82 nargs => {
83 let (mut argtypes, mut datums, nulls) = args_to_datums(args);
84
85 unsafe {
87 pg_sys::SPI_execute_with_args(
88 cmd.as_ptr(),
89 nargs as i32,
90 argtypes.as_mut_ptr(),
91 datums.as_mut_ptr(),
92 nulls.as_ptr(),
93 Spi::is_xact_still_immutable(),
94 limit.unwrap_or(0),
95 )
96 }
97 }
98 };
99
100 SpiClient::prepare_tuple_table(status_code)
101}
102
103fn open_cursor<'conn>(cmd: &CStr, args: &[DatumWithOid<'_>]) -> SpiResult<SpiCursor<'conn>> {
104 let nargs = args.len();
105 let (mut argtypes, mut datums, nulls) = args_to_datums(args);
106
107 let ptr = unsafe {
108 NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args(
111 std::ptr::null_mut(), cmd.as_ptr(),
113 nargs as i32,
114 argtypes.as_mut_ptr(),
115 datums.as_mut_ptr(),
116 nulls.as_ptr(),
117 Spi::is_xact_still_immutable(),
118 0,
119 ))
120 };
121
122 Ok(SpiCursor { ptr, __marker: PhantomData })
123}
124
125fn args_to_datums(
126 args: &[DatumWithOid<'_>],
127) -> (Vec<pg_sys::Oid>, Vec<pg_sys::Datum>, Vec<c_char>) {
128 let mut argtypes = Vec::with_capacity(args.len());
129 let mut datums = Vec::with_capacity(args.len());
130 let mut nulls = Vec::with_capacity(args.len());
131
132 for arg in args {
133 let (datum, null) = prepare_datum(arg);
134
135 argtypes.push(arg.oid());
136 datums.push(datum);
137 nulls.push(null);
138 }
139
140 (argtypes, datums, nulls)
141}
142
143fn prepare_datum(datum: &DatumWithOid<'_>) -> (pg_sys::Datum, std::os::raw::c_char) {
144 match datum.datum() {
145 Some(datum) => (datum.sans_lifetime(), ' ' as std::os::raw::c_char),
146 None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char),
147 }
148}
149
150fn prepare<'conn>(
151 cmd: &CStr,
152 args: &[PgOid],
153 mutating: bool,
154) -> SpiResult<PreparedStatement<'conn>> {
155 let plan = unsafe {
157 pg_sys::SPI_prepare(
158 cmd.as_ptr(),
159 args.len() as i32,
160 args.iter().map(|arg| arg.value()).collect::<Vec<_>>().as_mut_ptr(),
161 )
162 };
163 Ok(PreparedStatement {
164 plan: NonNull::new(plan).ok_or_else(|| {
165 Spi::check_status(unsafe {
166 pg_sys::SPI_result
168 })
169 .err()
170 .unwrap()
171 })?,
172 __marker: PhantomData,
173 mutating,
174 })
175}
176
177macro_rules! impl_prepared_query {
178 ($t:ident, $s:ident) => {
179 impl<'conn> Query<'conn> for &$t {
180 #[inline]
181 fn execute(
182 self,
183 _client: &SpiClient<'conn>,
184 limit: Option<libc::c_long>,
185 args: &[DatumWithOid],
186 ) -> SpiResult<SpiTupleTable<'conn>> {
187 execute($s(self).as_ref(), args, limit)
188 }
189
190 #[inline]
191 fn try_open_cursor(
192 self,
193 _client: &SpiClient<'conn>,
194 args: &[DatumWithOid],
195 ) -> SpiResult<SpiCursor<'conn>> {
196 open_cursor($s(self).as_ref(), args)
197 }
198 }
199
200 impl<'conn> PreparableQuery<'conn> for &$t {
201 fn prepare(
202 self,
203 _client: &SpiClient<'conn>,
204 args: &[PgOid],
205 ) -> SpiResult<PreparedStatement<'conn>> {
206 prepare($s(self).as_ref(), args, false)
207 }
208
209 fn prepare_mut(
210 self,
211 _client: &SpiClient<'conn>,
212 args: &[PgOid],
213 ) -> SpiResult<PreparedStatement<'conn>> {
214 prepare($s(self).as_ref(), args, true)
215 }
216 }
217 };
218}
219
220#[inline]
221fn pass_as_is<T>(s: T) -> T {
222 s
223}
224
225#[inline]
226fn pass_with_conv<T: AsRef<str>>(s: T) -> CString {
227 CString::new(s.as_ref()).expect("query contained a null byte")
228}
229
230impl_prepared_query!(CStr, pass_as_is);
231impl_prepared_query!(CString, pass_as_is);
232impl_prepared_query!(String, pass_with_conv);
233impl_prepared_query!(str, pass_with_conv);
234
235pub struct PreparedStatement<'conn> {
237 pub(super) plan: NonNull<pg_sys::_SPI_plan>,
238 pub(super) __marker: PhantomData<&'conn ()>,
239 pub(super) mutating: bool,
240}
241
242pub struct OwnedPreparedStatement(PreparedStatement<'static>);
244
245impl Deref for OwnedPreparedStatement {
246 type Target = PreparedStatement<'static>;
247
248 fn deref(&self) -> &Self::Target {
249 &self.0
250 }
251}
252
253impl Drop for OwnedPreparedStatement {
254 fn drop(&mut self) {
255 unsafe {
256 pg_sys::SPI_freeplan(self.0.plan.as_ptr());
257 }
258 }
259}
260
261impl<'conn> Query<'conn> for &OwnedPreparedStatement {
262 fn execute<'mcx>(
263 self,
264 client: &SpiClient<'conn>,
265 limit: Option<libc::c_long>,
266 args: &[DatumWithOid<'mcx>],
267 ) -> SpiResult<SpiTupleTable<'conn>> {
268 (&self.0).execute(client, limit, args)
269 }
270
271 fn try_open_cursor<'mcx>(
272 self,
273 client: &SpiClient<'conn>,
274 args: &[DatumWithOid<'mcx>],
275 ) -> SpiResult<SpiCursor<'conn>> {
276 (&self.0).try_open_cursor(client, args)
277 }
278}
279
280impl<'conn> Query<'conn> for OwnedPreparedStatement {
281 fn execute<'mcx>(
282 self,
283 client: &SpiClient<'conn>,
284 limit: Option<libc::c_long>,
285 args: &[DatumWithOid<'mcx>],
286 ) -> SpiResult<SpiTupleTable<'conn>> {
287 (&self.0).execute(client, limit, args)
288 }
289
290 fn try_open_cursor<'mcx>(
291 self,
292 client: &SpiClient<'conn>,
293 args: &[DatumWithOid<'mcx>],
294 ) -> SpiResult<SpiCursor<'conn>> {
295 (&self.0).try_open_cursor(client, args)
296 }
297}
298
299impl PreparedStatement<'_> {
300 pub fn keep(self) -> OwnedPreparedStatement {
304 unsafe {
308 pg_sys::SPI_keepplan(self.plan.as_ptr());
309 }
310 OwnedPreparedStatement(PreparedStatement {
311 __marker: PhantomData,
312 plan: self.plan,
313 mutating: self.mutating,
314 })
315 }
316
317 fn args_to_datums(
318 &self,
319 args: &[DatumWithOid<'_>],
320 ) -> SpiResult<(Vec<pg_sys::Datum>, Vec<std::os::raw::c_char>)> {
321 let actual = args.len();
322 let expected = unsafe { pg_sys::SPI_getargcount(self.plan.as_ptr()) } as usize;
323
324 if expected == actual {
325 Ok(args.iter().map(prepare_datum).unzip())
326 } else {
327 Err(SpiError::PreparedStatementArgumentMismatch { expected, got: actual })
328 }
329 }
330}
331
332impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
333 fn execute<'mcx>(
334 self,
335 _client: &SpiClient<'conn>,
336 limit: Option<libc::c_long>,
337 args: &[DatumWithOid<'mcx>],
338 ) -> SpiResult<SpiTupleTable<'conn>> {
339 unsafe {
341 pg_sys::SPI_tuptable = std::ptr::null_mut();
342 }
343
344 let (mut datums, mut nulls) = self.args_to_datums(args)?;
345
346 let status_code = unsafe {
348 pg_sys::SPI_execute_plan(
349 self.plan.as_ptr(),
350 datums.as_mut_ptr(),
351 nulls.as_mut_ptr(),
352 !self.mutating && Spi::is_xact_still_immutable(),
353 limit.unwrap_or(0),
354 )
355 };
356
357 SpiClient::prepare_tuple_table(status_code)
358 }
359
360 fn try_open_cursor<'mcx>(
361 self,
362 _client: &SpiClient<'conn>,
363 args: &[DatumWithOid<'mcx>],
364 ) -> SpiResult<SpiCursor<'conn>> {
365 let (mut datums, nulls) = self.args_to_datums(args)?;
366
367 let ptr = unsafe {
370 NonNull::new_unchecked(pg_sys::SPI_cursor_open(
371 std::ptr::null_mut(), self.plan.as_ptr(),
373 datums.as_mut_ptr(),
374 nulls.as_ptr(),
375 !self.mutating && Spi::is_xact_still_immutable(),
376 ))
377 };
378 Ok(SpiCursor { ptr, __marker: PhantomData })
379 }
380}
381
382impl<'conn> Query<'conn> for PreparedStatement<'conn> {
383 fn execute<'mcx>(
384 self,
385 client: &SpiClient<'conn>,
386 limit: Option<libc::c_long>,
387 args: &[DatumWithOid<'mcx>],
388 ) -> SpiResult<SpiTupleTable<'conn>> {
389 (&self).execute(client, limit, args)
390 }
391
392 fn try_open_cursor<'mcx>(
393 self,
394 client: &SpiClient<'conn>,
395 args: &[DatumWithOid<'mcx>],
396 ) -> SpiResult<SpiCursor<'conn>> {
397 (&self).try_open_cursor(client, args)
398 }
399}