Skip to main content

hyperdb_api_core/client/
prepare.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Prepared statement support using extended query protocol.
5//!
6//! # Parameter Encoding
7//!
8//! Use the \[`params!`\] macro for ergonomic parameter encoding:
9//!
10//! ```no_run
11//! # use hyperdb_api_core::{params, client::{Client, Config}};
12//! # fn example(client: &Client) -> hyperdb_api_core::client::Result<()> {
13//! let stmt = client.prepare("SELECT * FROM users WHERE id = $1 AND name = $2")?;
14//! let rows = client.execute(&stmt, params![42_i32, "Alice"])?;
15//! # Ok(())
16//! # }
17//! ```
18
19use std::sync::atomic::{AtomicU64, Ordering};
20use std::sync::{Arc, Mutex, Weak};
21
22use crate::protocol::message::{backend::Message, frontend};
23use crate::types::Oid;
24use tracing::{trace, warn};
25
26use super::connection::RawConnection;
27use super::error::{Error, Result};
28use super::row::Row;
29use super::statement::Column;
30use super::sync_stream::SyncStream;
31
32// =============================================================================
33// SqlParam trait - Zero-cost parameter encoding
34// =============================================================================
35
36/// Trait for types that can be encoded as SQL prepared statement parameters.
37///
38/// This trait enables the \[`params!`\] macro to automatically encode values.
39/// All implementations use `#[inline]` for zero-cost abstraction.
40pub trait SqlParam {
41    /// Encodes the value as binary bytes.
42    fn encode(&self) -> Vec<u8>;
43}
44
45impl SqlParam for i16 {
46    #[inline]
47    fn encode(&self) -> Vec<u8> {
48        self.to_le_bytes().to_vec()
49    }
50}
51
52impl SqlParam for i32 {
53    #[inline]
54    fn encode(&self) -> Vec<u8> {
55        self.to_le_bytes().to_vec()
56    }
57}
58
59impl SqlParam for i64 {
60    #[inline]
61    fn encode(&self) -> Vec<u8> {
62        self.to_le_bytes().to_vec()
63    }
64}
65
66impl SqlParam for f32 {
67    #[inline]
68    fn encode(&self) -> Vec<u8> {
69        self.to_le_bytes().to_vec()
70    }
71}
72
73impl SqlParam for f64 {
74    #[inline]
75    fn encode(&self) -> Vec<u8> {
76        self.to_le_bytes().to_vec()
77    }
78}
79
80impl SqlParam for bool {
81    #[inline]
82    fn encode(&self) -> Vec<u8> {
83        vec![u8::from(*self)]
84    }
85}
86
87impl SqlParam for &str {
88    #[inline]
89    fn encode(&self) -> Vec<u8> {
90        self.as_bytes().to_vec()
91    }
92}
93
94impl SqlParam for String {
95    #[inline]
96    fn encode(&self) -> Vec<u8> {
97        self.as_bytes().to_vec()
98    }
99}
100
101impl SqlParam for &String {
102    #[inline]
103    fn encode(&self) -> Vec<u8> {
104        self.as_bytes().to_vec()
105    }
106}
107
108impl SqlParam for Vec<u8> {
109    #[inline]
110    fn encode(&self) -> Vec<u8> {
111        self.clone()
112    }
113}
114
115impl SqlParam for &[u8] {
116    #[inline]
117    fn encode(&self) -> Vec<u8> {
118        self.to_vec()
119    }
120}
121
122/// Macro for building prepared statement parameters with automatic encoding.
123///
124/// # Examples
125///
126/// ```no_run
127/// # use hyperdb_api_core::{params, client::{Client, Config, SqlParam}};
128/// # fn example(client: &Client) -> hyperdb_api_core::client::Result<()> {
129/// let stmt = client.prepare("SELECT * FROM t WHERE id = $1 AND name = $2")?;
130///
131/// // Pass typed values directly
132/// let rows = client.execute(&stmt, params![42_i32, "Alice"])?;
133///
134/// // For NULL values, use None explicitly
135/// let rows = client.execute(&stmt, &[Some(42_i32.encode()), None])?;
136/// # Ok(())
137/// # }
138/// ```
139#[macro_export]
140macro_rules! params {
141    () => {
142        &[] as &[Option<Vec<u8>>]
143    };
144    ($($val:expr),+ $(,)?) => {{
145        use $crate::client::prepare::SqlParam;
146        vec![$(Some($val.encode())),+]
147    }};
148}
149
150/// Counter for generating unique statement names.
151static STATEMENT_COUNTER: AtomicU64 = AtomicU64::new(0);
152
153/// Generates a unique statement name.
154fn generate_statement_name() -> String {
155    let id = STATEMENT_COUNTER.fetch_add(1, Ordering::Relaxed);
156    format!("__hyper_stmt_{id}")
157}
158
159/// A prepared statement.
160///
161/// Prepared statements allow you to execute the same query multiple times
162/// with different parameters efficiently. The statement is prepared once on
163/// the server and can be executed many times with different parameter values.
164///
165/// For automatic cleanup, use \[`OwnedPreparedStatement`\] via \[`crate::Client::prepare`\].
166///
167/// # Example
168///
169/// ```no_run
170/// # use hyperdb_api_core::{params, client::{Client, Config}};
171/// # fn example(client: &Client) -> hyperdb_api_core::client::Result<()> {
172/// let stmt = client.prepare("SELECT * FROM users WHERE id = $1")?;
173/// let rows1 = client.execute(&stmt, params![42_i32])?;
174/// let rows2 = client.execute(&stmt, params![100_i32])?;
175/// # Ok(())
176/// # }
177/// ```
178#[derive(Debug)]
179pub struct PreparedStatement {
180    /// Statement name on the server (used for Bind/Execute messages).
181    name: String,
182    /// Original SQL query string.
183    query: String,
184    /// Parameter type OIDs (empty if types were inferred by the server).
185    param_types: Vec<Oid>,
186    /// Result column descriptions (populated after first execution).
187    columns: Vec<Column>,
188}
189
190/// A prepared statement that automatically closes itself when dropped.
191///
192/// This is the recommended way to use prepared statements. It holds a weak
193/// reference to the connection and automatically closes the statement when dropped.
194///
195/// # Example
196///
197/// ```no_run
198/// # use hyperdb_api_core::{params, client::{Client, Config}};
199/// # fn example(client: &Client) -> hyperdb_api_core::client::Result<()> {
200/// // Statement automatically closes when it goes out of scope
201/// {
202///     let stmt = client.prepare("SELECT * FROM users WHERE id = $1")?;
203///     let rows = client.execute(&stmt, params![42_i32])?;
204/// } // Statement is automatically closed here
205/// # Ok(())
206/// # }
207/// ```
208#[derive(Debug)]
209pub struct OwnedPreparedStatement {
210    /// The underlying prepared statement.
211    statement: PreparedStatement,
212    /// Weak reference to the connection for cleanup.
213    connection: Weak<Mutex<RawConnection<SyncStream>>>,
214}
215
216impl OwnedPreparedStatement {
217    /// Creates a new owned prepared statement.
218    pub(crate) fn new(
219        statement: PreparedStatement,
220        connection: &Arc<Mutex<RawConnection<SyncStream>>>,
221    ) -> Self {
222        OwnedPreparedStatement {
223            statement,
224            connection: Arc::downgrade(connection),
225        }
226    }
227
228    /// Returns the statement name.
229    #[must_use]
230    pub fn name(&self) -> &str {
231        self.statement.name()
232    }
233
234    /// Returns the original query.
235    #[must_use]
236    pub fn query(&self) -> &str {
237        self.statement.query()
238    }
239
240    /// Returns the parameter types.
241    #[must_use]
242    pub fn param_types(&self) -> &[Oid] {
243        self.statement.param_types()
244    }
245
246    /// Returns the number of parameters.
247    #[must_use]
248    pub fn param_count(&self) -> usize {
249        self.statement.param_count()
250    }
251
252    /// Returns the result column descriptions.
253    #[must_use]
254    pub fn columns(&self) -> &[Column] {
255        self.statement.columns()
256    }
257
258    /// Returns the number of result columns.
259    #[must_use]
260    pub fn column_count(&self) -> usize {
261        self.statement.column_count()
262    }
263
264    /// Returns a reference to the underlying `PreparedStatement`.
265    #[must_use]
266    pub fn statement(&self) -> &PreparedStatement {
267        &self.statement
268    }
269
270    /// Explicitly closes the statement, returning any error.
271    ///
272    /// This is called automatically when the `OwnedPreparedStatement` is
273    /// dropped, but errors are silently ignored in that case. Use this
274    /// method if you need to handle close errors.
275    ///
276    /// # Errors
277    ///
278    /// Propagates any error from [`close_statement`] — connection
279    /// mutex poisoning, server-side error during `Close`/`Sync`, or
280    /// wire I/O failure. Returns `Ok(())` without contacting the server
281    /// when the connection has already been dropped.
282    pub fn close(self) -> Result<()> {
283        if let Some(conn) = self.connection.upgrade() {
284            close_statement(&conn, &self.statement)?;
285        }
286        // Don't run Drop since we've already closed
287        std::mem::forget(self);
288        Ok(())
289    }
290}
291
292impl Drop for OwnedPreparedStatement {
293    fn drop(&mut self) {
294        // Best-effort cleanup - log errors but don't panic during drop
295        if let Some(conn) = self.connection.upgrade() {
296            if let Err(e) = close_statement_internal(&conn, &self.statement) {
297                warn!(
298                    target: "hyperdb_api",
299                    statement_name = %self.statement.name,
300                    error = %e,
301                    "failed-to-close-prepared-statement-during-drop"
302                );
303            }
304        }
305        // If the connection is already dropped, we can't close the statement
306        // but that's okay - the server will clean it up when the connection closes
307    }
308}
309
310impl PreparedStatement {
311    /// Returns the statement name.
312    #[must_use]
313    pub fn name(&self) -> &str {
314        &self.name
315    }
316
317    /// Returns the original query.
318    #[must_use]
319    pub fn query(&self) -> &str {
320        &self.query
321    }
322
323    /// Returns the parameter types.
324    #[must_use]
325    pub fn param_types(&self) -> &[Oid] {
326        &self.param_types
327    }
328
329    /// Returns the number of parameters.
330    #[must_use]
331    pub fn param_count(&self) -> usize {
332        self.param_types.len()
333    }
334
335    /// Returns the result column descriptions.
336    #[must_use]
337    pub fn columns(&self) -> &[Column] {
338        &self.columns
339    }
340
341    /// Returns the number of result columns.
342    #[must_use]
343    pub fn column_count(&self) -> usize {
344        self.columns.len()
345    }
346}
347
348/// Prepares a statement using the extended query protocol.
349///
350/// # Errors
351///
352/// - Returns [`Error`] (connection) if the connection mutex is poisoned.
353/// - Returns [`Error`] (server) if the server rejects the `Parse` request
354///   (SQL syntax error, unknown parameter OIDs, etc.).
355/// - Returns [`Error`] (I/O) / [`Error`] (closed) on wire-protocol I/O
356///   failure.
357pub fn prepare(
358    connection: &Arc<Mutex<RawConnection<SyncStream>>>,
359    query: &str,
360    param_types: &[Oid],
361) -> Result<PreparedStatement> {
362    let name = generate_statement_name();
363    let mut conn = connection
364        .lock()
365        .map_err(|_| Error::connection("connection mutex poisoned"))?;
366
367    // Send Parse message
368    frontend::parse(&name, query, param_types, conn.write_buf())?;
369
370    // Send Describe message for the statement
371    frontend::describe(b'S', &name, conn.write_buf())?;
372
373    // Send Sync to get responses
374    frontend::sync(conn.write_buf());
375    conn.flush()?;
376
377    // Process responses
378    let mut parsed_params = Vec::new();
379    let mut parsed_columns = Vec::new();
380
381    loop {
382        let msg = conn.read_message()?;
383        match msg {
384            Message::ParseComplete => {
385                // Statement parsed successfully
386            }
387            Message::ParameterDescription(desc) => {
388                for oid in desc.parameters().filter_map(|r| {
389                    r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing parameter OID")).ok()
390                }) {
391                    parsed_params.push(oid);
392                }
393            }
394            Message::RowDescription(desc) => {
395                for f in desc.fields().filter_map(|r| {
396                    r.map_err(|e| trace!(target: "hyperdb_api_core::client", error = %e, "dropped error parsing row description field")).ok()
397                }) {
398                    parsed_columns.push(Column::new(
399                        f.name().to_string(),
400                        f.type_oid(),
401                        f.type_modifier(),
402                        super::statement::ColumnFormat::from_code(f.format()),
403                    ));
404                }
405            }
406            Message::NoData => {
407                // Statement returns no data (e.g., INSERT)
408            }
409            Message::ReadyForQuery(_) => {
410                break;
411            }
412            Message::ErrorResponse(body) => {
413                return Err(conn.consume_error(&body));
414            }
415            _ => {}
416        }
417    }
418
419    Ok(PreparedStatement {
420        name,
421        query: query.to_string(),
422        param_types: parsed_params,
423        columns: parsed_columns,
424    })
425}
426
427/// Executes a prepared statement with parameters.
428///
429/// # Errors
430///
431/// - Returns [`Error`] (connection) if the connection mutex is poisoned.
432/// - Returns [`Error`] (server) if the server rejects `Bind` / `Execute`
433///   (parameter type mismatch, constraint violation, etc.).
434/// - Returns [`Error`] (I/O) / [`Error`] (closed) on wire-protocol I/O
435///   failure.
436/// - Propagates row-construction errors from `Row::new` if a
437///   `DataRow` cannot be decoded against the prepared columns.
438pub fn execute_prepared(
439    connection: &Arc<Mutex<RawConnection<SyncStream>>>,
440    statement: &PreparedStatement,
441    params: &[Option<&[u8]>],
442) -> Result<Vec<Row>> {
443    let mut conn = connection
444        .lock()
445        .map_err(|_| Error::connection("connection mutex poisoned"))?;
446
447    // Bind parameters (all in binary format)
448    let param_formats: Vec<i16> = vec![1; params.len()]; // 1 = binary
449    let result_formats: Vec<i16> = vec![1; statement.columns.len()]; // 1 = binary
450
451    frontend::bind(
452        "", // unnamed portal
453        &statement.name,
454        &param_formats,
455        params,
456        &result_formats,
457        conn.write_buf(),
458    )?;
459
460    // Execute
461    frontend::execute("", 0, conn.write_buf())?; // 0 = fetch all rows
462
463    // Sync
464    frontend::sync(conn.write_buf());
465    conn.flush()?;
466
467    // Process responses
468    let mut rows = Vec::new();
469    let columns = Arc::new(statement.columns.clone());
470
471    loop {
472        let msg = conn.read_message()?;
473        match msg {
474            Message::BindComplete => {
475                // Bind succeeded
476            }
477            Message::DataRow(data) => {
478                rows.push(Row::new(Arc::clone(&columns), data)?);
479            }
480            Message::CommandComplete(_) => {
481                // Execution complete
482            }
483            Message::EmptyQueryResponse => {
484                // Empty query
485            }
486            Message::ReadyForQuery(_) => {
487                break;
488            }
489            Message::ErrorResponse(body) => {
490                return Err(conn.consume_error(&body));
491            }
492            _ => {}
493        }
494    }
495
496    Ok(rows)
497}
498
499/// Executes a prepared statement that doesn't return rows.
500///
501/// # Errors
502///
503/// Same failure modes as [`execute_prepared`] (minus row-construction
504/// errors — this path never builds rows).
505pub fn execute_prepared_no_result(
506    connection: &Arc<Mutex<RawConnection<SyncStream>>>,
507    statement: &PreparedStatement,
508    params: &[Option<&[u8]>],
509) -> Result<u64> {
510    let mut conn = connection
511        .lock()
512        .map_err(|_| Error::connection("connection mutex poisoned"))?;
513
514    // Bind parameters
515    let param_formats: Vec<i16> = vec![1; params.len()];
516    let result_formats: Vec<i16> = vec![];
517
518    frontend::bind(
519        "",
520        &statement.name,
521        &param_formats,
522        params,
523        &result_formats,
524        conn.write_buf(),
525    )?;
526
527    // Execute
528    frontend::execute("", 0, conn.write_buf())?;
529
530    // Sync
531    frontend::sync(conn.write_buf());
532    conn.flush()?;
533
534    // Process responses
535    let mut affected_rows = 0u64;
536
537    loop {
538        let msg = conn.read_message()?;
539        match msg {
540            Message::BindComplete => {}
541            Message::CommandComplete(body) => {
542                if let Ok(tag) = body.tag() {
543                    affected_rows = parse_affected_rows(tag);
544                }
545            }
546            Message::EmptyQueryResponse => {}
547            Message::ReadyForQuery(_) => {
548                break;
549            }
550            Message::ErrorResponse(body) => {
551                return Err(conn.consume_error(&body));
552            }
553            _ => {}
554        }
555    }
556
557    Ok(affected_rows)
558}
559
560/// Closes a prepared statement on the server.
561///
562/// # Errors
563///
564/// - Returns [`Error`] (connection) if the connection mutex is poisoned.
565/// - Returns [`Error`] (server) if the server reports an `ErrorResponse`
566///   during `Close`/`Sync`.
567/// - Returns [`Error`] (I/O) / [`Error`] (closed) on wire-protocol I/O
568///   failure.
569pub fn close_statement(
570    connection: &Arc<Mutex<RawConnection<SyncStream>>>,
571    statement: &PreparedStatement,
572) -> Result<()> {
573    close_statement_internal(connection, statement)
574}
575
576/// Internal close function that can be used from Drop.
577fn close_statement_internal(
578    connection: &Arc<Mutex<RawConnection<SyncStream>>>,
579    statement: &PreparedStatement,
580) -> Result<()> {
581    let mut conn = connection
582        .lock()
583        .map_err(|_| Error::connection("connection mutex poisoned"))?;
584
585    // Send Close message for the statement
586    frontend::close(b'S', &statement.name, conn.write_buf())?;
587
588    // Sync
589    frontend::sync(conn.write_buf());
590    conn.flush()?;
591
592    // Process responses
593    loop {
594        let msg = conn.read_message()?;
595        match msg {
596            Message::CloseComplete => {}
597            Message::ReadyForQuery(_) => {
598                break;
599            }
600            Message::ErrorResponse(body) => {
601                return Err(conn.consume_error(&body));
602            }
603            _ => {}
604        }
605    }
606
607    Ok(())
608}
609
610/// Creates an owned prepared statement that automatically closes when dropped.
611///
612/// # Errors
613///
614/// Propagates any error from [`prepare`].
615pub fn prepare_owned(
616    connection: &Arc<Mutex<RawConnection<SyncStream>>>,
617    query: &str,
618    param_types: &[Oid],
619) -> Result<OwnedPreparedStatement> {
620    let statement = prepare(connection, query, param_types)?;
621    Ok(OwnedPreparedStatement::new(statement, connection))
622}
623
624/// Parses affected row count from a command tag.
625fn parse_affected_rows(tag: &str) -> u64 {
626    let parts: Vec<&str> = tag.split_whitespace().collect();
627
628    match parts.first() {
629        Some(&"INSERT") => parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0),
630        Some(&"UPDATE" | &"DELETE" | &"SELECT" | &"COPY") => {
631            parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0)
632        }
633        _ => 0,
634    }
635}
636
637#[cfg(test)]
638mod tests {
639    use super::*;
640
641    #[test]
642    fn test_sql_param_i16() {
643        assert_eq!(0_i16.encode(), vec![0, 0]);
644        assert_eq!(1_i16.encode(), vec![1, 0]);
645        assert_eq!((-1_i16).encode(), vec![255, 255]);
646    }
647
648    #[test]
649    fn test_sql_param_i32() {
650        assert_eq!(0_i32.encode(), vec![0, 0, 0, 0]);
651        assert_eq!(1_i32.encode(), vec![1, 0, 0, 0]);
652        assert_eq!((-1_i32).encode(), vec![255, 255, 255, 255]);
653        assert_eq!(256_i32.encode(), vec![0, 1, 0, 0]);
654    }
655
656    #[test]
657    fn test_sql_param_i64() {
658        assert_eq!(0_i64.encode(), vec![0, 0, 0, 0, 0, 0, 0, 0]);
659        assert_eq!(1_i64.encode(), vec![1, 0, 0, 0, 0, 0, 0, 0]);
660        assert_eq!(
661            (-1_i64).encode(),
662            vec![255, 255, 255, 255, 255, 255, 255, 255]
663        );
664    }
665
666    #[test]
667    #[expect(
668        clippy::float_cmp,
669        reason = "1.5 is exactly representable; encode/decode must round-trip bit-for-bit"
670    )]
671    fn test_sql_param_f32() {
672        let encoded = 1.5_f32.encode();
673        assert_eq!(encoded.len(), 4);
674        let decoded = f32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]);
675        assert_eq!(decoded, 1.5);
676    }
677
678    #[test]
679    #[expect(
680        clippy::float_cmp,
681        reason = "1.5 is exactly representable; encode/decode must round-trip bit-for-bit"
682    )]
683    fn test_sql_param_f64() {
684        let encoded = 1.5_f64.encode();
685        assert_eq!(encoded.len(), 8);
686        let decoded = f64::from_le_bytes([
687            encoded[0], encoded[1], encoded[2], encoded[3], encoded[4], encoded[5], encoded[6],
688            encoded[7],
689        ]);
690        assert_eq!(decoded, 1.5);
691    }
692
693    #[test]
694    fn test_sql_param_bool() {
695        assert_eq!(true.encode(), vec![1]);
696        assert_eq!(false.encode(), vec![0]);
697    }
698
699    #[test]
700    fn test_sql_param_str() {
701        assert_eq!("hello".encode(), b"hello".to_vec());
702        assert_eq!("".encode(), Vec::<u8>::new());
703        assert_eq!("héllo".encode(), "héllo".as_bytes().to_vec());
704    }
705
706    #[test]
707    fn test_sql_param_string() {
708        let s = String::from("hello");
709        assert_eq!(s.encode(), b"hello".to_vec());
710        assert_eq!(s.encode(), b"hello".to_vec());
711    }
712
713    #[test]
714    fn test_sql_param_bytes() {
715        let bytes: Vec<u8> = vec![1, 2, 3, 4];
716        assert_eq!(bytes.encode(), vec![1, 2, 3, 4]);
717        assert_eq!(bytes.as_slice().encode(), vec![1, 2, 3, 4]);
718    }
719
720    #[test]
721    fn test_params_macro_empty() {
722        let p = params![];
723        assert!(p.is_empty());
724    }
725
726    #[test]
727    fn test_params_macro_single() {
728        let p = params![42_i32];
729        assert_eq!(p.len(), 1);
730        assert_eq!(p[0], Some(vec![42, 0, 0, 0]));
731    }
732
733    #[test]
734    fn test_params_macro_multiple() {
735        let p = params![42_i32, "hello", true];
736        assert_eq!(p.len(), 3);
737        assert_eq!(p[0], Some(vec![42, 0, 0, 0]));
738        assert_eq!(p[1], Some(b"hello".to_vec()));
739        assert_eq!(p[2], Some(vec![1]));
740    }
741}