Skip to main content

qail_pg/driver/
pipeline.rs

1//! High-performance pipelining methods for PostgreSQL connection.
2//!
3//!
4//! Performance hierarchy (fastest to slowest):
5//! 1. `pipeline_execute_count_ast_cached` - Parse once, Bind+Execute many (275k q/s)
6//! 2. `pipeline_execute_count_simple_wire` - Pre-encoded simple query
7//! 3. `pipeline_execute_count_wire` - Pre-encoded extended query
8//! 4. `pipeline_execute_count_simple_ast` - Simple query protocol (~99k q/s)
9//! 5. `pipeline_execute_count_ast_oneshot` - Fast extended query, count only
10//! 6. `pipeline_execute_rows_ast` - Full results collection
11//! 7. `query_pipeline` - SQL-based pipelining
12
13use super::{
14    PgConnection, PgError, PgResult, is_ignorable_session_message, is_ignorable_session_msg_type,
15    unexpected_backend_message, unexpected_backend_msg_type,
16};
17use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
18use bytes::BytesMut;
19
20/// Strategy for AST pipeline execution.
21///
22/// `Auto` favors the cached prepared-statement path for large batches and
23/// one-shot execution for tiny batches.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum AstPipelineMode {
26    /// Heuristic strategy:
27    /// - small batch => `OneShot`
28    /// - larger batch => `Cached`
29    #[default]
30    Auto,
31    /// Parse+Bind+Execute for each command in the batch.
32    OneShot,
33    /// Cache prepared SQL templates and execute Bind+Execute in hot path.
34    Cached,
35}
36
37impl AstPipelineMode {
38    const AUTO_CACHE_MIN_BATCH: usize = 8;
39
40    #[inline]
41    fn resolve_for_batch_len(self, batch_len: usize) -> Self {
42        match self {
43            Self::Auto => {
44                if batch_len >= Self::AUTO_CACHE_MIN_BATCH {
45                    Self::Cached
46                } else {
47                    Self::OneShot
48                }
49            }
50            mode => mode,
51        }
52    }
53}
54
55#[inline]
56fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
57    if matches!(
58        err,
59        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
60    ) {
61        conn.mark_io_desynced();
62    }
63    Err(err)
64}
65
66#[inline]
67fn capture_query_server_error(conn: &mut PgConnection, slot: &mut Option<PgError>, err: PgError) {
68    if slot.is_some() {
69        return;
70    }
71    if err.is_prepared_statement_retryable() {
72        conn.clear_prepared_statement_state();
73    }
74    *slot = Some(err);
75}
76
77#[inline]
78fn rollback_new_cached_statements(conn: &mut PgConnection, new_stmt_hashes: &[u64]) {
79    for sql_hash in new_stmt_hashes {
80        if let Some(stmt_name) = conn.stmt_cache.remove(sql_hash) {
81            conn.prepared_statements.remove(&stmt_name);
82        }
83    }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct FastExtendedFlowConfig {
88    expected_queries: usize,
89    allow_parse_complete: bool,
90    require_parse_before_bind: bool,
91    no_data_counts_as_completion: bool,
92    allow_no_data_nonterminal: bool,
93    expected_parse_completes: Option<usize>,
94}
95
96#[derive(Debug, Clone, Copy)]
97struct FastExtendedFlowTracker {
98    cfg: FastExtendedFlowConfig,
99    completed_queries: usize,
100    parse_completes: usize,
101    current_parse_seen: bool,
102    current_bind_seen: bool,
103}
104
105impl FastExtendedFlowTracker {
106    fn new(cfg: FastExtendedFlowConfig) -> Self {
107        Self {
108            cfg,
109            completed_queries: 0,
110            parse_completes: 0,
111            current_parse_seen: false,
112            current_bind_seen: false,
113        }
114    }
115
116    fn completed_queries(&self) -> usize {
117        self.completed_queries
118    }
119
120    fn validate_msg_type(
121        &mut self,
122        msg_type: u8,
123        context: &'static str,
124        error_pending: bool,
125    ) -> PgResult<FastPipelineEvent> {
126        if is_ignorable_session_msg_type(msg_type) {
127            return Ok(FastPipelineEvent::Continue);
128        }
129
130        if error_pending {
131            if msg_type == b'Z' {
132                return Ok(FastPipelineEvent::ReadyForQuery);
133            }
134            return Ok(FastPipelineEvent::Continue);
135        }
136
137        if msg_type == b'Z' {
138            if self.completed_queries != self.cfg.expected_queries {
139                return Err(PgError::Protocol(format!(
140                    "{}: Pipeline completion mismatch: expected {}, got {}",
141                    context, self.cfg.expected_queries, self.completed_queries
142                )));
143            }
144            if self.current_parse_seen || self.current_bind_seen {
145                return Err(PgError::Protocol(format!(
146                    "{}: ReadyForQuery with incomplete query state",
147                    context
148                )));
149            }
150            if let Some(expected) = self.cfg.expected_parse_completes
151                && self.parse_completes != expected
152            {
153                return Err(PgError::Protocol(format!(
154                    "{}: ParseComplete mismatch: expected {}, got {}",
155                    context, expected, self.parse_completes
156                )));
157            }
158            return Ok(FastPipelineEvent::ReadyForQuery);
159        }
160
161        if self.completed_queries >= self.cfg.expected_queries {
162            return Err(PgError::Protocol(format!(
163                "{}: unexpected message '{}' after all queries completed",
164                context, msg_type as char
165            )));
166        }
167
168        match msg_type {
169            b'1' => {
170                if !self.cfg.allow_parse_complete {
171                    return Err(PgError::Protocol(format!(
172                        "{}: unexpected ParseComplete",
173                        context
174                    )));
175                }
176                if self.current_bind_seen {
177                    return Err(PgError::Protocol(format!(
178                        "{}: ParseComplete after BindComplete",
179                        context
180                    )));
181                }
182                if self.current_parse_seen {
183                    return Err(PgError::Protocol(format!(
184                        "{}: duplicate ParseComplete",
185                        context
186                    )));
187                }
188                self.current_parse_seen = true;
189                self.parse_completes += 1;
190                if let Some(expected) = self.cfg.expected_parse_completes
191                    && self.parse_completes > expected
192                {
193                    return Err(PgError::Protocol(format!(
194                        "{}: ParseComplete mismatch: expected {}, got at least {}",
195                        context, expected, self.parse_completes
196                    )));
197                }
198            }
199            b'2' => {
200                if self.current_bind_seen {
201                    return Err(PgError::Protocol(format!(
202                        "{}: duplicate BindComplete",
203                        context
204                    )));
205                }
206                if self.cfg.require_parse_before_bind && !self.current_parse_seen {
207                    return Err(PgError::Protocol(format!(
208                        "{}: BindComplete before ParseComplete",
209                        context
210                    )));
211                }
212                self.current_bind_seen = true;
213            }
214            b'T' | b't' | b's' => {
215                if !self.current_bind_seen {
216                    return Err(PgError::Protocol(format!(
217                        "{}: '{}' before BindComplete",
218                        context, msg_type as char
219                    )));
220                }
221            }
222            b'D' => {
223                if !self.current_bind_seen {
224                    return Err(PgError::Protocol(format!(
225                        "{}: DataRow before BindComplete",
226                        context
227                    )));
228                }
229            }
230            b'n' => {
231                if !self.current_bind_seen {
232                    return Err(PgError::Protocol(format!(
233                        "{}: NoData before BindComplete",
234                        context
235                    )));
236                }
237                if self.cfg.no_data_counts_as_completion {
238                    self.complete_current();
239                } else if !self.cfg.allow_no_data_nonterminal {
240                    return Err(PgError::Protocol(format!("{}: unexpected NoData", context)));
241                }
242            }
243            b'C' => {
244                if !self.current_bind_seen {
245                    return Err(PgError::Protocol(format!(
246                        "{}: CommandComplete before BindComplete",
247                        context
248                    )));
249                }
250                self.complete_current();
251            }
252            b'I' => {
253                return Err(PgError::Protocol(format!(
254                    "{}: unexpected EmptyQueryResponse in extended pipeline",
255                    context
256                )));
257            }
258            other => return Err(unexpected_backend_msg_type(context, other)),
259        }
260
261        Ok(FastPipelineEvent::Continue)
262    }
263
264    fn complete_current(&mut self) {
265        self.completed_queries += 1;
266        self.current_parse_seen = false;
267        self.current_bind_seen = false;
268    }
269}
270
271#[derive(Debug, Clone, Copy)]
272struct FastSimpleFlowTracker {
273    expected_queries: usize,
274    completed_queries: usize,
275    current_row_description_seen: bool,
276}
277
278impl FastSimpleFlowTracker {
279    fn new(expected_queries: usize) -> Self {
280        Self {
281            expected_queries,
282            completed_queries: 0,
283            current_row_description_seen: false,
284        }
285    }
286
287    fn completed_queries(&self) -> usize {
288        self.completed_queries
289    }
290
291    fn validate_msg_type(
292        &mut self,
293        msg_type: u8,
294        context: &'static str,
295        error_pending: bool,
296    ) -> PgResult<FastPipelineEvent> {
297        if is_ignorable_session_msg_type(msg_type) {
298            return Ok(FastPipelineEvent::Continue);
299        }
300
301        if error_pending {
302            if msg_type == b'Z' {
303                return Ok(FastPipelineEvent::ReadyForQuery);
304            }
305            return Ok(FastPipelineEvent::Continue);
306        }
307
308        if msg_type == b'Z' {
309            if self.completed_queries != self.expected_queries {
310                return Err(PgError::Protocol(format!(
311                    "{}: Pipeline completion mismatch: expected {}, got {}",
312                    context, self.expected_queries, self.completed_queries
313                )));
314            }
315            if self.current_row_description_seen {
316                return Err(PgError::Protocol(format!(
317                    "{}: ReadyForQuery with incomplete row stream",
318                    context
319                )));
320            }
321            return Ok(FastPipelineEvent::ReadyForQuery);
322        }
323
324        if self.completed_queries >= self.expected_queries {
325            return Err(PgError::Protocol(format!(
326                "{}: unexpected message '{}' after all queries completed",
327                context, msg_type as char
328            )));
329        }
330
331        match msg_type {
332            b'T' => {
333                if self.current_row_description_seen {
334                    return Err(PgError::Protocol(format!(
335                        "{}: duplicate RowDescription",
336                        context
337                    )));
338                }
339                self.current_row_description_seen = true;
340            }
341            b'D' => {
342                if !self.current_row_description_seen {
343                    return Err(PgError::Protocol(format!(
344                        "{}: DataRow before RowDescription",
345                        context
346                    )));
347                }
348            }
349            b'C' | b'I' => {
350                self.completed_queries += 1;
351                self.current_row_description_seen = false;
352            }
353            b'1' | b'2' | b'n' | b't' | b's' => {
354                return Err(PgError::Protocol(format!(
355                    "{}: unexpected '{}' in simple pipeline",
356                    context, msg_type as char
357                )));
358            }
359            other => return Err(unexpected_backend_msg_type(context, other)),
360        }
361
362        Ok(FastPipelineEvent::Continue)
363    }
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq)]
367enum FastPipelineEvent {
368    Continue,
369    ReadyForQuery,
370}
371
372#[inline]
373fn backend_msg_type_for_flow(msg: &BackendMessage) -> Option<u8> {
374    match msg {
375        BackendMessage::ParseComplete => Some(b'1'),
376        BackendMessage::BindComplete => Some(b'2'),
377        BackendMessage::ParameterDescription(_) => Some(b't'),
378        BackendMessage::RowDescription(_) => Some(b'T'),
379        BackendMessage::NoData => Some(b'n'),
380        BackendMessage::PortalSuspended => Some(b's'),
381        BackendMessage::DataRow(_) => Some(b'D'),
382        BackendMessage::CommandComplete(_) => Some(b'C'),
383        BackendMessage::EmptyQueryResponse => Some(b'I'),
384        BackendMessage::ReadyForQuery(_) => Some(b'Z'),
385        _ => None,
386    }
387}
388
389impl PgConnection {
390    /// Execute multiple SQL queries in a single network round-trip (PIPELINING).
391    pub async fn query_pipeline(
392        &mut self,
393        queries: &[(&str, &[Option<Vec<u8>>])],
394    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
395        // Encode all queries into a single buffer
396        let mut buf = BytesMut::new();
397        for (sql, params) in queries {
398            buf.extend_from_slice(
399                &PgEncoder::encode_extended_query(sql, params)
400                    .map_err(|e| PgError::Encode(e.to_string()))?,
401            );
402        }
403
404        // Send all queries in ONE write
405        self.write_all_with_timeout(&buf, "stream write").await?;
406
407        // Collect all results
408        let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
409        let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
410        let mut error: Option<PgError> = None;
411        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
412            expected_queries: queries.len(),
413            allow_parse_complete: true,
414            require_parse_before_bind: true,
415            no_data_counts_as_completion: true,
416            allow_no_data_nonterminal: false,
417            expected_parse_completes: Some(queries.len()),
418        });
419
420        loop {
421            let msg = self.recv().await?;
422            if is_ignorable_session_message(&msg) {
423                continue;
424            }
425            if let BackendMessage::ErrorResponse(err) = msg {
426                if error.is_none() {
427                    error = Some(PgError::QueryServer(err.into()));
428                }
429                continue;
430            }
431            let msg_type = backend_msg_type_for_flow(&msg)
432                .ok_or_else(|| unexpected_backend_message("pipeline query", &msg));
433            let msg_type = match msg_type {
434                Ok(msg_type) => msg_type,
435                Err(err) => return return_with_desync(self, err),
436            };
437            if let Err(err) = flow.validate_msg_type(msg_type, "pipeline query", error.is_some()) {
438                return return_with_desync(self, err);
439            }
440            match msg {
441                BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
442                BackendMessage::RowDescription(_) => {}
443                BackendMessage::DataRow(data) => {
444                    if error.is_none() {
445                        current_rows.push(data);
446                    }
447                }
448                BackendMessage::CommandComplete(_) => {
449                    all_results.push(std::mem::take(&mut current_rows));
450                }
451                BackendMessage::NoData => {
452                    all_results.push(Vec::new());
453                }
454                BackendMessage::ReadyForQuery(_) => {
455                    if all_results.len() != queries.len() {
456                        return Err(error.unwrap_or_else(|| {
457                            PgError::Protocol(format!(
458                                "Pipeline completion mismatch: expected {}, got {}",
459                                queries.len(),
460                                all_results.len()
461                            ))
462                        }));
463                    }
464                    if let Some(err) = error {
465                        return Err(err);
466                    }
467                    return Ok(all_results);
468                }
469                other => {
470                    return return_with_desync(
471                        self,
472                        unexpected_backend_message("pipeline query", &other),
473                    );
474                }
475            }
476        }
477    }
478
479    /// Execute multiple Qail ASTs in a single network round-trip.
480    pub async fn pipeline_execute_rows_ast(
481        &mut self,
482        cmds: &[qail_core::ast::Qail],
483    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
484        let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
485        self.write_all_with_timeout(&buf, "stream write").await?;
486
487        let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
488        let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
489        let mut error: Option<PgError> = None;
490        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
491            expected_queries: cmds.len(),
492            allow_parse_complete: true,
493            require_parse_before_bind: true,
494            no_data_counts_as_completion: true,
495            allow_no_data_nonterminal: false,
496            expected_parse_completes: Some(cmds.len()),
497        });
498
499        loop {
500            let msg = self.recv().await?;
501            if is_ignorable_session_message(&msg) {
502                continue;
503            }
504            if let BackendMessage::ErrorResponse(err) = msg {
505                if error.is_none() {
506                    error = Some(PgError::QueryServer(err.into()));
507                }
508                continue;
509            }
510            let msg_type = backend_msg_type_for_flow(&msg)
511                .ok_or_else(|| unexpected_backend_message("pipeline ast", &msg));
512            let msg_type = match msg_type {
513                Ok(msg_type) => msg_type,
514                Err(err) => return return_with_desync(self, err),
515            };
516            if let Err(err) = flow.validate_msg_type(msg_type, "pipeline ast", error.is_some()) {
517                return return_with_desync(self, err);
518            }
519            match msg {
520                BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
521                BackendMessage::RowDescription(_) => {}
522                BackendMessage::DataRow(data) => {
523                    if error.is_none() {
524                        current_rows.push(data);
525                    }
526                }
527                BackendMessage::CommandComplete(_) => {
528                    all_results.push(std::mem::take(&mut current_rows));
529                }
530                BackendMessage::NoData => {
531                    all_results.push(Vec::new());
532                }
533                BackendMessage::ReadyForQuery(_) => {
534                    if all_results.len() != cmds.len() {
535                        return Err(error.unwrap_or_else(|| {
536                            PgError::Protocol(format!(
537                                "Pipeline completion mismatch: expected {}, got {}",
538                                cmds.len(),
539                                all_results.len()
540                            ))
541                        }));
542                    }
543                    if let Some(err) = error {
544                        return Err(err);
545                    }
546                    return Ok(all_results);
547                }
548                other => {
549                    return return_with_desync(
550                        self,
551                        unexpected_backend_message("pipeline ast", &other),
552                    );
553                }
554            }
555        }
556    }
557
558    /// FAST AST pipeline - returns only query count, no result parsing.
559    pub async fn pipeline_execute_count_ast_oneshot(
560        &mut self,
561        cmds: &[qail_core::ast::Qail],
562    ) -> PgResult<usize> {
563        let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
564
565        self.write_all_with_timeout(&buf, "stream write").await?;
566        self.flush_with_timeout("stream flush").await?;
567
568        let mut error: Option<PgError> = None;
569        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
570            expected_queries: cmds.len(),
571            allow_parse_complete: true,
572            require_parse_before_bind: true,
573            no_data_counts_as_completion: true,
574            allow_no_data_nonterminal: false,
575            expected_parse_completes: Some(cmds.len()),
576        });
577
578        loop {
579            match self.recv_msg_type_fast().await {
580                Ok(msg_type) => {
581                    let event = match flow.validate_msg_type(
582                        msg_type,
583                        "pipeline_execute_count_ast_oneshot",
584                        error.is_some(),
585                    ) {
586                        Ok(event) => event,
587                        Err(err) => return return_with_desync(self, err),
588                    };
589                    match event {
590                        FastPipelineEvent::Continue => {}
591                        FastPipelineEvent::ReadyForQuery => {
592                            if let Some(err) = error {
593                                return Err(err);
594                            }
595                            return Ok(flow.completed_queries());
596                        }
597                    }
598                }
599                Err(e) => {
600                    if matches!(&e, PgError::QueryServer(_)) {
601                        capture_query_server_error(self, &mut error, e);
602                        continue;
603                    }
604                    return Err(e);
605                }
606            }
607        }
608    }
609
610    /// Execute AST pipeline with explicit strategy mode.
611    ///
612    /// `Auto` uses a lightweight batch-size heuristic:
613    /// - `< 8` queries: one-shot path (`pipeline_execute_count_ast_oneshot`)
614    /// - `>= 8` queries: cached path (`pipeline_execute_count_ast_cached`)
615    #[inline]
616    pub async fn pipeline_execute_count_ast_with_mode(
617        &mut self,
618        cmds: &[qail_core::ast::Qail],
619        mode: AstPipelineMode,
620    ) -> PgResult<usize> {
621        if cmds.is_empty() {
622            return Ok(0);
623        }
624
625        match mode.resolve_for_batch_len(cmds.len()) {
626            AstPipelineMode::OneShot => self.pipeline_execute_count_ast_oneshot(cmds).await,
627            AstPipelineMode::Cached => self.pipeline_execute_count_ast_cached(cmds).await,
628            AstPipelineMode::Auto => unreachable!("Auto mode must resolve to concrete strategy"),
629        }
630    }
631
632    /// FASTEST extended query pipeline - takes pre-encoded wire bytes.
633    #[inline]
634    pub async fn pipeline_execute_count_wire(
635        &mut self,
636        wire_bytes: &[u8],
637        expected_queries: usize,
638    ) -> PgResult<usize> {
639        self.write_all_with_timeout(wire_bytes, "stream write")
640            .await?;
641        self.flush_with_timeout("stream flush").await?;
642
643        let mut error: Option<PgError> = None;
644        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
645            expected_queries,
646            allow_parse_complete: true,
647            require_parse_before_bind: false,
648            no_data_counts_as_completion: true,
649            allow_no_data_nonterminal: false,
650            expected_parse_completes: None,
651        });
652
653        loop {
654            match self.recv_msg_type_fast().await {
655                Ok(msg_type) => {
656                    let event = match flow.validate_msg_type(
657                        msg_type,
658                        "pipeline_execute_count_wire",
659                        error.is_some(),
660                    ) {
661                        Ok(event) => event,
662                        Err(err) => return return_with_desync(self, err),
663                    };
664                    match event {
665                        FastPipelineEvent::Continue => {}
666                        FastPipelineEvent::ReadyForQuery => {
667                            if let Some(err) = error {
668                                return Err(err);
669                            }
670                            return Ok(flow.completed_queries());
671                        }
672                    }
673                }
674                Err(e) => {
675                    if matches!(&e, PgError::QueryServer(_)) {
676                        capture_query_server_error(self, &mut error, e);
677                        continue;
678                    }
679                    return Err(e);
680                }
681            }
682        }
683    }
684
685    /// Simple query protocol pipeline - uses 'Q' message.
686    #[inline]
687    pub async fn pipeline_execute_count_simple_ast(
688        &mut self,
689        cmds: &[qail_core::ast::Qail],
690    ) -> PgResult<usize> {
691        if cmds.is_empty() {
692            return Ok(0);
693        }
694
695        let buf =
696            AstEncoder::encode_batch_simple(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
697        self.write_all_with_timeout(&buf, "stream write").await?;
698        self.flush_with_timeout("stream flush").await?;
699
700        let mut error: Option<PgError> = None;
701        let mut flow = FastSimpleFlowTracker::new(cmds.len());
702
703        loop {
704            match self.recv_msg_type_fast().await {
705                Ok(msg_type) => {
706                    let event = match flow.validate_msg_type(
707                        msg_type,
708                        "pipeline_execute_count_simple_ast",
709                        error.is_some(),
710                    ) {
711                        Ok(event) => event,
712                        Err(err) => return return_with_desync(self, err),
713                    };
714                    match event {
715                        FastPipelineEvent::Continue => {}
716                        FastPipelineEvent::ReadyForQuery => {
717                            if let Some(err) = error {
718                                return Err(err);
719                            }
720                            return Ok(flow.completed_queries());
721                        }
722                    }
723                }
724                Err(e) => {
725                    if matches!(&e, PgError::QueryServer(_)) {
726                        capture_query_server_error(self, &mut error, e);
727                        continue;
728                    }
729                    return Err(e);
730                }
731            }
732        }
733    }
734
735    /// FASTEST simple query pipeline - takes pre-encoded bytes.
736    #[inline]
737    pub async fn pipeline_execute_count_simple_wire(
738        &mut self,
739        wire_bytes: &[u8],
740        expected_queries: usize,
741    ) -> PgResult<usize> {
742        self.write_all_with_timeout(wire_bytes, "stream write")
743            .await?;
744        self.flush_with_timeout("stream flush").await?;
745
746        let mut error: Option<PgError> = None;
747        let mut flow = FastSimpleFlowTracker::new(expected_queries);
748
749        loop {
750            match self.recv_msg_type_fast().await {
751                Ok(msg_type) => {
752                    let event = match flow.validate_msg_type(
753                        msg_type,
754                        "pipeline_execute_count_simple_wire",
755                        error.is_some(),
756                    ) {
757                        Ok(event) => event,
758                        Err(err) => return return_with_desync(self, err),
759                    };
760                    match event {
761                        FastPipelineEvent::Continue => {}
762                        FastPipelineEvent::ReadyForQuery => {
763                            if let Some(err) = error {
764                                return Err(err);
765                            }
766                            return Ok(flow.completed_queries());
767                        }
768                    }
769                }
770                Err(e) => {
771                    if matches!(&e, PgError::QueryServer(_)) {
772                        capture_query_server_error(self, &mut error, e);
773                        continue;
774                    }
775                    return Err(e);
776                }
777            }
778        }
779    }
780
781    /// CACHED PREPARED STATEMENT pipeline - Parse once, Bind+Execute many.
782    /// 1. Generate SQL template with $1, $2, etc. placeholders
783    /// 2. Parse template ONCE (cached in PostgreSQL)
784    /// 3. Send Bind+Execute for each instance (params differ per query)
785    #[inline]
786    pub async fn pipeline_execute_count_ast_cached(
787        &mut self,
788        cmds: &[qail_core::ast::Qail],
789    ) -> PgResult<usize> {
790        if cmds.is_empty() {
791            return Ok(0);
792        }
793
794        use super::prepared::{sql_bytes_hash, stmt_name_from_hash};
795
796        let mut buf = BytesMut::with_capacity(cmds.len() * 64);
797        let mut sql_buf = BytesMut::with_capacity(256);
798        let mut params: Vec<Option<Vec<u8>>> = Vec::new();
799        let mut new_stmt_hashes: Vec<u64> = Vec::new();
800
801        for cmd in cmds {
802            if let Err(e) = AstEncoder::encode_cmd_sql_reuse(cmd, &mut sql_buf, &mut params) {
803                rollback_new_cached_statements(self, &new_stmt_hashes);
804                return Err(PgError::Encode(e.to_string()));
805            }
806
807            let sql_hash = sql_bytes_hash(sql_buf.as_ref());
808
809            if self.stmt_cache.contains(&sql_hash) {
810                self.stmt_cache.touch_key(sql_hash);
811            } else {
812                let stmt_name = stmt_name_from_hash(sql_hash);
813                if self.prepared_statements.contains_key(&stmt_name) {
814                    // Recover from old cache states where prepared_statements had
815                    // entries that were not mirrored in stmt_cache.
816                    self.stmt_cache.put(sql_hash, stmt_name.clone());
817                } else {
818                    self.evict_prepared_if_full();
819
820                    let sql = String::from_utf8_lossy(sql_buf.as_ref()).to_string();
821                    let parse_msg = match PgEncoder::try_encode_parse(&stmt_name, &sql, &[]) {
822                        Ok(msg) => msg,
823                        Err(e) => {
824                            rollback_new_cached_statements(self, &new_stmt_hashes);
825                            return Err(PgError::Encode(e.to_string()));
826                        }
827                    };
828                    buf.extend(parse_msg);
829                    self.stmt_cache.put(sql_hash, stmt_name.clone());
830                    self.prepared_statements.insert(stmt_name.clone(), sql);
831                    new_stmt_hashes.push(sql_hash);
832                }
833            }
834
835            let Some(stmt_name) = self.stmt_cache.peek(&sql_hash) else {
836                rollback_new_cached_statements(self, &new_stmt_hashes);
837                return Err(PgError::Protocol(
838                    "stmt_cache lookup failed after statement registration".to_string(),
839                ));
840            };
841
842            if let Err(e) = PgEncoder::encode_bind_to(&mut buf, stmt_name, &params) {
843                rollback_new_cached_statements(self, &new_stmt_hashes);
844                return Err(PgError::Encode(e.to_string()));
845            }
846            PgEncoder::encode_execute_to(&mut buf);
847        }
848
849        PgEncoder::encode_sync_to(&mut buf);
850
851        if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
852            rollback_new_cached_statements(self, &new_stmt_hashes);
853            return Err(err);
854        }
855        if let Err(err) = self.flush_with_timeout("stream flush").await {
856            rollback_new_cached_statements(self, &new_stmt_hashes);
857            return Err(err);
858        }
859
860        let mut error: Option<PgError> = None;
861        let expected_parse_completes = new_stmt_hashes.len();
862        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
863            expected_queries: cmds.len(),
864            allow_parse_complete: true,
865            require_parse_before_bind: false,
866            no_data_counts_as_completion: true,
867            allow_no_data_nonterminal: false,
868            expected_parse_completes: Some(expected_parse_completes),
869        });
870
871        loop {
872            match self.recv_msg_type_fast().await {
873                Ok(msg_type) => {
874                    match flow.validate_msg_type(
875                        msg_type,
876                        "pipeline_execute_count_ast_cached",
877                        error.is_some(),
878                    ) {
879                        Ok(FastPipelineEvent::Continue) => {}
880                        Ok(FastPipelineEvent::ReadyForQuery) => {
881                            if let Some(err) = error {
882                                rollback_new_cached_statements(self, &new_stmt_hashes);
883                                return Err(err);
884                            }
885                            return Ok(flow.completed_queries());
886                        }
887                        Err(err) => {
888                            rollback_new_cached_statements(self, &new_stmt_hashes);
889                            return return_with_desync(self, err);
890                        }
891                    }
892                }
893                Err(e) => {
894                    if matches!(&e, PgError::QueryServer(_)) {
895                        capture_query_server_error(self, &mut error, e);
896                        continue;
897                    }
898                    rollback_new_cached_statements(self, &new_stmt_hashes);
899                    return Err(e);
900                }
901            }
902        }
903    }
904    /// ZERO-LOOKUP prepared statement pipeline.
905    /// - Hash computation per query
906    /// - HashMap lookup per query
907    /// - String allocation for stmt_name
908    /// # Example
909    /// ```ignore
910    /// // Prepare once (outside timing loop):
911    /// let stmt = PreparedStatement::from_sql("SELECT id, name FROM harbors LIMIT $1");
912    /// let params_batch: Vec<Vec<Option<Vec<u8>>>> = (1..=1000)
913    ///     .map(|i| vec![Some(i.to_string().into_bytes())])
914    ///     .collect();
915    /// // Execute many (no hash, no lookup!):
916    /// conn.pipeline_execute_prepared_count(&stmt, &params_batch).await?;
917    /// ```
918    #[inline]
919    pub async fn pipeline_execute_prepared_count(
920        &mut self,
921        stmt: &super::PreparedStatement,
922        params_batch: &[Vec<Option<Vec<u8>>>],
923    ) -> PgResult<usize> {
924        if params_batch.is_empty() {
925            return Ok(0);
926        }
927
928        // Local buffer - faster than reusing connection buffer
929        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
930
931        let is_new = !self.prepared_statements.contains_key(&stmt.name);
932
933        if is_new {
934            return Err(PgError::Query(
935                "Statement not prepared. Call prepare() first.".to_string(),
936            ));
937        }
938
939        // ZERO ALLOCATION: write directly to local buffer
940        for params in params_batch {
941            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
942                .map_err(|e| PgError::Encode(e.to_string()))?;
943            PgEncoder::encode_execute_to(&mut buf);
944        }
945
946        PgEncoder::encode_sync_to(&mut buf);
947
948        self.write_all_with_timeout(&buf, "stream write").await?;
949        self.flush_with_timeout("stream flush").await?;
950
951        let mut error: Option<PgError> = None;
952        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
953            expected_queries: params_batch.len(),
954            allow_parse_complete: false,
955            require_parse_before_bind: false,
956            no_data_counts_as_completion: true,
957            allow_no_data_nonterminal: false,
958            expected_parse_completes: Some(0),
959        });
960
961        loop {
962            match self.recv_msg_type_fast().await {
963                Ok(msg_type) => {
964                    let event = match flow.validate_msg_type(
965                        msg_type,
966                        "pipeline_execute_prepared_count",
967                        error.is_some(),
968                    ) {
969                        Ok(event) => event,
970                        Err(err) => return return_with_desync(self, err),
971                    };
972                    match event {
973                        FastPipelineEvent::Continue => {}
974                        FastPipelineEvent::ReadyForQuery => {
975                            if let Some(err) = error {
976                                return Err(err);
977                            }
978                            return Ok(flow.completed_queries());
979                        }
980                    }
981                }
982                Err(e) => {
983                    if matches!(&e, PgError::QueryServer(_)) {
984                        capture_query_server_error(self, &mut error, e);
985                        continue;
986                    }
987                    return Err(e);
988                }
989            }
990        }
991    }
992
993    /// Prepare a statement and return a handle for fast execution.
994    /// PreparedStatement handle for use with pipeline_execute_prepared_count.
995    pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
996        use super::prepared::sql_bytes_to_stmt_name;
997
998        let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
999
1000        if !self.prepared_statements.contains_key(&stmt_name) {
1001            self.evict_prepared_if_full();
1002            let mut buf = BytesMut::with_capacity(sql.len() + 32);
1003            buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
1004            buf.extend(PgEncoder::encode_sync());
1005
1006            self.write_all_with_timeout(&buf, "stream write").await?;
1007            self.flush_with_timeout("stream flush").await?;
1008
1009            // Wait for ParseComplete
1010            let mut error: Option<PgError> = None;
1011            let mut saw_parse_complete = false;
1012            loop {
1013                match self.recv_msg_type_fast().await {
1014                    Ok(msg_type) => match msg_type {
1015                        b'1' => {
1016                            if saw_parse_complete {
1017                                return Err(PgError::Protocol(
1018                                    "prepare received duplicate ParseComplete".to_string(),
1019                                ));
1020                            }
1021                            saw_parse_complete = true;
1022                            self.prepared_statements
1023                                .insert(stmt_name.clone(), sql.to_string());
1024                        }
1025                        b'Z' => {
1026                            if let Some(err) = error {
1027                                return Err(err);
1028                            }
1029                            if !saw_parse_complete {
1030                                return Err(PgError::Protocol(
1031                                    "prepare reached ReadyForQuery without ParseComplete"
1032                                        .to_string(),
1033                                ));
1034                            }
1035                            break;
1036                        }
1037                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
1038                        other => {
1039                            return return_with_desync(
1040                                self,
1041                                unexpected_backend_msg_type("prepare", other),
1042                            );
1043                        }
1044                    },
1045                    Err(e) => {
1046                        if matches!(&e, PgError::QueryServer(_)) {
1047                            capture_query_server_error(self, &mut error, e);
1048                            continue;
1049                        }
1050                        return Err(e);
1051                    }
1052                }
1053            }
1054        }
1055
1056        Ok(super::PreparedStatement {
1057            name: stmt_name,
1058            param_count: sql.matches('$').count(),
1059        })
1060    }
1061
1062    /// Execute a prepared statement pipeline and return all row data.
1063    pub async fn pipeline_execute_prepared_rows(
1064        &mut self,
1065        stmt: &super::PreparedStatement,
1066        params_batch: &[Vec<Option<Vec<u8>>>],
1067    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
1068        if params_batch.is_empty() {
1069            return Ok(Vec::new());
1070        }
1071
1072        if !self.prepared_statements.contains_key(&stmt.name) {
1073            return Err(PgError::Query(
1074                "Statement not prepared. Call prepare() first.".to_string(),
1075            ));
1076        }
1077
1078        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1079
1080        for params in params_batch {
1081            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1082                .map_err(|e| PgError::Encode(e.to_string()))?;
1083            PgEncoder::encode_execute_to(&mut buf);
1084        }
1085
1086        PgEncoder::encode_sync_to(&mut buf);
1087
1088        self.write_all_with_timeout(&buf, "stream write").await?;
1089        self.flush_with_timeout("stream flush").await?;
1090
1091        // Collect results using fast inline DataRow parsing
1092        let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
1093            Vec::with_capacity(params_batch.len());
1094        let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
1095        let mut error: Option<PgError> = None;
1096        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1097            expected_queries: params_batch.len(),
1098            allow_parse_complete: false,
1099            require_parse_before_bind: false,
1100            no_data_counts_as_completion: true,
1101            allow_no_data_nonterminal: false,
1102            expected_parse_completes: Some(0),
1103        });
1104
1105        loop {
1106            match self.recv_with_data_fast().await {
1107                Ok((msg_type, data)) => {
1108                    if let Err(err) = flow.validate_msg_type(
1109                        msg_type,
1110                        "pipeline_execute_prepared_rows",
1111                        error.is_some(),
1112                    ) {
1113                        return return_with_desync(self, err);
1114                    }
1115                    match msg_type {
1116                        b'2' => {} // BindComplete
1117                        b'T' => {} // RowDescription
1118                        b'D' => {
1119                            // DataRow
1120                            if error.is_none()
1121                                && let Some(row) = data
1122                            {
1123                                current_rows.push(row);
1124                            }
1125                        }
1126                        b'C' => {
1127                            // CommandComplete
1128                            all_results.push(std::mem::take(&mut current_rows));
1129                        }
1130                        b'n' => {
1131                            // NoData
1132                            all_results.push(Vec::new());
1133                        }
1134                        b'Z' => {
1135                            // ReadyForQuery
1136                            if all_results.len() != params_batch.len() {
1137                                return Err(error.unwrap_or_else(|| {
1138                                    PgError::Protocol(format!(
1139                                        "Pipeline completion mismatch: expected {}, got {}",
1140                                        params_batch.len(),
1141                                        all_results.len()
1142                                    ))
1143                                }));
1144                            }
1145                            if let Some(err) = error {
1146                                return Err(err);
1147                            }
1148                            return Ok(all_results);
1149                        }
1150                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
1151                        other => {
1152                            return return_with_desync(
1153                                self,
1154                                unexpected_backend_msg_type(
1155                                    "pipeline_execute_prepared_rows",
1156                                    other,
1157                                ),
1158                            );
1159                        }
1160                    }
1161                }
1162                Err(e) => {
1163                    if matches!(&e, PgError::QueryServer(_)) {
1164                        capture_query_server_error(self, &mut error, e);
1165                        continue;
1166                    }
1167                    return Err(e);
1168                }
1169            }
1170        }
1171    }
1172
1173    /// ZERO-COPY pipeline execution with Bytes for column data.
1174    pub async fn pipeline_execute_prepared_rows_bytes(
1175        &mut self,
1176        stmt: &super::PreparedStatement,
1177        params_batch: &[Vec<Option<Vec<u8>>>],
1178    ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
1179        if params_batch.is_empty() {
1180            return Ok(Vec::new());
1181        }
1182
1183        if !self.prepared_statements.contains_key(&stmt.name) {
1184            return Err(PgError::Query(
1185                "Statement not prepared. Call prepare() first.".to_string(),
1186            ));
1187        }
1188
1189        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1190
1191        for params in params_batch {
1192            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1193                .map_err(|e| PgError::Encode(e.to_string()))?;
1194            PgEncoder::encode_execute_to(&mut buf);
1195        }
1196
1197        PgEncoder::encode_sync_to(&mut buf);
1198
1199        self.write_all_with_timeout(&buf, "stream write").await?;
1200        self.flush_with_timeout("stream flush").await?;
1201
1202        // Collect results using ZERO-COPY Bytes
1203        let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
1204            Vec::with_capacity(params_batch.len());
1205        let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
1206        let mut error: Option<PgError> = None;
1207        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1208            expected_queries: params_batch.len(),
1209            allow_parse_complete: false,
1210            require_parse_before_bind: false,
1211            no_data_counts_as_completion: true,
1212            allow_no_data_nonterminal: false,
1213            expected_parse_completes: Some(0),
1214        });
1215
1216        loop {
1217            match self.recv_data_zerocopy().await {
1218                Ok((msg_type, data)) => {
1219                    if let Err(err) = flow.validate_msg_type(
1220                        msg_type,
1221                        "pipeline_execute_prepared_rows_bytes",
1222                        error.is_some(),
1223                    ) {
1224                        return return_with_desync(self, err);
1225                    }
1226                    match msg_type {
1227                        b'2' => {} // BindComplete
1228                        b'T' => {} // RowDescription
1229                        b'D' => {
1230                            // DataRow
1231                            if error.is_none()
1232                                && let Some(row) = data
1233                            {
1234                                current_rows.push(row);
1235                            }
1236                        }
1237                        b'C' => {
1238                            // CommandComplete
1239                            all_results.push(std::mem::take(&mut current_rows));
1240                        }
1241                        b'n' => {
1242                            // NoData
1243                            all_results.push(Vec::new());
1244                        }
1245                        b'Z' => {
1246                            // ReadyForQuery
1247                            if all_results.len() != params_batch.len() {
1248                                return Err(error.unwrap_or_else(|| {
1249                                    PgError::Protocol(format!(
1250                                        "Pipeline completion mismatch: expected {}, got {}",
1251                                        params_batch.len(),
1252                                        all_results.len()
1253                                    ))
1254                                }));
1255                            }
1256                            if let Some(err) = error {
1257                                return Err(err);
1258                            }
1259                            return Ok(all_results);
1260                        }
1261                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
1262                        other => {
1263                            return return_with_desync(
1264                                self,
1265                                unexpected_backend_msg_type(
1266                                    "pipeline_execute_prepared_rows_bytes",
1267                                    other,
1268                                ),
1269                            );
1270                        }
1271                    }
1272                }
1273                Err(e) => {
1274                    if matches!(&e, PgError::QueryServer(_)) {
1275                        capture_query_server_error(self, &mut error, e);
1276                        continue;
1277                    }
1278                    return Err(e);
1279                }
1280            }
1281        }
1282    }
1283
1284    /// ULTRA-FAST pipeline for 2-column SELECT queries.
1285    pub async fn pipeline_execute_prepared_rows_2cols_bytes(
1286        &mut self,
1287        stmt: &super::PreparedStatement,
1288        params_batch: &[Vec<Option<Vec<u8>>>],
1289    ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
1290        if params_batch.is_empty() {
1291            return Ok(Vec::new());
1292        }
1293
1294        if !self.prepared_statements.contains_key(&stmt.name) {
1295            return Err(PgError::Query(
1296                "Statement not prepared. Call prepare() first.".to_string(),
1297            ));
1298        }
1299
1300        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1301
1302        for params in params_batch {
1303            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1304                .map_err(|e| PgError::Encode(e.to_string()))?;
1305            PgEncoder::encode_execute_to(&mut buf);
1306        }
1307
1308        PgEncoder::encode_sync_to(&mut buf);
1309
1310        self.write_all_with_timeout(&buf, "stream write").await?;
1311        self.flush_with_timeout("stream flush").await?;
1312
1313        // Pre-allocate with expected capacity
1314        let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
1315            Vec::with_capacity(params_batch.len());
1316        let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
1317        let mut error: Option<PgError> = None;
1318        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1319            expected_queries: params_batch.len(),
1320            allow_parse_complete: false,
1321            require_parse_before_bind: false,
1322            no_data_counts_as_completion: true,
1323            allow_no_data_nonterminal: false,
1324            expected_parse_completes: Some(0),
1325        });
1326
1327        loop {
1328            match self.recv_data_ultra().await {
1329                Ok((msg_type, data)) => {
1330                    if let Err(err) = flow.validate_msg_type(
1331                        msg_type,
1332                        "pipeline_execute_prepared_rows_2cols_bytes",
1333                        error.is_some(),
1334                    ) {
1335                        return return_with_desync(self, err);
1336                    }
1337                    match msg_type {
1338                        b'2' | b'T' => {} // BindComplete, RowDescription
1339                        b'D' => {
1340                            if error.is_none()
1341                                && let Some(row) = data
1342                            {
1343                                current_rows.push(row);
1344                            }
1345                        }
1346                        b'C' => {
1347                            all_results.push(std::mem::take(&mut current_rows));
1348                            current_rows = Vec::with_capacity(16);
1349                        }
1350                        b'n' => {
1351                            all_results.push(Vec::new());
1352                        }
1353                        b'Z' => {
1354                            if all_results.len() != params_batch.len() {
1355                                return Err(error.unwrap_or_else(|| {
1356                                    PgError::Protocol(format!(
1357                                        "Pipeline completion mismatch: expected {}, got {}",
1358                                        params_batch.len(),
1359                                        all_results.len()
1360                                    ))
1361                                }));
1362                            }
1363                            if let Some(err) = error {
1364                                return Err(err);
1365                            }
1366                            return Ok(all_results);
1367                        }
1368                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
1369                        other => {
1370                            return return_with_desync(
1371                                self,
1372                                unexpected_backend_msg_type(
1373                                    "pipeline_execute_prepared_rows_2cols_bytes",
1374                                    other,
1375                                ),
1376                            );
1377                        }
1378                    }
1379                }
1380                Err(e) => {
1381                    if matches!(&e, PgError::QueryServer(_)) {
1382                        capture_query_server_error(self, &mut error, e);
1383                        continue;
1384                    }
1385                    return Err(e);
1386                }
1387            }
1388        }
1389    }
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394    use super::*;
1395    use qail_core::ast::Qail;
1396
1397    #[test]
1398    fn ast_pipeline_mode_auto_resolves_by_batch_size() {
1399        assert_eq!(
1400            AstPipelineMode::Auto.resolve_for_batch_len(0),
1401            AstPipelineMode::OneShot
1402        );
1403        assert_eq!(
1404            AstPipelineMode::Auto.resolve_for_batch_len(7),
1405            AstPipelineMode::OneShot
1406        );
1407        assert_eq!(
1408            AstPipelineMode::Auto.resolve_for_batch_len(8),
1409            AstPipelineMode::Cached
1410        );
1411        assert_eq!(
1412            AstPipelineMode::Cached.resolve_for_batch_len(1),
1413            AstPipelineMode::Cached
1414        );
1415        assert_eq!(
1416            AstPipelineMode::OneShot.resolve_for_batch_len(1000),
1417            AstPipelineMode::OneShot
1418        );
1419    }
1420
1421    #[cfg(unix)]
1422    fn make_test_conn_with_prepared() -> PgConnection {
1423        use crate::driver::connection::StatementCache;
1424        use crate::driver::stream::PgStream;
1425        use bytes::BytesMut;
1426        use std::collections::{HashMap, VecDeque};
1427        use std::num::NonZeroUsize;
1428        use tokio::net::UnixStream;
1429
1430        let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1431        let mut conn = PgConnection {
1432            stream: PgStream::Unix(unix_stream),
1433            buffer: BytesMut::with_capacity(1024),
1434            write_buf: BytesMut::with_capacity(1024),
1435            sql_buf: BytesMut::with_capacity(256),
1436            params_buf: Vec::new(),
1437            prepared_statements: HashMap::new(),
1438            stmt_cache: StatementCache::new(NonZeroUsize::new(16).expect("non-zero")),
1439            column_info_cache: HashMap::new(),
1440            process_id: 0,
1441            secret_key: 0,
1442            cancel_key_bytes: Vec::new(),
1443            requested_protocol_minor: PgConnection::default_protocol_minor(),
1444            negotiated_protocol_minor: PgConnection::default_protocol_minor(),
1445            notifications: VecDeque::new(),
1446            replication_stream_active: false,
1447            replication_mode_enabled: false,
1448            last_replication_wal_end: None,
1449            io_desynced: false,
1450            pending_statement_closes: Vec::new(),
1451            draining_statement_closes: false,
1452        };
1453        conn.prepared_statements
1454            .insert("s1".to_string(), "SELECT 1".to_string());
1455        conn.stmt_cache.put(1, "s1".to_string());
1456        conn
1457    }
1458
1459    fn server_error(code: &str, message: &str) -> PgError {
1460        PgError::QueryServer(super::super::PgServerError {
1461            severity: "ERROR".to_string(),
1462            code: code.to_string(),
1463            message: message.to_string(),
1464            detail: None,
1465            hint: None,
1466        })
1467    }
1468
1469    #[cfg(unix)]
1470    #[tokio::test]
1471    async fn capture_query_server_error_clears_prepared_state_on_retryable_error() {
1472        let mut conn = make_test_conn_with_prepared();
1473        let mut slot = None;
1474        let err = server_error("26000", "prepared statement \"s1\" does not exist");
1475        capture_query_server_error(&mut conn, &mut slot, err);
1476
1477        assert!(slot.is_some());
1478        assert!(conn.prepared_statements.is_empty());
1479        assert_eq!(conn.stmt_cache.len(), 0);
1480    }
1481
1482    #[cfg(unix)]
1483    #[tokio::test]
1484    async fn capture_query_server_error_preserves_prepared_state_on_non_retryable_error() {
1485        let mut conn = make_test_conn_with_prepared();
1486        let mut slot = None;
1487        let err = server_error("23505", "duplicate key value violates unique constraint");
1488        capture_query_server_error(&mut conn, &mut slot, err);
1489
1490        assert!(slot.is_some());
1491        assert_eq!(conn.prepared_statements.len(), 1);
1492        assert_eq!(conn.stmt_cache.len(), 1);
1493    }
1494
1495    #[cfg(unix)]
1496    #[tokio::test]
1497    async fn capture_query_server_error_does_not_override_existing_error() {
1498        let mut conn = make_test_conn_with_prepared();
1499        let mut slot = Some(server_error("23505", "duplicate key"));
1500        let retryable = server_error("26000", "prepared statement \"s1\" does not exist");
1501        capture_query_server_error(&mut conn, &mut slot, retryable);
1502
1503        assert_eq!(conn.prepared_statements.len(), 1);
1504        assert_eq!(conn.stmt_cache.len(), 1);
1505        assert_eq!(
1506            slot.and_then(|e| e.sqlstate().map(str::to_string))
1507                .as_deref(),
1508            Some("23505")
1509        );
1510    }
1511
1512    #[cfg(unix)]
1513    #[tokio::test]
1514    async fn pipeline_ast_cached_rolls_back_new_state_on_encode_error() {
1515        let mut conn = make_test_conn_with_prepared();
1516        let baseline = conn.prepared_statements.len();
1517        let baseline_stmt_cache = conn.stmt_cache.len();
1518
1519        let cmds = vec![
1520            Qail::get("harbors").columns(["id", "name"]).limit(1),
1521            Qail::get("bad\0table").columns(["id"]).limit(1),
1522        ];
1523
1524        let err = conn
1525            .pipeline_execute_count_ast_cached(&cmds)
1526            .await
1527            .expect_err("expected encode error for NUL byte in table name");
1528
1529        assert!(matches!(err, PgError::Encode(_)));
1530        assert_eq!(conn.prepared_statements.len(), baseline);
1531        assert_eq!(conn.stmt_cache.len(), baseline_stmt_cache);
1532        assert!(conn.prepared_statements.contains_key("s1"));
1533    }
1534
1535    #[cfg(unix)]
1536    #[tokio::test]
1537    async fn pipeline_simple_ast_empty_batch_returns_zero_without_io() {
1538        let mut conn = make_test_conn_with_prepared();
1539        let res = conn
1540            .pipeline_execute_count_simple_ast(&[])
1541            .await
1542            .expect("empty batch should be a fast no-op");
1543        assert_eq!(res, 0);
1544        assert!(!conn.is_io_desynced());
1545    }
1546}