1use crate::conversion::ToParams;
4use crate::error::{Error, Result};
5use crate::handler::ExtendedHandler;
6use crate::protocol::backend::{
7 BindComplete, CloseComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse,
8 NoData, ParameterDescription, ParseComplete, PortalSuspended, RawMessage, ReadyForQuery,
9 RowDescription, msg_type,
10};
11use crate::protocol::frontend::{
12 write_bind, write_close_statement, write_describe_portal, write_describe_statement,
13 write_execute, write_parse, write_sync,
14};
15use crate::protocol::types::{Oid, TransactionStatus};
16
17use super::StateMachine;
18use super::action::{Action, AsyncMessage};
19use crate::buffer_set::BufferSet;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum State {
24 Initial,
25 WaitingParse,
26 WaitingBind,
27 WaitingDescribe,
28 WaitingRowDesc,
29 ProcessingRows,
30 WaitingReady,
31 Finished,
32}
33
34#[derive(Debug, Clone)]
36pub struct PreparedStatement {
37 pub idx: u64,
39 pub param_oids: Vec<Oid>,
41 pub(crate) row_desc_payload: Option<Vec<u8>>,
43}
44
45impl PreparedStatement {
46 pub fn wire_name(&self) -> String {
48 format!("_zero_s_{}", self.idx)
49 }
50
51 pub fn parse_columns(&self) -> Option<Result<RowDescription<'_>>> {
55 self.row_desc_payload
56 .as_ref()
57 .map(|bytes| RowDescription::parse(bytes))
58 }
59
60 pub fn row_desc_payload(&self) -> Option<&[u8]> {
64 self.row_desc_payload.as_deref()
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum Operation {
71 Prepare,
73 Execute,
75 ExecuteSql,
77 CloseStatement,
79}
80
81pub struct ExtendedQueryStateMachine<'a, H> {
83 state: State,
84 handler: &'a mut H,
85 operation: Operation,
86 transaction_status: TransactionStatus,
87 prepared_stmt: Option<PreparedStatement>,
88}
89
90impl<'a, H: ExtendedHandler> ExtendedQueryStateMachine<'a, H> {
91 pub fn take_prepared_statement(&mut self) -> Option<PreparedStatement> {
93 self.prepared_stmt.take()
94 }
95
96 pub fn prepare(
100 handler: &'a mut H,
101 buffer_set: &mut BufferSet,
102 idx: u64,
103 query: &str,
104 param_oids: &[Oid],
105 ) -> Self {
106 let stmt_name = format!("_zero_s_{}", idx);
107 buffer_set.write_buffer.clear();
108 write_parse(&mut buffer_set.write_buffer, &stmt_name, query, param_oids);
109 write_describe_statement(&mut buffer_set.write_buffer, &stmt_name);
110 write_sync(&mut buffer_set.write_buffer);
111
112 Self {
113 state: State::Initial,
114 handler,
115 operation: Operation::Prepare,
116 transaction_status: TransactionStatus::Idle,
117 prepared_stmt: Some(PreparedStatement {
118 idx,
119 param_oids: Vec::new(),
120 row_desc_payload: None,
121 }),
122 }
123 }
124
125 pub fn execute<P: ToParams>(
132 handler: &'a mut H,
133 buffer_set: &mut BufferSet,
134 statement_name: &str,
135 param_oids: &[Oid],
136 params: &P,
137 ) -> Result<Self> {
138 buffer_set.write_buffer.clear();
139 write_bind(
140 &mut buffer_set.write_buffer,
141 "",
142 statement_name,
143 params,
144 param_oids,
145 )?;
146 write_describe_portal(&mut buffer_set.write_buffer, "");
147 write_execute(&mut buffer_set.write_buffer, "", 0);
148 write_sync(&mut buffer_set.write_buffer);
149
150 Ok(Self {
151 state: State::Initial,
152 handler,
153 operation: Operation::Execute,
154 transaction_status: TransactionStatus::Idle,
155 prepared_stmt: None,
156 })
157 }
158
159 pub fn execute_sql<P: ToParams>(
167 handler: &'a mut H,
168 buffer_set: &mut BufferSet,
169 sql: &str,
170 params: &P,
171 ) -> Result<Self> {
172 let param_oids = params.natural_oids();
173 buffer_set.write_buffer.clear();
174 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
175 write_bind(&mut buffer_set.write_buffer, "", "", params, ¶m_oids)?;
176 write_describe_portal(&mut buffer_set.write_buffer, "");
177 write_execute(&mut buffer_set.write_buffer, "", 0);
178 write_sync(&mut buffer_set.write_buffer);
179
180 Ok(Self {
181 state: State::Initial,
182 handler,
183 operation: Operation::ExecuteSql,
184 transaction_status: TransactionStatus::Idle,
185 prepared_stmt: None,
186 })
187 }
188
189 pub fn close_statement(handler: &'a mut H, buffer_set: &mut BufferSet, name: &str) -> Self {
193 buffer_set.write_buffer.clear();
194 write_close_statement(&mut buffer_set.write_buffer, name);
195 write_sync(&mut buffer_set.write_buffer);
196
197 Self {
198 state: State::Initial,
199 handler,
200 operation: Operation::CloseStatement,
201 transaction_status: TransactionStatus::Idle,
202 prepared_stmt: None,
203 }
204 }
205
206 fn handle_parse(&mut self, buffer_set: &BufferSet) -> Result<Action> {
207 let type_byte = buffer_set.type_byte;
208 if type_byte != msg_type::PARSE_COMPLETE {
209 return Err(Error::Protocol(format!(
210 "Expected ParseComplete, got '{}'",
211 type_byte as char
212 )));
213 }
214
215 ParseComplete::parse(&buffer_set.read_buffer)?;
216 self.state = match self.operation {
219 Operation::ExecuteSql => State::WaitingBind,
220 Operation::Prepare => State::WaitingDescribe,
221 _ => unreachable!("handle_parse called for non-parse operation"),
222 };
223 Ok(Action::ReadMessage)
224 }
225
226 fn handle_describe(&mut self, buffer_set: &BufferSet) -> Result<Action> {
227 let type_byte = buffer_set.type_byte;
228 if type_byte != msg_type::PARAMETER_DESCRIPTION {
229 return Err(Error::Protocol(format!(
230 "Expected ParameterDescription, got '{}'",
231 type_byte as char
232 )));
233 }
234
235 let param_desc = ParameterDescription::parse(&buffer_set.read_buffer)?;
236 if let Some(ref mut stmt) = self.prepared_stmt {
237 stmt.param_oids = param_desc.oids().to_vec();
238 }
239
240 self.state = State::WaitingRowDesc;
241 Ok(Action::ReadMessage)
242 }
243
244 fn handle_row_desc(&mut self, buffer_set: &BufferSet) -> Result<Action> {
245 let type_byte = buffer_set.type_byte;
246
247 match type_byte {
248 msg_type::ROW_DESCRIPTION => {
249 if let Some(ref mut stmt) = self.prepared_stmt {
250 stmt.row_desc_payload = Some(buffer_set.read_buffer.clone());
251 }
252 self.state = State::WaitingReady;
253 Ok(Action::ReadMessage)
254 }
255 msg_type::NO_DATA => {
256 let payload = &buffer_set.read_buffer;
257 NoData::parse(payload)?;
258 self.state = State::WaitingReady;
260 Ok(Action::ReadMessage)
261 }
262 _ => Err(Error::Protocol(format!(
263 "Expected RowDescription or NoData, got '{}'",
264 type_byte as char
265 ))),
266 }
267 }
268
269 fn handle_bind(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
270 let type_byte = buffer_set.type_byte;
271
272 match type_byte {
273 msg_type::BIND_COMPLETE => {
274 BindComplete::parse(&buffer_set.read_buffer)?;
275 self.state = State::ProcessingRows;
276 Ok(Action::ReadMessage)
277 }
278 msg_type::ROW_DESCRIPTION => {
279 buffer_set.column_buffer.clear();
281 buffer_set
282 .column_buffer
283 .extend_from_slice(&buffer_set.read_buffer);
284 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
285 self.handler.result_start(cols)?;
286 self.state = State::ProcessingRows;
287 Ok(Action::ReadMessage)
288 }
289 _ => Err(Error::Protocol(format!(
290 "Expected BindComplete, got '{}'",
291 type_byte as char
292 ))),
293 }
294 }
295
296 fn handle_rows(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
297 let type_byte = buffer_set.type_byte;
298 let payload = &buffer_set.read_buffer;
299
300 match type_byte {
301 msg_type::ROW_DESCRIPTION => {
302 buffer_set.column_buffer.clear();
304 buffer_set.column_buffer.extend_from_slice(payload);
305 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
306 self.handler.result_start(cols)?;
307 Ok(Action::ReadMessage)
308 }
309 msg_type::NO_DATA => {
310 NoData::parse(payload)?;
312 Ok(Action::ReadMessage)
313 }
314 msg_type::DATA_ROW => {
315 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
316 let row = DataRow::parse(payload)?;
317 self.handler.row(cols, row)?;
318 Ok(Action::ReadMessage)
319 }
320 msg_type::COMMAND_COMPLETE => {
321 let complete = CommandComplete::parse(payload)?;
322 self.handler.result_end(complete)?;
323 self.state = State::WaitingReady;
324 Ok(Action::ReadMessage)
325 }
326 msg_type::EMPTY_QUERY_RESPONSE => {
327 EmptyQueryResponse::parse(payload)?;
328 self.state = State::WaitingReady;
330 Ok(Action::ReadMessage)
331 }
332 msg_type::PORTAL_SUSPENDED => {
333 PortalSuspended::parse(payload)?;
334 self.state = State::WaitingReady;
336 Ok(Action::ReadMessage)
337 }
338 msg_type::READY_FOR_QUERY => {
339 let ready = ReadyForQuery::parse(payload)?;
340 self.transaction_status = ready.transaction_status().unwrap_or_default();
341 self.state = State::Finished;
342 Ok(Action::Finished)
343 }
344 _ => Err(Error::Protocol(format!(
345 "Unexpected message in rows: '{}'",
346 type_byte as char
347 ))),
348 }
349 }
350
351 fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
352 let type_byte = buffer_set.type_byte;
353 let payload = &buffer_set.read_buffer;
354
355 match type_byte {
356 msg_type::READY_FOR_QUERY => {
357 let ready = ReadyForQuery::parse(payload)?;
358 self.transaction_status = ready.transaction_status().unwrap_or_default();
359 self.state = State::Finished;
360 Ok(Action::Finished)
361 }
362 msg_type::CLOSE_COMPLETE => {
363 CloseComplete::parse(payload)?;
364 Ok(Action::ReadMessage)
366 }
367 _ => Err(Error::Protocol(format!(
368 "Expected ReadyForQuery, got '{}'",
369 type_byte as char
370 ))),
371 }
372 }
373
374 fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
375 match msg.type_byte {
376 msg_type::NOTICE_RESPONSE => {
377 let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
378 Ok(Action::HandleAsyncMessageAndReadMessage(
379 AsyncMessage::Notice(notice.0),
380 ))
381 }
382 msg_type::PARAMETER_STATUS => {
383 let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
384 Ok(Action::HandleAsyncMessageAndReadMessage(
385 AsyncMessage::ParameterChanged {
386 name: param.name.to_string(),
387 value: param.value.to_string(),
388 },
389 ))
390 }
391 msg_type::NOTIFICATION_RESPONSE => {
392 let notification =
393 crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
394 Ok(Action::HandleAsyncMessageAndReadMessage(
395 AsyncMessage::Notification {
396 pid: notification.pid,
397 channel: notification.channel.to_string(),
398 payload: notification.payload.to_string(),
399 },
400 ))
401 }
402 _ => Err(Error::Protocol(format!(
403 "Unknown async message type: '{}'",
404 msg.type_byte as char
405 ))),
406 }
407 }
408}
409
410impl<H: ExtendedHandler> StateMachine for ExtendedQueryStateMachine<'_, H> {
411 fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
412 if self.state == State::Initial {
414 self.state = match self.operation {
416 Operation::Prepare => State::WaitingParse,
417 Operation::Execute => State::WaitingBind, Operation::ExecuteSql => State::WaitingParse,
419 Operation::CloseStatement => State::WaitingReady,
420 };
421 return Ok(Action::WriteAndReadMessage);
422 }
423
424 let type_byte = buffer_set.type_byte;
425
426 if RawMessage::is_async_type(type_byte) {
428 let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
429 return self.handle_async_message(&msg);
430 }
431
432 if type_byte == msg_type::ERROR_RESPONSE {
434 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
435 self.state = State::WaitingReady;
437 return Err(error.into_error());
438 }
439
440 match self.state {
441 State::WaitingParse => self.handle_parse(buffer_set),
442 State::WaitingDescribe => self.handle_describe(buffer_set),
443 State::WaitingRowDesc => self.handle_row_desc(buffer_set),
444 State::WaitingBind => self.handle_bind(buffer_set),
445 State::ProcessingRows => self.handle_rows(buffer_set),
446 State::WaitingReady => self.handle_ready(buffer_set),
447 _ => Err(Error::Protocol(format!(
448 "Unexpected state {:?}",
449 self.state
450 ))),
451 }
452 }
453
454 fn transaction_status(&self) -> TransactionStatus {
455 self.transaction_status
456 }
457}
458
459use crate::protocol::frontend::write_flush;
463
464#[derive(Debug, Clone, Copy, PartialEq, Eq)]
466enum BindState {
467 Initial,
468 WaitingParse,
469 WaitingBind,
470 Finished,
471}
472
473pub struct BindStateMachine {
477 state: BindState,
478 needs_parse: bool,
479}
480
481impl BindStateMachine {
482 pub fn bind_prepared<P: ToParams>(
491 buffer_set: &mut BufferSet,
492 portal_name: &str,
493 statement_name: &str,
494 param_oids: &[Oid],
495 params: &P,
496 ) -> Result<Self> {
497 buffer_set.write_buffer.clear();
498 write_bind(
499 &mut buffer_set.write_buffer,
500 portal_name,
501 statement_name,
502 params,
503 param_oids,
504 )?;
505 write_flush(&mut buffer_set.write_buffer);
506
507 Ok(Self {
508 state: BindState::Initial,
509 needs_parse: false,
510 })
511 }
512
513 pub fn bind_sql<P: ToParams>(
522 buffer_set: &mut BufferSet,
523 portal_name: &str,
524 sql: &str,
525 params: &P,
526 ) -> Result<Self> {
527 let param_oids = params.natural_oids();
528 buffer_set.write_buffer.clear();
529 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
530 write_bind(
531 &mut buffer_set.write_buffer,
532 portal_name,
533 "",
534 params,
535 ¶m_oids,
536 )?;
537 write_flush(&mut buffer_set.write_buffer);
538
539 Ok(Self {
540 state: BindState::Initial,
541 needs_parse: true,
542 })
543 }
544
545 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
547 if self.state == BindState::Initial {
549 self.state = if self.needs_parse {
550 BindState::WaitingParse
551 } else {
552 BindState::WaitingBind
553 };
554 return Ok(Action::WriteAndReadMessage);
555 }
556
557 let type_byte = buffer_set.type_byte;
558
559 if RawMessage::is_async_type(type_byte) {
561 return Ok(Action::ReadMessage);
562 }
563
564 if type_byte == msg_type::ERROR_RESPONSE {
566 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
567 return Err(error.into_error());
568 }
569
570 match self.state {
571 BindState::WaitingParse => {
572 if type_byte != msg_type::PARSE_COMPLETE {
573 return Err(Error::Protocol(format!(
574 "Expected ParseComplete, got '{}'",
575 type_byte as char
576 )));
577 }
578 ParseComplete::parse(&buffer_set.read_buffer)?;
579 self.state = BindState::WaitingBind;
580 Ok(Action::ReadMessage)
581 }
582 BindState::WaitingBind => {
583 if type_byte != msg_type::BIND_COMPLETE {
584 return Err(Error::Protocol(format!(
585 "Expected BindComplete, got '{}'",
586 type_byte as char
587 )));
588 }
589 BindComplete::parse(&buffer_set.read_buffer)?;
590 self.state = BindState::Finished;
591 Ok(Action::Finished)
592 }
593 _ => Err(Error::Protocol(format!(
594 "Unexpected state {:?}",
595 self.state
596 ))),
597 }
598 }
599}
600
601#[derive(Debug, Clone, Copy, PartialEq, Eq)]
606enum BatchState {
607 Initial,
608 WaitingParse,
609 Processing,
610 Finished,
611}
612
613pub struct BatchStateMachine {
617 state: BatchState,
618 needs_parse: bool,
619 transaction_status: TransactionStatus,
620}
621
622impl BatchStateMachine {
623 pub fn new(needs_parse: bool) -> Self {
630 Self {
631 state: BatchState::Initial,
632 needs_parse,
633 transaction_status: TransactionStatus::Idle,
634 }
635 }
636
637 pub fn transaction_status(&self) -> TransactionStatus {
639 self.transaction_status
640 }
641
642 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
644 if self.state == BatchState::Initial {
646 self.state = if self.needs_parse {
647 BatchState::WaitingParse
648 } else {
649 BatchState::Processing
650 };
651 return Ok(Action::WriteAndReadMessage);
652 }
653
654 let type_byte = buffer_set.type_byte;
655
656 if RawMessage::is_async_type(type_byte) {
658 return Ok(Action::ReadMessage);
659 }
660
661 if type_byte == msg_type::ERROR_RESPONSE {
663 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
664 self.state = BatchState::Processing;
665 return Err(error.into_error());
666 }
667
668 match self.state {
669 BatchState::WaitingParse => {
670 if type_byte != msg_type::PARSE_COMPLETE {
671 return Err(Error::Protocol(format!(
672 "Expected ParseComplete, got '{}'",
673 type_byte as char
674 )));
675 }
676 ParseComplete::parse(&buffer_set.read_buffer)?;
677 self.state = BatchState::Processing;
678 Ok(Action::ReadMessage)
679 }
680 BatchState::Processing => {
681 match type_byte {
682 msg_type::BIND_COMPLETE => {
683 BindComplete::parse(&buffer_set.read_buffer)?;
684 Ok(Action::ReadMessage)
685 }
686 msg_type::NO_DATA => {
687 NoData::parse(&buffer_set.read_buffer)?;
688 Ok(Action::ReadMessage)
689 }
690 msg_type::ROW_DESCRIPTION => {
691 RowDescription::parse(&buffer_set.read_buffer)?;
693 Ok(Action::ReadMessage)
694 }
695 msg_type::DATA_ROW => {
696 Ok(Action::ReadMessage)
698 }
699 msg_type::COMMAND_COMPLETE => {
700 CommandComplete::parse(&buffer_set.read_buffer)?;
701 Ok(Action::ReadMessage)
702 }
703 msg_type::EMPTY_QUERY_RESPONSE => {
704 EmptyQueryResponse::parse(&buffer_set.read_buffer)?;
705 Ok(Action::ReadMessage)
706 }
707 msg_type::READY_FOR_QUERY => {
708 let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
709 self.transaction_status = ready.transaction_status().unwrap_or_default();
710 self.state = BatchState::Finished;
711 Ok(Action::Finished)
712 }
713 _ => Err(Error::Protocol(format!(
714 "Unexpected message in batch: '{}'",
715 type_byte as char
716 ))),
717 }
718 }
719 _ => Err(Error::Protocol(format!(
720 "Unexpected state {:?}",
721 self.state
722 ))),
723 }
724 }
725}