Skip to main content

qail_pg/driver/pool/
fetch.rs

1//! Fetch methods for PooledConnection: uncached, fast, cached, typed, and pipelined-RLS variants.
2
3use super::connection::PooledConnection;
4use super::lifecycle::MAX_HOT_STATEMENTS;
5use crate::driver::{
6    PgConnection, PgError, PgResult, ResultFormat,
7    extended_flow::{ExtendedFlowConfig, ExtendedFlowTracker},
8    is_ignorable_session_message, unexpected_backend_message,
9};
10use std::sync::Arc;
11
12#[inline]
13fn rollback_cache_miss_statement_registration(
14    conn: &mut PgConnection,
15    is_cache_miss: bool,
16    sql_hash: u64,
17    stmt_name: &str,
18) {
19    if is_cache_miss {
20        conn.stmt_cache.remove(&sql_hash);
21        conn.prepared_statements.remove(stmt_name);
22        conn.column_info_cache.remove(&sql_hash);
23    }
24}
25
26#[inline]
27fn register_hot_statement_after_parse_success(
28    pool: &super::lifecycle::PgPoolInner,
29    sql_hash: u64,
30    stmt_name: &str,
31    sql: &str,
32) {
33    if let Ok(mut hot) = pool.hot_statements.write()
34        && (hot.contains_key(&sql_hash) || hot.len() < MAX_HOT_STATEMENTS)
35    {
36        hot.insert(sql_hash, (stmt_name.to_string(), sql.to_string()));
37    }
38}
39
40#[inline]
41fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
42    if matches!(
43        err,
44        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
45    ) {
46        conn.mark_io_desynced();
47    }
48    Err(err)
49}
50
51#[inline]
52fn encoded_sql_str(sql_buf: &[u8]) -> PgResult<&str> {
53    std::str::from_utf8(sql_buf)
54        .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))
55}
56
57async fn drain_extended_responses_after_rls_setup_error(conn: &mut PgConnection) -> PgResult<()> {
58    loop {
59        let msg = conn.recv().await?;
60        match msg {
61            crate::protocol::BackendMessage::ReadyForQuery(_) => return Ok(()),
62            crate::protocol::BackendMessage::ErrorResponse(_) => {}
63            msg if is_ignorable_session_message(&msg) => {}
64            // Best-effort drain: consume everything until Sync's ReadyForQuery.
65            _ => {}
66        }
67    }
68}
69
70fn copy_export_table_sql(table: &str, columns: &[String]) -> PgResult<String> {
71    let cols: Vec<String> = columns
72        .iter()
73        .map(|c| crate::driver::copy::quote_copy_column_ident(c))
74        .collect::<PgResult<_>>()?;
75
76    Ok(format!(
77        "COPY {} ({}) TO STDOUT",
78        crate::driver::copy::quote_copy_table_ref(table)?,
79        cols.join(", ")
80    ))
81}
82
83impl PooledConnection {
84    /// Execute a QAIL command and fetch all rows (UNCACHED).
85    /// Returns rows with column metadata for JSON serialization.
86    pub async fn fetch_all_uncached(
87        &mut self,
88        cmd: &qail_core::ast::Qail,
89    ) -> PgResult<Vec<crate::driver::PgRow>> {
90        self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
91            .await
92    }
93
94    /// Execute raw SQL with bind parameters and return raw row data.
95    ///
96    /// Uses the Extended Query Protocol so parameters are never interpolated
97    /// into the SQL string. Intended for EXPLAIN or other SQL that can't be
98    /// represented as a `Qail` AST but still needs parameterized execution.
99    ///
100    /// Returns raw column bytes; callers must decode as needed.
101    pub async fn query_raw_with_params(
102        &mut self,
103        sql: &str,
104        params: &[Option<Vec<u8>>],
105    ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
106        let conn = self.conn_mut()?;
107        conn.query(sql, params).await
108    }
109
110    /// Execute raw SQL with bind parameters and return rows with column metadata.
111    ///
112    /// Uses the Extended Query Protocol so parameters are never interpolated
113    /// into the SQL string. Intended for compatibility paths that need
114    /// `PgRow` decoding and stable column names, not just raw bytes.
115    pub async fn query_rows_with_params(
116        &mut self,
117        sql: &str,
118        params: &[Option<Vec<u8>>],
119    ) -> PgResult<Vec<crate::driver::PgRow>> {
120        self.query_rows_with_params_with_format(sql, params, ResultFormat::Text)
121            .await
122    }
123
124    /// Execute raw SQL with bind parameters and explicit result format,
125    /// returning rows with column metadata.
126    pub async fn query_rows_with_params_with_format(
127        &mut self,
128        sql: &str,
129        params: &[Option<Vec<u8>>],
130        result_format: ResultFormat,
131    ) -> PgResult<Vec<crate::driver::PgRow>> {
132        let conn = self.conn_mut()?;
133        conn.query_rows_with_result_format(sql, params, result_format.as_wire_code())
134            .await
135    }
136
137    /// Execute raw SQL with explicit PostgreSQL parameter type OIDs and return
138    /// rows with column metadata.
139    pub async fn query_rows_with_param_types_with_format(
140        &mut self,
141        sql: &str,
142        param_types: &[u32],
143        params: &[Option<Vec<u8>>],
144        result_format: ResultFormat,
145    ) -> PgResult<Vec<crate::driver::PgRow>> {
146        let conn = self.conn_mut()?;
147        conn.query_rows_with_param_types_and_result_format(
148            sql,
149            param_types,
150            params,
151            result_format.as_wire_code(),
152        )
153        .await
154    }
155
156    /// Validate raw SQL with explicit PostgreSQL parameter type OIDs without
157    /// executing it.
158    pub async fn probe_query_with_param_types(
159        &mut self,
160        sql: &str,
161        param_types: &[u32],
162        params: &[Option<Vec<u8>>],
163    ) -> PgResult<()> {
164        let conn = self.conn_mut()?;
165        conn.probe_query_with_param_types(sql, param_types, params)
166            .await
167    }
168
169    /// Export data using AST-native COPY TO STDOUT and collect parsed rows.
170    pub async fn copy_export(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<Vec<String>>> {
171        self.conn_mut()?.copy_export(cmd).await
172    }
173
174    /// Stream AST-native COPY TO STDOUT chunks with bounded memory usage.
175    pub async fn copy_export_stream_raw<F, Fut>(
176        &mut self,
177        cmd: &qail_core::ast::Qail,
178        on_chunk: F,
179    ) -> PgResult<()>
180    where
181        F: FnMut(Vec<u8>) -> Fut,
182        Fut: std::future::Future<Output = PgResult<()>>,
183    {
184        self.conn_mut()?.copy_export_stream_raw(cmd, on_chunk).await
185    }
186
187    /// Stream AST-native COPY TO STDOUT rows with bounded memory usage.
188    pub async fn copy_export_stream_rows<F>(
189        &mut self,
190        cmd: &qail_core::ast::Qail,
191        on_row: F,
192    ) -> PgResult<()>
193    where
194        F: FnMut(Vec<String>) -> PgResult<()>,
195    {
196        self.conn_mut()?.copy_export_stream_rows(cmd, on_row).await
197    }
198
199    /// Export a table using COPY TO STDOUT and collect raw bytes.
200    pub async fn copy_export_table(
201        &mut self,
202        table: &str,
203        columns: &[String],
204    ) -> PgResult<Vec<u8>> {
205        let sql = copy_export_table_sql(table, columns)?;
206        self.conn_mut()?.copy_out_raw(&sql).await
207    }
208
209    /// Stream a table export using COPY TO STDOUT with bounded memory usage.
210    pub async fn copy_export_table_stream<F, Fut>(
211        &mut self,
212        table: &str,
213        columns: &[String],
214        on_chunk: F,
215    ) -> PgResult<()>
216    where
217        F: FnMut(Vec<u8>) -> Fut,
218        Fut: std::future::Future<Output = PgResult<()>>,
219    {
220        let sql = copy_export_table_sql(table, columns)?;
221        self.conn_mut()?.copy_out_raw_stream(&sql, on_chunk).await
222    }
223
224    /// Execute a QAIL command and fetch all rows (UNCACHED) with explicit result format.
225    pub async fn fetch_all_uncached_with_format(
226        &mut self,
227        cmd: &qail_core::ast::Qail,
228        result_format: ResultFormat,
229    ) -> PgResult<Vec<crate::driver::PgRow>> {
230        use crate::driver::ColumnInfo;
231        use crate::protocol::AstEncoder;
232
233        let conn = self.conn_mut()?;
234
235        AstEncoder::encode_cmd_reuse_into_with_result_format(
236            cmd,
237            &mut conn.sql_buf,
238            &mut conn.params_buf,
239            &mut conn.write_buf,
240            result_format.as_wire_code(),
241        )
242        .map_err(|e| PgError::Encode(e.to_string()))?;
243
244        conn.flush_write_buf().await?;
245
246        let mut rows: Vec<crate::driver::PgRow> = Vec::new();
247        let mut column_info: Option<Arc<ColumnInfo>> = None;
248        let mut error: Option<PgError> = None;
249        let mut flow =
250            ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_describe_portal_execute());
251
252        loop {
253            let msg = conn.recv().await?;
254            if let Err(err) = flow.validate(&msg, "pool fetch_all execute", error.is_some()) {
255                return return_with_desync(conn, err);
256            }
257            match msg {
258                crate::protocol::BackendMessage::ParseComplete
259                | crate::protocol::BackendMessage::BindComplete => {}
260                crate::protocol::BackendMessage::RowDescription(fields) => {
261                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
262                }
263                crate::protocol::BackendMessage::DataRow(data) => {
264                    if error.is_none() {
265                        rows.push(crate::driver::PgRow {
266                            columns: data,
267                            column_info: column_info.clone(),
268                        });
269                    }
270                }
271                crate::protocol::BackendMessage::NoData => {}
272                crate::protocol::BackendMessage::CommandComplete(_) => {}
273                crate::protocol::BackendMessage::ReadyForQuery(_) => {
274                    if let Some(err) = error {
275                        return Err(err);
276                    }
277                    return Ok(rows);
278                }
279                crate::protocol::BackendMessage::ErrorResponse(err) => {
280                    if error.is_none() {
281                        error = Some(PgError::QueryServer(err.into()));
282                    }
283                }
284                msg if is_ignorable_session_message(&msg) => {}
285                other => {
286                    return return_with_desync(
287                        conn,
288                        unexpected_backend_message("pool fetch_all execute", &other),
289                    );
290                }
291            }
292        }
293    }
294
295    /// Execute a QAIL command and fetch all rows (FAST VERSION).
296    /// Uses native AST-to-wire encoding and optimized recv_with_data_fast.
297    /// Skips column metadata for maximum speed.
298    pub async fn fetch_all_fast(
299        &mut self,
300        cmd: &qail_core::ast::Qail,
301    ) -> PgResult<Vec<crate::driver::PgRow>> {
302        self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
303            .await
304    }
305
306    /// Execute a QAIL command and fetch all rows (FAST VERSION) with explicit result format.
307    pub async fn fetch_all_fast_with_format(
308        &mut self,
309        cmd: &qail_core::ast::Qail,
310        result_format: ResultFormat,
311    ) -> PgResult<Vec<crate::driver::PgRow>> {
312        use crate::protocol::AstEncoder;
313
314        let conn = self.conn_mut()?;
315
316        AstEncoder::encode_cmd_reuse_into_with_result_format(
317            cmd,
318            &mut conn.sql_buf,
319            &mut conn.params_buf,
320            &mut conn.write_buf,
321            result_format.as_wire_code(),
322        )
323        .map_err(|e| PgError::Encode(e.to_string()))?;
324
325        conn.flush_write_buf().await?;
326
327        let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
328        let mut error: Option<PgError> = None;
329        let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(true));
330
331        loop {
332            let res = conn.recv_with_data_fast().await;
333            match res {
334                Ok((msg_type, data)) => {
335                    if let Err(err) = flow.validate_msg_type(
336                        msg_type,
337                        "pool fetch_all_fast execute",
338                        error.is_some(),
339                    ) {
340                        return return_with_desync(conn, err);
341                    }
342                    match msg_type {
343                        b'D' => {
344                            if error.is_none()
345                                && let Some(columns) = data
346                            {
347                                rows.push(crate::driver::PgRow {
348                                    columns,
349                                    column_info: None,
350                                });
351                            }
352                        }
353                        b'Z' => {
354                            if let Some(err) = error {
355                                return Err(err);
356                            }
357                            return Ok(rows);
358                        }
359                        _ => {}
360                    }
361                }
362                Err(e) => {
363                    if matches!(&e, PgError::QueryServer(_)) {
364                        if error.is_none() {
365                            error = Some(e);
366                        }
367                        continue;
368                    }
369                    return Err(e);
370                }
371            }
372        }
373    }
374
375    /// Execute a QAIL command and fetch all rows (CACHED).
376    /// Uses prepared statement caching: Parse+Describe on first call,
377    /// then Bind+Execute only on subsequent calls with the same SQL shape.
378    /// This matches PostgREST's behavior for fair benchmarks.
379    pub async fn fetch_all_cached(
380        &mut self,
381        cmd: &qail_core::ast::Qail,
382    ) -> PgResult<Vec<crate::driver::PgRow>> {
383        self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
384            .await
385    }
386
387    /// Execute a QAIL command and fetch all rows (CACHED) with explicit result format.
388    pub async fn fetch_all_cached_with_format(
389        &mut self,
390        cmd: &qail_core::ast::Qail,
391        result_format: ResultFormat,
392    ) -> PgResult<Vec<crate::driver::PgRow>> {
393        let mut retried = false;
394        loop {
395            match self
396                .fetch_all_cached_with_format_once(cmd, result_format)
397                .await
398            {
399                Ok(rows) => return Ok(rows),
400                Err(err)
401                    if !retried
402                        && (err.is_prepared_statement_retryable()
403                            || err.is_prepared_statement_already_exists()) =>
404                {
405                    retried = true;
406                    if err.is_prepared_statement_retryable()
407                        && let Some(conn) = self.conn.as_mut()
408                    {
409                        conn.clear_prepared_statement_state();
410                    }
411                }
412                Err(err) => return Err(err),
413            }
414        }
415    }
416
417    /// Execute a QAIL command and decode rows into typed structs (CACHED, text format).
418    pub async fn fetch_typed<T: crate::driver::row::QailRow>(
419        &mut self,
420        cmd: &qail_core::ast::Qail,
421    ) -> PgResult<Vec<T>> {
422        self.fetch_typed_with_format(cmd, ResultFormat::Text).await
423    }
424
425    /// Execute a QAIL command and decode rows into typed structs with explicit result format.
426    ///
427    /// Use [`ResultFormat::Binary`] for binary wire values; row decoders should use
428    /// metadata-aware helpers like `PgRow::try_get()` / `try_get_by_name()`.
429    pub async fn fetch_typed_with_format<T: crate::driver::row::QailRow>(
430        &mut self,
431        cmd: &qail_core::ast::Qail,
432        result_format: ResultFormat,
433    ) -> PgResult<Vec<T>> {
434        let rows = self
435            .fetch_all_cached_with_format(cmd, result_format)
436            .await?;
437        Ok(rows.iter().map(T::from_row).collect())
438    }
439
440    /// Execute a QAIL command and decode one typed row (CACHED, text format).
441    pub async fn fetch_one_typed<T: crate::driver::row::QailRow>(
442        &mut self,
443        cmd: &qail_core::ast::Qail,
444    ) -> PgResult<Option<T>> {
445        self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
446            .await
447    }
448
449    /// Execute a QAIL command and decode one typed row with explicit result format.
450    pub async fn fetch_one_typed_with_format<T: crate::driver::row::QailRow>(
451        &mut self,
452        cmd: &qail_core::ast::Qail,
453        result_format: ResultFormat,
454    ) -> PgResult<Option<T>> {
455        let rows = self
456            .fetch_all_cached_with_format(cmd, result_format)
457            .await?;
458        Ok(rows.first().map(T::from_row))
459    }
460
461    async fn fetch_all_cached_with_format_once(
462        &mut self,
463        cmd: &qail_core::ast::Qail,
464        result_format: ResultFormat,
465    ) -> PgResult<Vec<crate::driver::PgRow>> {
466        use crate::driver::ColumnInfo;
467        use std::collections::hash_map::DefaultHasher;
468        use std::hash::{Hash, Hasher};
469
470        let pool = std::sync::Arc::clone(&self.pool);
471        let conn = self.conn.as_mut().ok_or_else(|| {
472            PgError::Connection("Connection already released back to pool".into())
473        })?;
474
475        conn.sql_buf.clear();
476        conn.params_buf.clear();
477
478        // Encode SQL + params to reusable buffers
479        match cmd.action {
480            qail_core::ast::Action::Get | qail_core::ast::Action::With => {
481                crate::protocol::ast_encoder::dml::encode_select(
482                    cmd,
483                    &mut conn.sql_buf,
484                    &mut conn.params_buf,
485                )?;
486            }
487            qail_core::ast::Action::Add => {
488                crate::protocol::ast_encoder::dml::encode_insert(
489                    cmd,
490                    &mut conn.sql_buf,
491                    &mut conn.params_buf,
492                )?;
493            }
494            qail_core::ast::Action::Set => {
495                crate::protocol::ast_encoder::dml::encode_update(
496                    cmd,
497                    &mut conn.sql_buf,
498                    &mut conn.params_buf,
499                )?;
500            }
501            qail_core::ast::Action::Del => {
502                crate::protocol::ast_encoder::dml::encode_delete(
503                    cmd,
504                    &mut conn.sql_buf,
505                    &mut conn.params_buf,
506                )?;
507            }
508            _ => {
509                // Fallback: unsupported actions go through uncached path
510                return self
511                    .fetch_all_uncached_with_format(cmd, result_format)
512                    .await;
513            }
514        }
515
516        let mut hasher = DefaultHasher::new();
517        conn.sql_buf.hash(&mut hasher);
518        let sql_hash = hasher.finish();
519
520        let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
521
522        conn.write_buf.clear();
523
524        let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
525            name
526        } else {
527            let name = format!("qail_{:x}", sql_hash);
528
529            conn.evict_prepared_if_full();
530
531            let sql_str = encoded_sql_str(&conn.sql_buf)?;
532
533            use crate::protocol::PgEncoder;
534            let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
535            let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
536            conn.write_buf.extend_from_slice(&parse_msg);
537            conn.write_buf.extend_from_slice(&describe_msg);
538
539            conn.stmt_cache.put(sql_hash, name.clone());
540            conn.prepared_statements
541                .insert(name.clone(), sql_str.to_string());
542
543            name
544        };
545
546        use crate::protocol::PgEncoder;
547        if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
548            &mut conn.write_buf,
549            &stmt_name,
550            &conn.params_buf,
551            result_format.as_wire_code(),
552        ) {
553            if is_cache_miss {
554                conn.stmt_cache.remove(&sql_hash);
555                conn.prepared_statements.remove(&stmt_name);
556                conn.column_info_cache.remove(&sql_hash);
557            }
558            return Err(PgError::Encode(e.to_string()));
559        }
560        PgEncoder::encode_execute_to(&mut conn.write_buf);
561        PgEncoder::encode_sync_to(&mut conn.write_buf);
562
563        if let Err(err) = conn.flush_write_buf().await {
564            if is_cache_miss {
565                conn.stmt_cache.remove(&sql_hash);
566                conn.prepared_statements.remove(&stmt_name);
567                conn.column_info_cache.remove(&sql_hash);
568            }
569            return Err(err);
570        }
571
572        let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
573
574        let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
575        let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
576        let mut error: Option<PgError> = None;
577        let mut flow = ExtendedFlowTracker::new(
578            ExtendedFlowConfig::parse_describe_statement_bind_execute(is_cache_miss),
579        );
580
581        loop {
582            let msg = match conn.recv().await {
583                Ok(msg) => msg,
584                Err(err) => {
585                    if is_cache_miss && !flow.saw_parse_complete() {
586                        conn.stmt_cache.remove(&sql_hash);
587                        conn.prepared_statements.remove(&stmt_name);
588                        conn.column_info_cache.remove(&sql_hash);
589                    }
590                    return Err(err);
591                }
592            };
593            if let Err(err) = flow.validate(&msg, "pool fetch_all_cached execute", error.is_some())
594            {
595                if is_cache_miss && !flow.saw_parse_complete() {
596                    conn.stmt_cache.remove(&sql_hash);
597                    conn.prepared_statements.remove(&stmt_name);
598                    conn.column_info_cache.remove(&sql_hash);
599                }
600                return return_with_desync(conn, err);
601            }
602            match msg {
603                crate::protocol::BackendMessage::ParseComplete => {}
604                crate::protocol::BackendMessage::BindComplete => {}
605                crate::protocol::BackendMessage::ParameterDescription(_) => {}
606                crate::protocol::BackendMessage::RowDescription(fields) => {
607                    let info = Arc::new(ColumnInfo::from_fields(&fields));
608                    if is_cache_miss {
609                        conn.column_info_cache.insert(sql_hash, Arc::clone(&info));
610                    }
611                    column_info = Some(info);
612                }
613                crate::protocol::BackendMessage::DataRow(data) => {
614                    if error.is_none() {
615                        rows.push(crate::driver::PgRow {
616                            columns: data,
617                            column_info: column_info.clone(),
618                        });
619                    }
620                }
621                crate::protocol::BackendMessage::CommandComplete(_) => {}
622                crate::protocol::BackendMessage::ReadyForQuery(_) => {
623                    if let Some(err) = error {
624                        if is_cache_miss
625                            && !flow.saw_parse_complete()
626                            && !err.is_prepared_statement_already_exists()
627                        {
628                            conn.stmt_cache.remove(&sql_hash);
629                            conn.prepared_statements.remove(&stmt_name);
630                            conn.column_info_cache.remove(&sql_hash);
631                        }
632                        return Err(err);
633                    }
634                    if is_cache_miss && !flow.saw_parse_complete() {
635                        conn.stmt_cache.remove(&sql_hash);
636                        conn.prepared_statements.remove(&stmt_name);
637                        conn.column_info_cache.remove(&sql_hash);
638                        return return_with_desync(
639                            conn,
640                            PgError::Protocol(
641                                "Cache miss query reached ReadyForQuery without ParseComplete"
642                                    .to_string(),
643                            ),
644                        );
645                    }
646                    if is_cache_miss && let Some(sql) = conn.prepared_statements.get(&stmt_name) {
647                        register_hot_statement_after_parse_success(
648                            &pool, sql_hash, &stmt_name, sql,
649                        );
650                    }
651                    return Ok(rows);
652                }
653                crate::protocol::BackendMessage::ErrorResponse(err) => {
654                    if error.is_none() {
655                        error = Some(PgError::QueryServer(err.into()));
656                    }
657                }
658                msg if is_ignorable_session_message(&msg) => {}
659                other => {
660                    if is_cache_miss && !flow.saw_parse_complete() {
661                        conn.stmt_cache.remove(&sql_hash);
662                        conn.prepared_statements.remove(&stmt_name);
663                        conn.column_info_cache.remove(&sql_hash);
664                    }
665                    return return_with_desync(
666                        conn,
667                        unexpected_backend_message("pool fetch_all_cached execute", &other),
668                    );
669                }
670            }
671        }
672    }
673
674    /// Execute a QAIL command with RLS context in a SINGLE roundtrip.
675    ///
676    /// Pipelines the RLS setup (BEGIN + set_config) and the query
677    /// (Parse/Bind/Execute/Sync) into one `write_all` syscall.
678    /// PG processes messages in order, so the BEGIN + set_config
679    /// completes before the query executes — security is preserved.
680    ///
681    /// Wire layout:
682    /// ```text
683    /// [SimpleQuery: "BEGIN; SET LOCAL...; SELECT set_config(...)"]
684    /// [Parse (if cache miss)]
685    /// [Describe (if cache miss)]
686    /// [Bind]
687    /// [Execute]
688    /// [Sync]
689    /// ```
690    ///
691    /// Response processing: consume 2× ReadyForQuery (SimpleQuery + Sync).
692    pub async fn fetch_all_with_rls(
693        &mut self,
694        cmd: &qail_core::ast::Qail,
695        rls_sql: &str,
696    ) -> PgResult<Vec<crate::driver::PgRow>> {
697        self.fetch_all_with_rls_with_format(cmd, rls_sql, ResultFormat::Text)
698            .await
699    }
700
701    /// Execute a QAIL command with RLS context in a SINGLE roundtrip with explicit result format.
702    pub async fn fetch_all_with_rls_with_format(
703        &mut self,
704        cmd: &qail_core::ast::Qail,
705        rls_sql: &str,
706        result_format: ResultFormat,
707    ) -> PgResult<Vec<crate::driver::PgRow>> {
708        let mut retried = false;
709        loop {
710            match self
711                .fetch_all_with_rls_with_format_once(cmd, rls_sql, result_format)
712                .await
713            {
714                Ok(rows) => return Ok(rows),
715                Err(err)
716                    if !retried
717                        && (err.is_prepared_statement_retryable()
718                            || err.is_prepared_statement_already_exists()) =>
719                {
720                    retried = true;
721                    if let Some(conn) = self.conn.as_mut() {
722                        if err.is_prepared_statement_retryable() {
723                            conn.clear_prepared_statement_state();
724                        }
725                        // Always rollback transaction state before a retried RLS pipeline
726                        // attempt, including 42P05 "prepared statement already exists".
727                        let _ = conn.execute_simple("ROLLBACK").await;
728                    }
729                    self.rls_dirty = false;
730                }
731                Err(err) => return Err(err),
732            }
733        }
734    }
735
736    async fn fetch_all_with_rls_with_format_once(
737        &mut self,
738        cmd: &qail_core::ast::Qail,
739        rls_sql: &str,
740        result_format: ResultFormat,
741    ) -> PgResult<Vec<crate::driver::PgRow>> {
742        use crate::driver::ColumnInfo;
743        use std::collections::hash_map::DefaultHasher;
744        use std::hash::{Hash, Hasher};
745
746        let pool = std::sync::Arc::clone(&self.pool);
747        let conn = self.conn.as_mut().ok_or_else(|| {
748            PgError::Connection("Connection already released back to pool".into())
749        })?;
750
751        if !crate::protocol::AstEncoder::encode_cacheable_cmd_sql_to(
752            cmd,
753            &mut conn.sql_buf,
754            &mut conn.params_buf,
755        )? {
756            // Fallback: RLS setup must happen synchronously for unsupported actions
757            conn.execute_simple(rls_sql).await?;
758            self.rls_dirty = true;
759            return self
760                .fetch_all_uncached_with_format(cmd, result_format)
761                .await;
762        }
763
764        let mut hasher = DefaultHasher::new();
765        conn.sql_buf.hash(&mut hasher);
766        let sql_hash = hasher.finish();
767
768        let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
769
770        conn.write_buf.clear();
771
772        // ── Prepend RLS Simple Query message ─────────────────────────
773        // NOTE: this is PostgreSQL SimpleQuery text, so the backend still
774        // parses this segment on every request. The optimization here is
775        // batching RLS + query protocol messages into one network flush.
776        let rls_msg = crate::protocol::PgEncoder::try_encode_query_string(rls_sql)?;
777        conn.write_buf.extend_from_slice(&rls_msg);
778
779        // ── Then append the query messages (same as fetch_all_cached) ──
780        let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
781            name
782        } else {
783            let name = format!("qail_{:x}", sql_hash);
784
785            conn.evict_prepared_if_full();
786
787            let sql_str = encoded_sql_str(&conn.sql_buf)?;
788
789            use crate::protocol::PgEncoder;
790            let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
791            let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
792            conn.write_buf.extend_from_slice(&parse_msg);
793            conn.write_buf.extend_from_slice(&describe_msg);
794
795            conn.stmt_cache.put(sql_hash, name.clone());
796            conn.prepared_statements
797                .insert(name.clone(), sql_str.to_string());
798
799            name
800        };
801
802        use crate::protocol::PgEncoder;
803        if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
804            &mut conn.write_buf,
805            &stmt_name,
806            &conn.params_buf,
807            result_format.as_wire_code(),
808        ) {
809            rollback_cache_miss_statement_registration(conn, is_cache_miss, sql_hash, &stmt_name);
810            return Err(PgError::Encode(e.to_string()));
811        }
812        PgEncoder::encode_execute_to(&mut conn.write_buf);
813        PgEncoder::encode_sync_to(&mut conn.write_buf);
814
815        // ── Single write_all for RLS + Query ────────────────────────
816        if let Err(err) = conn.flush_write_buf().await {
817            rollback_cache_miss_statement_registration(conn, is_cache_miss, sql_hash, &stmt_name);
818            return Err(err);
819        }
820
821        // Mark connection as RLS-dirty (needs COMMIT on release)
822        self.rls_dirty = true;
823
824        // ── Phase 1: Consume Simple Query responses (RLS setup) ─────
825        // Simple Query produces: CommandComplete × N, then ReadyForQuery.
826        // set_config results and BEGIN/SET LOCAL responses are all here.
827        let mut rls_error: Option<PgError> = None;
828        loop {
829            let msg = match conn.recv().await {
830                Ok(msg) => msg,
831                Err(err) => {
832                    rollback_cache_miss_statement_registration(
833                        conn,
834                        is_cache_miss,
835                        sql_hash,
836                        &stmt_name,
837                    );
838                    return Err(err);
839                }
840            };
841            match msg {
842                crate::protocol::BackendMessage::ReadyForQuery(_) => {
843                    // RLS setup done — break to Extended Query phase
844                    if let Some(err) = rls_error {
845                        rollback_cache_miss_statement_registration(
846                            conn,
847                            is_cache_miss,
848                            sql_hash,
849                            &stmt_name,
850                        );
851                        if let Err(drain_err) =
852                            drain_extended_responses_after_rls_setup_error(conn).await
853                        {
854                            tracing::warn!(
855                                error = %drain_err,
856                                "failed to drain pipelined extended responses after RLS setup error"
857                            );
858                        }
859                        return Err(err);
860                    }
861                    break;
862                }
863                crate::protocol::BackendMessage::ErrorResponse(err) => {
864                    if rls_error.is_none() {
865                        rls_error = Some(PgError::QueryServer(err.into()));
866                    }
867                }
868                // CommandComplete, DataRow (from set_config), RowDescription — ignore
869                crate::protocol::BackendMessage::CommandComplete(_)
870                | crate::protocol::BackendMessage::DataRow(_)
871                | crate::protocol::BackendMessage::RowDescription(_)
872                | crate::protocol::BackendMessage::ParseComplete
873                | crate::protocol::BackendMessage::BindComplete => {}
874                msg if is_ignorable_session_message(&msg) => {}
875                other => {
876                    rollback_cache_miss_statement_registration(
877                        conn,
878                        is_cache_miss,
879                        sql_hash,
880                        &stmt_name,
881                    );
882                    return return_with_desync(
883                        conn,
884                        unexpected_backend_message("pool rls setup", &other),
885                    );
886                }
887            }
888        }
889
890        // ── Phase 2: Consume Extended Query responses (actual data) ──
891        let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
892
893        let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
894        let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
895        let mut error: Option<PgError> = None;
896        let mut flow = ExtendedFlowTracker::new(
897            ExtendedFlowConfig::parse_describe_statement_bind_execute(is_cache_miss),
898        );
899
900        loop {
901            let msg = match conn.recv().await {
902                Ok(msg) => msg,
903                Err(err) => {
904                    if is_cache_miss && !flow.saw_parse_complete() {
905                        rollback_cache_miss_statement_registration(
906                            conn,
907                            is_cache_miss,
908                            sql_hash,
909                            &stmt_name,
910                        );
911                    }
912                    return Err(err);
913                }
914            };
915            if let Err(err) =
916                flow.validate(&msg, "pool fetch_all_with_rls execute", error.is_some())
917            {
918                if is_cache_miss && !flow.saw_parse_complete() {
919                    rollback_cache_miss_statement_registration(
920                        conn,
921                        is_cache_miss,
922                        sql_hash,
923                        &stmt_name,
924                    );
925                }
926                return return_with_desync(conn, err);
927            }
928            match msg {
929                crate::protocol::BackendMessage::ParseComplete => {}
930                crate::protocol::BackendMessage::BindComplete => {}
931                crate::protocol::BackendMessage::ParameterDescription(_) => {}
932                crate::protocol::BackendMessage::RowDescription(fields) => {
933                    let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
934                    if is_cache_miss {
935                        conn.column_info_cache
936                            .insert(sql_hash, std::sync::Arc::clone(&info));
937                    }
938                    column_info = Some(info);
939                }
940                crate::protocol::BackendMessage::DataRow(data) => {
941                    if error.is_none() {
942                        rows.push(crate::driver::PgRow {
943                            columns: data,
944                            column_info: column_info.clone(),
945                        });
946                    }
947                }
948                crate::protocol::BackendMessage::CommandComplete(_) => {}
949                crate::protocol::BackendMessage::ReadyForQuery(_) => {
950                    if let Some(err) = error {
951                        if is_cache_miss
952                            && !flow.saw_parse_complete()
953                            && !err.is_prepared_statement_already_exists()
954                        {
955                            rollback_cache_miss_statement_registration(
956                                conn,
957                                is_cache_miss,
958                                sql_hash,
959                                &stmt_name,
960                            );
961                        }
962                        return Err(err);
963                    }
964                    if is_cache_miss && !flow.saw_parse_complete() {
965                        rollback_cache_miss_statement_registration(
966                            conn,
967                            is_cache_miss,
968                            sql_hash,
969                            &stmt_name,
970                        );
971                        return return_with_desync(
972                            conn,
973                            PgError::Protocol(
974                                "Cache miss query reached ReadyForQuery without ParseComplete"
975                                    .to_string(),
976                            ),
977                        );
978                    }
979                    if is_cache_miss && let Some(sql) = conn.prepared_statements.get(&stmt_name) {
980                        register_hot_statement_after_parse_success(
981                            &pool, sql_hash, &stmt_name, sql,
982                        );
983                    }
984                    return Ok(rows);
985                }
986                crate::protocol::BackendMessage::ErrorResponse(err) => {
987                    if error.is_none() {
988                        error = Some(PgError::QueryServer(err.into()));
989                    }
990                }
991                msg if is_ignorable_session_message(&msg) => {}
992                other => {
993                    if is_cache_miss && !flow.saw_parse_complete() {
994                        rollback_cache_miss_statement_registration(
995                            conn,
996                            is_cache_miss,
997                            sql_hash,
998                            &stmt_name,
999                        );
1000                    }
1001                    return return_with_desync(
1002                        conn,
1003                        unexpected_backend_message("pool fetch_all_with_rls execute", &other),
1004                    );
1005                }
1006            }
1007        }
1008    }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013    use super::{copy_export_table_sql, encoded_sql_str, return_with_desync};
1014
1015    #[cfg(unix)]
1016    fn test_conn() -> crate::driver::PgConnection {
1017        use crate::driver::connection::StatementCache;
1018        use crate::driver::stream::PgStream;
1019        use bytes::BytesMut;
1020        use std::collections::{HashMap, VecDeque};
1021        use std::num::NonZeroUsize;
1022        use tokio::net::UnixStream;
1023
1024        let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1025        crate::driver::PgConnection {
1026            stream: PgStream::Unix(unix_stream),
1027            buffer: BytesMut::with_capacity(1024),
1028            write_buf: BytesMut::with_capacity(1024),
1029            sql_buf: BytesMut::with_capacity(256),
1030            params_buf: Vec::new(),
1031            prepared_statements: HashMap::new(),
1032            stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
1033            column_info_cache: HashMap::new(),
1034            process_id: 0,
1035            cancel_key_bytes: Vec::new(),
1036            requested_protocol_minor: crate::driver::PgConnection::default_protocol_minor(),
1037            negotiated_protocol_minor: crate::driver::PgConnection::default_protocol_minor(),
1038            notifications: VecDeque::new(),
1039            replication_stream_active: false,
1040            replication_mode_enabled: false,
1041            last_replication_wal_end: None,
1042            io_desynced: false,
1043            pending_statement_closes: Vec::new(),
1044            draining_statement_closes: false,
1045        }
1046    }
1047
1048    #[test]
1049    fn pool_copy_export_table_sql_preserves_schema_qualified_table() {
1050        let sql = copy_export_table_sql(
1051            "tenant_a.users",
1052            &["id".to_string(), "display\"name".to_string()],
1053        )
1054        .unwrap();
1055
1056        assert_eq!(
1057            sql,
1058            "COPY \"tenant_a\".\"users\" (\"id\", \"display\"\"name\") TO STDOUT"
1059        );
1060    }
1061
1062    #[test]
1063    fn pool_copy_export_table_sql_rejects_nul_bytes() {
1064        assert!(copy_export_table_sql("tenant\0.users", &["id".to_string()]).is_err());
1065        assert!(copy_export_table_sql("users", &["id\0".to_string()]).is_err());
1066    }
1067
1068    #[test]
1069    fn pool_encoded_sql_str_rejects_invalid_utf8() {
1070        let err = encoded_sql_str(&[0xff]).expect_err("invalid SQL UTF-8 must fail");
1071        assert!(err.to_string().contains("encoded SQL is not UTF-8"));
1072    }
1073
1074    #[cfg(unix)]
1075    #[tokio::test]
1076    async fn pool_return_with_desync_marks_protocol_error() {
1077        let mut conn = test_conn();
1078
1079        let err = return_with_desync::<()>(
1080            &mut conn,
1081            crate::driver::PgError::Protocol("bad response ordering".to_string()),
1082        )
1083        .expect_err("protocol error must be returned");
1084
1085        assert!(err.to_string().contains("bad response ordering"));
1086        assert!(conn.is_io_desynced());
1087    }
1088}