Skip to main content

exarrow_rs/query/
prepared.rs

1//! Prepared statement handling for parameterized queries.
2//!
3//! This module provides the `PreparedStatement` type as a data container for
4//! server-side prepared statements. Execution is handled by Connection.
5
6use crate::error::QueryError;
7use crate::query::statement::Parameter;
8use crate::transport::messages::DataType;
9use crate::transport::protocol::PreparedStatementHandle;
10
11/// A prepared statement for parameterized query execution.
12///
13/// PreparedStatement stores the server-side statement handle and parameter values.
14/// Execution is performed by Connection, not by PreparedStatement itself.
15///
16/// # Example
17///
18pub struct PreparedStatement {
19    /// Server-side prepared statement handle
20    handle: PreparedStatementHandle,
21    /// Bound parameter values (column-major format for batch execution)
22    parameters: Vec<Option<Parameter>>,
23    /// Whether this prepared statement has been closed
24    closed: bool,
25}
26
27impl PreparedStatement {
28    /// Create a new PreparedStatement from a handle.
29    pub(crate) fn new(handle: PreparedStatementHandle) -> Self {
30        let num_params = handle.num_params as usize;
31        Self {
32            handle,
33            parameters: vec![None; num_params],
34            closed: false,
35        }
36    }
37
38    /// Get the number of parameters in this prepared statement.
39    pub fn parameter_count(&self) -> usize {
40        self.handle.num_params as usize
41    }
42
43    /// Get the parameter types (if available from server).
44    pub fn parameter_types(&self) -> &[DataType] {
45        &self.handle.parameter_types
46    }
47
48    /// Get the statement handle ID.
49    pub fn handle(&self) -> i32 {
50        self.handle.handle
51    }
52
53    /// Get the full handle (for Connection use).
54    pub(crate) fn handle_ref(&self) -> &PreparedStatementHandle {
55        &self.handle
56    }
57
58    /// Check if the prepared statement has been closed.
59    pub fn is_closed(&self) -> bool {
60        self.closed
61    }
62
63    /// Mark the prepared statement as closed (called by Connection).
64    pub(crate) fn mark_closed(&mut self) {
65        self.closed = true;
66    }
67
68    /// Bind a parameter value at the given index.
69    ///
70    /// # Arguments
71    /// * `index` - Zero-based parameter index
72    /// * `value` - Value to bind (must implement Into<Parameter>)
73    ///
74    /// # Errors
75    /// Returns `QueryError::ParameterBindingError` if index is out of bounds.
76    pub fn bind(&mut self, index: usize, value: impl Into<Parameter>) -> Result<(), QueryError> {
77        if index >= self.parameters.len() {
78            return Err(QueryError::ParameterBindingError {
79                index,
80                message: format!(
81                    "Parameter index {} out of bounds (statement has {} parameters)",
82                    index,
83                    self.parameters.len()
84                ),
85            });
86        }
87        self.parameters[index] = Some(value.into());
88        Ok(())
89    }
90
91    /// Clear all bound parameters.
92    pub fn clear_parameters(&mut self) {
93        for param in &mut self.parameters {
94            *param = None;
95        }
96    }
97
98    /// Get bound parameters.
99    pub fn parameters(&self) -> &[Option<Parameter>] {
100        &self.parameters
101    }
102
103    /// Build parameters data in column-major format for the protocol.
104    ///
105    /// This is used internally by Connection when executing prepared statements.
106    pub fn build_parameters_data(&self) -> Result<Option<Vec<Vec<serde_json::Value>>>, QueryError> {
107        if self.parameters.is_empty() {
108            return Ok(None);
109        }
110
111        // Check all parameters are bound
112        for (i, param) in self.parameters.iter().enumerate() {
113            if param.is_none() {
114                return Err(QueryError::ParameterBindingError {
115                    index: i,
116                    message: format!("Parameter {} is not bound", i),
117                });
118            }
119        }
120
121        // Convert to column-major format (each column has one value for single-row execution)
122        let columns: Vec<Vec<serde_json::Value>> = self
123            .parameters
124            .iter()
125            .map(|p| vec![parameter_to_json(p.as_ref().unwrap())])
126            .collect();
127
128        Ok(Some(columns))
129    }
130}
131
132/// Convert a Parameter to JSON value for the wire protocol.
133pub(crate) fn parameter_to_json(param: &Parameter) -> serde_json::Value {
134    match param {
135        Parameter::Null => serde_json::Value::Null,
136        Parameter::Boolean(b) => serde_json::Value::Bool(*b),
137        Parameter::Integer(i) => serde_json::json!(*i),
138        Parameter::Float(f) => serde_json::json!(*f),
139        Parameter::String(s) => serde_json::Value::String(s.clone()),
140        Parameter::Binary(b) => serde_json::Value::String(hex::encode(b)),
141    }
142}
143
144impl std::fmt::Debug for PreparedStatement {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("PreparedStatement")
147            .field("handle", &self.handle.handle)
148            .field("parameter_count", &self.parameter_count())
149            .field("closed", &self.closed)
150            .finish()
151    }
152}
153
154impl Drop for PreparedStatement {
155    fn drop(&mut self) {
156        // Note: We can't do async cleanup in Drop.
157        // Users should call Connection::close_prepared() explicitly for proper cleanup.
158        // The server will eventually clean up orphaned statements.
159        if !self.closed {
160            // Log a warning in debug builds
161            #[cfg(debug_assertions)]
162            eprintln!(
163                "Warning: PreparedStatement {} dropped without calling close_prepared()",
164                self.handle.handle
165            );
166        }
167    }
168}
169
170#[cfg(test)]
171#[allow(clippy::approx_constant)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_parameter_to_json() {
177        assert_eq!(parameter_to_json(&Parameter::Null), serde_json::Value::Null);
178        assert_eq!(
179            parameter_to_json(&Parameter::Boolean(true)),
180            serde_json::json!(true)
181        );
182        assert_eq!(
183            parameter_to_json(&Parameter::Integer(42)),
184            serde_json::json!(42)
185        );
186        assert_eq!(
187            parameter_to_json(&Parameter::Float(3.14)),
188            serde_json::json!(3.14)
189        );
190        assert_eq!(
191            parameter_to_json(&Parameter::String("hello".to_string())),
192            serde_json::json!("hello")
193        );
194        assert_eq!(
195            parameter_to_json(&Parameter::Binary(vec![0xDE, 0xAD])),
196            serde_json::json!("dead")
197        );
198    }
199
200    #[test]
201    fn test_prepared_statement_creation() {
202        let handle = PreparedStatementHandle::new(
203            42,
204            2,
205            vec![
206                DataType {
207                    type_name: "DECIMAL".to_string(),
208                    precision: Some(18),
209                    scale: Some(0),
210                    size: None,
211                    character_set: None,
212                    with_local_time_zone: None,
213                    fraction: None,
214                },
215                DataType {
216                    type_name: "VARCHAR".to_string(),
217                    precision: None,
218                    scale: None,
219                    size: Some(100),
220                    character_set: Some("UTF8".to_string()),
221                    with_local_time_zone: None,
222                    fraction: None,
223                },
224            ],
225        );
226
227        let stmt = PreparedStatement::new(handle);
228
229        assert_eq!(stmt.handle(), 42);
230        assert_eq!(stmt.parameter_count(), 2);
231        assert_eq!(stmt.parameter_types().len(), 2);
232        assert!(!stmt.is_closed());
233    }
234
235    #[test]
236    fn test_prepared_statement_bind_valid() {
237        let handle = PreparedStatementHandle::new(1, 2, vec![]);
238        let mut stmt = PreparedStatement::new(handle);
239
240        assert!(stmt.bind(0, 42).is_ok());
241        assert!(stmt.bind(1, "test").is_ok());
242    }
243
244    #[test]
245    fn test_prepared_statement_bind_out_of_bounds() {
246        let handle = PreparedStatementHandle::new(1, 2, vec![]);
247        let mut stmt = PreparedStatement::new(handle);
248
249        let result = stmt.bind(5, 42);
250        assert!(result.is_err());
251        assert!(matches!(
252            result.unwrap_err(),
253            QueryError::ParameterBindingError { index: 5, .. }
254        ));
255    }
256
257    #[test]
258    fn test_prepared_statement_clear_parameters() {
259        let handle = PreparedStatementHandle::new(1, 2, vec![]);
260        let mut stmt = PreparedStatement::new(handle);
261
262        stmt.bind(0, 42).unwrap();
263        stmt.bind(1, "test").unwrap();
264
265        stmt.clear_parameters();
266
267        // Build should fail because parameters are cleared
268        let result = stmt.build_parameters_data();
269        assert!(result.is_err());
270    }
271
272    #[test]
273    fn test_prepared_statement_build_params_unbound() {
274        let handle = PreparedStatementHandle::new(1, 2, vec![]);
275        let stmt = PreparedStatement::new(handle);
276
277        let result = stmt.build_parameters_data();
278        assert!(result.is_err());
279        assert!(matches!(
280            result.unwrap_err(),
281            QueryError::ParameterBindingError { index: 0, .. }
282        ));
283    }
284
285    #[test]
286    fn test_prepared_statement_build_params_success() {
287        let handle = PreparedStatementHandle::new(1, 2, vec![]);
288        let mut stmt = PreparedStatement::new(handle);
289
290        stmt.bind(0, 42).unwrap();
291        stmt.bind(1, "test").unwrap();
292
293        let result = stmt.build_parameters_data().unwrap();
294        assert!(result.is_some());
295
296        let columns = result.unwrap();
297        assert_eq!(columns.len(), 2);
298        assert_eq!(columns[0], vec![serde_json::json!(42)]);
299        assert_eq!(columns[1], vec![serde_json::json!("test")]);
300    }
301
302    #[test]
303    fn test_prepared_statement_no_params() {
304        let handle = PreparedStatementHandle::new(1, 0, vec![]);
305        let stmt = PreparedStatement::new(handle);
306
307        assert_eq!(stmt.parameter_count(), 0);
308        let result = stmt.build_parameters_data();
309        assert!(result.is_ok());
310        assert!(result.unwrap().is_none());
311    }
312
313    #[test]
314    fn test_parameter_to_json_all_types() {
315        // Test null
316        assert_eq!(parameter_to_json(&Parameter::Null), serde_json::Value::Null);
317
318        // Test boolean
319        assert_eq!(
320            parameter_to_json(&Parameter::Boolean(true)),
321            serde_json::Value::Bool(true)
322        );
323        assert_eq!(
324            parameter_to_json(&Parameter::Boolean(false)),
325            serde_json::Value::Bool(false)
326        );
327
328        // Test integer
329        assert_eq!(
330            parameter_to_json(&Parameter::Integer(0)),
331            serde_json::json!(0)
332        );
333        assert_eq!(
334            parameter_to_json(&Parameter::Integer(-1)),
335            serde_json::json!(-1)
336        );
337        assert_eq!(
338            parameter_to_json(&Parameter::Integer(i64::MAX)),
339            serde_json::json!(i64::MAX)
340        );
341
342        // Test float
343        assert_eq!(
344            parameter_to_json(&Parameter::Float(0.0)),
345            serde_json::json!(0.0)
346        );
347        assert_eq!(
348            parameter_to_json(&Parameter::Float(-1.5)),
349            serde_json::json!(-1.5)
350        );
351
352        // Test string
353        assert_eq!(
354            parameter_to_json(&Parameter::String("".to_string())),
355            serde_json::json!("")
356        );
357        assert_eq!(
358            parameter_to_json(&Parameter::String("hello world".to_string())),
359            serde_json::json!("hello world")
360        );
361
362        // Test binary
363        assert_eq!(
364            parameter_to_json(&Parameter::Binary(vec![])),
365            serde_json::json!("")
366        );
367        assert_eq!(
368            parameter_to_json(&Parameter::Binary(vec![0x00, 0xFF])),
369            serde_json::json!("00ff")
370        );
371    }
372}