Skip to main content

qail_pg/driver/
query.rs

1//! Query execution methods for PostgreSQL connection.
2//!
3//! This module provides query, query_cached, and execute_simple.
4
5use super::{
6    PgConnection, PgError, PgResult,
7    extended_flow::{ExtendedFlowConfig, ExtendedFlowTracker},
8    is_ignorable_session_message, unexpected_backend_message,
9};
10use crate::protocol::{BackendMessage, PgEncoder};
11use bytes::BytesMut;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum SimpleStatementState {
15    AwaitingResult,
16    InRowStream,
17}
18
19#[derive(Debug, Clone, Copy)]
20struct SimpleFlowTracker {
21    state: SimpleStatementState,
22    saw_completion: bool,
23}
24
25impl SimpleFlowTracker {
26    fn new() -> Self {
27        Self {
28            state: SimpleStatementState::AwaitingResult,
29            saw_completion: false,
30        }
31    }
32
33    fn on_row_description(&mut self, context: &'static str) -> PgResult<()> {
34        if self.state == SimpleStatementState::InRowStream {
35            return Err(PgError::Protocol(format!(
36                "{}: duplicate RowDescription before statement completion",
37                context
38            )));
39        }
40        self.state = SimpleStatementState::InRowStream;
41        self.saw_completion = false;
42        Ok(())
43    }
44
45    fn on_data_row(&self, context: &'static str) -> PgResult<()> {
46        if self.state != SimpleStatementState::InRowStream {
47            return Err(PgError::Protocol(format!(
48                "{}: DataRow before RowDescription",
49                context
50            )));
51        }
52        Ok(())
53    }
54
55    fn on_command_complete(&mut self) {
56        self.state = SimpleStatementState::AwaitingResult;
57        self.saw_completion = true;
58    }
59
60    fn on_empty_query_response(&mut self, context: &'static str) -> PgResult<()> {
61        if self.state == SimpleStatementState::InRowStream {
62            return Err(PgError::Protocol(format!(
63                "{}: EmptyQueryResponse during active row stream",
64                context
65            )));
66        }
67        self.saw_completion = true;
68        Ok(())
69    }
70
71    fn on_ready_for_query(&self, context: &'static str, error_pending: bool) -> PgResult<()> {
72        if error_pending {
73            return Ok(());
74        }
75        if self.state == SimpleStatementState::InRowStream {
76            return Err(PgError::Protocol(format!(
77                "{}: ReadyForQuery before CommandComplete",
78                context
79            )));
80        }
81        if !self.saw_completion {
82            return Err(PgError::Protocol(format!(
83                "{}: ReadyForQuery before completion",
84                context
85            )));
86        }
87        Ok(())
88    }
89}
90
91impl PgConnection {
92    /// Execute a query with binary parameters (crate-internal).
93    /// This uses the Extended Query Protocol (Parse/Bind/Execute/Sync):
94    /// - Parameters are sent as binary bytes, skipping the string layer
95    /// - No SQL injection possible - parameters are never interpolated
96    /// - Better performance via prepared statement reuse
97    pub(crate) async fn query(
98        &mut self,
99        sql: &str,
100        params: &[Option<Vec<u8>>],
101    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
102        self.query_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
103            .await
104    }
105
106    /// Execute a query with binary parameters and explicit result-column format.
107    pub(crate) async fn query_with_result_format(
108        &mut self,
109        sql: &str,
110        params: &[Option<Vec<u8>>],
111        result_format: i16,
112    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
113        let bytes = PgEncoder::encode_extended_query_with_result_format(sql, params, result_format)
114            .map_err(|e| PgError::Encode(e.to_string()))?;
115        self.write_all_with_timeout(&bytes, "stream write").await?;
116
117        let mut rows = Vec::new();
118
119        let mut error: Option<PgError> = None;
120        let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(true));
121
122        loop {
123            let msg = self.recv().await?;
124            flow.validate(&msg, "extended-query execute", error.is_some())?;
125            match msg {
126                BackendMessage::ParseComplete => {}
127                BackendMessage::BindComplete => {}
128                BackendMessage::RowDescription(_) => {}
129                BackendMessage::DataRow(data) => {
130                    // Only collect rows if no error occurred
131                    if error.is_none() {
132                        rows.push(data);
133                    }
134                }
135                BackendMessage::CommandComplete(_) => {}
136                BackendMessage::NoData => {}
137                BackendMessage::ReadyForQuery(_) => {
138                    if let Some(err) = error {
139                        return Err(err);
140                    }
141                    return Ok(rows);
142                }
143                BackendMessage::ErrorResponse(err) => {
144                    if error.is_none() {
145                        error = Some(PgError::QueryServer(err.into()));
146                    }
147                }
148                msg if is_ignorable_session_message(&msg) => {}
149                other => {
150                    return Err(unexpected_backend_message("extended-query execute", &other));
151                }
152            }
153        }
154    }
155
156    /// Execute a query with cached prepared statement.
157    /// Like `query()`, but reuses prepared statements across calls.
158    /// The statement name is derived from a hash of the SQL text.
159    /// OPTIMIZED: Pre-allocated buffer + ultra-fast encoders.
160    pub async fn query_cached(
161        &mut self,
162        sql: &str,
163        params: &[Option<Vec<u8>>],
164    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
165        self.query_cached_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
166            .await
167    }
168
169    /// Execute a query with cached prepared statement and explicit result-column format.
170    pub async fn query_cached_with_result_format(
171        &mut self,
172        sql: &str,
173        params: &[Option<Vec<u8>>],
174        result_format: i16,
175    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
176        let mut retried = false;
177        loop {
178            match self
179                .query_cached_with_result_format_once(sql, params, result_format)
180                .await
181            {
182                Ok(rows) => return Ok(rows),
183                Err(err)
184                    if !retried
185                        && (err.is_prepared_statement_retryable()
186                            || err.is_prepared_statement_already_exists()) =>
187                {
188                    retried = true;
189                    if err.is_prepared_statement_retryable() {
190                        self.clear_prepared_statement_state();
191                    }
192                }
193                Err(err) => return Err(err),
194            }
195        }
196    }
197
198    async fn query_cached_with_result_format_once(
199        &mut self,
200        sql: &str,
201        params: &[Option<Vec<u8>>],
202        result_format: i16,
203    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
204        let stmt_name = Self::sql_to_stmt_name(sql);
205        let is_new = !self.prepared_statements.contains_key(&stmt_name);
206
207        // Pre-calculate buffer size for single allocation
208        let params_size: usize = params
209            .iter()
210            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
211            .sum();
212
213        let estimated_size = if is_new {
214            50 + sql.len() + stmt_name.len() * 2 + params_size
215        } else {
216            30 + stmt_name.len() + params_size
217        };
218
219        let mut buf = BytesMut::with_capacity(estimated_size);
220
221        if is_new {
222            // Evict LRU prepared statement if at capacity. This prevents
223            // unbounded memory growth from dynamic batch filters while
224            // preserving hot statements (unlike the old nuclear `.clear()`).
225            self.evict_prepared_if_full();
226            buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
227            // Cache the SQL for debugging
228            self.prepared_statements
229                .insert(stmt_name.clone(), sql.to_string());
230        }
231
232        // Use ULTRA-OPTIMIZED encoders - write directly to buffer
233        if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
234            &mut buf,
235            &stmt_name,
236            params,
237            result_format,
238        ) {
239            if is_new {
240                self.prepared_statements.remove(&stmt_name);
241            }
242            return Err(PgError::Encode(e.to_string()));
243        }
244        PgEncoder::encode_execute_to(&mut buf);
245        PgEncoder::encode_sync_to(&mut buf);
246
247        if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
248            if is_new {
249                self.prepared_statements.remove(&stmt_name);
250            }
251            return Err(err);
252        }
253
254        let mut rows = Vec::new();
255
256        let mut error: Option<PgError> = None;
257        let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(is_new));
258
259        loop {
260            let msg = match self.recv().await {
261                Ok(msg) => msg,
262                Err(err) => {
263                    if is_new && !flow.saw_parse_complete() {
264                        self.prepared_statements.remove(&stmt_name);
265                    }
266                    return Err(err);
267                }
268            };
269            if let Err(err) = flow.validate(&msg, "extended-query cached execute", error.is_some())
270            {
271                if is_new && !flow.saw_parse_complete() {
272                    self.prepared_statements.remove(&stmt_name);
273                }
274                return Err(err);
275            }
276            match msg {
277                BackendMessage::ParseComplete => {
278                    // Already cached in is_new block above.
279                }
280                BackendMessage::BindComplete => {}
281                BackendMessage::RowDescription(_) => {}
282                BackendMessage::DataRow(data) => {
283                    if error.is_none() {
284                        rows.push(data);
285                    }
286                }
287                BackendMessage::CommandComplete(_) => {}
288                BackendMessage::NoData => {}
289                BackendMessage::ReadyForQuery(_) => {
290                    if let Some(err) = error {
291                        if is_new
292                            && !flow.saw_parse_complete()
293                            && !err.is_prepared_statement_already_exists()
294                        {
295                            self.prepared_statements.remove(&stmt_name);
296                        }
297                        return Err(err);
298                    }
299                    if is_new && !flow.saw_parse_complete() {
300                        self.prepared_statements.remove(&stmt_name);
301                        return Err(PgError::Protocol(
302                            "Cache miss query reached ReadyForQuery without ParseComplete"
303                                .to_string(),
304                        ));
305                    }
306                    return Ok(rows);
307                }
308                BackendMessage::ErrorResponse(err) => {
309                    if error.is_none() {
310                        let query_err = PgError::QueryServer(err.into());
311                        if !query_err.is_prepared_statement_already_exists() {
312                            // Invalidate cache to prevent stale local mapping after parse failure.
313                            self.prepared_statements.remove(&stmt_name);
314                        }
315                        error = Some(query_err);
316                    }
317                }
318                msg if is_ignorable_session_message(&msg) => {}
319                other => {
320                    if is_new && !flow.saw_parse_complete() {
321                        self.prepared_statements.remove(&stmt_name);
322                    }
323                    return Err(unexpected_backend_message(
324                        "extended-query cached execute",
325                        &other,
326                    ));
327                }
328            }
329        }
330    }
331
332    /// Generate a statement name from SQL hash.
333    /// Uses a simple hash to create a unique name like "stmt_12345abc".
334    pub(crate) fn sql_to_stmt_name(sql: &str) -> String {
335        use std::collections::hash_map::DefaultHasher;
336        use std::hash::{Hash, Hasher};
337
338        let mut hasher = DefaultHasher::new();
339        sql.hash(&mut hasher);
340        format!("s{:016x}", hasher.finish())
341    }
342
343    /// Execute a simple SQL statement (no parameters).
344    pub async fn execute_simple(&mut self, sql: &str) -> PgResult<()> {
345        let bytes = PgEncoder::try_encode_query_string(sql)?;
346        self.write_all_with_timeout(&bytes, "stream write").await?;
347
348        let mut error: Option<PgError> = None;
349        let mut flow = SimpleFlowTracker::new();
350
351        loop {
352            let msg = self.recv().await?;
353            match msg {
354                BackendMessage::RowDescription(_) => {
355                    // Some callers use execute_simple() with session-shaping SQL that
356                    // can legally return rows (e.g., SELECT set_config(...)).
357                    // Drain and ignore row data while preserving protocol ordering checks.
358                    flow.on_row_description("simple-query execute")?;
359                }
360                BackendMessage::DataRow(_) => {
361                    flow.on_data_row("simple-query execute")?;
362                }
363                BackendMessage::CommandComplete(_) => {
364                    flow.on_command_complete();
365                }
366                BackendMessage::EmptyQueryResponse => {
367                    flow.on_empty_query_response("simple-query execute")?;
368                }
369                BackendMessage::ReadyForQuery(_) => {
370                    if let Some(err) = error {
371                        return Err(err);
372                    }
373                    flow.on_ready_for_query("simple-query execute", error.is_some())?;
374                    return Ok(());
375                }
376                BackendMessage::ErrorResponse(err) => {
377                    if error.is_none() {
378                        error = Some(PgError::QueryServer(err.into()));
379                    }
380                }
381                msg if is_ignorable_session_message(&msg) => {}
382                other => {
383                    return Err(unexpected_backend_message("simple-query execute", &other));
384                }
385            }
386        }
387    }
388
389    /// Execute a simple SQL query and return rows (Simple Query Protocol).
390    ///
391    /// Unlike `execute_simple`, this collects and returns data rows.
392    /// Used for branch management and other administrative queries.
393    ///
394    /// SECURITY: Capped at 10,000 rows to prevent OOM from unbounded results.
395    pub async fn simple_query(&mut self, sql: &str) -> PgResult<Vec<super::PgRow>> {
396        use std::sync::Arc;
397
398        /// Safety cap to prevent OOM from unbounded result accumulation.
399        /// Simple Query Protocol has no streaming; all rows are buffered in memory.
400        const MAX_SIMPLE_QUERY_ROWS: usize = 10_000;
401
402        let bytes = PgEncoder::try_encode_query_string(sql)?;
403        self.write_all_with_timeout(&bytes, "stream write").await?;
404
405        let mut rows: Vec<super::PgRow> = Vec::new();
406        let mut column_info: Option<Arc<super::ColumnInfo>> = None;
407        let mut error: Option<PgError> = None;
408        let mut flow = SimpleFlowTracker::new();
409
410        loop {
411            let msg = self.recv().await?;
412            match msg {
413                BackendMessage::RowDescription(fields) => {
414                    flow.on_row_description("simple-query read")?;
415                    column_info = Some(Arc::new(super::ColumnInfo::from_fields(&fields)));
416                }
417                BackendMessage::DataRow(data) => {
418                    flow.on_data_row("simple-query read")?;
419                    if error.is_none() {
420                        if rows.len() >= MAX_SIMPLE_QUERY_ROWS {
421                            if error.is_none() {
422                                error = Some(PgError::Query(format!(
423                                    "simple_query exceeded {} row safety cap",
424                                    MAX_SIMPLE_QUERY_ROWS,
425                                )));
426                            }
427                            // Continue draining to reach ReadyForQuery
428                        } else {
429                            rows.push(super::PgRow {
430                                columns: data,
431                                column_info: column_info.clone(),
432                            });
433                        }
434                    }
435                }
436                BackendMessage::CommandComplete(_) => {
437                    flow.on_command_complete();
438                    column_info = None;
439                }
440                BackendMessage::EmptyQueryResponse => {
441                    flow.on_empty_query_response("simple-query read")?;
442                    column_info = None;
443                }
444                BackendMessage::ReadyForQuery(_) => {
445                    if let Some(err) = error {
446                        return Err(err);
447                    }
448                    flow.on_ready_for_query("simple-query read", error.is_some())?;
449                    return Ok(rows);
450                }
451                BackendMessage::ErrorResponse(err) => {
452                    if error.is_none() {
453                        error = Some(PgError::QueryServer(err.into()));
454                    }
455                }
456                msg if is_ignorable_session_message(&msg) => {}
457                other => {
458                    return Err(unexpected_backend_message("simple-query read", &other));
459                }
460            }
461        }
462    }
463
464    /// ZERO-HASH sequential query using pre-computed PreparedStatement.
465    /// This is the FASTEST sequential path because it skips:
466    /// - SQL generation from AST (done once outside loop)
467    /// - Hash computation for statement name (pre-computed in PreparedStatement)
468    /// - HashMap lookup for is_new check (statement already prepared)
469    /// # Example
470    /// ```ignore
471    /// let stmt = conn.prepare("SELECT * FROM users WHERE id = $1").await?;
472    /// for id in 1..10000 {
473    ///     let rows = conn.query_prepared_single(&stmt, &[Some(id.to_string().into_bytes())]).await?;
474    /// }
475    /// ```
476    #[inline]
477    pub async fn query_prepared_single(
478        &mut self,
479        stmt: &super::PreparedStatement,
480        params: &[Option<Vec<u8>>],
481    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
482        self.query_prepared_single_with_result_format(stmt, params, PgEncoder::FORMAT_TEXT)
483            .await
484    }
485
486    /// ZERO-HASH sequential query with explicit result-column format.
487    #[inline]
488    pub async fn query_prepared_single_with_result_format(
489        &mut self,
490        stmt: &super::PreparedStatement,
491        params: &[Option<Vec<u8>>],
492        result_format: i16,
493    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
494        // Pre-calculate buffer size for single allocation
495        let params_size: usize = params
496            .iter()
497            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
498            .sum();
499
500        // Bind: ~15 + stmt.name.len() + params_size, Execute: 10, Sync: 5
501        let mut buf = BytesMut::with_capacity(30 + stmt.name.len() + params_size);
502
503        // ZERO HASH, ZERO LOOKUP - just encode and send!
504        PgEncoder::encode_bind_to_with_result_format(&mut buf, &stmt.name, params, result_format)
505            .map_err(|e| PgError::Encode(e.to_string()))?;
506        PgEncoder::encode_execute_to(&mut buf);
507        PgEncoder::encode_sync_to(&mut buf);
508
509        self.write_all_with_timeout(&buf, "stream write").await?;
510
511        let mut rows = Vec::new();
512
513        let mut error: Option<PgError> = None;
514        let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(false));
515
516        loop {
517            let msg = self.recv().await?;
518            flow.validate(&msg, "prepared single execute", error.is_some())?;
519            match msg {
520                BackendMessage::BindComplete => {}
521                BackendMessage::RowDescription(_) => {}
522                BackendMessage::DataRow(data) => {
523                    if error.is_none() {
524                        rows.push(data);
525                    }
526                }
527                BackendMessage::CommandComplete(_) => {}
528                BackendMessage::NoData => {}
529                BackendMessage::ReadyForQuery(_) => {
530                    if let Some(err) = error {
531                        return Err(err);
532                    }
533                    return Ok(rows);
534                }
535                BackendMessage::ErrorResponse(err) => {
536                    if error.is_none() {
537                        error = Some(PgError::QueryServer(err.into()));
538                    }
539                }
540                msg if is_ignorable_session_message(&msg) => {}
541                other => {
542                    return Err(unexpected_backend_message(
543                        "prepared single execute",
544                        &other,
545                    ));
546                }
547            }
548        }
549    }
550}