Skip to main content

qail_pg/driver/
fetch.rs

1//! PgDriver fetch methods: fetch_all (cached/uncached/fast), fetch_typed,
2//! fetch_one, execute, and query_ast.
3
4use super::core::PgDriver;
5use super::prepared::PreparedAstQuery;
6use super::types::*;
7use qail_core::ast::Qail;
8use std::sync::Arc;
9use std::{
10    collections::hash_map::DefaultHasher,
11    hash::{Hash, Hasher},
12};
13
14#[inline]
15fn return_with_desync<T>(driver: &mut PgDriver, err: PgError) -> PgResult<T> {
16    if matches!(
17        err,
18        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
19    ) {
20        driver.connection.mark_io_desynced();
21    }
22    Err(err)
23}
24
25#[inline]
26fn encoded_sql_str(sql_buf: &[u8]) -> PgResult<&str> {
27    std::str::from_utf8(sql_buf)
28        .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))
29}
30
31impl PgDriver {
32    /// Execute a QAIL command and fetch all rows (CACHED + ZERO-ALLOC).
33    /// **Default method** - uses prepared statement caching for best performance.
34    /// On first call: sends Parse + Bind + Execute + Sync
35    /// On subsequent calls with same SQL: sends only Bind + Execute (SKIPS Parse!)
36    /// Uses per-connection LRU cache with max 100 statements (auto-evicts oldest),
37    /// with a hard prepared-statement cap of 128 per connection.
38    pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
39        self.fetch_all_with_format(cmd, ResultFormat::Text).await
40    }
41
42    /// Execute a QAIL command and fetch all rows using a specific result format.
43    ///
44    /// `result_format` controls server result-column encoding:
45    /// - [`ResultFormat::Text`] for standard text decoding.
46    /// - [`ResultFormat::Binary`] for binary wire values.
47    pub async fn fetch_all_with_format(
48        &mut self,
49        cmd: &Qail,
50        result_format: ResultFormat,
51    ) -> PgResult<Vec<PgRow>> {
52        // Delegate to cached-by-default behavior.
53        self.fetch_all_cached_with_format(cmd, result_format).await
54    }
55
56    /// Prepare an AST query once and return a reusable frozen handle.
57    ///
58    /// This is the lowest-overhead path for repeating the **exact same** AST
59    /// command (same SQL text and same bind values). It avoids per-call AST
60    /// encoding and statement-cache hash/lookup in `fetch_all_cached`.
61    pub async fn prepare_ast_query(&mut self, cmd: &Qail) -> PgResult<PreparedAstQuery> {
62        use crate::protocol::AstEncoder;
63
64        let (sql, params) =
65            AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
66        let stmt = self.connection.prepare(&sql).await?;
67
68        let mut hasher = DefaultHasher::new();
69        sql.hash(&mut hasher);
70        let sql_hash = hasher.finish();
71
72        self.connection
73            .stmt_cache
74            .put(sql_hash, stmt.name().to_string());
75        self.connection
76            .prepared_statements
77            .insert(stmt.name().to_string(), sql.clone());
78
79        Ok(PreparedAstQuery {
80            stmt,
81            params,
82            sql,
83            sql_hash,
84        })
85    }
86
87    /// Execute a precompiled AST query handle and return rows.
88    ///
89    /// Rows are returned without `ColumnInfo` metadata (`column_info = None`),
90    /// so prefer positional access (`row.text(0)`, `row.get_i64(1)`, ...).
91    pub async fn fetch_all_prepared_ast(
92        &mut self,
93        prepared: &PreparedAstQuery,
94    ) -> PgResult<Vec<PgRow>> {
95        self.fetch_all_prepared_ast_with_format(prepared, ResultFormat::Text)
96            .await
97    }
98
99    /// Execute a precompiled AST query handle with explicit result format.
100    pub async fn fetch_all_prepared_ast_with_format(
101        &mut self,
102        prepared: &PreparedAstQuery,
103        result_format: ResultFormat,
104    ) -> PgResult<Vec<PgRow>> {
105        let mut retried = false;
106
107        loop {
108            self.connection.stmt_cache.touch_key(prepared.sql_hash);
109            self.connection.write_buf.clear();
110            if let Err(e) = crate::protocol::PgEncoder::encode_bind_to_with_result_format(
111                &mut self.connection.write_buf,
112                prepared.stmt.name(),
113                &prepared.params,
114                result_format.as_wire_code(),
115            ) {
116                return Err(PgError::Encode(e.to_string()));
117            }
118            crate::protocol::PgEncoder::encode_execute_to(&mut self.connection.write_buf);
119            crate::protocol::PgEncoder::encode_sync_to(&mut self.connection.write_buf);
120
121            if let Err(err) = self.connection.flush_write_buf().await {
122                if !retried && err.is_prepared_statement_retryable() {
123                    retried = true;
124                    let stmt = self.connection.prepare(&prepared.sql).await?;
125                    self.connection
126                        .stmt_cache
127                        .put(prepared.sql_hash, stmt.name().to_string());
128                    self.connection
129                        .prepared_statements
130                        .insert(stmt.name().to_string(), prepared.sql.clone());
131                    continue;
132                }
133                return Err(err);
134            }
135
136            let mut rows: Vec<PgRow> = Vec::with_capacity(32);
137            let mut error: Option<PgError> = None;
138            let mut flow = super::extended_flow::ExtendedFlowTracker::new(
139                super::extended_flow::ExtendedFlowConfig::parse_bind_execute(false),
140            );
141
142            loop {
143                let msg = self.connection.recv().await?;
144                if let Err(err) = flow.validate(
145                    &msg,
146                    "driver fetch_all_prepared_ast execute",
147                    error.is_some(),
148                ) {
149                    return return_with_desync(self, err);
150                }
151                match msg {
152                    crate::protocol::BackendMessage::BindComplete => {}
153                    crate::protocol::BackendMessage::RowDescription(_) => {}
154                    crate::protocol::BackendMessage::DataRow(data) => {
155                        if error.is_none() {
156                            rows.push(PgRow {
157                                columns: data,
158                                column_info: None,
159                            });
160                        }
161                    }
162                    crate::protocol::BackendMessage::CommandComplete(_) => {}
163                    crate::protocol::BackendMessage::NoData => {}
164                    crate::protocol::BackendMessage::ReadyForQuery(_) => {
165                        if let Some(err) = error {
166                            if !retried && err.is_prepared_statement_retryable() {
167                                retried = true;
168                                let stmt = self.connection.prepare(&prepared.sql).await?;
169                                self.connection
170                                    .stmt_cache
171                                    .put(prepared.sql_hash, stmt.name().to_string());
172                                self.connection
173                                    .prepared_statements
174                                    .insert(stmt.name().to_string(), prepared.sql.clone());
175                                break;
176                            }
177                            return Err(err);
178                        }
179                        return Ok(rows);
180                    }
181                    crate::protocol::BackendMessage::ErrorResponse(err) => {
182                        if error.is_none() {
183                            error = Some(PgError::QueryServer(err.into()));
184                        }
185                    }
186                    msg if is_ignorable_session_message(&msg) => {}
187                    other => {
188                        return return_with_desync(
189                            self,
190                            unexpected_backend_message(
191                                "driver fetch_all_prepared_ast execute",
192                                &other,
193                            ),
194                        );
195                    }
196                }
197            }
198        }
199    }
200
201    /// Execute a QAIL command and fetch all rows as a typed struct (text format).
202    /// Requires the target type to implement `QailRow` trait.
203    ///
204    /// # Example
205    /// ```ignore
206    /// let users: Vec<User> = driver.fetch_typed::<User>(&query).await?;
207    /// ```
208    pub async fn fetch_typed<T: super::row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
209        self.fetch_typed_with_format(cmd, ResultFormat::Text).await
210    }
211
212    /// Execute a QAIL command and fetch all rows as a typed struct with explicit result format.
213    ///
214    /// Use [`ResultFormat::Binary`] to get binary wire values; row decoding should use
215    /// metadata-aware accessors such as `PgRow::try_get()` / `try_get_by_name()`.
216    pub async fn fetch_typed_with_format<T: super::row::QailRow>(
217        &mut self,
218        cmd: &Qail,
219        result_format: ResultFormat,
220    ) -> PgResult<Vec<T>> {
221        let rows = self.fetch_all_with_format(cmd, result_format).await?;
222        Ok(rows.iter().map(T::from_row).collect())
223    }
224
225    /// Execute a QAIL command and fetch a single row as a typed struct (text format).
226    /// Returns None if no rows are returned.
227    pub async fn fetch_one_typed<T: super::row::QailRow>(
228        &mut self,
229        cmd: &Qail,
230    ) -> PgResult<Option<T>> {
231        self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
232            .await
233    }
234
235    /// Execute a QAIL command and fetch a single row as a typed struct with explicit result format.
236    pub async fn fetch_one_typed_with_format<T: super::row::QailRow>(
237        &mut self,
238        cmd: &Qail,
239        result_format: ResultFormat,
240    ) -> PgResult<Option<T>> {
241        let rows = self.fetch_all_with_format(cmd, result_format).await?;
242        Ok(rows.first().map(T::from_row))
243    }
244
245    /// Execute a QAIL command and fetch all rows (UNCACHED).
246    /// Sends Parse + Bind + Execute on every call.
247    /// Use for one-off queries or when caching is not desired.
248    ///
249    /// Optimized: encodes wire bytes into reusable write_buf (zero-alloc).
250    pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
251        self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
252            .await
253    }
254
255    /// Execute a QAIL command and fetch all rows (UNCACHED) with explicit result format.
256    pub async fn fetch_all_uncached_with_format(
257        &mut self,
258        cmd: &Qail,
259        result_format: ResultFormat,
260    ) -> PgResult<Vec<PgRow>> {
261        use crate::protocol::AstEncoder;
262
263        AstEncoder::encode_cmd_reuse_into_with_result_format(
264            cmd,
265            &mut self.connection.sql_buf,
266            &mut self.connection.params_buf,
267            &mut self.connection.write_buf,
268            result_format.as_wire_code(),
269        )
270        .map_err(|e| PgError::Encode(e.to_string()))?;
271
272        self.connection.flush_write_buf().await?;
273
274        let mut rows: Vec<PgRow> = Vec::with_capacity(32);
275        let mut column_info: Option<Arc<ColumnInfo>> = None;
276
277        let mut error: Option<PgError> = None;
278        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
279            super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
280        );
281
282        loop {
283            let msg = self.connection.recv().await?;
284            if let Err(err) = flow.validate(&msg, "driver fetch_all execute", error.is_some()) {
285                return return_with_desync(self, err);
286            }
287            match msg {
288                crate::protocol::BackendMessage::ParseComplete
289                | crate::protocol::BackendMessage::BindComplete => {}
290                crate::protocol::BackendMessage::RowDescription(fields) => {
291                    column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
292                }
293                crate::protocol::BackendMessage::DataRow(data) => {
294                    if error.is_none() {
295                        rows.push(PgRow {
296                            columns: data,
297                            column_info: column_info.clone(),
298                        });
299                    }
300                }
301                crate::protocol::BackendMessage::NoData => {}
302                crate::protocol::BackendMessage::CommandComplete(_) => {}
303                crate::protocol::BackendMessage::ReadyForQuery(_) => {
304                    if let Some(err) = error {
305                        return Err(err);
306                    }
307                    return Ok(rows);
308                }
309                crate::protocol::BackendMessage::ErrorResponse(err) => {
310                    if error.is_none() {
311                        error = Some(PgError::QueryServer(err.into()));
312                    }
313                }
314                msg if is_ignorable_session_message(&msg) => {}
315                other => {
316                    return return_with_desync(
317                        self,
318                        unexpected_backend_message("driver fetch_all execute", &other),
319                    );
320                }
321            }
322        }
323    }
324
325    /// Execute a QAIL command and fetch all rows (FAST VERSION).
326    /// Uses optimized recv_with_data_fast for faster response parsing.
327    /// Skips column metadata collection for maximum speed.
328    pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
329        self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
330            .await
331    }
332
333    /// Execute a QAIL command and fetch all rows (FAST VERSION) with explicit result format.
334    pub async fn fetch_all_fast_with_format(
335        &mut self,
336        cmd: &Qail,
337        result_format: ResultFormat,
338    ) -> PgResult<Vec<PgRow>> {
339        use crate::protocol::AstEncoder;
340
341        AstEncoder::encode_cmd_reuse_into_with_result_format(
342            cmd,
343            &mut self.connection.sql_buf,
344            &mut self.connection.params_buf,
345            &mut self.connection.write_buf,
346            result_format.as_wire_code(),
347        )
348        .map_err(|e| PgError::Encode(e.to_string()))?;
349
350        self.connection.flush_write_buf().await?;
351
352        // Collect results using FAST receiver
353        let mut rows: Vec<PgRow> = Vec::with_capacity(32);
354        let mut error: Option<PgError> = None;
355        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
356            super::extended_flow::ExtendedFlowConfig::parse_bind_execute(true),
357        );
358
359        loop {
360            let res = self.connection.recv_with_data_fast().await;
361            match res {
362                Ok((msg_type, data)) => {
363                    if let Err(err) = flow.validate_msg_type(
364                        msg_type,
365                        "driver fetch_all_fast execute",
366                        error.is_some(),
367                    ) {
368                        return return_with_desync(self, err);
369                    }
370                    match msg_type {
371                        b'D' => {
372                            if error.is_none()
373                                && let Some(columns) = data
374                            {
375                                rows.push(PgRow {
376                                    columns,
377                                    column_info: None,
378                                });
379                            }
380                        }
381                        b'Z' => {
382                            if let Some(err) = error {
383                                return Err(err);
384                            }
385                            return Ok(rows);
386                        }
387                        _ => {}
388                    }
389                }
390                Err(e) => {
391                    // QueryServer means backend sent ErrorResponse; keep draining to ReadyForQuery.
392                    if matches!(&e, PgError::QueryServer(_)) {
393                        if error.is_none() {
394                            error = Some(e);
395                        }
396                        continue;
397                    }
398                    return Err(e);
399                }
400            }
401        }
402    }
403
404    /// Execute a QAIL command and fetch one row.
405    pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
406        let rows = self.fetch_all(cmd).await?;
407        rows.into_iter().next().ok_or(PgError::NoRows)
408    }
409
410    /// Execute a QAIL command with PREPARED STATEMENT CACHING.
411    /// Like fetch_all(), but caches the prepared statement on the server.
412    /// On first call: sends Parse + Describe + Bind + Execute + Sync
413    /// On subsequent calls: sends only Bind + Execute + Sync (SKIPS Parse!)
414    /// Column metadata (RowDescription) is cached alongside the statement
415    /// so that by-name column access works on every call.
416    ///
417    /// Optimized: all wire messages are batched into a single write_all syscall.
418    pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
419        self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
420            .await
421    }
422
423    /// Execute a QAIL command with prepared statement caching and explicit result format.
424    pub async fn fetch_all_cached_with_format(
425        &mut self,
426        cmd: &Qail,
427        result_format: ResultFormat,
428    ) -> PgResult<Vec<PgRow>> {
429        let mut retried = false;
430        loop {
431            match self
432                .fetch_all_cached_with_format_once(cmd, result_format)
433                .await
434            {
435                Ok(rows) => return Ok(rows),
436                Err(err)
437                    if !retried
438                        && (err.is_prepared_statement_retryable()
439                            || err.is_prepared_statement_already_exists()) =>
440                {
441                    retried = true;
442                    if err.is_prepared_statement_retryable() {
443                        self.connection.clear_prepared_statement_state();
444                    }
445                }
446                Err(err) => return Err(err),
447            }
448        }
449    }
450
451    async fn fetch_all_cached_with_format_once(
452        &mut self,
453        cmd: &Qail,
454        result_format: ResultFormat,
455    ) -> PgResult<Vec<PgRow>> {
456        use crate::protocol::AstEncoder;
457        use std::collections::hash_map::DefaultHasher;
458        use std::hash::{Hash, Hasher};
459
460        if !AstEncoder::encode_cacheable_cmd_sql_to(
461            cmd,
462            &mut self.connection.sql_buf,
463            &mut self.connection.params_buf,
464        )? {
465            // Fallback for unsupported actions
466            let (sql, params) =
467                AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
468            let raw_rows = self
469                .connection
470                .query_cached_with_result_format(&sql, &params, result_format.as_wire_code())
471                .await?;
472            return Ok(raw_rows
473                .into_iter()
474                .map(|data| PgRow {
475                    columns: data,
476                    column_info: None,
477                })
478                .collect());
479        }
480
481        let mut hasher = DefaultHasher::new();
482        self.connection.sql_buf.hash(&mut hasher);
483        let sql_hash = hasher.finish();
484
485        let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
486
487        // Build ALL wire messages into write_buf (single syscall)
488        self.connection.write_buf.clear();
489
490        let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
491            name
492        } else {
493            let name = format!("qail_{:x}", sql_hash);
494
495            // Evict LRU before borrowing sql_buf to avoid borrow conflict
496            self.connection.evict_prepared_if_full();
497
498            let sql_str = encoded_sql_str(&self.connection.sql_buf)?;
499
500            // Buffer Parse + Describe(Statement) for first call
501            use crate::protocol::PgEncoder;
502            let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
503            let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
504            self.connection.write_buf.extend_from_slice(&parse_msg);
505            self.connection.write_buf.extend_from_slice(&describe_msg);
506
507            self.connection.stmt_cache.put(sql_hash, name.clone());
508            self.connection
509                .prepared_statements
510                .insert(name.clone(), sql_str.to_string());
511
512            name
513        };
514
515        // Append Bind + Execute + Sync to same buffer
516        use crate::protocol::PgEncoder;
517        if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
518            &mut self.connection.write_buf,
519            &stmt_name,
520            &self.connection.params_buf,
521            result_format.as_wire_code(),
522        ) {
523            if is_cache_miss {
524                self.connection.stmt_cache.remove(&sql_hash);
525                self.connection.prepared_statements.remove(&stmt_name);
526                self.connection.column_info_cache.remove(&sql_hash);
527            }
528            return Err(PgError::Encode(e.to_string()));
529        }
530        PgEncoder::encode_execute_to(&mut self.connection.write_buf);
531        PgEncoder::encode_sync_to(&mut self.connection.write_buf);
532
533        // Single write_all syscall for all messages
534        if let Err(err) = self.connection.flush_write_buf().await {
535            if is_cache_miss {
536                self.connection.stmt_cache.remove(&sql_hash);
537                self.connection.prepared_statements.remove(&stmt_name);
538                self.connection.column_info_cache.remove(&sql_hash);
539            }
540            return Err(err);
541        }
542
543        // On cache hit, use the previously cached ColumnInfo
544        let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
545
546        let mut rows: Vec<PgRow> = Vec::with_capacity(32);
547        let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
548        let mut error: Option<PgError> = None;
549        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
550            super::extended_flow::ExtendedFlowConfig::parse_describe_statement_bind_execute(
551                is_cache_miss,
552            ),
553        );
554
555        loop {
556            let msg = match self.connection.recv().await {
557                Ok(msg) => msg,
558                Err(err) => {
559                    if is_cache_miss && !flow.saw_parse_complete() {
560                        self.connection.stmt_cache.remove(&sql_hash);
561                        self.connection.prepared_statements.remove(&stmt_name);
562                        self.connection.column_info_cache.remove(&sql_hash);
563                    }
564                    return Err(err);
565                }
566            };
567            if let Err(err) =
568                flow.validate(&msg, "driver fetch_all_cached execute", error.is_some())
569            {
570                if is_cache_miss && !flow.saw_parse_complete() {
571                    self.connection.stmt_cache.remove(&sql_hash);
572                    self.connection.prepared_statements.remove(&stmt_name);
573                    self.connection.column_info_cache.remove(&sql_hash);
574                }
575                return return_with_desync(self, err);
576            }
577            match msg {
578                crate::protocol::BackendMessage::ParseComplete => {}
579                crate::protocol::BackendMessage::BindComplete => {}
580                crate::protocol::BackendMessage::ParameterDescription(_) => {
581                    // Sent after Describe(Statement) — ignore
582                }
583                crate::protocol::BackendMessage::RowDescription(fields) => {
584                    // Received after Describe(Statement) on cache miss
585                    let info = Arc::new(ColumnInfo::from_fields(&fields));
586                    if is_cache_miss {
587                        self.connection
588                            .column_info_cache
589                            .insert(sql_hash, Arc::clone(&info));
590                    }
591                    column_info = Some(info);
592                }
593                crate::protocol::BackendMessage::DataRow(data) => {
594                    if error.is_none() {
595                        rows.push(PgRow {
596                            columns: data,
597                            column_info: column_info.clone(),
598                        });
599                    }
600                }
601                crate::protocol::BackendMessage::CommandComplete(_) => {}
602                crate::protocol::BackendMessage::NoData => {
603                    // Sent by Describe for statements that return no data (e.g. pure UPDATE without RETURNING)
604                }
605                crate::protocol::BackendMessage::ReadyForQuery(_) => {
606                    if let Some(err) = error {
607                        if is_cache_miss
608                            && !flow.saw_parse_complete()
609                            && !err.is_prepared_statement_already_exists()
610                        {
611                            self.connection.stmt_cache.remove(&sql_hash);
612                            self.connection.prepared_statements.remove(&stmt_name);
613                            self.connection.column_info_cache.remove(&sql_hash);
614                        }
615                        return Err(err);
616                    }
617                    if is_cache_miss && !flow.saw_parse_complete() {
618                        self.connection.stmt_cache.remove(&sql_hash);
619                        self.connection.prepared_statements.remove(&stmt_name);
620                        self.connection.column_info_cache.remove(&sql_hash);
621                        return return_with_desync(
622                            self,
623                            PgError::Protocol(
624                                "Cache miss query reached ReadyForQuery without ParseComplete"
625                                    .to_string(),
626                            ),
627                        );
628                    }
629                    return Ok(rows);
630                }
631                crate::protocol::BackendMessage::ErrorResponse(err) => {
632                    if error.is_none() {
633                        let query_err = PgError::QueryServer(err.into());
634                        if query_err.is_prepared_statement_retryable() {
635                            self.connection.clear_prepared_statement_state();
636                        }
637                        error = Some(query_err);
638                    }
639                }
640                msg if is_ignorable_session_message(&msg) => {}
641                other => {
642                    if is_cache_miss && !flow.saw_parse_complete() {
643                        self.connection.stmt_cache.remove(&sql_hash);
644                        self.connection.prepared_statements.remove(&stmt_name);
645                        self.connection.column_info_cache.remove(&sql_hash);
646                    }
647                    return return_with_desync(
648                        self,
649                        unexpected_backend_message("driver fetch_all_cached execute", &other),
650                    );
651                }
652            }
653        }
654    }
655
656    /// Execute a QAIL command (for mutations) - ZERO-ALLOC.
657    pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
658        use crate::protocol::AstEncoder;
659
660        let wire_bytes = AstEncoder::encode_cmd_reuse(
661            cmd,
662            &mut self.connection.sql_buf,
663            &mut self.connection.params_buf,
664        )
665        .map_err(|e| PgError::Encode(e.to_string()))?;
666
667        self.connection.send_bytes(&wire_bytes).await?;
668
669        let mut affected = 0u64;
670        let mut error: Option<PgError> = None;
671        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
672            super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
673        );
674
675        loop {
676            let msg = self.connection.recv().await?;
677            if let Err(err) = flow.validate(&msg, "driver execute mutation", error.is_some()) {
678                return return_with_desync(self, err);
679            }
680            match msg {
681                crate::protocol::BackendMessage::ParseComplete
682                | crate::protocol::BackendMessage::BindComplete => {}
683                crate::protocol::BackendMessage::RowDescription(_) => {}
684                crate::protocol::BackendMessage::DataRow(_) => {}
685                crate::protocol::BackendMessage::NoData => {}
686                crate::protocol::BackendMessage::CommandComplete(tag) => {
687                    if error.is_none() {
688                        match super::parse_affected_rows(&tag) {
689                            Ok(parsed) => affected = parsed,
690                            Err(err) => return return_with_desync(self, err),
691                        }
692                    }
693                }
694                crate::protocol::BackendMessage::ReadyForQuery(_) => {
695                    if let Some(err) = error {
696                        return Err(err);
697                    }
698                    return Ok(affected);
699                }
700                crate::protocol::BackendMessage::ErrorResponse(err) => {
701                    if error.is_none() {
702                        error = Some(PgError::QueryServer(err.into()));
703                    }
704                }
705                msg if is_ignorable_session_message(&msg) => {}
706                other => {
707                    return return_with_desync(
708                        self,
709                        unexpected_backend_message("driver execute mutation", &other),
710                    );
711                }
712            }
713        }
714    }
715
716    /// Query a QAIL command and return rows (for SELECT/GET queries).
717    /// Like `execute()` but collects RowDescription + DataRow messages
718    /// instead of discarding them.
719    pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
720        self.query_ast_with_format(cmd, ResultFormat::Text).await
721    }
722
723    /// Query a QAIL command and return rows using an explicit result format.
724    pub async fn query_ast_with_format(
725        &mut self,
726        cmd: &Qail,
727        result_format: ResultFormat,
728    ) -> PgResult<QueryResult> {
729        use crate::protocol::AstEncoder;
730
731        let wire_bytes = AstEncoder::encode_cmd_reuse_with_result_format(
732            cmd,
733            &mut self.connection.sql_buf,
734            &mut self.connection.params_buf,
735            result_format.as_wire_code(),
736        )
737        .map_err(|e| PgError::Encode(e.to_string()))?;
738
739        self.connection.send_bytes(&wire_bytes).await?;
740
741        let mut columns: Vec<String> = Vec::new();
742        let mut rows: Vec<Vec<Option<String>>> = Vec::new();
743        let mut error: Option<PgError> = None;
744        let mut flow = super::extended_flow::ExtendedFlowTracker::new(
745            super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
746        );
747
748        loop {
749            let msg = self.connection.recv().await?;
750            if let Err(err) = flow.validate(&msg, "driver query_ast", error.is_some()) {
751                return return_with_desync(self, err);
752            }
753            match msg {
754                crate::protocol::BackendMessage::ParseComplete
755                | crate::protocol::BackendMessage::BindComplete => {}
756                crate::protocol::BackendMessage::RowDescription(fields) => {
757                    columns = fields.into_iter().map(|f| f.name).collect();
758                }
759                crate::protocol::BackendMessage::DataRow(data) => {
760                    if error.is_none() {
761                        let row: Vec<Option<String>> = data
762                            .into_iter()
763                            .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
764                            .collect();
765                        rows.push(row);
766                    }
767                }
768                crate::protocol::BackendMessage::CommandComplete(_) => {}
769                crate::protocol::BackendMessage::NoData => {}
770                crate::protocol::BackendMessage::ReadyForQuery(_) => {
771                    if let Some(err) = error {
772                        return Err(err);
773                    }
774                    return Ok(QueryResult { columns, rows });
775                }
776                crate::protocol::BackendMessage::ErrorResponse(err) => {
777                    if error.is_none() {
778                        error = Some(PgError::QueryServer(err.into()));
779                    }
780                }
781                msg if is_ignorable_session_message(&msg) => {}
782                other => {
783                    return return_with_desync(
784                        self,
785                        unexpected_backend_message("driver query_ast", &other),
786                    );
787                }
788            }
789        }
790    }
791}
792
793#[cfg(test)]
794mod tests {
795    use super::*;
796
797    #[test]
798    fn driver_encoded_sql_str_rejects_invalid_utf8() {
799        let err = encoded_sql_str(&[0xff]).expect_err("invalid SQL UTF-8 must fail");
800        assert!(err.to_string().contains("encoded SQL is not UTF-8"));
801    }
802
803    #[cfg(unix)]
804    fn test_driver_with_peer() -> (PgDriver, tokio::net::UnixStream) {
805        use crate::driver::connection::StatementCache;
806        use crate::driver::stream::PgStream;
807        use bytes::BytesMut;
808        use std::collections::{HashMap, VecDeque};
809        use std::num::NonZeroUsize;
810        use tokio::net::UnixStream;
811
812        let (unix_stream, peer) = UnixStream::pair().expect("unix stream pair");
813        let conn = super::super::PgConnection {
814            stream: PgStream::Unix(unix_stream),
815            buffer: BytesMut::with_capacity(1024),
816            write_buf: BytesMut::with_capacity(1024),
817            sql_buf: BytesMut::with_capacity(256),
818            params_buf: Vec::new(),
819            prepared_statements: HashMap::new(),
820            stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
821            column_info_cache: HashMap::new(),
822            process_id: 0,
823            cancel_key_bytes: Vec::new(),
824            requested_protocol_minor: super::super::PgConnection::default_protocol_minor(),
825            negotiated_protocol_minor: super::super::PgConnection::default_protocol_minor(),
826            notifications: VecDeque::new(),
827            replication_stream_active: false,
828            replication_mode_enabled: false,
829            last_replication_wal_end: None,
830            io_desynced: false,
831            pending_statement_closes: Vec::new(),
832            draining_statement_closes: false,
833        };
834        (PgDriver::new(conn), peer)
835    }
836
837    #[cfg(unix)]
838    fn push_backend_frame(driver: &mut PgDriver, msg_type: u8, payload: &[u8]) {
839        driver.connection.buffer.extend_from_slice(&[msg_type]);
840        driver
841            .connection
842            .buffer
843            .extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
844        driver.connection.buffer.extend_from_slice(payload);
845    }
846
847    #[cfg(unix)]
848    fn push_command_complete(driver: &mut PgDriver, tag: &str) {
849        let mut payload = Vec::with_capacity(tag.len() + 1);
850        payload.extend_from_slice(tag.as_bytes());
851        payload.push(0);
852        push_backend_frame(driver, b'C', &payload);
853    }
854
855    #[cfg(unix)]
856    #[tokio::test]
857    async fn fetch_fast_protocol_error_marks_driver_connection_desynced() {
858        let (mut driver, _peer) = test_driver_with_peer();
859        push_backend_frame(&mut driver, b'D', &0i16.to_be_bytes());
860
861        let err = match driver.fetch_all_fast(&Qail::get("users")).await {
862            Ok(_) => panic!("out-of-order DataRow must fail"),
863            Err(err) => err,
864        };
865
866        assert!(err.to_string().contains("DataRow before BindComplete"));
867        assert!(driver.connection.is_io_desynced());
868    }
869
870    #[cfg(unix)]
871    #[tokio::test]
872    async fn execute_bad_command_tag_marks_driver_connection_desynced() {
873        let (mut driver, _peer) = test_driver_with_peer();
874        push_backend_frame(&mut driver, b'1', &[]);
875        push_backend_frame(&mut driver, b'2', &[]);
876        push_backend_frame(&mut driver, b'n', &[]);
877        push_command_complete(&mut driver, "UPDATE");
878        push_backend_frame(&mut driver, b'Z', b"I");
879
880        let err = driver
881            .execute(&Qail::get("users"))
882            .await
883            .expect_err("malformed CommandComplete tag must fail");
884
885        assert!(
886            err.to_string().contains("missing affected row count")
887                || err.to_string().contains("invalid affected row count")
888        );
889        assert!(driver.connection.is_io_desynced());
890    }
891}