1use super::{
14 PgConnection, PgError, PgResult, is_ignorable_session_message, is_ignorable_session_msg_type,
15 unexpected_backend_message, unexpected_backend_msg_type,
16};
17use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
18use bytes::{Bytes, BytesMut};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum AstPipelineMode {
26 #[default]
30 Auto,
31 OneShot,
33 Cached,
35}
36
37impl AstPipelineMode {
38 const AUTO_CACHE_MIN_BATCH: usize = 8;
39
40 #[inline]
41 fn resolve_for_batch_len(self, batch_len: usize) -> Self {
42 match self {
43 Self::Auto => {
44 if batch_len >= Self::AUTO_CACHE_MIN_BATCH {
45 Self::Cached
46 } else {
47 Self::OneShot
48 }
49 }
50 mode => mode,
51 }
52 }
53}
54
55#[inline]
56fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
57 if matches!(
58 err,
59 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
60 ) {
61 conn.mark_io_desynced();
62 }
63 Err(err)
64}
65
66#[inline]
67fn return_callback_error_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
68 conn.mark_io_desynced();
69 Err(err)
70}
71
72#[inline]
73fn capture_query_server_error(conn: &mut PgConnection, slot: &mut Option<PgError>, err: PgError) {
74 if slot.is_some() {
75 return;
76 }
77 if err.is_prepared_statement_retryable() {
78 conn.clear_prepared_statement_state();
79 }
80 *slot = Some(err);
81}
82
83#[inline]
84fn rollback_new_cached_statements_from(
85 conn: &mut PgConnection,
86 new_stmt_hashes: &[u64],
87 start_idx: usize,
88) {
89 for sql_hash in &new_stmt_hashes[start_idx.min(new_stmt_hashes.len())..] {
90 conn.stmt_cache.remove(sql_hash);
91 let stmt_name = super::prepared::stmt_name_from_hash(*sql_hash);
92 conn.prepared_statements.remove(&stmt_name);
93 conn.column_info_cache.remove(sql_hash);
94 }
95}
96
97#[inline]
98fn rollback_new_cached_statements(conn: &mut PgConnection, new_stmt_hashes: &[u64]) {
99 rollback_new_cached_statements_from(conn, new_stmt_hashes, 0);
100}
101
102#[inline]
103fn enforce_prepared_statement_cache_limit(conn: &mut PgConnection) {
104 while conn.prepared_statements.len() > PgConnection::MAX_PREPARED_PER_CONN {
105 conn.evict_prepared_if_full();
106 }
107}
108
109#[inline]
110fn reconcile_new_cached_statements_after_server_error(
111 conn: &mut PgConnection,
112 new_stmt_hashes: &[u64],
113 parse_completes: usize,
114) {
115 rollback_new_cached_statements_from(conn, new_stmt_hashes, parse_completes);
116 enforce_prepared_statement_cache_limit(conn);
117}
118
119#[inline]
120fn reserve_prepared_pipeline_write_buf(
121 conn: &mut PgConnection,
122 stmt: &super::PreparedStatement,
123 params_batch: &[Vec<Option<Vec<u8>>>],
124 result_format: i16,
125) -> PgResult<()> {
126 conn.write_buf.clear();
127 let mut needed = 5usize;
128 for params in params_batch {
129 let bind_execute = PgEncoder::bind_execute_wire_len_with_formats(
130 &stmt.name,
131 params,
132 PgEncoder::FORMAT_TEXT,
133 result_format,
134 )
135 .map_err(|e| PgError::Encode(e.to_string()))?;
136 needed = needed
137 .checked_add(bind_execute)
138 .ok_or_else(|| PgError::Encode("prepared pipeline batch too large".to_string()))?;
139 }
140 conn.write_buf.reserve(needed);
141 Ok(())
142}
143
144#[derive(Debug, Clone, Copy)]
145struct FastExtendedFlowConfig {
146 expected_queries: usize,
147 allow_parse_complete: bool,
148 require_parse_before_bind: bool,
149 no_data_counts_as_completion: bool,
150 allow_no_data_nonterminal: bool,
151 expected_parse_completes: Option<usize>,
152}
153
154#[derive(Debug, Clone, Copy)]
155struct FastExtendedFlowTracker {
156 cfg: FastExtendedFlowConfig,
157 completed_queries: usize,
158 parse_completes: usize,
159 current_parse_seen: bool,
160 current_bind_seen: bool,
161}
162
163impl FastExtendedFlowTracker {
164 fn new(cfg: FastExtendedFlowConfig) -> Self {
165 Self {
166 cfg,
167 completed_queries: 0,
168 parse_completes: 0,
169 current_parse_seen: false,
170 current_bind_seen: false,
171 }
172 }
173
174 fn completed_queries(&self) -> usize {
175 self.completed_queries
176 }
177
178 fn validate_msg_type(
179 &mut self,
180 msg_type: u8,
181 context: &'static str,
182 error_pending: bool,
183 ) -> PgResult<FastPipelineEvent> {
184 if is_ignorable_session_msg_type(msg_type) {
185 return Ok(FastPipelineEvent::Continue);
186 }
187
188 if error_pending {
189 if msg_type == b'Z' {
190 return Ok(FastPipelineEvent::ReadyForQuery);
191 }
192 return Ok(FastPipelineEvent::Continue);
193 }
194
195 if msg_type == b'Z' {
196 if self.completed_queries != self.cfg.expected_queries {
197 return Err(PgError::Protocol(format!(
198 "{}: Pipeline completion mismatch: expected {}, got {}",
199 context, self.cfg.expected_queries, self.completed_queries
200 )));
201 }
202 if self.current_parse_seen || self.current_bind_seen {
203 return Err(PgError::Protocol(format!(
204 "{}: ReadyForQuery with incomplete query state",
205 context
206 )));
207 }
208 if let Some(expected) = self.cfg.expected_parse_completes
209 && self.parse_completes != expected
210 {
211 return Err(PgError::Protocol(format!(
212 "{}: ParseComplete mismatch: expected {}, got {}",
213 context, expected, self.parse_completes
214 )));
215 }
216 return Ok(FastPipelineEvent::ReadyForQuery);
217 }
218
219 if self.completed_queries >= self.cfg.expected_queries {
220 return Err(PgError::Protocol(format!(
221 "{}: unexpected message '{}' after all queries completed",
222 context, msg_type as char
223 )));
224 }
225
226 match msg_type {
227 b'1' => {
228 if !self.cfg.allow_parse_complete {
229 return Err(PgError::Protocol(format!(
230 "{}: unexpected ParseComplete",
231 context
232 )));
233 }
234 if self.current_bind_seen {
235 return Err(PgError::Protocol(format!(
236 "{}: ParseComplete after BindComplete",
237 context
238 )));
239 }
240 if self.current_parse_seen {
241 return Err(PgError::Protocol(format!(
242 "{}: duplicate ParseComplete",
243 context
244 )));
245 }
246 self.current_parse_seen = true;
247 self.parse_completes += 1;
248 if let Some(expected) = self.cfg.expected_parse_completes
249 && self.parse_completes > expected
250 {
251 return Err(PgError::Protocol(format!(
252 "{}: ParseComplete mismatch: expected {}, got at least {}",
253 context, expected, self.parse_completes
254 )));
255 }
256 }
257 b'2' => {
258 if self.current_bind_seen {
259 return Err(PgError::Protocol(format!(
260 "{}: duplicate BindComplete",
261 context
262 )));
263 }
264 if self.cfg.require_parse_before_bind && !self.current_parse_seen {
265 return Err(PgError::Protocol(format!(
266 "{}: BindComplete before ParseComplete",
267 context
268 )));
269 }
270 self.current_bind_seen = true;
271 }
272 b'T' | b't' | b's' => {
273 if !self.current_bind_seen {
274 return Err(PgError::Protocol(format!(
275 "{}: '{}' before BindComplete",
276 context, msg_type as char
277 )));
278 }
279 }
280 b'D' => {
281 if !self.current_bind_seen {
282 return Err(PgError::Protocol(format!(
283 "{}: DataRow before BindComplete",
284 context
285 )));
286 }
287 }
288 b'n' => {
289 if !self.current_bind_seen {
290 return Err(PgError::Protocol(format!(
291 "{}: NoData before BindComplete",
292 context
293 )));
294 }
295 if self.cfg.no_data_counts_as_completion {
296 self.complete_current();
297 } else if !self.cfg.allow_no_data_nonterminal {
298 return Err(PgError::Protocol(format!("{}: unexpected NoData", context)));
299 }
300 }
301 b'C' => {
302 if !self.current_bind_seen {
303 return Err(PgError::Protocol(format!(
304 "{}: CommandComplete before BindComplete",
305 context
306 )));
307 }
308 self.complete_current();
309 }
310 b'I' => {
311 return Err(PgError::Protocol(format!(
312 "{}: unexpected EmptyQueryResponse in extended pipeline",
313 context
314 )));
315 }
316 other => return Err(unexpected_backend_msg_type(context, other)),
317 }
318
319 Ok(FastPipelineEvent::Continue)
320 }
321
322 fn complete_current(&mut self) {
323 self.completed_queries += 1;
324 self.current_parse_seen = false;
325 self.current_bind_seen = false;
326 }
327}
328
329#[derive(Debug, Clone, Copy)]
330struct FastSimpleFlowTracker {
331 expected_queries: usize,
332 completed_queries: usize,
333 current_row_description_seen: bool,
334}
335
336impl FastSimpleFlowTracker {
337 fn new(expected_queries: usize) -> Self {
338 Self {
339 expected_queries,
340 completed_queries: 0,
341 current_row_description_seen: false,
342 }
343 }
344
345 fn completed_queries(&self) -> usize {
346 self.completed_queries
347 }
348
349 fn validate_msg_type(
350 &mut self,
351 msg_type: u8,
352 context: &'static str,
353 error_pending: bool,
354 ) -> PgResult<FastPipelineEvent> {
355 if is_ignorable_session_msg_type(msg_type) {
356 return Ok(FastPipelineEvent::Continue);
357 }
358
359 if error_pending {
360 if msg_type == b'Z' {
361 return Ok(FastPipelineEvent::ReadyForQuery);
362 }
363 return Ok(FastPipelineEvent::Continue);
364 }
365
366 if msg_type == b'Z' {
367 if self.completed_queries != self.expected_queries {
368 return Err(PgError::Protocol(format!(
369 "{}: Pipeline completion mismatch: expected {}, got {}",
370 context, self.expected_queries, self.completed_queries
371 )));
372 }
373 if self.current_row_description_seen {
374 return Err(PgError::Protocol(format!(
375 "{}: ReadyForQuery with incomplete row stream",
376 context
377 )));
378 }
379 return Ok(FastPipelineEvent::ReadyForQuery);
380 }
381
382 if self.completed_queries >= self.expected_queries {
383 return Err(PgError::Protocol(format!(
384 "{}: unexpected message '{}' after all queries completed",
385 context, msg_type as char
386 )));
387 }
388
389 match msg_type {
390 b'T' => {
391 if self.current_row_description_seen {
392 return Err(PgError::Protocol(format!(
393 "{}: duplicate RowDescription",
394 context
395 )));
396 }
397 self.current_row_description_seen = true;
398 }
399 b'D' => {
400 if !self.current_row_description_seen {
401 return Err(PgError::Protocol(format!(
402 "{}: DataRow before RowDescription",
403 context
404 )));
405 }
406 }
407 b'C' | b'I' => {
408 self.completed_queries += 1;
409 self.current_row_description_seen = false;
410 }
411 b'1' | b'2' | b'n' | b't' | b's' => {
412 return Err(PgError::Protocol(format!(
413 "{}: unexpected '{}' in simple pipeline",
414 context, msg_type as char
415 )));
416 }
417 other => return Err(unexpected_backend_msg_type(context, other)),
418 }
419
420 Ok(FastPipelineEvent::Continue)
421 }
422}
423
424#[derive(Debug, Clone, Copy, PartialEq, Eq)]
425enum FastPipelineEvent {
426 Continue,
427 ReadyForQuery,
428}
429
430#[inline]
431fn backend_msg_type_for_flow(msg: &BackendMessage) -> Option<u8> {
432 match msg {
433 BackendMessage::ParseComplete => Some(b'1'),
434 BackendMessage::BindComplete => Some(b'2'),
435 BackendMessage::ParameterDescription(_) => Some(b't'),
436 BackendMessage::RowDescription(_) => Some(b'T'),
437 BackendMessage::NoData => Some(b'n'),
438 BackendMessage::PortalSuspended => Some(b's'),
439 BackendMessage::DataRow(_) => Some(b'D'),
440 BackendMessage::CommandComplete(_) => Some(b'C'),
441 BackendMessage::EmptyQueryResponse => Some(b'I'),
442 BackendMessage::ReadyForQuery(_) => Some(b'Z'),
443 _ => None,
444 }
445}
446
447impl PgConnection {
448 pub async fn query_pipeline(
450 &mut self,
451 queries: &[(&str, &[Option<Vec<u8>>])],
452 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
453 let mut buf = BytesMut::new();
455 for (sql, params) in queries {
456 buf.extend_from_slice(
457 &PgEncoder::try_encode_parse("", sql, &[])
458 .map_err(|e| PgError::Encode(e.to_string()))?,
459 );
460 buf.extend_from_slice(
461 &PgEncoder::encode_bind("", "", params)
462 .map_err(|e| PgError::Encode(e.to_string()))?,
463 );
464 buf.extend_from_slice(
465 &PgEncoder::try_encode_execute("", 0)
466 .map_err(|e| PgError::Encode(e.to_string()))?,
467 );
468 }
469 buf.extend_from_slice(&PgEncoder::encode_sync());
470
471 self.write_all_with_timeout(&buf, "stream write").await?;
473
474 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
476 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
477 let mut error: Option<PgError> = None;
478 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
479 expected_queries: queries.len(),
480 allow_parse_complete: true,
481 require_parse_before_bind: true,
482 no_data_counts_as_completion: true,
483 allow_no_data_nonterminal: false,
484 expected_parse_completes: Some(queries.len()),
485 });
486
487 loop {
488 let msg = self.recv().await?;
489 if is_ignorable_session_message(&msg) {
490 continue;
491 }
492 if let BackendMessage::ErrorResponse(err) = msg {
493 if error.is_none() {
494 error = Some(PgError::QueryServer(err.into()));
495 }
496 continue;
497 }
498 let msg_type = backend_msg_type_for_flow(&msg)
499 .ok_or_else(|| unexpected_backend_message("pipeline query", &msg));
500 let msg_type = match msg_type {
501 Ok(msg_type) => msg_type,
502 Err(err) => return return_with_desync(self, err),
503 };
504 if let Err(err) = flow.validate_msg_type(msg_type, "pipeline query", error.is_some()) {
505 return return_with_desync(self, err);
506 }
507 match msg {
508 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
509 BackendMessage::RowDescription(_) => {}
510 BackendMessage::DataRow(data) => {
511 if error.is_none() {
512 current_rows.push(data);
513 }
514 }
515 BackendMessage::CommandComplete(_) => {
516 all_results.push(std::mem::take(&mut current_rows));
517 }
518 BackendMessage::NoData => {
519 all_results.push(Vec::new());
520 }
521 BackendMessage::ReadyForQuery(_) => {
522 if all_results.len() != queries.len() {
523 return Err(error.unwrap_or_else(|| {
524 PgError::Protocol(format!(
525 "Pipeline completion mismatch: expected {}, got {}",
526 queries.len(),
527 all_results.len()
528 ))
529 }));
530 }
531 if let Some(err) = error {
532 return Err(err);
533 }
534 return Ok(all_results);
535 }
536 other => {
537 return return_with_desync(
538 self,
539 unexpected_backend_message("pipeline query", &other),
540 );
541 }
542 }
543 }
544 }
545
546 pub async fn query_pipeline_count(
549 &mut self,
550 queries: &[(&str, &[Option<Vec<u8>>])],
551 ) -> PgResult<usize> {
552 if queries.is_empty() {
553 return Ok(0);
554 }
555
556 self.write_buf.clear();
557 for (sql, params) in queries {
558 PgEncoder::try_encode_parse_to(&mut self.write_buf, "", sql, &[])
559 .map_err(|e| PgError::Encode(e.to_string()))?;
560 PgEncoder::encode_bind_to(&mut self.write_buf, "", params)
561 .map_err(|e| PgError::Encode(e.to_string()))?;
562 PgEncoder::encode_execute_to(&mut self.write_buf);
563 }
564 PgEncoder::encode_sync_to(&mut self.write_buf);
565
566 self.flush_write_buf().await?;
567
568 let mut error: Option<PgError> = None;
569 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
570 expected_queries: queries.len(),
571 allow_parse_complete: true,
572 require_parse_before_bind: true,
573 no_data_counts_as_completion: true,
574 allow_no_data_nonterminal: false,
575 expected_parse_completes: Some(queries.len()),
576 });
577
578 loop {
579 match self.recv_msg_type_fast().await {
580 Ok(msg_type) => {
581 let event = match flow.validate_msg_type(
582 msg_type,
583 "query_pipeline_count",
584 error.is_some(),
585 ) {
586 Ok(event) => event,
587 Err(err) => return return_with_desync(self, err),
588 };
589 match event {
590 FastPipelineEvent::Continue => {}
591 FastPipelineEvent::ReadyForQuery => {
592 if let Some(err) = error {
593 return Err(err);
594 }
595 return Ok(flow.completed_queries());
596 }
597 }
598 }
599 Err(e) => {
600 if matches!(&e, PgError::QueryServer(_)) {
601 capture_query_server_error(self, &mut error, e);
602 continue;
603 }
604 return Err(e);
605 }
606 }
607 }
608 }
609
610 pub async fn query_pipeline_visit_bytes_rows<F>(
613 &mut self,
614 queries: &[(&str, &[Option<Vec<u8>>])],
615 mut on_row: F,
616 ) -> PgResult<usize>
617 where
618 F: FnMut(&super::PgBytesRow) -> PgResult<()>,
619 {
620 if queries.is_empty() {
621 return Ok(0);
622 }
623
624 self.write_buf.clear();
625 for (sql, params) in queries {
626 PgEncoder::try_encode_parse_to(&mut self.write_buf, "", sql, &[])
627 .map_err(|e| PgError::Encode(e.to_string()))?;
628 PgEncoder::encode_bind_to(&mut self.write_buf, "", params)
629 .map_err(|e| PgError::Encode(e.to_string()))?;
630 PgEncoder::encode_execute_to(&mut self.write_buf);
631 }
632 PgEncoder::encode_sync_to(&mut self.write_buf);
633
634 self.flush_write_buf().await?;
635
636 let mut row = super::PgBytesRow::default();
637 let mut error: Option<PgError> = None;
638 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
639 expected_queries: queries.len(),
640 allow_parse_complete: true,
641 require_parse_before_bind: true,
642 no_data_counts_as_completion: true,
643 allow_no_data_nonterminal: false,
644 expected_parse_completes: Some(queries.len()),
645 });
646
647 loop {
648 match self.recv_fill_zerocopy_row_fast(&mut row).await {
649 Ok(msg_type) => {
650 if let Err(err) = flow.validate_msg_type(
651 msg_type,
652 "query_pipeline_visit_bytes_rows",
653 error.is_some(),
654 ) {
655 return return_with_desync(self, err);
656 }
657 match msg_type {
658 b'1' | b'2' | b'T' | b'C' | b'n' => {}
659 b'D' => {
660 if error.is_none() {
661 if let Err(err) = on_row(&row) {
662 return return_callback_error_with_desync(self, err);
663 }
664 row.release_payload();
665 }
666 }
667 b'Z' => {
668 if let Some(err) = error {
669 return Err(err);
670 }
671 return Ok(flow.completed_queries());
672 }
673 msg_type if is_ignorable_session_msg_type(msg_type) => {}
674 other => {
675 return return_with_desync(
676 self,
677 unexpected_backend_msg_type(
678 "query_pipeline_visit_bytes_rows",
679 other,
680 ),
681 );
682 }
683 }
684 }
685 Err(e) => {
686 if matches!(&e, PgError::QueryServer(_)) {
687 capture_query_server_error(self, &mut error, e);
688 continue;
689 }
690 return Err(e);
691 }
692 }
693 }
694 }
695
696 pub async fn query_pipeline_visit_first_column_bytes<F>(
699 &mut self,
700 queries: &[(&str, &[Option<Vec<u8>>])],
701 mut on_value: F,
702 ) -> PgResult<usize>
703 where
704 F: FnMut(Option<&[u8]>) -> PgResult<()>,
705 {
706 if queries.is_empty() {
707 return Ok(0);
708 }
709
710 self.write_buf.clear();
711 for (sql, params) in queries {
712 PgEncoder::try_encode_parse_to(&mut self.write_buf, "", sql, &[])
713 .map_err(|e| PgError::Encode(e.to_string()))?;
714 PgEncoder::encode_bind_to(&mut self.write_buf, "", params)
715 .map_err(|e| PgError::Encode(e.to_string()))?;
716 PgEncoder::encode_execute_to(&mut self.write_buf);
717 }
718 PgEncoder::encode_sync_to(&mut self.write_buf);
719
720 self.flush_write_buf().await?;
721
722 let mut first_column: Option<Bytes> = None;
723 let mut error: Option<PgError> = None;
724 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
725 expected_queries: queries.len(),
726 allow_parse_complete: true,
727 require_parse_before_bind: true,
728 no_data_counts_as_completion: true,
729 allow_no_data_nonterminal: false,
730 expected_parse_completes: Some(queries.len()),
731 });
732
733 loop {
734 match self
735 .recv_fill_first_column_zerocopy_fast(&mut first_column)
736 .await
737 {
738 Ok(msg_type) => {
739 if let Err(err) = flow.validate_msg_type(
740 msg_type,
741 "query_pipeline_visit_first_column_bytes",
742 error.is_some(),
743 ) {
744 return return_with_desync(self, err);
745 }
746 match msg_type {
747 b'1' | b'2' | b'T' | b'C' | b'n' => {}
748 b'D' => {
749 if error.is_none() {
750 if let Err(err) = on_value(first_column.as_deref()) {
751 return return_callback_error_with_desync(self, err);
752 }
753 first_column = None;
754 }
755 }
756 b'Z' => {
757 if let Some(err) = error {
758 return Err(err);
759 }
760 return Ok(flow.completed_queries());
761 }
762 msg_type if is_ignorable_session_msg_type(msg_type) => {}
763 other => {
764 return return_with_desync(
765 self,
766 unexpected_backend_msg_type(
767 "query_pipeline_visit_first_column_bytes",
768 other,
769 ),
770 );
771 }
772 }
773 }
774 Err(e) => {
775 if matches!(&e, PgError::QueryServer(_)) {
776 capture_query_server_error(self, &mut error, e);
777 continue;
778 }
779 return Err(e);
780 }
781 }
782 }
783 }
784
785 pub async fn pipeline_execute_rows_ast(
787 &mut self,
788 cmds: &[qail_core::ast::Qail],
789 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
790 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
791 self.write_all_with_timeout(&buf, "stream write").await?;
792
793 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
794 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
795 let mut error: Option<PgError> = None;
796 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
797 expected_queries: cmds.len(),
798 allow_parse_complete: true,
799 require_parse_before_bind: true,
800 no_data_counts_as_completion: true,
801 allow_no_data_nonterminal: false,
802 expected_parse_completes: Some(cmds.len()),
803 });
804
805 loop {
806 let msg = self.recv().await?;
807 if is_ignorable_session_message(&msg) {
808 continue;
809 }
810 if let BackendMessage::ErrorResponse(err) = msg {
811 if error.is_none() {
812 error = Some(PgError::QueryServer(err.into()));
813 }
814 continue;
815 }
816 let msg_type = backend_msg_type_for_flow(&msg)
817 .ok_or_else(|| unexpected_backend_message("pipeline ast", &msg));
818 let msg_type = match msg_type {
819 Ok(msg_type) => msg_type,
820 Err(err) => return return_with_desync(self, err),
821 };
822 if let Err(err) = flow.validate_msg_type(msg_type, "pipeline ast", error.is_some()) {
823 return return_with_desync(self, err);
824 }
825 match msg {
826 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
827 BackendMessage::RowDescription(_) => {}
828 BackendMessage::DataRow(data) => {
829 if error.is_none() {
830 current_rows.push(data);
831 }
832 }
833 BackendMessage::CommandComplete(_) => {
834 all_results.push(std::mem::take(&mut current_rows));
835 }
836 BackendMessage::NoData => {
837 all_results.push(Vec::new());
838 }
839 BackendMessage::ReadyForQuery(_) => {
840 if all_results.len() != cmds.len() {
841 return Err(error.unwrap_or_else(|| {
842 PgError::Protocol(format!(
843 "Pipeline completion mismatch: expected {}, got {}",
844 cmds.len(),
845 all_results.len()
846 ))
847 }));
848 }
849 if let Some(err) = error {
850 return Err(err);
851 }
852 return Ok(all_results);
853 }
854 other => {
855 return return_with_desync(
856 self,
857 unexpected_backend_message("pipeline ast", &other),
858 );
859 }
860 }
861 }
862 }
863
864 pub async fn pipeline_execute_count_ast_oneshot(
866 &mut self,
867 cmds: &[qail_core::ast::Qail],
868 ) -> PgResult<usize> {
869 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
870
871 self.write_all_with_timeout(&buf, "stream write").await?;
872 self.flush_with_timeout("stream flush").await?;
873
874 let mut error: Option<PgError> = None;
875 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
876 expected_queries: cmds.len(),
877 allow_parse_complete: true,
878 require_parse_before_bind: true,
879 no_data_counts_as_completion: true,
880 allow_no_data_nonterminal: false,
881 expected_parse_completes: Some(cmds.len()),
882 });
883
884 loop {
885 match self.recv_msg_type_fast().await {
886 Ok(msg_type) => {
887 let event = match flow.validate_msg_type(
888 msg_type,
889 "pipeline_execute_count_ast_oneshot",
890 error.is_some(),
891 ) {
892 Ok(event) => event,
893 Err(err) => return return_with_desync(self, err),
894 };
895 match event {
896 FastPipelineEvent::Continue => {}
897 FastPipelineEvent::ReadyForQuery => {
898 if let Some(err) = error {
899 return Err(err);
900 }
901 return Ok(flow.completed_queries());
902 }
903 }
904 }
905 Err(e) => {
906 if matches!(&e, PgError::QueryServer(_)) {
907 capture_query_server_error(self, &mut error, e);
908 continue;
909 }
910 return Err(e);
911 }
912 }
913 }
914 }
915
916 #[inline]
922 pub async fn pipeline_execute_count_ast_with_mode(
923 &mut self,
924 cmds: &[qail_core::ast::Qail],
925 mode: AstPipelineMode,
926 ) -> PgResult<usize> {
927 if cmds.is_empty() {
928 return Ok(0);
929 }
930
931 match mode.resolve_for_batch_len(cmds.len()) {
932 AstPipelineMode::OneShot => self.pipeline_execute_count_ast_oneshot(cmds).await,
933 AstPipelineMode::Cached => self.pipeline_execute_count_ast_cached(cmds).await,
934 AstPipelineMode::Auto => Err(PgError::Protocol(
935 "auto pipeline mode did not resolve to a concrete strategy".to_string(),
936 )),
937 }
938 }
939
940 #[inline]
942 pub async fn pipeline_execute_count_wire(
943 &mut self,
944 wire_bytes: &[u8],
945 expected_queries: usize,
946 ) -> PgResult<usize> {
947 self.write_all_with_timeout(wire_bytes, "stream write")
948 .await?;
949 self.flush_with_timeout("stream flush").await?;
950
951 let mut error: Option<PgError> = None;
952 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
953 expected_queries,
954 allow_parse_complete: true,
955 require_parse_before_bind: false,
956 no_data_counts_as_completion: true,
957 allow_no_data_nonterminal: false,
958 expected_parse_completes: None,
959 });
960
961 loop {
962 match self.recv_msg_type_fast().await {
963 Ok(msg_type) => {
964 let event = match flow.validate_msg_type(
965 msg_type,
966 "pipeline_execute_count_wire",
967 error.is_some(),
968 ) {
969 Ok(event) => event,
970 Err(err) => return return_with_desync(self, err),
971 };
972 match event {
973 FastPipelineEvent::Continue => {}
974 FastPipelineEvent::ReadyForQuery => {
975 if let Some(err) = error {
976 return Err(err);
977 }
978 return Ok(flow.completed_queries());
979 }
980 }
981 }
982 Err(e) => {
983 if matches!(&e, PgError::QueryServer(_)) {
984 capture_query_server_error(self, &mut error, e);
985 continue;
986 }
987 return Err(e);
988 }
989 }
990 }
991 }
992
993 #[inline]
995 pub async fn pipeline_execute_count_simple_ast(
996 &mut self,
997 cmds: &[qail_core::ast::Qail],
998 ) -> PgResult<usize> {
999 if cmds.is_empty() {
1000 return Ok(0);
1001 }
1002
1003 let buf =
1004 AstEncoder::encode_batch_simple(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
1005 self.write_all_with_timeout(&buf, "stream write").await?;
1006 self.flush_with_timeout("stream flush").await?;
1007
1008 let mut error: Option<PgError> = None;
1009 let mut flow = FastSimpleFlowTracker::new(cmds.len());
1010
1011 loop {
1012 match self.recv_msg_type_fast().await {
1013 Ok(msg_type) => {
1014 let event = match flow.validate_msg_type(
1015 msg_type,
1016 "pipeline_execute_count_simple_ast",
1017 error.is_some(),
1018 ) {
1019 Ok(event) => event,
1020 Err(err) => return return_with_desync(self, err),
1021 };
1022 match event {
1023 FastPipelineEvent::Continue => {}
1024 FastPipelineEvent::ReadyForQuery => {
1025 if let Some(err) = error {
1026 return Err(err);
1027 }
1028 return Ok(flow.completed_queries());
1029 }
1030 }
1031 }
1032 Err(e) => {
1033 if matches!(&e, PgError::QueryServer(_)) {
1034 capture_query_server_error(self, &mut error, e);
1035 continue;
1036 }
1037 return Err(e);
1038 }
1039 }
1040 }
1041 }
1042
1043 #[inline]
1045 pub async fn pipeline_execute_count_simple_wire(
1046 &mut self,
1047 wire_bytes: &[u8],
1048 expected_queries: usize,
1049 ) -> PgResult<usize> {
1050 self.write_all_with_timeout(wire_bytes, "stream write")
1051 .await?;
1052 self.flush_with_timeout("stream flush").await?;
1053
1054 let mut error: Option<PgError> = None;
1055 let mut flow = FastSimpleFlowTracker::new(expected_queries);
1056
1057 loop {
1058 match self.recv_msg_type_fast().await {
1059 Ok(msg_type) => {
1060 let event = match flow.validate_msg_type(
1061 msg_type,
1062 "pipeline_execute_count_simple_wire",
1063 error.is_some(),
1064 ) {
1065 Ok(event) => event,
1066 Err(err) => return return_with_desync(self, err),
1067 };
1068 match event {
1069 FastPipelineEvent::Continue => {}
1070 FastPipelineEvent::ReadyForQuery => {
1071 if let Some(err) = error {
1072 return Err(err);
1073 }
1074 return Ok(flow.completed_queries());
1075 }
1076 }
1077 }
1078 Err(e) => {
1079 if matches!(&e, PgError::QueryServer(_)) {
1080 capture_query_server_error(self, &mut error, e);
1081 continue;
1082 }
1083 return Err(e);
1084 }
1085 }
1086 }
1087 }
1088
1089 #[inline]
1094 pub async fn pipeline_execute_count_ast_cached(
1095 &mut self,
1096 cmds: &[qail_core::ast::Qail],
1097 ) -> PgResult<usize> {
1098 if cmds.is_empty() {
1099 return Ok(0);
1100 }
1101
1102 use super::prepared::{sql_bytes_hash, stmt_name_from_hash};
1103
1104 let mut buf = BytesMut::with_capacity(cmds.len() * 64);
1105 let mut sql_buf = BytesMut::with_capacity(256);
1106 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
1107 let mut new_stmt_hashes: Vec<u64> = Vec::new();
1108
1109 for cmd in cmds {
1110 if let Err(e) = AstEncoder::encode_cmd_sql_reuse(cmd, &mut sql_buf, &mut params) {
1111 rollback_new_cached_statements(self, &new_stmt_hashes);
1112 return Err(PgError::Encode(e.to_string()));
1113 }
1114
1115 let sql_hash = sql_bytes_hash(sql_buf.as_ref());
1116
1117 if self.stmt_cache.contains(&sql_hash) {
1118 self.stmt_cache.touch_key(sql_hash);
1119 } else {
1120 let stmt_name = stmt_name_from_hash(sql_hash);
1121 if self.prepared_statements.contains_key(&stmt_name) {
1122 self.stmt_cache.put(sql_hash, stmt_name.clone());
1125 } else {
1126 let sql = match std::str::from_utf8(sql_buf.as_ref()) {
1127 Ok(sql) => sql.to_string(),
1128 Err(e) => {
1129 rollback_new_cached_statements(self, &new_stmt_hashes);
1130 return Err(PgError::Encode(format!(
1131 "encoded SQL is not UTF-8: {}",
1132 e
1133 )));
1134 }
1135 };
1136 let parse_msg = match PgEncoder::try_encode_parse(&stmt_name, &sql, &[]) {
1137 Ok(msg) => msg,
1138 Err(e) => {
1139 rollback_new_cached_statements(self, &new_stmt_hashes);
1140 return Err(PgError::Encode(e.to_string()));
1141 }
1142 };
1143 buf.extend(parse_msg);
1144 self.stmt_cache.put(sql_hash, stmt_name.clone());
1145 self.prepared_statements.insert(stmt_name.clone(), sql);
1146 new_stmt_hashes.push(sql_hash);
1147 }
1148 }
1149
1150 let Some(stmt_name) = self.stmt_cache.peek(&sql_hash) else {
1151 rollback_new_cached_statements(self, &new_stmt_hashes);
1152 return Err(PgError::Protocol(
1153 "stmt_cache lookup failed after statement registration".to_string(),
1154 ));
1155 };
1156
1157 if let Err(e) = PgEncoder::encode_bind_to(&mut buf, stmt_name, ¶ms) {
1158 rollback_new_cached_statements(self, &new_stmt_hashes);
1159 return Err(PgError::Encode(e.to_string()));
1160 }
1161 PgEncoder::encode_execute_to(&mut buf);
1162 }
1163
1164 PgEncoder::encode_sync_to(&mut buf);
1165
1166 if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
1167 rollback_new_cached_statements(self, &new_stmt_hashes);
1168 return Err(err);
1169 }
1170 if let Err(err) = self.flush_with_timeout("stream flush").await {
1171 rollback_new_cached_statements(self, &new_stmt_hashes);
1172 return Err(err);
1173 }
1174
1175 let mut error: Option<PgError> = None;
1176 let expected_parse_completes = new_stmt_hashes.len();
1177 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1178 expected_queries: cmds.len(),
1179 allow_parse_complete: true,
1180 require_parse_before_bind: false,
1181 no_data_counts_as_completion: true,
1182 allow_no_data_nonterminal: false,
1183 expected_parse_completes: Some(expected_parse_completes),
1184 });
1185
1186 loop {
1187 match self.recv_msg_type_fast().await {
1188 Ok(msg_type) => {
1189 match flow.validate_msg_type(
1190 msg_type,
1191 "pipeline_execute_count_ast_cached",
1192 error.is_some(),
1193 ) {
1194 Ok(FastPipelineEvent::Continue) => {}
1195 Ok(FastPipelineEvent::ReadyForQuery) => {
1196 if let Some(err) = error {
1197 reconcile_new_cached_statements_after_server_error(
1198 self,
1199 &new_stmt_hashes,
1200 flow.parse_completes,
1201 );
1202 return Err(err);
1203 }
1204 enforce_prepared_statement_cache_limit(self);
1205 return Ok(flow.completed_queries());
1206 }
1207 Err(err) => {
1208 rollback_new_cached_statements(self, &new_stmt_hashes);
1209 return return_with_desync(self, err);
1210 }
1211 }
1212 }
1213 Err(e) => {
1214 if matches!(&e, PgError::QueryServer(_)) {
1215 capture_query_server_error(self, &mut error, e);
1216 continue;
1217 }
1218 rollback_new_cached_statements(self, &new_stmt_hashes);
1219 return Err(e);
1220 }
1221 }
1222 }
1223 }
1224 #[inline]
1239 pub async fn pipeline_execute_prepared_count(
1240 &mut self,
1241 stmt: &super::PreparedStatement,
1242 params_batch: &[Vec<Option<Vec<u8>>>],
1243 ) -> PgResult<usize> {
1244 if params_batch.is_empty() {
1245 return Ok(0);
1246 }
1247
1248 let is_new = !self.prepared_statements.contains_key(&stmt.name);
1249
1250 if is_new {
1251 return Err(PgError::Query(
1252 "Statement not prepared. Call prepare() first.".to_string(),
1253 ));
1254 }
1255
1256 self.write_buf.clear();
1257 for params in params_batch {
1258 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1259 .map_err(|e| PgError::Encode(e.to_string()))?;
1260 PgEncoder::encode_execute_to(&mut self.write_buf);
1261 }
1262
1263 PgEncoder::encode_sync_to(&mut self.write_buf);
1264 self.flush_write_buf().await?;
1265
1266 let mut error: Option<PgError> = None;
1267 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1268 expected_queries: params_batch.len(),
1269 allow_parse_complete: false,
1270 require_parse_before_bind: false,
1271 no_data_counts_as_completion: true,
1272 allow_no_data_nonterminal: false,
1273 expected_parse_completes: Some(0),
1274 });
1275
1276 loop {
1277 match self.recv_msg_type_fast().await {
1278 Ok(msg_type) => {
1279 let event = match flow.validate_msg_type(
1280 msg_type,
1281 "pipeline_execute_prepared_count",
1282 error.is_some(),
1283 ) {
1284 Ok(event) => event,
1285 Err(err) => return return_with_desync(self, err),
1286 };
1287 match event {
1288 FastPipelineEvent::Continue => {}
1289 FastPipelineEvent::ReadyForQuery => {
1290 if let Some(err) = error {
1291 return Err(err);
1292 }
1293 return Ok(flow.completed_queries());
1294 }
1295 }
1296 }
1297 Err(e) => {
1298 if matches!(&e, PgError::QueryServer(_)) {
1299 capture_query_server_error(self, &mut error, e);
1300 continue;
1301 }
1302 return Err(e);
1303 }
1304 }
1305 }
1306 }
1307
1308 pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
1311 use super::prepared::sql_bytes_to_stmt_name;
1312
1313 let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
1314
1315 if !self.prepared_statements.contains_key(&stmt_name) {
1316 self.evict_prepared_if_full();
1317 let mut buf = BytesMut::with_capacity(sql.len() + 32);
1318 buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
1319 buf.extend(PgEncoder::encode_sync());
1320
1321 self.write_all_with_timeout(&buf, "stream write").await?;
1322 self.flush_with_timeout("stream flush").await?;
1323
1324 let mut error: Option<PgError> = None;
1326 let mut saw_parse_complete = false;
1327 loop {
1328 match self.recv_msg_type_fast().await {
1329 Ok(msg_type) => match msg_type {
1330 b'1' => {
1331 if saw_parse_complete {
1332 return Err(PgError::Protocol(
1333 "prepare received duplicate ParseComplete".to_string(),
1334 ));
1335 }
1336 saw_parse_complete = true;
1337 self.prepared_statements
1338 .insert(stmt_name.clone(), sql.to_string());
1339 }
1340 b'Z' => {
1341 if let Some(err) = error {
1342 return Err(err);
1343 }
1344 if !saw_parse_complete {
1345 return Err(PgError::Protocol(
1346 "prepare reached ReadyForQuery without ParseComplete"
1347 .to_string(),
1348 ));
1349 }
1350 break;
1351 }
1352 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1353 other => {
1354 return return_with_desync(
1355 self,
1356 unexpected_backend_msg_type("prepare", other),
1357 );
1358 }
1359 },
1360 Err(e) => {
1361 if matches!(&e, PgError::QueryServer(_)) {
1362 capture_query_server_error(self, &mut error, e);
1363 continue;
1364 }
1365 return Err(e);
1366 }
1367 }
1368 }
1369 }
1370
1371 Ok(super::PreparedStatement { name: stmt_name })
1372 }
1373
1374 pub async fn pipeline_execute_prepared_rows(
1376 &mut self,
1377 stmt: &super::PreparedStatement,
1378 params_batch: &[Vec<Option<Vec<u8>>>],
1379 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
1380 if params_batch.is_empty() {
1381 return Ok(Vec::new());
1382 }
1383
1384 if !self.prepared_statements.contains_key(&stmt.name) {
1385 return Err(PgError::Query(
1386 "Statement not prepared. Call prepare() first.".to_string(),
1387 ));
1388 }
1389
1390 self.write_buf.clear();
1391 for params in params_batch {
1392 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1393 .map_err(|e| PgError::Encode(e.to_string()))?;
1394 PgEncoder::encode_execute_to(&mut self.write_buf);
1395 }
1396
1397 PgEncoder::encode_sync_to(&mut self.write_buf);
1398 self.flush_write_buf().await?;
1399
1400 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
1402 Vec::with_capacity(params_batch.len());
1403 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
1404 let mut error: Option<PgError> = None;
1405 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1406 expected_queries: params_batch.len(),
1407 allow_parse_complete: false,
1408 require_parse_before_bind: false,
1409 no_data_counts_as_completion: true,
1410 allow_no_data_nonterminal: false,
1411 expected_parse_completes: Some(0),
1412 });
1413
1414 loop {
1415 match self.recv_with_data_fast().await {
1416 Ok((msg_type, data)) => {
1417 if let Err(err) = flow.validate_msg_type(
1418 msg_type,
1419 "pipeline_execute_prepared_rows",
1420 error.is_some(),
1421 ) {
1422 return return_with_desync(self, err);
1423 }
1424 match msg_type {
1425 b'2' => {} b'T' => {} b'D' => {
1428 if error.is_none()
1430 && let Some(row) = data
1431 {
1432 current_rows.push(row);
1433 }
1434 }
1435 b'C' => {
1436 all_results.push(std::mem::take(&mut current_rows));
1438 }
1439 b'n' => {
1440 all_results.push(Vec::new());
1442 }
1443 b'Z' => {
1444 if all_results.len() != params_batch.len() {
1446 return Err(error.unwrap_or_else(|| {
1447 PgError::Protocol(format!(
1448 "Pipeline completion mismatch: expected {}, got {}",
1449 params_batch.len(),
1450 all_results.len()
1451 ))
1452 }));
1453 }
1454 if let Some(err) = error {
1455 return Err(err);
1456 }
1457 return Ok(all_results);
1458 }
1459 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1460 other => {
1461 return return_with_desync(
1462 self,
1463 unexpected_backend_msg_type(
1464 "pipeline_execute_prepared_rows",
1465 other,
1466 ),
1467 );
1468 }
1469 }
1470 }
1471 Err(e) => {
1472 if matches!(&e, PgError::QueryServer(_)) {
1473 capture_query_server_error(self, &mut error, e);
1474 continue;
1475 }
1476 return Err(e);
1477 }
1478 }
1479 }
1480 }
1481
1482 pub async fn pipeline_execute_prepared_rows_bytes(
1484 &mut self,
1485 stmt: &super::PreparedStatement,
1486 params_batch: &[Vec<Option<Vec<u8>>>],
1487 ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
1488 if params_batch.is_empty() {
1489 return Ok(Vec::new());
1490 }
1491
1492 if !self.prepared_statements.contains_key(&stmt.name) {
1493 return Err(PgError::Query(
1494 "Statement not prepared. Call prepare() first.".to_string(),
1495 ));
1496 }
1497
1498 reserve_prepared_pipeline_write_buf(self, stmt, params_batch, PgEncoder::FORMAT_TEXT)?;
1499
1500 for params in params_batch {
1501 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1502 .map_err(|e| PgError::Encode(e.to_string()))?;
1503 PgEncoder::encode_execute_to(&mut self.write_buf);
1504 }
1505
1506 PgEncoder::encode_sync_to(&mut self.write_buf);
1507 self.flush_write_buf().await?;
1508
1509 let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
1511 Vec::with_capacity(params_batch.len());
1512 let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
1513 let mut error: Option<PgError> = None;
1514 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1515 expected_queries: params_batch.len(),
1516 allow_parse_complete: false,
1517 require_parse_before_bind: false,
1518 no_data_counts_as_completion: true,
1519 allow_no_data_nonterminal: false,
1520 expected_parse_completes: Some(0),
1521 });
1522
1523 loop {
1524 match self.recv_data_zerocopy().await {
1525 Ok((msg_type, data)) => {
1526 if let Err(err) = flow.validate_msg_type(
1527 msg_type,
1528 "pipeline_execute_prepared_rows_bytes",
1529 error.is_some(),
1530 ) {
1531 return return_with_desync(self, err);
1532 }
1533 match msg_type {
1534 b'2' => {} b'T' => {} b'D' => {
1537 if error.is_none()
1539 && let Some(row) = data
1540 {
1541 current_rows.push(row);
1542 }
1543 }
1544 b'C' => {
1545 all_results.push(std::mem::take(&mut current_rows));
1547 }
1548 b'n' => {
1549 all_results.push(Vec::new());
1551 }
1552 b'Z' => {
1553 if all_results.len() != params_batch.len() {
1555 return Err(error.unwrap_or_else(|| {
1556 PgError::Protocol(format!(
1557 "Pipeline completion mismatch: expected {}, got {}",
1558 params_batch.len(),
1559 all_results.len()
1560 ))
1561 }));
1562 }
1563 if let Some(err) = error {
1564 return Err(err);
1565 }
1566 return Ok(all_results);
1567 }
1568 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1569 other => {
1570 return return_with_desync(
1571 self,
1572 unexpected_backend_msg_type(
1573 "pipeline_execute_prepared_rows_bytes",
1574 other,
1575 ),
1576 );
1577 }
1578 }
1579 }
1580 Err(e) => {
1581 if matches!(&e, PgError::QueryServer(_)) {
1582 capture_query_server_error(self, &mut error, e);
1583 continue;
1584 }
1585 return Err(e);
1586 }
1587 }
1588 }
1589 }
1590
1591 pub async fn pipeline_execute_prepared_visit_rows<F>(
1596 &mut self,
1597 stmt: &super::PreparedStatement,
1598 params_batch: &[Vec<Option<Vec<u8>>>],
1599 mut on_row: F,
1600 ) -> PgResult<usize>
1601 where
1602 F: FnMut(&[Option<Vec<u8>>]) -> PgResult<()>,
1603 {
1604 if params_batch.is_empty() {
1605 return Ok(0);
1606 }
1607
1608 if !self.prepared_statements.contains_key(&stmt.name) {
1609 return Err(PgError::Query(
1610 "Statement not prepared. Call prepare() first.".to_string(),
1611 ));
1612 }
1613
1614 reserve_prepared_pipeline_write_buf(self, stmt, params_batch, PgEncoder::FORMAT_TEXT)?;
1615
1616 for params in params_batch {
1617 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1618 .map_err(|e| PgError::Encode(e.to_string()))?;
1619 PgEncoder::encode_execute_to(&mut self.write_buf);
1620 }
1621
1622 PgEncoder::encode_sync_to(&mut self.write_buf);
1623 self.flush_write_buf().await?;
1624
1625 let mut row_buf: Vec<Option<Vec<u8>>> = Vec::new();
1626 let mut error: Option<PgError> = None;
1627 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1628 expected_queries: params_batch.len(),
1629 allow_parse_complete: false,
1630 require_parse_before_bind: false,
1631 no_data_counts_as_completion: true,
1632 allow_no_data_nonterminal: false,
1633 expected_parse_completes: Some(0),
1634 });
1635
1636 loop {
1637 match self.recv_fill_data_row_fast(&mut row_buf).await {
1638 Ok(msg_type) => {
1639 if let Err(err) = flow.validate_msg_type(
1640 msg_type,
1641 "pipeline_execute_prepared_visit_rows",
1642 error.is_some(),
1643 ) {
1644 return return_with_desync(self, err);
1645 }
1646 match msg_type {
1647 b'2' | b'T' | b'C' | b'n' => {}
1648 b'D' => {
1649 if error.is_none()
1650 && let Err(err) = on_row(row_buf.as_slice())
1651 {
1652 return return_callback_error_with_desync(self, err);
1653 }
1654 }
1655 b'Z' => {
1656 if let Some(err) = error {
1657 return Err(err);
1658 }
1659 return Ok(flow.completed_queries());
1660 }
1661 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1662 other => {
1663 return return_with_desync(
1664 self,
1665 unexpected_backend_msg_type(
1666 "pipeline_execute_prepared_visit_rows",
1667 other,
1668 ),
1669 );
1670 }
1671 }
1672 }
1673 Err(e) => {
1674 if matches!(&e, PgError::QueryServer(_)) {
1675 capture_query_server_error(self, &mut error, e);
1676 continue;
1677 }
1678 return Err(e);
1679 }
1680 }
1681 }
1682 }
1683
1684 pub async fn pipeline_execute_prepared_visit_bytes_rows<F>(
1689 &mut self,
1690 stmt: &super::PreparedStatement,
1691 params_batch: &[Vec<Option<Vec<u8>>>],
1692 mut on_row: F,
1693 ) -> PgResult<usize>
1694 where
1695 F: FnMut(&super::PgBytesRow) -> PgResult<()>,
1696 {
1697 if params_batch.is_empty() {
1698 return Ok(0);
1699 }
1700
1701 if !self.prepared_statements.contains_key(&stmt.name) {
1702 return Err(PgError::Query(
1703 "Statement not prepared. Call prepare() first.".to_string(),
1704 ));
1705 }
1706
1707 reserve_prepared_pipeline_write_buf(self, stmt, params_batch, PgEncoder::FORMAT_TEXT)?;
1708
1709 for params in params_batch {
1710 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1711 .map_err(|e| PgError::Encode(e.to_string()))?;
1712 PgEncoder::encode_execute_to(&mut self.write_buf);
1713 }
1714
1715 PgEncoder::encode_sync_to(&mut self.write_buf);
1716 self.flush_write_buf().await?;
1717
1718 let mut row = super::PgBytesRow::default();
1719 let mut error: Option<PgError> = None;
1720 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1721 expected_queries: params_batch.len(),
1722 allow_parse_complete: false,
1723 require_parse_before_bind: false,
1724 no_data_counts_as_completion: true,
1725 allow_no_data_nonterminal: false,
1726 expected_parse_completes: Some(0),
1727 });
1728
1729 loop {
1730 match self.recv_fill_zerocopy_row_fast(&mut row).await {
1731 Ok(msg_type) => {
1732 if let Err(err) = flow.validate_msg_type(
1733 msg_type,
1734 "pipeline_execute_prepared_visit_bytes_rows",
1735 error.is_some(),
1736 ) {
1737 return return_with_desync(self, err);
1738 }
1739 match msg_type {
1740 b'2' | b'T' | b'C' | b'n' => {}
1741 b'D' => {
1742 if error.is_none() {
1743 if let Err(err) = on_row(&row) {
1744 return return_callback_error_with_desync(self, err);
1745 }
1746 row.release_payload();
1747 }
1748 }
1749 b'Z' => {
1750 if let Some(err) = error {
1751 return Err(err);
1752 }
1753 return Ok(flow.completed_queries());
1754 }
1755 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1756 other => {
1757 return return_with_desync(
1758 self,
1759 unexpected_backend_msg_type(
1760 "pipeline_execute_prepared_visit_bytes_rows",
1761 other,
1762 ),
1763 );
1764 }
1765 }
1766 }
1767 Err(e) => {
1768 if matches!(&e, PgError::QueryServer(_)) {
1769 capture_query_server_error(self, &mut error, e);
1770 continue;
1771 }
1772 return Err(e);
1773 }
1774 }
1775 }
1776 }
1777
1778 pub async fn pipeline_execute_prepared_visit_first_column_bytes<F>(
1780 &mut self,
1781 stmt: &super::PreparedStatement,
1782 params_batch: &[Vec<Option<Vec<u8>>>],
1783 mut on_value: F,
1784 ) -> PgResult<usize>
1785 where
1786 F: FnMut(Option<&[u8]>) -> PgResult<()>,
1787 {
1788 if params_batch.is_empty() {
1789 return Ok(0);
1790 }
1791
1792 if !self.prepared_statements.contains_key(&stmt.name) {
1793 return Err(PgError::Query(
1794 "Statement not prepared. Call prepare() first.".to_string(),
1795 ));
1796 }
1797
1798 reserve_prepared_pipeline_write_buf(self, stmt, params_batch, PgEncoder::FORMAT_TEXT)?;
1799 for params in params_batch {
1800 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1801 .map_err(|e| PgError::Encode(e.to_string()))?;
1802 PgEncoder::encode_execute_to(&mut self.write_buf);
1803 }
1804
1805 PgEncoder::encode_sync_to(&mut self.write_buf);
1806 self.flush_write_buf().await?;
1807
1808 let mut first_column: Option<Bytes> = None;
1809 let mut error: Option<PgError> = None;
1810 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1811 expected_queries: params_batch.len(),
1812 allow_parse_complete: false,
1813 require_parse_before_bind: false,
1814 no_data_counts_as_completion: true,
1815 allow_no_data_nonterminal: false,
1816 expected_parse_completes: Some(0),
1817 });
1818
1819 loop {
1820 match self
1821 .recv_fill_first_column_zerocopy_fast(&mut first_column)
1822 .await
1823 {
1824 Ok(msg_type) => {
1825 if let Err(err) = flow.validate_msg_type(
1826 msg_type,
1827 "pipeline_execute_prepared_visit_first_column_bytes",
1828 error.is_some(),
1829 ) {
1830 return return_with_desync(self, err);
1831 }
1832 match msg_type {
1833 b'2' | b'T' | b'C' | b'n' => {}
1834 b'D' => {
1835 if error.is_none() {
1836 if let Err(err) = on_value(first_column.as_deref()) {
1837 return return_callback_error_with_desync(self, err);
1838 }
1839 first_column = None;
1840 }
1841 }
1842 b'Z' => {
1843 if let Some(err) = error {
1844 return Err(err);
1845 }
1846 return Ok(flow.completed_queries());
1847 }
1848 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1849 other => {
1850 return return_with_desync(
1851 self,
1852 unexpected_backend_msg_type(
1853 "pipeline_execute_prepared_visit_first_column_bytes",
1854 other,
1855 ),
1856 );
1857 }
1858 }
1859 }
1860 Err(e) => {
1861 if matches!(&e, PgError::QueryServer(_)) {
1862 capture_query_server_error(self, &mut error, e);
1863 continue;
1864 }
1865 return Err(e);
1866 }
1867 }
1868 }
1869 }
1870
1871 pub async fn pipeline_execute_prepared_visit_first_four_columns_bytes<F>(
1873 &mut self,
1874 stmt: &super::PreparedStatement,
1875 params_batch: &[Vec<Option<Vec<u8>>>],
1876 mut on_row: F,
1877 ) -> PgResult<usize>
1878 where
1879 F: FnMut([Option<&[u8]>; 4]) -> PgResult<()>,
1880 {
1881 if params_batch.is_empty() {
1882 return Ok(0);
1883 }
1884
1885 if !self.prepared_statements.contains_key(&stmt.name) {
1886 return Err(PgError::Query(
1887 "Statement not prepared. Call prepare() first.".to_string(),
1888 ));
1889 }
1890
1891 reserve_prepared_pipeline_write_buf(self, stmt, params_batch, PgEncoder::FORMAT_TEXT)?;
1892 for params in params_batch {
1893 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1894 .map_err(|e| PgError::Encode(e.to_string()))?;
1895 PgEncoder::encode_execute_to(&mut self.write_buf);
1896 }
1897
1898 PgEncoder::encode_sync_to(&mut self.write_buf);
1899 self.flush_write_buf().await?;
1900
1901 let mut columns = [None, None, None, None];
1902 let mut error: Option<PgError> = None;
1903 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1904 expected_queries: params_batch.len(),
1905 allow_parse_complete: false,
1906 require_parse_before_bind: false,
1907 no_data_counts_as_completion: true,
1908 allow_no_data_nonterminal: false,
1909 expected_parse_completes: Some(0),
1910 });
1911
1912 loop {
1913 match self
1914 .recv_fill_first_four_columns_zerocopy_fast(&mut columns)
1915 .await
1916 {
1917 Ok(msg_type) => {
1918 if let Err(err) = flow.validate_msg_type(
1919 msg_type,
1920 "pipeline_execute_prepared_visit_first_four_columns_bytes",
1921 error.is_some(),
1922 ) {
1923 return return_with_desync(self, err);
1924 }
1925 match msg_type {
1926 b'2' | b'T' | b'C' | b'n' => {}
1927 b'D' => {
1928 if error.is_none() {
1929 if let Err(err) = on_row([
1930 columns[0].as_deref(),
1931 columns[1].as_deref(),
1932 columns[2].as_deref(),
1933 columns[3].as_deref(),
1934 ]) {
1935 return return_callback_error_with_desync(self, err);
1936 }
1937 columns.fill(None);
1938 }
1939 }
1940 b'Z' => {
1941 if let Some(err) = error {
1942 return Err(err);
1943 }
1944 return Ok(flow.completed_queries());
1945 }
1946 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1947 other => {
1948 return return_with_desync(
1949 self,
1950 unexpected_backend_msg_type(
1951 "pipeline_execute_prepared_visit_first_four_columns_bytes",
1952 other,
1953 ),
1954 );
1955 }
1956 }
1957 }
1958 Err(e) => {
1959 if matches!(&e, PgError::QueryServer(_)) {
1960 capture_query_server_error(self, &mut error, e);
1961 continue;
1962 }
1963 return Err(e);
1964 }
1965 }
1966 }
1967 }
1968
1969 pub async fn pipeline_execute_prepared_rows_2cols_bytes(
1971 &mut self,
1972 stmt: &super::PreparedStatement,
1973 params_batch: &[Vec<Option<Vec<u8>>>],
1974 ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
1975 if params_batch.is_empty() {
1976 return Ok(Vec::new());
1977 }
1978
1979 if !self.prepared_statements.contains_key(&stmt.name) {
1980 return Err(PgError::Query(
1981 "Statement not prepared. Call prepare() first.".to_string(),
1982 ));
1983 }
1984
1985 reserve_prepared_pipeline_write_buf(self, stmt, params_batch, PgEncoder::FORMAT_TEXT)?;
1986
1987 for params in params_batch {
1988 PgEncoder::encode_bind_to(&mut self.write_buf, &stmt.name, params)
1989 .map_err(|e| PgError::Encode(e.to_string()))?;
1990 PgEncoder::encode_execute_to(&mut self.write_buf);
1991 }
1992
1993 PgEncoder::encode_sync_to(&mut self.write_buf);
1994 self.flush_write_buf().await?;
1995
1996 let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
1998 Vec::with_capacity(params_batch.len());
1999 let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
2000 let mut error: Option<PgError> = None;
2001 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
2002 expected_queries: params_batch.len(),
2003 allow_parse_complete: false,
2004 require_parse_before_bind: false,
2005 no_data_counts_as_completion: true,
2006 allow_no_data_nonterminal: false,
2007 expected_parse_completes: Some(0),
2008 });
2009
2010 loop {
2011 match self.recv_data_ultra().await {
2012 Ok((msg_type, data)) => {
2013 if let Err(err) = flow.validate_msg_type(
2014 msg_type,
2015 "pipeline_execute_prepared_rows_2cols_bytes",
2016 error.is_some(),
2017 ) {
2018 return return_with_desync(self, err);
2019 }
2020 match msg_type {
2021 b'2' | b'T' => {} b'D' => {
2023 if error.is_none()
2024 && let Some(row) = data
2025 {
2026 current_rows.push(row);
2027 }
2028 }
2029 b'C' => {
2030 all_results.push(std::mem::take(&mut current_rows));
2031 current_rows = Vec::with_capacity(16);
2032 }
2033 b'n' => {
2034 all_results.push(Vec::new());
2035 }
2036 b'Z' => {
2037 if all_results.len() != params_batch.len() {
2038 return Err(error.unwrap_or_else(|| {
2039 PgError::Protocol(format!(
2040 "Pipeline completion mismatch: expected {}, got {}",
2041 params_batch.len(),
2042 all_results.len()
2043 ))
2044 }));
2045 }
2046 if let Some(err) = error {
2047 return Err(err);
2048 }
2049 return Ok(all_results);
2050 }
2051 msg_type if is_ignorable_session_msg_type(msg_type) => {}
2052 other => {
2053 return return_with_desync(
2054 self,
2055 unexpected_backend_msg_type(
2056 "pipeline_execute_prepared_rows_2cols_bytes",
2057 other,
2058 ),
2059 );
2060 }
2061 }
2062 }
2063 Err(e) => {
2064 if matches!(&e, PgError::QueryServer(_)) {
2065 capture_query_server_error(self, &mut error, e);
2066 continue;
2067 }
2068 return Err(e);
2069 }
2070 }
2071 }
2072 }
2073}
2074
2075#[cfg(test)]
2076mod tests {
2077 use super::*;
2078 use qail_core::ast::Qail;
2079
2080 #[test]
2081 fn ast_pipeline_mode_auto_resolves_by_batch_size() {
2082 assert_eq!(
2083 AstPipelineMode::Auto.resolve_for_batch_len(0),
2084 AstPipelineMode::OneShot
2085 );
2086 assert_eq!(
2087 AstPipelineMode::Auto.resolve_for_batch_len(7),
2088 AstPipelineMode::OneShot
2089 );
2090 assert_eq!(
2091 AstPipelineMode::Auto.resolve_for_batch_len(8),
2092 AstPipelineMode::Cached
2093 );
2094 assert_eq!(
2095 AstPipelineMode::Cached.resolve_for_batch_len(1),
2096 AstPipelineMode::Cached
2097 );
2098 assert_eq!(
2099 AstPipelineMode::OneShot.resolve_for_batch_len(1000),
2100 AstPipelineMode::OneShot
2101 );
2102 }
2103
2104 #[cfg(unix)]
2105 fn make_test_conn_with_prepared() -> PgConnection {
2106 use crate::driver::connection::StatementCache;
2107 use crate::driver::stream::PgStream;
2108 use bytes::BytesMut;
2109 use std::collections::{HashMap, VecDeque};
2110 use std::num::NonZeroUsize;
2111 use tokio::net::UnixStream;
2112
2113 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
2114 let mut conn = PgConnection {
2115 stream: PgStream::Unix(unix_stream),
2116 buffer: BytesMut::with_capacity(1024),
2117 write_buf: BytesMut::with_capacity(1024),
2118 sql_buf: BytesMut::with_capacity(256),
2119 params_buf: Vec::new(),
2120 prepared_statements: HashMap::new(),
2121 stmt_cache: StatementCache::new(NonZeroUsize::new(16).expect("non-zero")),
2122 column_info_cache: HashMap::new(),
2123 process_id: 0,
2124 cancel_key_bytes: Vec::new(),
2125 requested_protocol_minor: PgConnection::default_protocol_minor(),
2126 negotiated_protocol_minor: PgConnection::default_protocol_minor(),
2127 notifications: VecDeque::new(),
2128 replication_stream_active: false,
2129 replication_mode_enabled: false,
2130 last_replication_wal_end: None,
2131 io_desynced: false,
2132 pending_statement_closes: Vec::new(),
2133 draining_statement_closes: false,
2134 };
2135 conn.prepared_statements
2136 .insert("s1".to_string(), "SELECT 1".to_string());
2137 conn.stmt_cache.put(1, "s1".to_string());
2138 conn
2139 }
2140
2141 fn server_error(code: &str, message: &str) -> PgError {
2142 PgError::QueryServer(super::super::PgServerError {
2143 severity: "ERROR".to_string(),
2144 code: code.to_string(),
2145 message: message.to_string(),
2146 detail: None,
2147 hint: None,
2148 })
2149 }
2150
2151 #[cfg(unix)]
2152 #[tokio::test]
2153 async fn streaming_callback_error_marks_pipeline_connection_desynced() {
2154 let mut conn = make_test_conn_with_prepared();
2155
2156 let err = return_callback_error_with_desync::<()>(
2157 &mut conn,
2158 PgError::Query("consumer stopped".to_string()),
2159 )
2160 .expect_err("callback error should be returned");
2161
2162 assert!(matches!(err, PgError::Query(msg) if msg == "consumer stopped"));
2163 assert!(conn.is_io_desynced());
2164 }
2165
2166 #[cfg(unix)]
2167 fn insert_cached_stmt(conn: &mut PgConnection, sql_hash: u64) -> String {
2168 let stmt_name = super::super::prepared::stmt_name_from_hash(sql_hash);
2169 conn.stmt_cache.put(sql_hash, stmt_name.clone());
2170 conn.prepared_statements
2171 .insert(stmt_name.clone(), format!("SELECT {sql_hash}"));
2172 stmt_name
2173 }
2174
2175 #[cfg(unix)]
2176 #[tokio::test]
2177 async fn capture_query_server_error_clears_prepared_state_on_retryable_error() {
2178 let mut conn = make_test_conn_with_prepared();
2179 let mut slot = None;
2180 let err = server_error("26000", "prepared statement \"s1\" does not exist");
2181 capture_query_server_error(&mut conn, &mut slot, err);
2182
2183 assert!(slot.is_some());
2184 assert!(conn.prepared_statements.is_empty());
2185 assert_eq!(conn.stmt_cache.len(), 0);
2186 }
2187
2188 #[cfg(unix)]
2189 #[tokio::test]
2190 async fn capture_query_server_error_preserves_prepared_state_on_non_retryable_error() {
2191 let mut conn = make_test_conn_with_prepared();
2192 let mut slot = None;
2193 let err = server_error("23505", "duplicate key value violates unique constraint");
2194 capture_query_server_error(&mut conn, &mut slot, err);
2195
2196 assert!(slot.is_some());
2197 assert_eq!(conn.prepared_statements.len(), 1);
2198 assert_eq!(conn.stmt_cache.len(), 1);
2199 }
2200
2201 #[cfg(unix)]
2202 #[tokio::test]
2203 async fn capture_query_server_error_does_not_override_existing_error() {
2204 let mut conn = make_test_conn_with_prepared();
2205 let mut slot = Some(server_error("23505", "duplicate key"));
2206 let retryable = server_error("26000", "prepared statement \"s1\" does not exist");
2207 capture_query_server_error(&mut conn, &mut slot, retryable);
2208
2209 assert_eq!(conn.prepared_statements.len(), 1);
2210 assert_eq!(conn.stmt_cache.len(), 1);
2211 assert_eq!(
2212 slot.and_then(|e| e.sqlstate().map(str::to_string))
2213 .as_deref(),
2214 Some("23505")
2215 );
2216 }
2217
2218 #[cfg(unix)]
2219 #[tokio::test]
2220 async fn pipeline_ast_cached_rolls_back_new_state_on_encode_error() {
2221 let mut conn = make_test_conn_with_prepared();
2222 let baseline = conn.prepared_statements.len();
2223 let baseline_stmt_cache = conn.stmt_cache.len();
2224
2225 let cmds = vec![
2226 Qail::get("harbors").columns(["id", "name"]).limit(1),
2227 Qail::get("bad\0table").columns(["id"]).limit(1),
2228 ];
2229
2230 let err = conn
2231 .pipeline_execute_count_ast_cached(&cmds)
2232 .await
2233 .expect_err("expected encode error for NUL byte in table name");
2234
2235 assert!(matches!(err, PgError::Encode(_)));
2236 assert_eq!(conn.prepared_statements.len(), baseline);
2237 assert_eq!(conn.stmt_cache.len(), baseline_stmt_cache);
2238 assert!(conn.prepared_statements.contains_key("s1"));
2239 }
2240
2241 #[cfg(unix)]
2242 #[tokio::test]
2243 async fn rollback_new_cached_statements_preserves_server_parsed_prefix() {
2244 let mut conn = make_test_conn_with_prepared();
2245 let parsed_1 = insert_cached_stmt(&mut conn, 10);
2246 let parsed_2 = insert_cached_stmt(&mut conn, 11);
2247 let unparsed = insert_cached_stmt(&mut conn, 12);
2248
2249 rollback_new_cached_statements_from(&mut conn, &[10, 11, 12], 2);
2250
2251 assert!(conn.prepared_statements.contains_key(&parsed_1));
2252 assert!(conn.prepared_statements.contains_key(&parsed_2));
2253 assert!(!conn.prepared_statements.contains_key(&unparsed));
2254 assert!(conn.stmt_cache.contains(&10));
2255 assert!(conn.stmt_cache.contains(&11));
2256 assert!(!conn.stmt_cache.contains(&12));
2257 }
2258
2259 #[cfg(unix)]
2260 #[tokio::test]
2261 async fn rollback_new_cached_statements_removes_prepared_entry_after_lru_drop() {
2262 let mut conn = make_test_conn_with_prepared();
2263 let stmt_name = insert_cached_stmt(&mut conn, 99);
2264 conn.stmt_cache.remove(&99);
2265
2266 rollback_new_cached_statements(&mut conn, &[99]);
2267
2268 assert!(!conn.prepared_statements.contains_key(&stmt_name));
2269 assert!(!conn.stmt_cache.contains(&99));
2270 }
2271
2272 #[cfg(unix)]
2273 #[tokio::test]
2274 async fn cached_pipeline_capacity_enforcement_queues_closes_after_registration() {
2275 use crate::driver::connection::StatementCache;
2276 use std::num::NonZeroUsize;
2277
2278 let mut conn = make_test_conn_with_prepared();
2279 conn.prepared_statements.clear();
2280 conn.stmt_cache = StatementCache::new(
2281 NonZeroUsize::new(PgConnection::MAX_PREPARED_PER_CONN).expect("non-zero"),
2282 );
2283
2284 for hash in 0..(PgConnection::MAX_PREPARED_PER_CONN as u64 + 3) {
2285 insert_cached_stmt(&mut conn, hash);
2286 }
2287 assert!(conn.pending_statement_closes.is_empty());
2288
2289 enforce_prepared_statement_cache_limit(&mut conn);
2290
2291 assert_eq!(
2292 conn.prepared_statements.len(),
2293 PgConnection::MAX_PREPARED_PER_CONN
2294 );
2295 assert_eq!(conn.pending_statement_closes.len(), 3);
2296 }
2297
2298 #[cfg(unix)]
2299 #[tokio::test]
2300 async fn pipeline_simple_ast_empty_batch_returns_zero_without_io() {
2301 let mut conn = make_test_conn_with_prepared();
2302 let res = conn
2303 .pipeline_execute_count_simple_ast(&[])
2304 .await
2305 .expect("empty batch should be a fast no-op");
2306 assert_eq!(res, 0);
2307 assert!(!conn.is_io_desynced());
2308 }
2309}