Skip to main content

qail_pg/driver/
fetch.rs

1//! PgDriver fetch methods: fetch_all (cached/uncached/fast), fetch_typed,
2//! fetch_one, execute, and query_ast.
3
4use super::core::PgDriver;
5use super::prepared::PreparedAstQuery;
6use super::types::*;
7use qail_core::ast::Qail;
8use std::sync::Arc;
9use std::{
10    collections::hash_map::DefaultHasher,
11    hash::{Hash, Hasher},
12};
13
14#[inline]
15fn return_with_desync<T>(driver: &mut PgDriver, err: PgError) -> PgResult<T> {
16    if matches!(
17        err,
18        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
19    ) {
20        driver.connection.mark_io_desynced();
21    }
22    Err(err)
23}
24
25#[inline]
26fn encoded_sql_str(sql_buf: &[u8]) -> PgResult<&str> {
27    std::str::from_utf8(sql_buf)
28        .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))
29}
30
31async fn reprepare_prepared_ast_query(
32    conn: &mut super::PgConnection,
33    prepared: &PreparedAstQuery,
34) -> PgResult<()> {
35    conn.clear_prepared_statement_state();
36    let stmt = conn.prepare(&prepared.sql).await?;
37    conn.stmt_cache
38        .put(prepared.sql_hash, stmt.name().to_string());
39    conn.prepared_statements
40        .insert(stmt.name().to_string(), prepared.sql.clone());
41    Ok(())
42}
43
44impl PgDriver {
45    /// Execute a QAIL command and fetch all rows (CACHED + ZERO-ALLOC).
46    /// **Default method** - uses prepared statement caching for best performance.
47    /// On first call: sends Parse + Bind + Execute + Sync
48    /// On subsequent calls with same SQL: sends only Bind + Execute (SKIPS Parse!)
49    /// Uses per-connection LRU cache with max 100 statements (auto-evicts oldest),
50    /// with a hard prepared-statement cap of 128 per connection.
51    pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
52        self.fetch_all_with_format(cmd, ResultFormat::Text).await
53    }
54
55    /// Execute a QAIL command and fetch all rows using a specific result format.
56    ///
57    /// `result_format` controls server result-column encoding:
58    /// - [`ResultFormat::Text`] for standard text decoding.
59    /// - [`ResultFormat::Binary`] for binary wire values.
60    pub async fn fetch_all_with_format(
61        &mut self,
62        cmd: &Qail,
63        result_format: ResultFormat,
64    ) -> PgResult<Vec<PgRow>> {
65        // Delegate to cached-by-default behavior.
66        self.fetch_all_cached_with_format(cmd, result_format).await
67    }
68
69    /// Prepare an AST query once and return a reusable frozen handle.
70    ///
71    /// This is the lowest-overhead path for repeating the **exact same** AST
72    /// command (same SQL text and same bind values). It avoids per-call AST
73    /// encoding and statement-cache hash/lookup in `fetch_all_cached`.
74    pub async fn prepare_ast_query(&mut self, cmd: &Qail) -> PgResult<PreparedAstQuery> {
75        use crate::protocol::AstEncoder;
76
77        let (sql, params) =
78            AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
79        let stmt = self.connection.prepare(&sql).await?;
80
81        let mut hasher = DefaultHasher::new();
82        sql.hash(&mut hasher);
83        let sql_hash = hasher.finish();
84
85        self.connection
86            .stmt_cache
87            .put(sql_hash, stmt.name().to_string());
88        self.connection
89            .prepared_statements
90            .insert(stmt.name().to_string(), sql.clone());
91
92        Ok(PreparedAstQuery {
93            stmt,
94            params,
95            sql,
96            sql_hash,
97        })
98    }
99
100    /// Execute a precompiled AST query handle and return rows.
101    ///
102    /// Rows are returned without `ColumnInfo` metadata (`column_info = None`),
103    /// so prefer positional access (`row.text(0)`, `row.get_i64(1)`, ...).
104    pub async fn fetch_all_prepared_ast(
105        &mut self,
106        prepared: &PreparedAstQuery,
107    ) -> PgResult<Vec<PgRow>> {
108        self.fetch_all_prepared_ast_with_format(prepared, ResultFormat::Text)
109            .await
110    }
111
112    /// Execute a precompiled AST query handle with explicit result format.
113    pub async fn fetch_all_prepared_ast_with_format(
114        &mut self,
115        prepared: &PreparedAstQuery,
116        result_format: ResultFormat,
117    ) -> PgResult<Vec<PgRow>> {
118        let mut retried = false;
119
120        loop {
121            self.connection.stmt_cache.touch_key(prepared.sql_hash);
122            self.connection.write_buf.clear();
123            if let Err(e) = crate::protocol::PgEncoder::encode_bind_to_with_result_format(
124                &mut self.connection.write_buf,
125                prepared.stmt.name(),
126                &prepared.params,
127                result_format.as_wire_code(),
128            ) {
129                return Err(PgError::Encode(e.to_string()));
130            }
131            crate::protocol::PgEncoder::encode_execute_to(&mut self.connection.write_buf);
132            crate::protocol::PgEncoder::encode_sync_to(&mut self.connection.write_buf);
133
134            if let Err(err) = self.connection.flush_write_buf().await {
135                if !retried && err.is_prepared_statement_retryable() {
136                    retried = true;
137                    reprepare_prepared_ast_query(&mut self.connection, prepared).await?;
138                    continue;
139                }
140                return Err(err);
141            }
142
143            let mut rows: Vec<PgRow> = Vec::with_capacity(32);
144            let mut error: Option<PgError> = None;
145            let mut flow = super::extended_flow::ExtendedFlowTracker::new(
146                super::extended_flow::ExtendedFlowConfig::parse_bind_execute(false),
147            );
148
149            loop {
150                let msg = self.connection.recv().await?;
151                if let Err(err) = flow.validate(
152                    &msg,
153                    "driver fetch_all_prepared_ast execute",
154                    error.is_some(),
155                ) {
156                    return return_with_desync(self, err);
157                }
158                match msg {
159                    crate::protocol::BackendMessage::BindComplete => {}
160                    crate::protocol::BackendMessage::RowDescription(_) => {}
161                    crate::protocol::BackendMessage::DataRow(data) => {
162                        if error.is_none() {
163                            rows.push(PgRow {
164                                columns: data,
165                                column_info: None,
166                            });
167                        }
168                    }
169                    crate::protocol::BackendMessage::CommandComplete(_) => {}
170                    crate::protocol::BackendMessage::NoData => {}
171                    crate::protocol::BackendMessage::ReadyForQuery(_) => {
172                        if let Some(err) = error {
173                            if !retried && err.is_prepared_statement_retryable() {
174                                retried = true;
175                                reprepare_prepared_ast_query(&mut self.connection, prepared)
176                                    .await?;
177                                break;
178                            }
179                            return Err(err);
180                        }
181                        return Ok(rows);
182                    }
183                    crate::protocol::BackendMessage::ErrorResponse(err) => {
184                        if error.is_none() {
185                            error = Some(PgError::QueryServer(err.into()));
186                        }
187                    }
188                    msg if is_ignorable_session_message(&msg) => {}
189                    other => {
190                        return return_with_desync(
191                            self,
192                            unexpected_backend_message(
193                                "driver fetch_all_prepared_ast execute",
194                                &other,
195                            ),
196                        );
197                    }
198                }
199            }
200        }
201    }
202
203    /// Execute a QAIL command and fetch all rows as a typed struct (text format).
204    /// Requires the target type to implement `QailRow` trait.
205    ///
206    /// # Example
207    /// ```ignore
208    /// let users: Vec<User> = driver.fetch_typed::<User>(&query).await?;
209    /// ```
210    pub async fn fetch_typed<T: super::row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
211        self.fetch_typed_with_format(cmd, ResultFormat::Text).await
212    }
213
214    /// Execute a QAIL command and fetch all rows as a typed struct with explicit result format.
215    ///
216    /// Use [`ResultFormat::Binary`] to get binary wire values; row decoding should use
217    /// metadata-aware accessors such as `PgRow::try_get()` / `try_get_by_name()`.
218    pub async fn fetch_typed_with_format<T: super::row::QailRow>(
219        &mut self,
220        cmd: &Qail,
221        result_format: ResultFormat,
222    ) -> PgResult<Vec<T>> {
223        let rows = self.fetch_all_with_format(cmd, result_format).await?;
224        Ok(rows.iter().map(T::from_row).collect())
225    }
226
227    /// Execute a QAIL command and fetch a single row as a typed struct (text format).
228    /// Returns None if no rows are returned.
229    pub async fn fetch_one_typed<T: super::row::QailRow>(
230        &mut self,
231        cmd: &Qail,
232    ) -> PgResult<Option<T>> {
233        self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
234            .await
235    }
236
237    /// Execute a QAIL command and fetch a single row as a typed struct with explicit result format.
238    pub async fn fetch_one_typed_with_format<T: super::row::QailRow>(
239        &mut self,
240        cmd: &Qail,
241        result_format: ResultFormat,
242    ) -> PgResult<Option<T>> {
243        let rows = self.fetch_all_with_format(cmd, result_format).await?;
244        Ok(rows.first().map(T::from_row))
245    }
246
247    /// Execute a QAIL command and fetch all rows (UNCACHED).
248    /// Sends Parse + Bind + Execute on every call.
249    /// Use for one-off queries or when caching is not desired.
250    ///
251    /// Optimized: encodes wire bytes into reusable write_buf (zero-alloc).
252    pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
253        self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
254            .await
255    }
256
257    /// Execute a QAIL command and fetch all rows (UNCACHED) with explicit result format.
258    pub async fn fetch_all_uncached_with_format(
259        &mut self,
260        cmd: &Qail,
261        result_format: ResultFormat,
262    ) -> PgResult<Vec<PgRow>> {
263        use crate::protocol::AstEncoder;
264
265        AstEncoder::encode_cmd_reuse_into_with_result_format(
266            cmd,
267            &mut self.connection.sql_buf,
268            &mut self.connection.params_buf,
269            &mut self.connection.write_buf,
270            result_format.as_wire_code(),
271        )
272        .map_err(|e| PgError::Encode(e.to_string()))?;
273
274        self.connection.flush_write_buf().await?;
275
276        let mut rows: Vec<PgRow> = Vec::with_capacity(32);
277        let mut column_info: Option<Arc<ColumnInfo>> = None;
278
279        let mut error: Option<PgError> = None;
280        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
281            super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
282        );
283
284        loop {
285            let msg = self.connection.recv().await?;
286            if let Err(err) = flow.validate(&msg, "driver fetch_all execute", error.is_some()) {
287                return return_with_desync(self, err);
288            }
289            match msg {
290                crate::protocol::BackendMessage::ParseComplete
291                | crate::protocol::BackendMessage::BindComplete => {}
292                crate::protocol::BackendMessage::RowDescription(fields) => {
293                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
294                }
295                crate::protocol::BackendMessage::DataRow(data) => {
296                    if error.is_none() {
297                        rows.push(PgRow {
298                            columns: data,
299                            column_info: column_info.clone(),
300                        });
301                    }
302                }
303                crate::protocol::BackendMessage::NoData => {}
304                crate::protocol::BackendMessage::CommandComplete(_) => {}
305                crate::protocol::BackendMessage::ReadyForQuery(_) => {
306                    if let Some(err) = error {
307                        return Err(err);
308                    }
309                    return Ok(rows);
310                }
311                crate::protocol::BackendMessage::ErrorResponse(err) => {
312                    if error.is_none() {
313                        error = Some(PgError::QueryServer(err.into()));
314                    }
315                }
316                msg if is_ignorable_session_message(&msg) => {}
317                other => {
318                    return return_with_desync(
319                        self,
320                        unexpected_backend_message("driver fetch_all execute", &other),
321                    );
322                }
323            }
324        }
325    }
326
327    /// Execute a QAIL command and fetch all rows (FAST VERSION).
328    /// Uses optimized recv_with_data_fast for faster response parsing.
329    /// Skips column metadata collection for maximum speed.
330    pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
331        self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
332            .await
333    }
334
335    /// Execute a QAIL command and fetch all rows (FAST VERSION) with explicit result format.
336    pub async fn fetch_all_fast_with_format(
337        &mut self,
338        cmd: &Qail,
339        result_format: ResultFormat,
340    ) -> PgResult<Vec<PgRow>> {
341        use crate::protocol::AstEncoder;
342
343        AstEncoder::encode_cmd_reuse_into_with_result_format(
344            cmd,
345            &mut self.connection.sql_buf,
346            &mut self.connection.params_buf,
347            &mut self.connection.write_buf,
348            result_format.as_wire_code(),
349        )
350        .map_err(|e| PgError::Encode(e.to_string()))?;
351
352        self.connection.flush_write_buf().await?;
353
354        // Collect results using FAST receiver
355        let mut rows: Vec<PgRow> = Vec::with_capacity(32);
356        let mut error: Option<PgError> = None;
357        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
358            super::extended_flow::ExtendedFlowConfig::parse_bind_execute(true),
359        );
360
361        loop {
362            let res = self.connection.recv_with_data_fast().await;
363            match res {
364                Ok((msg_type, data)) => {
365                    if let Err(err) = flow.validate_msg_type(
366                        msg_type,
367                        "driver fetch_all_fast execute",
368                        error.is_some(),
369                    ) {
370                        return return_with_desync(self, err);
371                    }
372                    match msg_type {
373                        b'D' => {
374                            if error.is_none()
375                                && let Some(columns) = data
376                            {
377                                rows.push(PgRow {
378                                    columns,
379                                    column_info: None,
380                                });
381                            }
382                        }
383                        b'Z' => {
384                            if let Some(err) = error {
385                                return Err(err);
386                            }
387                            return Ok(rows);
388                        }
389                        _ => {}
390                    }
391                }
392                Err(e) => {
393                    // QueryServer means backend sent ErrorResponse; keep draining to ReadyForQuery.
394                    if matches!(&e, PgError::QueryServer(_)) {
395                        if error.is_none() {
396                            error = Some(e);
397                        }
398                        continue;
399                    }
400                    return Err(e);
401                }
402            }
403        }
404    }
405
406    /// Execute a QAIL command and fetch one row.
407    pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
408        let rows = self.fetch_all(cmd).await?;
409        rows.into_iter().next().ok_or(PgError::NoRows)
410    }
411
412    /// Execute a QAIL command with PREPARED STATEMENT CACHING.
413    /// Like fetch_all(), but caches the prepared statement on the server.
414    /// On first call: sends Parse + Describe + Bind + Execute + Sync
415    /// On subsequent calls: sends only Bind + Execute + Sync (SKIPS Parse!)
416    /// Column metadata (RowDescription) is cached alongside the statement
417    /// so that by-name column access works on every call.
418    ///
419    /// Optimized: all wire messages are batched into a single write_all syscall.
420    pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
421        self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
422            .await
423    }
424
425    /// Execute a QAIL command with prepared statement caching and explicit result format.
426    pub async fn fetch_all_cached_with_format(
427        &mut self,
428        cmd: &Qail,
429        result_format: ResultFormat,
430    ) -> PgResult<Vec<PgRow>> {
431        let mut retried = false;
432        loop {
433            match self
434                .fetch_all_cached_with_format_once(cmd, result_format)
435                .await
436            {
437                Ok(rows) => return Ok(rows),
438                Err(err)
439                    if !retried
440                        && (err.is_prepared_statement_retryable()
441                            || err.is_prepared_statement_already_exists()) =>
442                {
443                    retried = true;
444                    if err.is_prepared_statement_retryable() {
445                        self.connection.clear_prepared_statement_state();
446                    }
447                }
448                Err(err) => return Err(err),
449            }
450        }
451    }
452
453    async fn fetch_all_cached_with_format_once(
454        &mut self,
455        cmd: &Qail,
456        result_format: ResultFormat,
457    ) -> PgResult<Vec<PgRow>> {
458        use crate::protocol::AstEncoder;
459        use std::collections::hash_map::DefaultHasher;
460        use std::hash::{Hash, Hasher};
461
462        if !AstEncoder::encode_cacheable_cmd_sql_to(
463            cmd,
464            &mut self.connection.sql_buf,
465            &mut self.connection.params_buf,
466        )? {
467            // Fallback for unsupported actions
468            let (sql, params) =
469                AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
470            let raw_rows = self
471                .connection
472                .query_cached_with_result_format(&sql, &params, result_format.as_wire_code())
473                .await?;
474            return Ok(raw_rows
475                .into_iter()
476                .map(|data| PgRow {
477                    columns: data,
478                    column_info: None,
479                })
480                .collect());
481        }
482
483        let mut hasher = DefaultHasher::new();
484        self.connection.sql_buf.hash(&mut hasher);
485        let sql_hash = hasher.finish();
486
487        let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
488
489        // Build ALL wire messages into write_buf (single syscall)
490        self.connection.write_buf.clear();
491
492        let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
493            name
494        } else {
495            let name = format!("qail_{:x}", sql_hash);
496
497            // Evict LRU before borrowing sql_buf to avoid borrow conflict
498            self.connection.evict_prepared_if_full();
499
500            let sql_str = encoded_sql_str(&self.connection.sql_buf)?;
501
502            // Buffer Parse + Describe(Statement) for first call
503            use crate::protocol::PgEncoder;
504            let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
505            let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
506            self.connection.write_buf.extend_from_slice(&parse_msg);
507            self.connection.write_buf.extend_from_slice(&describe_msg);
508
509            self.connection.stmt_cache.put(sql_hash, name.clone());
510            self.connection
511                .prepared_statements
512                .insert(name.clone(), sql_str.to_string());
513
514            name
515        };
516
517        // Append Bind + Execute + Sync to same buffer
518        use crate::protocol::PgEncoder;
519        if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
520            &mut self.connection.write_buf,
521            &stmt_name,
522            &self.connection.params_buf,
523            result_format.as_wire_code(),
524        ) {
525            if is_cache_miss {
526                self.connection.stmt_cache.remove(&sql_hash);
527                self.connection.prepared_statements.remove(&stmt_name);
528                self.connection.column_info_cache.remove(&sql_hash);
529            }
530            return Err(PgError::Encode(e.to_string()));
531        }
532        PgEncoder::encode_execute_to(&mut self.connection.write_buf);
533        PgEncoder::encode_sync_to(&mut self.connection.write_buf);
534
535        // Single write_all syscall for all messages
536        if let Err(err) = self.connection.flush_write_buf().await {
537            if is_cache_miss {
538                self.connection.stmt_cache.remove(&sql_hash);
539                self.connection.prepared_statements.remove(&stmt_name);
540                self.connection.column_info_cache.remove(&sql_hash);
541            }
542            return Err(err);
543        }
544
545        // On cache hit, use the previously cached ColumnInfo
546        let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
547
548        let mut rows: Vec<PgRow> = Vec::with_capacity(32);
549        let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
550        let mut error: Option<PgError> = None;
551        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
552            super::extended_flow::ExtendedFlowConfig::parse_describe_statement_bind_execute(
553                is_cache_miss,
554            ),
555        );
556
557        loop {
558            let msg = match self.connection.recv().await {
559                Ok(msg) => msg,
560                Err(err) => {
561                    if is_cache_miss && !flow.saw_parse_complete() {
562                        self.connection.stmt_cache.remove(&sql_hash);
563                        self.connection.prepared_statements.remove(&stmt_name);
564                        self.connection.column_info_cache.remove(&sql_hash);
565                    }
566                    return Err(err);
567                }
568            };
569            if let Err(err) =
570                flow.validate(&msg, "driver fetch_all_cached execute", error.is_some())
571            {
572                if is_cache_miss && !flow.saw_parse_complete() {
573                    self.connection.stmt_cache.remove(&sql_hash);
574                    self.connection.prepared_statements.remove(&stmt_name);
575                    self.connection.column_info_cache.remove(&sql_hash);
576                }
577                return return_with_desync(self, err);
578            }
579            match msg {
580                crate::protocol::BackendMessage::ParseComplete => {}
581                crate::protocol::BackendMessage::BindComplete => {}
582                crate::protocol::BackendMessage::ParameterDescription(_) => {
583                    // Sent after Describe(Statement) — ignore
584                }
585                crate::protocol::BackendMessage::RowDescription(fields) => {
586                    // Received after Describe(Statement) on cache miss
587                    let info = Arc::new(ColumnInfo::from_fields(&fields));
588                    if is_cache_miss {
589                        self.connection
590                            .column_info_cache
591                            .insert(sql_hash, Arc::clone(&info));
592                    }
593                    column_info = Some(info);
594                }
595                crate::protocol::BackendMessage::DataRow(data) => {
596                    if error.is_none() {
597                        rows.push(PgRow {
598                            columns: data,
599                            column_info: column_info.clone(),
600                        });
601                    }
602                }
603                crate::protocol::BackendMessage::CommandComplete(_) => {}
604                crate::protocol::BackendMessage::NoData => {
605                    // Sent by Describe for statements that return no data (e.g. pure UPDATE without RETURNING)
606                }
607                crate::protocol::BackendMessage::ReadyForQuery(_) => {
608                    if let Some(err) = error {
609                        if is_cache_miss
610                            && !flow.saw_parse_complete()
611                            && !err.is_prepared_statement_already_exists()
612                        {
613                            self.connection.stmt_cache.remove(&sql_hash);
614                            self.connection.prepared_statements.remove(&stmt_name);
615                            self.connection.column_info_cache.remove(&sql_hash);
616                        }
617                        return Err(err);
618                    }
619                    if is_cache_miss && !flow.saw_parse_complete() {
620                        self.connection.stmt_cache.remove(&sql_hash);
621                        self.connection.prepared_statements.remove(&stmt_name);
622                        self.connection.column_info_cache.remove(&sql_hash);
623                        return return_with_desync(
624                            self,
625                            PgError::Protocol(
626                                "Cache miss query reached ReadyForQuery without ParseComplete"
627                                    .to_string(),
628                            ),
629                        );
630                    }
631                    return Ok(rows);
632                }
633                crate::protocol::BackendMessage::ErrorResponse(err) => {
634                    if error.is_none() {
635                        let query_err = PgError::QueryServer(err.into());
636                        if query_err.is_prepared_statement_retryable() {
637                            self.connection.clear_prepared_statement_state();
638                        }
639                        error = Some(query_err);
640                    }
641                }
642                msg if is_ignorable_session_message(&msg) => {}
643                other => {
644                    if is_cache_miss && !flow.saw_parse_complete() {
645                        self.connection.stmt_cache.remove(&sql_hash);
646                        self.connection.prepared_statements.remove(&stmt_name);
647                        self.connection.column_info_cache.remove(&sql_hash);
648                    }
649                    return return_with_desync(
650                        self,
651                        unexpected_backend_message("driver fetch_all_cached execute", &other),
652                    );
653                }
654            }
655        }
656    }
657
658    /// Execute a QAIL command (for mutations) - ZERO-ALLOC.
659    pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
660        use crate::protocol::AstEncoder;
661
662        let wire_bytes = AstEncoder::encode_cmd_reuse(
663            cmd,
664            &mut self.connection.sql_buf,
665            &mut self.connection.params_buf,
666        )
667        .map_err(|e| PgError::Encode(e.to_string()))?;
668
669        self.connection.send_bytes(&wire_bytes).await?;
670
671        let mut affected = 0u64;
672        let mut error: Option<PgError> = None;
673        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
674            super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
675        );
676
677        loop {
678            let msg = self.connection.recv().await?;
679            if let Err(err) = flow.validate(&msg, "driver execute mutation", error.is_some()) {
680                return return_with_desync(self, err);
681            }
682            match msg {
683                crate::protocol::BackendMessage::ParseComplete
684                | crate::protocol::BackendMessage::BindComplete => {}
685                crate::protocol::BackendMessage::RowDescription(_) => {}
686                crate::protocol::BackendMessage::DataRow(_) => {}
687                crate::protocol::BackendMessage::NoData => {}
688                crate::protocol::BackendMessage::CommandComplete(tag) => {
689                    if error.is_none() {
690                        match super::parse_affected_rows(&tag) {
691                            Ok(parsed) => affected = parsed,
692                            Err(err) => return return_with_desync(self, err),
693                        }
694                    }
695                }
696                crate::protocol::BackendMessage::ReadyForQuery(_) => {
697                    if let Some(err) = error {
698                        return Err(err);
699                    }
700                    return Ok(affected);
701                }
702                crate::protocol::BackendMessage::ErrorResponse(err) => {
703                    if error.is_none() {
704                        error = Some(PgError::QueryServer(err.into()));
705                    }
706                }
707                msg if is_ignorable_session_message(&msg) => {}
708                other => {
709                    return return_with_desync(
710                        self,
711                        unexpected_backend_message("driver execute mutation", &other),
712                    );
713                }
714            }
715        }
716    }
717
718    /// Query a QAIL command and return rows (for SELECT/GET queries).
719    /// Like `execute()` but collects RowDescription + DataRow messages
720    /// instead of discarding them.
721    pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
722        self.query_ast_with_format(cmd, ResultFormat::Text).await
723    }
724
725    /// Query a QAIL command and return rows using an explicit result format.
726    pub async fn query_ast_with_format(
727        &mut self,
728        cmd: &Qail,
729        result_format: ResultFormat,
730    ) -> PgResult<QueryResult> {
731        use crate::protocol::AstEncoder;
732
733        let wire_bytes = AstEncoder::encode_cmd_reuse_with_result_format(
734            cmd,
735            &mut self.connection.sql_buf,
736            &mut self.connection.params_buf,
737            result_format.as_wire_code(),
738        )
739        .map_err(|e| PgError::Encode(e.to_string()))?;
740
741        self.connection.send_bytes(&wire_bytes).await?;
742
743        let mut columns: Vec<String> = Vec::new();
744        let mut rows: Vec<Vec<Option<String>>> = Vec::new();
745        let mut error: Option<PgError> = None;
746        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
747            super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
748        );
749
750        loop {
751            let msg = self.connection.recv().await?;
752            if let Err(err) = flow.validate(&msg, "driver query_ast", error.is_some()) {
753                return return_with_desync(self, err);
754            }
755            match msg {
756                crate::protocol::BackendMessage::ParseComplete
757                | crate::protocol::BackendMessage::BindComplete => {}
758                crate::protocol::BackendMessage::RowDescription(fields) => {
759                    columns = fields.into_iter().map(|f| f.name).collect();
760                }
761                crate::protocol::BackendMessage::DataRow(data) => {
762                    if error.is_none() {
763                        let row: Vec<Option<String>> = data
764                            .into_iter()
765                            .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
766                            .collect();
767                        rows.push(row);
768                    }
769                }
770                crate::protocol::BackendMessage::CommandComplete(_) => {}
771                crate::protocol::BackendMessage::NoData => {}
772                crate::protocol::BackendMessage::ReadyForQuery(_) => {
773                    if let Some(err) = error {
774                        return Err(err);
775                    }
776                    return Ok(QueryResult { columns, rows });
777                }
778                crate::protocol::BackendMessage::ErrorResponse(err) => {
779                    if error.is_none() {
780                        error = Some(PgError::QueryServer(err.into()));
781                    }
782                }
783                msg if is_ignorable_session_message(&msg) => {}
784                other => {
785                    return return_with_desync(
786                        self,
787                        unexpected_backend_message("driver query_ast", &other),
788                    );
789                }
790            }
791        }
792    }
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798
799    #[test]
800    fn driver_encoded_sql_str_rejects_invalid_utf8() {
801        let err = encoded_sql_str(&[0xff]).expect_err("invalid SQL UTF-8 must fail");
802        assert!(err.to_string().contains("encoded SQL is not UTF-8"));
803    }
804
805    #[cfg(unix)]
806    fn test_driver_with_peer() -> (PgDriver, tokio::net::UnixStream) {
807        use crate::driver::connection::StatementCache;
808        use crate::driver::stream::PgStream;
809        use bytes::BytesMut;
810        use std::collections::{HashMap, VecDeque};
811        use std::num::NonZeroUsize;
812        use tokio::net::UnixStream;
813
814        let (unix_stream, peer) = UnixStream::pair().expect("unix stream pair");
815        let conn = super::super::PgConnection {
816            stream: PgStream::Unix(unix_stream),
817            buffer: BytesMut::with_capacity(1024),
818            write_buf: BytesMut::with_capacity(1024),
819            sql_buf: BytesMut::with_capacity(256),
820            params_buf: Vec::new(),
821            prepared_statements: HashMap::new(),
822            stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
823            column_info_cache: HashMap::new(),
824            process_id: 0,
825            cancel_key_bytes: Vec::new(),
826            requested_protocol_minor: super::super::PgConnection::default_protocol_minor(),
827            negotiated_protocol_minor: super::super::PgConnection::default_protocol_minor(),
828            notifications: VecDeque::new(),
829            replication_stream_active: false,
830            replication_mode_enabled: false,
831            last_replication_wal_end: None,
832            io_desynced: false,
833            pending_statement_closes: Vec::new(),
834            draining_statement_closes: false,
835        };
836        (PgDriver::new(conn), peer)
837    }
838
839    #[cfg(unix)]
840    fn push_backend_frame(driver: &mut PgDriver, msg_type: u8, payload: &[u8]) {
841        driver.connection.buffer.extend_from_slice(&[msg_type]);
842        driver
843            .connection
844            .buffer
845            .extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
846        driver.connection.buffer.extend_from_slice(payload);
847    }
848
849    #[cfg(unix)]
850    fn error_response_payload(code: &str, message: &str) -> Vec<u8> {
851        let mut payload = Vec::new();
852        payload.push(b'S');
853        payload.extend_from_slice(b"ERROR\0");
854        payload.push(b'C');
855        payload.extend_from_slice(code.as_bytes());
856        payload.push(0);
857        payload.push(b'M');
858        payload.extend_from_slice(message.as_bytes());
859        payload.push(0);
860        payload.push(0);
861        payload
862    }
863
864    #[cfg(unix)]
865    fn push_command_complete(driver: &mut PgDriver, tag: &str) {
866        let mut payload = Vec::with_capacity(tag.len() + 1);
867        payload.extend_from_slice(tag.as_bytes());
868        payload.push(0);
869        push_backend_frame(driver, b'C', &payload);
870    }
871
872    #[cfg(unix)]
873    fn prepared_ast_for_sql(sql: &str) -> PreparedAstQuery {
874        use std::collections::hash_map::DefaultHasher;
875        use std::hash::{Hash, Hasher};
876
877        let mut hasher = DefaultHasher::new();
878        sql.hash(&mut hasher);
879
880        PreparedAstQuery {
881            stmt: crate::driver::PreparedStatement::from_sql(sql),
882            params: Vec::new(),
883            sql: sql.to_string(),
884            sql_hash: hasher.finish(),
885        }
886    }
887
888    #[cfg(unix)]
889    #[tokio::test]
890    async fn fetch_fast_protocol_error_marks_driver_connection_desynced() {
891        let (mut driver, _peer) = test_driver_with_peer();
892        push_backend_frame(&mut driver, b'D', &0i16.to_be_bytes());
893
894        let err = match driver.fetch_all_fast(&Qail::get("users")).await {
895            Ok(_) => panic!("out-of-order DataRow must fail"),
896            Err(err) => err,
897        };
898
899        assert!(err.to_string().contains("DataRow before BindComplete"));
900        assert!(driver.connection.is_io_desynced());
901    }
902
903    #[cfg(unix)]
904    #[tokio::test]
905    async fn execute_bad_command_tag_marks_driver_connection_desynced() {
906        let (mut driver, _peer) = test_driver_with_peer();
907        push_backend_frame(&mut driver, b'1', &[]);
908        push_backend_frame(&mut driver, b'2', &[]);
909        push_backend_frame(&mut driver, b'n', &[]);
910        push_command_complete(&mut driver, "UPDATE");
911        push_backend_frame(&mut driver, b'Z', b"I");
912
913        let err = driver
914            .execute(&Qail::get("users"))
915            .await
916            .expect_err("malformed CommandComplete tag must fail");
917
918        assert!(
919            err.to_string().contains("missing affected row count")
920                || err.to_string().contains("invalid affected row count")
921        );
922        assert!(driver.connection.is_io_desynced());
923    }
924
925    #[cfg(unix)]
926    #[tokio::test]
927    async fn prepared_ast_retry_reparses_after_missing_server_statement() {
928        let (mut driver, _peer) = test_driver_with_peer();
929        let prepared = prepared_ast_for_sql("SELECT 1");
930        let stmt_name = prepared.stmt.name().to_string();
931
932        driver
933            .connection
934            .stmt_cache
935            .put(prepared.sql_hash, stmt_name.clone());
936        driver
937            .connection
938            .prepared_statements
939            .insert(stmt_name.clone(), prepared.sql.clone());
940
941        let missing_payload = error_response_payload(
942            "26000",
943            &format!("prepared statement \"{}\" does not exist", stmt_name),
944        );
945
946        // First execution: backend says local prepared state is stale.
947        push_backend_frame(&mut driver, b'E', &missing_payload);
948        push_backend_frame(&mut driver, b'Z', b"I");
949        // Re-prepare: this must consume ParseComplete + ReadyForQuery.
950        push_backend_frame(&mut driver, b'1', &[]);
951        push_backend_frame(&mut driver, b'Z', b"I");
952        // Retried execution succeeds.
953        push_backend_frame(&mut driver, b'2', &[]);
954        push_command_complete(&mut driver, "SELECT 0");
955        push_backend_frame(&mut driver, b'Z', b"I");
956
957        let rows = driver
958            .fetch_all_prepared_ast(&prepared)
959            .await
960            .expect("stale prepared AST handle should reparse and retry once");
961
962        assert!(rows.is_empty());
963        assert!(
964            driver
965                .connection
966                .prepared_statements
967                .contains_key(&stmt_name)
968        );
969        assert!(driver.connection.stmt_cache.contains(&prepared.sql_hash));
970        assert!(!driver.connection.is_io_desynced());
971    }
972}