pgrx/spi/
query.rs

1use std::ffi::{CStr, CString};
2use std::marker::PhantomData;
3use std::ops::Deref;
4use std::ptr::NonNull;
5
6use libc::c_char;
7
8use super::{Spi, SpiClient, SpiCursor, SpiError, SpiResult, SpiTupleTable};
9use crate::{
10    datum::DatumWithOid,
11    pg_sys::{self, PgOid},
12};
13
14/// A generalized interface to what constitutes a query
15///
16/// Its primary purpose is to abstract away differences between
17/// one-off statements and prepared statements, but it can potentially
18/// be implemented for other types, provided they can be converted into a query.
19pub trait Query<'conn>: Sized {
20    /// Execute a query given a client and other arguments.
21    fn execute<'mcx>(
22        self,
23        client: &SpiClient<'conn>,
24        limit: Option<libc::c_long>,
25        args: &[DatumWithOid<'mcx>],
26    ) -> SpiResult<SpiTupleTable<'conn>>;
27
28    /// Open a cursor for the query.
29    ///
30    /// # Panics
31    ///
32    /// Panics if a cursor wasn't opened.
33    #[deprecated(since = "0.12.2", note = "undefined behavior")]
34    fn open_cursor<'mcx>(
35        self,
36        client: &SpiClient<'conn>,
37        args: &[DatumWithOid<'mcx>],
38    ) -> SpiCursor<'conn> {
39        self.try_open_cursor(client, args).unwrap()
40    }
41
42    /// Tries to open cursor for the query.
43    fn try_open_cursor<'mcx>(
44        self,
45        client: &SpiClient<'conn>,
46        args: &[DatumWithOid<'mcx>],
47    ) -> SpiResult<SpiCursor<'conn>>;
48}
49
50/// A trait representing a query which can be prepared.
51pub trait PreparableQuery<'conn>: Query<'conn> {
52    /// Prepares a query.
53    fn prepare(
54        self,
55        client: &SpiClient<'conn>,
56        args: &[PgOid],
57    ) -> SpiResult<PreparedStatement<'conn>>;
58
59    /// Prepares a query allowed to change data
60    fn prepare_mut(
61        self,
62        client: &SpiClient<'conn>,
63        args: &[PgOid],
64    ) -> SpiResult<PreparedStatement<'conn>>;
65}
66
67fn execute<'conn>(
68    cmd: &CStr,
69    args: &[DatumWithOid<'_>],
70    limit: Option<libc::c_long>,
71) -> SpiResult<SpiTupleTable<'conn>> {
72    // SAFETY: no concurrent access
73    unsafe {
74        pg_sys::SPI_tuptable = std::ptr::null_mut();
75    }
76
77    let status_code = match args.len() {
78        // SAFETY: arguments are prepared above
79        0 => unsafe {
80            pg_sys::SPI_execute(cmd.as_ptr(), Spi::is_xact_still_immutable(), limit.unwrap_or(0))
81        },
82        nargs => {
83            let (mut argtypes, mut datums, nulls) = args_to_datums(args);
84
85            // SAFETY: arguments are prepared above
86            unsafe {
87                pg_sys::SPI_execute_with_args(
88                    cmd.as_ptr(),
89                    nargs as i32,
90                    argtypes.as_mut_ptr(),
91                    datums.as_mut_ptr(),
92                    nulls.as_ptr(),
93                    Spi::is_xact_still_immutable(),
94                    limit.unwrap_or(0),
95                )
96            }
97        }
98    };
99
100    SpiClient::prepare_tuple_table(status_code)
101}
102
103fn open_cursor<'conn>(cmd: &CStr, args: &[DatumWithOid<'_>]) -> SpiResult<SpiCursor<'conn>> {
104    let nargs = args.len();
105    let (mut argtypes, mut datums, nulls) = args_to_datums(args);
106
107    let ptr = unsafe {
108        // SAFETY: arguments are prepared above and SPI_cursor_open_with_args will never return
109        // the null pointer.  It'll raise an ERROR if something is invalid for it to create the cursor
110        NonNull::new_unchecked(pg_sys::SPI_cursor_open_with_args(
111            std::ptr::null_mut(), // let postgres assign a name
112            cmd.as_ptr(),
113            nargs as i32,
114            argtypes.as_mut_ptr(),
115            datums.as_mut_ptr(),
116            nulls.as_ptr(),
117            Spi::is_xact_still_immutable(),
118            0,
119        ))
120    };
121
122    Ok(SpiCursor { ptr, __marker: PhantomData })
123}
124
125fn args_to_datums(
126    args: &[DatumWithOid<'_>],
127) -> (Vec<pg_sys::Oid>, Vec<pg_sys::Datum>, Vec<c_char>) {
128    let mut argtypes = Vec::with_capacity(args.len());
129    let mut datums = Vec::with_capacity(args.len());
130    let mut nulls = Vec::with_capacity(args.len());
131
132    for arg in args {
133        let (datum, null) = prepare_datum(arg);
134
135        argtypes.push(arg.oid());
136        datums.push(datum);
137        nulls.push(null);
138    }
139
140    (argtypes, datums, nulls)
141}
142
143fn prepare_datum(datum: &DatumWithOid<'_>) -> (pg_sys::Datum, std::os::raw::c_char) {
144    match datum.datum() {
145        Some(datum) => (datum.sans_lifetime(), ' ' as std::os::raw::c_char),
146        None => (pg_sys::Datum::from(0usize), 'n' as std::os::raw::c_char),
147    }
148}
149
150fn prepare<'conn>(
151    cmd: &CStr,
152    args: &[PgOid],
153    mutating: bool,
154) -> SpiResult<PreparedStatement<'conn>> {
155    // SAFETY: all arguments are prepared above
156    let plan = unsafe {
157        pg_sys::SPI_prepare(
158            cmd.as_ptr(),
159            args.len() as i32,
160            args.iter().map(|arg| arg.value()).collect::<Vec<_>>().as_mut_ptr(),
161        )
162    };
163    Ok(PreparedStatement {
164        plan: NonNull::new(plan).ok_or_else(|| {
165            Spi::check_status(unsafe {
166                // SAFETY: no concurrent usage
167                pg_sys::SPI_result
168            })
169            .err()
170            .unwrap()
171        })?,
172        __marker: PhantomData,
173        mutating,
174    })
175}
176
177macro_rules! impl_prepared_query {
178    ($t:ident, $s:ident) => {
179        impl<'conn> Query<'conn> for &$t {
180            #[inline]
181            fn execute(
182                self,
183                _client: &SpiClient<'conn>,
184                limit: Option<libc::c_long>,
185                args: &[DatumWithOid],
186            ) -> SpiResult<SpiTupleTable<'conn>> {
187                execute($s(self).as_ref(), args, limit)
188            }
189
190            #[inline]
191            fn try_open_cursor(
192                self,
193                _client: &SpiClient<'conn>,
194                args: &[DatumWithOid],
195            ) -> SpiResult<SpiCursor<'conn>> {
196                open_cursor($s(self).as_ref(), args)
197            }
198        }
199
200        impl<'conn> PreparableQuery<'conn> for &$t {
201            fn prepare(
202                self,
203                _client: &SpiClient<'conn>,
204                args: &[PgOid],
205            ) -> SpiResult<PreparedStatement<'conn>> {
206                prepare($s(self).as_ref(), args, false)
207            }
208
209            fn prepare_mut(
210                self,
211                _client: &SpiClient<'conn>,
212                args: &[PgOid],
213            ) -> SpiResult<PreparedStatement<'conn>> {
214                prepare($s(self).as_ref(), args, true)
215            }
216        }
217    };
218}
219
220#[inline]
221fn pass_as_is<T>(s: T) -> T {
222    s
223}
224
225#[inline]
226fn pass_with_conv<T: AsRef<str>>(s: T) -> CString {
227    CString::new(s.as_ref()).expect("query contained a null byte")
228}
229
230impl_prepared_query!(CStr, pass_as_is);
231impl_prepared_query!(CString, pass_as_is);
232impl_prepared_query!(String, pass_with_conv);
233impl_prepared_query!(str, pass_with_conv);
234
235/// Client lifetime-bound prepared statement
236pub struct PreparedStatement<'conn> {
237    pub(super) plan: NonNull<pg_sys::_SPI_plan>,
238    pub(super) __marker: PhantomData<&'conn ()>,
239    pub(super) mutating: bool,
240}
241
242/// Static lifetime-bound prepared statement
243pub struct OwnedPreparedStatement(PreparedStatement<'static>);
244
245impl Deref for OwnedPreparedStatement {
246    type Target = PreparedStatement<'static>;
247
248    fn deref(&self) -> &Self::Target {
249        &self.0
250    }
251}
252
253impl Drop for OwnedPreparedStatement {
254    fn drop(&mut self) {
255        unsafe {
256            pg_sys::SPI_freeplan(self.0.plan.as_ptr());
257        }
258    }
259}
260
261impl<'conn> Query<'conn> for &OwnedPreparedStatement {
262    fn execute<'mcx>(
263        self,
264        client: &SpiClient<'conn>,
265        limit: Option<libc::c_long>,
266        args: &[DatumWithOid<'mcx>],
267    ) -> SpiResult<SpiTupleTable<'conn>> {
268        (&self.0).execute(client, limit, args)
269    }
270
271    fn try_open_cursor<'mcx>(
272        self,
273        client: &SpiClient<'conn>,
274        args: &[DatumWithOid<'mcx>],
275    ) -> SpiResult<SpiCursor<'conn>> {
276        (&self.0).try_open_cursor(client, args)
277    }
278}
279
280impl<'conn> Query<'conn> for OwnedPreparedStatement {
281    fn execute<'mcx>(
282        self,
283        client: &SpiClient<'conn>,
284        limit: Option<libc::c_long>,
285        args: &[DatumWithOid<'mcx>],
286    ) -> SpiResult<SpiTupleTable<'conn>> {
287        (&self.0).execute(client, limit, args)
288    }
289
290    fn try_open_cursor<'mcx>(
291        self,
292        client: &SpiClient<'conn>,
293        args: &[DatumWithOid<'mcx>],
294    ) -> SpiResult<SpiCursor<'conn>> {
295        (&self.0).try_open_cursor(client, args)
296    }
297}
298
299impl PreparedStatement<'_> {
300    /// Converts prepared statement into an owned prepared statement
301    ///
302    /// These statements have static lifetime and are freed only when dropped
303    pub fn keep(self) -> OwnedPreparedStatement {
304        // SAFETY: self.plan is initialized in `SpiClient::prepare` and `PreparedStatement`
305        // is consumed. If it wasn't consumed, a subsequent call to `keep` would trigger
306        // an SPI_ERROR_ARGUMENT as per `SPI_keepplan` implementation.
307        unsafe {
308            pg_sys::SPI_keepplan(self.plan.as_ptr());
309        }
310        OwnedPreparedStatement(PreparedStatement {
311            __marker: PhantomData,
312            plan: self.plan,
313            mutating: self.mutating,
314        })
315    }
316
317    fn args_to_datums(
318        &self,
319        args: &[DatumWithOid<'_>],
320    ) -> SpiResult<(Vec<pg_sys::Datum>, Vec<std::os::raw::c_char>)> {
321        let actual = args.len();
322        let expected = unsafe { pg_sys::SPI_getargcount(self.plan.as_ptr()) } as usize;
323
324        if expected == actual {
325            Ok(args.iter().map(prepare_datum).unzip())
326        } else {
327            Err(SpiError::PreparedStatementArgumentMismatch { expected, got: actual })
328        }
329    }
330}
331
332impl<'conn: 'stmt, 'stmt> Query<'conn> for &'stmt PreparedStatement<'conn> {
333    fn execute<'mcx>(
334        self,
335        _client: &SpiClient<'conn>,
336        limit: Option<libc::c_long>,
337        args: &[DatumWithOid<'mcx>],
338    ) -> SpiResult<SpiTupleTable<'conn>> {
339        // SAFETY: no concurrent access
340        unsafe {
341            pg_sys::SPI_tuptable = std::ptr::null_mut();
342        }
343
344        let (mut datums, mut nulls) = self.args_to_datums(args)?;
345
346        // SAFETY: all arguments are prepared above
347        let status_code = unsafe {
348            pg_sys::SPI_execute_plan(
349                self.plan.as_ptr(),
350                datums.as_mut_ptr(),
351                nulls.as_mut_ptr(),
352                !self.mutating && Spi::is_xact_still_immutable(),
353                limit.unwrap_or(0),
354            )
355        };
356
357        SpiClient::prepare_tuple_table(status_code)
358    }
359
360    fn try_open_cursor<'mcx>(
361        self,
362        _client: &SpiClient<'conn>,
363        args: &[DatumWithOid<'mcx>],
364    ) -> SpiResult<SpiCursor<'conn>> {
365        let (mut datums, nulls) = self.args_to_datums(args)?;
366
367        // SAFETY: arguments are prepared above and SPI_cursor_open will never return the null
368        // pointer.  It'll raise an ERROR if something is invalid for it to create the cursor
369        let ptr = unsafe {
370            NonNull::new_unchecked(pg_sys::SPI_cursor_open(
371                std::ptr::null_mut(), // let postgres assign a name
372                self.plan.as_ptr(),
373                datums.as_mut_ptr(),
374                nulls.as_ptr(),
375                !self.mutating && Spi::is_xact_still_immutable(),
376            ))
377        };
378        Ok(SpiCursor { ptr, __marker: PhantomData })
379    }
380}
381
382impl<'conn> Query<'conn> for PreparedStatement<'conn> {
383    fn execute<'mcx>(
384        self,
385        client: &SpiClient<'conn>,
386        limit: Option<libc::c_long>,
387        args: &[DatumWithOid<'mcx>],
388    ) -> SpiResult<SpiTupleTable<'conn>> {
389        (&self).execute(client, limit, args)
390    }
391
392    fn try_open_cursor<'mcx>(
393        self,
394        client: &SpiClient<'conn>,
395        args: &[DatumWithOid<'mcx>],
396    ) -> SpiResult<SpiCursor<'conn>> {
397        (&self).try_open_cursor(client, args)
398    }
399}