1use super::{
6 PgConnection, PgError, PgResult,
7 extended_flow::{ExtendedFlowConfig, ExtendedFlowTracker},
8 is_ignorable_session_message, unexpected_backend_message,
9};
10use crate::protocol::{BackendMessage, PgEncoder};
11use bytes::BytesMut;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum SimpleStatementState {
15 AwaitingResult,
16 InRowStream,
17}
18
19#[derive(Debug, Clone, Copy)]
20struct SimpleFlowTracker {
21 state: SimpleStatementState,
22 saw_completion: bool,
23}
24
25impl SimpleFlowTracker {
26 fn new() -> Self {
27 Self {
28 state: SimpleStatementState::AwaitingResult,
29 saw_completion: false,
30 }
31 }
32
33 fn on_row_description(&mut self, context: &'static str) -> PgResult<()> {
34 if self.state == SimpleStatementState::InRowStream {
35 return Err(PgError::Protocol(format!(
36 "{}: duplicate RowDescription before statement completion",
37 context
38 )));
39 }
40 self.state = SimpleStatementState::InRowStream;
41 self.saw_completion = false;
42 Ok(())
43 }
44
45 fn on_data_row(&self, context: &'static str) -> PgResult<()> {
46 if self.state != SimpleStatementState::InRowStream {
47 return Err(PgError::Protocol(format!(
48 "{}: DataRow before RowDescription",
49 context
50 )));
51 }
52 Ok(())
53 }
54
55 fn on_command_complete(&mut self) {
56 self.state = SimpleStatementState::AwaitingResult;
57 self.saw_completion = true;
58 }
59
60 fn on_empty_query_response(&mut self, context: &'static str) -> PgResult<()> {
61 if self.state == SimpleStatementState::InRowStream {
62 return Err(PgError::Protocol(format!(
63 "{}: EmptyQueryResponse during active row stream",
64 context
65 )));
66 }
67 self.saw_completion = true;
68 Ok(())
69 }
70
71 fn on_ready_for_query(&self, context: &'static str, error_pending: bool) -> PgResult<()> {
72 if error_pending {
73 return Ok(());
74 }
75 if self.state == SimpleStatementState::InRowStream {
76 return Err(PgError::Protocol(format!(
77 "{}: ReadyForQuery before CommandComplete",
78 context
79 )));
80 }
81 if !self.saw_completion {
82 return Err(PgError::Protocol(format!(
83 "{}: ReadyForQuery before completion",
84 context
85 )));
86 }
87 Ok(())
88 }
89}
90
91impl PgConnection {
92 pub(crate) async fn query(
98 &mut self,
99 sql: &str,
100 params: &[Option<Vec<u8>>],
101 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
102 self.query_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
103 .await
104 }
105
106 pub(crate) async fn query_with_result_format(
108 &mut self,
109 sql: &str,
110 params: &[Option<Vec<u8>>],
111 result_format: i16,
112 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
113 let bytes = PgEncoder::encode_extended_query_with_result_format(sql, params, result_format)
114 .map_err(|e| PgError::Encode(e.to_string()))?;
115 self.write_all_with_timeout(&bytes, "stream write").await?;
116
117 let mut rows = Vec::new();
118
119 let mut error: Option<PgError> = None;
120 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(true));
121
122 loop {
123 let msg = self.recv().await?;
124 flow.validate(&msg, "extended-query execute", error.is_some())?;
125 match msg {
126 BackendMessage::ParseComplete => {}
127 BackendMessage::BindComplete => {}
128 BackendMessage::RowDescription(_) => {}
129 BackendMessage::DataRow(data) => {
130 if error.is_none() {
132 rows.push(data);
133 }
134 }
135 BackendMessage::CommandComplete(_) => {}
136 BackendMessage::NoData => {}
137 BackendMessage::ReadyForQuery(_) => {
138 if let Some(err) = error {
139 return Err(err);
140 }
141 return Ok(rows);
142 }
143 BackendMessage::ErrorResponse(err) => {
144 if error.is_none() {
145 error = Some(PgError::QueryServer(err.into()));
146 }
147 }
148 msg if is_ignorable_session_message(&msg) => {}
149 other => {
150 return Err(unexpected_backend_message("extended-query execute", &other));
151 }
152 }
153 }
154 }
155
156 pub async fn query_cached(
161 &mut self,
162 sql: &str,
163 params: &[Option<Vec<u8>>],
164 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
165 self.query_cached_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
166 .await
167 }
168
169 pub async fn query_cached_with_result_format(
171 &mut self,
172 sql: &str,
173 params: &[Option<Vec<u8>>],
174 result_format: i16,
175 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
176 let mut retried = false;
177 loop {
178 match self
179 .query_cached_with_result_format_once(sql, params, result_format)
180 .await
181 {
182 Ok(rows) => return Ok(rows),
183 Err(err)
184 if !retried
185 && (err.is_prepared_statement_retryable()
186 || err.is_prepared_statement_already_exists()) =>
187 {
188 retried = true;
189 if err.is_prepared_statement_retryable() {
190 self.clear_prepared_statement_state();
191 }
192 }
193 Err(err) => return Err(err),
194 }
195 }
196 }
197
198 async fn query_cached_with_result_format_once(
199 &mut self,
200 sql: &str,
201 params: &[Option<Vec<u8>>],
202 result_format: i16,
203 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
204 let stmt_name = Self::sql_to_stmt_name(sql);
205 let is_new = !self.prepared_statements.contains_key(&stmt_name);
206
207 let params_size: usize = params
209 .iter()
210 .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
211 .sum();
212
213 let estimated_size = if is_new {
214 50 + sql.len() + stmt_name.len() * 2 + params_size
215 } else {
216 30 + stmt_name.len() + params_size
217 };
218
219 let mut buf = BytesMut::with_capacity(estimated_size);
220
221 if is_new {
222 self.evict_prepared_if_full();
226 buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
227 self.prepared_statements
229 .insert(stmt_name.clone(), sql.to_string());
230 }
231
232 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
234 &mut buf,
235 &stmt_name,
236 params,
237 result_format,
238 ) {
239 if is_new {
240 self.prepared_statements.remove(&stmt_name);
241 }
242 return Err(PgError::Encode(e.to_string()));
243 }
244 PgEncoder::encode_execute_to(&mut buf);
245 PgEncoder::encode_sync_to(&mut buf);
246
247 if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
248 if is_new {
249 self.prepared_statements.remove(&stmt_name);
250 }
251 return Err(err);
252 }
253
254 let mut rows = Vec::new();
255
256 let mut error: Option<PgError> = None;
257 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(is_new));
258
259 loop {
260 let msg = match self.recv().await {
261 Ok(msg) => msg,
262 Err(err) => {
263 if is_new && !flow.saw_parse_complete() {
264 self.prepared_statements.remove(&stmt_name);
265 }
266 return Err(err);
267 }
268 };
269 if let Err(err) = flow.validate(&msg, "extended-query cached execute", error.is_some())
270 {
271 if is_new && !flow.saw_parse_complete() {
272 self.prepared_statements.remove(&stmt_name);
273 }
274 return Err(err);
275 }
276 match msg {
277 BackendMessage::ParseComplete => {
278 }
280 BackendMessage::BindComplete => {}
281 BackendMessage::RowDescription(_) => {}
282 BackendMessage::DataRow(data) => {
283 if error.is_none() {
284 rows.push(data);
285 }
286 }
287 BackendMessage::CommandComplete(_) => {}
288 BackendMessage::NoData => {}
289 BackendMessage::ReadyForQuery(_) => {
290 if let Some(err) = error {
291 if is_new
292 && !flow.saw_parse_complete()
293 && !err.is_prepared_statement_already_exists()
294 {
295 self.prepared_statements.remove(&stmt_name);
296 }
297 return Err(err);
298 }
299 if is_new && !flow.saw_parse_complete() {
300 self.prepared_statements.remove(&stmt_name);
301 return Err(PgError::Protocol(
302 "Cache miss query reached ReadyForQuery without ParseComplete"
303 .to_string(),
304 ));
305 }
306 return Ok(rows);
307 }
308 BackendMessage::ErrorResponse(err) => {
309 if error.is_none() {
310 let query_err = PgError::QueryServer(err.into());
311 if !query_err.is_prepared_statement_already_exists() {
312 self.prepared_statements.remove(&stmt_name);
314 }
315 error = Some(query_err);
316 }
317 }
318 msg if is_ignorable_session_message(&msg) => {}
319 other => {
320 if is_new && !flow.saw_parse_complete() {
321 self.prepared_statements.remove(&stmt_name);
322 }
323 return Err(unexpected_backend_message(
324 "extended-query cached execute",
325 &other,
326 ));
327 }
328 }
329 }
330 }
331
332 pub(crate) fn sql_to_stmt_name(sql: &str) -> String {
335 use std::collections::hash_map::DefaultHasher;
336 use std::hash::{Hash, Hasher};
337
338 let mut hasher = DefaultHasher::new();
339 sql.hash(&mut hasher);
340 format!("s{:016x}", hasher.finish())
341 }
342
343 pub async fn execute_simple(&mut self, sql: &str) -> PgResult<()> {
345 let bytes = PgEncoder::try_encode_query_string(sql)?;
346 self.write_all_with_timeout(&bytes, "stream write").await?;
347
348 let mut error: Option<PgError> = None;
349 let mut flow = SimpleFlowTracker::new();
350
351 loop {
352 let msg = self.recv().await?;
353 match msg {
354 BackendMessage::CommandComplete(_) => {
355 flow.on_command_complete();
356 }
357 BackendMessage::EmptyQueryResponse => {
358 flow.on_empty_query_response("simple-query execute")?;
359 }
360 BackendMessage::ReadyForQuery(_) => {
361 if let Some(err) = error {
362 return Err(err);
363 }
364 flow.on_ready_for_query("simple-query execute", error.is_some())?;
365 return Ok(());
366 }
367 BackendMessage::ErrorResponse(err) => {
368 if error.is_none() {
369 error = Some(PgError::QueryServer(err.into()));
370 }
371 }
372 msg if is_ignorable_session_message(&msg) => {}
373 other => {
374 return Err(unexpected_backend_message("simple-query execute", &other));
375 }
376 }
377 }
378 }
379
380 pub async fn simple_query(&mut self, sql: &str) -> PgResult<Vec<super::PgRow>> {
387 use std::sync::Arc;
388
389 const MAX_SIMPLE_QUERY_ROWS: usize = 10_000;
392
393 let bytes = PgEncoder::try_encode_query_string(sql)?;
394 self.write_all_with_timeout(&bytes, "stream write").await?;
395
396 let mut rows: Vec<super::PgRow> = Vec::new();
397 let mut column_info: Option<Arc<super::ColumnInfo>> = None;
398 let mut error: Option<PgError> = None;
399 let mut flow = SimpleFlowTracker::new();
400
401 loop {
402 let msg = self.recv().await?;
403 match msg {
404 BackendMessage::RowDescription(fields) => {
405 flow.on_row_description("simple-query read")?;
406 column_info = Some(Arc::new(super::ColumnInfo::from_fields(&fields)));
407 }
408 BackendMessage::DataRow(data) => {
409 flow.on_data_row("simple-query read")?;
410 if error.is_none() {
411 if rows.len() >= MAX_SIMPLE_QUERY_ROWS {
412 if error.is_none() {
413 error = Some(PgError::Query(format!(
414 "simple_query exceeded {} row safety cap",
415 MAX_SIMPLE_QUERY_ROWS,
416 )));
417 }
418 } else {
420 rows.push(super::PgRow {
421 columns: data,
422 column_info: column_info.clone(),
423 });
424 }
425 }
426 }
427 BackendMessage::CommandComplete(_) => {
428 flow.on_command_complete();
429 column_info = None;
430 }
431 BackendMessage::EmptyQueryResponse => {
432 flow.on_empty_query_response("simple-query read")?;
433 column_info = None;
434 }
435 BackendMessage::ReadyForQuery(_) => {
436 if let Some(err) = error {
437 return Err(err);
438 }
439 flow.on_ready_for_query("simple-query read", error.is_some())?;
440 return Ok(rows);
441 }
442 BackendMessage::ErrorResponse(err) => {
443 if error.is_none() {
444 error = Some(PgError::QueryServer(err.into()));
445 }
446 }
447 msg if is_ignorable_session_message(&msg) => {}
448 other => {
449 return Err(unexpected_backend_message("simple-query read", &other));
450 }
451 }
452 }
453 }
454
455 #[inline]
468 pub async fn query_prepared_single(
469 &mut self,
470 stmt: &super::PreparedStatement,
471 params: &[Option<Vec<u8>>],
472 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
473 self.query_prepared_single_with_result_format(stmt, params, PgEncoder::FORMAT_TEXT)
474 .await
475 }
476
477 #[inline]
479 pub async fn query_prepared_single_with_result_format(
480 &mut self,
481 stmt: &super::PreparedStatement,
482 params: &[Option<Vec<u8>>],
483 result_format: i16,
484 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
485 let params_size: usize = params
487 .iter()
488 .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
489 .sum();
490
491 let mut buf = BytesMut::with_capacity(30 + stmt.name.len() + params_size);
493
494 PgEncoder::encode_bind_to_with_result_format(&mut buf, &stmt.name, params, result_format)
496 .map_err(|e| PgError::Encode(e.to_string()))?;
497 PgEncoder::encode_execute_to(&mut buf);
498 PgEncoder::encode_sync_to(&mut buf);
499
500 self.write_all_with_timeout(&buf, "stream write").await?;
501
502 let mut rows = Vec::new();
503
504 let mut error: Option<PgError> = None;
505 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(false));
506
507 loop {
508 let msg = self.recv().await?;
509 flow.validate(&msg, "prepared single execute", error.is_some())?;
510 match msg {
511 BackendMessage::BindComplete => {}
512 BackendMessage::RowDescription(_) => {}
513 BackendMessage::DataRow(data) => {
514 if error.is_none() {
515 rows.push(data);
516 }
517 }
518 BackendMessage::CommandComplete(_) => {}
519 BackendMessage::NoData => {}
520 BackendMessage::ReadyForQuery(_) => {
521 if let Some(err) = error {
522 return Err(err);
523 }
524 return Ok(rows);
525 }
526 BackendMessage::ErrorResponse(err) => {
527 if error.is_none() {
528 error = Some(PgError::QueryServer(err.into()));
529 }
530 }
531 msg if is_ignorable_session_message(&msg) => {}
532 other => {
533 return Err(unexpected_backend_message(
534 "prepared single execute",
535 &other,
536 ));
537 }
538 }
539 }
540 }
541}