1use super::core::PgDriver;
5use super::prepared::PreparedAstQuery;
6use super::types::*;
7use qail_core::ast::Qail;
8use std::sync::Arc;
9use std::{
10 collections::hash_map::DefaultHasher,
11 hash::{Hash, Hasher},
12};
13
14impl PgDriver {
15 pub async fn fetch_all(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
22 self.fetch_all_with_format(cmd, ResultFormat::Text).await
23 }
24
25 pub async fn fetch_all_with_format(
31 &mut self,
32 cmd: &Qail,
33 result_format: ResultFormat,
34 ) -> PgResult<Vec<PgRow>> {
35 self.fetch_all_cached_with_format(cmd, result_format).await
37 }
38
39 pub async fn prepare_ast_query(&mut self, cmd: &Qail) -> PgResult<PreparedAstQuery> {
45 use crate::protocol::AstEncoder;
46
47 let (sql, params) =
48 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
49 let stmt = self.connection.prepare(&sql).await?;
50
51 let mut hasher = DefaultHasher::new();
52 sql.hash(&mut hasher);
53 let sql_hash = hasher.finish();
54
55 self.connection
56 .stmt_cache
57 .put(sql_hash, stmt.name().to_string());
58 self.connection
59 .prepared_statements
60 .insert(stmt.name().to_string(), sql.clone());
61
62 Ok(PreparedAstQuery {
63 stmt,
64 params,
65 sql,
66 sql_hash,
67 })
68 }
69
70 pub async fn fetch_all_prepared_ast(
75 &mut self,
76 prepared: &PreparedAstQuery,
77 ) -> PgResult<Vec<PgRow>> {
78 self.fetch_all_prepared_ast_with_format(prepared, ResultFormat::Text)
79 .await
80 }
81
82 pub async fn fetch_all_prepared_ast_with_format(
84 &mut self,
85 prepared: &PreparedAstQuery,
86 result_format: ResultFormat,
87 ) -> PgResult<Vec<PgRow>> {
88 let mut retried = false;
89
90 loop {
91 self.connection.stmt_cache.touch_key(prepared.sql_hash);
92 self.connection.write_buf.clear();
93 if let Err(e) = crate::protocol::PgEncoder::encode_bind_to_with_result_format(
94 &mut self.connection.write_buf,
95 prepared.stmt.name(),
96 &prepared.params,
97 result_format.as_wire_code(),
98 ) {
99 return Err(PgError::Encode(e.to_string()));
100 }
101 crate::protocol::PgEncoder::encode_execute_to(&mut self.connection.write_buf);
102 crate::protocol::PgEncoder::encode_sync_to(&mut self.connection.write_buf);
103
104 if let Err(err) = self.connection.flush_write_buf().await {
105 if !retried && err.is_prepared_statement_retryable() {
106 retried = true;
107 let stmt = self.connection.prepare(&prepared.sql).await?;
108 self.connection
109 .stmt_cache
110 .put(prepared.sql_hash, stmt.name().to_string());
111 self.connection
112 .prepared_statements
113 .insert(stmt.name().to_string(), prepared.sql.clone());
114 continue;
115 }
116 return Err(err);
117 }
118
119 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
120 let mut error: Option<PgError> = None;
121 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
122 super::extended_flow::ExtendedFlowConfig::parse_bind_execute(false),
123 );
124
125 loop {
126 let msg = self.connection.recv().await?;
127 flow.validate(
128 &msg,
129 "driver fetch_all_prepared_ast execute",
130 error.is_some(),
131 )?;
132 match msg {
133 crate::protocol::BackendMessage::BindComplete => {}
134 crate::protocol::BackendMessage::RowDescription(_) => {}
135 crate::protocol::BackendMessage::DataRow(data) => {
136 if error.is_none() {
137 rows.push(PgRow {
138 columns: data,
139 column_info: None,
140 });
141 }
142 }
143 crate::protocol::BackendMessage::CommandComplete(_) => {}
144 crate::protocol::BackendMessage::NoData => {}
145 crate::protocol::BackendMessage::ReadyForQuery(_) => {
146 if let Some(err) = error {
147 if !retried && err.is_prepared_statement_retryable() {
148 retried = true;
149 let stmt = self.connection.prepare(&prepared.sql).await?;
150 self.connection
151 .stmt_cache
152 .put(prepared.sql_hash, stmt.name().to_string());
153 self.connection
154 .prepared_statements
155 .insert(stmt.name().to_string(), prepared.sql.clone());
156 break;
157 }
158 return Err(err);
159 }
160 return Ok(rows);
161 }
162 crate::protocol::BackendMessage::ErrorResponse(err) => {
163 if error.is_none() {
164 error = Some(PgError::QueryServer(err.into()));
165 }
166 }
167 msg if is_ignorable_session_message(&msg) => {}
168 other => {
169 return Err(unexpected_backend_message(
170 "driver fetch_all_prepared_ast execute",
171 &other,
172 ));
173 }
174 }
175 }
176 }
177 }
178
179 pub async fn fetch_typed<T: super::row::QailRow>(&mut self, cmd: &Qail) -> PgResult<Vec<T>> {
187 self.fetch_typed_with_format(cmd, ResultFormat::Text).await
188 }
189
190 pub async fn fetch_typed_with_format<T: super::row::QailRow>(
195 &mut self,
196 cmd: &Qail,
197 result_format: ResultFormat,
198 ) -> PgResult<Vec<T>> {
199 let rows = self.fetch_all_with_format(cmd, result_format).await?;
200 Ok(rows.iter().map(T::from_row).collect())
201 }
202
203 pub async fn fetch_one_typed<T: super::row::QailRow>(
206 &mut self,
207 cmd: &Qail,
208 ) -> PgResult<Option<T>> {
209 self.fetch_one_typed_with_format(cmd, ResultFormat::Text)
210 .await
211 }
212
213 pub async fn fetch_one_typed_with_format<T: super::row::QailRow>(
215 &mut self,
216 cmd: &Qail,
217 result_format: ResultFormat,
218 ) -> PgResult<Option<T>> {
219 let rows = self.fetch_all_with_format(cmd, result_format).await?;
220 Ok(rows.first().map(T::from_row))
221 }
222
223 pub async fn fetch_all_uncached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
229 self.fetch_all_uncached_with_format(cmd, ResultFormat::Text)
230 .await
231 }
232
233 pub async fn fetch_all_uncached_with_format(
235 &mut self,
236 cmd: &Qail,
237 result_format: ResultFormat,
238 ) -> PgResult<Vec<PgRow>> {
239 use crate::protocol::AstEncoder;
240
241 AstEncoder::encode_cmd_reuse_into_with_result_format(
242 cmd,
243 &mut self.connection.sql_buf,
244 &mut self.connection.params_buf,
245 &mut self.connection.write_buf,
246 result_format.as_wire_code(),
247 )
248 .map_err(|e| PgError::Encode(e.to_string()))?;
249
250 self.connection.flush_write_buf().await?;
251
252 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
253 let mut column_info: Option<Arc<ColumnInfo>> = None;
254
255 let mut error: Option<PgError> = None;
256 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
257 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
258 );
259
260 loop {
261 let msg = self.connection.recv().await?;
262 flow.validate(&msg, "driver fetch_all execute", error.is_some())?;
263 match msg {
264 crate::protocol::BackendMessage::ParseComplete
265 | crate::protocol::BackendMessage::BindComplete => {}
266 crate::protocol::BackendMessage::RowDescription(fields) => {
267 column_info = Some(Arc::new(ColumnInfo::from_fields(&fields)));
268 }
269 crate::protocol::BackendMessage::DataRow(data) => {
270 if error.is_none() {
271 rows.push(PgRow {
272 columns: data,
273 column_info: column_info.clone(),
274 });
275 }
276 }
277 crate::protocol::BackendMessage::NoData => {}
278 crate::protocol::BackendMessage::CommandComplete(_) => {}
279 crate::protocol::BackendMessage::ReadyForQuery(_) => {
280 if let Some(err) = error {
281 return Err(err);
282 }
283 return Ok(rows);
284 }
285 crate::protocol::BackendMessage::ErrorResponse(err) => {
286 if error.is_none() {
287 error = Some(PgError::QueryServer(err.into()));
288 }
289 }
290 msg if is_ignorable_session_message(&msg) => {}
291 other => {
292 return Err(unexpected_backend_message(
293 "driver fetch_all execute",
294 &other,
295 ));
296 }
297 }
298 }
299 }
300
301 pub async fn fetch_all_fast(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
305 self.fetch_all_fast_with_format(cmd, ResultFormat::Text)
306 .await
307 }
308
309 pub async fn fetch_all_fast_with_format(
311 &mut self,
312 cmd: &Qail,
313 result_format: ResultFormat,
314 ) -> PgResult<Vec<PgRow>> {
315 use crate::protocol::AstEncoder;
316
317 AstEncoder::encode_cmd_reuse_into_with_result_format(
318 cmd,
319 &mut self.connection.sql_buf,
320 &mut self.connection.params_buf,
321 &mut self.connection.write_buf,
322 result_format.as_wire_code(),
323 )
324 .map_err(|e| PgError::Encode(e.to_string()))?;
325
326 self.connection.flush_write_buf().await?;
327
328 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
330 let mut error: Option<PgError> = None;
331 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
332 super::extended_flow::ExtendedFlowConfig::parse_bind_execute(true),
333 );
334
335 loop {
336 let res = self.connection.recv_with_data_fast().await;
337 match res {
338 Ok((msg_type, data)) => {
339 flow.validate_msg_type(
340 msg_type,
341 "driver fetch_all_fast execute",
342 error.is_some(),
343 )?;
344 match msg_type {
345 b'D' => {
346 if error.is_none()
347 && let Some(columns) = data
348 {
349 rows.push(PgRow {
350 columns,
351 column_info: None,
352 });
353 }
354 }
355 b'Z' => {
356 if let Some(err) = error {
357 return Err(err);
358 }
359 return Ok(rows);
360 }
361 _ => {}
362 }
363 }
364 Err(e) => {
365 if matches!(&e, PgError::QueryServer(_)) {
367 if error.is_none() {
368 error = Some(e);
369 }
370 continue;
371 }
372 return Err(e);
373 }
374 }
375 }
376 }
377
378 pub async fn fetch_one(&mut self, cmd: &Qail) -> PgResult<PgRow> {
380 let rows = self.fetch_all(cmd).await?;
381 rows.into_iter().next().ok_or(PgError::NoRows)
382 }
383
384 pub async fn fetch_all_cached(&mut self, cmd: &Qail) -> PgResult<Vec<PgRow>> {
393 self.fetch_all_cached_with_format(cmd, ResultFormat::Text)
394 .await
395 }
396
397 pub async fn fetch_all_cached_with_format(
399 &mut self,
400 cmd: &Qail,
401 result_format: ResultFormat,
402 ) -> PgResult<Vec<PgRow>> {
403 let mut retried = false;
404 loop {
405 match self
406 .fetch_all_cached_with_format_once(cmd, result_format)
407 .await
408 {
409 Ok(rows) => return Ok(rows),
410 Err(err)
411 if !retried
412 && (err.is_prepared_statement_retryable()
413 || err.is_prepared_statement_already_exists()) =>
414 {
415 retried = true;
416 if err.is_prepared_statement_retryable() {
417 self.connection.clear_prepared_statement_state();
418 }
419 }
420 Err(err) => return Err(err),
421 }
422 }
423 }
424
425 async fn fetch_all_cached_with_format_once(
426 &mut self,
427 cmd: &Qail,
428 result_format: ResultFormat,
429 ) -> PgResult<Vec<PgRow>> {
430 use crate::protocol::AstEncoder;
431 use std::collections::hash_map::DefaultHasher;
432 use std::hash::{Hash, Hasher};
433
434 self.connection.sql_buf.clear();
435 self.connection.params_buf.clear();
436
437 match cmd.action {
439 qail_core::ast::Action::Get | qail_core::ast::Action::With => {
440 crate::protocol::ast_encoder::dml::encode_select(
441 cmd,
442 &mut self.connection.sql_buf,
443 &mut self.connection.params_buf,
444 )?;
445 }
446 qail_core::ast::Action::Add => {
447 crate::protocol::ast_encoder::dml::encode_insert(
448 cmd,
449 &mut self.connection.sql_buf,
450 &mut self.connection.params_buf,
451 )?;
452 }
453 qail_core::ast::Action::Set => {
454 crate::protocol::ast_encoder::dml::encode_update(
455 cmd,
456 &mut self.connection.sql_buf,
457 &mut self.connection.params_buf,
458 )?;
459 }
460 qail_core::ast::Action::Del => {
461 crate::protocol::ast_encoder::dml::encode_delete(
462 cmd,
463 &mut self.connection.sql_buf,
464 &mut self.connection.params_buf,
465 )?;
466 }
467 _ => {
468 let (sql, params) =
470 AstEncoder::encode_cmd_sql(cmd).map_err(|e| PgError::Encode(e.to_string()))?;
471 let raw_rows = self
472 .connection
473 .query_cached_with_result_format(&sql, ¶ms, result_format.as_wire_code())
474 .await?;
475 return Ok(raw_rows
476 .into_iter()
477 .map(|data| PgRow {
478 columns: data,
479 column_info: None,
480 })
481 .collect());
482 }
483 }
484
485 let mut hasher = DefaultHasher::new();
486 self.connection.sql_buf.hash(&mut hasher);
487 let sql_hash = hasher.finish();
488
489 let is_cache_miss = !self.connection.stmt_cache.contains(&sql_hash);
490
491 self.connection.write_buf.clear();
493
494 let stmt_name = if let Some(name) = self.connection.stmt_cache.get(&sql_hash) {
495 name
496 } else {
497 let name = format!("qail_{:x}", sql_hash);
498
499 self.connection.evict_prepared_if_full();
501
502 let sql_str = std::str::from_utf8(&self.connection.sql_buf).unwrap_or("");
503
504 use crate::protocol::PgEncoder;
506 let parse_msg = PgEncoder::try_encode_parse(&name, sql_str, &[])?;
507 let describe_msg = PgEncoder::try_encode_describe(false, &name)?;
508 self.connection.write_buf.extend_from_slice(&parse_msg);
509 self.connection.write_buf.extend_from_slice(&describe_msg);
510
511 self.connection.stmt_cache.put(sql_hash, name.clone());
512 self.connection
513 .prepared_statements
514 .insert(name.clone(), sql_str.to_string());
515
516 name
517 };
518
519 use crate::protocol::PgEncoder;
521 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
522 &mut self.connection.write_buf,
523 &stmt_name,
524 &self.connection.params_buf,
525 result_format.as_wire_code(),
526 ) {
527 if is_cache_miss {
528 self.connection.stmt_cache.remove(&sql_hash);
529 self.connection.prepared_statements.remove(&stmt_name);
530 self.connection.column_info_cache.remove(&sql_hash);
531 }
532 return Err(PgError::Encode(e.to_string()));
533 }
534 PgEncoder::encode_execute_to(&mut self.connection.write_buf);
535 PgEncoder::encode_sync_to(&mut self.connection.write_buf);
536
537 if let Err(err) = self.connection.flush_write_buf().await {
539 if is_cache_miss {
540 self.connection.stmt_cache.remove(&sql_hash);
541 self.connection.prepared_statements.remove(&stmt_name);
542 self.connection.column_info_cache.remove(&sql_hash);
543 }
544 return Err(err);
545 }
546
547 let cached_column_info = self.connection.column_info_cache.get(&sql_hash).cloned();
549
550 let mut rows: Vec<PgRow> = Vec::with_capacity(32);
551 let mut column_info: Option<Arc<ColumnInfo>> = cached_column_info;
552 let mut error: Option<PgError> = None;
553 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
554 super::extended_flow::ExtendedFlowConfig::parse_describe_statement_bind_execute(
555 is_cache_miss,
556 ),
557 );
558
559 loop {
560 let msg = match self.connection.recv().await {
561 Ok(msg) => msg,
562 Err(err) => {
563 if is_cache_miss && !flow.saw_parse_complete() {
564 self.connection.stmt_cache.remove(&sql_hash);
565 self.connection.prepared_statements.remove(&stmt_name);
566 self.connection.column_info_cache.remove(&sql_hash);
567 }
568 return Err(err);
569 }
570 };
571 if let Err(err) =
572 flow.validate(&msg, "driver fetch_all_cached execute", error.is_some())
573 {
574 if is_cache_miss && !flow.saw_parse_complete() {
575 self.connection.stmt_cache.remove(&sql_hash);
576 self.connection.prepared_statements.remove(&stmt_name);
577 self.connection.column_info_cache.remove(&sql_hash);
578 }
579 return Err(err);
580 }
581 match msg {
582 crate::protocol::BackendMessage::ParseComplete => {}
583 crate::protocol::BackendMessage::BindComplete => {}
584 crate::protocol::BackendMessage::ParameterDescription(_) => {
585 }
587 crate::protocol::BackendMessage::RowDescription(fields) => {
588 let info = Arc::new(ColumnInfo::from_fields(&fields));
590 if is_cache_miss {
591 self.connection
592 .column_info_cache
593 .insert(sql_hash, info.clone());
594 }
595 column_info = Some(info);
596 }
597 crate::protocol::BackendMessage::DataRow(data) => {
598 if error.is_none() {
599 rows.push(PgRow {
600 columns: data,
601 column_info: column_info.clone(),
602 });
603 }
604 }
605 crate::protocol::BackendMessage::CommandComplete(_) => {}
606 crate::protocol::BackendMessage::NoData => {
607 }
609 crate::protocol::BackendMessage::ReadyForQuery(_) => {
610 if let Some(err) = error {
611 if is_cache_miss
612 && !flow.saw_parse_complete()
613 && !err.is_prepared_statement_already_exists()
614 {
615 self.connection.stmt_cache.remove(&sql_hash);
616 self.connection.prepared_statements.remove(&stmt_name);
617 self.connection.column_info_cache.remove(&sql_hash);
618 }
619 return Err(err);
620 }
621 if is_cache_miss && !flow.saw_parse_complete() {
622 self.connection.stmt_cache.remove(&sql_hash);
623 self.connection.prepared_statements.remove(&stmt_name);
624 self.connection.column_info_cache.remove(&sql_hash);
625 return Err(PgError::Protocol(
626 "Cache miss query reached ReadyForQuery without ParseComplete"
627 .to_string(),
628 ));
629 }
630 return Ok(rows);
631 }
632 crate::protocol::BackendMessage::ErrorResponse(err) => {
633 if error.is_none() {
634 let query_err = PgError::QueryServer(err.into());
635 if query_err.is_prepared_statement_retryable() {
636 self.connection.clear_prepared_statement_state();
637 }
638 error = Some(query_err);
639 }
640 }
641 msg if is_ignorable_session_message(&msg) => {}
642 other => {
643 if is_cache_miss && !flow.saw_parse_complete() {
644 self.connection.stmt_cache.remove(&sql_hash);
645 self.connection.prepared_statements.remove(&stmt_name);
646 self.connection.column_info_cache.remove(&sql_hash);
647 }
648 return Err(unexpected_backend_message(
649 "driver fetch_all_cached execute",
650 &other,
651 ));
652 }
653 }
654 }
655 }
656
657 pub async fn execute(&mut self, cmd: &Qail) -> PgResult<u64> {
659 use crate::protocol::AstEncoder;
660
661 let wire_bytes = AstEncoder::encode_cmd_reuse(
662 cmd,
663 &mut self.connection.sql_buf,
664 &mut self.connection.params_buf,
665 )
666 .map_err(|e| PgError::Encode(e.to_string()))?;
667
668 self.connection.send_bytes(&wire_bytes).await?;
669
670 let mut affected = 0u64;
671 let mut error: Option<PgError> = None;
672 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
673 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
674 );
675
676 loop {
677 let msg = self.connection.recv().await?;
678 flow.validate(&msg, "driver execute mutation", error.is_some())?;
679 match msg {
680 crate::protocol::BackendMessage::ParseComplete
681 | crate::protocol::BackendMessage::BindComplete => {}
682 crate::protocol::BackendMessage::RowDescription(_) => {}
683 crate::protocol::BackendMessage::DataRow(_) => {}
684 crate::protocol::BackendMessage::NoData => {}
685 crate::protocol::BackendMessage::CommandComplete(tag) => {
686 if error.is_none()
687 && let Some(n) = tag.split_whitespace().last()
688 {
689 affected = n.parse().unwrap_or(0);
690 }
691 }
692 crate::protocol::BackendMessage::ReadyForQuery(_) => {
693 if let Some(err) = error {
694 return Err(err);
695 }
696 return Ok(affected);
697 }
698 crate::protocol::BackendMessage::ErrorResponse(err) => {
699 if error.is_none() {
700 error = Some(PgError::QueryServer(err.into()));
701 }
702 }
703 msg if is_ignorable_session_message(&msg) => {}
704 other => {
705 return Err(unexpected_backend_message(
706 "driver execute mutation",
707 &other,
708 ));
709 }
710 }
711 }
712 }
713
714 pub async fn query_ast(&mut self, cmd: &Qail) -> PgResult<QueryResult> {
718 self.query_ast_with_format(cmd, ResultFormat::Text).await
719 }
720
721 pub async fn query_ast_with_format(
723 &mut self,
724 cmd: &Qail,
725 result_format: ResultFormat,
726 ) -> PgResult<QueryResult> {
727 use crate::protocol::AstEncoder;
728
729 let wire_bytes = AstEncoder::encode_cmd_reuse_with_result_format(
730 cmd,
731 &mut self.connection.sql_buf,
732 &mut self.connection.params_buf,
733 result_format.as_wire_code(),
734 )
735 .map_err(|e| PgError::Encode(e.to_string()))?;
736
737 self.connection.send_bytes(&wire_bytes).await?;
738
739 let mut columns: Vec<String> = Vec::new();
740 let mut rows: Vec<Vec<Option<String>>> = Vec::new();
741 let mut error: Option<PgError> = None;
742 let mut flow = super::extended_flow::ExtendedFlowTracker::new(
743 super::extended_flow::ExtendedFlowConfig::parse_bind_describe_portal_execute(),
744 );
745
746 loop {
747 let msg = self.connection.recv().await?;
748 flow.validate(&msg, "driver query_ast", error.is_some())?;
749 match msg {
750 crate::protocol::BackendMessage::ParseComplete
751 | crate::protocol::BackendMessage::BindComplete => {}
752 crate::protocol::BackendMessage::RowDescription(fields) => {
753 columns = fields.into_iter().map(|f| f.name).collect();
754 }
755 crate::protocol::BackendMessage::DataRow(data) => {
756 if error.is_none() {
757 let row: Vec<Option<String>> = data
758 .into_iter()
759 .map(|col| col.map(|bytes| String::from_utf8_lossy(&bytes).to_string()))
760 .collect();
761 rows.push(row);
762 }
763 }
764 crate::protocol::BackendMessage::CommandComplete(_) => {}
765 crate::protocol::BackendMessage::NoData => {}
766 crate::protocol::BackendMessage::ReadyForQuery(_) => {
767 if let Some(err) = error {
768 return Err(err);
769 }
770 return Ok(QueryResult { columns, rows });
771 }
772 crate::protocol::BackendMessage::ErrorResponse(err) => {
773 if error.is_none() {
774 error = Some(PgError::QueryServer(err.into()));
775 }
776 }
777 msg if is_ignorable_session_message(&msg) => {}
778 other => return Err(unexpected_backend_message("driver query_ast", &other)),
779 }
780 }
781 }
782}