1use super::core::PgDriver;
5use super::prepared::PreparedAstQuery;
6use super::types::*;
7use qail_core::ast::Qail;
8use std::sync::Arc;
9use std::{
10 collections::hash_map::DefaultHasher,
11 hash::{Hash, Hasher},
12};
13
14#[inline]
15fn return_with_desync<T>(driver: &mut PgDriver, err: PgError) -> PgResult<T> {
16 if matches!(
17 err,
18 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
19 ) {
20 driver.connection.mark_io_desynced();
21 }
22 Err(err)
23}
24
25#[inline]
26fn encoded_sql_str(sql_buf: &[u8]) -> PgResult<&str> {
27 std::str::from_utf8(sql_buf)
28 .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))
29}
30
31async fn reprepare_prepared_ast_query(
32 conn: &mut super::PgConnection,
33 prepared: &PreparedAstQuery,
34) -> PgResult<()> {
35 conn.clear_prepared_statement_state();
36 let stmt = conn.prepare(&prepared.sql).await?;
37 conn.stmt_cache
38 .put(prepared.sql_hash, stmt.name().to_string());
39 conn.prepared_statements
40 .insert(stmt.name().to_string(), prepared.sql.clone());
41 Ok(())
42}
43
44impl PgDriver {
45 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
52 self.fetch_all_with_format(cmd, ResultFormat::Text).await
53 }
54
55 pub async fn fetch_all_with_format(
61 &mut self,
62 cmd: &Qail,
63 result_format: ResultFormat,
64 ) -> PgResult<Vec<PgRow>> {
65 self.fetch_all_cached_with_format(cmd, result_format).await
67 }
68
69 pub async fn prepare_ast_query(&mut self, cmd: &Qail) -> PgResult<PreparedAstQuery> {
75 use crate::protocol::AstEncoder;
76
77 let (sql, params) =
78 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
79 let stmt = self.connection.prepare(&sql).await?;
80
81 let mut hasher = DefaultHasher::new();
82 sql.hash(&mut hasher);
83 let sql_hash = hasher.finish();
84
85 self.connection
86 .stmt_cache
87 .put(sql_hash, stmt.name().to_string());
88 self.connection
89 .prepared_statements
90 .insert(stmt.name().to_string(), sql.clone());
91
92 Ok(PreparedAstQuery {
93 stmt,
94 params,
95 sql,
96 sql_hash,
97 })
98 }
99
100 pub async fn fetch_all_prepared_ast(
105 &mut self,
106 prepared: &PreparedAstQuery,
107 ) -> PgResult<Vec<PgRow>> {
108 self.fetch_all_prepared_ast_with_format(prepared, ResultFormat::Text)
109 .await
110 }
111
112 pub async fn fetch_all_prepared_ast_with_format(
114 &mut self,
115 prepared: &PreparedAstQuery,
116 result_format: ResultFormat,
117 ) -> PgResult<Vec<PgRow>> {
118 let mut retried = false;
119
120 loop {
121 self.connection.stmt_cache.touch_key(prepared.sql_hash);
122 self.connection.write_buf.clear();
123 if let Err(e) = crate::protocol::PgEncoder::encode_bind_to_with_result_format(
124 &mut self.connection.write_buf,
125 prepared.stmt.name(),
126 &prepared.params,
127 result_format.as_wire_code(),
128 ) {
129 return Err(PgError::Encode(e.to_string()));
130 }
131 crate::protocol::PgEncoder::encode_execute_to(&mut self.connection.write_buf);
132 crate::protocol::PgEncoder::encode_sync_to(&mut self.connection.write_buf);
133
134 if let Err(err) = self.connection.flush_write_buf().await {
135 if !retried && err.is_prepared_statement_retryable() {
136 retried = true;
137 reprepare_prepared_ast_query(&mut self.connection, prepared).await?;
138 continue;
139 }
140 return Err(err);
141 }
142
143 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
144 let mut error: Option<PgError> = None;
145 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
146 super::extended_flow::ExtendedFlowConfig::parse_bind_execute(false),
147 );
148
149 loop {
150 let msg = self.connection.recv().await?;
151 if let Err(err) = flow.validate(
152 &msg,
153 "driver fetch_all_prepared_ast execute",
154 error.is_some(),
155 ) {
156 return return_with_desync(self, err);
157 }
158 match msg {
159 crate::protocol::BackendMessage::BindComplete => {}
160 crate::protocol::BackendMessage::RowDescription(_) => {}
161 crate::protocol::BackendMessage::DataRow(data) => {
162 if error.is_none() {
163 rows.push(PgRow {
164 columns: data,
165 column_info: None,
166 });
167 }
168 }
169 crate::protocol::BackendMessage::CommandComplete(_) => {}
170 crate::protocol::BackendMessage::NoData => {}
171 crate::protocol::BackendMessage::ReadyForQuery(_) => {
172 if let Some(err) = error {
173 if !retried && err.is_prepared_statement_retryable() {
174 retried = true;
175 reprepare_prepared_ast_query(&mut self.connection, prepared)
176 .await?;
177 break;
178 }
179 return Err(err);
180 }
181 return Ok(rows);
182 }
183 crate::protocol::BackendMessage::ErrorResponse(err) => {
184 if error.is_none() {
185 error = Some(PgError::QueryServer(err.into()));
186 }
187 }
188 msg if is_ignorable_session_message(&msg) => {}
189 other => {
190 return return_with_desync(
191 self,
192 unexpected_backend_message(
193 "driver fetch_all_prepared_ast execute",
194 &other,
195 ),
196 );
197 }
198 }
199 }
200 }
201 }
202
203 pub async fn fetch_typed<T: super::row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
211 self.fetch_typed_with_format(cmd, ResultFormat::Text).await
212 }
213
214 pub async fn fetch_typed_with_format<T: super::row::QailRow>(
219 &mut self,
220 cmd: &Qail,
221 result_format: ResultFormat,
222 ) -> PgResult<Vec<T>> {
223 let rows = self.fetch_all_with_format(cmd, result_format).await?;
224 Ok(rows.iter().map(T::from_row).collect())
225 }
226
227 pub async fn fetch_one_typed<T: super::row::QailRow>(
230 &mut self,
231 cmd: &Qail,
232 ) -> PgResult<Option<T>> {
233 self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
234 .await
235 }
236
237 pub async fn fetch_one_typed_with_format<T: super::row::QailRow>(
239 &mut self,
240 cmd: &Qail,
241 result_format: ResultFormat,
242 ) -> PgResult<Option<T>> {
243 let rows = self.fetch_all_with_format(cmd, result_format).await?;
244 Ok(rows.first().map(T::from_row))
245 }
246
247 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
253 self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
254 .await
255 }
256
257 pub async fn fetch_all_uncached_with_format(
259 &mut self,
260 cmd: &Qail,
261 result_format: ResultFormat,
262 ) -> PgResult<Vec<PgRow>> {
263 use crate::protocol::AstEncoder;
264
265 AstEncoder::encode_cmd_reuse_into_with_result_format(
266 cmd,
267 &mut self.connection.sql_buf,
268 &mut self.connection.params_buf,
269 &mut self.connection.write_buf,
270 result_format.as_wire_code(),
271 )
272 .map_err(|e| PgError::Encode(e.to_string()))?;
273
274 self.connection.flush_write_buf().await?;
275
276 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
277 let mut column_info: Option<Arc<ColumnInfo>> = None;
278
279 let mut error: Option<PgError> = None;
280 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
281 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
282 );
283
284 loop {
285 let msg = self.connection.recv().await?;
286 if let Err(err) = flow.validate(&msg, "driver fetch_all execute", error.is_some()) {
287 return return_with_desync(self, err);
288 }
289 match msg {
290 crate::protocol::BackendMessage::ParseComplete
291 | crate::protocol::BackendMessage::BindComplete => {}
292 crate::protocol::BackendMessage::RowDescription(fields) => {
293 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
294 }
295 crate::protocol::BackendMessage::DataRow(data) => {
296 if error.is_none() {
297 rows.push(PgRow {
298 columns: data,
299 column_info: column_info.clone(),
300 });
301 }
302 }
303 crate::protocol::BackendMessage::NoData => {}
304 crate::protocol::BackendMessage::CommandComplete(_) => {}
305 crate::protocol::BackendMessage::ReadyForQuery(_) => {
306 if let Some(err) = error {
307 return Err(err);
308 }
309 return Ok(rows);
310 }
311 crate::protocol::BackendMessage::ErrorResponse(err) => {
312 if error.is_none() {
313 error = Some(PgError::QueryServer(err.into()));
314 }
315 }
316 msg if is_ignorable_session_message(&msg) => {}
317 other => {
318 return return_with_desync(
319 self,
320 unexpected_backend_message("driver fetch_all execute", &other),
321 );
322 }
323 }
324 }
325 }
326
327 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
331 self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
332 .await
333 }
334
335 pub async fn fetch_all_fast_with_format(
337 &mut self,
338 cmd: &Qail,
339 result_format: ResultFormat,
340 ) -> PgResult<Vec<PgRow>> {
341 use crate::protocol::AstEncoder;
342
343 AstEncoder::encode_cmd_reuse_into_with_result_format(
344 cmd,
345 &mut self.connection.sql_buf,
346 &mut self.connection.params_buf,
347 &mut self.connection.write_buf,
348 result_format.as_wire_code(),
349 )
350 .map_err(|e| PgError::Encode(e.to_string()))?;
351
352 self.connection.flush_write_buf().await?;
353
354 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
356 let mut error: Option<PgError> = None;
357 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
358 super::extended_flow::ExtendedFlowConfig::parse_bind_execute(true),
359 );
360
361 loop {
362 let res = self.connection.recv_with_data_fast().await;
363 match res {
364 Ok((msg_type, data)) => {
365 if let Err(err) = flow.validate_msg_type(
366 msg_type,
367 "driver fetch_all_fast execute",
368 error.is_some(),
369 ) {
370 return return_with_desync(self, err);
371 }
372 match msg_type {
373 b'D' => {
374 if error.is_none()
375 && let Some(columns) = data
376 {
377 rows.push(PgRow {
378 columns,
379 column_info: None,
380 });
381 }
382 }
383 b'Z' => {
384 if let Some(err) = error {
385 return Err(err);
386 }
387 return Ok(rows);
388 }
389 _ => {}
390 }
391 }
392 Err(e) => {
393 if matches!(&e, PgError::QueryServer(_)) {
395 if error.is_none() {
396 error = Some(e);
397 }
398 continue;
399 }
400 return Err(e);
401 }
402 }
403 }
404 }
405
406 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
408 let rows = self.fetch_all(cmd).await?;
409 rows.into_iter().next().ok_or(PgError::NoRows)
410 }
411
412 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
421 self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
422 .await
423 }
424
425 pub async fn fetch_all_cached_with_format(
427 &mut self,
428 cmd: &Qail,
429 result_format: ResultFormat,
430 ) -> PgResult<Vec<PgRow>> {
431 let mut retried = false;
432 loop {
433 match self
434 .fetch_all_cached_with_format_once(cmd, result_format)
435 .await
436 {
437 Ok(rows) => return Ok(rows),
438 Err(err)
439 if !retried
440 && (err.is_prepared_statement_retryable()
441 || err.is_prepared_statement_already_exists()) =>
442 {
443 retried = true;
444 if err.is_prepared_statement_retryable() {
445 self.connection.clear_prepared_statement_state();
446 }
447 }
448 Err(err) => return Err(err),
449 }
450 }
451 }
452
453 async fn fetch_all_cached_with_format_once(
454 &mut self,
455 cmd: &Qail,
456 result_format: ResultFormat,
457 ) -> PgResult<Vec<PgRow>> {
458 use crate::protocol::AstEncoder;
459 use std::collections::hash_map::DefaultHasher;
460 use std::hash::{Hash, Hasher};
461
462 if !AstEncoder::encode_cacheable_cmd_sql_to(
463 cmd,
464 &mut self.connection.sql_buf,
465 &mut self.connection.params_buf,
466 )? {
467 let (sql, params) =
469 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
470 let raw_rows = self
471 .connection
472 .query_cached_with_result_format(&sql, ¶ms, result_format.as_wire_code())
473 .await?;
474 return Ok(raw_rows
475 .into_iter()
476 .map(|data| PgRow {
477 columns: data,
478 column_info: None,
479 })
480 .collect());
481 }
482
483 let mut hasher = DefaultHasher::new();
484 self.connection.sql_buf.hash(&mut hasher);
485 let sql_hash = hasher.finish();
486
487 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
488
489 self.connection.write_buf.clear();
491
492 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
493 name
494 } else {
495 let name = format!("qail_{:x}", sql_hash);
496
497 self.connection.evict_prepared_if_full();
499
500 let sql_str = encoded_sql_str(&self.connection.sql_buf)?;
501
502 use crate::protocol::PgEncoder;
504 let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
505 let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
506 self.connection.write_buf.extend_from_slice(&parse_msg);
507 self.connection.write_buf.extend_from_slice(&describe_msg);
508
509 self.connection.stmt_cache.put(sql_hash, name.clone());
510 self.connection
511 .prepared_statements
512 .insert(name.clone(), sql_str.to_string());
513
514 name
515 };
516
517 use crate::protocol::PgEncoder;
519 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
520 &mut self.connection.write_buf,
521 &stmt_name,
522 &self.connection.params_buf,
523 result_format.as_wire_code(),
524 ) {
525 if is_cache_miss {
526 self.connection.stmt_cache.remove(&sql_hash);
527 self.connection.prepared_statements.remove(&stmt_name);
528 self.connection.column_info_cache.remove(&sql_hash);
529 }
530 return Err(PgError::Encode(e.to_string()));
531 }
532 PgEncoder::encode_execute_to(&mut self.connection.write_buf);
533 PgEncoder::encode_sync_to(&mut self.connection.write_buf);
534
535 if let Err(err) = self.connection.flush_write_buf().await {
537 if is_cache_miss {
538 self.connection.stmt_cache.remove(&sql_hash);
539 self.connection.prepared_statements.remove(&stmt_name);
540 self.connection.column_info_cache.remove(&sql_hash);
541 }
542 return Err(err);
543 }
544
545 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
547
548 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
549 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
550 let mut error: Option<PgError> = None;
551 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
552 super::extended_flow::ExtendedFlowConfig::parse_describe_statement_bind_execute(
553 is_cache_miss,
554 ),
555 );
556
557 loop {
558 let msg = match self.connection.recv().await {
559 Ok(msg) => msg,
560 Err(err) => {
561 if is_cache_miss && !flow.saw_parse_complete() {
562 self.connection.stmt_cache.remove(&sql_hash);
563 self.connection.prepared_statements.remove(&stmt_name);
564 self.connection.column_info_cache.remove(&sql_hash);
565 }
566 return Err(err);
567 }
568 };
569 if let Err(err) =
570 flow.validate(&msg, "driver fetch_all_cached execute", error.is_some())
571 {
572 if is_cache_miss && !flow.saw_parse_complete() {
573 self.connection.stmt_cache.remove(&sql_hash);
574 self.connection.prepared_statements.remove(&stmt_name);
575 self.connection.column_info_cache.remove(&sql_hash);
576 }
577 return return_with_desync(self, err);
578 }
579 match msg {
580 crate::protocol::BackendMessage::ParseComplete => {}
581 crate::protocol::BackendMessage::BindComplete => {}
582 crate::protocol::BackendMessage::ParameterDescription(_) => {
583 }
585 crate::protocol::BackendMessage::RowDescription(fields) => {
586 let info = Arc::new(ColumnInfo::from_fields(&fields));
588 if is_cache_miss {
589 self.connection
590 .column_info_cache
591 .insert(sql_hash, Arc::clone(&info));
592 }
593 column_info = Some(info);
594 }
595 crate::protocol::BackendMessage::DataRow(data) => {
596 if error.is_none() {
597 rows.push(PgRow {
598 columns: data,
599 column_info: column_info.clone(),
600 });
601 }
602 }
603 crate::protocol::BackendMessage::CommandComplete(_) => {}
604 crate::protocol::BackendMessage::NoData => {
605 }
607 crate::protocol::BackendMessage::ReadyForQuery(_) => {
608 if let Some(err) = error {
609 if is_cache_miss
610 && !flow.saw_parse_complete()
611 && !err.is_prepared_statement_already_exists()
612 {
613 self.connection.stmt_cache.remove(&sql_hash);
614 self.connection.prepared_statements.remove(&stmt_name);
615 self.connection.column_info_cache.remove(&sql_hash);
616 }
617 return Err(err);
618 }
619 if is_cache_miss && !flow.saw_parse_complete() {
620 self.connection.stmt_cache.remove(&sql_hash);
621 self.connection.prepared_statements.remove(&stmt_name);
622 self.connection.column_info_cache.remove(&sql_hash);
623 return return_with_desync(
624 self,
625 PgError::Protocol(
626 "Cache miss query reached ReadyForQuery without ParseComplete"
627 .to_string(),
628 ),
629 );
630 }
631 return Ok(rows);
632 }
633 crate::protocol::BackendMessage::ErrorResponse(err) => {
634 if error.is_none() {
635 let query_err = PgError::QueryServer(err.into());
636 if query_err.is_prepared_statement_retryable() {
637 self.connection.clear_prepared_statement_state();
638 }
639 error = Some(query_err);
640 }
641 }
642 msg if is_ignorable_session_message(&msg) => {}
643 other => {
644 if is_cache_miss && !flow.saw_parse_complete() {
645 self.connection.stmt_cache.remove(&sql_hash);
646 self.connection.prepared_statements.remove(&stmt_name);
647 self.connection.column_info_cache.remove(&sql_hash);
648 }
649 return return_with_desync(
650 self,
651 unexpected_backend_message("driver fetch_all_cached execute", &other),
652 );
653 }
654 }
655 }
656 }
657
658 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
660 use crate::protocol::AstEncoder;
661
662 let wire_bytes = AstEncoder::encode_cmd_reuse(
663 cmd,
664 &mut self.connection.sql_buf,
665 &mut self.connection.params_buf,
666 )
667 .map_err(|e| PgError::Encode(e.to_string()))?;
668
669 self.connection.send_bytes(&wire_bytes).await?;
670
671 let mut affected = 0u64;
672 let mut error: Option<PgError> = None;
673 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
674 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
675 );
676
677 loop {
678 let msg = self.connection.recv().await?;
679 if let Err(err) = flow.validate(&msg, "driver execute mutation", error.is_some()) {
680 return return_with_desync(self, err);
681 }
682 match msg {
683 crate::protocol::BackendMessage::ParseComplete
684 | crate::protocol::BackendMessage::BindComplete => {}
685 crate::protocol::BackendMessage::RowDescription(_) => {}
686 crate::protocol::BackendMessage::DataRow(_) => {}
687 crate::protocol::BackendMessage::NoData => {}
688 crate::protocol::BackendMessage::CommandComplete(tag) => {
689 if error.is_none() {
690 match super::parse_affected_rows(&tag) {
691 Ok(parsed) => affected = parsed,
692 Err(err) => return return_with_desync(self, err),
693 }
694 }
695 }
696 crate::protocol::BackendMessage::ReadyForQuery(_) => {
697 if let Some(err) = error {
698 return Err(err);
699 }
700 return Ok(affected);
701 }
702 crate::protocol::BackendMessage::ErrorResponse(err) => {
703 if error.is_none() {
704 error = Some(PgError::QueryServer(err.into()));
705 }
706 }
707 msg if is_ignorable_session_message(&msg) => {}
708 other => {
709 return return_with_desync(
710 self,
711 unexpected_backend_message("driver execute mutation", &other),
712 );
713 }
714 }
715 }
716 }
717
718 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
722 self.query_ast_with_format(cmd, ResultFormat::Text).await
723 }
724
725 pub async fn query_ast_with_format(
727 &mut self,
728 cmd: &Qail,
729 result_format: ResultFormat,
730 ) -> PgResult<QueryResult> {
731 use crate::protocol::AstEncoder;
732
733 let wire_bytes = AstEncoder::encode_cmd_reuse_with_result_format(
734 cmd,
735 &mut self.connection.sql_buf,
736 &mut self.connection.params_buf,
737 result_format.as_wire_code(),
738 )
739 .map_err(|e| PgError::Encode(e.to_string()))?;
740
741 self.connection.send_bytes(&wire_bytes).await?;
742
743 let mut columns: Vec<String> = Vec::new();
744 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
745 let mut error: Option<PgError> = None;
746 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
747 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
748 );
749
750 loop {
751 let msg = self.connection.recv().await?;
752 if let Err(err) = flow.validate(&msg, "driver query_ast", error.is_some()) {
753 return return_with_desync(self, err);
754 }
755 match msg {
756 crate::protocol::BackendMessage::ParseComplete
757 | crate::protocol::BackendMessage::BindComplete => {}
758 crate::protocol::BackendMessage::RowDescription(fields) => {
759 columns = fields.into_iter().map(|f| f.name).collect();
760 }
761 crate::protocol::BackendMessage::DataRow(data) => {
762 if error.is_none() {
763 let row: Vec<Option<String>> = data
764 .into_iter()
765 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
766 .collect();
767 rows.push(row);
768 }
769 }
770 crate::protocol::BackendMessage::CommandComplete(_) => {}
771 crate::protocol::BackendMessage::NoData => {}
772 crate::protocol::BackendMessage::ReadyForQuery(_) => {
773 if let Some(err) = error {
774 return Err(err);
775 }
776 return Ok(QueryResult { columns, rows });
777 }
778 crate::protocol::BackendMessage::ErrorResponse(err) => {
779 if error.is_none() {
780 error = Some(PgError::QueryServer(err.into()));
781 }
782 }
783 msg if is_ignorable_session_message(&msg) => {}
784 other => {
785 return return_with_desync(
786 self,
787 unexpected_backend_message("driver query_ast", &other),
788 );
789 }
790 }
791 }
792 }
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798
799 #[test]
800 fn driver_encoded_sql_str_rejects_invalid_utf8() {
801 let err = encoded_sql_str(&[0xff]).expect_err("invalid SQL UTF-8 must fail");
802 assert!(err.to_string().contains("encoded SQL is not UTF-8"));
803 }
804
805 #[cfg(unix)]
806 fn test_driver_with_peer() -> (PgDriver, tokio::net::UnixStream) {
807 use crate::driver::connection::StatementCache;
808 use crate::driver::stream::PgStream;
809 use bytes::BytesMut;
810 use std::collections::{HashMap, VecDeque};
811 use std::num::NonZeroUsize;
812 use tokio::net::UnixStream;
813
814 let (unix_stream, peer) = UnixStream::pair().expect("unix stream pair");
815 let conn = super::super::PgConnection {
816 stream: PgStream::Unix(unix_stream),
817 buffer: BytesMut::with_capacity(1024),
818 write_buf: BytesMut::with_capacity(1024),
819 sql_buf: BytesMut::with_capacity(256),
820 params_buf: Vec::new(),
821 prepared_statements: HashMap::new(),
822 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
823 column_info_cache: HashMap::new(),
824 process_id: 0,
825 cancel_key_bytes: Vec::new(),
826 requested_protocol_minor: super::super::PgConnection::default_protocol_minor(),
827 negotiated_protocol_minor: super::super::PgConnection::default_protocol_minor(),
828 notifications: VecDeque::new(),
829 replication_stream_active: false,
830 replication_mode_enabled: false,
831 last_replication_wal_end: None,
832 io_desynced: false,
833 pending_statement_closes: Vec::new(),
834 draining_statement_closes: false,
835 };
836 (PgDriver::new(conn), peer)
837 }
838
839 #[cfg(unix)]
840 fn push_backend_frame(driver: &mut PgDriver, msg_type: u8, payload: &[u8]) {
841 driver.connection.buffer.extend_from_slice(&[msg_type]);
842 driver
843 .connection
844 .buffer
845 .extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
846 driver.connection.buffer.extend_from_slice(payload);
847 }
848
849 #[cfg(unix)]
850 fn error_response_payload(code: &str, message: &str) -> Vec<u8> {
851 let mut payload = Vec::new();
852 payload.push(b'S');
853 payload.extend_from_slice(b"ERROR\0");
854 payload.push(b'C');
855 payload.extend_from_slice(code.as_bytes());
856 payload.push(0);
857 payload.push(b'M');
858 payload.extend_from_slice(message.as_bytes());
859 payload.push(0);
860 payload.push(0);
861 payload
862 }
863
864 #[cfg(unix)]
865 fn push_command_complete(driver: &mut PgDriver, tag: &str) {
866 let mut payload = Vec::with_capacity(tag.len() + 1);
867 payload.extend_from_slice(tag.as_bytes());
868 payload.push(0);
869 push_backend_frame(driver, b'C', &payload);
870 }
871
872 #[cfg(unix)]
873 fn prepared_ast_for_sql(sql: &str) -> PreparedAstQuery {
874 use std::collections::hash_map::DefaultHasher;
875 use std::hash::{Hash, Hasher};
876
877 let mut hasher = DefaultHasher::new();
878 sql.hash(&mut hasher);
879
880 PreparedAstQuery {
881 stmt: crate::driver::PreparedStatement::from_sql(sql),
882 params: Vec::new(),
883 sql: sql.to_string(),
884 sql_hash: hasher.finish(),
885 }
886 }
887
888 #[cfg(unix)]
889 #[tokio::test]
890 async fn fetch_fast_protocol_error_marks_driver_connection_desynced() {
891 let (mut driver, _peer) = test_driver_with_peer();
892 push_backend_frame(&mut driver, b'D', &0i16.to_be_bytes());
893
894 let err = match driver.fetch_all_fast(&Qail::get("users")).await {
895 Ok(_) => panic!("out-of-order DataRow must fail"),
896 Err(err) => err,
897 };
898
899 assert!(err.to_string().contains("DataRow before BindComplete"));
900 assert!(driver.connection.is_io_desynced());
901 }
902
903 #[cfg(unix)]
904 #[tokio::test]
905 async fn execute_bad_command_tag_marks_driver_connection_desynced() {
906 let (mut driver, _peer) = test_driver_with_peer();
907 push_backend_frame(&mut driver, b'1', &[]);
908 push_backend_frame(&mut driver, b'2', &[]);
909 push_backend_frame(&mut driver, b'n', &[]);
910 push_command_complete(&mut driver, "UPDATE");
911 push_backend_frame(&mut driver, b'Z', b"I");
912
913 let err = driver
914 .execute(&Qail::get("users"))
915 .await
916 .expect_err("malformed CommandComplete tag must fail");
917
918 assert!(
919 err.to_string().contains("missing affected row count")
920 || err.to_string().contains("invalid affected row count")
921 );
922 assert!(driver.connection.is_io_desynced());
923 }
924
925 #[cfg(unix)]
926 #[tokio::test]
927 async fn prepared_ast_retry_reparses_after_missing_server_statement() {
928 let (mut driver, _peer) = test_driver_with_peer();
929 let prepared = prepared_ast_for_sql("SELECT 1");
930 let stmt_name = prepared.stmt.name().to_string();
931
932 driver
933 .connection
934 .stmt_cache
935 .put(prepared.sql_hash, stmt_name.clone());
936 driver
937 .connection
938 .prepared_statements
939 .insert(stmt_name.clone(), prepared.sql.clone());
940
941 let missing_payload = error_response_payload(
942 "26000",
943 &format!("prepared statement \"{}\" does not exist", stmt_name),
944 );
945
946 push_backend_frame(&mut driver, b'E', &missing_payload);
948 push_backend_frame(&mut driver, b'Z', b"I");
949 push_backend_frame(&mut driver, b'1', &[]);
951 push_backend_frame(&mut driver, b'Z', b"I");
952 push_backend_frame(&mut driver, b'2', &[]);
954 push_command_complete(&mut driver, "SELECT 0");
955 push_backend_frame(&mut driver, b'Z', b"I");
956
957 let rows = driver
958 .fetch_all_prepared_ast(&prepared)
959 .await
960 .expect("stale prepared AST handle should reparse and retry once");
961
962 assert!(rows.is_empty());
963 assert!(
964 driver
965 .connection
966 .prepared_statements
967 .contains_key(&stmt_name)
968 );
969 assert!(driver.connection.stmt_cache.contains(&prepared.sql_hash));
970 assert!(!driver.connection.is_io_desynced());
971 }
972}