1use super::core::PgDriver;
5use super::prepared::PreparedAstQuery;
6use super::types::*;
7use qail_core::ast::Qail;
8use std::sync::Arc;
9use std::{
10 collections::hash_map::DefaultHasher,
11 hash::{Hash, Hasher},
12};
13
14#[inline]
15fn return_with_desync<T>(driver: &mut PgDriver, err: PgError) -> PgResult<T> {
16 if matches!(
17 err,
18 PgError::Protocol(_) | PgError::Connection(_) | PgError::Timeout(_)
19 ) {
20 driver.connection.mark_io_desynced();
21 }
22 Err(err)
23}
24
25#[inline]
26fn encoded_sql_str(sql_buf: &[u8]) -> PgResult<&str> {
27 std::str::from_utf8(sql_buf)
28 .map_err(|e| PgError::Encode(format!("encoded SQL is not UTF-8: {}", e)))
29}
30
31impl PgDriver {
32 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
39 self.fetch_all_with_format(cmd, ResultFormat::Text).await
40 }
41
42 pub async fn fetch_all_with_format(
48 &mut self,
49 cmd: &Qail,
50 result_format: ResultFormat,
51 ) -> PgResult<Vec<PgRow>> {
52 self.fetch_all_cached_with_format(cmd, result_format).await
54 }
55
56 pub async fn prepare_ast_query(&mut self, cmd: &Qail) -> PgResult<PreparedAstQuery> {
62 use crate::protocol::AstEncoder;
63
64 let (sql, params) =
65 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
66 let stmt = self.connection.prepare(&sql).await?;
67
68 let mut hasher = DefaultHasher::new();
69 sql.hash(&mut hasher);
70 let sql_hash = hasher.finish();
71
72 self.connection
73 .stmt_cache
74 .put(sql_hash, stmt.name().to_string());
75 self.connection
76 .prepared_statements
77 .insert(stmt.name().to_string(), sql.clone());
78
79 Ok(PreparedAstQuery {
80 stmt,
81 params,
82 sql,
83 sql_hash,
84 })
85 }
86
87 pub async fn fetch_all_prepared_ast(
92 &mut self,
93 prepared: &PreparedAstQuery,
94 ) -> PgResult<Vec<PgRow>> {
95 self.fetch_all_prepared_ast_with_format(prepared, ResultFormat::Text)
96 .await
97 }
98
99 pub async fn fetch_all_prepared_ast_with_format(
101 &mut self,
102 prepared: &PreparedAstQuery,
103 result_format: ResultFormat,
104 ) -> PgResult<Vec<PgRow>> {
105 let mut retried = false;
106
107 loop {
108 self.connection.stmt_cache.touch_key(prepared.sql_hash);
109 self.connection.write_buf.clear();
110 if let Err(e) = crate::protocol::PgEncoder::encode_bind_to_with_result_format(
111 &mut self.connection.write_buf,
112 prepared.stmt.name(),
113 &prepared.params,
114 result_format.as_wire_code(),
115 ) {
116 return Err(PgError::Encode(e.to_string()));
117 }
118 crate::protocol::PgEncoder::encode_execute_to(&mut self.connection.write_buf);
119 crate::protocol::PgEncoder::encode_sync_to(&mut self.connection.write_buf);
120
121 if let Err(err) = self.connection.flush_write_buf().await {
122 if !retried && err.is_prepared_statement_retryable() {
123 retried = true;
124 let stmt = self.connection.prepare(&prepared.sql).await?;
125 self.connection
126 .stmt_cache
127 .put(prepared.sql_hash, stmt.name().to_string());
128 self.connection
129 .prepared_statements
130 .insert(stmt.name().to_string(), prepared.sql.clone());
131 continue;
132 }
133 return Err(err);
134 }
135
136 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
137 let mut error: Option<PgError> = None;
138 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
139 super::extended_flow::ExtendedFlowConfig::parse_bind_execute(false),
140 );
141
142 loop {
143 let msg = self.connection.recv().await?;
144 if let Err(err) = flow.validate(
145 &msg,
146 "driver fetch_all_prepared_ast execute",
147 error.is_some(),
148 ) {
149 return return_with_desync(self, err);
150 }
151 match msg {
152 crate::protocol::BackendMessage::BindComplete => {}
153 crate::protocol::BackendMessage::RowDescription(_) => {}
154 crate::protocol::BackendMessage::DataRow(data) => {
155 if error.is_none() {
156 rows.push(PgRow {
157 columns: data,
158 column_info: None,
159 });
160 }
161 }
162 crate::protocol::BackendMessage::CommandComplete(_) => {}
163 crate::protocol::BackendMessage::NoData => {}
164 crate::protocol::BackendMessage::ReadyForQuery(_) => {
165 if let Some(err) = error {
166 if !retried && err.is_prepared_statement_retryable() {
167 retried = true;
168 let stmt = self.connection.prepare(&prepared.sql).await?;
169 self.connection
170 .stmt_cache
171 .put(prepared.sql_hash, stmt.name().to_string());
172 self.connection
173 .prepared_statements
174 .insert(stmt.name().to_string(), prepared.sql.clone());
175 break;
176 }
177 return Err(err);
178 }
179 return Ok(rows);
180 }
181 crate::protocol::BackendMessage::ErrorResponse(err) => {
182 if error.is_none() {
183 error = Some(PgError::QueryServer(err.into()));
184 }
185 }
186 msg if is_ignorable_session_message(&msg) => {}
187 other => {
188 return return_with_desync(
189 self,
190 unexpected_backend_message(
191 "driver fetch_all_prepared_ast execute",
192 &other,
193 ),
194 );
195 }
196 }
197 }
198 }
199 }
200
201 pub async fn fetch_typed<T: super::row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
209 self.fetch_typed_with_format(cmd, ResultFormat::Text).await
210 }
211
212 pub async fn fetch_typed_with_format<T: super::row::QailRow>(
217 &mut self,
218 cmd: &Qail,
219 result_format: ResultFormat,
220 ) -> PgResult<Vec<T>> {
221 let rows = self.fetch_all_with_format(cmd, result_format).await?;
222 Ok(rows.iter().map(T::from_row).collect())
223 }
224
225 pub async fn fetch_one_typed<T: super::row::QailRow>(
228 &mut self,
229 cmd: &Qail,
230 ) -> PgResult<Option<T>> {
231 self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
232 .await
233 }
234
235 pub async fn fetch_one_typed_with_format<T: super::row::QailRow>(
237 &mut self,
238 cmd: &Qail,
239 result_format: ResultFormat,
240 ) -> PgResult<Option<T>> {
241 let rows = self.fetch_all_with_format(cmd, result_format).await?;
242 Ok(rows.first().map(T::from_row))
243 }
244
245 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
251 self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
252 .await
253 }
254
255 pub async fn fetch_all_uncached_with_format(
257 &mut self,
258 cmd: &Qail,
259 result_format: ResultFormat,
260 ) -> PgResult<Vec<PgRow>> {
261 use crate::protocol::AstEncoder;
262
263 AstEncoder::encode_cmd_reuse_into_with_result_format(
264 cmd,
265 &mut self.connection.sql_buf,
266 &mut self.connection.params_buf,
267 &mut self.connection.write_buf,
268 result_format.as_wire_code(),
269 )
270 .map_err(|e| PgError::Encode(e.to_string()))?;
271
272 self.connection.flush_write_buf().await?;
273
274 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
275 let mut column_info: Option<Arc<ColumnInfo>> = None;
276
277 let mut error: Option<PgError> = None;
278 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
279 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
280 );
281
282 loop {
283 let msg = self.connection.recv().await?;
284 if let Err(err) = flow.validate(&msg, "driver fetch_all execute", error.is_some()) {
285 return return_with_desync(self, err);
286 }
287 match msg {
288 crate::protocol::BackendMessage::ParseComplete
289 | crate::protocol::BackendMessage::BindComplete => {}
290 crate::protocol::BackendMessage::RowDescription(fields) => {
291 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
292 }
293 crate::protocol::BackendMessage::DataRow(data) => {
294 if error.is_none() {
295 rows.push(PgRow {
296 columns: data,
297 column_info: column_info.clone(),
298 });
299 }
300 }
301 crate::protocol::BackendMessage::NoData => {}
302 crate::protocol::BackendMessage::CommandComplete(_) => {}
303 crate::protocol::BackendMessage::ReadyForQuery(_) => {
304 if let Some(err) = error {
305 return Err(err);
306 }
307 return Ok(rows);
308 }
309 crate::protocol::BackendMessage::ErrorResponse(err) => {
310 if error.is_none() {
311 error = Some(PgError::QueryServer(err.into()));
312 }
313 }
314 msg if is_ignorable_session_message(&msg) => {}
315 other => {
316 return return_with_desync(
317 self,
318 unexpected_backend_message("driver fetch_all execute", &other),
319 );
320 }
321 }
322 }
323 }
324
325 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
329 self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
330 .await
331 }
332
333 pub async fn fetch_all_fast_with_format(
335 &mut self,
336 cmd: &Qail,
337 result_format: ResultFormat,
338 ) -> PgResult<Vec<PgRow>> {
339 use crate::protocol::AstEncoder;
340
341 AstEncoder::encode_cmd_reuse_into_with_result_format(
342 cmd,
343 &mut self.connection.sql_buf,
344 &mut self.connection.params_buf,
345 &mut self.connection.write_buf,
346 result_format.as_wire_code(),
347 )
348 .map_err(|e| PgError::Encode(e.to_string()))?;
349
350 self.connection.flush_write_buf().await?;
351
352 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
354 let mut error: Option<PgError> = None;
355 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
356 super::extended_flow::ExtendedFlowConfig::parse_bind_execute(true),
357 );
358
359 loop {
360 let res = self.connection.recv_with_data_fast().await;
361 match res {
362 Ok((msg_type, data)) => {
363 if let Err(err) = flow.validate_msg_type(
364 msg_type,
365 "driver fetch_all_fast execute",
366 error.is_some(),
367 ) {
368 return return_with_desync(self, err);
369 }
370 match msg_type {
371 b'D' => {
372 if error.is_none()
373 && let Some(columns) = data
374 {
375 rows.push(PgRow {
376 columns,
377 column_info: None,
378 });
379 }
380 }
381 b'Z' => {
382 if let Some(err) = error {
383 return Err(err);
384 }
385 return Ok(rows);
386 }
387 _ => {}
388 }
389 }
390 Err(e) => {
391 if matches!(&e, PgError::QueryServer(_)) {
393 if error.is_none() {
394 error = Some(e);
395 }
396 continue;
397 }
398 return Err(e);
399 }
400 }
401 }
402 }
403
404 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
406 let rows = self.fetch_all(cmd).await?;
407 rows.into_iter().next().ok_or(PgError::NoRows)
408 }
409
410 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
419 self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
420 .await
421 }
422
423 pub async fn fetch_all_cached_with_format(
425 &mut self,
426 cmd: &Qail,
427 result_format: ResultFormat,
428 ) -> PgResult<Vec<PgRow>> {
429 let mut retried = false;
430 loop {
431 match self
432 .fetch_all_cached_with_format_once(cmd, result_format)
433 .await
434 {
435 Ok(rows) => return Ok(rows),
436 Err(err)
437 if !retried
438 && (err.is_prepared_statement_retryable()
439 || err.is_prepared_statement_already_exists()) =>
440 {
441 retried = true;
442 if err.is_prepared_statement_retryable() {
443 self.connection.clear_prepared_statement_state();
444 }
445 }
446 Err(err) => return Err(err),
447 }
448 }
449 }
450
451 async fn fetch_all_cached_with_format_once(
452 &mut self,
453 cmd: &Qail,
454 result_format: ResultFormat,
455 ) -> PgResult<Vec<PgRow>> {
456 use crate::protocol::AstEncoder;
457 use std::collections::hash_map::DefaultHasher;
458 use std::hash::{Hash, Hasher};
459
460 if !AstEncoder::encode_cacheable_cmd_sql_to(
461 cmd,
462 &mut self.connection.sql_buf,
463 &mut self.connection.params_buf,
464 )? {
465 let (sql, params) =
467 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
468 let raw_rows = self
469 .connection
470 .query_cached_with_result_format(&sql, ¶ms, result_format.as_wire_code())
471 .await?;
472 return Ok(raw_rows
473 .into_iter()
474 .map(|data| PgRow {
475 columns: data,
476 column_info: None,
477 })
478 .collect());
479 }
480
481 let mut hasher = DefaultHasher::new();
482 self.connection.sql_buf.hash(&mut hasher);
483 let sql_hash = hasher.finish();
484
485 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
486
487 self.connection.write_buf.clear();
489
490 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
491 name
492 } else {
493 let name = format!("qail_{:x}", sql_hash);
494
495 self.connection.evict_prepared_if_full();
497
498 let sql_str = encoded_sql_str(&self.connection.sql_buf)?;
499
500 use crate::protocol::PgEncoder;
502 let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
503 let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
504 self.connection.write_buf.extend_from_slice(&parse_msg);
505 self.connection.write_buf.extend_from_slice(&describe_msg);
506
507 self.connection.stmt_cache.put(sql_hash, name.clone());
508 self.connection
509 .prepared_statements
510 .insert(name.clone(), sql_str.to_string());
511
512 name
513 };
514
515 use crate::protocol::PgEncoder;
517 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
518 &mut self.connection.write_buf,
519 &stmt_name,
520 &self.connection.params_buf,
521 result_format.as_wire_code(),
522 ) {
523 if is_cache_miss {
524 self.connection.stmt_cache.remove(&sql_hash);
525 self.connection.prepared_statements.remove(&stmt_name);
526 self.connection.column_info_cache.remove(&sql_hash);
527 }
528 return Err(PgError::Encode(e.to_string()));
529 }
530 PgEncoder::encode_execute_to(&mut self.connection.write_buf);
531 PgEncoder::encode_sync_to(&mut self.connection.write_buf);
532
533 if let Err(err) = self.connection.flush_write_buf().await {
535 if is_cache_miss {
536 self.connection.stmt_cache.remove(&sql_hash);
537 self.connection.prepared_statements.remove(&stmt_name);
538 self.connection.column_info_cache.remove(&sql_hash);
539 }
540 return Err(err);
541 }
542
543 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
545
546 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
547 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
548 let mut error: Option<PgError> = None;
549 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
550 super::extended_flow::ExtendedFlowConfig::parse_describe_statement_bind_execute(
551 is_cache_miss,
552 ),
553 );
554
555 loop {
556 let msg = match self.connection.recv().await {
557 Ok(msg) => msg,
558 Err(err) => {
559 if is_cache_miss && !flow.saw_parse_complete() {
560 self.connection.stmt_cache.remove(&sql_hash);
561 self.connection.prepared_statements.remove(&stmt_name);
562 self.connection.column_info_cache.remove(&sql_hash);
563 }
564 return Err(err);
565 }
566 };
567 if let Err(err) =
568 flow.validate(&msg, "driver fetch_all_cached execute", error.is_some())
569 {
570 if is_cache_miss && !flow.saw_parse_complete() {
571 self.connection.stmt_cache.remove(&sql_hash);
572 self.connection.prepared_statements.remove(&stmt_name);
573 self.connection.column_info_cache.remove(&sql_hash);
574 }
575 return return_with_desync(self, err);
576 }
577 match msg {
578 crate::protocol::BackendMessage::ParseComplete => {}
579 crate::protocol::BackendMessage::BindComplete => {}
580 crate::protocol::BackendMessage::ParameterDescription(_) => {
581 }
583 crate::protocol::BackendMessage::RowDescription(fields) => {
584 let info = Arc::new(ColumnInfo::from_fields(&fields));
586 if is_cache_miss {
587 self.connection
588 .column_info_cache
589 .insert(sql_hash, Arc::clone(&info));
590 }
591 column_info = Some(info);
592 }
593 crate::protocol::BackendMessage::DataRow(data) => {
594 if error.is_none() {
595 rows.push(PgRow {
596 columns: data,
597 column_info: column_info.clone(),
598 });
599 }
600 }
601 crate::protocol::BackendMessage::CommandComplete(_) => {}
602 crate::protocol::BackendMessage::NoData => {
603 }
605 crate::protocol::BackendMessage::ReadyForQuery(_) => {
606 if let Some(err) = error {
607 if is_cache_miss
608 && !flow.saw_parse_complete()
609 && !err.is_prepared_statement_already_exists()
610 {
611 self.connection.stmt_cache.remove(&sql_hash);
612 self.connection.prepared_statements.remove(&stmt_name);
613 self.connection.column_info_cache.remove(&sql_hash);
614 }
615 return Err(err);
616 }
617 if is_cache_miss && !flow.saw_parse_complete() {
618 self.connection.stmt_cache.remove(&sql_hash);
619 self.connection.prepared_statements.remove(&stmt_name);
620 self.connection.column_info_cache.remove(&sql_hash);
621 return return_with_desync(
622 self,
623 PgError::Protocol(
624 "Cache miss query reached ReadyForQuery without ParseComplete"
625 .to_string(),
626 ),
627 );
628 }
629 return Ok(rows);
630 }
631 crate::protocol::BackendMessage::ErrorResponse(err) => {
632 if error.is_none() {
633 let query_err = PgError::QueryServer(err.into());
634 if query_err.is_prepared_statement_retryable() {
635 self.connection.clear_prepared_statement_state();
636 }
637 error = Some(query_err);
638 }
639 }
640 msg if is_ignorable_session_message(&msg) => {}
641 other => {
642 if is_cache_miss && !flow.saw_parse_complete() {
643 self.connection.stmt_cache.remove(&sql_hash);
644 self.connection.prepared_statements.remove(&stmt_name);
645 self.connection.column_info_cache.remove(&sql_hash);
646 }
647 return return_with_desync(
648 self,
649 unexpected_backend_message("driver fetch_all_cached execute", &other),
650 );
651 }
652 }
653 }
654 }
655
656 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
658 use crate::protocol::AstEncoder;
659
660 let wire_bytes = AstEncoder::encode_cmd_reuse(
661 cmd,
662 &mut self.connection.sql_buf,
663 &mut self.connection.params_buf,
664 )
665 .map_err(|e| PgError::Encode(e.to_string()))?;
666
667 self.connection.send_bytes(&wire_bytes).await?;
668
669 let mut affected = 0u64;
670 let mut error: Option<PgError> = None;
671 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
672 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
673 );
674
675 loop {
676 let msg = self.connection.recv().await?;
677 if let Err(err) = flow.validate(&msg, "driver execute mutation", error.is_some()) {
678 return return_with_desync(self, err);
679 }
680 match msg {
681 crate::protocol::BackendMessage::ParseComplete
682 | crate::protocol::BackendMessage::BindComplete => {}
683 crate::protocol::BackendMessage::RowDescription(_) => {}
684 crate::protocol::BackendMessage::DataRow(_) => {}
685 crate::protocol::BackendMessage::NoData => {}
686 crate::protocol::BackendMessage::CommandComplete(tag) => {
687 if error.is_none() {
688 match super::parse_affected_rows(&tag) {
689 Ok(parsed) => affected = parsed,
690 Err(err) => return return_with_desync(self, err),
691 }
692 }
693 }
694 crate::protocol::BackendMessage::ReadyForQuery(_) => {
695 if let Some(err) = error {
696 return Err(err);
697 }
698 return Ok(affected);
699 }
700 crate::protocol::BackendMessage::ErrorResponse(err) => {
701 if error.is_none() {
702 error = Some(PgError::QueryServer(err.into()));
703 }
704 }
705 msg if is_ignorable_session_message(&msg) => {}
706 other => {
707 return return_with_desync(
708 self,
709 unexpected_backend_message("driver execute mutation", &other),
710 );
711 }
712 }
713 }
714 }
715
716 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
720 self.query_ast_with_format(cmd, ResultFormat::Text).await
721 }
722
723 pub async fn query_ast_with_format(
725 &mut self,
726 cmd: &Qail,
727 result_format: ResultFormat,
728 ) -> PgResult<QueryResult> {
729 use crate::protocol::AstEncoder;
730
731 let wire_bytes = AstEncoder::encode_cmd_reuse_with_result_format(
732 cmd,
733 &mut self.connection.sql_buf,
734 &mut self.connection.params_buf,
735 result_format.as_wire_code(),
736 )
737 .map_err(|e| PgError::Encode(e.to_string()))?;
738
739 self.connection.send_bytes(&wire_bytes).await?;
740
741 let mut columns: Vec<String> = Vec::new();
742 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
743 let mut error: Option<PgError> = None;
744 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
745 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
746 );
747
748 loop {
749 let msg = self.connection.recv().await?;
750 if let Err(err) = flow.validate(&msg, "driver query_ast", error.is_some()) {
751 return return_with_desync(self, err);
752 }
753 match msg {
754 crate::protocol::BackendMessage::ParseComplete
755 | crate::protocol::BackendMessage::BindComplete => {}
756 crate::protocol::BackendMessage::RowDescription(fields) => {
757 columns = fields.into_iter().map(|f| f.name).collect();
758 }
759 crate::protocol::BackendMessage::DataRow(data) => {
760 if error.is_none() {
761 let row: Vec<Option<String>> = data
762 .into_iter()
763 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
764 .collect();
765 rows.push(row);
766 }
767 }
768 crate::protocol::BackendMessage::CommandComplete(_) => {}
769 crate::protocol::BackendMessage::NoData => {}
770 crate::protocol::BackendMessage::ReadyForQuery(_) => {
771 if let Some(err) = error {
772 return Err(err);
773 }
774 return Ok(QueryResult { columns, rows });
775 }
776 crate::protocol::BackendMessage::ErrorResponse(err) => {
777 if error.is_none() {
778 error = Some(PgError::QueryServer(err.into()));
779 }
780 }
781 msg if is_ignorable_session_message(&msg) => {}
782 other => {
783 return return_with_desync(
784 self,
785 unexpected_backend_message("driver query_ast", &other),
786 );
787 }
788 }
789 }
790 }
791}
792
793#[cfg(test)]
794mod tests {
795 use super::*;
796
797 #[test]
798 fn driver_encoded_sql_str_rejects_invalid_utf8() {
799 let err = encoded_sql_str(&[0xff]).expect_err("invalid SQL UTF-8 must fail");
800 assert!(err.to_string().contains("encoded SQL is not UTF-8"));
801 }
802
803 #[cfg(unix)]
804 fn test_driver_with_peer() -> (PgDriver, tokio::net::UnixStream) {
805 use crate::driver::connection::StatementCache;
806 use crate::driver::stream::PgStream;
807 use bytes::BytesMut;
808 use std::collections::{HashMap, VecDeque};
809 use std::num::NonZeroUsize;
810 use tokio::net::UnixStream;
811
812 let (unix_stream, peer) = UnixStream::pair().expect("unix stream pair");
813 let conn = super::super::PgConnection {
814 stream: PgStream::Unix(unix_stream),
815 buffer: BytesMut::with_capacity(1024),
816 write_buf: BytesMut::with_capacity(1024),
817 sql_buf: BytesMut::with_capacity(256),
818 params_buf: Vec::new(),
819 prepared_statements: HashMap::new(),
820 stmt_cache: StatementCache::new(NonZeroUsize::new(2).expect("non-zero")),
821 column_info_cache: HashMap::new(),
822 process_id: 0,
823 cancel_key_bytes: Vec::new(),
824 requested_protocol_minor: super::super::PgConnection::default_protocol_minor(),
825 negotiated_protocol_minor: super::super::PgConnection::default_protocol_minor(),
826 notifications: VecDeque::new(),
827 replication_stream_active: false,
828 replication_mode_enabled: false,
829 last_replication_wal_end: None,
830 io_desynced: false,
831 pending_statement_closes: Vec::new(),
832 draining_statement_closes: false,
833 };
834 (PgDriver::new(conn), peer)
835 }
836
837 #[cfg(unix)]
838 fn push_backend_frame(driver: &mut PgDriver, msg_type: u8, payload: &[u8]) {
839 driver.connection.buffer.extend_from_slice(&[msg_type]);
840 driver
841 .connection
842 .buffer
843 .extend_from_slice(&((payload.len() + 4) as u32).to_be_bytes());
844 driver.connection.buffer.extend_from_slice(payload);
845 }
846
847 #[cfg(unix)]
848 fn push_command_complete(driver: &mut PgDriver, tag: &str) {
849 let mut payload = Vec::with_capacity(tag.len() + 1);
850 payload.extend_from_slice(tag.as_bytes());
851 payload.push(0);
852 push_backend_frame(driver, b'C', &payload);
853 }
854
855 #[cfg(unix)]
856 #[tokio::test]
857 async fn fetch_fast_protocol_error_marks_driver_connection_desynced() {
858 let (mut driver, _peer) = test_driver_with_peer();
859 push_backend_frame(&mut driver, b'D', &0i16.to_be_bytes());
860
861 let err = match driver.fetch_all_fast(&Qail::get("users")).await {
862 Ok(_) => panic!("out-of-order DataRow must fail"),
863 Err(err) => err,
864 };
865
866 assert!(err.to_string().contains("DataRow before BindComplete"));
867 assert!(driver.connection.is_io_desynced());
868 }
869
870 #[cfg(unix)]
871 #[tokio::test]
872 async fn execute_bad_command_tag_marks_driver_connection_desynced() {
873 let (mut driver, _peer) = test_driver_with_peer();
874 push_backend_frame(&mut driver, b'1', &[]);
875 push_backend_frame(&mut driver, b'2', &[]);
876 push_backend_frame(&mut driver, b'n', &[]);
877 push_command_complete(&mut driver, "UPDATE");
878 push_backend_frame(&mut driver, b'Z', b"I");
879
880 let err = driver
881 .execute(&Qail::get("users"))
882 .await
883 .expect_err("malformed CommandComplete tag must fail");
884
885 assert!(
886 err.to_string().contains("missing affected row count")
887 || err.to_string().contains("invalid affected row count")
888 );
889 assert!(driver.connection.is_io_desynced());
890 }
891}