Skip to main content

qail_pg/driver/
query.rs

1//! Query execution methods for PostgreSQL connection.
2//!
3//! This module provides query, query_cached, and execute_simple.
4
5use super::{PgConnection, PgError, PgResult};
6use crate::protocol::{BackendMessage, PgEncoder};
7use bytes::BytesMut;
8use tokio::io::AsyncWriteExt;
9
10impl PgConnection {
11    /// Execute a query with binary parameters (crate-internal).
12    /// This uses the Extended Query Protocol (Parse/Bind/Execute/Sync):
13    /// - Parameters are sent as binary bytes, skipping the string layer
14    /// - No SQL injection possible - parameters are never interpolated
15    /// - Better performance via prepared statement reuse
16    pub(crate) async fn query(
17        &mut self,
18        sql: &str,
19        params: &[Option<Vec<u8>>],
20    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
21        let bytes = PgEncoder::encode_extended_query(sql, params)
22            .map_err(|e| PgError::Encode(e.to_string()))?;
23        self.stream.write_all(&bytes).await?;
24
25        let mut rows = Vec::new();
26
27        let mut error: Option<PgError> = None;
28
29        loop {
30            let msg = self.recv().await?;
31            match msg {
32                BackendMessage::ParseComplete => {}
33                BackendMessage::BindComplete => {}
34                BackendMessage::RowDescription(_) => {}
35                BackendMessage::DataRow(data) => {
36                    // Only collect rows if no error occurred
37                    if error.is_none() {
38                        rows.push(data);
39                    }
40                }
41                BackendMessage::CommandComplete(_) => {}
42                BackendMessage::NoData => {}
43                BackendMessage::ReadyForQuery(_) => {
44                    if let Some(err) = error {
45                        return Err(err);
46                    }
47                    return Ok(rows);
48                }
49                BackendMessage::ErrorResponse(err) => {
50                    if error.is_none() {
51                        error = Some(PgError::Query(err.message));
52                    }
53                }
54                _ => {}
55            }
56        }
57    }
58
59    /// Execute a query with cached prepared statement.
60    /// Like `query()`, but reuses prepared statements across calls.
61    /// The statement name is derived from a hash of the SQL text.
62    /// OPTIMIZED: Pre-allocated buffer + ultra-fast encoders.
63    pub async fn query_cached(
64        &mut self,
65        sql: &str,
66        params: &[Option<Vec<u8>>],
67    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
68        let stmt_name = Self::sql_to_stmt_name(sql);
69        let is_new = !self.prepared_statements.contains_key(&stmt_name);
70
71        // Pre-calculate buffer size for single allocation
72        let params_size: usize = params
73            .iter()
74            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
75            .sum();
76        
77        let estimated_size = if is_new {
78            50 + sql.len() + stmt_name.len() * 2 + params_size
79        } else {
80            30 + stmt_name.len() + params_size
81        };
82        
83        let mut buf = BytesMut::with_capacity(estimated_size);
84
85        if is_new {
86            // Evict LRU prepared statement if at capacity. This prevents
87            // unbounded memory growth from dynamic batch filters while
88            // preserving hot statements (unlike the old nuclear `.clear()`).
89            self.evict_prepared_if_full();
90            buf.extend(PgEncoder::encode_parse(&stmt_name, sql, &[]));
91            // Cache the SQL for debugging
92            self.prepared_statements.insert(stmt_name.clone(), sql.to_string());
93        }
94
95        // Use ULTRA-OPTIMIZED encoders - write directly to buffer
96        PgEncoder::encode_bind_to(&mut buf, &stmt_name, params)
97            .map_err(|e| PgError::Encode(e.to_string()))?;
98        PgEncoder::encode_execute_to(&mut buf);
99        PgEncoder::encode_sync_to(&mut buf);
100
101        self.stream.write_all(&buf).await?;
102
103        let mut rows = Vec::new();
104
105        let mut error: Option<PgError> = None;
106
107        loop {
108            let msg = self.recv().await?;
109            match msg {
110                BackendMessage::ParseComplete => {
111                    // Already cached in is_new block above
112                }
113                BackendMessage::BindComplete => {}
114                BackendMessage::RowDescription(_) => {}
115                BackendMessage::DataRow(data) => {
116                    if error.is_none() {
117                        rows.push(data);
118                    }
119                }
120                BackendMessage::CommandComplete(_) => {}
121                BackendMessage::NoData => {}
122                BackendMessage::ReadyForQuery(_) => {
123                    if let Some(err) = error {
124                        return Err(err);
125                    }
126                    return Ok(rows);
127                }
128                BackendMessage::ErrorResponse(err) => {
129                    if error.is_none() {
130                        error = Some(PgError::Query(err.message));
131                        // Invalidate cache to prevent "prepared statement does not exist"
132                        // on next retry.
133                        self.prepared_statements.remove(&stmt_name);
134                    }
135                }
136                _ => {}
137            }
138        }
139    }
140
141    /// Generate a statement name from SQL hash.
142    /// Uses a simple hash to create a unique name like "stmt_12345abc".
143    pub(crate) fn sql_to_stmt_name(sql: &str) -> String {
144        use std::collections::hash_map::DefaultHasher;
145        use std::hash::{Hash, Hasher};
146
147        let mut hasher = DefaultHasher::new();
148        sql.hash(&mut hasher);
149        format!("s{:016x}", hasher.finish())
150    }
151
152    /// Execute a simple SQL statement (no parameters).
153    pub async fn execute_simple(&mut self, sql: &str) -> PgResult<()> {
154        let bytes = PgEncoder::encode_query_string(sql);
155        self.stream.write_all(&bytes).await?;
156
157        let mut error: Option<PgError> = None;
158
159        loop {
160            let msg = self.recv().await?;
161            match msg {
162                BackendMessage::CommandComplete(_) => {}
163                BackendMessage::ReadyForQuery(_) => {
164                    if let Some(err) = error {
165                        return Err(err);
166                    }
167                    return Ok(());
168                }
169                BackendMessage::ErrorResponse(err) => {
170                    if error.is_none() {
171                        error = Some(PgError::Query(err.message));
172                    }
173                }
174                _ => {}
175            }
176        }
177    }
178
179    /// Execute a simple SQL query and return rows (Simple Query Protocol).
180    ///
181    /// Unlike `execute_simple`, this collects and returns data rows.
182    /// Used for branch management and other administrative queries.
183    ///
184    /// SECURITY: Capped at 10,000 rows to prevent OOM from unbounded results.
185    pub async fn simple_query(&mut self, sql: &str) -> PgResult<Vec<super::PgRow>> {
186        use std::sync::Arc;
187
188        /// Safety cap to prevent OOM from unbounded result accumulation.
189        /// Simple Query Protocol has no streaming; all rows are buffered in memory.
190        const MAX_SIMPLE_QUERY_ROWS: usize = 10_000;
191
192        let bytes = PgEncoder::encode_query_string(sql);
193        self.stream.write_all(&bytes).await?;
194
195        let mut rows: Vec<super::PgRow> = Vec::new();
196        let mut column_info: Option<Arc<super::ColumnInfo>> = None;
197        let mut error: Option<PgError> = None;
198
199        loop {
200            let msg = self.recv().await?;
201            match msg {
202                BackendMessage::RowDescription(fields) => {
203                    column_info = Some(Arc::new(super::ColumnInfo::from_fields(&fields)));
204                }
205                BackendMessage::DataRow(data) => {
206                    if error.is_none() {
207                        if rows.len() >= MAX_SIMPLE_QUERY_ROWS {
208                            if error.is_none() {
209                                error = Some(PgError::Query(format!(
210                                    "simple_query exceeded {} row safety cap",
211                                    MAX_SIMPLE_QUERY_ROWS,
212                                )));
213                            }
214                            // Continue draining to reach ReadyForQuery
215                        } else {
216                            rows.push(super::PgRow {
217                                columns: data,
218                                column_info: column_info.clone(),
219                            });
220                        }
221                    }
222                }
223                BackendMessage::CommandComplete(_) => {}
224                BackendMessage::ReadyForQuery(_) => {
225                    if let Some(err) = error {
226                        return Err(err);
227                    }
228                    return Ok(rows);
229                }
230                BackendMessage::ErrorResponse(err) => {
231                    if error.is_none() {
232                        error = Some(PgError::Query(err.message));
233                    }
234                }
235                _ => {}
236            }
237        }
238    }
239
240    /// ZERO-HASH sequential query using pre-computed PreparedStatement.
241    /// This is the FASTEST sequential path because it skips:
242    /// - SQL generation from AST (done once outside loop)
243    /// - Hash computation for statement name (pre-computed in PreparedStatement)
244    /// - HashMap lookup for is_new check (statement already prepared)
245    /// # Example
246    /// ```ignore
247    /// let stmt = conn.prepare("SELECT * FROM users WHERE id = $1").await?;
248    /// for id in 1..10000 {
249    ///     let rows = conn.query_prepared_single(&stmt, &[Some(id.to_string().into_bytes())]).await?;
250    /// }
251    /// ```
252    #[inline]
253    pub async fn query_prepared_single(
254        &mut self,
255        stmt: &super::PreparedStatement,
256        params: &[Option<Vec<u8>>],
257    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
258        // Pre-calculate buffer size for single allocation
259        let params_size: usize = params
260            .iter()
261            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
262            .sum();
263        
264        // Bind: ~15 + stmt.name.len() + params_size, Execute: 10, Sync: 5
265        let mut buf = BytesMut::with_capacity(30 + stmt.name.len() + params_size);
266
267        // ZERO HASH, ZERO LOOKUP - just encode and send!
268        PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
269            .map_err(|e| PgError::Encode(e.to_string()))?;
270        PgEncoder::encode_execute_to(&mut buf);
271        PgEncoder::encode_sync_to(&mut buf);
272
273        self.stream.write_all(&buf).await?;
274
275        let mut rows = Vec::new();
276
277        let mut error: Option<PgError> = None;
278
279        loop {
280            let msg = self.recv().await?;
281            match msg {
282                BackendMessage::BindComplete => {}
283                BackendMessage::RowDescription(_) => {}
284                BackendMessage::DataRow(data) => {
285                    if error.is_none() {
286                        rows.push(data);
287                    }
288                }
289                BackendMessage::CommandComplete(_) => {}
290                BackendMessage::NoData => {}
291                BackendMessage::ReadyForQuery(_) => {
292                    if let Some(err) = error {
293                        return Err(err);
294                    }
295                    return Ok(rows);
296                }
297                BackendMessage::ErrorResponse(err) => {
298                    if error.is_none() {
299                        error = Some(PgError::Query(err.message));
300                    }
301                }
302                _ => {}
303            }
304        }
305    }
306}