1use super::{
14 PgConnection, PgError, PgResult, is_ignorable_session_message, is_ignorable_session_msg_type,
15 unexpected_backend_message, unexpected_backend_msg_type,
16};
17use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
18use bytes::BytesMut;
19
20#[inline]
21fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
22 if matches!(
23 err,
24 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
25 ) {
26 conn.mark_io_desynced();
27 }
28 Err(err)
29}
30
31#[inline]
32fn capture_query_server_error(conn: &mut PgConnection, slot: &mut Option<PgError>, err: PgError) {
33 if slot.is_some() {
34 return;
35 }
36 if err.is_prepared_statement_retryable() {
37 conn.clear_prepared_statement_state();
38 }
39 *slot = Some(err);
40}
41
42#[derive(Debug, Clone, Copy)]
43struct FastExtendedFlowConfig {
44 expected_queries: usize,
45 allow_parse_complete: bool,
46 require_parse_before_bind: bool,
47 no_data_counts_as_completion: bool,
48 allow_no_data_nonterminal: bool,
49 expected_parse_completes: Option<usize>,
50}
51
52#[derive(Debug, Clone, Copy)]
53struct FastExtendedFlowTracker {
54 cfg: FastExtendedFlowConfig,
55 completed_queries: usize,
56 parse_completes: usize,
57 current_parse_seen: bool,
58 current_bind_seen: bool,
59}
60
61impl FastExtendedFlowTracker {
62 fn new(cfg: FastExtendedFlowConfig) -> Self {
63 Self {
64 cfg,
65 completed_queries: 0,
66 parse_completes: 0,
67 current_parse_seen: false,
68 current_bind_seen: false,
69 }
70 }
71
72 fn completed_queries(&self) -> usize {
73 self.completed_queries
74 }
75
76 fn validate_msg_type(
77 &mut self,
78 msg_type: u8,
79 context: &'static str,
80 error_pending: bool,
81 ) -> PgResult<FastPipelineEvent> {
82 if is_ignorable_session_msg_type(msg_type) {
83 return Ok(FastPipelineEvent::Continue);
84 }
85
86 if error_pending {
87 if msg_type == b'Z' {
88 return Ok(FastPipelineEvent::ReadyForQuery);
89 }
90 return Ok(FastPipelineEvent::Continue);
91 }
92
93 if msg_type == b'Z' {
94 if self.completed_queries != self.cfg.expected_queries {
95 return Err(PgError::Protocol(format!(
96 "{}: Pipeline completion mismatch: expected {}, got {}",
97 context, self.cfg.expected_queries, self.completed_queries
98 )));
99 }
100 if self.current_parse_seen || self.current_bind_seen {
101 return Err(PgError::Protocol(format!(
102 "{}: ReadyForQuery with incomplete query state",
103 context
104 )));
105 }
106 if let Some(expected) = self.cfg.expected_parse_completes
107 && self.parse_completes != expected
108 {
109 return Err(PgError::Protocol(format!(
110 "{}: ParseComplete mismatch: expected {}, got {}",
111 context, expected, self.parse_completes
112 )));
113 }
114 return Ok(FastPipelineEvent::ReadyForQuery);
115 }
116
117 if self.completed_queries >= self.cfg.expected_queries {
118 return Err(PgError::Protocol(format!(
119 "{}: unexpected message '{}' after all queries completed",
120 context, msg_type as char
121 )));
122 }
123
124 match msg_type {
125 b'1' => {
126 if !self.cfg.allow_parse_complete {
127 return Err(PgError::Protocol(format!(
128 "{}: unexpected ParseComplete",
129 context
130 )));
131 }
132 if self.current_bind_seen {
133 return Err(PgError::Protocol(format!(
134 "{}: ParseComplete after BindComplete",
135 context
136 )));
137 }
138 if self.current_parse_seen {
139 return Err(PgError::Protocol(format!(
140 "{}: duplicate ParseComplete",
141 context
142 )));
143 }
144 self.current_parse_seen = true;
145 self.parse_completes += 1;
146 if let Some(expected) = self.cfg.expected_parse_completes
147 && self.parse_completes > expected
148 {
149 return Err(PgError::Protocol(format!(
150 "{}: ParseComplete mismatch: expected {}, got at least {}",
151 context, expected, self.parse_completes
152 )));
153 }
154 }
155 b'2' => {
156 if self.current_bind_seen {
157 return Err(PgError::Protocol(format!(
158 "{}: duplicate BindComplete",
159 context
160 )));
161 }
162 if self.cfg.require_parse_before_bind && !self.current_parse_seen {
163 return Err(PgError::Protocol(format!(
164 "{}: BindComplete before ParseComplete",
165 context
166 )));
167 }
168 self.current_bind_seen = true;
169 }
170 b'T' | b't' | b's' => {
171 if !self.current_bind_seen {
172 return Err(PgError::Protocol(format!(
173 "{}: '{}' before BindComplete",
174 context, msg_type as char
175 )));
176 }
177 }
178 b'D' => {
179 if !self.current_bind_seen {
180 return Err(PgError::Protocol(format!(
181 "{}: DataRow before BindComplete",
182 context
183 )));
184 }
185 }
186 b'n' => {
187 if !self.current_bind_seen {
188 return Err(PgError::Protocol(format!(
189 "{}: NoData before BindComplete",
190 context
191 )));
192 }
193 if self.cfg.no_data_counts_as_completion {
194 self.complete_current();
195 } else if !self.cfg.allow_no_data_nonterminal {
196 return Err(PgError::Protocol(format!("{}: unexpected NoData", context)));
197 }
198 }
199 b'C' => {
200 if !self.current_bind_seen {
201 return Err(PgError::Protocol(format!(
202 "{}: CommandComplete before BindComplete",
203 context
204 )));
205 }
206 self.complete_current();
207 }
208 b'I' => {
209 return Err(PgError::Protocol(format!(
210 "{}: unexpected EmptyQueryResponse in extended pipeline",
211 context
212 )));
213 }
214 other => return Err(unexpected_backend_msg_type(context, other)),
215 }
216
217 Ok(FastPipelineEvent::Continue)
218 }
219
220 fn complete_current(&mut self) {
221 self.completed_queries += 1;
222 self.current_parse_seen = false;
223 self.current_bind_seen = false;
224 }
225}
226
227#[derive(Debug, Clone, Copy)]
228struct FastSimpleFlowTracker {
229 expected_queries: usize,
230 completed_queries: usize,
231 current_row_description_seen: bool,
232}
233
234impl FastSimpleFlowTracker {
235 fn new(expected_queries: usize) -> Self {
236 Self {
237 expected_queries,
238 completed_queries: 0,
239 current_row_description_seen: false,
240 }
241 }
242
243 fn completed_queries(&self) -> usize {
244 self.completed_queries
245 }
246
247 fn validate_msg_type(
248 &mut self,
249 msg_type: u8,
250 context: &'static str,
251 error_pending: bool,
252 ) -> PgResult<FastPipelineEvent> {
253 if is_ignorable_session_msg_type(msg_type) {
254 return Ok(FastPipelineEvent::Continue);
255 }
256
257 if error_pending {
258 if msg_type == b'Z' {
259 return Ok(FastPipelineEvent::ReadyForQuery);
260 }
261 return Ok(FastPipelineEvent::Continue);
262 }
263
264 if msg_type == b'Z' {
265 if self.completed_queries != self.expected_queries {
266 return Err(PgError::Protocol(format!(
267 "{}: Pipeline completion mismatch: expected {}, got {}",
268 context, self.expected_queries, self.completed_queries
269 )));
270 }
271 if self.current_row_description_seen {
272 return Err(PgError::Protocol(format!(
273 "{}: ReadyForQuery with incomplete row stream",
274 context
275 )));
276 }
277 return Ok(FastPipelineEvent::ReadyForQuery);
278 }
279
280 if self.completed_queries >= self.expected_queries {
281 return Err(PgError::Protocol(format!(
282 "{}: unexpected message '{}' after all queries completed",
283 context, msg_type as char
284 )));
285 }
286
287 match msg_type {
288 b'T' => {
289 if self.current_row_description_seen {
290 return Err(PgError::Protocol(format!(
291 "{}: duplicate RowDescription",
292 context
293 )));
294 }
295 self.current_row_description_seen = true;
296 }
297 b'D' => {
298 if !self.current_row_description_seen {
299 return Err(PgError::Protocol(format!(
300 "{}: DataRow before RowDescription",
301 context
302 )));
303 }
304 }
305 b'C' | b'I' => {
306 self.completed_queries += 1;
307 self.current_row_description_seen = false;
308 }
309 b'1' | b'2' | b'n' | b't' | b's' => {
310 return Err(PgError::Protocol(format!(
311 "{}: unexpected '{}' in simple pipeline",
312 context, msg_type as char
313 )));
314 }
315 other => return Err(unexpected_backend_msg_type(context, other)),
316 }
317
318 Ok(FastPipelineEvent::Continue)
319 }
320}
321
322#[derive(Debug, Clone, Copy, PartialEq, Eq)]
323enum FastPipelineEvent {
324 Continue,
325 ReadyForQuery,
326}
327
328#[inline]
329fn backend_msg_type_for_flow(msg: &BackendMessage) -> Option<u8> {
330 match msg {
331 BackendMessage::ParseComplete => Some(b'1'),
332 BackendMessage::BindComplete => Some(b'2'),
333 BackendMessage::ParameterDescription(_) => Some(b't'),
334 BackendMessage::RowDescription(_) => Some(b'T'),
335 BackendMessage::NoData => Some(b'n'),
336 BackendMessage::PortalSuspended => Some(b's'),
337 BackendMessage::DataRow(_) => Some(b'D'),
338 BackendMessage::CommandComplete(_) => Some(b'C'),
339 BackendMessage::EmptyQueryResponse => Some(b'I'),
340 BackendMessage::ReadyForQuery(_) => Some(b'Z'),
341 _ => None,
342 }
343}
344
345impl PgConnection {
346 pub async fn query_pipeline(
348 &mut self,
349 queries: &[(&str, &[Option<Vec<u8>>])],
350 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
351 let mut buf = BytesMut::new();
353 for (sql, params) in queries {
354 buf.extend_from_slice(
355 &PgEncoder::encode_extended_query(sql, params)
356 .map_err(|e| PgError::Encode(e.to_string()))?,
357 );
358 }
359
360 self.write_all_with_timeout(&buf, "stream write").await?;
362
363 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
365 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
366 let mut error: Option<PgError> = None;
367 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
368 expected_queries: queries.len(),
369 allow_parse_complete: true,
370 require_parse_before_bind: true,
371 no_data_counts_as_completion: true,
372 allow_no_data_nonterminal: false,
373 expected_parse_completes: Some(queries.len()),
374 });
375
376 loop {
377 let msg = self.recv().await?;
378 if is_ignorable_session_message(&msg) {
379 continue;
380 }
381 if let BackendMessage::ErrorResponse(err) = msg {
382 if error.is_none() {
383 error = Some(PgError::QueryServer(err.into()));
384 }
385 continue;
386 }
387 let msg_type = backend_msg_type_for_flow(&msg)
388 .ok_or_else(|| unexpected_backend_message("pipeline query", &msg));
389 let msg_type = match msg_type {
390 Ok(msg_type) => msg_type,
391 Err(err) => return return_with_desync(self, err),
392 };
393 if let Err(err) = flow.validate_msg_type(msg_type, "pipeline query", error.is_some()) {
394 return return_with_desync(self, err);
395 }
396 match msg {
397 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
398 BackendMessage::RowDescription(_) => {}
399 BackendMessage::DataRow(data) => {
400 if error.is_none() {
401 current_rows.push(data);
402 }
403 }
404 BackendMessage::CommandComplete(_) => {
405 all_results.push(std::mem::take(&mut current_rows));
406 }
407 BackendMessage::NoData => {
408 all_results.push(Vec::new());
409 }
410 BackendMessage::ReadyForQuery(_) => {
411 if all_results.len() != queries.len() {
412 return Err(error.unwrap_or_else(|| {
413 PgError::Protocol(format!(
414 "Pipeline completion mismatch: expected {}, got {}",
415 queries.len(),
416 all_results.len()
417 ))
418 }));
419 }
420 if let Some(err) = error {
421 return Err(err);
422 }
423 return Ok(all_results);
424 }
425 other => {
426 return return_with_desync(
427 self,
428 unexpected_backend_message("pipeline query", &other),
429 );
430 }
431 }
432 }
433 }
434
435 pub async fn pipeline_ast(
437 &mut self,
438 cmds: &[qail_core::ast::Qail],
439 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
440 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
441 self.write_all_with_timeout(&buf, "stream write").await?;
442
443 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
444 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
445 let mut error: Option<PgError> = None;
446 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
447 expected_queries: cmds.len(),
448 allow_parse_complete: true,
449 require_parse_before_bind: true,
450 no_data_counts_as_completion: true,
451 allow_no_data_nonterminal: false,
452 expected_parse_completes: Some(cmds.len()),
453 });
454
455 loop {
456 let msg = self.recv().await?;
457 if is_ignorable_session_message(&msg) {
458 continue;
459 }
460 if let BackendMessage::ErrorResponse(err) = msg {
461 if error.is_none() {
462 error = Some(PgError::QueryServer(err.into()));
463 }
464 continue;
465 }
466 let msg_type = backend_msg_type_for_flow(&msg)
467 .ok_or_else(|| unexpected_backend_message("pipeline ast", &msg));
468 let msg_type = match msg_type {
469 Ok(msg_type) => msg_type,
470 Err(err) => return return_with_desync(self, err),
471 };
472 if let Err(err) = flow.validate_msg_type(msg_type, "pipeline ast", error.is_some()) {
473 return return_with_desync(self, err);
474 }
475 match msg {
476 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
477 BackendMessage::RowDescription(_) => {}
478 BackendMessage::DataRow(data) => {
479 if error.is_none() {
480 current_rows.push(data);
481 }
482 }
483 BackendMessage::CommandComplete(_) => {
484 all_results.push(std::mem::take(&mut current_rows));
485 }
486 BackendMessage::NoData => {
487 all_results.push(Vec::new());
488 }
489 BackendMessage::ReadyForQuery(_) => {
490 if all_results.len() != cmds.len() {
491 return Err(error.unwrap_or_else(|| {
492 PgError::Protocol(format!(
493 "Pipeline completion mismatch: expected {}, got {}",
494 cmds.len(),
495 all_results.len()
496 ))
497 }));
498 }
499 if let Some(err) = error {
500 return Err(err);
501 }
502 return Ok(all_results);
503 }
504 other => {
505 return return_with_desync(
506 self,
507 unexpected_backend_message("pipeline ast", &other),
508 );
509 }
510 }
511 }
512 }
513
514 pub async fn pipeline_ast_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
516 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
517
518 self.write_all_with_timeout(&buf, "stream write").await?;
519 self.flush_with_timeout("stream flush").await?;
520
521 let mut error: Option<PgError> = None;
522 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
523 expected_queries: cmds.len(),
524 allow_parse_complete: true,
525 require_parse_before_bind: true,
526 no_data_counts_as_completion: true,
527 allow_no_data_nonterminal: false,
528 expected_parse_completes: Some(cmds.len()),
529 });
530
531 loop {
532 match self.recv_msg_type_fast().await {
533 Ok(msg_type) => {
534 let event = match flow.validate_msg_type(
535 msg_type,
536 "pipeline_ast_fast",
537 error.is_some(),
538 ) {
539 Ok(event) => event,
540 Err(err) => return return_with_desync(self, err),
541 };
542 match event {
543 FastPipelineEvent::Continue => {}
544 FastPipelineEvent::ReadyForQuery => {
545 if let Some(err) = error {
546 return Err(err);
547 }
548 return Ok(flow.completed_queries());
549 }
550 }
551 }
552 Err(e) => {
553 if matches!(&e, PgError::QueryServer(_)) {
554 capture_query_server_error(self, &mut error, e);
555 continue;
556 }
557 return Err(e);
558 }
559 }
560 }
561 }
562
563 #[inline]
565 pub async fn pipeline_bytes_fast(
566 &mut self,
567 wire_bytes: &[u8],
568 expected_queries: usize,
569 ) -> PgResult<usize> {
570 self.write_all_with_timeout(wire_bytes, "stream write")
571 .await?;
572 self.flush_with_timeout("stream flush").await?;
573
574 let mut error: Option<PgError> = None;
575 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
576 expected_queries,
577 allow_parse_complete: true,
578 require_parse_before_bind: false,
579 no_data_counts_as_completion: true,
580 allow_no_data_nonterminal: false,
581 expected_parse_completes: None,
582 });
583
584 loop {
585 match self.recv_msg_type_fast().await {
586 Ok(msg_type) => {
587 let event = match flow.validate_msg_type(
588 msg_type,
589 "pipeline_bytes_fast",
590 error.is_some(),
591 ) {
592 Ok(event) => event,
593 Err(err) => return return_with_desync(self, err),
594 };
595 match event {
596 FastPipelineEvent::Continue => {}
597 FastPipelineEvent::ReadyForQuery => {
598 if let Some(err) = error {
599 return Err(err);
600 }
601 return Ok(flow.completed_queries());
602 }
603 }
604 }
605 Err(e) => {
606 if matches!(&e, PgError::QueryServer(_)) {
607 capture_query_server_error(self, &mut error, e);
608 continue;
609 }
610 return Err(e);
611 }
612 }
613 }
614 }
615
616 #[inline]
618 pub async fn pipeline_simple_fast(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
619 let buf =
620 AstEncoder::encode_batch_simple(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
621 self.write_all_with_timeout(&buf, "stream write").await?;
622 self.flush_with_timeout("stream flush").await?;
623
624 let mut error: Option<PgError> = None;
625 let mut flow = FastSimpleFlowTracker::new(cmds.len());
626
627 loop {
628 match self.recv_msg_type_fast().await {
629 Ok(msg_type) => {
630 let event = match flow.validate_msg_type(
631 msg_type,
632 "pipeline_simple_fast",
633 error.is_some(),
634 ) {
635 Ok(event) => event,
636 Err(err) => return return_with_desync(self, err),
637 };
638 match event {
639 FastPipelineEvent::Continue => {}
640 FastPipelineEvent::ReadyForQuery => {
641 if let Some(err) = error {
642 return Err(err);
643 }
644 return Ok(flow.completed_queries());
645 }
646 }
647 }
648 Err(e) => {
649 if matches!(&e, PgError::QueryServer(_)) {
650 capture_query_server_error(self, &mut error, e);
651 continue;
652 }
653 return Err(e);
654 }
655 }
656 }
657 }
658
659 #[inline]
661 pub async fn pipeline_simple_bytes_fast(
662 &mut self,
663 wire_bytes: &[u8],
664 expected_queries: usize,
665 ) -> PgResult<usize> {
666 self.write_all_with_timeout(wire_bytes, "stream write")
667 .await?;
668 self.flush_with_timeout("stream flush").await?;
669
670 let mut error: Option<PgError> = None;
671 let mut flow = FastSimpleFlowTracker::new(expected_queries);
672
673 loop {
674 match self.recv_msg_type_fast().await {
675 Ok(msg_type) => {
676 let event = match flow.validate_msg_type(
677 msg_type,
678 "pipeline_simple_bytes_fast",
679 error.is_some(),
680 ) {
681 Ok(event) => event,
682 Err(err) => return return_with_desync(self, err),
683 };
684 match event {
685 FastPipelineEvent::Continue => {}
686 FastPipelineEvent::ReadyForQuery => {
687 if let Some(err) = error {
688 return Err(err);
689 }
690 return Ok(flow.completed_queries());
691 }
692 }
693 }
694 Err(e) => {
695 if matches!(&e, PgError::QueryServer(_)) {
696 capture_query_server_error(self, &mut error, e);
697 continue;
698 }
699 return Err(e);
700 }
701 }
702 }
703 }
704
705 #[inline]
710 pub async fn pipeline_ast_cached(&mut self, cmds: &[qail_core::ast::Qail]) -> PgResult<usize> {
711 if cmds.is_empty() {
712 return Ok(0);
713 }
714
715 let mut buf = BytesMut::with_capacity(cmds.len() * 64);
716 let mut new_stmt_names: Vec<String> = Vec::new();
717
718 for cmd in cmds {
719 let (sql, params) =
720 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
721 let stmt_name = Self::sql_to_stmt_name(&sql);
722
723 if !self.prepared_statements.contains_key(&stmt_name) {
724 self.evict_prepared_if_full();
725 buf.extend(PgEncoder::try_encode_parse(&stmt_name, &sql, &[])?);
726 self.prepared_statements.insert(stmt_name.clone(), sql);
727 new_stmt_names.push(stmt_name.clone());
728 }
729
730 let bind_msg = match PgEncoder::encode_bind("", &stmt_name, ¶ms) {
731 Ok(msg) => msg,
732 Err(e) => {
733 for stmt in &new_stmt_names {
734 self.prepared_statements.remove(stmt);
735 }
736 return Err(PgError::Encode(e.to_string()));
737 }
738 };
739 buf.extend_from_slice(&bind_msg);
740 buf.extend(PgEncoder::try_encode_execute("", 0)?);
741 }
742
743 buf.extend(PgEncoder::encode_sync());
744
745 if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
746 for stmt in &new_stmt_names {
747 self.prepared_statements.remove(stmt);
748 }
749 return Err(err);
750 }
751 if let Err(err) = self.flush_with_timeout("stream flush").await {
752 for stmt in &new_stmt_names {
753 self.prepared_statements.remove(stmt);
754 }
755 return Err(err);
756 }
757
758 let mut error: Option<PgError> = None;
759 let expected_parse_completes = new_stmt_names.len();
760 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
761 expected_queries: cmds.len(),
762 allow_parse_complete: true,
763 require_parse_before_bind: false,
764 no_data_counts_as_completion: true,
765 allow_no_data_nonterminal: false,
766 expected_parse_completes: Some(expected_parse_completes),
767 });
768
769 loop {
770 match self.recv_msg_type_fast().await {
771 Ok(msg_type) => {
772 match flow.validate_msg_type(msg_type, "pipeline_ast_cached", error.is_some()) {
773 Ok(FastPipelineEvent::Continue) => {}
774 Ok(FastPipelineEvent::ReadyForQuery) => {
775 if let Some(err) = error {
776 for stmt in &new_stmt_names {
777 self.prepared_statements.remove(stmt);
778 }
779 return Err(err);
780 }
781 return Ok(flow.completed_queries());
782 }
783 Err(err) => {
784 for stmt in &new_stmt_names {
785 self.prepared_statements.remove(stmt);
786 }
787 return return_with_desync(self, err);
788 }
789 }
790 }
791 Err(e) => {
792 if matches!(&e, PgError::QueryServer(_)) {
793 capture_query_server_error(self, &mut error, e);
794 continue;
795 }
796 for stmt in &new_stmt_names {
797 self.prepared_statements.remove(stmt);
798 }
799 return Err(e);
800 }
801 }
802 }
803 }
804
805 #[inline]
820 pub async fn pipeline_prepared_fast(
821 &mut self,
822 stmt: &super::PreparedStatement,
823 params_batch: &[Vec<Option<Vec<u8>>>],
824 ) -> PgResult<usize> {
825 if params_batch.is_empty() {
826 return Ok(0);
827 }
828
829 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
831
832 let is_new = !self.prepared_statements.contains_key(&stmt.name);
833
834 if is_new {
835 return Err(PgError::Query(
836 "Statement not prepared. Call prepare() first.".to_string(),
837 ));
838 }
839
840 for params in params_batch {
842 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
843 .map_err(|e| PgError::Encode(e.to_string()))?;
844 PgEncoder::encode_execute_to(&mut buf);
845 }
846
847 PgEncoder::encode_sync_to(&mut buf);
848
849 self.write_all_with_timeout(&buf, "stream write").await?;
850 self.flush_with_timeout("stream flush").await?;
851
852 let mut error: Option<PgError> = None;
853 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
854 expected_queries: params_batch.len(),
855 allow_parse_complete: false,
856 require_parse_before_bind: false,
857 no_data_counts_as_completion: true,
858 allow_no_data_nonterminal: false,
859 expected_parse_completes: Some(0),
860 });
861
862 loop {
863 match self.recv_msg_type_fast().await {
864 Ok(msg_type) => {
865 let event = match flow.validate_msg_type(
866 msg_type,
867 "pipeline_prepared_fast",
868 error.is_some(),
869 ) {
870 Ok(event) => event,
871 Err(err) => return return_with_desync(self, err),
872 };
873 match event {
874 FastPipelineEvent::Continue => {}
875 FastPipelineEvent::ReadyForQuery => {
876 if let Some(err) = error {
877 return Err(err);
878 }
879 return Ok(flow.completed_queries());
880 }
881 }
882 }
883 Err(e) => {
884 if matches!(&e, PgError::QueryServer(_)) {
885 capture_query_server_error(self, &mut error, e);
886 continue;
887 }
888 return Err(e);
889 }
890 }
891 }
892 }
893
894 pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
897 use super::prepared::sql_bytes_to_stmt_name;
898
899 let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
900
901 if !self.prepared_statements.contains_key(&stmt_name) {
902 self.evict_prepared_if_full();
903 let mut buf = BytesMut::with_capacity(sql.len() + 32);
904 buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
905 buf.extend(PgEncoder::encode_sync());
906
907 self.write_all_with_timeout(&buf, "stream write").await?;
908 self.flush_with_timeout("stream flush").await?;
909
910 let mut error: Option<PgError> = None;
912 let mut saw_parse_complete = false;
913 loop {
914 match self.recv_msg_type_fast().await {
915 Ok(msg_type) => match msg_type {
916 b'1' => {
917 if saw_parse_complete {
918 return Err(PgError::Protocol(
919 "prepare received duplicate ParseComplete".to_string(),
920 ));
921 }
922 saw_parse_complete = true;
923 self.prepared_statements
924 .insert(stmt_name.clone(), sql.to_string());
925 }
926 b'Z' => {
927 if let Some(err) = error {
928 return Err(err);
929 }
930 if !saw_parse_complete {
931 return Err(PgError::Protocol(
932 "prepare reached ReadyForQuery without ParseComplete"
933 .to_string(),
934 ));
935 }
936 break;
937 }
938 msg_type if is_ignorable_session_msg_type(msg_type) => {}
939 other => {
940 return return_with_desync(
941 self,
942 unexpected_backend_msg_type("prepare", other),
943 );
944 }
945 },
946 Err(e) => {
947 if matches!(&e, PgError::QueryServer(_)) {
948 capture_query_server_error(self, &mut error, e);
949 continue;
950 }
951 return Err(e);
952 }
953 }
954 }
955 }
956
957 Ok(super::PreparedStatement {
958 name: stmt_name,
959 param_count: sql.matches('$').count(),
960 })
961 }
962
963 pub async fn pipeline_prepared_results(
965 &mut self,
966 stmt: &super::PreparedStatement,
967 params_batch: &[Vec<Option<Vec<u8>>>],
968 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
969 if params_batch.is_empty() {
970 return Ok(Vec::new());
971 }
972
973 if !self.prepared_statements.contains_key(&stmt.name) {
974 return Err(PgError::Query(
975 "Statement not prepared. Call prepare() first.".to_string(),
976 ));
977 }
978
979 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
980
981 for params in params_batch {
982 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
983 .map_err(|e| PgError::Encode(e.to_string()))?;
984 PgEncoder::encode_execute_to(&mut buf);
985 }
986
987 PgEncoder::encode_sync_to(&mut buf);
988
989 self.write_all_with_timeout(&buf, "stream write").await?;
990 self.flush_with_timeout("stream flush").await?;
991
992 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
994 Vec::with_capacity(params_batch.len());
995 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
996 let mut error: Option<PgError> = None;
997 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
998 expected_queries: params_batch.len(),
999 allow_parse_complete: false,
1000 require_parse_before_bind: false,
1001 no_data_counts_as_completion: true,
1002 allow_no_data_nonterminal: false,
1003 expected_parse_completes: Some(0),
1004 });
1005
1006 loop {
1007 match self.recv_with_data_fast().await {
1008 Ok((msg_type, data)) => {
1009 if let Err(err) = flow.validate_msg_type(
1010 msg_type,
1011 "pipeline_prepared_results",
1012 error.is_some(),
1013 ) {
1014 return return_with_desync(self, err);
1015 }
1016 match msg_type {
1017 b'2' => {} b'T' => {} b'D' => {
1020 if error.is_none()
1022 && let Some(row) = data
1023 {
1024 current_rows.push(row);
1025 }
1026 }
1027 b'C' => {
1028 all_results.push(std::mem::take(&mut current_rows));
1030 }
1031 b'n' => {
1032 all_results.push(Vec::new());
1034 }
1035 b'Z' => {
1036 if all_results.len() != params_batch.len() {
1038 return Err(error.unwrap_or_else(|| {
1039 PgError::Protocol(format!(
1040 "Pipeline completion mismatch: expected {}, got {}",
1041 params_batch.len(),
1042 all_results.len()
1043 ))
1044 }));
1045 }
1046 if let Some(err) = error {
1047 return Err(err);
1048 }
1049 return Ok(all_results);
1050 }
1051 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1052 other => {
1053 return return_with_desync(
1054 self,
1055 unexpected_backend_msg_type("pipeline_prepared_results", other),
1056 );
1057 }
1058 }
1059 }
1060 Err(e) => {
1061 if matches!(&e, PgError::QueryServer(_)) {
1062 capture_query_server_error(self, &mut error, e);
1063 continue;
1064 }
1065 return Err(e);
1066 }
1067 }
1068 }
1069 }
1070
1071 pub async fn pipeline_prepared_zerocopy(
1073 &mut self,
1074 stmt: &super::PreparedStatement,
1075 params_batch: &[Vec<Option<Vec<u8>>>],
1076 ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
1077 if params_batch.is_empty() {
1078 return Ok(Vec::new());
1079 }
1080
1081 if !self.prepared_statements.contains_key(&stmt.name) {
1082 return Err(PgError::Query(
1083 "Statement not prepared. Call prepare() first.".to_string(),
1084 ));
1085 }
1086
1087 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1088
1089 for params in params_batch {
1090 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1091 .map_err(|e| PgError::Encode(e.to_string()))?;
1092 PgEncoder::encode_execute_to(&mut buf);
1093 }
1094
1095 PgEncoder::encode_sync_to(&mut buf);
1096
1097 self.write_all_with_timeout(&buf, "stream write").await?;
1098 self.flush_with_timeout("stream flush").await?;
1099
1100 let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
1102 Vec::with_capacity(params_batch.len());
1103 let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
1104 let mut error: Option<PgError> = None;
1105 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1106 expected_queries: params_batch.len(),
1107 allow_parse_complete: false,
1108 require_parse_before_bind: false,
1109 no_data_counts_as_completion: true,
1110 allow_no_data_nonterminal: false,
1111 expected_parse_completes: Some(0),
1112 });
1113
1114 loop {
1115 match self.recv_data_zerocopy().await {
1116 Ok((msg_type, data)) => {
1117 if let Err(err) = flow.validate_msg_type(
1118 msg_type,
1119 "pipeline_prepared_zerocopy",
1120 error.is_some(),
1121 ) {
1122 return return_with_desync(self, err);
1123 }
1124 match msg_type {
1125 b'2' => {} b'T' => {} b'D' => {
1128 if error.is_none()
1130 && let Some(row) = data
1131 {
1132 current_rows.push(row);
1133 }
1134 }
1135 b'C' => {
1136 all_results.push(std::mem::take(&mut current_rows));
1138 }
1139 b'n' => {
1140 all_results.push(Vec::new());
1142 }
1143 b'Z' => {
1144 if all_results.len() != params_batch.len() {
1146 return Err(error.unwrap_or_else(|| {
1147 PgError::Protocol(format!(
1148 "Pipeline completion mismatch: expected {}, got {}",
1149 params_batch.len(),
1150 all_results.len()
1151 ))
1152 }));
1153 }
1154 if let Some(err) = error {
1155 return Err(err);
1156 }
1157 return Ok(all_results);
1158 }
1159 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1160 other => {
1161 return return_with_desync(
1162 self,
1163 unexpected_backend_msg_type("pipeline_prepared_zerocopy", other),
1164 );
1165 }
1166 }
1167 }
1168 Err(e) => {
1169 if matches!(&e, PgError::QueryServer(_)) {
1170 capture_query_server_error(self, &mut error, e);
1171 continue;
1172 }
1173 return Err(e);
1174 }
1175 }
1176 }
1177 }
1178
1179 pub async fn pipeline_prepared_ultra(
1181 &mut self,
1182 stmt: &super::PreparedStatement,
1183 params_batch: &[Vec<Option<Vec<u8>>>],
1184 ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
1185 if params_batch.is_empty() {
1186 return Ok(Vec::new());
1187 }
1188
1189 if !self.prepared_statements.contains_key(&stmt.name) {
1190 return Err(PgError::Query(
1191 "Statement not prepared. Call prepare() first.".to_string(),
1192 ));
1193 }
1194
1195 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1196
1197 for params in params_batch {
1198 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1199 .map_err(|e| PgError::Encode(e.to_string()))?;
1200 PgEncoder::encode_execute_to(&mut buf);
1201 }
1202
1203 PgEncoder::encode_sync_to(&mut buf);
1204
1205 self.write_all_with_timeout(&buf, "stream write").await?;
1206 self.flush_with_timeout("stream flush").await?;
1207
1208 let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
1210 Vec::with_capacity(params_batch.len());
1211 let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
1212 let mut error: Option<PgError> = None;
1213 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1214 expected_queries: params_batch.len(),
1215 allow_parse_complete: false,
1216 require_parse_before_bind: false,
1217 no_data_counts_as_completion: true,
1218 allow_no_data_nonterminal: false,
1219 expected_parse_completes: Some(0),
1220 });
1221
1222 loop {
1223 match self.recv_data_ultra().await {
1224 Ok((msg_type, data)) => {
1225 if let Err(err) =
1226 flow.validate_msg_type(msg_type, "pipeline_prepared_ultra", error.is_some())
1227 {
1228 return return_with_desync(self, err);
1229 }
1230 match msg_type {
1231 b'2' | b'T' => {} b'D' => {
1233 if error.is_none()
1234 && let Some(row) = data
1235 {
1236 current_rows.push(row);
1237 }
1238 }
1239 b'C' => {
1240 all_results.push(std::mem::take(&mut current_rows));
1241 current_rows = Vec::with_capacity(16);
1242 }
1243 b'n' => {
1244 all_results.push(Vec::new());
1245 }
1246 b'Z' => {
1247 if all_results.len() != params_batch.len() {
1248 return Err(error.unwrap_or_else(|| {
1249 PgError::Protocol(format!(
1250 "Pipeline completion mismatch: expected {}, got {}",
1251 params_batch.len(),
1252 all_results.len()
1253 ))
1254 }));
1255 }
1256 if let Some(err) = error {
1257 return Err(err);
1258 }
1259 return Ok(all_results);
1260 }
1261 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1262 other => {
1263 return return_with_desync(
1264 self,
1265 unexpected_backend_msg_type("pipeline_prepared_ultra", other),
1266 );
1267 }
1268 }
1269 }
1270 Err(e) => {
1271 if matches!(&e, PgError::QueryServer(_)) {
1272 capture_query_server_error(self, &mut error, e);
1273 continue;
1274 }
1275 return Err(e);
1276 }
1277 }
1278 }
1279 }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284 use super::*;
1285
1286 #[cfg(unix)]
1287 fn make_test_conn_with_prepared() -> PgConnection {
1288 use crate::driver::connection::StatementCache;
1289 use crate::driver::stream::PgStream;
1290 use bytes::BytesMut;
1291 use std::collections::{HashMap, VecDeque};
1292 use std::num::NonZeroUsize;
1293 use tokio::net::UnixStream;
1294
1295 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1296 let mut conn = PgConnection {
1297 stream: PgStream::Unix(unix_stream),
1298 buffer: BytesMut::with_capacity(1024),
1299 write_buf: BytesMut::with_capacity(1024),
1300 sql_buf: BytesMut::with_capacity(256),
1301 params_buf: Vec::new(),
1302 prepared_statements: HashMap::new(),
1303 stmt_cache: StatementCache::new(NonZeroUsize::new(16).expect("non-zero")),
1304 column_info_cache: HashMap::new(),
1305 process_id: 0,
1306 secret_key: 0,
1307 notifications: VecDeque::new(),
1308 replication_stream_active: false,
1309 replication_mode_enabled: false,
1310 last_replication_wal_end: None,
1311 io_desynced: false,
1312 pending_statement_closes: Vec::new(),
1313 draining_statement_closes: false,
1314 };
1315 conn.prepared_statements
1316 .insert("s1".to_string(), "SELECT 1".to_string());
1317 conn.stmt_cache.put(1, "s1".to_string());
1318 conn
1319 }
1320
1321 fn server_error(code: &str, message: &str) -> PgError {
1322 PgError::QueryServer(super::super::PgServerError {
1323 severity: "ERROR".to_string(),
1324 code: code.to_string(),
1325 message: message.to_string(),
1326 detail: None,
1327 hint: None,
1328 })
1329 }
1330
1331 #[cfg(unix)]
1332 #[tokio::test]
1333 async fn capture_query_server_error_clears_prepared_state_on_retryable_error() {
1334 let mut conn = make_test_conn_with_prepared();
1335 let mut slot = None;
1336 let err = server_error("26000", "prepared statement \"s1\" does not exist");
1337 capture_query_server_error(&mut conn, &mut slot, err);
1338
1339 assert!(slot.is_some());
1340 assert!(conn.prepared_statements.is_empty());
1341 assert_eq!(conn.stmt_cache.len(), 0);
1342 }
1343
1344 #[cfg(unix)]
1345 #[tokio::test]
1346 async fn capture_query_server_error_preserves_prepared_state_on_non_retryable_error() {
1347 let mut conn = make_test_conn_with_prepared();
1348 let mut slot = None;
1349 let err = server_error("23505", "duplicate key value violates unique constraint");
1350 capture_query_server_error(&mut conn, &mut slot, err);
1351
1352 assert!(slot.is_some());
1353 assert_eq!(conn.prepared_statements.len(), 1);
1354 assert_eq!(conn.stmt_cache.len(), 1);
1355 }
1356
1357 #[cfg(unix)]
1358 #[tokio::test]
1359 async fn capture_query_server_error_does_not_override_existing_error() {
1360 let mut conn = make_test_conn_with_prepared();
1361 let mut slot = Some(server_error("23505", "duplicate key"));
1362 let retryable = server_error("26000", "prepared statement \"s1\" does not exist");
1363 capture_query_server_error(&mut conn, &mut slot, retryable);
1364
1365 assert_eq!(conn.prepared_statements.len(), 1);
1366 assert_eq!(conn.stmt_cache.len(), 1);
1367 assert_eq!(
1368 slot.and_then(|e| e.sqlstate().map(str::to_string))
1369 .as_deref(),
1370 Some("23505")
1371 );
1372 }
1373}