Skip to main content

qail_pg/driver/
pipeline.rs

1//! High-performance pipelining methods for PostgreSQL connection.
2//!
3//!
4//! Performance hierarchy (fastest to slowest):
5//! 1. `pipeline_ast_cached` - Parse once, Bind+Execute many (275k q/s)
6//! 2. `pipeline_simple_bytes_fast` - Pre-encoded simple query
7//! 3. `pipeline_bytes_fast` - Pre-encoded extended query
8//! 4. `pipeline_simple_fast` - Simple query protocol (~99k q/s)
9//! 5. `pipeline_ast_fast` - Fast extended query, count only
10//! 6. `pipeline_ast` - Full results collection
11//! 7. `query_pipeline` - SQL-based pipelining
12
13use super::{
14    PgConnection, PgError, PgResult, is_ignorable_session_message, is_ignorable_session_msg_type,
15    unexpected_backend_message, unexpected_backend_msg_type,
16};
17use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
18use bytes::BytesMut;
19
20#[inline]
21fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
22    if matches!(
23        err,
24        PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
25    ) {
26        conn.mark_io_desynced();
27    }
28    Err(err)
29}
30
31#[inline]
32fn capture_query_server_error(conn: &mut PgConnection, slot: &mut Option<PgError>, err: PgError) {
33    if slot.is_some() {
34        return;
35    }
36    if err.is_prepared_statement_retryable() {
37        conn.clear_prepared_statement_state();
38    }
39    *slot = Some(err);
40}
41
42#[derive(Debug, Clone, Copy)]
43struct FastExtendedFlowConfig {
44    expected_queries: usize,
45    allow_parse_complete: bool,
46    require_parse_before_bind: bool,
47    no_data_counts_as_completion: bool,
48    allow_no_data_nonterminal: bool,
49    expected_parse_completes: Option<usize>,
50}
51
52#[derive(Debug, Clone, Copy)]
53struct FastExtendedFlowTracker {
54    cfg: FastExtendedFlowConfig,
55    completed_queries: usize,
56    parse_completes: usize,
57    current_parse_seen: bool,
58    current_bind_seen: bool,
59}
60
61impl FastExtendedFlowTracker {
62    fn new(cfg: FastExtendedFlowConfig) -> Self {
63        Self {
64            cfg,
65            completed_queries: 0,
66            parse_completes: 0,
67            current_parse_seen: false,
68            current_bind_seen: false,
69        }
70    }
71
72    fn completed_queries(&self) -> usize {
73        self.completed_queries
74    }
75
76    fn validate_msg_type(
77        &mut self,
78        msg_type: u8,
79        context: &'static str,
80        error_pending: bool,
81    ) -> PgResult<FastPipelineEvent> {
82        if is_ignorable_session_msg_type(msg_type) {
83            return Ok(FastPipelineEvent::Continue);
84        }
85
86        if error_pending {
87            if msg_type == b'Z' {
88                return Ok(FastPipelineEvent::ReadyForQuery);
89            }
90            return Ok(FastPipelineEvent::Continue);
91        }
92
93        if msg_type == b'Z' {
94            if self.completed_queries != self.cfg.expected_queries {
95                return Err(PgError::Protocol(format!(
96                    "{}: Pipeline completion mismatch: expected {}, got {}",
97                    context, self.cfg.expected_queries, self.completed_queries
98                )));
99            }
100            if self.current_parse_seen || self.current_bind_seen {
101                return Err(PgError::Protocol(format!(
102                    "{}: ReadyForQuery with incomplete query state",
103                    context
104                )));
105            }
106            if let Some(expected) = self.cfg.expected_parse_completes
107                && self.parse_completes != expected
108            {
109                return Err(PgError::Protocol(format!(
110                    "{}: ParseComplete mismatch: expected {}, got {}",
111                    context, expected, self.parse_completes
112                )));
113            }
114            return Ok(FastPipelineEvent::ReadyForQuery);
115        }
116
117        if self.completed_queries >= self.cfg.expected_queries {
118            return Err(PgError::Protocol(format!(
119                "{}: unexpected message '{}' after all queries completed",
120                context, msg_type as char
121            )));
122        }
123
124        match msg_type {
125            b'1' => {
126                if !self.cfg.allow_parse_complete {
127                    return Err(PgError::Protocol(format!(
128                        "{}: unexpected ParseComplete",
129                        context
130                    )));
131                }
132                if self.current_bind_seen {
133                    return Err(PgError::Protocol(format!(
134                        "{}: ParseComplete after BindComplete",
135                        context
136                    )));
137                }
138                if self.current_parse_seen {
139                    return Err(PgError::Protocol(format!(
140                        "{}: duplicate ParseComplete",
141                        context
142                    )));
143                }
144                self.current_parse_seen = true;
145                self.parse_completes += 1;
146                if let Some(expected) = self.cfg.expected_parse_completes
147                    && self.parse_completes > expected
148                {
149                    return Err(PgError::Protocol(format!(
150                        "{}: ParseComplete mismatch: expected {}, got at least {}",
151                        context, expected, self.parse_completes
152                    )));
153                }
154            }
155            b'2' => {
156                if self.current_bind_seen {
157                    return Err(PgError::Protocol(format!(
158                        "{}: duplicate BindComplete",
159                        context
160                    )));
161                }
162                if self.cfg.require_parse_before_bind && !self.current_parse_seen {
163                    return Err(PgError::Protocol(format!(
164                        "{}: BindComplete before ParseComplete",
165                        context
166                    )));
167                }
168                self.current_bind_seen = true;
169            }
170            b'T' | b't' | b's' => {
171                if !self.current_bind_seen {
172                    return Err(PgError::Protocol(format!(
173                        "{}: '{}' before BindComplete",
174                        context, msg_type as char
175                    )));
176                }
177            }
178            b'D' => {
179                if !self.current_bind_seen {
180                    return Err(PgError::Protocol(format!(
181                        "{}: DataRow before BindComplete",
182                        context
183                    )));
184                }
185            }
186            b'n' => {
187                if !self.current_bind_seen {
188                    return Err(PgError::Protocol(format!(
189                        "{}: NoData before BindComplete",
190                        context
191                    )));
192                }
193                if self.cfg.no_data_counts_as_completion {
194                    self.complete_current();
195                } else if !self.cfg.allow_no_data_nonterminal {
196                    return Err(PgError::Protocol(format!("{}: unexpected NoData", context)));
197                }
198            }
199            b'C' => {
200                if !self.current_bind_seen {
201                    return Err(PgError::Protocol(format!(
202                        "{}: CommandComplete before BindComplete",
203                        context
204                    )));
205                }
206                self.complete_current();
207            }
208            b'I' => {
209                return Err(PgError::Protocol(format!(
210                    "{}: unexpected EmptyQueryResponse in extended pipeline",
211                    context
212                )));
213            }
214            other => return Err(unexpected_backend_msg_type(context, other)),
215        }
216
217        Ok(FastPipelineEvent::Continue)
218    }
219
220    fn complete_current(&mut self) {
221        self.completed_queries += 1;
222        self.current_parse_seen = false;
223        self.current_bind_seen = false;
224    }
225}
226
227#[derive(Debug, Clone, Copy)]
228struct FastSimpleFlowTracker {
229    expected_queries: usize,
230    completed_queries: usize,
231    current_row_description_seen: bool,
232}
233
234impl FastSimpleFlowTracker {
235    fn new(expected_queries: usize) -> Self {
236        Self {
237            expected_queries,
238            completed_queries: 0,
239            current_row_description_seen: false,
240        }
241    }
242
243    fn completed_queries(&self) -> usize {
244        self.completed_queries
245    }
246
247    fn validate_msg_type(
248        &mut self,
249        msg_type: u8,
250        context: &'static str,
251        error_pending: bool,
252    ) -> PgResult<FastPipelineEvent> {
253        if is_ignorable_session_msg_type(msg_type) {
254            return Ok(FastPipelineEvent::Continue);
255        }
256
257        if error_pending {
258            if msg_type == b'Z' {
259                return Ok(FastPipelineEvent::ReadyForQuery);
260            }
261            return Ok(FastPipelineEvent::Continue);
262        }
263
264        if msg_type == b'Z' {
265            if self.completed_queries != self.expected_queries {
266                return Err(PgError::Protocol(format!(
267                    "{}: Pipeline completion mismatch: expected {}, got {}",
268                    context, self.expected_queries, self.completed_queries
269                )));
270            }
271            if self.current_row_description_seen {
272                return Err(PgError::Protocol(format!(
273                    "{}: ReadyForQuery with incomplete row stream",
274                    context
275                )));
276            }
277            return Ok(FastPipelineEvent::ReadyForQuery);
278        }
279
280        if self.completed_queries >= self.expected_queries {
281            return Err(PgError::Protocol(format!(
282                "{}: unexpected message '{}' after all queries completed",
283                context, msg_type as char
284            )));
285        }
286
287        match msg_type {
288            b'T' => {
289                if self.current_row_description_seen {
290                    return Err(PgError::Protocol(format!(
291                        "{}: duplicate RowDescription",
292                        context
293                    )));
294                }
295                self.current_row_description_seen = true;
296            }
297            b'D' => {
298                if !self.current_row_description_seen {
299                    return Err(PgError::Protocol(format!(
300                        "{}: DataRow before RowDescription",
301                        context
302                    )));
303                }
304            }
305            b'C' | b'I' => {
306                self.completed_queries += 1;
307                self.current_row_description_seen = false;
308            }
309            b'1' | b'2' | b'n' | b't' | b's' => {
310                return Err(PgError::Protocol(format!(
311                    "{}: unexpected '{}' in simple pipeline",
312                    context, msg_type as char
313                )));
314            }
315            other => return Err(unexpected_backend_msg_type(context, other)),
316        }
317
318        Ok(FastPipelineEvent::Continue)
319    }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq)]
323enum FastPipelineEvent {
324    Continue,
325    ReadyForQuery,
326}
327
328#[inline]
329fn backend_msg_type_for_flow(msg: &BackendMessage) -> Option<u8> {
330    match msg {
331        BackendMessage::ParseComplete => Some(b'1'),
332        BackendMessage::BindComplete => Some(b'2'),
333        BackendMessage::ParameterDescription(_) => Some(b't'),
334        BackendMessage::RowDescription(_) => Some(b'T'),
335        BackendMessage::NoData => Some(b'n'),
336        BackendMessage::PortalSuspended => Some(b's'),
337        BackendMessage::DataRow(_) => Some(b'D'),
338        BackendMessage::CommandComplete(_) => Some(b'C'),
339        BackendMessage::EmptyQueryResponse => Some(b'I'),
340        BackendMessage::ReadyForQuery(_) => Some(b'Z'),
341        _ => None,
342    }
343}
344
345impl PgConnection {
346    /// Execute multiple SQL queries in a single network round-trip (PIPELINING).
347    pub async fn query_pipeline(
348        &mut self,
349        queries: &[(&str, &[Option<Vec<u8>>])],
350    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
351        // Encode all queries into a single buffer
352        let mut buf = BytesMut::new();
353        for (sql, params) in queries {
354            buf.extend_from_slice(
355                &PgEncoder::encode_extended_query(sql, params)
356                    .map_err(|e| PgError::Encode(e.to_string()))?,
357            );
358        }
359
360        // Send all queries in ONE write
361        self.write_all_with_timeout(&buf, "stream write").await?;
362
363        // Collect all results
364        let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
365        let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
366        let mut error: Option<PgError> = None;
367        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
368            expected_queries: queries.len(),
369            allow_parse_complete: true,
370            require_parse_before_bind: true,
371            no_data_counts_as_completion: true,
372            allow_no_data_nonterminal: false,
373            expected_parse_completes: Some(queries.len()),
374        });
375
376        loop {
377            let msg = self.recv().await?;
378            if is_ignorable_session_message(&msg) {
379                continue;
380            }
381            if let BackendMessage::ErrorResponse(err) = msg {
382                if error.is_none() {
383                    error = Some(PgError::QueryServer(err.into()));
384                }
385                continue;
386            }
387            let msg_type = backend_msg_type_for_flow(&msg)
388                .ok_or_else(|| unexpected_backend_message("pipeline query", &msg));
389            let msg_type = match msg_type {
390                Ok(msg_type) => msg_type,
391                Err(err) => return return_with_desync(self, err),
392            };
393            if let Err(err) = flow.validate_msg_type(msg_type, "pipeline query", error.is_some()) {
394                return return_with_desync(self, err);
395            }
396            match msg {
397                BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
398                BackendMessage::RowDescription(_) => {}
399                BackendMessage::DataRow(data) => {
400                    if error.is_none() {
401                        current_rows.push(data);
402                    }
403                }
404                BackendMessage::CommandComplete(_) => {
405                    all_results.push(std::mem::take(&mut current_rows));
406                }
407                BackendMessage::NoData => {
408                    all_results.push(Vec::new());
409                }
410                BackendMessage::ReadyForQuery(_) => {
411                    if all_results.len() != queries.len() {
412                        return Err(error.unwrap_or_else(|| {
413                            PgError::Protocol(format!(
414                                "Pipeline completion mismatch: expected {}, got {}",
415                                queries.len(),
416                                all_results.len()
417                            ))
418                        }));
419                    }
420                    if let Some(err) = error {
421                        return Err(err);
422                    }
423                    return Ok(all_results);
424                }
425                other => {
426                    return return_with_desync(
427                        self,
428                        unexpected_backend_message("pipeline query", &other),
429                    );
430                }
431            }
432        }
433    }
434
435    /// Execute multiple Qail ASTs in a single network round-trip.
436    pub async fn pipeline_ast(
437        &mut self,
438        cmds: &[qail_core::ast::Qail],
439    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
440        let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
441        self.write_all_with_timeout(&buf, "stream write").await?;
442
443        let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
444        let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
445        let mut error: Option<PgError> = None;
446        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
447            expected_queries: cmds.len(),
448            allow_parse_complete: true,
449            require_parse_before_bind: true,
450            no_data_counts_as_completion: true,
451            allow_no_data_nonterminal: false,
452            expected_parse_completes: Some(cmds.len()),
453        });
454
455        loop {
456            let msg = self.recv().await?;
457            if is_ignorable_session_message(&msg) {
458                continue;
459            }
460            if let BackendMessage::ErrorResponse(err) = msg {
461                if error.is_none() {
462                    error = Some(PgError::QueryServer(err.into()));
463                }
464                continue;
465            }
466            let msg_type = backend_msg_type_for_flow(&msg)
467                .ok_or_else(|| unexpected_backend_message("pipeline ast", &msg));
468            let msg_type = match msg_type {
469                Ok(msg_type) => msg_type,
470                Err(err) => return return_with_desync(self, err),
471            };
472            if let Err(err) = flow.validate_msg_type(msg_type, "pipeline ast", error.is_some()) {
473                return return_with_desync(self, err);
474            }
475            match msg {
476                BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
477                BackendMessage::RowDescription(_) => {}
478                BackendMessage::DataRow(data) => {
479                    if error.is_none() {
480                        current_rows.push(data);
481                    }
482                }
483                BackendMessage::CommandComplete(_) => {
484                    all_results.push(std::mem::take(&mut current_rows));
485                }
486                BackendMessage::NoData => {
487                    all_results.push(Vec::new());
488                }
489                BackendMessage::ReadyForQuery(_) => {
490                    if all_results.len() != cmds.len() {
491                        return Err(error.unwrap_or_else(|| {
492                            PgError::Protocol(format!(
493                                "Pipeline completion mismatch: expected {}, got {}",
494                                cmds.len(),
495                                all_results.len()
496                            ))
497                        }));
498                    }
499                    if let Some(err) = error {
500                        return Err(err);
501                    }
502                    return Ok(all_results);
503                }
504                other => {
505                    return return_with_desync(
506                        self,
507                        unexpected_backend_message("pipeline ast", &other),
508                    );
509                }
510            }
511        }
512    }
513
514    /// FAST AST pipeline - returns only query count, no result parsing.
515    pub async fn pipeline_ast_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
516        let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
517
518        self.write_all_with_timeout(&buf, "stream write").await?;
519        self.flush_with_timeout("stream flush").await?;
520
521        let mut error: Option<PgError> = None;
522        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
523            expected_queries: cmds.len(),
524            allow_parse_complete: true,
525            require_parse_before_bind: true,
526            no_data_counts_as_completion: true,
527            allow_no_data_nonterminal: false,
528            expected_parse_completes: Some(cmds.len()),
529        });
530
531        loop {
532            match self.recv_msg_type_fast().await {
533                Ok(msg_type) => {
534                    let event = match flow.validate_msg_type(
535                        msg_type,
536                        "pipeline_ast_fast",
537                        error.is_some(),
538                    ) {
539                        Ok(event) => event,
540                        Err(err) => return return_with_desync(self, err),
541                    };
542                    match event {
543                        FastPipelineEvent::Continue => {}
544                        FastPipelineEvent::ReadyForQuery => {
545                            if let Some(err) = error {
546                                return Err(err);
547                            }
548                            return Ok(flow.completed_queries());
549                        }
550                    }
551                }
552                Err(e) => {
553                    if matches!(&e, PgError::QueryServer(_)) {
554                        capture_query_server_error(self, &mut error, e);
555                        continue;
556                    }
557                    return Err(e);
558                }
559            }
560        }
561    }
562
563    /// FASTEST extended query pipeline - takes pre-encoded wire bytes.
564    #[inline]
565    pub async fn pipeline_bytes_fast(
566        &mut self,
567        wire_bytes: &[u8],
568        expected_queries: usize,
569    ) -> PgResult<usize> {
570        self.write_all_with_timeout(wire_bytes, "stream write")
571            .await?;
572        self.flush_with_timeout("stream flush").await?;
573
574        let mut error: Option<PgError> = None;
575        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
576            expected_queries,
577            allow_parse_complete: true,
578            require_parse_before_bind: false,
579            no_data_counts_as_completion: true,
580            allow_no_data_nonterminal: false,
581            expected_parse_completes: None,
582        });
583
584        loop {
585            match self.recv_msg_type_fast().await {
586                Ok(msg_type) => {
587                    let event = match flow.validate_msg_type(
588                        msg_type,
589                        "pipeline_bytes_fast",
590                        error.is_some(),
591                    ) {
592                        Ok(event) => event,
593                        Err(err) => return return_with_desync(self, err),
594                    };
595                    match event {
596                        FastPipelineEvent::Continue => {}
597                        FastPipelineEvent::ReadyForQuery => {
598                            if let Some(err) = error {
599                                return Err(err);
600                            }
601                            return Ok(flow.completed_queries());
602                        }
603                    }
604                }
605                Err(e) => {
606                    if matches!(&e, PgError::QueryServer(_)) {
607                        capture_query_server_error(self, &mut error, e);
608                        continue;
609                    }
610                    return Err(e);
611                }
612            }
613        }
614    }
615
616    /// Simple query protocol pipeline - uses 'Q' message.
617    #[inline]
618    pub async fn pipeline_simple_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
619        let buf =
620            AstEncoder::encode_batch_simple(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
621        self.write_all_with_timeout(&buf, "stream write").await?;
622        self.flush_with_timeout("stream flush").await?;
623
624        let mut error: Option<PgError> = None;
625        let mut flow = FastSimpleFlowTracker::new(cmds.len());
626
627        loop {
628            match self.recv_msg_type_fast().await {
629                Ok(msg_type) => {
630                    let event = match flow.validate_msg_type(
631                        msg_type,
632                        "pipeline_simple_fast",
633                        error.is_some(),
634                    ) {
635                        Ok(event) => event,
636                        Err(err) => return return_with_desync(self, err),
637                    };
638                    match event {
639                        FastPipelineEvent::Continue => {}
640                        FastPipelineEvent::ReadyForQuery => {
641                            if let Some(err) = error {
642                                return Err(err);
643                            }
644                            return Ok(flow.completed_queries());
645                        }
646                    }
647                }
648                Err(e) => {
649                    if matches!(&e, PgError::QueryServer(_)) {
650                        capture_query_server_error(self, &mut error, e);
651                        continue;
652                    }
653                    return Err(e);
654                }
655            }
656        }
657    }
658
659    /// FASTEST simple query pipeline - takes pre-encoded bytes.
660    #[inline]
661    pub async fn pipeline_simple_bytes_fast(
662        &mut self,
663        wire_bytes: &[u8],
664        expected_queries: usize,
665    ) -> PgResult<usize> {
666        self.write_all_with_timeout(wire_bytes, "stream write")
667            .await?;
668        self.flush_with_timeout("stream flush").await?;
669
670        let mut error: Option<PgError> = None;
671        let mut flow = FastSimpleFlowTracker::new(expected_queries);
672
673        loop {
674            match self.recv_msg_type_fast().await {
675                Ok(msg_type) => {
676                    let event = match flow.validate_msg_type(
677                        msg_type,
678                        "pipeline_simple_bytes_fast",
679                        error.is_some(),
680                    ) {
681                        Ok(event) => event,
682                        Err(err) => return return_with_desync(self, err),
683                    };
684                    match event {
685                        FastPipelineEvent::Continue => {}
686                        FastPipelineEvent::ReadyForQuery => {
687                            if let Some(err) = error {
688                                return Err(err);
689                            }
690                            return Ok(flow.completed_queries());
691                        }
692                    }
693                }
694                Err(e) => {
695                    if matches!(&e, PgError::QueryServer(_)) {
696                        capture_query_server_error(self, &mut error, e);
697                        continue;
698                    }
699                    return Err(e);
700                }
701            }
702        }
703    }
704
705    /// CACHED PREPARED STATEMENT pipeline - Parse once, Bind+Execute many.
706    /// 1. Generate SQL template with $1, $2, etc. placeholders
707    /// 2. Parse template ONCE (cached in PostgreSQL)
708    /// 3. Send Bind+Execute for each instance (params differ per query)
709    #[inline]
710    pub async fn pipeline_ast_cached(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
711        if cmds.is_empty() {
712            return Ok(0);
713        }
714
715        let mut buf = BytesMut::with_capacity(cmds.len() * 64);
716        let mut new_stmt_names: Vec<String> = Vec::new();
717
718        for cmd in cmds {
719            let (sql, params) =
720                AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
721            let stmt_name = Self::sql_to_stmt_name(&sql);
722
723            if !self.prepared_statements.contains_key(&stmt_name) {
724                self.evict_prepared_if_full();
725                buf.extend(PgEncoder::try_encode_parse(&stmt_name, &sql, &[])?);
726                self.prepared_statements.insert(stmt_name.clone(), sql);
727                new_stmt_names.push(stmt_name.clone());
728            }
729
730            let bind_msg = match PgEncoder::encode_bind("", &stmt_name, &params) {
731                Ok(msg) => msg,
732                Err(e) => {
733                    for stmt in &new_stmt_names {
734                        self.prepared_statements.remove(stmt);
735                    }
736                    return Err(PgError::Encode(e.to_string()));
737                }
738            };
739            buf.extend_from_slice(&bind_msg);
740            buf.extend(PgEncoder::try_encode_execute("", 0)?);
741        }
742
743        buf.extend(PgEncoder::encode_sync());
744
745        if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
746            for stmt in &new_stmt_names {
747                self.prepared_statements.remove(stmt);
748            }
749            return Err(err);
750        }
751        if let Err(err) = self.flush_with_timeout("stream flush").await {
752            for stmt in &new_stmt_names {
753                self.prepared_statements.remove(stmt);
754            }
755            return Err(err);
756        }
757
758        let mut error: Option<PgError> = None;
759        let expected_parse_completes = new_stmt_names.len();
760        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
761            expected_queries: cmds.len(),
762            allow_parse_complete: true,
763            require_parse_before_bind: false,
764            no_data_counts_as_completion: true,
765            allow_no_data_nonterminal: false,
766            expected_parse_completes: Some(expected_parse_completes),
767        });
768
769        loop {
770            match self.recv_msg_type_fast().await {
771                Ok(msg_type) => {
772                    match flow.validate_msg_type(msg_type, "pipeline_ast_cached", error.is_some()) {
773                        Ok(FastPipelineEvent::Continue) => {}
774                        Ok(FastPipelineEvent::ReadyForQuery) => {
775                            if let Some(err) = error {
776                                for stmt in &new_stmt_names {
777                                    self.prepared_statements.remove(stmt);
778                                }
779                                return Err(err);
780                            }
781                            return Ok(flow.completed_queries());
782                        }
783                        Err(err) => {
784                            for stmt in &new_stmt_names {
785                                self.prepared_statements.remove(stmt);
786                            }
787                            return return_with_desync(self, err);
788                        }
789                    }
790                }
791                Err(e) => {
792                    if matches!(&e, PgError::QueryServer(_)) {
793                        capture_query_server_error(self, &mut error, e);
794                        continue;
795                    }
796                    for stmt in &new_stmt_names {
797                        self.prepared_statements.remove(stmt);
798                    }
799                    return Err(e);
800                }
801            }
802        }
803    }
804
805    /// ZERO-LOOKUP prepared statement pipeline.
806    /// - Hash computation per query
807    /// - HashMap lookup per query
808    /// - String allocation for stmt_name
809    /// # Example
810    /// ```ignore
811    /// // Prepare once (outside timing loop):
812    /// let stmt = PreparedStatement::from_sql("SELECT id, name FROM harbors LIMIT $1");
813    /// let params_batch: Vec<Vec<Option<Vec<u8>>>> = (1..=1000)
814    ///     .map(|i| vec![Some(i.to_string().into_bytes())])
815    ///     .collect();
816    /// // Execute many (no hash, no lookup!):
817    /// conn.pipeline_prepared_fast(&stmt, &params_batch).await?;
818    /// ```
819    #[inline]
820    pub async fn pipeline_prepared_fast(
821        &mut self,
822        stmt: &super::PreparedStatement,
823        params_batch: &[Vec<Option<Vec<u8>>>],
824    ) -> PgResult<usize> {
825        if params_batch.is_empty() {
826            return Ok(0);
827        }
828
829        // Local buffer - faster than reusing connection buffer
830        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
831
832        let is_new = !self.prepared_statements.contains_key(&stmt.name);
833
834        if is_new {
835            return Err(PgError::Query(
836                "Statement not prepared. Call prepare() first.".to_string(),
837            ));
838        }
839
840        // ZERO ALLOCATION: write directly to local buffer
841        for params in params_batch {
842            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
843                .map_err(|e| PgError::Encode(e.to_string()))?;
844            PgEncoder::encode_execute_to(&mut buf);
845        }
846
847        PgEncoder::encode_sync_to(&mut buf);
848
849        self.write_all_with_timeout(&buf, "stream write").await?;
850        self.flush_with_timeout("stream flush").await?;
851
852        let mut error: Option<PgError> = None;
853        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
854            expected_queries: params_batch.len(),
855            allow_parse_complete: false,
856            require_parse_before_bind: false,
857            no_data_counts_as_completion: true,
858            allow_no_data_nonterminal: false,
859            expected_parse_completes: Some(0),
860        });
861
862        loop {
863            match self.recv_msg_type_fast().await {
864                Ok(msg_type) => {
865                    let event = match flow.validate_msg_type(
866                        msg_type,
867                        "pipeline_prepared_fast",
868                        error.is_some(),
869                    ) {
870                        Ok(event) => event,
871                        Err(err) => return return_with_desync(self, err),
872                    };
873                    match event {
874                        FastPipelineEvent::Continue => {}
875                        FastPipelineEvent::ReadyForQuery => {
876                            if let Some(err) = error {
877                                return Err(err);
878                            }
879                            return Ok(flow.completed_queries());
880                        }
881                    }
882                }
883                Err(e) => {
884                    if matches!(&e, PgError::QueryServer(_)) {
885                        capture_query_server_error(self, &mut error, e);
886                        continue;
887                    }
888                    return Err(e);
889                }
890            }
891        }
892    }
893
894    /// Prepare a statement and return a handle for fast execution.
895    /// PreparedStatement handle for use with pipeline_prepared_fast.
896    pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
897        use super::prepared::sql_bytes_to_stmt_name;
898
899        let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
900
901        if !self.prepared_statements.contains_key(&stmt_name) {
902            self.evict_prepared_if_full();
903            let mut buf = BytesMut::with_capacity(sql.len() + 32);
904            buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
905            buf.extend(PgEncoder::encode_sync());
906
907            self.write_all_with_timeout(&buf, "stream write").await?;
908            self.flush_with_timeout("stream flush").await?;
909
910            // Wait for ParseComplete
911            let mut error: Option<PgError> = None;
912            let mut saw_parse_complete = false;
913            loop {
914                match self.recv_msg_type_fast().await {
915                    Ok(msg_type) => match msg_type {
916                        b'1' => {
917                            if saw_parse_complete {
918                                return Err(PgError::Protocol(
919                                    "prepare received duplicate ParseComplete".to_string(),
920                                ));
921                            }
922                            saw_parse_complete = true;
923                            self.prepared_statements
924                                .insert(stmt_name.clone(), sql.to_string());
925                        }
926                        b'Z' => {
927                            if let Some(err) = error {
928                                return Err(err);
929                            }
930                            if !saw_parse_complete {
931                                return Err(PgError::Protocol(
932                                    "prepare reached ReadyForQuery without ParseComplete"
933                                        .to_string(),
934                                ));
935                            }
936                            break;
937                        }
938                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
939                        other => {
940                            return return_with_desync(
941                                self,
942                                unexpected_backend_msg_type("prepare", other),
943                            );
944                        }
945                    },
946                    Err(e) => {
947                        if matches!(&e, PgError::QueryServer(_)) {
948                            capture_query_server_error(self, &mut error, e);
949                            continue;
950                        }
951                        return Err(e);
952                    }
953                }
954            }
955        }
956
957        Ok(super::PreparedStatement {
958            name: stmt_name,
959            param_count: sql.matches('$').count(),
960        })
961    }
962
963    /// Execute a prepared statement pipeline and return all row data.
964    pub async fn pipeline_prepared_results(
965        &mut self,
966        stmt: &super::PreparedStatement,
967        params_batch: &[Vec<Option<Vec<u8>>>],
968    ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
969        if params_batch.is_empty() {
970            return Ok(Vec::new());
971        }
972
973        if !self.prepared_statements.contains_key(&stmt.name) {
974            return Err(PgError::Query(
975                "Statement not prepared. Call prepare() first.".to_string(),
976            ));
977        }
978
979        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
980
981        for params in params_batch {
982            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
983                .map_err(|e| PgError::Encode(e.to_string()))?;
984            PgEncoder::encode_execute_to(&mut buf);
985        }
986
987        PgEncoder::encode_sync_to(&mut buf);
988
989        self.write_all_with_timeout(&buf, "stream write").await?;
990        self.flush_with_timeout("stream flush").await?;
991
992        // Collect results using fast inline DataRow parsing
993        let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
994            Vec::with_capacity(params_batch.len());
995        let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
996        let mut error: Option<PgError> = None;
997        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
998            expected_queries: params_batch.len(),
999            allow_parse_complete: false,
1000            require_parse_before_bind: false,
1001            no_data_counts_as_completion: true,
1002            allow_no_data_nonterminal: false,
1003            expected_parse_completes: Some(0),
1004        });
1005
1006        loop {
1007            match self.recv_with_data_fast().await {
1008                Ok((msg_type, data)) => {
1009                    if let Err(err) = flow.validate_msg_type(
1010                        msg_type,
1011                        "pipeline_prepared_results",
1012                        error.is_some(),
1013                    ) {
1014                        return return_with_desync(self, err);
1015                    }
1016                    match msg_type {
1017                        b'2' => {} // BindComplete
1018                        b'T' => {} // RowDescription
1019                        b'D' => {
1020                            // DataRow
1021                            if error.is_none()
1022                                && let Some(row) = data
1023                            {
1024                                current_rows.push(row);
1025                            }
1026                        }
1027                        b'C' => {
1028                            // CommandComplete
1029                            all_results.push(std::mem::take(&mut current_rows));
1030                        }
1031                        b'n' => {
1032                            // NoData
1033                            all_results.push(Vec::new());
1034                        }
1035                        b'Z' => {
1036                            // ReadyForQuery
1037                            if all_results.len() != params_batch.len() {
1038                                return Err(error.unwrap_or_else(|| {
1039                                    PgError::Protocol(format!(
1040                                        "Pipeline completion mismatch: expected {}, got {}",
1041                                        params_batch.len(),
1042                                        all_results.len()
1043                                    ))
1044                                }));
1045                            }
1046                            if let Some(err) = error {
1047                                return Err(err);
1048                            }
1049                            return Ok(all_results);
1050                        }
1051                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
1052                        other => {
1053                            return return_with_desync(
1054                                self,
1055                                unexpected_backend_msg_type("pipeline_prepared_results", other),
1056                            );
1057                        }
1058                    }
1059                }
1060                Err(e) => {
1061                    if matches!(&e, PgError::QueryServer(_)) {
1062                        capture_query_server_error(self, &mut error, e);
1063                        continue;
1064                    }
1065                    return Err(e);
1066                }
1067            }
1068        }
1069    }
1070
1071    /// ZERO-COPY pipeline execution with Bytes for column data.
1072    pub async fn pipeline_prepared_zerocopy(
1073        &mut self,
1074        stmt: &super::PreparedStatement,
1075        params_batch: &[Vec<Option<Vec<u8>>>],
1076    ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
1077        if params_batch.is_empty() {
1078            return Ok(Vec::new());
1079        }
1080
1081        if !self.prepared_statements.contains_key(&stmt.name) {
1082            return Err(PgError::Query(
1083                "Statement not prepared. Call prepare() first.".to_string(),
1084            ));
1085        }
1086
1087        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1088
1089        for params in params_batch {
1090            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1091                .map_err(|e| PgError::Encode(e.to_string()))?;
1092            PgEncoder::encode_execute_to(&mut buf);
1093        }
1094
1095        PgEncoder::encode_sync_to(&mut buf);
1096
1097        self.write_all_with_timeout(&buf, "stream write").await?;
1098        self.flush_with_timeout("stream flush").await?;
1099
1100        // Collect results using ZERO-COPY Bytes
1101        let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
1102            Vec::with_capacity(params_batch.len());
1103        let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
1104        let mut error: Option<PgError> = None;
1105        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1106            expected_queries: params_batch.len(),
1107            allow_parse_complete: false,
1108            require_parse_before_bind: false,
1109            no_data_counts_as_completion: true,
1110            allow_no_data_nonterminal: false,
1111            expected_parse_completes: Some(0),
1112        });
1113
1114        loop {
1115            match self.recv_data_zerocopy().await {
1116                Ok((msg_type, data)) => {
1117                    if let Err(err) = flow.validate_msg_type(
1118                        msg_type,
1119                        "pipeline_prepared_zerocopy",
1120                        error.is_some(),
1121                    ) {
1122                        return return_with_desync(self, err);
1123                    }
1124                    match msg_type {
1125                        b'2' => {} // BindComplete
1126                        b'T' => {} // RowDescription
1127                        b'D' => {
1128                            // DataRow
1129                            if error.is_none()
1130                                && let Some(row) = data
1131                            {
1132                                current_rows.push(row);
1133                            }
1134                        }
1135                        b'C' => {
1136                            // CommandComplete
1137                            all_results.push(std::mem::take(&mut current_rows));
1138                        }
1139                        b'n' => {
1140                            // NoData
1141                            all_results.push(Vec::new());
1142                        }
1143                        b'Z' => {
1144                            // ReadyForQuery
1145                            if all_results.len() != params_batch.len() {
1146                                return Err(error.unwrap_or_else(|| {
1147                                    PgError::Protocol(format!(
1148                                        "Pipeline completion mismatch: expected {}, got {}",
1149                                        params_batch.len(),
1150                                        all_results.len()
1151                                    ))
1152                                }));
1153                            }
1154                            if let Some(err) = error {
1155                                return Err(err);
1156                            }
1157                            return Ok(all_results);
1158                        }
1159                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
1160                        other => {
1161                            return return_with_desync(
1162                                self,
1163                                unexpected_backend_msg_type("pipeline_prepared_zerocopy", other),
1164                            );
1165                        }
1166                    }
1167                }
1168                Err(e) => {
1169                    if matches!(&e, PgError::QueryServer(_)) {
1170                        capture_query_server_error(self, &mut error, e);
1171                        continue;
1172                    }
1173                    return Err(e);
1174                }
1175            }
1176        }
1177    }
1178
1179    /// ULTRA-FAST pipeline for 2-column SELECT queries.
1180    pub async fn pipeline_prepared_ultra(
1181        &mut self,
1182        stmt: &super::PreparedStatement,
1183        params_batch: &[Vec<Option<Vec<u8>>>],
1184    ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
1185        if params_batch.is_empty() {
1186            return Ok(Vec::new());
1187        }
1188
1189        if !self.prepared_statements.contains_key(&stmt.name) {
1190            return Err(PgError::Query(
1191                "Statement not prepared. Call prepare() first.".to_string(),
1192            ));
1193        }
1194
1195        let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1196
1197        for params in params_batch {
1198            PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1199                .map_err(|e| PgError::Encode(e.to_string()))?;
1200            PgEncoder::encode_execute_to(&mut buf);
1201        }
1202
1203        PgEncoder::encode_sync_to(&mut buf);
1204
1205        self.write_all_with_timeout(&buf, "stream write").await?;
1206        self.flush_with_timeout("stream flush").await?;
1207
1208        // Pre-allocate with expected capacity
1209        let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
1210            Vec::with_capacity(params_batch.len());
1211        let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
1212        let mut error: Option<PgError> = None;
1213        let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1214            expected_queries: params_batch.len(),
1215            allow_parse_complete: false,
1216            require_parse_before_bind: false,
1217            no_data_counts_as_completion: true,
1218            allow_no_data_nonterminal: false,
1219            expected_parse_completes: Some(0),
1220        });
1221
1222        loop {
1223            match self.recv_data_ultra().await {
1224                Ok((msg_type, data)) => {
1225                    if let Err(err) =
1226                        flow.validate_msg_type(msg_type, "pipeline_prepared_ultra", error.is_some())
1227                    {
1228                        return return_with_desync(self, err);
1229                    }
1230                    match msg_type {
1231                        b'2' | b'T' => {} // BindComplete, RowDescription
1232                        b'D' => {
1233                            if error.is_none()
1234                                && let Some(row) = data
1235                            {
1236                                current_rows.push(row);
1237                            }
1238                        }
1239                        b'C' => {
1240                            all_results.push(std::mem::take(&mut current_rows));
1241                            current_rows = Vec::with_capacity(16);
1242                        }
1243                        b'n' => {
1244                            all_results.push(Vec::new());
1245                        }
1246                        b'Z' => {
1247                            if all_results.len() != params_batch.len() {
1248                                return Err(error.unwrap_or_else(|| {
1249                                    PgError::Protocol(format!(
1250                                        "Pipeline completion mismatch: expected {}, got {}",
1251                                        params_batch.len(),
1252                                        all_results.len()
1253                                    ))
1254                                }));
1255                            }
1256                            if let Some(err) = error {
1257                                return Err(err);
1258                            }
1259                            return Ok(all_results);
1260                        }
1261                        msg_type if is_ignorable_session_msg_type(msg_type) => {}
1262                        other => {
1263                            return return_with_desync(
1264                                self,
1265                                unexpected_backend_msg_type("pipeline_prepared_ultra", other),
1266                            );
1267                        }
1268                    }
1269                }
1270                Err(e) => {
1271                    if matches!(&e, PgError::QueryServer(_)) {
1272                        capture_query_server_error(self, &mut error, e);
1273                        continue;
1274                    }
1275                    return Err(e);
1276                }
1277            }
1278        }
1279    }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284    use super::*;
1285
1286    #[cfg(unix)]
1287    fn make_test_conn_with_prepared() -> PgConnection {
1288        use crate::driver::connection::StatementCache;
1289        use crate::driver::stream::PgStream;
1290        use bytes::BytesMut;
1291        use std::collections::{HashMap, VecDeque};
1292        use std::num::NonZeroUsize;
1293        use tokio::net::UnixStream;
1294
1295        let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1296        let mut conn = PgConnection {
1297            stream: PgStream::Unix(unix_stream),
1298            buffer: BytesMut::with_capacity(1024),
1299            write_buf: BytesMut::with_capacity(1024),
1300            sql_buf: BytesMut::with_capacity(256),
1301            params_buf: Vec::new(),
1302            prepared_statements: HashMap::new(),
1303            stmt_cache: StatementCache::new(NonZeroUsize::new(16).expect("non-zero")),
1304            column_info_cache: HashMap::new(),
1305            process_id: 0,
1306            secret_key: 0,
1307            notifications: VecDeque::new(),
1308            replication_stream_active: false,
1309            replication_mode_enabled: false,
1310            last_replication_wal_end: None,
1311            io_desynced: false,
1312            pending_statement_closes: Vec::new(),
1313            draining_statement_closes: false,
1314        };
1315        conn.prepared_statements
1316            .insert("s1".to_string(), "SELECT 1".to_string());
1317        conn.stmt_cache.put(1, "s1".to_string());
1318        conn
1319    }
1320
1321    fn server_error(code: &str, message: &str) -> PgError {
1322        PgError::QueryServer(super::super::PgServerError {
1323            severity: "ERROR".to_string(),
1324            code: code.to_string(),
1325            message: message.to_string(),
1326            detail: None,
1327            hint: None,
1328        })
1329    }
1330
1331    #[cfg(unix)]
1332    #[tokio::test]
1333    async fn capture_query_server_error_clears_prepared_state_on_retryable_error() {
1334        let mut conn = make_test_conn_with_prepared();
1335        let mut slot = None;
1336        let err = server_error("26000", "prepared statement \"s1\" does not exist");
1337        capture_query_server_error(&mut conn, &mut slot, err);
1338
1339        assert!(slot.is_some());
1340        assert!(conn.prepared_statements.is_empty());
1341        assert_eq!(conn.stmt_cache.len(), 0);
1342    }
1343
1344    #[cfg(unix)]
1345    #[tokio::test]
1346    async fn capture_query_server_error_preserves_prepared_state_on_non_retryable_error() {
1347        let mut conn = make_test_conn_with_prepared();
1348        let mut slot = None;
1349        let err = server_error("23505", "duplicate key value violates unique constraint");
1350        capture_query_server_error(&mut conn, &mut slot, err);
1351
1352        assert!(slot.is_some());
1353        assert_eq!(conn.prepared_statements.len(), 1);
1354        assert_eq!(conn.stmt_cache.len(), 1);
1355    }
1356
1357    #[cfg(unix)]
1358    #[tokio::test]
1359    async fn capture_query_server_error_does_not_override_existing_error() {
1360        let mut conn = make_test_conn_with_prepared();
1361        let mut slot = Some(server_error("23505", "duplicate key"));
1362        let retryable = server_error("26000", "prepared statement \"s1\" does not exist");
1363        capture_query_server_error(&mut conn, &mut slot, retryable);
1364
1365        assert_eq!(conn.prepared_statements.len(), 1);
1366        assert_eq!(conn.stmt_cache.len(), 1);
1367        assert_eq!(
1368            slot.and_then(|e| e.sqlstate().map(str::to_string))
1369                .as_deref(),
1370            Some("23505")
1371        );
1372    }
1373}