Skip to main content

qail_pg/driver/
query.rs

1//! Query execution methods for PostgreSQL connection.
2//!
3//! This module provides query, query_cached, and execute_simple.
4
5use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, PgEncoder};
7use bytes::BytesMut;
8use tokio::io::AsyncWriteExt;
9
10impl PgConnection {
11    /// Execute a query with binary parameters (crate-internal).
12    /// This uses the Extended Query Protocol (Parse/Bind/Execute/Sync):
13    /// - Parameters are sent as binary bytes, skipping the string layer
14    /// - No SQL injection possible - parameters are never interpolated
15    /// - Better performance via prepared statement reuse
16    pub(crate) async fn query(
17        &mut self,
18        sql: &str,
19        params: &[Option<Vec<u8>>],
20    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
21        self.query_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
22            .await
23    }
24
25    /// Execute a query with binary parameters and explicit result-column format.
26    pub(crate) async fn query_with_result_format(
27        &mut self,
28        sql: &str,
29        params: &[Option<Vec<u8>>],
30        result_format: i16,
31    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
32        let bytes = PgEncoder::encode_extended_query_with_result_format(sql, params, result_format)
33            .map_err(|e| PgError::Encode(e.to_string()))?;
34        self.stream.write_all(&bytes).await?;
35
36        let mut rows = Vec::new();
37
38        let mut error: Option<PgError> = None;
39
40        loop {
41            let msg = self.recv().await?;
42            match msg {
43                BackendMessage::ParseComplete => {}
44                BackendMessage::BindComplete => {}
45                BackendMessage::RowDescription(_) => {}
46                BackendMessage::DataRow(data) => {
47                    // Only collect rows if no error occurred
48                    if error.is_none() {
49                        rows.push(data);
50                    }
51                }
52                BackendMessage::CommandComplete(_) => {}
53                BackendMessage::NoData => {}
54                BackendMessage::ReadyForQuery(_) => {
55                    if let Some(err) = error {
56                        return Err(err);
57                    }
58                    return Ok(rows);
59                }
60                BackendMessage::ErrorResponse(err) => {
61                    if error.is_none() {
62                        error = Some(PgError::QueryServer(err.into()));
63                    }
64                }
65                _ => {}
66            }
67        }
68    }
69
70    /// Execute a query with cached prepared statement.
71    /// Like `query()`, but reuses prepared statements across calls.
72    /// The statement name is derived from a hash of the SQL text.
73    /// OPTIMIZED: Pre-allocated buffer + ultra-fast encoders.
74    pub async fn query_cached(
75        &mut self,
76        sql: &str,
77        params: &[Option<Vec<u8>>],
78    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
79        self.query_cached_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
80            .await
81    }
82
83    /// Execute a query with cached prepared statement and explicit result-column format.
84    pub async fn query_cached_with_result_format(
85        &mut self,
86        sql: &str,
87        params: &[Option<Vec<u8>>],
88        result_format: i16,
89    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
90        let mut retried = false;
91        loop {
92            match self
93                .query_cached_with_result_format_once(sql, params, result_format)
94                .await
95            {
96                Ok(rows) => return Ok(rows),
97                Err(err) if !retried && err.is_prepared_statement_retryable() => {
98                    retried = true;
99                    self.clear_prepared_statement_state();
100                }
101                Err(err) => return Err(err),
102            }
103        }
104    }
105
106    async fn query_cached_with_result_format_once(
107        &mut self,
108        sql: &str,
109        params: &[Option<Vec<u8>>],
110        result_format: i16,
111    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
112        let stmt_name = Self::sql_to_stmt_name(sql);
113        let is_new = !self.prepared_statements.contains_key(&stmt_name);
114
115        // Pre-calculate buffer size for single allocation
116        let params_size: usize = params
117            .iter()
118            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
119            .sum();
120
121        let estimated_size = if is_new {
122            50 + sql.len() + stmt_name.len() * 2 + params_size
123        } else {
124            30 + stmt_name.len() + params_size
125        };
126
127        let mut buf = BytesMut::with_capacity(estimated_size);
128
129        if is_new {
130            // Evict LRU prepared statement if at capacity. This prevents
131            // unbounded memory growth from dynamic batch filters while
132            // preserving hot statements (unlike the old nuclear `.clear()`).
133            self.evict_prepared_if_full();
134            buf.extend(PgEncoder::encode_parse(&stmt_name, sql, &[]));
135            // Cache the SQL for debugging
136            self.prepared_statements
137                .insert(stmt_name.clone(), sql.to_string());
138        }
139
140        // Use ULTRA-OPTIMIZED encoders - write directly to buffer
141        PgEncoder::encode_bind_to_with_result_format(&mut buf, &stmt_name, params, result_format)
142            .map_err(|e| PgError::Encode(e.to_string()))?;
143        PgEncoder::encode_execute_to(&mut buf);
144        PgEncoder::encode_sync_to(&mut buf);
145
146        self.stream.write_all(&buf).await?;
147
148        let mut rows = Vec::new();
149
150        let mut error: Option<PgError> = None;
151
152        loop {
153            let msg = self.recv().await?;
154            match msg {
155                BackendMessage::ParseComplete => {
156                    // Already cached in is_new block above
157                }
158                BackendMessage::BindComplete => {}
159                BackendMessage::RowDescription(_) => {}
160                BackendMessage::DataRow(data) => {
161                    if error.is_none() {
162                        rows.push(data);
163                    }
164                }
165                BackendMessage::CommandComplete(_) => {}
166                BackendMessage::NoData => {}
167                BackendMessage::ReadyForQuery(_) => {
168                    if let Some(err) = error {
169                        return Err(err);
170                    }
171                    return Ok(rows);
172                }
173                BackendMessage::ErrorResponse(err) => {
174                    if error.is_none() {
175                        error = Some(PgError::QueryServer(err.into()));
176                        // Invalidate cache to prevent "prepared statement does not exist"
177                        // on next retry.
178                        self.prepared_statements.remove(&stmt_name);
179                    }
180                }
181                _ => {}
182            }
183        }
184    }
185
186    /// Generate a statement name from SQL hash.
187    /// Uses a simple hash to create a unique name like "stmt_12345abc".
188    pub(crate) fn sql_to_stmt_name(sql: &str) -> String {
189        use std::collections::hash_map::DefaultHasher;
190        use std::hash::{Hash, Hasher};
191
192        let mut hasher = DefaultHasher::new();
193        sql.hash(&mut hasher);
194        format!("s{:016x}", hasher.finish())
195    }
196
197    /// Execute a simple SQL statement (no parameters).
198    pub async fn execute_simple(&mut self, sql: &str) -> PgResult<()> {
199        let bytes = PgEncoder::encode_query_string(sql);
200        self.stream.write_all(&bytes).await?;
201
202        let mut error: Option<PgError> = None;
203
204        loop {
205            let msg = self.recv().await?;
206            match msg {
207                BackendMessage::CommandComplete(_) => {}
208                BackendMessage::ReadyForQuery(_) => {
209                    if let Some(err) = error {
210                        return Err(err);
211                    }
212                    return Ok(());
213                }
214                BackendMessage::ErrorResponse(err) => {
215                    if error.is_none() {
216                        error = Some(PgError::QueryServer(err.into()));
217                    }
218                }
219                _ => {}
220            }
221        }
222    }
223
224    /// Execute a simple SQL query and return rows (Simple Query Protocol).
225    ///
226    /// Unlike `execute_simple`, this collects and returns data rows.
227    /// Used for branch management and other administrative queries.
228    ///
229    /// SECURITY: Capped at 10,000 rows to prevent OOM from unbounded results.
230    pub async fn simple_query(&mut self, sql: &str) -> PgResult<Vec<super::PgRow>> {
231        use std::sync::Arc;
232
233        /// Safety cap to prevent OOM from unbounded result accumulation.
234        /// Simple Query Protocol has no streaming; all rows are buffered in memory.
235        const MAX_SIMPLE_QUERY_ROWS: usize = 10_000;
236
237        let bytes = PgEncoder::encode_query_string(sql);
238        self.stream.write_all(&bytes).await?;
239
240        let mut rows: Vec<super::PgRow> = Vec::new();
241        let mut column_info: Option<Arc<super::ColumnInfo>> = None;
242        let mut error: Option<PgError> = None;
243
244        loop {
245            let msg = self.recv().await?;
246            match msg {
247                BackendMessage::RowDescription(fields) => {
248                    column_info = Some(Arc::new(super::ColumnInfo::from_fields(&fields)));
249                }
250                BackendMessage::DataRow(data) => {
251                    if error.is_none() {
252                        if rows.len() >= MAX_SIMPLE_QUERY_ROWS {
253                            if error.is_none() {
254                                error = Some(PgError::Query(format!(
255                                    "simple_query exceeded {} row safety cap",
256                                    MAX_SIMPLE_QUERY_ROWS,
257                                )));
258                            }
259                            // Continue draining to reach ReadyForQuery
260                        } else {
261                            rows.push(super::PgRow {
262                                columns: data,
263                                column_info: column_info.clone(),
264                            });
265                        }
266                    }
267                }
268                BackendMessage::CommandComplete(_) => {}
269                BackendMessage::ReadyForQuery(_) => {
270                    if let Some(err) = error {
271                        return Err(err);
272                    }
273                    return Ok(rows);
274                }
275                BackendMessage::ErrorResponse(err) => {
276                    if error.is_none() {
277                        error = Some(PgError::QueryServer(err.into()));
278                    }
279                }
280                _ => {}
281            }
282        }
283    }
284
285    /// ZERO-HASH sequential query using pre-computed PreparedStatement.
286    /// This is the FASTEST sequential path because it skips:
287    /// - SQL generation from AST (done once outside loop)
288    /// - Hash computation for statement name (pre-computed in PreparedStatement)
289    /// - HashMap lookup for is_new check (statement already prepared)
290    /// # Example
291    /// ```ignore
292    /// let stmt = conn.prepare("SELECT * FROM users WHERE id = $1").await?;
293    /// for id in 1..10000 {
294    ///     let rows = conn.query_prepared_single(&stmt, &[Some(id.to_string().into_bytes())]).await?;
295    /// }
296    /// ```
297    #[inline]
298    pub async fn query_prepared_single(
299        &mut self,
300        stmt: &super::PreparedStatement,
301        params: &[Option<Vec<u8>>],
302    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
303        self.query_prepared_single_with_result_format(stmt, params, PgEncoder::FORMAT_TEXT)
304            .await
305    }
306
307    /// ZERO-HASH sequential query with explicit result-column format.
308    #[inline]
309    pub async fn query_prepared_single_with_result_format(
310        &mut self,
311        stmt: &super::PreparedStatement,
312        params: &[Option<Vec<u8>>],
313        result_format: i16,
314    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
315        // Pre-calculate buffer size for single allocation
316        let params_size: usize = params
317            .iter()
318            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
319            .sum();
320
321        // Bind: ~15 + stmt.name.len() + params_size, Execute: 10, Sync: 5
322        let mut buf = BytesMut::with_capacity(30 + stmt.name.len() + params_size);
323
324        // ZERO HASH, ZERO LOOKUP - just encode and send!
325        PgEncoder::encode_bind_to_with_result_format(&mut buf, &stmt.name, params, result_format)
326            .map_err(|e| PgError::Encode(e.to_string()))?;
327        PgEncoder::encode_execute_to(&mut buf);
328        PgEncoder::encode_sync_to(&mut buf);
329
330        self.stream.write_all(&buf).await?;
331
332        let mut rows = Vec::new();
333
334        let mut error: Option<PgError> = None;
335
336        loop {
337            let msg = self.recv().await?;
338            match msg {
339                BackendMessage::BindComplete => {}
340                BackendMessage::RowDescription(_) => {}
341                BackendMessage::DataRow(data) => {
342                    if error.is_none() {
343                        rows.push(data);
344                    }
345                }
346                BackendMessage::CommandComplete(_) => {}
347                BackendMessage::NoData => {}
348                BackendMessage::ReadyForQuery(_) => {
349                    if let Some(err) = error {
350                        return Err(err);
351                    }
352                    return Ok(rows);
353                }
354                BackendMessage::ErrorResponse(err) => {
355                    if error.is_none() {
356                        error = Some(PgError::QueryServer(err.into()));
357                    }
358                }
359                _ => {}
360            }
361        }
362    }
363}