1use crate::conversion::ToParams;
4use crate::error::{Error, Result};
5use crate::handler::BinaryHandler;
6use crate::protocol::backend::{
7 BindComplete, CloseComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse,
8 NoData, ParameterDescription, ParseComplete, PortalSuspended, RawMessage, ReadyForQuery,
9 RowDescription, msg_type,
10};
11use crate::protocol::frontend::{
12 write_bind, write_close_statement, write_describe_portal, write_describe_statement,
13 write_execute, write_parse, write_sync,
14};
15use crate::protocol::types::{Oid, TransactionStatus};
16
17use super::StateMachine;
18use super::action::{Action, AsyncMessage};
19use crate::buffer_set::BufferSet;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum State {
24 Initial,
25 WaitingParse,
26 WaitingBind,
27 WaitingDescribe,
28 WaitingRowDesc,
29 ProcessingRows,
30 WaitingReady,
31 Finished,
32}
33
34#[derive(Debug, Clone)]
36pub struct PreparedStatement {
37 pub idx: u64,
39 pub param_oids: Vec<Oid>,
41 row_desc_payload: Option<Vec<u8>>,
43 custom_wire_name: Option<String>,
45}
46
47impl PreparedStatement {
48 pub fn new(
50 idx: u64,
51 param_oids: Vec<Oid>,
52 row_desc_payload: Option<Vec<u8>>,
53 wire_name: String,
54 ) -> Self {
55 Self {
56 idx,
57 param_oids,
58 row_desc_payload,
59 custom_wire_name: Some(wire_name),
60 }
61 }
62
63 pub fn wire_name(&self) -> String {
65 if let Some(name) = &self.custom_wire_name {
66 name.clone()
67 } else {
68 format!("_zero_{}", self.idx)
69 }
70 }
71
72 pub fn parse_columns(&self) -> Option<Result<RowDescription<'_>>> {
76 self.row_desc_payload
77 .as_ref()
78 .map(|bytes| RowDescription::parse(bytes))
79 }
80
81 pub fn row_desc_payload(&self) -> Option<&[u8]> {
85 self.row_desc_payload.as_deref()
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91enum Operation {
92 Prepare,
94 Execute,
96 ExecuteSql,
98 CloseStatement,
100}
101
102pub struct ExtendedQueryStateMachine<'a, H> {
104 state: State,
105 handler: &'a mut H,
106 operation: Operation,
107 transaction_status: TransactionStatus,
108 prepared_stmt: Option<PreparedStatement>,
109}
110
111impl<'a, H: BinaryHandler> ExtendedQueryStateMachine<'a, H> {
112 pub fn take_prepared_statement(&mut self) -> Option<PreparedStatement> {
114 self.prepared_stmt.take()
115 }
116
117 pub fn prepare(
121 handler: &'a mut H,
122 buffer_set: &mut BufferSet,
123 idx: u64,
124 query: &str,
125 param_oids: &[Oid],
126 ) -> Self {
127 let stmt_name = format!("_zero_{}", idx);
128 buffer_set.write_buffer.clear();
129 write_parse(&mut buffer_set.write_buffer, &stmt_name, query, param_oids);
130 write_describe_statement(&mut buffer_set.write_buffer, &stmt_name);
131 write_sync(&mut buffer_set.write_buffer);
132
133 Self {
134 state: State::Initial,
135 handler,
136 operation: Operation::Prepare,
137 transaction_status: TransactionStatus::Idle,
138 prepared_stmt: Some(PreparedStatement {
139 idx,
140 param_oids: Vec::new(),
141 row_desc_payload: None,
142 custom_wire_name: None,
143 }),
144 }
145 }
146
147 pub fn execute<P: ToParams>(
154 handler: &'a mut H,
155 buffer_set: &mut BufferSet,
156 statement_name: &str,
157 param_oids: &[Oid],
158 params: &P,
159 ) -> Result<Self> {
160 buffer_set.write_buffer.clear();
161 write_bind(
162 &mut buffer_set.write_buffer,
163 "",
164 statement_name,
165 params,
166 param_oids,
167 )?;
168 write_describe_portal(&mut buffer_set.write_buffer, "");
169 write_execute(&mut buffer_set.write_buffer, "", 0);
170 write_sync(&mut buffer_set.write_buffer);
171
172 Ok(Self {
173 state: State::Initial,
174 handler,
175 operation: Operation::Execute,
176 transaction_status: TransactionStatus::Idle,
177 prepared_stmt: None,
178 })
179 }
180
181 pub fn execute_sql<P: ToParams>(
189 handler: &'a mut H,
190 buffer_set: &mut BufferSet,
191 sql: &str,
192 params: &P,
193 ) -> Result<Self> {
194 let param_oids = params.natural_oids();
195 buffer_set.write_buffer.clear();
196 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
197 write_bind(&mut buffer_set.write_buffer, "", "", params, ¶m_oids)?;
198 write_describe_portal(&mut buffer_set.write_buffer, "");
199 write_execute(&mut buffer_set.write_buffer, "", 0);
200 write_sync(&mut buffer_set.write_buffer);
201
202 Ok(Self {
203 state: State::Initial,
204 handler,
205 operation: Operation::ExecuteSql,
206 transaction_status: TransactionStatus::Idle,
207 prepared_stmt: None,
208 })
209 }
210
211 pub fn close_statement(handler: &'a mut H, buffer_set: &mut BufferSet, name: &str) -> Self {
215 buffer_set.write_buffer.clear();
216 write_close_statement(&mut buffer_set.write_buffer, name);
217 write_sync(&mut buffer_set.write_buffer);
218
219 Self {
220 state: State::Initial,
221 handler,
222 operation: Operation::CloseStatement,
223 transaction_status: TransactionStatus::Idle,
224 prepared_stmt: None,
225 }
226 }
227
228 fn handle_parse(&mut self, buffer_set: &BufferSet) -> Result<Action> {
229 let type_byte = buffer_set.type_byte;
230 if type_byte != msg_type::PARSE_COMPLETE {
231 return Err(Error::Protocol(format!(
232 "Expected ParseComplete, got '{}'",
233 type_byte as char
234 )));
235 }
236
237 ParseComplete::parse(&buffer_set.read_buffer)?;
238 self.state = match self.operation {
241 Operation::ExecuteSql => State::WaitingBind,
242 Operation::Prepare => State::WaitingDescribe,
243 _ => unreachable!("handle_parse called for non-parse operation"),
244 };
245 Ok(Action::ReadMessage)
246 }
247
248 fn handle_describe(&mut self, buffer_set: &BufferSet) -> Result<Action> {
249 let type_byte = buffer_set.type_byte;
250 if type_byte != msg_type::PARAMETER_DESCRIPTION {
251 return Err(Error::Protocol(format!(
252 "Expected ParameterDescription, got '{}'",
253 type_byte as char
254 )));
255 }
256
257 let param_desc = ParameterDescription::parse(&buffer_set.read_buffer)?;
258 if let Some(ref mut stmt) = self.prepared_stmt {
259 stmt.param_oids = param_desc.oids().to_vec();
260 }
261
262 self.state = State::WaitingRowDesc;
263 Ok(Action::ReadMessage)
264 }
265
266 fn handle_row_desc(&mut self, buffer_set: &BufferSet) -> Result<Action> {
267 let type_byte = buffer_set.type_byte;
268
269 match type_byte {
270 msg_type::ROW_DESCRIPTION => {
271 if let Some(ref mut stmt) = self.prepared_stmt {
272 stmt.row_desc_payload = Some(buffer_set.read_buffer.clone());
273 }
274 self.state = State::WaitingReady;
275 Ok(Action::ReadMessage)
276 }
277 msg_type::NO_DATA => {
278 let payload = &buffer_set.read_buffer;
279 NoData::parse(payload)?;
280 self.state = State::WaitingReady;
282 Ok(Action::ReadMessage)
283 }
284 _ => Err(Error::Protocol(format!(
285 "Expected RowDescription or NoData, got '{}'",
286 type_byte as char
287 ))),
288 }
289 }
290
291 fn handle_bind(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
292 let type_byte = buffer_set.type_byte;
293
294 match type_byte {
295 msg_type::BIND_COMPLETE => {
296 BindComplete::parse(&buffer_set.read_buffer)?;
297 self.state = State::ProcessingRows;
298 Ok(Action::ReadMessage)
299 }
300 msg_type::ROW_DESCRIPTION => {
301 buffer_set.column_buffer.clear();
303 buffer_set
304 .column_buffer
305 .extend_from_slice(&buffer_set.read_buffer);
306 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
307 self.handler.result_start(cols)?;
308 self.state = State::ProcessingRows;
309 Ok(Action::ReadMessage)
310 }
311 _ => Err(Error::Protocol(format!(
312 "Expected BindComplete, got '{}'",
313 type_byte as char
314 ))),
315 }
316 }
317
318 fn handle_rows(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
319 let type_byte = buffer_set.type_byte;
320 let payload = &buffer_set.read_buffer;
321
322 match type_byte {
323 msg_type::ROW_DESCRIPTION => {
324 buffer_set.column_buffer.clear();
326 buffer_set.column_buffer.extend_from_slice(payload);
327 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
328 self.handler.result_start(cols)?;
329 Ok(Action::ReadMessage)
330 }
331 msg_type::NO_DATA => {
332 NoData::parse(payload)?;
334 Ok(Action::ReadMessage)
335 }
336 msg_type::DATA_ROW => {
337 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
338 let row = DataRow::parse(payload)?;
339 self.handler.row(cols, row)?;
340 Ok(Action::ReadMessage)
341 }
342 msg_type::COMMAND_COMPLETE => {
343 let complete = CommandComplete::parse(payload)?;
344 self.handler.result_end(complete)?;
345 self.state = State::WaitingReady;
346 Ok(Action::ReadMessage)
347 }
348 msg_type::EMPTY_QUERY_RESPONSE => {
349 EmptyQueryResponse::parse(payload)?;
350 self.state = State::WaitingReady;
352 Ok(Action::ReadMessage)
353 }
354 msg_type::PORTAL_SUSPENDED => {
355 PortalSuspended::parse(payload)?;
356 self.state = State::WaitingReady;
358 Ok(Action::ReadMessage)
359 }
360 msg_type::READY_FOR_QUERY => {
361 let ready = ReadyForQuery::parse(payload)?;
362 self.transaction_status = ready.transaction_status().unwrap_or_default();
363 self.state = State::Finished;
364 Ok(Action::Finished)
365 }
366 _ => Err(Error::Protocol(format!(
367 "Unexpected message in rows: '{}'",
368 type_byte as char
369 ))),
370 }
371 }
372
373 fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
374 let type_byte = buffer_set.type_byte;
375 let payload = &buffer_set.read_buffer;
376
377 match type_byte {
378 msg_type::READY_FOR_QUERY => {
379 let ready = ReadyForQuery::parse(payload)?;
380 self.transaction_status = ready.transaction_status().unwrap_or_default();
381 self.state = State::Finished;
382 Ok(Action::Finished)
383 }
384 msg_type::CLOSE_COMPLETE => {
385 CloseComplete::parse(payload)?;
386 Ok(Action::ReadMessage)
388 }
389 _ => Err(Error::Protocol(format!(
390 "Expected ReadyForQuery, got '{}'",
391 type_byte as char
392 ))),
393 }
394 }
395
396 fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
397 match msg.type_byte {
398 msg_type::NOTICE_RESPONSE => {
399 let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
400 Ok(Action::HandleAsyncMessageAndReadMessage(
401 AsyncMessage::Notice(notice.0),
402 ))
403 }
404 msg_type::PARAMETER_STATUS => {
405 let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
406 Ok(Action::HandleAsyncMessageAndReadMessage(
407 AsyncMessage::ParameterChanged {
408 name: param.name.to_string(),
409 value: param.value.to_string(),
410 },
411 ))
412 }
413 msg_type::NOTIFICATION_RESPONSE => {
414 let notification =
415 crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
416 Ok(Action::HandleAsyncMessageAndReadMessage(
417 AsyncMessage::Notification {
418 pid: notification.pid,
419 channel: notification.channel.to_string(),
420 payload: notification.payload.to_string(),
421 },
422 ))
423 }
424 _ => Err(Error::Protocol(format!(
425 "Unknown async message type: '{}'",
426 msg.type_byte as char
427 ))),
428 }
429 }
430}
431
432impl<H: BinaryHandler> StateMachine for ExtendedQueryStateMachine<'_, H> {
433 fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
434 if self.state == State::Initial {
436 self.state = match self.operation {
438 Operation::Prepare => State::WaitingParse,
439 Operation::Execute => State::WaitingBind, Operation::ExecuteSql => State::WaitingParse,
441 Operation::CloseStatement => State::WaitingReady,
442 };
443 return Ok(Action::WriteAndReadMessage);
444 }
445
446 let type_byte = buffer_set.type_byte;
447
448 if RawMessage::is_async_type(type_byte) {
450 let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
451 return self.handle_async_message(&msg);
452 }
453
454 if type_byte == msg_type::ERROR_RESPONSE {
456 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
457 self.state = State::WaitingReady;
459 return Err(error.into_error());
460 }
461
462 match self.state {
463 State::WaitingParse => self.handle_parse(buffer_set),
464 State::WaitingDescribe => self.handle_describe(buffer_set),
465 State::WaitingRowDesc => self.handle_row_desc(buffer_set),
466 State::WaitingBind => self.handle_bind(buffer_set),
467 State::ProcessingRows => self.handle_rows(buffer_set),
468 State::WaitingReady => self.handle_ready(buffer_set),
469 _ => Err(Error::Protocol(format!(
470 "Unexpected state {:?}",
471 self.state
472 ))),
473 }
474 }
475
476 fn transaction_status(&self) -> TransactionStatus {
477 self.transaction_status
478 }
479}
480
481use crate::protocol::frontend::write_flush;
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488enum BindState {
489 Initial,
490 WaitingParse,
491 WaitingBind,
492 Finished,
493}
494
495pub struct BindStateMachine {
499 state: BindState,
500 needs_parse: bool,
501}
502
503impl BindStateMachine {
504 pub fn bind_prepared<P: ToParams>(
513 buffer_set: &mut BufferSet,
514 portal_name: &str,
515 statement_name: &str,
516 param_oids: &[Oid],
517 params: &P,
518 ) -> Result<Self> {
519 buffer_set.write_buffer.clear();
520 write_bind(
521 &mut buffer_set.write_buffer,
522 portal_name,
523 statement_name,
524 params,
525 param_oids,
526 )?;
527 write_flush(&mut buffer_set.write_buffer);
528
529 Ok(Self {
530 state: BindState::Initial,
531 needs_parse: false,
532 })
533 }
534
535 pub fn bind_sql<P: ToParams>(
544 buffer_set: &mut BufferSet,
545 portal_name: &str,
546 sql: &str,
547 params: &P,
548 ) -> Result<Self> {
549 let param_oids = params.natural_oids();
550 buffer_set.write_buffer.clear();
551 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
552 write_bind(
553 &mut buffer_set.write_buffer,
554 portal_name,
555 "",
556 params,
557 ¶m_oids,
558 )?;
559 write_flush(&mut buffer_set.write_buffer);
560
561 Ok(Self {
562 state: BindState::Initial,
563 needs_parse: true,
564 })
565 }
566
567 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
569 if self.state == BindState::Initial {
571 self.state = if self.needs_parse {
572 BindState::WaitingParse
573 } else {
574 BindState::WaitingBind
575 };
576 return Ok(Action::WriteAndReadMessage);
577 }
578
579 let type_byte = buffer_set.type_byte;
580
581 if RawMessage::is_async_type(type_byte) {
583 return Ok(Action::ReadMessage);
584 }
585
586 if type_byte == msg_type::ERROR_RESPONSE {
588 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
589 return Err(error.into_error());
590 }
591
592 match self.state {
593 BindState::WaitingParse => {
594 if type_byte != msg_type::PARSE_COMPLETE {
595 return Err(Error::Protocol(format!(
596 "Expected ParseComplete, got '{}'",
597 type_byte as char
598 )));
599 }
600 ParseComplete::parse(&buffer_set.read_buffer)?;
601 self.state = BindState::WaitingBind;
602 Ok(Action::ReadMessage)
603 }
604 BindState::WaitingBind => {
605 if type_byte != msg_type::BIND_COMPLETE {
606 return Err(Error::Protocol(format!(
607 "Expected BindComplete, got '{}'",
608 type_byte as char
609 )));
610 }
611 BindComplete::parse(&buffer_set.read_buffer)?;
612 self.state = BindState::Finished;
613 Ok(Action::Finished)
614 }
615 _ => Err(Error::Protocol(format!(
616 "Unexpected state {:?}",
617 self.state
618 ))),
619 }
620 }
621}
622
623#[derive(Debug, Clone, Copy, PartialEq, Eq)]
628enum BatchState {
629 Initial,
630 WaitingParse,
631 Processing,
632 Finished,
633}
634
635pub struct BatchStateMachine {
639 state: BatchState,
640 needs_parse: bool,
641 transaction_status: TransactionStatus,
642}
643
644impl BatchStateMachine {
645 pub fn new(needs_parse: bool) -> Self {
652 Self {
653 state: BatchState::Initial,
654 needs_parse,
655 transaction_status: TransactionStatus::Idle,
656 }
657 }
658
659 pub fn transaction_status(&self) -> TransactionStatus {
661 self.transaction_status
662 }
663
664 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
666 if self.state == BatchState::Initial {
668 self.state = if self.needs_parse {
669 BatchState::WaitingParse
670 } else {
671 BatchState::Processing
672 };
673 return Ok(Action::WriteAndReadMessage);
674 }
675
676 let type_byte = buffer_set.type_byte;
677
678 if RawMessage::is_async_type(type_byte) {
680 return Ok(Action::ReadMessage);
681 }
682
683 if type_byte == msg_type::ERROR_RESPONSE {
685 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
686 self.state = BatchState::Processing;
687 return Err(error.into_error());
688 }
689
690 match self.state {
691 BatchState::WaitingParse => {
692 if type_byte != msg_type::PARSE_COMPLETE {
693 return Err(Error::Protocol(format!(
694 "Expected ParseComplete, got '{}'",
695 type_byte as char
696 )));
697 }
698 ParseComplete::parse(&buffer_set.read_buffer)?;
699 self.state = BatchState::Processing;
700 Ok(Action::ReadMessage)
701 }
702 BatchState::Processing => {
703 match type_byte {
704 msg_type::BIND_COMPLETE => {
705 BindComplete::parse(&buffer_set.read_buffer)?;
706 Ok(Action::ReadMessage)
707 }
708 msg_type::NO_DATA => {
709 NoData::parse(&buffer_set.read_buffer)?;
710 Ok(Action::ReadMessage)
711 }
712 msg_type::ROW_DESCRIPTION => {
713 RowDescription::parse(&buffer_set.read_buffer)?;
715 Ok(Action::ReadMessage)
716 }
717 msg_type::DATA_ROW => {
718 Ok(Action::ReadMessage)
720 }
721 msg_type::COMMAND_COMPLETE => {
722 CommandComplete::parse(&buffer_set.read_buffer)?;
723 Ok(Action::ReadMessage)
724 }
725 msg_type::EMPTY_QUERY_RESPONSE => {
726 EmptyQueryResponse::parse(&buffer_set.read_buffer)?;
727 Ok(Action::ReadMessage)
728 }
729 msg_type::READY_FOR_QUERY => {
730 let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
731 self.transaction_status = ready.transaction_status().unwrap_or_default();
732 self.state = BatchState::Finished;
733 Ok(Action::Finished)
734 }
735 _ => Err(Error::Protocol(format!(
736 "Unexpected message in batch: '{}'",
737 type_byte as char
738 ))),
739 }
740 }
741 _ => Err(Error::Protocol(format!(
742 "Unexpected state {:?}",
743 self.state
744 ))),
745 }
746 }
747}