1use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25
26pub struct Conn {
28 pub(crate) stream: Stream,
29 pub(crate) buffer_set: PooledBufferSet,
30 backend_key: Option<BackendKeyData>,
31 server_params: Vec<(String, String)>,
32 pub(crate) transaction_status: TransactionStatus,
33 pub(crate) is_broken: bool,
34 name_counter: u64,
35 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39 pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41 where
42 Error: From<O::Error>,
43 {
44 let opts = opts.try_into()?;
45
46 let stream = if let Some(socket_path) = &opts.socket {
47 #[cfg(unix)]
48 {
49 Stream::unix(UnixStream::connect(socket_path).await?)
50 }
51 #[cfg(not(unix))]
52 {
53 let _ = socket_path;
54 return Err(Error::Unsupported(
55 "Unix sockets are not supported on this platform".into(),
56 ));
57 }
58 } else {
59 if opts.host.is_empty() {
60 return Err(Error::InvalidUsage("host is empty".into()));
61 }
62 let addr = format!("{}:{}", opts.host, opts.port);
63 let tcp = TcpStream::connect(&addr).await?;
64 tcp.set_nodelay(true)?;
65 Stream::tcp(tcp)
66 };
67
68 Self::new_with_stream(stream, opts).await
69 }
70
71 #[allow(unused_mut)]
73 pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
74 let mut buffer_set = options.buffer_pool.get_buffer_set();
75 let mut state_machine = ConnectionStateMachine::new(options.clone());
76
77 loop {
79 match state_machine.step(&mut buffer_set)? {
80 Action::WriteAndReadByte => {
81 stream.write_all(&buffer_set.write_buffer).await?;
82 stream.flush().await?;
83 let byte = stream.read_u8().await?;
84 state_machine.set_ssl_response(byte);
85 }
86 Action::ReadMessage => {
87 stream.read_message(&mut buffer_set).await?;
88 }
89 Action::Write => {
90 stream.write_all(&buffer_set.write_buffer).await?;
91 stream.flush().await?;
92 }
93 Action::WriteAndReadMessage => {
94 stream.write_all(&buffer_set.write_buffer).await?;
95 stream.flush().await?;
96 stream.read_message(&mut buffer_set).await?;
97 }
98 Action::TlsHandshake => {
99 #[cfg(feature = "tokio-tls")]
100 {
101 stream = stream.upgrade_to_tls(&options.host).await?;
102 }
103 #[cfg(not(feature = "tokio-tls"))]
104 {
105 return Err(Error::Unsupported(
106 "TLS requested but tokio-tls feature not enabled".into(),
107 ));
108 }
109 }
110 Action::HandleAsyncMessageAndReadMessage(_) => {
111 stream.read_message(&mut buffer_set).await?;
113 }
114 Action::Finished => break,
115 }
116 }
117
118 let conn = Self {
119 stream,
120 buffer_set,
121 backend_key: state_machine.backend_key().cloned(),
122 server_params: state_machine.take_server_params(),
123 transaction_status: state_machine.transaction_status(),
124 is_broken: false,
125 name_counter: 0,
126 async_message_handler: None,
127 };
128
129 #[cfg(unix)]
131 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
132 conn.try_upgrade_to_unix_socket(&options).await
133 } else {
134 conn
135 };
136
137 Ok(conn)
138 }
139
140 #[cfg(unix)]
143 fn try_upgrade_to_unix_socket(
144 mut self,
145 opts: &Opts,
146 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
147 let opts = opts.clone();
148 Box::pin(async move {
149 let mut handler = FirstRowHandler::<(String,)>::new();
151 if self
152 .query("SHOW unix_socket_directories", &mut handler)
153 .await
154 .is_err()
155 {
156 return self;
157 }
158
159 let socket_dir = match handler.into_row() {
160 Some((dirs,)) => {
161 match dirs.split(',').next() {
163 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
164 _ => return self,
165 }
166 }
167 None => return self,
168 };
169
170 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
172
173 let unix_stream = match UnixStream::connect(&socket_path).await {
175 Ok(s) => s,
176 Err(_) => return self,
177 };
178
179 let mut opts_unix = opts.clone();
181 opts_unix.upgrade_to_unix_socket = false;
182
183 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
184 Ok(new_conn) => new_conn,
185 Err(_) => self,
186 }
187 })
188 }
189
190 pub fn backend_key(&self) -> Option<&BackendKeyData> {
192 self.backend_key.as_ref()
193 }
194
195 pub fn connection_id(&self) -> u32 {
199 self.backend_key.as_ref().map_or(0, |k| k.process_id())
200 }
201
202 pub fn server_params(&self) -> &[(String, String)] {
204 &self.server_params
205 }
206
207 pub fn transaction_status(&self) -> TransactionStatus {
209 self.transaction_status
210 }
211
212 pub fn in_transaction(&self) -> bool {
214 self.transaction_status.in_transaction()
215 }
216
217 pub fn is_broken(&self) -> bool {
219 self.is_broken
220 }
221
222 pub(crate) fn next_portal_name(&mut self) -> String {
224 self.name_counter += 1;
225 format!("_zero_p_{}", self.name_counter)
226 }
227
228 pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
232 &mut self,
233 portal_name: &str,
234 statement: &S,
235 params: &P,
236 ) -> Result<()> {
237 let mut state_machine = if let Some(sql) = statement.as_sql() {
239 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
240 } else {
241 let stmt = statement.as_prepared().unwrap();
242 BindStateMachine::bind_prepared(
243 &mut self.buffer_set,
244 portal_name,
245 &stmt.wire_name(),
246 &stmt.param_oids,
247 params,
248 )?
249 };
250
251 loop {
253 match state_machine.step(&mut self.buffer_set)? {
254 Action::ReadMessage => {
255 self.stream.read_message(&mut self.buffer_set).await?;
256 }
257 Action::Write => {
258 self.stream.write_all(&self.buffer_set.write_buffer).await?;
259 self.stream.flush().await?;
260 }
261 Action::WriteAndReadMessage => {
262 self.stream.write_all(&self.buffer_set.write_buffer).await?;
263 self.stream.flush().await?;
264 self.stream.read_message(&mut self.buffer_set).await?;
265 }
266 Action::Finished => break,
267 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
268 }
269 }
270
271 Ok(())
272 }
273
274 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
281 self.async_message_handler = Some(Box::new(handler));
282 }
283
284 pub fn clear_async_message_handler(&mut self) {
286 self.async_message_handler = None;
287 }
288
289 pub async fn ping(&mut self) -> Result<()> {
291 self.query_drop("").await?;
292 Ok(())
293 }
294
295 async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
297 loop {
298 match state_machine.step(&mut self.buffer_set)? {
299 Action::WriteAndReadByte => {
300 return Err(Error::Protocol(
301 "Unexpected WriteAndReadByte in query state machine".into(),
302 ));
303 }
304 Action::ReadMessage => {
305 self.stream.read_message(&mut self.buffer_set).await?;
306 }
307 Action::Write => {
308 self.stream.write_all(&self.buffer_set.write_buffer).await?;
309 self.stream.flush().await?;
310 }
311 Action::WriteAndReadMessage => {
312 self.stream.write_all(&self.buffer_set.write_buffer).await?;
313 self.stream.flush().await?;
314 self.stream.read_message(&mut self.buffer_set).await?;
315 }
316 Action::TlsHandshake => {
317 return Err(Error::Protocol(
318 "Unexpected TlsHandshake in query state machine".into(),
319 ));
320 }
321 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
322 if let Some(ref mut h) = self.async_message_handler {
323 h.handle(async_msg);
324 }
325 self.stream.read_message(&mut self.buffer_set).await?;
327 }
328 Action::Finished => {
329 self.transaction_status = state_machine.transaction_status();
330 break;
331 }
332 }
333 }
334 Ok(())
335 }
336
337 pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
339 let result = self.query_inner(sql, handler).await;
340 if let Err(e) = &result
341 && e.is_connection_broken()
342 {
343 self.is_broken = true;
344 }
345 result
346 }
347
348 async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
349 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
350 self.drive(&mut state_machine).await
351 }
352
353 pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
355 let mut handler = DropHandler::new();
356 self.query(sql, &mut handler).await?;
357 Ok(handler.rows_affected())
358 }
359
360 pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
362 &mut self,
363 sql: &str,
364 ) -> Result<Vec<T>> {
365 let mut handler = crate::handler::CollectHandler::<T>::new();
366 self.query(sql, &mut handler).await?;
367 Ok(handler.into_rows())
368 }
369
370 pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
372 &mut self,
373 sql: &str,
374 ) -> Result<Option<T>> {
375 let mut handler = crate::handler::FirstRowHandler::<T>::new();
376 self.query(sql, &mut handler).await?;
377 Ok(handler.into_row())
378 }
379
380 pub async fn close(mut self) -> Result<()> {
382 self.buffer_set.write_buffer.clear();
383 write_terminate(&mut self.buffer_set.write_buffer);
384 self.stream.write_all(&self.buffer_set.write_buffer).await?;
385 self.stream.flush().await?;
386 Ok(())
387 }
388
389 pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
393 self.prepare_typed(query, &[]).await
394 }
395
396 pub async fn prepare_typed(
398 &mut self,
399 query: &str,
400 param_oids: &[u32],
401 ) -> Result<PreparedStatement> {
402 self.name_counter += 1;
403 let idx = self.name_counter;
404 let result = self.prepare_inner(idx, query, param_oids).await;
405 if let Err(e) = &result
406 && e.is_connection_broken()
407 {
408 self.is_broken = true;
409 }
410 result
411 }
412
413 pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
430 if queries.is_empty() {
431 return Ok(Vec::new());
432 }
433
434 let start_idx = self.name_counter + 1;
435 self.name_counter += queries.len() as u64;
436
437 let result = self.prepare_batch_inner(queries, start_idx).await;
438 if let Err(e) = &result
439 && e.is_connection_broken()
440 {
441 self.is_broken = true;
442 }
443 result
444 }
445
446 async fn prepare_batch_inner(
447 &mut self,
448 queries: &[&str],
449 start_idx: u64,
450 ) -> Result<Vec<PreparedStatement>> {
451 use crate::state::batch_prepare::BatchPrepareStateMachine;
452
453 let mut state_machine =
454 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
455
456 loop {
457 match state_machine.step(&mut self.buffer_set)? {
458 Action::ReadMessage => {
459 self.stream.read_message(&mut self.buffer_set).await?;
460 }
461 Action::WriteAndReadMessage => {
462 self.stream.write_all(&self.buffer_set.write_buffer).await?;
463 self.stream.flush().await?;
464 self.stream.read_message(&mut self.buffer_set).await?;
465 }
466 Action::Finished => {
467 self.transaction_status = state_machine.transaction_status();
468 break;
469 }
470 _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
471 }
472 }
473
474 Ok(state_machine.take_statements())
475 }
476
477 async fn prepare_inner(
478 &mut self,
479 idx: u64,
480 query: &str,
481 param_oids: &[u32],
482 ) -> Result<PreparedStatement> {
483 let mut handler = DropHandler::new();
484 let mut state_machine = ExtendedQueryStateMachine::prepare(
485 &mut handler,
486 &mut self.buffer_set,
487 idx,
488 query,
489 param_oids,
490 );
491 self.drive(&mut state_machine).await?;
492 state_machine
493 .take_prepared_statement()
494 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
495 }
496
497 pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
503 &mut self,
504 statement: S,
505 params: P,
506 handler: &mut H,
507 ) -> Result<()> {
508 let result = self.exec_inner(&statement, ¶ms, handler).await;
509 if let Err(e) = &result
510 && e.is_connection_broken()
511 {
512 self.is_broken = true;
513 }
514 result
515 }
516
517 async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
518 &mut self,
519 statement: &S,
520 params: &P,
521 handler: &mut H,
522 ) -> Result<()> {
523 let mut state_machine = if statement.needs_parse() {
524 ExtendedQueryStateMachine::execute_sql(
525 handler,
526 &mut self.buffer_set,
527 statement.as_sql().unwrap(),
528 params,
529 )?
530 } else {
531 let stmt = statement.as_prepared().unwrap();
532 ExtendedQueryStateMachine::execute(
533 handler,
534 &mut self.buffer_set,
535 &stmt.wire_name(),
536 &stmt.param_oids,
537 params,
538 )?
539 };
540
541 self.drive(&mut state_machine).await
542 }
543
544 pub async fn exec_drop<S: IntoStatement, P: ToParams>(
548 &mut self,
549 statement: S,
550 params: P,
551 ) -> Result<Option<u64>> {
552 let mut handler = DropHandler::new();
553 self.exec(statement, params, &mut handler).await?;
554 Ok(handler.rows_affected())
555 }
556
557 pub async fn exec_collect<
561 T: for<'a> crate::conversion::FromRow<'a>,
562 S: IntoStatement,
563 P: ToParams,
564 >(
565 &mut self,
566 statement: S,
567 params: P,
568 ) -> Result<Vec<T>> {
569 let mut handler = crate::handler::CollectHandler::<T>::new();
570 self.exec(statement, params, &mut handler).await?;
571 Ok(handler.into_rows())
572 }
573
574 pub async fn exec_batch<S: IntoStatement, P: ToParams>(
605 &mut self,
606 statement: S,
607 params_list: &[P],
608 ) -> Result<()> {
609 self.exec_batch_chunked(statement, params_list, 1000).await
610 }
611
612 pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
616 &mut self,
617 statement: S,
618 params_list: &[P],
619 chunk_size: usize,
620 ) -> Result<()> {
621 let result = self
622 .exec_batch_inner(&statement, params_list, chunk_size)
623 .await;
624 if let Err(e) = &result
625 && e.is_connection_broken()
626 {
627 self.is_broken = true;
628 }
629 result
630 }
631
632 async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
633 &mut self,
634 statement: &S,
635 params_list: &[P],
636 chunk_size: usize,
637 ) -> Result<()> {
638 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
639 use crate::state::extended::BatchStateMachine;
640
641 if params_list.is_empty() {
642 return Ok(());
643 }
644
645 let chunk_size = chunk_size.max(1);
646 let needs_parse = statement.needs_parse();
647 let sql = statement.as_sql();
648 let prepared = statement.as_prepared();
649
650 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
652 stmt.param_oids.clone()
653 } else {
654 params_list[0].natural_oids()
655 };
656
657 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
659
660 for chunk in params_list.chunks(chunk_size) {
661 self.buffer_set.write_buffer.clear();
662
663 let parse_in_chunk = needs_parse;
665 if parse_in_chunk {
666 write_parse(
667 &mut self.buffer_set.write_buffer,
668 "",
669 sql.unwrap(),
670 ¶m_oids,
671 );
672 }
673
674 for params in chunk {
676 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
677 write_bind(
678 &mut self.buffer_set.write_buffer,
679 "",
680 effective_stmt_name,
681 params,
682 ¶m_oids,
683 )?;
684 write_execute(&mut self.buffer_set.write_buffer, "", 0);
685 }
686
687 write_sync(&mut self.buffer_set.write_buffer);
689
690 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
692 self.drive_batch(&mut state_machine).await?;
693 self.transaction_status = state_machine.transaction_status();
694 }
695
696 Ok(())
697 }
698
699 async fn drive_batch(
701 &mut self,
702 state_machine: &mut crate::state::extended::BatchStateMachine,
703 ) -> Result<()> {
704 use crate::protocol::backend::{ReadyForQuery, msg_type};
705 use crate::state::action::Action;
706
707 loop {
708 let step_result = state_machine.step(&mut self.buffer_set);
709 match step_result {
710 Ok(Action::ReadMessage) => {
711 self.stream.read_message(&mut self.buffer_set).await?;
712 }
713 Ok(Action::WriteAndReadMessage) => {
714 self.stream.write_all(&self.buffer_set.write_buffer).await?;
715 self.stream.flush().await?;
716 self.stream.read_message(&mut self.buffer_set).await?;
717 }
718 Ok(Action::Finished) => {
719 break;
720 }
721 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
722 Err(e) => {
723 loop {
725 self.stream.read_message(&mut self.buffer_set).await?;
726 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
727 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
728 self.transaction_status =
729 ready.transaction_status().unwrap_or_default();
730 break;
731 }
732 }
733 return Err(e);
734 }
735 }
736 }
737 Ok(())
738 }
739
740 pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
742 let result = self.close_statement_inner(&stmt.wire_name()).await;
743 if let Err(e) = &result
744 && e.is_connection_broken()
745 {
746 self.is_broken = true;
747 }
748 result
749 }
750
751 async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
752 let mut handler = DropHandler::new();
753 let mut state_machine =
754 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
755 self.drive(&mut state_machine).await
756 }
757
758 pub async fn lowlevel_flush(&mut self) -> Result<()> {
766 use crate::protocol::frontend::write_flush;
767
768 self.buffer_set.write_buffer.clear();
769 write_flush(&mut self.buffer_set.write_buffer);
770
771 self.stream.write_all(&self.buffer_set.write_buffer).await?;
772 self.stream.flush().await?;
773 Ok(())
774 }
775
776 pub async fn lowlevel_sync(&mut self) -> Result<()> {
783 let result = self.lowlevel_sync_inner().await;
784 if let Err(e) = &result
785 && e.is_connection_broken()
786 {
787 self.is_broken = true;
788 }
789 result
790 }
791
792 async fn lowlevel_sync_inner(&mut self) -> Result<()> {
793 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
794 use crate::protocol::frontend::write_sync;
795
796 self.buffer_set.write_buffer.clear();
797 write_sync(&mut self.buffer_set.write_buffer);
798
799 self.stream.write_all(&self.buffer_set.write_buffer).await?;
800 self.stream.flush().await?;
801
802 let mut pending_error: Option<Error> = None;
803
804 loop {
805 self.stream.read_message(&mut self.buffer_set).await?;
806 let type_byte = self.buffer_set.type_byte;
807
808 if RawMessage::is_async_type(type_byte) {
809 continue;
810 }
811
812 match type_byte {
813 msg_type::READY_FOR_QUERY => {
814 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
815 self.transaction_status = ready.transaction_status().unwrap_or_default();
816 if let Some(e) = pending_error {
817 return Err(e);
818 }
819 return Ok(());
820 }
821 msg_type::ERROR_RESPONSE => {
822 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
823 pending_error = Some(error.into_error());
824 }
825 _ => {
826 }
828 }
829 }
830 }
831
832 pub async fn lowlevel_bind<P: ToParams>(
842 &mut self,
843 portal: &str,
844 statement_name: &str,
845 params: P,
846 ) -> Result<()> {
847 let result = self
848 .lowlevel_bind_inner(portal, statement_name, ¶ms)
849 .await;
850 if let Err(e) = &result
851 && e.is_connection_broken()
852 {
853 self.is_broken = true;
854 }
855 result
856 }
857
858 async fn lowlevel_bind_inner<P: ToParams>(
859 &mut self,
860 portal: &str,
861 statement_name: &str,
862 params: &P,
863 ) -> Result<()> {
864 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
865 use crate::protocol::frontend::{write_bind, write_flush};
866
867 let param_oids = params.natural_oids();
868 self.buffer_set.write_buffer.clear();
869 write_bind(
870 &mut self.buffer_set.write_buffer,
871 portal,
872 statement_name,
873 params,
874 ¶m_oids,
875 )?;
876 write_flush(&mut self.buffer_set.write_buffer);
877
878 self.stream.write_all(&self.buffer_set.write_buffer).await?;
879 self.stream.flush().await?;
880
881 loop {
882 self.stream.read_message(&mut self.buffer_set).await?;
883 let type_byte = self.buffer_set.type_byte;
884
885 if RawMessage::is_async_type(type_byte) {
886 continue;
887 }
888
889 match type_byte {
890 msg_type::BIND_COMPLETE => {
891 BindComplete::parse(&self.buffer_set.read_buffer)?;
892 return Ok(());
893 }
894 msg_type::ERROR_RESPONSE => {
895 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
896 return Err(error.into_error());
897 }
898 _ => {
899 return Err(Error::Protocol(format!(
900 "Expected BindComplete or ErrorResponse, got '{}'",
901 type_byte as char
902 )));
903 }
904 }
905 }
906 }
907
908 pub async fn lowlevel_execute<H: BinaryHandler>(
921 &mut self,
922 portal: &str,
923 max_rows: u32,
924 handler: &mut H,
925 ) -> Result<bool> {
926 let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
927 if let Err(e) = &result
928 && e.is_connection_broken()
929 {
930 self.is_broken = true;
931 }
932 result
933 }
934
935 async fn lowlevel_execute_inner<H: BinaryHandler>(
936 &mut self,
937 portal: &str,
938 max_rows: u32,
939 handler: &mut H,
940 ) -> Result<bool> {
941 use crate::protocol::backend::{
942 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
943 RowDescription, msg_type,
944 };
945 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
946
947 self.buffer_set.write_buffer.clear();
948 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
949 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
950 write_flush(&mut self.buffer_set.write_buffer);
951
952 self.stream.write_all(&self.buffer_set.write_buffer).await?;
953 self.stream.flush().await?;
954
955 let mut column_buffer: Vec<u8> = Vec::new();
956
957 loop {
958 self.stream.read_message(&mut self.buffer_set).await?;
959 let type_byte = self.buffer_set.type_byte;
960
961 if RawMessage::is_async_type(type_byte) {
962 continue;
963 }
964
965 match type_byte {
966 msg_type::ROW_DESCRIPTION => {
967 column_buffer.clear();
968 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
969 let cols = RowDescription::parse(&column_buffer)?;
970 handler.result_start(cols)?;
971 }
972 msg_type::NO_DATA => {
973 NoData::parse(&self.buffer_set.read_buffer)?;
974 }
975 msg_type::DATA_ROW => {
976 let cols = RowDescription::parse(&column_buffer)?;
977 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
978 handler.row(cols, row)?;
979 }
980 msg_type::COMMAND_COMPLETE => {
981 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
982 handler.result_end(complete)?;
983 return Ok(false); }
985 msg_type::PORTAL_SUSPENDED => {
986 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
987 return Ok(true); }
989 msg_type::ERROR_RESPONSE => {
990 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
991 return Err(error.into_error());
992 }
993 _ => {
994 return Err(Error::Protocol(format!(
995 "Unexpected message in execute: '{}'",
996 type_byte as char
997 )));
998 }
999 }
1000 }
1001 }
1002
1003 pub async fn exec_iter<S: IntoStatement, P, F, Fut, T>(
1033 &mut self,
1034 statement: S,
1035 params: P,
1036 f: F,
1037 ) -> Result<T>
1038 where
1039 P: ToParams,
1040 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1041 Fut: std::future::Future<Output = Result<T>>,
1042 {
1043 let result = self.exec_iter_inner(&statement, ¶ms, f).await;
1044 if let Err(e) = &result
1045 && e.is_connection_broken()
1046 {
1047 self.is_broken = true;
1048 }
1049 result
1050 }
1051
1052 async fn exec_iter_inner<S: IntoStatement, P, F, Fut, T>(
1053 &mut self,
1054 statement: &S,
1055 params: &P,
1056 f: F,
1057 ) -> Result<T>
1058 where
1059 P: ToParams,
1060 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1061 Fut: std::future::Future<Output = Result<T>>,
1062 {
1063 let mut state_machine = if let Some(sql) = statement.as_sql() {
1065 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1066 } else {
1067 let stmt = statement.as_prepared().unwrap();
1068 BindStateMachine::bind_prepared(
1069 &mut self.buffer_set,
1070 "",
1071 &stmt.wire_name(),
1072 &stmt.param_oids,
1073 params,
1074 )?
1075 };
1076
1077 loop {
1079 match state_machine.step(&mut self.buffer_set)? {
1080 Action::ReadMessage => {
1081 self.stream.read_message(&mut self.buffer_set).await?;
1082 }
1083 Action::Write => {
1084 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1085 self.stream.flush().await?;
1086 }
1087 Action::WriteAndReadMessage => {
1088 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1089 self.stream.flush().await?;
1090 self.stream.read_message(&mut self.buffer_set).await?;
1091 }
1092 Action::Finished => break,
1093 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1094 }
1095 }
1096
1097 let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1099 let result = f(&mut portal).await;
1100
1101 let sync_result = portal.conn.lowlevel_sync().await;
1103
1104 match (result, sync_result) {
1106 (Ok(v), Ok(())) => Ok(v),
1107 (Err(e), _) => Err(e),
1108 (Ok(_), Err(e)) => Err(e),
1109 }
1110 }
1111
1112 pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1114 let result = self.lowlevel_close_portal_inner(portal).await;
1115 if let Err(e) = &result
1116 && e.is_connection_broken()
1117 {
1118 self.is_broken = true;
1119 }
1120 result
1121 }
1122
1123 async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1124 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1125 use crate::protocol::frontend::{write_close_portal, write_flush};
1126
1127 self.buffer_set.write_buffer.clear();
1128 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1129 write_flush(&mut self.buffer_set.write_buffer);
1130
1131 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1132 self.stream.flush().await?;
1133
1134 loop {
1135 self.stream.read_message(&mut self.buffer_set).await?;
1136 let type_byte = self.buffer_set.type_byte;
1137
1138 if RawMessage::is_async_type(type_byte) {
1139 continue;
1140 }
1141
1142 match type_byte {
1143 msg_type::CLOSE_COMPLETE => {
1144 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1145 return Ok(());
1146 }
1147 msg_type::ERROR_RESPONSE => {
1148 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1149 return Err(error.into_error());
1150 }
1151 _ => {
1152 return Err(Error::Protocol(format!(
1153 "Expected CloseComplete or ErrorResponse, got '{}'",
1154 type_byte as char
1155 )));
1156 }
1157 }
1158 }
1159 }
1160
1161 pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1192 where
1193 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1194 Fut: std::future::Future<Output = Result<T>>,
1195 {
1196 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1197 let result = f(&mut pipeline).await;
1198 pipeline.cleanup().await;
1199 result
1200 }
1201
1202 pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1212 where
1213 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1214 Fut: std::future::Future<Output = Result<R>>,
1215 {
1216 if self.in_transaction() {
1217 return Err(Error::InvalidUsage(
1218 "nested transactions are not supported".into(),
1219 ));
1220 }
1221
1222 self.query_drop("BEGIN").await?;
1223
1224 let tx = super::transaction::Transaction::new(self.connection_id());
1225
1226 let result = f(self, tx).await;
1229
1230 if self.in_transaction() {
1232 let rollback_result = self.query_drop("ROLLBACK").await;
1233
1234 if let Err(e) = result {
1236 return Err(e);
1237 }
1238 rollback_result?;
1239 }
1240
1241 result
1242 }
1243}