1use super::connection::PooledConnection;
4use super::lifecycle::MAX_HOT_STATEMENTS;
5use crate::driver::{
6 PgConnection, PgError, PgResult, ResultFormat,
7 extended_flow::{ExtendedFlowConfig, ExtendedFlowTracker},
8 is_ignorable_session_message, unexpected_backend_message,
9};
10use std::sync::Arc;
11
12#[inline]
13fn rollback_cache_miss_statement_registration(
14 conn: &mut PgConnection,
15 is_cache_miss: bool,
16 sql_hash: u64,
17 stmt_name: &str,
18) {
19 if is_cache_miss {
20 conn.stmt_cache.remove(&sql_hash);
21 conn.prepared_statements.remove(stmt_name);
22 conn.column_info_cache.remove(&sql_hash);
23 }
24}
25
26#[inline]
27fn return_with_desync<T>(conn: &mut PgConnection, err: PgError) -> PgResult<T> {
28 if matches!(
29 err,
30 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
31 ) {
32 conn.mark_io_desynced();
33 }
34 Err(err)
35}
36
37#[inline]
38fn encoded_sql_str(sql_buf: &[u8]) -> PgResult<&str> {
39 std::str::from_utf8(sql_buf)
40 .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))
41}
42
43async fn drain_extended_responses_after_rls_setup_error(conn: &mut PgConnection) -> PgResult<()> {
44 loop {
45 let msg = conn.recv().await?;
46 match msg {
47 crate::protocol::BackendMessage::ReadyForQuery(_) => return Ok(()),
48 crate::protocol::BackendMessage::ErrorResponse(_) => {}
49 msg if is_ignorable_session_message(&msg) => {}
50 _ => {}
52 }
53 }
54}
55
56fn copy_export_table_sql(table: &str, columns: &[String]) -> PgResult<String> {
57 let cols: Vec<String> = columns
58 .iter()
59 .map(|c| crate::driver::copy::quote_copy_column_ident(c))
60 .collect::<PgResult<_>>()?;
61
62 Ok(format!(
63 "COPY {} ({}) TO STDOUT",
64 crate::driver::copy::quote_copy_table_ref(table)?,
65 cols.join(", ")
66 ))
67}
68
69impl PooledConnection {
70 pub async fn fetch_all_uncached(
73 &mut self,
74 cmd: &qail_core::ast::Qail,
75 ) -> PgResult<Vec<crate::driver::PgRow>> {
76 self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
77 .await
78 }
79
80 pub async fn query_raw_with_params(
88 &mut self,
89 sql: &str,
90 params: &[Option<Vec<u8>>],
91 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
92 let conn = self.conn_mut()?;
93 conn.query(sql, params).await
94 }
95
96 pub async fn query_rows_with_params(
102 &mut self,
103 sql: &str,
104 params: &[Option<Vec<u8>>],
105 ) -> PgResult<Vec<crate::driver::PgRow>> {
106 self.query_rows_with_params_with_format(sql, params, ResultFormat::Text)
107 .await
108 }
109
110 pub async fn query_rows_with_params_with_format(
113 &mut self,
114 sql: &str,
115 params: &[Option<Vec<u8>>],
116 result_format: ResultFormat,
117 ) -> PgResult<Vec<crate::driver::PgRow>> {
118 let conn = self.conn_mut()?;
119 conn.query_rows_with_result_format(sql, params, result_format.as_wire_code())
120 .await
121 }
122
123 pub async fn query_rows_with_param_types_with_format(
126 &mut self,
127 sql: &str,
128 param_types: &[u32],
129 params: &[Option<Vec<u8>>],
130 result_format: ResultFormat,
131 ) -> PgResult<Vec<crate::driver::PgRow>> {
132 let conn = self.conn_mut()?;
133 conn.query_rows_with_param_types_and_result_format(
134 sql,
135 param_types,
136 params,
137 result_format.as_wire_code(),
138 )
139 .await
140 }
141
142 pub async fn probe_query_with_param_types(
145 &mut self,
146 sql: &str,
147 param_types: &[u32],
148 params: &[Option<Vec<u8>>],
149 ) -> PgResult<()> {
150 let conn = self.conn_mut()?;
151 conn.probe_query_with_param_types(sql, param_types, params)
152 .await
153 }
154
155 pub async fn copy_export(&mut self, cmd: &qail_core::ast::Qail) -> PgResult<Vec<Vec<String>>> {
157 self.conn_mut()?.copy_export(cmd).await
158 }
159
160 pub async fn copy_export_stream_raw<F, Fut>(
162 &mut self,
163 cmd: &qail_core::ast::Qail,
164 on_chunk: F,
165 ) -> PgResult<()>
166 where
167 F: FnMut(Vec<u8>) -> Fut,
168 Fut: std::future::Future<Output = PgResult<()>>,
169 {
170 self.conn_mut()?.copy_export_stream_raw(cmd, on_chunk).await
171 }
172
173 pub async fn copy_export_stream_rows<F>(
175 &mut self,
176 cmd: &qail_core::ast::Qail,
177 on_row: F,
178 ) -> PgResult<()>
179 where
180 F: FnMut(Vec<String>) -> PgResult<()>,
181 {
182 self.conn_mut()?.copy_export_stream_rows(cmd, on_row).await
183 }
184
185 pub async fn copy_export_table(
187 &mut self,
188 table: &str,
189 columns: &[String],
190 ) -> PgResult<Vec<u8>> {
191 let sql = copy_export_table_sql(table, columns)?;
192 self.conn_mut()?.copy_out_raw(&sql).await
193 }
194
195 pub async fn copy_export_table_stream<F, Fut>(
197 &mut self,
198 table: &str,
199 columns: &[String],
200 on_chunk: F,
201 ) -> PgResult<()>
202 where
203 F: FnMut(Vec<u8>) -> Fut,
204 Fut: std::future::Future<Output = PgResult<()>>,
205 {
206 let sql = copy_export_table_sql(table, columns)?;
207 self.conn_mut()?.copy_out_raw_stream(&sql, on_chunk).await
208 }
209
210 pub async fn fetch_all_uncached_with_format(
212 &mut self,
213 cmd: &qail_core::ast::Qail,
214 result_format: ResultFormat,
215 ) -> PgResult<Vec<crate::driver::PgRow>> {
216 use crate::driver::ColumnInfo;
217 use crate::protocol::AstEncoder;
218
219 let conn = self.conn_mut()?;
220
221 AstEncoder::encode_cmd_reuse_into_with_result_format(
222 cmd,
223 &mut conn.sql_buf,
224 &mut conn.params_buf,
225 &mut conn.write_buf,
226 result_format.as_wire_code(),
227 )
228 .map_err(|e| PgError::Encode(e.to_string()))?;
229
230 conn.flush_write_buf().await?;
231
232 let mut rows: Vec<crate::driver::PgRow> = Vec::new();
233 let mut column_info: Option<Arc<ColumnInfo>> = None;
234 let mut error: Option<PgError> = None;
235 let mut flow =
236 ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_describe_portal_execute());
237
238 loop {
239 let msg = conn.recv().await?;
240 if let Err(err) = flow.validate(&msg, "pool fetch_all execute", error.is_some()) {
241 return return_with_desync(conn, err);
242 }
243 match msg {
244 crate::protocol::BackendMessage::ParseComplete
245 | crate::protocol::BackendMessage::BindComplete => {}
246 crate::protocol::BackendMessage::RowDescription(fields) => {
247 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
248 }
249 crate::protocol::BackendMessage::DataRow(data) => {
250 if error.is_none() {
251 rows.push(crate::driver::PgRow {
252 columns: data,
253 column_info: column_info.clone(),
254 });
255 }
256 }
257 crate::protocol::BackendMessage::NoData => {}
258 crate::protocol::BackendMessage::CommandComplete(_) => {}
259 crate::protocol::BackendMessage::ReadyForQuery(_) => {
260 if let Some(err) = error {
261 return Err(err);
262 }
263 return Ok(rows);
264 }
265 crate::protocol::BackendMessage::ErrorResponse(err) => {
266 if error.is_none() {
267 error = Some(PgError::QueryServer(err.into()));
268 }
269 }
270 msg if is_ignorable_session_message(&msg) => {}
271 other => {
272 return return_with_desync(
273 conn,
274 unexpected_backend_message("pool fetch_all execute", &other),
275 );
276 }
277 }
278 }
279 }
280
281 pub async fn fetch_all_fast(
285 &mut self,
286 cmd: &qail_core::ast::Qail,
287 ) -> PgResult<Vec<crate::driver::PgRow>> {
288 self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
289 .await
290 }
291
292 pub async fn fetch_all_fast_with_format(
294 &mut self,
295 cmd: &qail_core::ast::Qail,
296 result_format: ResultFormat,
297 ) -> PgResult<Vec<crate::driver::PgRow>> {
298 use crate::protocol::AstEncoder;
299
300 let conn = self.conn_mut()?;
301
302 AstEncoder::encode_cmd_reuse_into_with_result_format(
303 cmd,
304 &mut conn.sql_buf,
305 &mut conn.params_buf,
306 &mut conn.write_buf,
307 result_format.as_wire_code(),
308 )
309 .map_err(|e| PgError::Encode(e.to_string()))?;
310
311 conn.flush_write_buf().await?;
312
313 let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
314 let mut error: Option<PgError> = None;
315 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(true));
316
317 loop {
318 let res = conn.recv_with_data_fast().await;
319 match res {
320 Ok((msg_type, data)) => {
321 if let Err(err) = flow.validate_msg_type(
322 msg_type,
323 "pool fetch_all_fast execute",
324 error.is_some(),
325 ) {
326 return return_with_desync(conn, err);
327 }
328 match msg_type {
329 b'D' => {
330 if error.is_none()
331 && let Some(columns) = data
332 {
333 rows.push(crate::driver::PgRow {
334 columns,
335 column_info: None,
336 });
337 }
338 }
339 b'Z' => {
340 if let Some(err) = error {
341 return Err(err);
342 }
343 return Ok(rows);
344 }
345 _ => {}
346 }
347 }
348 Err(e) => {
349 if matches!(&e, PgError::QueryServer(_)) {
350 if error.is_none() {
351 error = Some(e);
352 }
353 continue;
354 }
355 return Err(e);
356 }
357 }
358 }
359 }
360
361 pub async fn fetch_all_cached(
366 &mut self,
367 cmd: &qail_core::ast::Qail,
368 ) -> PgResult<Vec<crate::driver::PgRow>> {
369 self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
370 .await
371 }
372
373 pub async fn fetch_all_cached_with_format(
375 &mut self,
376 cmd: &qail_core::ast::Qail,
377 result_format: ResultFormat,
378 ) -> PgResult<Vec<crate::driver::PgRow>> {
379 let mut retried = false;
380 loop {
381 match self
382 .fetch_all_cached_with_format_once(cmd, result_format)
383 .await
384 {
385 Ok(rows) => return Ok(rows),
386 Err(err)
387 if !retried
388 && (err.is_prepared_statement_retryable()
389 || err.is_prepared_statement_already_exists()) =>
390 {
391 retried = true;
392 if err.is_prepared_statement_retryable()
393 && let Some(conn) = self.conn.as_mut()
394 {
395 conn.clear_prepared_statement_state();
396 }
397 }
398 Err(err) => return Err(err),
399 }
400 }
401 }
402
403 pub async fn fetch_typed<T: crate::driver::row::QailRow>(
405 &mut self,
406 cmd: &qail_core::ast::Qail,
407 ) -> PgResult<Vec<T>> {
408 self.fetch_typed_with_format(cmd, ResultFormat::Text).await
409 }
410
411 pub async fn fetch_typed_with_format<T: crate::driver::row::QailRow>(
416 &mut self,
417 cmd: &qail_core::ast::Qail,
418 result_format: ResultFormat,
419 ) -> PgResult<Vec<T>> {
420 let rows = self
421 .fetch_all_cached_with_format(cmd, result_format)
422 .await?;
423 Ok(rows.iter().map(T::from_row).collect())
424 }
425
426 pub async fn fetch_one_typed<T: crate::driver::row::QailRow>(
428 &mut self,
429 cmd: &qail_core::ast::Qail,
430 ) -> PgResult<Option<T>> {
431 self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
432 .await
433 }
434
435 pub async fn fetch_one_typed_with_format<T: crate::driver::row::QailRow>(
437 &mut self,
438 cmd: &qail_core::ast::Qail,
439 result_format: ResultFormat,
440 ) -> PgResult<Option<T>> {
441 let rows = self
442 .fetch_all_cached_with_format(cmd, result_format)
443 .await?;
444 Ok(rows.first().map(T::from_row))
445 }
446
447 async fn fetch_all_cached_with_format_once(
448 &mut self,
449 cmd: &qail_core::ast::Qail,
450 result_format: ResultFormat,
451 ) -> PgResult<Vec<crate::driver::PgRow>> {
452 use crate::driver::ColumnInfo;
453 use std::collections::hash_map::DefaultHasher;
454 use std::hash::{Hash, Hasher};
455
456 let conn = self.conn.as_mut().ok_or_else(|| {
457 PgError::Connection("Connection already released back to pool".into())
458 })?;
459
460 conn.sql_buf.clear();
461 conn.params_buf.clear();
462
463 match cmd.action {
465 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
466 crate::protocol::ast_encoder::dml::encode_select(
467 cmd,
468 &mut conn.sql_buf,
469 &mut conn.params_buf,
470 )?;
471 }
472 qail_core::ast::Action::Add => {
473 crate::protocol::ast_encoder::dml::encode_insert(
474 cmd,
475 &mut conn.sql_buf,
476 &mut conn.params_buf,
477 )?;
478 }
479 qail_core::ast::Action::Set => {
480 crate::protocol::ast_encoder::dml::encode_update(
481 cmd,
482 &mut conn.sql_buf,
483 &mut conn.params_buf,
484 )?;
485 }
486 qail_core::ast::Action::Del => {
487 crate::protocol::ast_encoder::dml::encode_delete(
488 cmd,
489 &mut conn.sql_buf,
490 &mut conn.params_buf,
491 )?;
492 }
493 _ => {
494 return self
496 .fetch_all_uncached_with_format(cmd, result_format)
497 .await;
498 }
499 }
500
501 let mut hasher = DefaultHasher::new();
502 conn.sql_buf.hash(&mut hasher);
503 let sql_hash = hasher.finish();
504
505 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
506
507 conn.write_buf.clear();
508
509 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
510 name
511 } else {
512 let name = format!("qail_{:x}", sql_hash);
513
514 conn.evict_prepared_if_full();
515
516 let sql_str = encoded_sql_str(&conn.sql_buf)?;
517
518 use crate::protocol::PgEncoder;
519 let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
520 let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
521 conn.write_buf.extend_from_slice(&parse_msg);
522 conn.write_buf.extend_from_slice(&describe_msg);
523
524 conn.stmt_cache.put(sql_hash, name.clone());
525 conn.prepared_statements
526 .insert(name.clone(), sql_str.to_string());
527
528 if let Ok(mut hot) = self.pool.hot_statements.write()
530 && hot.len() < MAX_HOT_STATEMENTS
531 {
532 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
533 }
534
535 name
536 };
537
538 use crate::protocol::PgEncoder;
539 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
540 &mut conn.write_buf,
541 &stmt_name,
542 &conn.params_buf,
543 result_format.as_wire_code(),
544 ) {
545 if is_cache_miss {
546 conn.stmt_cache.remove(&sql_hash);
547 conn.prepared_statements.remove(&stmt_name);
548 conn.column_info_cache.remove(&sql_hash);
549 }
550 return Err(PgError::Encode(e.to_string()));
551 }
552 PgEncoder::encode_execute_to(&mut conn.write_buf);
553 PgEncoder::encode_sync_to(&mut conn.write_buf);
554
555 if let Err(err) = conn.flush_write_buf().await {
556 if is_cache_miss {
557 conn.stmt_cache.remove(&sql_hash);
558 conn.prepared_statements.remove(&stmt_name);
559 conn.column_info_cache.remove(&sql_hash);
560 }
561 return Err(err);
562 }
563
564 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
565
566 let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
567 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
568 let mut error: Option<PgError> = None;
569 let mut flow = ExtendedFlowTracker::new(
570 ExtendedFlowConfig::parse_describe_statement_bind_execute(is_cache_miss),
571 );
572
573 loop {
574 let msg = match conn.recv().await {
575 Ok(msg) => msg,
576 Err(err) => {
577 if is_cache_miss && !flow.saw_parse_complete() {
578 conn.stmt_cache.remove(&sql_hash);
579 conn.prepared_statements.remove(&stmt_name);
580 conn.column_info_cache.remove(&sql_hash);
581 }
582 return Err(err);
583 }
584 };
585 if let Err(err) = flow.validate(&msg, "pool fetch_all_cached execute", error.is_some())
586 {
587 if is_cache_miss && !flow.saw_parse_complete() {
588 conn.stmt_cache.remove(&sql_hash);
589 conn.prepared_statements.remove(&stmt_name);
590 conn.column_info_cache.remove(&sql_hash);
591 }
592 return return_with_desync(conn, err);
593 }
594 match msg {
595 crate::protocol::BackendMessage::ParseComplete => {}
596 crate::protocol::BackendMessage::BindComplete => {}
597 crate::protocol::BackendMessage::ParameterDescription(_) => {}
598 crate::protocol::BackendMessage::RowDescription(fields) => {
599 let info = Arc::new(ColumnInfo::from_fields(&fields));
600 if is_cache_miss {
601 conn.column_info_cache.insert(sql_hash, Arc::clone(&info));
602 }
603 column_info = Some(info);
604 }
605 crate::protocol::BackendMessage::DataRow(data) => {
606 if error.is_none() {
607 rows.push(crate::driver::PgRow {
608 columns: data,
609 column_info: column_info.clone(),
610 });
611 }
612 }
613 crate::protocol::BackendMessage::CommandComplete(_) => {}
614 crate::protocol::BackendMessage::ReadyForQuery(_) => {
615 if let Some(err) = error {
616 if is_cache_miss
617 && !flow.saw_parse_complete()
618 && !err.is_prepared_statement_already_exists()
619 {
620 conn.stmt_cache.remove(&sql_hash);
621 conn.prepared_statements.remove(&stmt_name);
622 conn.column_info_cache.remove(&sql_hash);
623 }
624 return Err(err);
625 }
626 if is_cache_miss && !flow.saw_parse_complete() {
627 conn.stmt_cache.remove(&sql_hash);
628 conn.prepared_statements.remove(&stmt_name);
629 conn.column_info_cache.remove(&sql_hash);
630 return return_with_desync(
631 conn,
632 PgError::Protocol(
633 "Cache miss query reached ReadyForQuery without ParseComplete"
634 .to_string(),
635 ),
636 );
637 }
638 return Ok(rows);
639 }
640 crate::protocol::BackendMessage::ErrorResponse(err) => {
641 if error.is_none() {
642 error = Some(PgError::QueryServer(err.into()));
643 }
644 }
645 msg if is_ignorable_session_message(&msg) => {}
646 other => {
647 if is_cache_miss && !flow.saw_parse_complete() {
648 conn.stmt_cache.remove(&sql_hash);
649 conn.prepared_statements.remove(&stmt_name);
650 conn.column_info_cache.remove(&sql_hash);
651 }
652 return return_with_desync(
653 conn,
654 unexpected_backend_message("pool fetch_all_cached execute", &other),
655 );
656 }
657 }
658 }
659 }
660
661 pub async fn fetch_all_with_rls(
680 &mut self,
681 cmd: &qail_core::ast::Qail,
682 rls_sql: &str,
683 ) -> PgResult<Vec<crate::driver::PgRow>> {
684 self.fetch_all_with_rls_with_format(cmd, rls_sql, ResultFormat::Text)
685 .await
686 }
687
688 pub async fn fetch_all_with_rls_with_format(
690 &mut self,
691 cmd: &qail_core::ast::Qail,
692 rls_sql: &str,
693 result_format: ResultFormat,
694 ) -> PgResult<Vec<crate::driver::PgRow>> {
695 let mut retried = false;
696 loop {
697 match self
698 .fetch_all_with_rls_with_format_once(cmd, rls_sql, result_format)
699 .await
700 {
701 Ok(rows) => return Ok(rows),
702 Err(err)
703 if !retried
704 && (err.is_prepared_statement_retryable()
705 || err.is_prepared_statement_already_exists()) =>
706 {
707 retried = true;
708 if let Some(conn) = self.conn.as_mut() {
709 if err.is_prepared_statement_retryable() {
710 conn.clear_prepared_statement_state();
711 }
712 let _ = conn.execute_simple("ROLLBACK").await;
715 }
716 self.rls_dirty = false;
717 }
718 Err(err) => return Err(err),
719 }
720 }
721 }
722
723 async fn fetch_all_with_rls_with_format_once(
724 &mut self,
725 cmd: &qail_core::ast::Qail,
726 rls_sql: &str,
727 result_format: ResultFormat,
728 ) -> PgResult<Vec<crate::driver::PgRow>> {
729 use crate::driver::ColumnInfo;
730 use std::collections::hash_map::DefaultHasher;
731 use std::hash::{Hash, Hasher};
732
733 let conn = self.conn.as_mut().ok_or_else(|| {
734 PgError::Connection("Connection already released back to pool".into())
735 })?;
736
737 if !crate::protocol::AstEncoder::encode_cacheable_cmd_sql_to(
738 cmd,
739 &mut conn.sql_buf,
740 &mut conn.params_buf,
741 )? {
742 conn.execute_simple(rls_sql).await?;
744 self.rls_dirty = true;
745 return self
746 .fetch_all_uncached_with_format(cmd, result_format)
747 .await;
748 }
749
750 let mut hasher = DefaultHasher::new();
751 conn.sql_buf.hash(&mut hasher);
752 let sql_hash = hasher.finish();
753
754 let is_cache_miss = !conn.stmt_cache.contains(&sql_hash);
755
756 conn.write_buf.clear();
757
758 let rls_msg = crate::protocol::PgEncoder::try_encode_query_string(rls_sql)?;
763 conn.write_buf.extend_from_slice(&rls_msg);
764
765 let stmt_name = if let Some(name) = conn.stmt_cache.get(&sql_hash) {
767 name
768 } else {
769 let name = format!("qail_{:x}", sql_hash);
770
771 conn.evict_prepared_if_full();
772
773 let sql_str = encoded_sql_str(&conn.sql_buf)?;
774
775 use crate::protocol::PgEncoder;
776 let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
777 let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
778 conn.write_buf.extend_from_slice(&parse_msg);
779 conn.write_buf.extend_from_slice(&describe_msg);
780
781 conn.stmt_cache.put(sql_hash, name.clone());
782 conn.prepared_statements
783 .insert(name.clone(), sql_str.to_string());
784
785 if let Ok(mut hot) = self.pool.hot_statements.write()
786 && hot.len() < MAX_HOT_STATEMENTS
787 {
788 hot.insert(sql_hash, (name.clone(), sql_str.to_string()));
789 }
790
791 name
792 };
793
794 use crate::protocol::PgEncoder;
795 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
796 &mut conn.write_buf,
797 &stmt_name,
798 &conn.params_buf,
799 result_format.as_wire_code(),
800 ) {
801 rollback_cache_miss_statement_registration(conn, is_cache_miss, sql_hash, &stmt_name);
802 return Err(PgError::Encode(e.to_string()));
803 }
804 PgEncoder::encode_execute_to(&mut conn.write_buf);
805 PgEncoder::encode_sync_to(&mut conn.write_buf);
806
807 if let Err(err) = conn.flush_write_buf().await {
809 rollback_cache_miss_statement_registration(conn, is_cache_miss, sql_hash, &stmt_name);
810 return Err(err);
811 }
812
813 self.rls_dirty = true;
815
816 let mut rls_error: Option<PgError> = None;
820 loop {
821 let msg = match conn.recv().await {
822 Ok(msg) => msg,
823 Err(err) => {
824 rollback_cache_miss_statement_registration(
825 conn,
826 is_cache_miss,
827 sql_hash,
828 &stmt_name,
829 );
830 return Err(err);
831 }
832 };
833 match msg {
834 crate::protocol::BackendMessage::ReadyForQuery(_) => {
835 if let Some(err) = rls_error {
837 rollback_cache_miss_statement_registration(
838 conn,
839 is_cache_miss,
840 sql_hash,
841 &stmt_name,
842 );
843 if let Err(drain_err) =
844 drain_extended_responses_after_rls_setup_error(conn).await
845 {
846 tracing::warn!(
847 error = %drain_err,
848 "failed to drain pipelined extended responses after RLS setup error"
849 );
850 }
851 return Err(err);
852 }
853 break;
854 }
855 crate::protocol::BackendMessage::ErrorResponse(err) => {
856 if rls_error.is_none() {
857 rls_error = Some(PgError::QueryServer(err.into()));
858 }
859 }
860 crate::protocol::BackendMessage::CommandComplete(_)
862 | crate::protocol::BackendMessage::DataRow(_)
863 | crate::protocol::BackendMessage::RowDescription(_)
864 | crate::protocol::BackendMessage::ParseComplete
865 | crate::protocol::BackendMessage::BindComplete => {}
866 msg if is_ignorable_session_message(&msg) => {}
867 other => {
868 rollback_cache_miss_statement_registration(
869 conn,
870 is_cache_miss,
871 sql_hash,
872 &stmt_name,
873 );
874 return return_with_desync(
875 conn,
876 unexpected_backend_message("pool rls setup", &other),
877 );
878 }
879 }
880 }
881
882 let cached_column_info = conn.column_info_cache.get(&sql_hash).cloned();
884
885 let mut rows: Vec<crate::driver::PgRow> = Vec::with_capacity(32);
886 let mut column_info: Option<std::sync::Arc<ColumnInfo>> = cached_column_info;
887 let mut error: Option<PgError> = None;
888 let mut flow = ExtendedFlowTracker::new(
889 ExtendedFlowConfig::parse_describe_statement_bind_execute(is_cache_miss),
890 );
891
892 loop {
893 let msg = match conn.recv().await {
894 Ok(msg) => msg,
895 Err(err) => {
896 if is_cache_miss && !flow.saw_parse_complete() {
897 rollback_cache_miss_statement_registration(
898 conn,
899 is_cache_miss,
900 sql_hash,
901 &stmt_name,
902 );
903 }
904 return Err(err);
905 }
906 };
907 if let Err(err) =
908 flow.validate(&msg, "pool fetch_all_with_rls execute", error.is_some())
909 {
910 if is_cache_miss && !flow.saw_parse_complete() {
911 rollback_cache_miss_statement_registration(
912 conn,
913 is_cache_miss,
914 sql_hash,
915 &stmt_name,
916 );
917 }
918 return return_with_desync(conn, err);
919 }
920 match msg {
921 crate::protocol::BackendMessage::ParseComplete => {}
922 crate::protocol::BackendMessage::BindComplete => {}
923 crate::protocol::BackendMessage::ParameterDescription(_) => {}
924 crate::protocol::BackendMessage::RowDescription(fields) => {
925 let info = std::sync::Arc::new(ColumnInfo::from_fields(&fields));
926 if is_cache_miss {
927 conn.column_info_cache
928 .insert(sql_hash, std::sync::Arc::clone(&info));
929 }
930 column_info = Some(info);
931 }
932 crate::protocol::BackendMessage::DataRow(data) => {
933 if error.is_none() {
934 rows.push(crate::driver::PgRow {
935 columns: data,
936 column_info: column_info.clone(),
937 });
938 }
939 }
940 crate::protocol::BackendMessage::CommandComplete(_) => {}
941 crate::protocol::BackendMessage::ReadyForQuery(_) => {
942 if let Some(err) = error {
943 if is_cache_miss
944 && !flow.saw_parse_complete()
945 && !err.is_prepared_statement_already_exists()
946 {
947 rollback_cache_miss_statement_registration(
948 conn,
949 is_cache_miss,
950 sql_hash,
951 &stmt_name,
952 );
953 }
954 return Err(err);
955 }
956 if is_cache_miss && !flow.saw_parse_complete() {
957 rollback_cache_miss_statement_registration(
958 conn,
959 is_cache_miss,
960 sql_hash,
961 &stmt_name,
962 );
963 return return_with_desync(
964 conn,
965 PgError::Protocol(
966 "Cache miss query reached ReadyForQuery without ParseComplete"
967 .to_string(),
968 ),
969 );
970 }
971 return Ok(rows);
972 }
973 crate::protocol::BackendMessage::ErrorResponse(err) => {
974 if error.is_none() {
975 error = Some(PgError::QueryServer(err.into()));
976 }
977 }
978 msg if is_ignorable_session_message(&msg) => {}
979 other => {
980 if is_cache_miss && !flow.saw_parse_complete() {
981 rollback_cache_miss_statement_registration(
982 conn,
983 is_cache_miss,
984 sql_hash,
985 &stmt_name,
986 );
987 }
988 return return_with_desync(
989 conn,
990 unexpected_backend_message("pool fetch_all_with_rls execute", &other),
991 );
992 }
993 }
994 }
995 }
996}
997
998#[cfg(test)]
999mod tests {
1000 use super::{copy_export_table_sql, encoded_sql_str, return_with_desync};
1001
1002 #[cfg(unix)]
1003 fn test_conn() -> crate::driver::PgConnection {
1004 use crate::driver::connection::StatementCache;
1005 use crate::driver::stream::PgStream;
1006 use bytes::BytesMut;
1007 use std::collections::{HashMap, VecDeque};
1008 use std::num::NonZeroUsize;
1009 use tokio::net::UnixStream;
1010
1011 let (unix_stream, _peer) = UnixStream::pair().expect("unix stream pair");
1012 crate::driver::PgConnection {
1013 stream: PgStream::Unix(unix_stream),
1014 buffer: BytesMut::with_capacity(1024),
1015 write_buf: BytesMut::with_capacity(1024),
1016 sql_buf: BytesMut::with_capacity(256),
1017 params_buf: Vec::new(),
1018 prepared_statements: HashMap::new(),
1019 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
1020 column_info_cache: HashMap::new(),
1021 process_id: 0,
1022 cancel_key_bytes: Vec::new(),
1023 requested_protocol_minor: crate::driver::PgConnection::default_protocol_minor(),
1024 negotiated_protocol_minor: crate::driver::PgConnection::default_protocol_minor(),
1025 notifications: VecDeque::new(),
1026 replication_stream_active: false,
1027 replication_mode_enabled: false,
1028 last_replication_wal_end: None,
1029 io_desynced: false,
1030 pending_statement_closes: Vec::new(),
1031 draining_statement_closes: false,
1032 }
1033 }
1034
1035 #[test]
1036 fn pool_copy_export_table_sql_preserves_schema_qualified_table() {
1037 let sql = copy_export_table_sql(
1038 "tenant_a.users",
1039 &["id".to_string(), "display\"name".to_string()],
1040 )
1041 .unwrap();
1042
1043 assert_eq!(
1044 sql,
1045 "COPY \"tenant_a\".\"users\" (\"id\", \"display\"\"name\") TO STDOUT"
1046 );
1047 }
1048
1049 #[test]
1050 fn pool_copy_export_table_sql_rejects_nul_bytes() {
1051 assert!(copy_export_table_sql("tenant\0.users", &["id".to_string()]).is_err());
1052 assert!(copy_export_table_sql("users", &["id\0".to_string()]).is_err());
1053 }
1054
1055 #[test]
1056 fn pool_encoded_sql_str_rejects_invalid_utf8() {
1057 let err = encoded_sql_str(&[0xff]).expect_err("invalid SQL UTF-8 must fail");
1058 assert!(err.to_string().contains("encoded SQL is not UTF-8"));
1059 }
1060
1061 #[cfg(unix)]
1062 #[tokio::test]
1063 async fn pool_return_with_desync_marks_protocol_error() {
1064 let mut conn = test_conn();
1065
1066 let err = return_with_desync::<()>(
1067 &mut conn,
1068 crate::driver::PgError::Protocol("bad response ordering".to_string()),
1069 )
1070 .expect_err("protocol error must be returned");
1071
1072 assert!(err.to_string().contains("bad response ordering"));
1073 assert!(conn.is_io_desynced());
1074 }
1075}