1use super::{
6 PgConnection, PgError, PgResult,
7 extended_flow::{ExtendedFlowConfig, ExtendedFlowTracker},
8 is_ignorable_session_message, unexpected_backend_message,
9};
10use crate::protocol::{BackendMessage, PgEncoder};
11use bytes::BytesMut;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum SimpleStatementState {
15 AwaitingResult,
16 InRowStream,
17}
18
19#[derive(Debug, Clone, Copy)]
20struct SimpleFlowTracker {
21 state: SimpleStatementState,
22 saw_completion: bool,
23}
24
25impl SimpleFlowTracker {
26 fn new() -> Self {
27 Self {
28 state: SimpleStatementState::AwaitingResult,
29 saw_completion: false,
30 }
31 }
32
33 fn on_row_description(&mut self, context: &'static str) -> PgResult<()> {
34 if self.state == SimpleStatementState::InRowStream {
35 return Err(PgError::Protocol(format!(
36 "{}: duplicate RowDescription before statement completion",
37 context
38 )));
39 }
40 self.state = SimpleStatementState::InRowStream;
41 self.saw_completion = false;
42 Ok(())
43 }
44
45 fn on_data_row(&self, context: &'static str) -> PgResult<()> {
46 if self.state != SimpleStatementState::InRowStream {
47 return Err(PgError::Protocol(format!(
48 "{}: DataRow before RowDescription",
49 context
50 )));
51 }
52 Ok(())
53 }
54
55 fn on_command_complete(&mut self) {
56 self.state = SimpleStatementState::AwaitingResult;
57 self.saw_completion = true;
58 }
59
60 fn on_empty_query_response(&mut self, context: &'static str) -> PgResult<()> {
61 if self.state == SimpleStatementState::InRowStream {
62 return Err(PgError::Protocol(format!(
63 "{}: EmptyQueryResponse during active row stream",
64 context
65 )));
66 }
67 self.saw_completion = true;
68 Ok(())
69 }
70
71 fn on_ready_for_query(&self, context: &'static str, error_pending: bool) -> PgResult<()> {
72 if error_pending {
73 return Ok(());
74 }
75 if self.state == SimpleStatementState::InRowStream {
76 return Err(PgError::Protocol(format!(
77 "{}: ReadyForQuery before CommandComplete",
78 context
79 )));
80 }
81 if !self.saw_completion {
82 return Err(PgError::Protocol(format!(
83 "{}: ReadyForQuery before completion",
84 context
85 )));
86 }
87 Ok(())
88 }
89}
90
91impl PgConnection {
92 pub(crate) async fn query(
98 &mut self,
99 sql: &str,
100 params: &[Option<Vec<u8>>],
101 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
102 self.query_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
103 .await
104 }
105
106 pub(crate) async fn query_with_result_format(
108 &mut self,
109 sql: &str,
110 params: &[Option<Vec<u8>>],
111 result_format: i16,
112 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
113 let bytes = PgEncoder::encode_extended_query_with_result_format(sql, params, result_format)
114 .map_err(|e| PgError::Encode(e.to_string()))?;
115 self.write_all_with_timeout(&bytes, "stream write").await?;
116
117 let mut rows = Vec::new();
118
119 let mut error: Option<PgError> = None;
120 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(true));
121
122 loop {
123 let msg = self.recv().await?;
124 flow.validate(&msg, "extended-query execute", error.is_some())?;
125 match msg {
126 BackendMessage::ParseComplete => {}
127 BackendMessage::BindComplete => {}
128 BackendMessage::RowDescription(_) => {}
129 BackendMessage::DataRow(data) => {
130 if error.is_none() {
132 rows.push(data);
133 }
134 }
135 BackendMessage::CommandComplete(_) => {}
136 BackendMessage::NoData => {}
137 BackendMessage::ReadyForQuery(_) => {
138 if let Some(err) = error {
139 return Err(err);
140 }
141 return Ok(rows);
142 }
143 BackendMessage::ErrorResponse(err) => {
144 if error.is_none() {
145 error = Some(PgError::QueryServer(err.into()));
146 }
147 }
148 msg if is_ignorable_session_message(&msg) => {}
149 other => {
150 return Err(unexpected_backend_message("extended-query execute", &other));
151 }
152 }
153 }
154 }
155
156 pub async fn query_cached(
161 &mut self,
162 sql: &str,
163 params: &[Option<Vec<u8>>],
164 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
165 self.query_cached_with_result_format(sql, params, PgEncoder::FORMAT_TEXT)
166 .await
167 }
168
169 pub async fn query_cached_with_result_format(
171 &mut self,
172 sql: &str,
173 params: &[Option<Vec<u8>>],
174 result_format: i16,
175 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
176 let mut retried = false;
177 loop {
178 match self
179 .query_cached_with_result_format_once(sql, params, result_format)
180 .await
181 {
182 Ok(rows) => return Ok(rows),
183 Err(err)
184 if !retried
185 && (err.is_prepared_statement_retryable()
186 || err.is_prepared_statement_already_exists()) =>
187 {
188 retried = true;
189 if err.is_prepared_statement_retryable() {
190 self.clear_prepared_statement_state();
191 }
192 }
193 Err(err) => return Err(err),
194 }
195 }
196 }
197
198 async fn query_cached_with_result_format_once(
199 &mut self,
200 sql: &str,
201 params: &[Option<Vec<u8>>],
202 result_format: i16,
203 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
204 let stmt_name = Self::sql_to_stmt_name(sql);
205 let is_new = !self.prepared_statements.contains_key(&stmt_name);
206
207 let params_size: usize = params
209 .iter()
210 .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
211 .sum();
212
213 let estimated_size = if is_new {
214 50 + sql.len() + stmt_name.len() * 2 + params_size
215 } else {
216 30 + stmt_name.len() + params_size
217 };
218
219 let mut buf = BytesMut::with_capacity(estimated_size);
220
221 if is_new {
222 self.evict_prepared_if_full();
226 buf.extend(PgEncoder::try_encode_parse(&stmt_name, sql, &[])?);
227 self.prepared_statements
229 .insert(stmt_name.clone(), sql.to_string());
230 }
231
232 if let Err(e) = PgEncoder::encode_bind_to_with_result_format(
234 &mut buf,
235 &stmt_name,
236 params,
237 result_format,
238 ) {
239 if is_new {
240 self.prepared_statements.remove(&stmt_name);
241 }
242 return Err(PgError::Encode(e.to_string()));
243 }
244 PgEncoder::encode_execute_to(&mut buf);
245 PgEncoder::encode_sync_to(&mut buf);
246
247 if let Err(err) = self.write_all_with_timeout(&buf, "stream write").await {
248 if is_new {
249 self.prepared_statements.remove(&stmt_name);
250 }
251 return Err(err);
252 }
253
254 let mut rows = Vec::new();
255
256 let mut error: Option<PgError> = None;
257 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(is_new));
258
259 loop {
260 let msg = match self.recv().await {
261 Ok(msg) => msg,
262 Err(err) => {
263 if is_new && !flow.saw_parse_complete() {
264 self.prepared_statements.remove(&stmt_name);
265 }
266 return Err(err);
267 }
268 };
269 if let Err(err) = flow.validate(&msg, "extended-query cached execute", error.is_some())
270 {
271 if is_new && !flow.saw_parse_complete() {
272 self.prepared_statements.remove(&stmt_name);
273 }
274 return Err(err);
275 }
276 match msg {
277 BackendMessage::ParseComplete => {
278 }
280 BackendMessage::BindComplete => {}
281 BackendMessage::RowDescription(_) => {}
282 BackendMessage::DataRow(data) => {
283 if error.is_none() {
284 rows.push(data);
285 }
286 }
287 BackendMessage::CommandComplete(_) => {}
288 BackendMessage::NoData => {}
289 BackendMessage::ReadyForQuery(_) => {
290 if let Some(err) = error {
291 if is_new
292 && !flow.saw_parse_complete()
293 && !err.is_prepared_statement_already_exists()
294 {
295 self.prepared_statements.remove(&stmt_name);
296 }
297 return Err(err);
298 }
299 if is_new && !flow.saw_parse_complete() {
300 self.prepared_statements.remove(&stmt_name);
301 return Err(PgError::Protocol(
302 "Cache miss query reached ReadyForQuery without ParseComplete"
303 .to_string(),
304 ));
305 }
306 return Ok(rows);
307 }
308 BackendMessage::ErrorResponse(err) => {
309 if error.is_none() {
310 let query_err = PgError::QueryServer(err.into());
311 if !query_err.is_prepared_statement_already_exists() {
312 self.prepared_statements.remove(&stmt_name);
314 }
315 error = Some(query_err);
316 }
317 }
318 msg if is_ignorable_session_message(&msg) => {}
319 other => {
320 if is_new && !flow.saw_parse_complete() {
321 self.prepared_statements.remove(&stmt_name);
322 }
323 return Err(unexpected_backend_message(
324 "extended-query cached execute",
325 &other,
326 ));
327 }
328 }
329 }
330 }
331
332 pub(crate) fn sql_to_stmt_name(sql: &str) -> String {
335 use std::collections::hash_map::DefaultHasher;
336 use std::hash::{Hash, Hasher};
337
338 let mut hasher = DefaultHasher::new();
339 sql.hash(&mut hasher);
340 format!("s{:016x}", hasher.finish())
341 }
342
343 pub async fn execute_simple(&mut self, sql: &str) -> PgResult<()> {
345 let bytes = PgEncoder::try_encode_query_string(sql)?;
346 self.write_all_with_timeout(&bytes, "stream write").await?;
347
348 let mut error: Option<PgError> = None;
349 let mut flow = SimpleFlowTracker::new();
350
351 loop {
352 let msg = self.recv().await?;
353 match msg {
354 BackendMessage::RowDescription(_) => {
355 flow.on_row_description("simple-query execute")?;
359 }
360 BackendMessage::DataRow(_) => {
361 flow.on_data_row("simple-query execute")?;
362 }
363 BackendMessage::CommandComplete(_) => {
364 flow.on_command_complete();
365 }
366 BackendMessage::EmptyQueryResponse => {
367 flow.on_empty_query_response("simple-query execute")?;
368 }
369 BackendMessage::ReadyForQuery(_) => {
370 if let Some(err) = error {
371 return Err(err);
372 }
373 flow.on_ready_for_query("simple-query execute", error.is_some())?;
374 return Ok(());
375 }
376 BackendMessage::ErrorResponse(err) => {
377 if error.is_none() {
378 error = Some(PgError::QueryServer(err.into()));
379 }
380 }
381 msg if is_ignorable_session_message(&msg) => {}
382 other => {
383 return Err(unexpected_backend_message("simple-query execute", &other));
384 }
385 }
386 }
387 }
388
389 pub async fn simple_query(&mut self, sql: &str) -> PgResult<Vec<super::PgRow>> {
396 use std::sync::Arc;
397
398 const MAX_SIMPLE_QUERY_ROWS: usize = 10_000;
401
402 let bytes = PgEncoder::try_encode_query_string(sql)?;
403 self.write_all_with_timeout(&bytes, "stream write").await?;
404
405 let mut rows: Vec<super::PgRow> = Vec::new();
406 let mut column_info: Option<Arc<super::ColumnInfo>> = None;
407 let mut error: Option<PgError> = None;
408 let mut flow = SimpleFlowTracker::new();
409
410 loop {
411 let msg = self.recv().await?;
412 match msg {
413 BackendMessage::RowDescription(fields) => {
414 flow.on_row_description("simple-query read")?;
415 column_info = Some(Arc::new(super::ColumnInfo::from_fields(&fields)));
416 }
417 BackendMessage::DataRow(data) => {
418 flow.on_data_row("simple-query read")?;
419 if error.is_none() {
420 if rows.len() >= MAX_SIMPLE_QUERY_ROWS {
421 if error.is_none() {
422 error = Some(PgError::Query(format!(
423 "simple_query exceeded {} row safety cap",
424 MAX_SIMPLE_QUERY_ROWS,
425 )));
426 }
427 } else {
429 rows.push(super::PgRow {
430 columns: data,
431 column_info: column_info.clone(),
432 });
433 }
434 }
435 }
436 BackendMessage::CommandComplete(_) => {
437 flow.on_command_complete();
438 column_info = None;
439 }
440 BackendMessage::EmptyQueryResponse => {
441 flow.on_empty_query_response("simple-query read")?;
442 column_info = None;
443 }
444 BackendMessage::ReadyForQuery(_) => {
445 if let Some(err) = error {
446 return Err(err);
447 }
448 flow.on_ready_for_query("simple-query read", error.is_some())?;
449 return Ok(rows);
450 }
451 BackendMessage::ErrorResponse(err) => {
452 if error.is_none() {
453 error = Some(PgError::QueryServer(err.into()));
454 }
455 }
456 msg if is_ignorable_session_message(&msg) => {}
457 other => {
458 return Err(unexpected_backend_message("simple-query read", &other));
459 }
460 }
461 }
462 }
463
464 #[inline]
477 pub async fn query_prepared_single(
478 &mut self,
479 stmt: &super::PreparedStatement,
480 params: &[Option<Vec<u8>>],
481 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
482 self.query_prepared_single_with_result_format(stmt, params, PgEncoder::FORMAT_TEXT)
483 .await
484 }
485
486 #[inline]
488 pub async fn query_prepared_single_with_result_format(
489 &mut self,
490 stmt: &super::PreparedStatement,
491 params: &[Option<Vec<u8>>],
492 result_format: i16,
493 ) -> PgResult<Vec<Vec<Option<Vec<u8>>>>> {
494 let params_size: usize = params
496 .iter()
497 .map(|p| 4 + p.as_ref().map_or(0, |v| v.len()))
498 .sum();
499
500 let mut buf = BytesMut::with_capacity(30 + stmt.name.len() + params_size);
502
503 PgEncoder::encode_bind_to_with_result_format(&mut buf, &stmt.name, params, result_format)
505 .map_err(|e| PgError::Encode(e.to_string()))?;
506 PgEncoder::encode_execute_to(&mut buf);
507 PgEncoder::encode_sync_to(&mut buf);
508
509 self.write_all_with_timeout(&buf, "stream write").await?;
510
511 let mut rows = Vec::new();
512
513 let mut error: Option<PgError> = None;
514 let mut flow = ExtendedFlowTracker::new(ExtendedFlowConfig::parse_bind_execute(false));
515
516 loop {
517 let msg = self.recv().await?;
518 flow.validate(&msg, "prepared single execute", error.is_some())?;
519 match msg {
520 BackendMessage::BindComplete => {}
521 BackendMessage::RowDescription(_) => {}
522 BackendMessage::DataRow(data) => {
523 if error.is_none() {
524 rows.push(data);
525 }
526 }
527 BackendMessage::CommandComplete(_) => {}
528 BackendMessage::NoData => {}
529 BackendMessage::ReadyForQuery(_) => {
530 if let Some(err) = error {
531 return Err(err);
532 }
533 return Ok(rows);
534 }
535 BackendMessage::ErrorResponse(err) => {
536 if error.is_none() {
537 error = Some(PgError::QueryServer(err.into()));
538 }
539 }
540 msg if is_ignorable_session_message(&msg) => {}
541 other => {
542 return Err(unexpected_backend_message(
543 "prepared single execute",
544 &other,
545 ));
546 }
547 }
548 }
549 }
550}