Skip to main content

qail_pg/driver/
query.rs

1//! Query execution methods for PostgreSQL connection.
2//!
3//! This module provides query, query_cached, and execute_simple.
4
5use super::{
6    PgConnection, PgError, PgResult,
7    extended_flow::{ExtendedFlowConfig, ExtendedFlowTracker},
8    is_ignorable_session_message, unexpected_backend_message,
9};
10use crate::protocol::{BackendMessage, PgEncoder};
11use bytes::BytesMut;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum SimpleStatementState {
15    AwaitingResult,
16    InRowStream,
17}
18
19#[derive(Debug, Clone, Copy)]
20struct SimpleFlowTracker {
21    state: SimpleStatementState,
22    saw_completion: bool,
23}
24
25impl SimpleFlowTracker {
26    fn new() -> Self {
27        Self {
28            state: SimpleStatementState::AwaitingResult,
29            saw_completion: false,
30        }
31    }
32
33    fn on_row_description(&mut self, context: &'static str) -> PgResult<()> {
34        if self.state == SimpleStatementState::InRowStream {
35            return Err(PgError::Protocol(format!(
36                "{}: duplicate RowDescription before statement completion",
37                context
38            )));
39        }
40        self.state = SimpleStatementState::InRowStream;
41        self.saw_completion = false;
42        Ok(())
43    }
44
45    fn on_data_row(&self, context: &'static str) -> PgResult<()> {
46        if self.state != SimpleStatementState::InRowStream {
47            return Err(PgError::Protocol(format!(
48                "{}: DataRow before RowDescription",
49                context
50            )));
51        }
52        Ok(())
53    }
54
55    fn on_command_complete(&mut self) {
56        self.state = SimpleStatementState::AwaitingResult;
57        self.saw_completion = true;
58    }
59
60    fn on_empty_query_response(&mut self, context: &'static str) -> PgResult<()> {
61        if self.state == SimpleStatementState::InRowStream {
62            return Err(PgError::Protocol(format!(
63                "{}: EmptyQueryResponse during active row stream",
64                context
65            )));
66        }
67        self.saw_completion = true;
68        Ok(())
69    }
70
71    fn on_ready_for_query(&self, context: &'static str, error_pending: bool) -> PgResult<()> {
72        if error_pending {
73            return Ok(());
74        }
75        if self.state == SimpleStatementState::InRowStream {
76            return Err(PgError::Protocol(format!(
77                "{}: ReadyForQuery before CommandComplete",
78                context
79            )));
80        }
81        if !self.saw_completion {
82            return Err(PgError::Protocol(format!(
83                "{}: ReadyForQuery before completion",
84                context
85            )));
86        }
87        Ok(())
88    }
89}
90
91impl PgConnection {
92    /// Execute a query with binary parameters (crate-internal).
93    /// This uses the Extended Query Protocol (Parse/Bind/Execute/Sync):
94    /// - Parameters are sent as binary bytes, skipping the string layer
95    /// - No SQL injection possible - parameters are never interpolated
96    /// - Better performance via prepared statement reuse
97    pub(crate) async fn query(
98        &mut self,
99        sql: &str,
100        params: &[Option<Vec<u8>>],
101    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
102        self.query_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
103            .await
104    }
105
106    /// Execute a query with binary parameters and explicit result-column format.
107    pub(crate) async fn query_with_result_format(
108        &mut self,
109        sql: &str,
110        params: &[Option<Vec<u8>>],
111        result_format: i16,
112    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
113        let bytes = PgEncoder::encode_extended_query_with_result_format(sql, params, result_format)
114            .map_err(|e| PgError::Encode(e.to_string()))?;
115        self.write_all_with_timeout(&bytes, "stream write").await?;
116
117        let mut rows = Vec::new();
118
119        let mut error: Option<PgError> = None;
120        let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(true));
121
122        loop {
123            let msg = self.recv().await?;
124            flow.validate(&msg, "extended-query execute", error.is_some())?;
125            match msg {
126                BackendMessage::ParseComplete => {}
127                BackendMessage::BindComplete => {}
128                BackendMessage::RowDescription(_) => {}
129                BackendMessage::DataRow(data) => {
130                    // Only collect rows if no error occurred
131                    if error.is_none() {
132                        rows.push(data);
133                    }
134                }
135                BackendMessage::CommandComplete(_) => {}
136                BackendMessage::NoData => {}
137                BackendMessage::ReadyForQuery(_) => {
138                    if let Some(err) = error {
139                        return Err(err);
140                    }
141                    return Ok(rows);
142                }
143                BackendMessage::ErrorResponse(err) => {
144                    if error.is_none() {
145                        error = Some(PgError::QueryServer(err.into()));
146                    }
147                }
148                msg if is_ignorable_session_message(&msg) => {}
149                other => {
150                    return Err(unexpected_backend_message("extended-query execute", &other));
151                }
152            }
153        }
154    }
155
156    /// Execute a query with cached prepared statement.
157    /// Like `query()`, but reuses prepared statements across calls.
158    /// The statement name is derived from a hash of the SQL text.
159    /// OPTIMIZED: Pre-allocated buffer + ultra-fast encoders.
160    pub async fn query_cached(
161        &mut self,
162        sql: &str,
163        params: &[Option<Vec<u8>>],
164    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
165        self.query_cached_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
166            .await
167    }
168
169    /// Execute a query with cached prepared statement and explicit result-column format.
170    pub async fn query_cached_with_result_format(
171        &mut self,
172        sql: &str,
173        params: &[Option<Vec<u8>>],
174        result_format: i16,
175    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
176        let mut retried = false;
177        loop {
178            match self
179                .query_cached_with_result_format_once(sql, params, result_format)
180                .await
181            {
182                Ok(rows) => return Ok(rows),
183                Err(err)
184                    if !retried
185                        && (err.is_prepared_statement_retryable()
186                            || err.is_prepared_statement_already_exists()) =>
187                {
188                    retried = true;
189                    if err.is_prepared_statement_retryable() {
190                        self.clear_prepared_statement_state();
191                    }
192                }
193                Err(err) => return Err(err),
194            }
195        }
196    }
197
198    async fn query_cached_with_result_format_once(
199        &mut self,
200        sql: &str,
201        params: &[Option<Vec<u8>>],
202        result_format: i16,
203    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
204        let stmt_name = Self::sql_to_stmt_name(sql);
205        let is_new = !self.prepared_statements.contains_key(&stmt_name);
206
207        // Pre-calculate buffer size for single allocation
208        let params_size: usize = params
209            .iter()
210            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
211            .sum();
212
213        let estimated_size = if is_new {
214            50 + sql.len() + stmt_name.len() * 2 + params_size
215        } else {
216            30 + stmt_name.len() + params_size
217        };
218
219        let mut buf = BytesMut::with_capacity(estimated_size);
220
221        if is_new {
222            // Evict LRU prepared statement if at capacity. This prevents
223            // unbounded memory growth from dynamic batch filters while
224            // preserving hot statements (unlike the old nuclear `.clear()`).
225            self.evict_prepared_if_full();
226            buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
227            // Cache the SQL for debugging
228            self.prepared_statements
229                .insert(stmt_name.clone(), sql.to_string());
230        }
231
232        // Use ULTRA-OPTIMIZED encoders - write directly to buffer
233        if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
234            &mut buf,
235            &stmt_name,
236            params,
237            result_format,
238        ) {
239            if is_new {
240                self.prepared_statements.remove(&stmt_name);
241            }
242            return Err(PgError::Encode(e.to_string()));
243        }
244        PgEncoder::encode_execute_to(&mut buf);
245        PgEncoder::encode_sync_to(&mut buf);
246
247        if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
248            if is_new {
249                self.prepared_statements.remove(&stmt_name);
250            }
251            return Err(err);
252        }
253
254        let mut rows = Vec::new();
255
256        let mut error: Option<PgError> = None;
257        let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(is_new));
258
259        loop {
260            let msg = match self.recv().await {
261                Ok(msg) => msg,
262                Err(err) => {
263                    if is_new && !flow.saw_parse_complete() {
264                        self.prepared_statements.remove(&stmt_name);
265                    }
266                    return Err(err);
267                }
268            };
269            if let Err(err) = flow.validate(&msg, "extended-query cached execute", error.is_some())
270            {
271                if is_new && !flow.saw_parse_complete() {
272                    self.prepared_statements.remove(&stmt_name);
273                }
274                return Err(err);
275            }
276            match msg {
277                BackendMessage::ParseComplete => {
278                    // Already cached in is_new block above.
279                }
280                BackendMessage::BindComplete => {}
281                BackendMessage::RowDescription(_) => {}
282                BackendMessage::DataRow(data) => {
283                    if error.is_none() {
284                        rows.push(data);
285                    }
286                }
287                BackendMessage::CommandComplete(_) => {}
288                BackendMessage::NoData => {}
289                BackendMessage::ReadyForQuery(_) => {
290                    if let Some(err) = error {
291                        if is_new
292                            && !flow.saw_parse_complete()
293                            && !err.is_prepared_statement_already_exists()
294                        {
295                            self.prepared_statements.remove(&stmt_name);
296                        }
297                        return Err(err);
298                    }
299                    if is_new && !flow.saw_parse_complete() {
300                        self.prepared_statements.remove(&stmt_name);
301                        return Err(PgError::Protocol(
302                            "Cache miss query reached ReadyForQuery without ParseComplete"
303                                .to_string(),
304                        ));
305                    }
306                    return Ok(rows);
307                }
308                BackendMessage::ErrorResponse(err) => {
309                    if error.is_none() {
310                        let query_err = PgError::QueryServer(err.into());
311                        if !query_err.is_prepared_statement_already_exists() {
312                            // Invalidate cache to prevent stale local mapping after parse failure.
313                            self.prepared_statements.remove(&stmt_name);
314                        }
315                        error = Some(query_err);
316                    }
317                }
318                msg if is_ignorable_session_message(&msg) => {}
319                other => {
320                    if is_new && !flow.saw_parse_complete() {
321                        self.prepared_statements.remove(&stmt_name);
322                    }
323                    return Err(unexpected_backend_message(
324                        "extended-query cached execute",
325                        &other,
326                    ));
327                }
328            }
329        }
330    }
331
332    /// Generate a statement name from SQL hash.
333    /// Uses a simple hash to create a unique name like "stmt_12345abc".
334    pub(crate) fn sql_to_stmt_name(sql: &str) -> String {
335        use std::collections::hash_map::DefaultHasher;
336        use std::hash::{Hash, Hasher};
337
338        let mut hasher = DefaultHasher::new();
339        sql.hash(&mut hasher);
340        format!("s{:016x}", hasher.finish())
341    }
342
343    /// Execute a simple SQL statement (no parameters).
344    pub async fn execute_simple(&mut self, sql: &str) -> PgResult<()> {
345        let bytes = PgEncoder::try_encode_query_string(sql)?;
346        self.write_all_with_timeout(&bytes, "stream write").await?;
347
348        let mut error: Option<PgError> = None;
349        let mut flow = SimpleFlowTracker::new();
350
351        loop {
352            let msg = self.recv().await?;
353            match msg {
354                BackendMessage::CommandComplete(_) => {
355                    flow.on_command_complete();
356                }
357                BackendMessage::EmptyQueryResponse => {
358                    flow.on_empty_query_response("simple-query execute")?;
359                }
360                BackendMessage::ReadyForQuery(_) => {
361                    if let Some(err) = error {
362                        return Err(err);
363                    }
364                    flow.on_ready_for_query("simple-query execute", error.is_some())?;
365                    return Ok(());
366                }
367                BackendMessage::ErrorResponse(err) => {
368                    if error.is_none() {
369                        error = Some(PgError::QueryServer(err.into()));
370                    }
371                }
372                msg if is_ignorable_session_message(&msg) => {}
373                other => {
374                    return Err(unexpected_backend_message("simple-query execute", &other));
375                }
376            }
377        }
378    }
379
380    /// Execute a simple SQL query and return rows (Simple Query Protocol).
381    ///
382    /// Unlike `execute_simple`, this collects and returns data rows.
383    /// Used for branch management and other administrative queries.
384    ///
385    /// SECURITY: Capped at 10,000 rows to prevent OOM from unbounded results.
386    pub async fn simple_query(&mut self, sql: &str) -> PgResult<Vec<super::PgRow>> {
387        use std::sync::Arc;
388
389        /// Safety cap to prevent OOM from unbounded result accumulation.
390        /// Simple Query Protocol has no streaming; all rows are buffered in memory.
391        const MAX_SIMPLE_QUERY_ROWS: usize = 10_000;
392
393        let bytes = PgEncoder::try_encode_query_string(sql)?;
394        self.write_all_with_timeout(&bytes, "stream write").await?;
395
396        let mut rows: Vec<super::PgRow> = Vec::new();
397        let mut column_info: Option<Arc<super::ColumnInfo>> = None;
398        let mut error: Option<PgError> = None;
399        let mut flow = SimpleFlowTracker::new();
400
401        loop {
402            let msg = self.recv().await?;
403            match msg {
404                BackendMessage::RowDescription(fields) => {
405                    flow.on_row_description("simple-query read")?;
406                    column_info = Some(Arc::new(super::ColumnInfo::from_fields(&fields)));
407                }
408                BackendMessage::DataRow(data) => {
409                    flow.on_data_row("simple-query read")?;
410                    if error.is_none() {
411                        if rows.len() >= MAX_SIMPLE_QUERY_ROWS {
412                            if error.is_none() {
413                                error = Some(PgError::Query(format!(
414                                    "simple_query exceeded {} row safety cap",
415                                    MAX_SIMPLE_QUERY_ROWS,
416                                )));
417                            }
418                            // Continue draining to reach ReadyForQuery
419                        } else {
420                            rows.push(super::PgRow {
421                                columns: data,
422                                column_info: column_info.clone(),
423                            });
424                        }
425                    }
426                }
427                BackendMessage::CommandComplete(_) => {
428                    flow.on_command_complete();
429                    column_info = None;
430                }
431                BackendMessage::EmptyQueryResponse => {
432                    flow.on_empty_query_response("simple-query read")?;
433                    column_info = None;
434                }
435                BackendMessage::ReadyForQuery(_) => {
436                    if let Some(err) = error {
437                        return Err(err);
438                    }
439                    flow.on_ready_for_query("simple-query read", error.is_some())?;
440                    return Ok(rows);
441                }
442                BackendMessage::ErrorResponse(err) => {
443                    if error.is_none() {
444                        error = Some(PgError::QueryServer(err.into()));
445                    }
446                }
447                msg if is_ignorable_session_message(&msg) => {}
448                other => {
449                    return Err(unexpected_backend_message("simple-query read", &other));
450                }
451            }
452        }
453    }
454
455    /// ZERO-HASH sequential query using pre-computed PreparedStatement.
456    /// This is the FASTEST sequential path because it skips:
457    /// - SQL generation from AST (done once outside loop)
458    /// - Hash computation for statement name (pre-computed in PreparedStatement)
459    /// - HashMap lookup for is_new check (statement already prepared)
460    /// # Example
461    /// ```ignore
462    /// let stmt = conn.prepare("SELECT * FROM users WHERE id = $1").await?;
463    /// for id in 1..10000 {
464    ///     let rows = conn.query_prepared_single(&stmt, &[Some(id.to_string().into_bytes())]).await?;
465    /// }
466    /// ```
467    #[inline]
468    pub async fn query_prepared_single(
469        &mut self,
470        stmt: &super::PreparedStatement,
471        params: &[Option<Vec<u8>>],
472    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
473        self.query_prepared_single_with_result_format(stmt, params, PgEncoder::FORMAT_TEXT)
474            .await
475    }
476
477    /// ZERO-HASH sequential query with explicit result-column format.
478    #[inline]
479    pub async fn query_prepared_single_with_result_format(
480        &mut self,
481        stmt: &super::PreparedStatement,
482        params: &[Option<Vec<u8>>],
483        result_format: i16,
484    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
485        // Pre-calculate buffer size for single allocation
486        let params_size: usize = params
487            .iter()
488            .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
489            .sum();
490
491        // Bind: ~15 + stmt.name.len() + params_size, Execute: 10, Sync: 5
492        let mut buf = BytesMut::with_capacity(30 + stmt.name.len() + params_size);
493
494        // ZERO HASH, ZERO LOOKUP - just encode and send!
495        PgEncoder::encode_bind_to_with_result_format(&mut buf, &stmt.name, params, result_format)
496            .map_err(|e| PgError::Encode(e.to_string()))?;
497        PgEncoder::encode_execute_to(&mut buf);
498        PgEncoder::encode_sync_to(&mut buf);
499
500        self.write_all_with_timeout(&buf, "stream write").await?;
501
502        let mut rows = Vec::new();
503
504        let mut error: Option<PgError> = None;
505        let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(false));
506
507        loop {
508            let msg = self.recv().await?;
509            flow.validate(&msg, "prepared single execute", error.is_some())?;
510            match msg {
511                BackendMessage::BindComplete => {}
512                BackendMessage::RowDescription(_) => {}
513                BackendMessage::DataRow(data) => {
514                    if error.is_none() {
515                        rows.push(data);
516                    }
517                }
518                BackendMessage::CommandComplete(_) => {}
519                BackendMessage::NoData => {}
520                BackendMessage::ReadyForQuery(_) => {
521                    if let Some(err) = error {
522                        return Err(err);
523                    }
524                    return Ok(rows);
525                }
526                BackendMessage::ErrorResponse(err) => {
527                    if error.is_none() {
528                        error = Some(PgError::QueryServer(err.into()));
529                    }
530                }
531                msg if is_ignorable_session_message(&msg) => {}
532                other => {
533                    return Err(unexpected_backend_message(
534                        "prepared single execute",
535                        &other,
536                    ));
537                }
538            }
539        }
540    }
541}