1use super::connection::PooledConnection;
4use super::lifecycle::MAX_HOT_STATEMENTS;
5use crate::driver::{
6 PgConnection, PgError, PgResult, ResultFormat,
7 extended_flow::{ExtendedFlowConfig, ExtendedFlowTracker},
8 is_ignorable_session_message, unexpected_backend_message,
9};
10use std::sync::Arc;
11
12#[inline]
13fn rollback_cache_miss_statement_registration(
14 conn: &mut PgConnection,
15 is_cache_miss: bool,
16 sql_hash: u64,
17 stmt_name: &str,
18) {
19 if is_cache_miss {
20 conn.stmt_cache.remove(&sql_hash);
21 conn.prepared_statements.remove(stmt_name);
22 conn.column_info_cache.remove(&sql_hash);
23 }
24}
25
26#[inline]
27fn register_hot_statement_after_parse_success(
28 pool: &super::lifecycle::PgPoolInner,
29 sql_hash: u64,
30 stmt_name: &str,
31 sql: &str,
32) {
33 if let Ok(mut hot) = pool.hot_statements.write()
34 && (hot.contains_key(&sql_hash) || hot.len() < MAX_HOT_STATEMENTS)
35 {
36 hot.insert(sql_hash, (stmt_name.to_string(), sql.to_string()));
37 }
38}
39
40#[inline]
41fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
42 if matches!(
43 err,
44 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
45 ) {
46 conn.mark_io_desynced();
47 }
48 Err(err)
49}
50
51#[inline]
52fn encoded_sql_str(sql_buf: &[u8]) -> PgResult<&str> {
53 std::str::from_utf8(sql_buf)
54 .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))
55}
56
57async fn drain_extended_responses_after_rls_setup_error(conn: &mut PgConnection) -> PgResult<()> {
58 loop {
59 let msg = conn.recv().await?;
60 match msg {
61 crate::protocol::BackendMessage::ReadyForQuery(_) => return Ok(()),
62 crate::protocol::BackendMessage::ErrorResponse(_) => {}
63 msg if is_ignorable_session_message(&msg) => {}
64 _ => {}
66 }
67 }
68}
69
70fn copy_export_table_sql(table: &str, columns: &[String]) -> PgResult<String> {
71 let cols: Vec<String> = columns
72 .iter()
73 .map(|c| crate::driver::copy::quote_copy_column_ident(c))
74 .collect::<PgResult<_>>()?;
75
76 Ok(format!(
77 "COPY {} ({}) TO STDOUT",
78 crate::driver::copy::quote_copy_table_ref(table)?,
79 cols.join(", ")
80 ))
81}
82
83impl PooledConnection {
84 pub async fn fetch_all_uncached(
87 &mut self,
88 cmd: &qail_core::ast::Qail,
89 ) -> PgResult<Vec<crate::driver::PgRow>> {
90 self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
91 .await
92 }
93
94 pub async fn query_raw_with_params(
102 &mut self,
103 sql: &str,
104 params: &[Option<Vec<u8>>],
105 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
106 let conn = self.conn_mut()?;
107 conn.query(sql, params).await
108 }
109
110 pub async fn query_rows_with_params(
116 &mut self,
117 sql: &str,
118 params: &[Option<Vec<u8>>],
119 ) -> PgResult<Vec<crate::driver::PgRow>> {
120 self.query_rows_with_params_with_format(sql, params, ResultFormat::Text)
121 .await
122 }
123
124 pub async fn query_rows_with_params_with_format(
127 &mut self,
128 sql: &str,
129 params: &[Option<Vec<u8>>],
130 result_format: ResultFormat,
131 ) -> PgResult<Vec<crate::driver::PgRow>> {
132 let conn = self.conn_mut()?;
133 conn.query_rows_with_result_format(sql, params, result_format.as_wire_code())
134 .await
135 }
136
137 pub async fn query_rows_with_param_types_with_format(
140 &mut self,
141 sql: &str,
142 param_types: &[u32],
143 params: &[Option<Vec<u8>>],
144 result_format: ResultFormat,
145 ) -> PgResult<Vec<crate::driver::PgRow>> {
146 let conn = self.conn_mut()?;
147 conn.query_rows_with_param_types_and_result_format(
148 sql,
149 param_types,
150 params,
151 result_format.as_wire_code(),
152 )
153 .await
154 }
155
156 pub async fn probe_query_with_param_types(
159 &mut self,
160 sql: &str,
161 param_types: &[u32],
162 params: &[Option<Vec<u8>>],
163 ) -> PgResult<()> {
164 let conn = self.conn_mut()?;
165 conn.probe_query_with_param_types(sql, param_types, params)
166 .await
167 }
168
169 pub async fn copy_export(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<Vec<String>>> {
171 self.conn_mut()?.copy_export(cmd).await
172 }
173
174 pub async fn copy_export_stream_raw<F, Fut>(
176 &mut self,
177 cmd: &qail_core::ast::Qail,
178 on_chunk: F,
179 ) -> PgResult<()>
180 where
181 F: FnMut(Vec<u8>) -> Fut,
182 Fut: std::future::Future<Output = PgResult<()>>,
183 {
184 self.conn_mut()?.copy_export_stream_raw(cmd, on_chunk).await
185 }
186
187 pub async fn copy_export_stream_rows<F>(
189 &mut self,
190 cmd: &qail_core::ast::Qail,
191 on_row: F,
192 ) -> PgResult<()>
193 where
194 F: FnMut(Vec<String>) -> PgResult<()>,
195 {
196 self.conn_mut()?.copy_export_stream_rows(cmd, on_row).await
197 }
198
199 pub async fn copy_export_table(
201 &mut self,
202 table: &str,
203 columns: &[String],
204 ) -> PgResult<Vec<u8>> {
205 let sql = copy_export_table_sql(table, columns)?;
206 self.conn_mut()?.copy_out_raw(&sql).await
207 }
208
209 pub async fn copy_export_table_stream<F, Fut>(
211 &mut self,
212 table: &str,
213 columns: &[String],
214 on_chunk: F,
215 ) -> PgResult<()>
216 where
217 F: FnMut(Vec<u8>) -> Fut,
218 Fut: std::future::Future<Output = PgResult<()>>,
219 {
220 let sql = copy_export_table_sql(table, columns)?;
221 self.conn_mut()?.copy_out_raw_stream(&sql, on_chunk).await
222 }
223
224 pub async fn fetch_all_uncached_with_format(
226 &mut self,
227 cmd: &qail_core::ast::Qail,
228 result_format: ResultFormat,
229 ) -> PgResult<Vec<crate::driver::PgRow>> {
230 use crate::driver::ColumnInfo;
231 use crate::protocol::AstEncoder;
232
233 let conn = self.conn_mut()?;
234
235 AstEncoder::encode_cmd_reuse_into_with_result_format(
236 cmd,
237 &mut conn.sql_buf,
238 &mut conn.params_buf,
239 &mut conn.write_buf,
240 result_format.as_wire_code(),
241 )
242 .map_err(|e| PgError::Encode(e.to_string()))?;
243
244 conn.flush_write_buf().await?;
245
246 let mut rows: Vec<crate::driver::PgRow> = Vec::new();
247 let mut column_info: Option<Arc<ColumnInfo>> = None;
248 let mut error: Option<PgError> = None;
249 let mut flow =
250 ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_describe_portal_execute());
251
252 loop {
253 let msg = conn.recv().await?;
254 if let Err(err) = flow.validate(&msg, "pool fetch_all execute", error.is_some()) {
255 return return_with_desync(conn, err);
256 }
257 match msg {
258 crate::protocol::BackendMessage::ParseComplete
259 | crate::protocol::BackendMessage::BindComplete => {}
260 crate::protocol::BackendMessage::RowDescription(fields) => {
261 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
262 }
263 crate::protocol::BackendMessage::DataRow(data) => {
264 if error.is_none() {
265 rows.push(crate::driver::PgRow {
266 columns: data,
267 column_info: column_info.clone(),
268 });
269 }
270 }
271 crate::protocol::BackendMessage::NoData => {}
272 crate::protocol::BackendMessage::CommandComplete(_) => {}
273 crate::protocol::BackendMessage::ReadyForQuery(_) => {
274 if let Some(err) = error {
275 return Err(err);
276 }
277 return Ok(rows);
278 }
279 crate::protocol::BackendMessage::ErrorResponse(err) => {
280 if error.is_none() {
281 error = Some(PgError::QueryServer(err.into()));
282 }
283 }
284 msg if is_ignorable_session_message(&msg) => {}
285 other => {
286 return return_with_desync(
287 conn,
288 unexpected_backend_message("pool fetch_all execute", &other),
289 );
290 }
291 }
292 }
293 }
294
295 pub async fn fetch_all_fast(
299 &mut self,
300 cmd: &qail_core::ast::Qail,
301 ) -> PgResult<Vec<crate::driver::PgRow>> {
302 self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
303 .await
304 }
305
306 pub async fn fetch_all_fast_with_format(
308 &mut self,
309 cmd: &qail_core::ast::Qail,
310 result_format: ResultFormat,
311 ) -> PgResult<Vec<crate::driver::PgRow>> {
312 use crate::protocol::AstEncoder;
313
314 let conn = self.conn_mut()?;
315
316 AstEncoder::encode_cmd_reuse_into_with_result_format(
317 cmd,
318 &mut conn.sql_buf,
319 &mut conn.params_buf,
320 &mut conn.write_buf,
321 result_format.as_wire_code(),
322 )
323 .map_err(|e| PgError::Encode(e.to_string()))?;
324
325 conn.flush_write_buf().await?;
326
327 let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
328 let mut error: Option<PgError> = None;
329 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(true));
330
331 loop {
332 let res = conn.recv_with_data_fast().await;
333 match res {
334 Ok((msg_type, data)) => {
335 if let Err(err) = flow.validate_msg_type(
336 msg_type,
337 "pool fetch_all_fast execute",
338 error.is_some(),
339 ) {
340 return return_with_desync(conn, err);
341 }
342 match msg_type {
343 b'D' => {
344 if error.is_none()
345 && let Some(columns) = data
346 {
347 rows.push(crate::driver::PgRow {
348 columns,
349 column_info: None,
350 });
351 }
352 }
353 b'Z' => {
354 if let Some(err) = error {
355 return Err(err);
356 }
357 return Ok(rows);
358 }
359 _ => {}
360 }
361 }
362 Err(e) => {
363 if matches!(&e, PgError::QueryServer(_)) {
364 if error.is_none() {
365 error = Some(e);
366 }
367 continue;
368 }
369 return Err(e);
370 }
371 }
372 }
373 }
374
375 pub async fn fetch_all_cached(
380 &mut self,
381 cmd: &qail_core::ast::Qail,
382 ) -> PgResult<Vec<crate::driver::PgRow>> {
383 self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
384 .await
385 }
386
387 pub async fn fetch_all_cached_with_format(
389 &mut self,
390 cmd: &qail_core::ast::Qail,
391 result_format: ResultFormat,
392 ) -> PgResult<Vec<crate::driver::PgRow>> {
393 let mut retried = false;
394 loop {
395 match self
396 .fetch_all_cached_with_format_once(cmd, result_format)
397 .await
398 {
399 Ok(rows) => return Ok(rows),
400 Err(err)
401 if !retried
402 && (err.is_prepared_statement_retryable()
403 || err.is_prepared_statement_already_exists()) =>
404 {
405 retried = true;
406 if err.is_prepared_statement_retryable()
407 && let Some(conn) = self.conn.as_mut()
408 {
409 conn.clear_prepared_statement_state();
410 }
411 }
412 Err(err) => return Err(err),
413 }
414 }
415 }
416
417 pub async fn fetch_typed<T: crate::driver::row::QailRow>(
419 &mut self,
420 cmd: &qail_core::ast::Qail,
421 ) -> PgResult<Vec<T>> {
422 self.fetch_typed_with_format(cmd, ResultFormat::Text).await
423 }
424
425 pub async fn fetch_typed_with_format<T: crate::driver::row::QailRow>(
430 &mut self,
431 cmd: &qail_core::ast::Qail,
432 result_format: ResultFormat,
433 ) -> PgResult<Vec<T>> {
434 let rows = self
435 .fetch_all_cached_with_format(cmd, result_format)
436 .await?;
437 Ok(rows.iter().map(T::from_row).collect())
438 }
439
440 pub async fn fetch_one_typed<T: crate::driver::row::QailRow>(
442 &mut self,
443 cmd: &qail_core::ast::Qail,
444 ) -> PgResult<Option<T>> {
445 self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
446 .await
447 }
448
449 pub async fn fetch_one_typed_with_format<T: crate::driver::row::QailRow>(
451 &mut self,
452 cmd: &qail_core::ast::Qail,
453 result_format: ResultFormat,
454 ) -> PgResult<Option<T>> {
455 let rows = self
456 .fetch_all_cached_with_format(cmd, result_format)
457 .await?;
458 Ok(rows.first().map(T::from_row))
459 }
460
461 async fn fetch_all_cached_with_format_once(
462 &mut self,
463 cmd: &qail_core::ast::Qail,
464 result_format: ResultFormat,
465 ) -> PgResult<Vec<crate::driver::PgRow>> {
466 use crate::driver::ColumnInfo;
467 use std::collections::hash_map::DefaultHasher;
468 use std::hash::{Hash, Hasher};
469
470 let pool = std::sync::Arc::clone(&self.pool);
471 let conn = self.conn.as_mut().ok_or_else(|| {
472 PgError::Connection("Connection already released back to pool".into())
473 })?;
474
475 conn.sql_buf.clear();
476 conn.params_buf.clear();
477
478 match cmd.action {
480 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
481 crate::protocol::ast_encoder::dml::encode_select(
482 cmd,
483 &mut conn.sql_buf,
484 &mut conn.params_buf,
485 )?;
486 }
487 qail_core::ast::Action::Add => {
488 crate::protocol::ast_encoder::dml::encode_insert(
489 cmd,
490 &mut conn.sql_buf,
491 &mut conn.params_buf,
492 )?;
493 }
494 qail_core::ast::Action::Set => {
495 crate::protocol::ast_encoder::dml::encode_update(
496 cmd,
497 &mut conn.sql_buf,
498 &mut conn.params_buf,
499 )?;
500 }
501 qail_core::ast::Action::Del => {
502 crate::protocol::ast_encoder::dml::encode_delete(
503 cmd,
504 &mut conn.sql_buf,
505 &mut conn.params_buf,
506 )?;
507 }
508 _ => {
509 return self
511 .fetch_all_uncached_with_format(cmd, result_format)
512 .await;
513 }
514 }
515
516 let mut hasher = DefaultHasher::new();
517 conn.sql_buf.hash(&mut hasher);
518 let sql_hash = hasher.finish();
519
520 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
521
522 conn.write_buf.clear();
523
524 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
525 name
526 } else {
527 let name = format!("qail_{:x}", sql_hash);
528
529 conn.evict_prepared_if_full();
530
531 let sql_str = encoded_sql_str(&conn.sql_buf)?;
532
533 use crate::protocol::PgEncoder;
534 let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
535 let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
536 conn.write_buf.extend_from_slice(&parse_msg);
537 conn.write_buf.extend_from_slice(&describe_msg);
538
539 conn.stmt_cache.put(sql_hash, name.clone());
540 conn.prepared_statements
541 .insert(name.clone(), sql_str.to_string());
542
543 name
544 };
545
546 use crate::protocol::PgEncoder;
547 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
548 &mut conn.write_buf,
549 &stmt_name,
550 &conn.params_buf,
551 result_format.as_wire_code(),
552 ) {
553 if is_cache_miss {
554 conn.stmt_cache.remove(&sql_hash);
555 conn.prepared_statements.remove(&stmt_name);
556 conn.column_info_cache.remove(&sql_hash);
557 }
558 return Err(PgError::Encode(e.to_string()));
559 }
560 PgEncoder::encode_execute_to(&mut conn.write_buf);
561 PgEncoder::encode_sync_to(&mut conn.write_buf);
562
563 if let Err(err) = conn.flush_write_buf().await {
564 if is_cache_miss {
565 conn.stmt_cache.remove(&sql_hash);
566 conn.prepared_statements.remove(&stmt_name);
567 conn.column_info_cache.remove(&sql_hash);
568 }
569 return Err(err);
570 }
571
572 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
573
574 let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
575 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
576 let mut error: Option<PgError> = None;
577 let mut flow = ExtendedFlowTracker::new(
578 ExtendedFlowConfig::parse_describe_statement_bind_execute(is_cache_miss),
579 );
580
581 loop {
582 let msg = match conn.recv().await {
583 Ok(msg) => msg,
584 Err(err) => {
585 if is_cache_miss && !flow.saw_parse_complete() {
586 conn.stmt_cache.remove(&sql_hash);
587 conn.prepared_statements.remove(&stmt_name);
588 conn.column_info_cache.remove(&sql_hash);
589 }
590 return Err(err);
591 }
592 };
593 if let Err(err) = flow.validate(&msg, "pool fetch_all_cached execute", error.is_some())
594 {
595 if is_cache_miss && !flow.saw_parse_complete() {
596 conn.stmt_cache.remove(&sql_hash);
597 conn.prepared_statements.remove(&stmt_name);
598 conn.column_info_cache.remove(&sql_hash);
599 }
600 return return_with_desync(conn, err);
601 }
602 match msg {
603 crate::protocol::BackendMessage::ParseComplete => {}
604 crate::protocol::BackendMessage::BindComplete => {}
605 crate::protocol::BackendMessage::ParameterDescription(_) => {}
606 crate::protocol::BackendMessage::RowDescription(fields) => {
607 let info = Arc::new(ColumnInfo::from_fields(&fields));
608 if is_cache_miss {
609 conn.column_info_cache.insert(sql_hash, Arc::clone(&info));
610 }
611 column_info = Some(info);
612 }
613 crate::protocol::BackendMessage::DataRow(data) => {
614 if error.is_none() {
615 rows.push(crate::driver::PgRow {
616 columns: data,
617 column_info: column_info.clone(),
618 });
619 }
620 }
621 crate::protocol::BackendMessage::CommandComplete(_) => {}
622 crate::protocol::BackendMessage::ReadyForQuery(_) => {
623 if let Some(err) = error {
624 if is_cache_miss
625 && !flow.saw_parse_complete()
626 && !err.is_prepared_statement_already_exists()
627 {
628 conn.stmt_cache.remove(&sql_hash);
629 conn.prepared_statements.remove(&stmt_name);
630 conn.column_info_cache.remove(&sql_hash);
631 }
632 return Err(err);
633 }
634 if is_cache_miss && !flow.saw_parse_complete() {
635 conn.stmt_cache.remove(&sql_hash);
636 conn.prepared_statements.remove(&stmt_name);
637 conn.column_info_cache.remove(&sql_hash);
638 return return_with_desync(
639 conn,
640 PgError::Protocol(
641 "Cache miss query reached ReadyForQuery without ParseComplete"
642 .to_string(),
643 ),
644 );
645 }
646 if is_cache_miss && let Some(sql) = conn.prepared_statements.get(&stmt_name) {
647 register_hot_statement_after_parse_success(
648 &pool, sql_hash, &stmt_name, sql,
649 );
650 }
651 return Ok(rows);
652 }
653 crate::protocol::BackendMessage::ErrorResponse(err) => {
654 if error.is_none() {
655 error = Some(PgError::QueryServer(err.into()));
656 }
657 }
658 msg if is_ignorable_session_message(&msg) => {}
659 other => {
660 if is_cache_miss && !flow.saw_parse_complete() {
661 conn.stmt_cache.remove(&sql_hash);
662 conn.prepared_statements.remove(&stmt_name);
663 conn.column_info_cache.remove(&sql_hash);
664 }
665 return return_with_desync(
666 conn,
667 unexpected_backend_message("pool fetch_all_cached execute", &other),
668 );
669 }
670 }
671 }
672 }
673
674 pub async fn fetch_all_with_rls(
693 &mut self,
694 cmd: &qail_core::ast::Qail,
695 rls_sql: &str,
696 ) -> PgResult<Vec<crate::driver::PgRow>> {
697 self.fetch_all_with_rls_with_format(cmd, rls_sql, ResultFormat::Text)
698 .await
699 }
700
701 pub async fn fetch_all_with_rls_with_format(
703 &mut self,
704 cmd: &qail_core::ast::Qail,
705 rls_sql: &str,
706 result_format: ResultFormat,
707 ) -> PgResult<Vec<crate::driver::PgRow>> {
708 let mut retried = false;
709 loop {
710 match self
711 .fetch_all_with_rls_with_format_once(cmd, rls_sql, result_format)
712 .await
713 {
714 Ok(rows) => return Ok(rows),
715 Err(err)
716 if !retried
717 && (err.is_prepared_statement_retryable()
718 || err.is_prepared_statement_already_exists()) =>
719 {
720 retried = true;
721 if let Some(conn) = self.conn.as_mut() {
722 if err.is_prepared_statement_retryable() {
723 conn.clear_prepared_statement_state();
724 }
725 let _ = conn.execute_simple("ROLLBACK").await;
728 }
729 self.rls_dirty = false;
730 }
731 Err(err) => return Err(err),
732 }
733 }
734 }
735
736 async fn fetch_all_with_rls_with_format_once(
737 &mut self,
738 cmd: &qail_core::ast::Qail,
739 rls_sql: &str,
740 result_format: ResultFormat,
741 ) -> PgResult<Vec<crate::driver::PgRow>> {
742 use crate::driver::ColumnInfo;
743 use std::collections::hash_map::DefaultHasher;
744 use std::hash::{Hash, Hasher};
745
746 let pool = std::sync::Arc::clone(&self.pool);
747 let conn = self.conn.as_mut().ok_or_else(|| {
748 PgError::Connection("Connection already released back to pool".into())
749 })?;
750
751 if !crate::protocol::AstEncoder::encode_cacheable_cmd_sql_to(
752 cmd,
753 &mut conn.sql_buf,
754 &mut conn.params_buf,
755 )? {
756 conn.execute_simple(rls_sql).await?;
758 self.rls_dirty = true;
759 return self
760 .fetch_all_uncached_with_format(cmd, result_format)
761 .await;
762 }
763
764 let mut hasher = DefaultHasher::new();
765 conn.sql_buf.hash(&mut hasher);
766 let sql_hash = hasher.finish();
767
768 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
769
770 conn.write_buf.clear();
771
772 let rls_msg = crate::protocol::PgEncoder::try_encode_query_string(rls_sql)?;
777 conn.write_buf.extend_from_slice(&rls_msg);
778
779 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
781 name
782 } else {
783 let name = format!("qail_{:x}", sql_hash);
784
785 conn.evict_prepared_if_full();
786
787 let sql_str = encoded_sql_str(&conn.sql_buf)?;
788
789 use crate::protocol::PgEncoder;
790 let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
791 let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
792 conn.write_buf.extend_from_slice(&parse_msg);
793 conn.write_buf.extend_from_slice(&describe_msg);
794
795 conn.stmt_cache.put(sql_hash, name.clone());
796 conn.prepared_statements
797 .insert(name.clone(), sql_str.to_string());
798
799 name
800 };
801
802 use crate::protocol::PgEncoder;
803 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
804 &mut conn.write_buf,
805 &stmt_name,
806 &conn.params_buf,
807 result_format.as_wire_code(),
808 ) {
809 rollback_cache_miss_statement_registration(conn, is_cache_miss, sql_hash, &stmt_name);
810 return Err(PgError::Encode(e.to_string()));
811 }
812 PgEncoder::encode_execute_to(&mut conn.write_buf);
813 PgEncoder::encode_sync_to(&mut conn.write_buf);
814
815 if let Err(err) = conn.flush_write_buf().await {
817 rollback_cache_miss_statement_registration(conn, is_cache_miss, sql_hash, &stmt_name);
818 return Err(err);
819 }
820
821 self.rls_dirty = true;
823
824 let mut rls_error: Option<PgError> = None;
828 loop {
829 let msg = match conn.recv().await {
830 Ok(msg) => msg,
831 Err(err) => {
832 rollback_cache_miss_statement_registration(
833 conn,
834 is_cache_miss,
835 sql_hash,
836 &stmt_name,
837 );
838 return Err(err);
839 }
840 };
841 match msg {
842 crate::protocol::BackendMessage::ReadyForQuery(_) => {
843 if let Some(err) = rls_error {
845 rollback_cache_miss_statement_registration(
846 conn,
847 is_cache_miss,
848 sql_hash,
849 &stmt_name,
850 );
851 if let Err(drain_err) =
852 drain_extended_responses_after_rls_setup_error(conn).await
853 {
854 tracing::warn!(
855 error = %drain_err,
856 "failed to drain pipelined extended responses after RLS setup error"
857 );
858 }
859 return Err(err);
860 }
861 break;
862 }
863 crate::protocol::BackendMessage::ErrorResponse(err) => {
864 if rls_error.is_none() {
865 rls_error = Some(PgError::QueryServer(err.into()));
866 }
867 }
868 crate::protocol::BackendMessage::CommandComplete(_)
870 | crate::protocol::BackendMessage::DataRow(_)
871 | crate::protocol::BackendMessage::RowDescription(_)
872 | crate::protocol::BackendMessage::ParseComplete
873 | crate::protocol::BackendMessage::BindComplete => {}
874 msg if is_ignorable_session_message(&msg) => {}
875 other => {
876 rollback_cache_miss_statement_registration(
877 conn,
878 is_cache_miss,
879 sql_hash,
880 &stmt_name,
881 );
882 return return_with_desync(
883 conn,
884 unexpected_backend_message("pool rls setup", &other),
885 );
886 }
887 }
888 }
889
890 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
892
893 let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
894 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
895 let mut error: Option<PgError> = None;
896 let mut flow = ExtendedFlowTracker::new(
897 ExtendedFlowConfig::parse_describe_statement_bind_execute(is_cache_miss),
898 );
899
900 loop {
901 let msg = match conn.recv().await {
902 Ok(msg) => msg,
903 Err(err) => {
904 if is_cache_miss && !flow.saw_parse_complete() {
905 rollback_cache_miss_statement_registration(
906 conn,
907 is_cache_miss,
908 sql_hash,
909 &stmt_name,
910 );
911 }
912 return Err(err);
913 }
914 };
915 if let Err(err) =
916 flow.validate(&msg, "pool fetch_all_with_rls execute", error.is_some())
917 {
918 if is_cache_miss && !flow.saw_parse_complete() {
919 rollback_cache_miss_statement_registration(
920 conn,
921 is_cache_miss,
922 sql_hash,
923 &stmt_name,
924 );
925 }
926 return return_with_desync(conn, err);
927 }
928 match msg {
929 crate::protocol::BackendMessage::ParseComplete => {}
930 crate::protocol::BackendMessage::BindComplete => {}
931 crate::protocol::BackendMessage::ParameterDescription(_) => {}
932 crate::protocol::BackendMessage::RowDescription(fields) => {
933 let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
934 if is_cache_miss {
935 conn.column_info_cache
936 .insert(sql_hash, std::sync::Arc::clone(&info));
937 }
938 column_info = Some(info);
939 }
940 crate::protocol::BackendMessage::DataRow(data) => {
941 if error.is_none() {
942 rows.push(crate::driver::PgRow {
943 columns: data,
944 column_info: column_info.clone(),
945 });
946 }
947 }
948 crate::protocol::BackendMessage::CommandComplete(_) => {}
949 crate::protocol::BackendMessage::ReadyForQuery(_) => {
950 if let Some(err) = error {
951 if is_cache_miss
952 && !flow.saw_parse_complete()
953 && !err.is_prepared_statement_already_exists()
954 {
955 rollback_cache_miss_statement_registration(
956 conn,
957 is_cache_miss,
958 sql_hash,
959 &stmt_name,
960 );
961 }
962 return Err(err);
963 }
964 if is_cache_miss && !flow.saw_parse_complete() {
965 rollback_cache_miss_statement_registration(
966 conn,
967 is_cache_miss,
968 sql_hash,
969 &stmt_name,
970 );
971 return return_with_desync(
972 conn,
973 PgError::Protocol(
974 "Cache miss query reached ReadyForQuery without ParseComplete"
975 .to_string(),
976 ),
977 );
978 }
979 if is_cache_miss && let Some(sql) = conn.prepared_statements.get(&stmt_name) {
980 register_hot_statement_after_parse_success(
981 &pool, sql_hash, &stmt_name, sql,
982 );
983 }
984 return Ok(rows);
985 }
986 crate::protocol::BackendMessage::ErrorResponse(err) => {
987 if error.is_none() {
988 error = Some(PgError::QueryServer(err.into()));
989 }
990 }
991 msg if is_ignorable_session_message(&msg) => {}
992 other => {
993 if is_cache_miss && !flow.saw_parse_complete() {
994 rollback_cache_miss_statement_registration(
995 conn,
996 is_cache_miss,
997 sql_hash,
998 &stmt_name,
999 );
1000 }
1001 return return_with_desync(
1002 conn,
1003 unexpected_backend_message("pool fetch_all_with_rls execute", &other),
1004 );
1005 }
1006 }
1007 }
1008 }
1009}
1010
1011#[cfg(test)]
1012mod tests {
1013 use super::{copy_export_table_sql, encoded_sql_str, return_with_desync};
1014
1015 #[cfg(unix)]
1016 fn test_conn() -> crate::driver::PgConnection {
1017 use crate::driver::connection::StatementCache;
1018 use crate::driver::stream::PgStream;
1019 use bytes::BytesMut;
1020 use std::collections::{HashMap, VecDeque};
1021 use std::num::NonZeroUsize;
1022 use tokio::net::UnixStream;
1023
1024 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1025 crate::driver::PgConnection {
1026 stream: PgStream::Unix(unix_stream),
1027 buffer: BytesMut::with_capacity(1024),
1028 write_buf: BytesMut::with_capacity(1024),
1029 sql_buf: BytesMut::with_capacity(256),
1030 params_buf: Vec::new(),
1031 prepared_statements: HashMap::new(),
1032 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
1033 column_info_cache: HashMap::new(),
1034 process_id: 0,
1035 cancel_key_bytes: Vec::new(),
1036 requested_protocol_minor: crate::driver::PgConnection::default_protocol_minor(),
1037 negotiated_protocol_minor: crate::driver::PgConnection::default_protocol_minor(),
1038 notifications: VecDeque::new(),
1039 replication_stream_active: false,
1040 replication_mode_enabled: false,
1041 last_replication_wal_end: None,
1042 io_desynced: false,
1043 pending_statement_closes: Vec::new(),
1044 draining_statement_closes: false,
1045 }
1046 }
1047
1048 #[test]
1049 fn pool_copy_export_table_sql_preserves_schema_qualified_table() {
1050 let sql = copy_export_table_sql(
1051 "tenant_a.users",
1052 &["id".to_string(), "display\"name".to_string()],
1053 )
1054 .unwrap();
1055
1056 assert_eq!(
1057 sql,
1058 "COPY \"tenant_a\".\"users\" (\"id\", \"display\"\"name\") TO STDOUT"
1059 );
1060 }
1061
1062 #[test]
1063 fn pool_copy_export_table_sql_rejects_nul_bytes() {
1064 assert!(copy_export_table_sql("tenant\0.users", &["id".to_string()]).is_err());
1065 assert!(copy_export_table_sql("users", &["id\0".to_string()]).is_err());
1066 }
1067
1068 #[test]
1069 fn pool_encoded_sql_str_rejects_invalid_utf8() {
1070 let err = encoded_sql_str(&[0xff]).expect_err("invalid SQL UTF-8 must fail");
1071 assert!(err.to_string().contains("encoded SQL is not UTF-8"));
1072 }
1073
1074 #[cfg(unix)]
1075 #[tokio::test]
1076 async fn pool_return_with_desync_marks_protocol_error() {
1077 let mut conn = test_conn();
1078
1079 let err = return_with_desync::<()>(
1080 &mut conn,
1081 crate::driver::PgError::Protocol("bad response ordering".to_string()),
1082 )
1083 .expect_err("protocol error must be returned");
1084
1085 assert!(err.to_string().contains("bad response ordering"));
1086 assert!(conn.is_io_desynced());
1087 }
1088}