1use super::{
14 PgConnection, PgError, PgResult, is_ignorable_session_message, is_ignorable_session_msg_type,
15 unexpected_backend_message, unexpected_backend_msg_type,
16};
17use crate::protocol::{AstEncoder, BackendMessage, PgEncoder};
18use bytes::BytesMut;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
25pub enum AstPipelineMode {
26 #[default]
30 Auto,
31 OneShot,
33 Cached,
35}
36
37impl AstPipelineMode {
38 const AUTO_CACHE_MIN_BATCH: usize = 8;
39
40 #[inline]
41 fn resolve_for_batch_len(self, batch_len: usize) -> Self {
42 match self {
43 Self::Auto => {
44 if batch_len >= Self::AUTO_CACHE_MIN_BATCH {
45 Self::Cached
46 } else {
47 Self::OneShot
48 }
49 }
50 mode => mode,
51 }
52 }
53}
54
55#[inline]
56fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
57 if matches!(
58 err,
59 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
60 ) {
61 conn.mark_io_desynced();
62 }
63 Err(err)
64}
65
66#[inline]
67fn capture_query_server_error(conn: &mut PgConnection, slot: &mut Option<PgError>, err: PgError) {
68 if slot.is_some() {
69 return;
70 }
71 if err.is_prepared_statement_retryable() {
72 conn.clear_prepared_statement_state();
73 }
74 *slot = Some(err);
75}
76
77#[inline]
78fn rollback_new_cached_statements(conn: &mut PgConnection, new_stmt_hashes: &[u64]) {
79 for sql_hash in new_stmt_hashes {
80 if let Some(stmt_name) = conn.stmt_cache.remove(sql_hash) {
81 conn.prepared_statements.remove(&stmt_name);
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy)]
87struct FastExtendedFlowConfig {
88 expected_queries: usize,
89 allow_parse_complete: bool,
90 require_parse_before_bind: bool,
91 no_data_counts_as_completion: bool,
92 allow_no_data_nonterminal: bool,
93 expected_parse_completes: Option<usize>,
94}
95
96#[derive(Debug, Clone, Copy)]
97struct FastExtendedFlowTracker {
98 cfg: FastExtendedFlowConfig,
99 completed_queries: usize,
100 parse_completes: usize,
101 current_parse_seen: bool,
102 current_bind_seen: bool,
103}
104
105impl FastExtendedFlowTracker {
106 fn new(cfg: FastExtendedFlowConfig) -> Self {
107 Self {
108 cfg,
109 completed_queries: 0,
110 parse_completes: 0,
111 current_parse_seen: false,
112 current_bind_seen: false,
113 }
114 }
115
116 fn completed_queries(&self) -> usize {
117 self.completed_queries
118 }
119
120 fn validate_msg_type(
121 &mut self,
122 msg_type: u8,
123 context: &'static str,
124 error_pending: bool,
125 ) -> PgResult<FastPipelineEvent> {
126 if is_ignorable_session_msg_type(msg_type) {
127 return Ok(FastPipelineEvent::Continue);
128 }
129
130 if error_pending {
131 if msg_type == b'Z' {
132 return Ok(FastPipelineEvent::ReadyForQuery);
133 }
134 return Ok(FastPipelineEvent::Continue);
135 }
136
137 if msg_type == b'Z' {
138 if self.completed_queries != self.cfg.expected_queries {
139 return Err(PgError::Protocol(format!(
140 "{}: Pipeline completion mismatch: expected {}, got {}",
141 context, self.cfg.expected_queries, self.completed_queries
142 )));
143 }
144 if self.current_parse_seen || self.current_bind_seen {
145 return Err(PgError::Protocol(format!(
146 "{}: ReadyForQuery with incomplete query state",
147 context
148 )));
149 }
150 if let Some(expected) = self.cfg.expected_parse_completes
151 && self.parse_completes != expected
152 {
153 return Err(PgError::Protocol(format!(
154 "{}: ParseComplete mismatch: expected {}, got {}",
155 context, expected, self.parse_completes
156 )));
157 }
158 return Ok(FastPipelineEvent::ReadyForQuery);
159 }
160
161 if self.completed_queries >= self.cfg.expected_queries {
162 return Err(PgError::Protocol(format!(
163 "{}: unexpected message '{}' after all queries completed",
164 context, msg_type as char
165 )));
166 }
167
168 match msg_type {
169 b'1' => {
170 if !self.cfg.allow_parse_complete {
171 return Err(PgError::Protocol(format!(
172 "{}: unexpected ParseComplete",
173 context
174 )));
175 }
176 if self.current_bind_seen {
177 return Err(PgError::Protocol(format!(
178 "{}: ParseComplete after BindComplete",
179 context
180 )));
181 }
182 if self.current_parse_seen {
183 return Err(PgError::Protocol(format!(
184 "{}: duplicate ParseComplete",
185 context
186 )));
187 }
188 self.current_parse_seen = true;
189 self.parse_completes += 1;
190 if let Some(expected) = self.cfg.expected_parse_completes
191 && self.parse_completes > expected
192 {
193 return Err(PgError::Protocol(format!(
194 "{}: ParseComplete mismatch: expected {}, got at least {}",
195 context, expected, self.parse_completes
196 )));
197 }
198 }
199 b'2' => {
200 if self.current_bind_seen {
201 return Err(PgError::Protocol(format!(
202 "{}: duplicate BindComplete",
203 context
204 )));
205 }
206 if self.cfg.require_parse_before_bind && !self.current_parse_seen {
207 return Err(PgError::Protocol(format!(
208 "{}: BindComplete before ParseComplete",
209 context
210 )));
211 }
212 self.current_bind_seen = true;
213 }
214 b'T' | b't' | b's' => {
215 if !self.current_bind_seen {
216 return Err(PgError::Protocol(format!(
217 "{}: '{}' before BindComplete",
218 context, msg_type as char
219 )));
220 }
221 }
222 b'D' => {
223 if !self.current_bind_seen {
224 return Err(PgError::Protocol(format!(
225 "{}: DataRow before BindComplete",
226 context
227 )));
228 }
229 }
230 b'n' => {
231 if !self.current_bind_seen {
232 return Err(PgError::Protocol(format!(
233 "{}: NoData before BindComplete",
234 context
235 )));
236 }
237 if self.cfg.no_data_counts_as_completion {
238 self.complete_current();
239 } else if !self.cfg.allow_no_data_nonterminal {
240 return Err(PgError::Protocol(format!("{}: unexpected NoData", context)));
241 }
242 }
243 b'C' => {
244 if !self.current_bind_seen {
245 return Err(PgError::Protocol(format!(
246 "{}: CommandComplete before BindComplete",
247 context
248 )));
249 }
250 self.complete_current();
251 }
252 b'I' => {
253 return Err(PgError::Protocol(format!(
254 "{}: unexpected EmptyQueryResponse in extended pipeline",
255 context
256 )));
257 }
258 other => return Err(unexpected_backend_msg_type(context, other)),
259 }
260
261 Ok(FastPipelineEvent::Continue)
262 }
263
264 fn complete_current(&mut self) {
265 self.completed_queries += 1;
266 self.current_parse_seen = false;
267 self.current_bind_seen = false;
268 }
269}
270
271#[derive(Debug, Clone, Copy)]
272struct FastSimpleFlowTracker {
273 expected_queries: usize,
274 completed_queries: usize,
275 current_row_description_seen: bool,
276}
277
278impl FastSimpleFlowTracker {
279 fn new(expected_queries: usize) -> Self {
280 Self {
281 expected_queries,
282 completed_queries: 0,
283 current_row_description_seen: false,
284 }
285 }
286
287 fn completed_queries(&self) -> usize {
288 self.completed_queries
289 }
290
291 fn validate_msg_type(
292 &mut self,
293 msg_type: u8,
294 context: &'static str,
295 error_pending: bool,
296 ) -> PgResult<FastPipelineEvent> {
297 if is_ignorable_session_msg_type(msg_type) {
298 return Ok(FastPipelineEvent::Continue);
299 }
300
301 if error_pending {
302 if msg_type == b'Z' {
303 return Ok(FastPipelineEvent::ReadyForQuery);
304 }
305 return Ok(FastPipelineEvent::Continue);
306 }
307
308 if msg_type == b'Z' {
309 if self.completed_queries != self.expected_queries {
310 return Err(PgError::Protocol(format!(
311 "{}: Pipeline completion mismatch: expected {}, got {}",
312 context, self.expected_queries, self.completed_queries
313 )));
314 }
315 if self.current_row_description_seen {
316 return Err(PgError::Protocol(format!(
317 "{}: ReadyForQuery with incomplete row stream",
318 context
319 )));
320 }
321 return Ok(FastPipelineEvent::ReadyForQuery);
322 }
323
324 if self.completed_queries >= self.expected_queries {
325 return Err(PgError::Protocol(format!(
326 "{}: unexpected message '{}' after all queries completed",
327 context, msg_type as char
328 )));
329 }
330
331 match msg_type {
332 b'T' => {
333 if self.current_row_description_seen {
334 return Err(PgError::Protocol(format!(
335 "{}: duplicate RowDescription",
336 context
337 )));
338 }
339 self.current_row_description_seen = true;
340 }
341 b'D' => {
342 if !self.current_row_description_seen {
343 return Err(PgError::Protocol(format!(
344 "{}: DataRow before RowDescription",
345 context
346 )));
347 }
348 }
349 b'C' | b'I' => {
350 self.completed_queries += 1;
351 self.current_row_description_seen = false;
352 }
353 b'1' | b'2' | b'n' | b't' | b's' => {
354 return Err(PgError::Protocol(format!(
355 "{}: unexpected '{}' in simple pipeline",
356 context, msg_type as char
357 )));
358 }
359 other => return Err(unexpected_backend_msg_type(context, other)),
360 }
361
362 Ok(FastPipelineEvent::Continue)
363 }
364}
365
366#[derive(Debug, Clone, Copy, PartialEq, Eq)]
367enum FastPipelineEvent {
368 Continue,
369 ReadyForQuery,
370}
371
372#[inline]
373fn backend_msg_type_for_flow(msg: &BackendMessage) -> Option<u8> {
374 match msg {
375 BackendMessage::ParseComplete => Some(b'1'),
376 BackendMessage::BindComplete => Some(b'2'),
377 BackendMessage::ParameterDescription(_) => Some(b't'),
378 BackendMessage::RowDescription(_) => Some(b'T'),
379 BackendMessage::NoData => Some(b'n'),
380 BackendMessage::PortalSuspended => Some(b's'),
381 BackendMessage::DataRow(_) => Some(b'D'),
382 BackendMessage::CommandComplete(_) => Some(b'C'),
383 BackendMessage::EmptyQueryResponse => Some(b'I'),
384 BackendMessage::ReadyForQuery(_) => Some(b'Z'),
385 _ => None,
386 }
387}
388
389impl PgConnection {
390 pub async fn query_pipeline(
392 &mut self,
393 queries: &[(&str, &[Option<Vec<u8>>])],
394 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
395 let mut buf = BytesMut::new();
397 for (sql, params) in queries {
398 buf.extend_from_slice(
399 &PgEncoder::encode_extended_query(sql, params)
400 .map_err(|e| PgError::Encode(e.to_string()))?,
401 );
402 }
403
404 self.write_all_with_timeout(&buf, "stream write").await?;
406
407 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(queries.len());
409 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
410 let mut error: Option<PgError> = None;
411 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
412 expected_queries: queries.len(),
413 allow_parse_complete: true,
414 require_parse_before_bind: true,
415 no_data_counts_as_completion: true,
416 allow_no_data_nonterminal: false,
417 expected_parse_completes: Some(queries.len()),
418 });
419
420 loop {
421 let msg = self.recv().await?;
422 if is_ignorable_session_message(&msg) {
423 continue;
424 }
425 if let BackendMessage::ErrorResponse(err) = msg {
426 if error.is_none() {
427 error = Some(PgError::QueryServer(err.into()));
428 }
429 continue;
430 }
431 let msg_type = backend_msg_type_for_flow(&msg)
432 .ok_or_else(|| unexpected_backend_message("pipeline query", &msg));
433 let msg_type = match msg_type {
434 Ok(msg_type) => msg_type,
435 Err(err) => return return_with_desync(self, err),
436 };
437 if let Err(err) = flow.validate_msg_type(msg_type, "pipeline query", error.is_some()) {
438 return return_with_desync(self, err);
439 }
440 match msg {
441 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
442 BackendMessage::RowDescription(_) => {}
443 BackendMessage::DataRow(data) => {
444 if error.is_none() {
445 current_rows.push(data);
446 }
447 }
448 BackendMessage::CommandComplete(_) => {
449 all_results.push(std::mem::take(&mut current_rows));
450 }
451 BackendMessage::NoData => {
452 all_results.push(Vec::new());
453 }
454 BackendMessage::ReadyForQuery(_) => {
455 if all_results.len() != queries.len() {
456 return Err(error.unwrap_or_else(|| {
457 PgError::Protocol(format!(
458 "Pipeline completion mismatch: expected {}, got {}",
459 queries.len(),
460 all_results.len()
461 ))
462 }));
463 }
464 if let Some(err) = error {
465 return Err(err);
466 }
467 return Ok(all_results);
468 }
469 other => {
470 return return_with_desync(
471 self,
472 unexpected_backend_message("pipeline query", &other),
473 );
474 }
475 }
476 }
477 }
478
479 pub async fn pipeline_execute_rows_ast(
481 &mut self,
482 cmds: &[qail_core::ast::Qail],
483 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
484 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
485 self.write_all_with_timeout(&buf, "stream write").await?;
486
487 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> = Vec::with_capacity(cmds.len());
488 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
489 let mut error: Option<PgError> = None;
490 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
491 expected_queries: cmds.len(),
492 allow_parse_complete: true,
493 require_parse_before_bind: true,
494 no_data_counts_as_completion: true,
495 allow_no_data_nonterminal: false,
496 expected_parse_completes: Some(cmds.len()),
497 });
498
499 loop {
500 let msg = self.recv().await?;
501 if is_ignorable_session_message(&msg) {
502 continue;
503 }
504 if let BackendMessage::ErrorResponse(err) = msg {
505 if error.is_none() {
506 error = Some(PgError::QueryServer(err.into()));
507 }
508 continue;
509 }
510 let msg_type = backend_msg_type_for_flow(&msg)
511 .ok_or_else(|| unexpected_backend_message("pipeline ast", &msg));
512 let msg_type = match msg_type {
513 Ok(msg_type) => msg_type,
514 Err(err) => return return_with_desync(self, err),
515 };
516 if let Err(err) = flow.validate_msg_type(msg_type, "pipeline ast", error.is_some()) {
517 return return_with_desync(self, err);
518 }
519 match msg {
520 BackendMessage::ParseComplete | BackendMessage::BindComplete => {}
521 BackendMessage::RowDescription(_) => {}
522 BackendMessage::DataRow(data) => {
523 if error.is_none() {
524 current_rows.push(data);
525 }
526 }
527 BackendMessage::CommandComplete(_) => {
528 all_results.push(std::mem::take(&mut current_rows));
529 }
530 BackendMessage::NoData => {
531 all_results.push(Vec::new());
532 }
533 BackendMessage::ReadyForQuery(_) => {
534 if all_results.len() != cmds.len() {
535 return Err(error.unwrap_or_else(|| {
536 PgError::Protocol(format!(
537 "Pipeline completion mismatch: expected {}, got {}",
538 cmds.len(),
539 all_results.len()
540 ))
541 }));
542 }
543 if let Some(err) = error {
544 return Err(err);
545 }
546 return Ok(all_results);
547 }
548 other => {
549 return return_with_desync(
550 self,
551 unexpected_backend_message("pipeline ast", &other),
552 );
553 }
554 }
555 }
556 }
557
558 pub async fn pipeline_execute_count_ast_oneshot(
560 &mut self,
561 cmds: &[qail_core::ast::Qail],
562 ) -> PgResult<usize> {
563 let buf = AstEncoder::encode_batch(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
564
565 self.write_all_with_timeout(&buf, "stream write").await?;
566 self.flush_with_timeout("stream flush").await?;
567
568 let mut error: Option<PgError> = None;
569 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
570 expected_queries: cmds.len(),
571 allow_parse_complete: true,
572 require_parse_before_bind: true,
573 no_data_counts_as_completion: true,
574 allow_no_data_nonterminal: false,
575 expected_parse_completes: Some(cmds.len()),
576 });
577
578 loop {
579 match self.recv_msg_type_fast().await {
580 Ok(msg_type) => {
581 let event = match flow.validate_msg_type(
582 msg_type,
583 "pipeline_execute_count_ast_oneshot",
584 error.is_some(),
585 ) {
586 Ok(event) => event,
587 Err(err) => return return_with_desync(self, err),
588 };
589 match event {
590 FastPipelineEvent::Continue => {}
591 FastPipelineEvent::ReadyForQuery => {
592 if let Some(err) = error {
593 return Err(err);
594 }
595 return Ok(flow.completed_queries());
596 }
597 }
598 }
599 Err(e) => {
600 if matches!(&e, PgError::QueryServer(_)) {
601 capture_query_server_error(self, &mut error, e);
602 continue;
603 }
604 return Err(e);
605 }
606 }
607 }
608 }
609
610 #[inline]
616 pub async fn pipeline_execute_count_ast_with_mode(
617 &mut self,
618 cmds: &[qail_core::ast::Qail],
619 mode: AstPipelineMode,
620 ) -> PgResult<usize> {
621 if cmds.is_empty() {
622 return Ok(0);
623 }
624
625 match mode.resolve_for_batch_len(cmds.len()) {
626 AstPipelineMode::OneShot => self.pipeline_execute_count_ast_oneshot(cmds).await,
627 AstPipelineMode::Cached => self.pipeline_execute_count_ast_cached(cmds).await,
628 AstPipelineMode::Auto => unreachable!("Auto mode must resolve to concrete strategy"),
629 }
630 }
631
632 #[inline]
634 pub async fn pipeline_execute_count_wire(
635 &mut self,
636 wire_bytes: &[u8],
637 expected_queries: usize,
638 ) -> PgResult<usize> {
639 self.write_all_with_timeout(wire_bytes, "stream write")
640 .await?;
641 self.flush_with_timeout("stream flush").await?;
642
643 let mut error: Option<PgError> = None;
644 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
645 expected_queries,
646 allow_parse_complete: true,
647 require_parse_before_bind: false,
648 no_data_counts_as_completion: true,
649 allow_no_data_nonterminal: false,
650 expected_parse_completes: None,
651 });
652
653 loop {
654 match self.recv_msg_type_fast().await {
655 Ok(msg_type) => {
656 let event = match flow.validate_msg_type(
657 msg_type,
658 "pipeline_execute_count_wire",
659 error.is_some(),
660 ) {
661 Ok(event) => event,
662 Err(err) => return return_with_desync(self, err),
663 };
664 match event {
665 FastPipelineEvent::Continue => {}
666 FastPipelineEvent::ReadyForQuery => {
667 if let Some(err) = error {
668 return Err(err);
669 }
670 return Ok(flow.completed_queries());
671 }
672 }
673 }
674 Err(e) => {
675 if matches!(&e, PgError::QueryServer(_)) {
676 capture_query_server_error(self, &mut error, e);
677 continue;
678 }
679 return Err(e);
680 }
681 }
682 }
683 }
684
685 #[inline]
687 pub async fn pipeline_execute_count_simple_ast(
688 &mut self,
689 cmds: &[qail_core::ast::Qail],
690 ) -> PgResult<usize> {
691 if cmds.is_empty() {
692 return Ok(0);
693 }
694
695 let buf =
696 AstEncoder::encode_batch_simple(cmds).map_err(|e| PgError::Encode(e.to_string()))?;
697 self.write_all_with_timeout(&buf, "stream write").await?;
698 self.flush_with_timeout("stream flush").await?;
699
700 let mut error: Option<PgError> = None;
701 let mut flow = FastSimpleFlowTracker::new(cmds.len());
702
703 loop {
704 match self.recv_msg_type_fast().await {
705 Ok(msg_type) => {
706 let event = match flow.validate_msg_type(
707 msg_type,
708 "pipeline_execute_count_simple_ast",
709 error.is_some(),
710 ) {
711 Ok(event) => event,
712 Err(err) => return return_with_desync(self, err),
713 };
714 match event {
715 FastPipelineEvent::Continue => {}
716 FastPipelineEvent::ReadyForQuery => {
717 if let Some(err) = error {
718 return Err(err);
719 }
720 return Ok(flow.completed_queries());
721 }
722 }
723 }
724 Err(e) => {
725 if matches!(&e, PgError::QueryServer(_)) {
726 capture_query_server_error(self, &mut error, e);
727 continue;
728 }
729 return Err(e);
730 }
731 }
732 }
733 }
734
735 #[inline]
737 pub async fn pipeline_execute_count_simple_wire(
738 &mut self,
739 wire_bytes: &[u8],
740 expected_queries: usize,
741 ) -> PgResult<usize> {
742 self.write_all_with_timeout(wire_bytes, "stream write")
743 .await?;
744 self.flush_with_timeout("stream flush").await?;
745
746 let mut error: Option<PgError> = None;
747 let mut flow = FastSimpleFlowTracker::new(expected_queries);
748
749 loop {
750 match self.recv_msg_type_fast().await {
751 Ok(msg_type) => {
752 let event = match flow.validate_msg_type(
753 msg_type,
754 "pipeline_execute_count_simple_wire",
755 error.is_some(),
756 ) {
757 Ok(event) => event,
758 Err(err) => return return_with_desync(self, err),
759 };
760 match event {
761 FastPipelineEvent::Continue => {}
762 FastPipelineEvent::ReadyForQuery => {
763 if let Some(err) = error {
764 return Err(err);
765 }
766 return Ok(flow.completed_queries());
767 }
768 }
769 }
770 Err(e) => {
771 if matches!(&e, PgError::QueryServer(_)) {
772 capture_query_server_error(self, &mut error, e);
773 continue;
774 }
775 return Err(e);
776 }
777 }
778 }
779 }
780
781 #[inline]
786 pub async fn pipeline_execute_count_ast_cached(
787 &mut self,
788 cmds: &[qail_core::ast::Qail],
789 ) -> PgResult<usize> {
790 if cmds.is_empty() {
791 return Ok(0);
792 }
793
794 use super::prepared::{sql_bytes_hash, stmt_name_from_hash};
795
796 let mut buf = BytesMut::with_capacity(cmds.len() * 64);
797 let mut sql_buf = BytesMut::with_capacity(256);
798 let mut params: Vec<Option<Vec<u8>>> = Vec::new();
799 let mut new_stmt_hashes: Vec<u64> = Vec::new();
800
801 for cmd in cmds {
802 if let Err(e) = AstEncoder::encode_cmd_sql_reuse(cmd, &mut sql_buf, &mut params) {
803 rollback_new_cached_statements(self, &new_stmt_hashes);
804 return Err(PgError::Encode(e.to_string()));
805 }
806
807 let sql_hash = sql_bytes_hash(sql_buf.as_ref());
808
809 if self.stmt_cache.contains(&sql_hash) {
810 self.stmt_cache.touch_key(sql_hash);
811 } else {
812 let stmt_name = stmt_name_from_hash(sql_hash);
813 if self.prepared_statements.contains_key(&stmt_name) {
814 self.stmt_cache.put(sql_hash, stmt_name.clone());
817 } else {
818 self.evict_prepared_if_full();
819
820 let sql = String::from_utf8_lossy(sql_buf.as_ref()).to_string();
821 let parse_msg = match PgEncoder::try_encode_parse(&stmt_name, &sql, &[]) {
822 Ok(msg) => msg,
823 Err(e) => {
824 rollback_new_cached_statements(self, &new_stmt_hashes);
825 return Err(PgError::Encode(e.to_string()));
826 }
827 };
828 buf.extend(parse_msg);
829 self.stmt_cache.put(sql_hash, stmt_name.clone());
830 self.prepared_statements.insert(stmt_name.clone(), sql);
831 new_stmt_hashes.push(sql_hash);
832 }
833 }
834
835 let Some(stmt_name) = self.stmt_cache.peek(&sql_hash) else {
836 rollback_new_cached_statements(self, &new_stmt_hashes);
837 return Err(PgError::Protocol(
838 "stmt_cache lookup failed after statement registration".to_string(),
839 ));
840 };
841
842 if let Err(e) = PgEncoder::encode_bind_to(&mut buf, stmt_name, ¶ms) {
843 rollback_new_cached_statements(self, &new_stmt_hashes);
844 return Err(PgError::Encode(e.to_string()));
845 }
846 PgEncoder::encode_execute_to(&mut buf);
847 }
848
849 PgEncoder::encode_sync_to(&mut buf);
850
851 if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
852 rollback_new_cached_statements(self, &new_stmt_hashes);
853 return Err(err);
854 }
855 if let Err(err) = self.flush_with_timeout("stream flush").await {
856 rollback_new_cached_statements(self, &new_stmt_hashes);
857 return Err(err);
858 }
859
860 let mut error: Option<PgError> = None;
861 let expected_parse_completes = new_stmt_hashes.len();
862 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
863 expected_queries: cmds.len(),
864 allow_parse_complete: true,
865 require_parse_before_bind: false,
866 no_data_counts_as_completion: true,
867 allow_no_data_nonterminal: false,
868 expected_parse_completes: Some(expected_parse_completes),
869 });
870
871 loop {
872 match self.recv_msg_type_fast().await {
873 Ok(msg_type) => {
874 match flow.validate_msg_type(
875 msg_type,
876 "pipeline_execute_count_ast_cached",
877 error.is_some(),
878 ) {
879 Ok(FastPipelineEvent::Continue) => {}
880 Ok(FastPipelineEvent::ReadyForQuery) => {
881 if let Some(err) = error {
882 rollback_new_cached_statements(self, &new_stmt_hashes);
883 return Err(err);
884 }
885 return Ok(flow.completed_queries());
886 }
887 Err(err) => {
888 rollback_new_cached_statements(self, &new_stmt_hashes);
889 return return_with_desync(self, err);
890 }
891 }
892 }
893 Err(e) => {
894 if matches!(&e, PgError::QueryServer(_)) {
895 capture_query_server_error(self, &mut error, e);
896 continue;
897 }
898 rollback_new_cached_statements(self, &new_stmt_hashes);
899 return Err(e);
900 }
901 }
902 }
903 }
904 #[inline]
919 pub async fn pipeline_execute_prepared_count(
920 &mut self,
921 stmt: &super::PreparedStatement,
922 params_batch: &[Vec<Option<Vec<u8>>>],
923 ) -> PgResult<usize> {
924 if params_batch.is_empty() {
925 return Ok(0);
926 }
927
928 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
930
931 let is_new = !self.prepared_statements.contains_key(&stmt.name);
932
933 if is_new {
934 return Err(PgError::Query(
935 "Statement not prepared. Call prepare() first.".to_string(),
936 ));
937 }
938
939 for params in params_batch {
941 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
942 .map_err(|e| PgError::Encode(e.to_string()))?;
943 PgEncoder::encode_execute_to(&mut buf);
944 }
945
946 PgEncoder::encode_sync_to(&mut buf);
947
948 self.write_all_with_timeout(&buf, "stream write").await?;
949 self.flush_with_timeout("stream flush").await?;
950
951 let mut error: Option<PgError> = None;
952 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
953 expected_queries: params_batch.len(),
954 allow_parse_complete: false,
955 require_parse_before_bind: false,
956 no_data_counts_as_completion: true,
957 allow_no_data_nonterminal: false,
958 expected_parse_completes: Some(0),
959 });
960
961 loop {
962 match self.recv_msg_type_fast().await {
963 Ok(msg_type) => {
964 let event = match flow.validate_msg_type(
965 msg_type,
966 "pipeline_execute_prepared_count",
967 error.is_some(),
968 ) {
969 Ok(event) => event,
970 Err(err) => return return_with_desync(self, err),
971 };
972 match event {
973 FastPipelineEvent::Continue => {}
974 FastPipelineEvent::ReadyForQuery => {
975 if let Some(err) = error {
976 return Err(err);
977 }
978 return Ok(flow.completed_queries());
979 }
980 }
981 }
982 Err(e) => {
983 if matches!(&e, PgError::QueryServer(_)) {
984 capture_query_server_error(self, &mut error, e);
985 continue;
986 }
987 return Err(e);
988 }
989 }
990 }
991 }
992
993 pub async fn prepare(&mut self, sql: &str) -> PgResult<super::PreparedStatement> {
996 use super::prepared::sql_bytes_to_stmt_name;
997
998 let stmt_name = sql_bytes_to_stmt_name(sql.as_bytes());
999
1000 if !self.prepared_statements.contains_key(&stmt_name) {
1001 self.evict_prepared_if_full();
1002 let mut buf = BytesMut::with_capacity(sql.len() + 32);
1003 buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
1004 buf.extend(PgEncoder::encode_sync());
1005
1006 self.write_all_with_timeout(&buf, "stream write").await?;
1007 self.flush_with_timeout("stream flush").await?;
1008
1009 let mut error: Option<PgError> = None;
1011 let mut saw_parse_complete = false;
1012 loop {
1013 match self.recv_msg_type_fast().await {
1014 Ok(msg_type) => match msg_type {
1015 b'1' => {
1016 if saw_parse_complete {
1017 return Err(PgError::Protocol(
1018 "prepare received duplicate ParseComplete".to_string(),
1019 ));
1020 }
1021 saw_parse_complete = true;
1022 self.prepared_statements
1023 .insert(stmt_name.clone(), sql.to_string());
1024 }
1025 b'Z' => {
1026 if let Some(err) = error {
1027 return Err(err);
1028 }
1029 if !saw_parse_complete {
1030 return Err(PgError::Protocol(
1031 "prepare reached ReadyForQuery without ParseComplete"
1032 .to_string(),
1033 ));
1034 }
1035 break;
1036 }
1037 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1038 other => {
1039 return return_with_desync(
1040 self,
1041 unexpected_backend_msg_type("prepare", other),
1042 );
1043 }
1044 },
1045 Err(e) => {
1046 if matches!(&e, PgError::QueryServer(_)) {
1047 capture_query_server_error(self, &mut error, e);
1048 continue;
1049 }
1050 return Err(e);
1051 }
1052 }
1053 }
1054 }
1055
1056 Ok(super::PreparedStatement {
1057 name: stmt_name,
1058 param_count: sql.matches('$').count(),
1059 })
1060 }
1061
1062 pub async fn pipeline_execute_prepared_rows(
1064 &mut self,
1065 stmt: &super::PreparedStatement,
1066 params_batch: &[Vec<Option<Vec<u8>>>],
1067 ) -> PgResult<Vec<Vec<Vec<Option<Vec<u8>>>>>> {
1068 if params_batch.is_empty() {
1069 return Ok(Vec::new());
1070 }
1071
1072 if !self.prepared_statements.contains_key(&stmt.name) {
1073 return Err(PgError::Query(
1074 "Statement not prepared. Call prepare() first.".to_string(),
1075 ));
1076 }
1077
1078 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1079
1080 for params in params_batch {
1081 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1082 .map_err(|e| PgError::Encode(e.to_string()))?;
1083 PgEncoder::encode_execute_to(&mut buf);
1084 }
1085
1086 PgEncoder::encode_sync_to(&mut buf);
1087
1088 self.write_all_with_timeout(&buf, "stream write").await?;
1089 self.flush_with_timeout("stream flush").await?;
1090
1091 let mut all_results: Vec<Vec<Vec<Option<Vec<u8>>>>> =
1093 Vec::with_capacity(params_batch.len());
1094 let mut current_rows: Vec<Vec<Option<Vec<u8>>>> = Vec::new();
1095 let mut error: Option<PgError> = None;
1096 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1097 expected_queries: params_batch.len(),
1098 allow_parse_complete: false,
1099 require_parse_before_bind: false,
1100 no_data_counts_as_completion: true,
1101 allow_no_data_nonterminal: false,
1102 expected_parse_completes: Some(0),
1103 });
1104
1105 loop {
1106 match self.recv_with_data_fast().await {
1107 Ok((msg_type, data)) => {
1108 if let Err(err) = flow.validate_msg_type(
1109 msg_type,
1110 "pipeline_execute_prepared_rows",
1111 error.is_some(),
1112 ) {
1113 return return_with_desync(self, err);
1114 }
1115 match msg_type {
1116 b'2' => {} b'T' => {} b'D' => {
1119 if error.is_none()
1121 && let Some(row) = data
1122 {
1123 current_rows.push(row);
1124 }
1125 }
1126 b'C' => {
1127 all_results.push(std::mem::take(&mut current_rows));
1129 }
1130 b'n' => {
1131 all_results.push(Vec::new());
1133 }
1134 b'Z' => {
1135 if all_results.len() != params_batch.len() {
1137 return Err(error.unwrap_or_else(|| {
1138 PgError::Protocol(format!(
1139 "Pipeline completion mismatch: expected {}, got {}",
1140 params_batch.len(),
1141 all_results.len()
1142 ))
1143 }));
1144 }
1145 if let Some(err) = error {
1146 return Err(err);
1147 }
1148 return Ok(all_results);
1149 }
1150 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1151 other => {
1152 return return_with_desync(
1153 self,
1154 unexpected_backend_msg_type(
1155 "pipeline_execute_prepared_rows",
1156 other,
1157 ),
1158 );
1159 }
1160 }
1161 }
1162 Err(e) => {
1163 if matches!(&e, PgError::QueryServer(_)) {
1164 capture_query_server_error(self, &mut error, e);
1165 continue;
1166 }
1167 return Err(e);
1168 }
1169 }
1170 }
1171 }
1172
1173 pub async fn pipeline_execute_prepared_rows_bytes(
1175 &mut self,
1176 stmt: &super::PreparedStatement,
1177 params_batch: &[Vec<Option<Vec<u8>>>],
1178 ) -> PgResult<Vec<Vec<Vec<Option<bytes::Bytes>>>>> {
1179 if params_batch.is_empty() {
1180 return Ok(Vec::new());
1181 }
1182
1183 if !self.prepared_statements.contains_key(&stmt.name) {
1184 return Err(PgError::Query(
1185 "Statement not prepared. Call prepare() first.".to_string(),
1186 ));
1187 }
1188
1189 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1190
1191 for params in params_batch {
1192 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1193 .map_err(|e| PgError::Encode(e.to_string()))?;
1194 PgEncoder::encode_execute_to(&mut buf);
1195 }
1196
1197 PgEncoder::encode_sync_to(&mut buf);
1198
1199 self.write_all_with_timeout(&buf, "stream write").await?;
1200 self.flush_with_timeout("stream flush").await?;
1201
1202 let mut all_results: Vec<Vec<Vec<Option<bytes::Bytes>>>> =
1204 Vec::with_capacity(params_batch.len());
1205 let mut current_rows: Vec<Vec<Option<bytes::Bytes>>> = Vec::new();
1206 let mut error: Option<PgError> = None;
1207 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1208 expected_queries: params_batch.len(),
1209 allow_parse_complete: false,
1210 require_parse_before_bind: false,
1211 no_data_counts_as_completion: true,
1212 allow_no_data_nonterminal: false,
1213 expected_parse_completes: Some(0),
1214 });
1215
1216 loop {
1217 match self.recv_data_zerocopy().await {
1218 Ok((msg_type, data)) => {
1219 if let Err(err) = flow.validate_msg_type(
1220 msg_type,
1221 "pipeline_execute_prepared_rows_bytes",
1222 error.is_some(),
1223 ) {
1224 return return_with_desync(self, err);
1225 }
1226 match msg_type {
1227 b'2' => {} b'T' => {} b'D' => {
1230 if error.is_none()
1232 && let Some(row) = data
1233 {
1234 current_rows.push(row);
1235 }
1236 }
1237 b'C' => {
1238 all_results.push(std::mem::take(&mut current_rows));
1240 }
1241 b'n' => {
1242 all_results.push(Vec::new());
1244 }
1245 b'Z' => {
1246 if all_results.len() != params_batch.len() {
1248 return Err(error.unwrap_or_else(|| {
1249 PgError::Protocol(format!(
1250 "Pipeline completion mismatch: expected {}, got {}",
1251 params_batch.len(),
1252 all_results.len()
1253 ))
1254 }));
1255 }
1256 if let Some(err) = error {
1257 return Err(err);
1258 }
1259 return Ok(all_results);
1260 }
1261 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1262 other => {
1263 return return_with_desync(
1264 self,
1265 unexpected_backend_msg_type(
1266 "pipeline_execute_prepared_rows_bytes",
1267 other,
1268 ),
1269 );
1270 }
1271 }
1272 }
1273 Err(e) => {
1274 if matches!(&e, PgError::QueryServer(_)) {
1275 capture_query_server_error(self, &mut error, e);
1276 continue;
1277 }
1278 return Err(e);
1279 }
1280 }
1281 }
1282 }
1283
1284 pub async fn pipeline_execute_prepared_rows_2cols_bytes(
1286 &mut self,
1287 stmt: &super::PreparedStatement,
1288 params_batch: &[Vec<Option<Vec<u8>>>],
1289 ) -> PgResult<Vec<Vec<(bytes::Bytes, bytes::Bytes)>>> {
1290 if params_batch.is_empty() {
1291 return Ok(Vec::new());
1292 }
1293
1294 if !self.prepared_statements.contains_key(&stmt.name) {
1295 return Err(PgError::Query(
1296 "Statement not prepared. Call prepare() first.".to_string(),
1297 ));
1298 }
1299
1300 let mut buf = BytesMut::with_capacity(params_batch.len() * 64);
1301
1302 for params in params_batch {
1303 PgEncoder::encode_bind_to(&mut buf, &stmt.name, params)
1304 .map_err(|e| PgError::Encode(e.to_string()))?;
1305 PgEncoder::encode_execute_to(&mut buf);
1306 }
1307
1308 PgEncoder::encode_sync_to(&mut buf);
1309
1310 self.write_all_with_timeout(&buf, "stream write").await?;
1311 self.flush_with_timeout("stream flush").await?;
1312
1313 let mut all_results: Vec<Vec<(bytes::Bytes, bytes::Bytes)>> =
1315 Vec::with_capacity(params_batch.len());
1316 let mut current_rows: Vec<(bytes::Bytes, bytes::Bytes)> = Vec::with_capacity(16);
1317 let mut error: Option<PgError> = None;
1318 let mut flow = FastExtendedFlowTracker::new(FastExtendedFlowConfig {
1319 expected_queries: params_batch.len(),
1320 allow_parse_complete: false,
1321 require_parse_before_bind: false,
1322 no_data_counts_as_completion: true,
1323 allow_no_data_nonterminal: false,
1324 expected_parse_completes: Some(0),
1325 });
1326
1327 loop {
1328 match self.recv_data_ultra().await {
1329 Ok((msg_type, data)) => {
1330 if let Err(err) = flow.validate_msg_type(
1331 msg_type,
1332 "pipeline_execute_prepared_rows_2cols_bytes",
1333 error.is_some(),
1334 ) {
1335 return return_with_desync(self, err);
1336 }
1337 match msg_type {
1338 b'2' | b'T' => {} b'D' => {
1340 if error.is_none()
1341 && let Some(row) = data
1342 {
1343 current_rows.push(row);
1344 }
1345 }
1346 b'C' => {
1347 all_results.push(std::mem::take(&mut current_rows));
1348 current_rows = Vec::with_capacity(16);
1349 }
1350 b'n' => {
1351 all_results.push(Vec::new());
1352 }
1353 b'Z' => {
1354 if all_results.len() != params_batch.len() {
1355 return Err(error.unwrap_or_else(|| {
1356 PgError::Protocol(format!(
1357 "Pipeline completion mismatch: expected {}, got {}",
1358 params_batch.len(),
1359 all_results.len()
1360 ))
1361 }));
1362 }
1363 if let Some(err) = error {
1364 return Err(err);
1365 }
1366 return Ok(all_results);
1367 }
1368 msg_type if is_ignorable_session_msg_type(msg_type) => {}
1369 other => {
1370 return return_with_desync(
1371 self,
1372 unexpected_backend_msg_type(
1373 "pipeline_execute_prepared_rows_2cols_bytes",
1374 other,
1375 ),
1376 );
1377 }
1378 }
1379 }
1380 Err(e) => {
1381 if matches!(&e, PgError::QueryServer(_)) {
1382 capture_query_server_error(self, &mut error, e);
1383 continue;
1384 }
1385 return Err(e);
1386 }
1387 }
1388 }
1389 }
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394 use super::*;
1395 use qail_core::ast::Qail;
1396
1397 #[test]
1398 fn ast_pipeline_mode_auto_resolves_by_batch_size() {
1399 assert_eq!(
1400 AstPipelineMode::Auto.resolve_for_batch_len(0),
1401 AstPipelineMode::OneShot
1402 );
1403 assert_eq!(
1404 AstPipelineMode::Auto.resolve_for_batch_len(7),
1405 AstPipelineMode::OneShot
1406 );
1407 assert_eq!(
1408 AstPipelineMode::Auto.resolve_for_batch_len(8),
1409 AstPipelineMode::Cached
1410 );
1411 assert_eq!(
1412 AstPipelineMode::Cached.resolve_for_batch_len(1),
1413 AstPipelineMode::Cached
1414 );
1415 assert_eq!(
1416 AstPipelineMode::OneShot.resolve_for_batch_len(1000),
1417 AstPipelineMode::OneShot
1418 );
1419 }
1420
1421 #[cfg(unix)]
1422 fn make_test_conn_with_prepared() -> PgConnection {
1423 use crate::driver::connection::StatementCache;
1424 use crate::driver::stream::PgStream;
1425 use bytes::BytesMut;
1426 use std::collections::{HashMap, VecDeque};
1427 use std::num::NonZeroUsize;
1428 use tokio::net::UnixStream;
1429
1430 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1431 let mut conn = PgConnection {
1432 stream: PgStream::Unix(unix_stream),
1433 buffer: BytesMut::with_capacity(1024),
1434 write_buf: BytesMut::with_capacity(1024),
1435 sql_buf: BytesMut::with_capacity(256),
1436 params_buf: Vec::new(),
1437 prepared_statements: HashMap::new(),
1438 stmt_cache: StatementCache::new(NonZeroUsize::new(16).expect("non-zero")),
1439 column_info_cache: HashMap::new(),
1440 process_id: 0,
1441 secret_key: 0,
1442 cancel_key_bytes: Vec::new(),
1443 requested_protocol_minor: PgConnection::default_protocol_minor(),
1444 negotiated_protocol_minor: PgConnection::default_protocol_minor(),
1445 notifications: VecDeque::new(),
1446 replication_stream_active: false,
1447 replication_mode_enabled: false,
1448 last_replication_wal_end: None,
1449 io_desynced: false,
1450 pending_statement_closes: Vec::new(),
1451 draining_statement_closes: false,
1452 };
1453 conn.prepared_statements
1454 .insert("s1".to_string(), "SELECT 1".to_string());
1455 conn.stmt_cache.put(1, "s1".to_string());
1456 conn
1457 }
1458
1459 fn server_error(code: &str, message: &str) -> PgError {
1460 PgError::QueryServer(super::super::PgServerError {
1461 severity: "ERROR".to_string(),
1462 code: code.to_string(),
1463 message: message.to_string(),
1464 detail: None,
1465 hint: None,
1466 })
1467 }
1468
1469 #[cfg(unix)]
1470 #[tokio::test]
1471 async fn capture_query_server_error_clears_prepared_state_on_retryable_error() {
1472 let mut conn = make_test_conn_with_prepared();
1473 let mut slot = None;
1474 let err = server_error("26000", "prepared statement \"s1\" does not exist");
1475 capture_query_server_error(&mut conn, &mut slot, err);
1476
1477 assert!(slot.is_some());
1478 assert!(conn.prepared_statements.is_empty());
1479 assert_eq!(conn.stmt_cache.len(), 0);
1480 }
1481
1482 #[cfg(unix)]
1483 #[tokio::test]
1484 async fn capture_query_server_error_preserves_prepared_state_on_non_retryable_error() {
1485 let mut conn = make_test_conn_with_prepared();
1486 let mut slot = None;
1487 let err = server_error("23505", "duplicate key value violates unique constraint");
1488 capture_query_server_error(&mut conn, &mut slot, err);
1489
1490 assert!(slot.is_some());
1491 assert_eq!(conn.prepared_statements.len(), 1);
1492 assert_eq!(conn.stmt_cache.len(), 1);
1493 }
1494
1495 #[cfg(unix)]
1496 #[tokio::test]
1497 async fn capture_query_server_error_does_not_override_existing_error() {
1498 let mut conn = make_test_conn_with_prepared();
1499 let mut slot = Some(server_error("23505", "duplicate key"));
1500 let retryable = server_error("26000", "prepared statement \"s1\" does not exist");
1501 capture_query_server_error(&mut conn, &mut slot, retryable);
1502
1503 assert_eq!(conn.prepared_statements.len(), 1);
1504 assert_eq!(conn.stmt_cache.len(), 1);
1505 assert_eq!(
1506 slot.and_then(|e| e.sqlstate().map(str::to_string))
1507 .as_deref(),
1508 Some("23505")
1509 );
1510 }
1511
1512 #[cfg(unix)]
1513 #[tokio::test]
1514 async fn pipeline_ast_cached_rolls_back_new_state_on_encode_error() {
1515 let mut conn = make_test_conn_with_prepared();
1516 let baseline = conn.prepared_statements.len();
1517 let baseline_stmt_cache = conn.stmt_cache.len();
1518
1519 let cmds = vec![
1520 Qail::get("harbors").columns(["id", "name"]).limit(1),
1521 Qail::get("bad\0table").columns(["id"]).limit(1),
1522 ];
1523
1524 let err = conn
1525 .pipeline_execute_count_ast_cached(&cmds)
1526 .await
1527 .expect_err("expected encode error for NUL byte in table name");
1528
1529 assert!(matches!(err, PgError::Encode(_)));
1530 assert_eq!(conn.prepared_statements.len(), baseline);
1531 assert_eq!(conn.stmt_cache.len(), baseline_stmt_cache);
1532 assert!(conn.prepared_statements.contains_key("s1"));
1533 }
1534
1535 #[cfg(unix)]
1536 #[tokio::test]
1537 async fn pipeline_simple_ast_empty_batch_returns_zero_without_io() {
1538 let mut conn = make_test_conn_with_prepared();
1539 let res = conn
1540 .pipeline_execute_count_simple_ast(&[])
1541 .await
1542 .expect("empty batch should be a fast no-op");
1543 assert_eq!(res, 0);
1544 assert!(!conn.is_io_desynced());
1545 }
1546}