1#![allow(clippy::manual_async_fn)]
15#![allow(clippy::result_large_err)]
17
18use std::collections::HashMap;
19use std::future::Future;
20use std::sync::Arc;
21
22use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
23use asupersync::net::TcpStream;
24use asupersync::sync::Mutex;
25use asupersync::{Cx, Outcome};
26
27use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
28use sqlmodel_core::error::{
29 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
30};
31use sqlmodel_core::row::ColumnInfo;
32use sqlmodel_core::{Error, Row, Value};
33
34use crate::auth::ScramClient;
35use crate::config::PgConfig;
36use crate::connection::{ConnectionState, TransactionStatusState};
37use crate::protocol::{
38 BackendMessage, DescribeKind, ErrorFields, FrontendMessage, MessageReader, MessageWriter,
39 PROTOCOL_VERSION,
40};
41use crate::types::{Format, decode_value, encode_value};
42
43pub struct PgAsyncConnection {
48 stream: TcpStream,
49 state: ConnectionState,
50 process_id: i32,
51 secret_key: i32,
52 parameters: HashMap<String, String>,
53 config: PgConfig,
54 reader: MessageReader,
55 writer: MessageWriter,
56 read_buf: Vec<u8>,
57}
58
59impl std::fmt::Debug for PgAsyncConnection {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("PgAsyncConnection")
62 .field("state", &self.state)
63 .field("process_id", &self.process_id)
64 .field("host", &self.config.host)
65 .field("port", &self.config.port)
66 .field("database", &self.config.database)
67 .finish_non_exhaustive()
68 }
69}
70
71impl PgAsyncConnection {
72 pub async fn connect(_cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
74 let addr = config.socket_addr();
75 let socket_addr = match addr.parse() {
76 Ok(a) => a,
77 Err(e) => {
78 return Outcome::Err(Error::Connection(ConnectionError {
79 kind: ConnectionErrorKind::Connect,
80 message: format!("Invalid socket address: {}", e),
81 source: None,
82 }));
83 }
84 };
85
86 let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
87 Ok(s) => s,
88 Err(e) => {
89 let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
90 ConnectionErrorKind::Refused
91 } else {
92 ConnectionErrorKind::Connect
93 };
94 return Outcome::Err(Error::Connection(ConnectionError {
95 kind,
96 message: format!("Failed to connect to {}: {}", addr, e),
97 source: Some(Box::new(e)),
98 }));
99 }
100 };
101
102 stream.set_nodelay(true).ok();
103
104 let mut conn = Self {
105 stream,
106 state: ConnectionState::Connecting,
107 process_id: 0,
108 secret_key: 0,
109 parameters: HashMap::new(),
110 config,
111 reader: MessageReader::new(),
112 writer: MessageWriter::new(),
113 read_buf: vec![0u8; 8192],
114 };
115
116 if conn.config.ssl_mode.should_try_ssl() {
118 match conn.negotiate_ssl().await {
119 Outcome::Ok(()) => {}
120 Outcome::Err(e) => return Outcome::Err(e),
121 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
122 Outcome::Panicked(p) => return Outcome::Panicked(p),
123 }
124 }
125
126 if let Outcome::Err(e) = conn.send_startup().await {
128 return Outcome::Err(e);
129 }
130 conn.state = ConnectionState::Authenticating;
131
132 match conn.handle_auth().await {
133 Outcome::Ok(()) => {}
134 Outcome::Err(e) => return Outcome::Err(e),
135 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
136 Outcome::Panicked(p) => return Outcome::Panicked(p),
137 }
138
139 match conn.read_startup_messages().await {
140 Outcome::Ok(()) => Outcome::Ok(conn),
141 Outcome::Err(e) => Outcome::Err(e),
142 Outcome::Cancelled(r) => Outcome::Cancelled(r),
143 Outcome::Panicked(p) => Outcome::Panicked(p),
144 }
145 }
146
147 pub async fn query_async(
149 &mut self,
150 cx: &Cx,
151 sql: &str,
152 params: &[Value],
153 ) -> Outcome<Vec<Row>, Error> {
154 match self.run_extended(cx, sql, params).await {
155 Outcome::Ok(result) => Outcome::Ok(result.rows),
156 Outcome::Err(e) => Outcome::Err(e),
157 Outcome::Cancelled(r) => Outcome::Cancelled(r),
158 Outcome::Panicked(p) => Outcome::Panicked(p),
159 }
160 }
161
162 pub async fn execute_async(
164 &mut self,
165 cx: &Cx,
166 sql: &str,
167 params: &[Value],
168 ) -> Outcome<u64, Error> {
169 match self.run_extended(cx, sql, params).await {
170 Outcome::Ok(result) => {
171 Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
172 }
173 Outcome::Err(e) => Outcome::Err(e),
174 Outcome::Cancelled(r) => Outcome::Cancelled(r),
175 Outcome::Panicked(p) => Outcome::Panicked(p),
176 }
177 }
178
179 pub async fn insert_async(
185 &mut self,
186 cx: &Cx,
187 sql: &str,
188 params: &[Value],
189 ) -> Outcome<i64, Error> {
190 let result = match self.run_extended(cx, sql, params).await {
191 Outcome::Ok(r) => r,
192 Outcome::Err(e) => return Outcome::Err(e),
193 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
194 Outcome::Panicked(p) => return Outcome::Panicked(p),
195 };
196
197 let Some(row) = result.rows.first() else {
198 return Outcome::Err(query_error_msg(
199 "INSERT did not return an id; add `RETURNING id`",
200 QueryErrorKind::Database,
201 ));
202 };
203 let Some(id_value) = row.get(0) else {
204 return Outcome::Err(query_error_msg(
205 "INSERT result row missing id column",
206 QueryErrorKind::Database,
207 ));
208 };
209 match id_value.as_i64() {
210 Some(v) => Outcome::Ok(v),
211 None => Outcome::Err(query_error_msg(
212 "INSERT returned non-integer id",
213 QueryErrorKind::Database,
214 )),
215 }
216 }
217
218 pub async fn ping_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
220 self.execute_async(cx, "SELECT 1", &[]).await.map(|_| ())
221 }
222
223 pub async fn close_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
225 let _ = self.send_message(cx, &FrontendMessage::Terminate).await;
227 self.state = ConnectionState::Closed;
228 Outcome::Ok(())
229 }
230
231 async fn run_extended(
234 &mut self,
235 cx: &Cx,
236 sql: &str,
237 params: &[Value],
238 ) -> Outcome<PgQueryResult, Error> {
239 let mut param_types = Vec::with_capacity(params.len());
241 let mut param_values = Vec::with_capacity(params.len());
242
243 for v in params {
244 if matches!(v, Value::Null) {
245 param_types.push(0);
246 param_values.push(None);
247 continue;
248 }
249 match encode_value(v, Format::Text) {
250 Ok((bytes, oid)) => {
251 param_types.push(oid);
252 param_values.push(Some(bytes));
253 }
254 Err(e) => return Outcome::Err(e),
255 }
256 }
257
258 if let Outcome::Err(e) = self
260 .send_message(
261 cx,
262 &FrontendMessage::Parse {
263 name: String::new(),
264 query: sql.to_string(),
265 param_types,
266 },
267 )
268 .await
269 {
270 return Outcome::Err(e);
271 }
272
273 let param_formats = if params.is_empty() {
274 Vec::new()
275 } else {
276 vec![Format::Text.code()]
277 };
278 if let Outcome::Err(e) = self
279 .send_message(
280 cx,
281 &FrontendMessage::Bind {
282 portal: String::new(),
283 statement: String::new(),
284 param_formats,
285 params: param_values,
286 result_formats: Vec::new(),
288 },
289 )
290 .await
291 {
292 return Outcome::Err(e);
293 }
294
295 if let Outcome::Err(e) = self
296 .send_message(
297 cx,
298 &FrontendMessage::Describe {
299 kind: DescribeKind::Portal,
300 name: String::new(),
301 },
302 )
303 .await
304 {
305 return Outcome::Err(e);
306 }
307
308 if let Outcome::Err(e) = self
309 .send_message(
310 cx,
311 &FrontendMessage::Execute {
312 portal: String::new(),
313 max_rows: 0,
314 },
315 )
316 .await
317 {
318 return Outcome::Err(e);
319 }
320
321 if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
322 return Outcome::Err(e);
323 }
324
325 let mut field_descs: Option<Vec<crate::protocol::FieldDescription>> = None;
327 let mut columns: Option<Arc<ColumnInfo>> = None;
328 let mut rows: Vec<Row> = Vec::new();
329 let mut command_tag: Option<String> = None;
330
331 loop {
332 let msg = match self.receive_message(cx).await {
333 Outcome::Ok(m) => m,
334 Outcome::Err(e) => return Outcome::Err(e),
335 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
336 Outcome::Panicked(p) => return Outcome::Panicked(p),
337 };
338
339 match msg {
340 BackendMessage::ParseComplete
341 | BackendMessage::BindComplete
342 | BackendMessage::CloseComplete
343 | BackendMessage::ParameterDescription(_)
344 | BackendMessage::NoData
345 | BackendMessage::PortalSuspended
346 | BackendMessage::EmptyQueryResponse => {}
347 BackendMessage::RowDescription(desc) => {
348 let names: Vec<String> = desc.iter().map(|f| f.name.clone()).collect();
349 columns = Some(Arc::new(ColumnInfo::new(names)));
350 field_descs = Some(desc);
351 }
352 BackendMessage::DataRow(raw_values) => {
353 let Some(ref desc) = field_descs else {
354 return Outcome::Err(protocol_error(
355 "DataRow received before RowDescription",
356 ));
357 };
358 let Some(ref cols) = columns else {
359 return Outcome::Err(protocol_error("Row column metadata missing"));
360 };
361 if raw_values.len() != desc.len() {
362 return Outcome::Err(protocol_error("DataRow field count mismatch"));
363 }
364
365 let mut values = Vec::with_capacity(raw_values.len());
366 for (i, raw) in raw_values.into_iter().enumerate() {
367 match raw {
368 None => values.push(Value::Null),
369 Some(bytes) => {
370 let field = &desc[i];
371 let format = Format::from_code(field.format);
372 let decoded = match decode_value(
373 field.type_oid,
374 Some(bytes.as_slice()),
375 format,
376 ) {
377 Ok(v) => v,
378 Err(e) => return Outcome::Err(e),
379 };
380 values.push(decoded);
381 }
382 }
383 }
384 rows.push(Row::with_columns(Arc::clone(cols), values));
385 }
386 BackendMessage::CommandComplete(tag) => {
387 command_tag = Some(tag);
388 }
389 BackendMessage::ReadyForQuery(status) => {
390 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
391 break;
392 }
393 BackendMessage::ErrorResponse(e) => {
394 self.state = ConnectionState::Error;
395 return Outcome::Err(error_from_fields(&e));
396 }
397 BackendMessage::NoticeResponse(_notice) => {}
398 _ => {}
399 }
400 }
401
402 Outcome::Ok(PgQueryResult { rows, command_tag })
403 }
404
405 async fn negotiate_ssl(&mut self) -> Outcome<(), Error> {
408 if let Outcome::Err(e) = self.send_message_no_cx(&FrontendMessage::SSLRequest).await {
410 return Outcome::Err(e);
411 }
412
413 let mut buf = [0u8; 1];
415 match read_exact_async(&mut self.stream, &mut buf).await {
416 Ok(()) => {}
417 Err(e) => {
418 return Outcome::Err(Error::Connection(ConnectionError {
419 kind: ConnectionErrorKind::Ssl,
420 message: format!("Failed to read SSL response: {}", e),
421 source: Some(Box::new(e)),
422 }));
423 }
424 }
425
426 match buf[0] {
427 b'S' => {
428 if self.config.ssl_mode.is_required() {
430 Outcome::Err(Error::Connection(ConnectionError {
431 kind: ConnectionErrorKind::Ssl,
432 message: "SSL/TLS not yet implemented".to_string(),
433 source: None,
434 }))
435 } else {
436 Outcome::Err(Error::Connection(ConnectionError {
437 kind: ConnectionErrorKind::Ssl,
438 message: "SSL/TLS not yet implemented, reconnect with ssl_mode=disable"
439 .to_string(),
440 source: None,
441 }))
442 }
443 }
444 b'N' => {
445 if self.config.ssl_mode.is_required() {
446 Outcome::Err(Error::Connection(ConnectionError {
447 kind: ConnectionErrorKind::Ssl,
448 message: "Server does not support SSL".to_string(),
449 source: None,
450 }))
451 } else {
452 Outcome::Ok(())
453 }
454 }
455 other => Outcome::Err(Error::Connection(ConnectionError {
456 kind: ConnectionErrorKind::Ssl,
457 message: format!("Unexpected SSL response: 0x{other:02x}"),
458 source: None,
459 })),
460 }
461 }
462
463 async fn send_startup(&mut self) -> Outcome<(), Error> {
464 let params = self.config.startup_params();
465 self.send_message_no_cx(&FrontendMessage::Startup {
466 version: PROTOCOL_VERSION,
467 params,
468 })
469 .await
470 }
471
472 async fn handle_auth(&mut self) -> Outcome<(), Error> {
473 loop {
474 let msg = match self.receive_message_no_cx().await {
475 Outcome::Ok(m) => m,
476 Outcome::Err(e) => return Outcome::Err(e),
477 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
478 Outcome::Panicked(p) => return Outcome::Panicked(p),
479 };
480
481 match msg {
482 BackendMessage::AuthenticationOk => return Outcome::Ok(()),
483 BackendMessage::AuthenticationCleartextPassword => {
484 let Some(password) = self.config.password.as_ref() else {
485 return Outcome::Err(auth_error("Password required but not provided"));
486 };
487 if let Outcome::Err(e) = self
488 .send_message_no_cx(&FrontendMessage::PasswordMessage(password.clone()))
489 .await
490 {
491 return Outcome::Err(e);
492 }
493 }
494 BackendMessage::AuthenticationMD5Password(salt) => {
495 let Some(password) = self.config.password.as_ref() else {
496 return Outcome::Err(auth_error("Password required but not provided"));
497 };
498 let hash = md5_password(&self.config.user, password, salt);
499 if let Outcome::Err(e) = self
500 .send_message_no_cx(&FrontendMessage::PasswordMessage(hash))
501 .await
502 {
503 return Outcome::Err(e);
504 }
505 }
506 BackendMessage::AuthenticationSASL(mechanisms) => {
507 if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
508 match self.scram_auth().await {
509 Outcome::Ok(()) => {}
510 Outcome::Err(e) => return Outcome::Err(e),
511 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
512 Outcome::Panicked(p) => return Outcome::Panicked(p),
513 }
514 } else {
515 return Outcome::Err(auth_error(format!(
516 "Unsupported SASL mechanisms: {:?}",
517 mechanisms
518 )));
519 }
520 }
521 BackendMessage::ErrorResponse(e) => {
522 self.state = ConnectionState::Error;
523 return Outcome::Err(error_from_fields(&e));
524 }
525 other => {
526 return Outcome::Err(protocol_error(format!(
527 "Unexpected message during auth: {other:?}"
528 )));
529 }
530 }
531 }
532 }
533
534 async fn scram_auth(&mut self) -> Outcome<(), Error> {
535 let Some(password) = self.config.password.as_ref() else {
536 return Outcome::Err(auth_error("Password required for SCRAM-SHA-256"));
537 };
538
539 let mut client = ScramClient::new(&self.config.user, password);
540
541 let client_first = client.client_first();
543 if let Outcome::Err(e) = self
544 .send_message_no_cx(&FrontendMessage::SASLInitialResponse {
545 mechanism: "SCRAM-SHA-256".to_string(),
546 data: client_first,
547 })
548 .await
549 {
550 return Outcome::Err(e);
551 }
552
553 let msg = match self.receive_message_no_cx().await {
555 Outcome::Ok(m) => m,
556 Outcome::Err(e) => return Outcome::Err(e),
557 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
558 Outcome::Panicked(p) => return Outcome::Panicked(p),
559 };
560 let server_first_data = match msg {
561 BackendMessage::AuthenticationSASLContinue(data) => data,
562 BackendMessage::ErrorResponse(e) => {
563 self.state = ConnectionState::Error;
564 return Outcome::Err(error_from_fields(&e));
565 }
566 other => {
567 return Outcome::Err(protocol_error(format!(
568 "Expected SASL continue, got: {other:?}"
569 )));
570 }
571 };
572
573 let client_final = match client.process_server_first(&server_first_data) {
575 Ok(v) => v,
576 Err(e) => return Outcome::Err(e),
577 };
578 if let Outcome::Err(e) = self
579 .send_message_no_cx(&FrontendMessage::SASLResponse(client_final))
580 .await
581 {
582 return Outcome::Err(e);
583 }
584
585 let msg = match self.receive_message_no_cx().await {
587 Outcome::Ok(m) => m,
588 Outcome::Err(e) => return Outcome::Err(e),
589 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
590 Outcome::Panicked(p) => return Outcome::Panicked(p),
591 };
592 let server_final_data = match msg {
593 BackendMessage::AuthenticationSASLFinal(data) => data,
594 BackendMessage::ErrorResponse(e) => {
595 self.state = ConnectionState::Error;
596 return Outcome::Err(error_from_fields(&e));
597 }
598 other => {
599 return Outcome::Err(protocol_error(format!(
600 "Expected SASL final, got: {other:?}"
601 )));
602 }
603 };
604
605 if let Err(e) = client.verify_server_final(&server_final_data) {
606 return Outcome::Err(e);
607 }
608
609 let msg = match self.receive_message_no_cx().await {
611 Outcome::Ok(m) => m,
612 Outcome::Err(e) => return Outcome::Err(e),
613 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
614 Outcome::Panicked(p) => return Outcome::Panicked(p),
615 };
616 match msg {
617 BackendMessage::AuthenticationOk => Outcome::Ok(()),
618 BackendMessage::ErrorResponse(e) => {
619 self.state = ConnectionState::Error;
620 Outcome::Err(error_from_fields(&e))
621 }
622 other => Outcome::Err(protocol_error(format!(
623 "Expected AuthenticationOk, got: {other:?}"
624 ))),
625 }
626 }
627
628 async fn read_startup_messages(&mut self) -> Outcome<(), Error> {
629 loop {
630 let msg = match self.receive_message_no_cx().await {
631 Outcome::Ok(m) => m,
632 Outcome::Err(e) => return Outcome::Err(e),
633 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
634 Outcome::Panicked(p) => return Outcome::Panicked(p),
635 };
636
637 match msg {
638 BackendMessage::BackendKeyData {
639 process_id,
640 secret_key,
641 } => {
642 self.process_id = process_id;
643 self.secret_key = secret_key;
644 }
645 BackendMessage::ParameterStatus { name, value } => {
646 self.parameters.insert(name, value);
647 }
648 BackendMessage::ReadyForQuery(status) => {
649 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
650 return Outcome::Ok(());
651 }
652 BackendMessage::ErrorResponse(e) => {
653 self.state = ConnectionState::Error;
654 return Outcome::Err(error_from_fields(&e));
655 }
656 BackendMessage::NoticeResponse(_notice) => {}
657 other => {
658 return Outcome::Err(protocol_error(format!(
659 "Unexpected startup message: {other:?}"
660 )));
661 }
662 }
663 }
664 }
665
666 async fn send_message(&mut self, cx: &Cx, msg: &FrontendMessage) -> Outcome<(), Error> {
669 if let Some(reason) = cx.cancel_reason() {
671 return Outcome::Cancelled(reason);
672 }
673 self.send_message_no_cx(msg).await
674 }
675
676 async fn receive_message(&mut self, cx: &Cx) -> Outcome<BackendMessage, Error> {
677 if let Some(reason) = cx.cancel_reason() {
678 return Outcome::Cancelled(reason);
679 }
680 self.receive_message_no_cx().await
681 }
682
683 async fn send_message_no_cx(&mut self, msg: &FrontendMessage) -> Outcome<(), Error> {
684 let data = self.writer.write(msg).to_vec();
685
686 let mut written = 0;
687 while written < data.len() {
688 match std::future::poll_fn(|cx| {
689 std::pin::Pin::new(&mut self.stream).poll_write(cx, &data[written..])
690 })
691 .await
692 {
693 Ok(n) => {
694 if n == 0 {
695 self.state = ConnectionState::Error;
696 return Outcome::Err(Error::Connection(ConnectionError {
697 kind: ConnectionErrorKind::Disconnected,
698 message: "Connection closed while writing".to_string(),
699 source: None,
700 }));
701 }
702 written += n;
703 }
704 Err(e) => {
705 self.state = ConnectionState::Error;
706 return Outcome::Err(Error::Connection(ConnectionError {
707 kind: ConnectionErrorKind::Disconnected,
708 message: format!("Failed to write to server: {}", e),
709 source: Some(Box::new(e)),
710 }));
711 }
712 }
713 }
714
715 match std::future::poll_fn(|cx| std::pin::Pin::new(&mut self.stream).poll_flush(cx)).await {
716 Ok(()) => Outcome::Ok(()),
717 Err(e) => {
718 self.state = ConnectionState::Error;
719 Outcome::Err(Error::Connection(ConnectionError {
720 kind: ConnectionErrorKind::Disconnected,
721 message: format!("Failed to flush stream: {}", e),
722 source: Some(Box::new(e)),
723 }))
724 }
725 }
726 }
727
728 async fn receive_message_no_cx(&mut self) -> Outcome<BackendMessage, Error> {
729 loop {
730 match self.reader.next_message() {
731 Ok(Some(msg)) => return Outcome::Ok(msg),
732 Ok(None) => {}
733 Err(e) => {
734 self.state = ConnectionState::Error;
735 return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
736 }
737 }
738
739 let mut read_buf = ReadBuf::new(&mut self.read_buf);
740 match std::future::poll_fn(|cx| {
741 std::pin::Pin::new(&mut self.stream).poll_read(cx, &mut read_buf)
742 })
743 .await
744 {
745 Ok(()) => {
746 let n = read_buf.filled().len();
747 if n == 0 {
748 self.state = ConnectionState::Disconnected;
749 return Outcome::Err(Error::Connection(ConnectionError {
750 kind: ConnectionErrorKind::Disconnected,
751 message: "Connection closed by server".to_string(),
752 source: None,
753 }));
754 }
755 if let Err(e) = self.reader.feed(read_buf.filled()) {
756 self.state = ConnectionState::Error;
757 return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
758 }
759 }
760 Err(e) => {
761 self.state = ConnectionState::Error;
762 return Outcome::Err(match e.kind() {
763 std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock => {
764 Error::Timeout
765 }
766 _ => Error::Connection(ConnectionError {
767 kind: ConnectionErrorKind::Disconnected,
768 message: format!("Failed to read from server: {}", e),
769 source: Some(Box::new(e)),
770 }),
771 });
772 }
773 }
774 }
775 }
776}
777
778pub struct SharedPgConnection {
780 inner: Arc<Mutex<PgAsyncConnection>>,
781}
782
783impl SharedPgConnection {
784 pub fn new(conn: PgAsyncConnection) -> Self {
785 Self {
786 inner: Arc::new(Mutex::new(conn)),
787 }
788 }
789
790 pub async fn connect(cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
791 match PgAsyncConnection::connect(cx, config).await {
792 Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
793 Outcome::Err(e) => Outcome::Err(e),
794 Outcome::Cancelled(r) => Outcome::Cancelled(r),
795 Outcome::Panicked(p) => Outcome::Panicked(p),
796 }
797 }
798
799 pub fn inner(&self) -> &Arc<Mutex<PgAsyncConnection>> {
800 &self.inner
801 }
802
803 async fn begin_transaction_impl(
804 &self,
805 cx: &Cx,
806 isolation: Option<IsolationLevel>,
807 ) -> Outcome<SharedPgTransaction<'_>, Error> {
808 let inner = Arc::clone(&self.inner);
809 let Ok(mut guard) = inner.lock(cx).await else {
810 return Outcome::Err(connection_error("Failed to acquire connection lock"));
811 };
812
813 if let Some(level) = isolation {
814 let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
815 match guard.execute_async(cx, &sql, &[]).await {
816 Outcome::Ok(_) => {}
817 Outcome::Err(e) => return Outcome::Err(e),
818 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
819 Outcome::Panicked(p) => return Outcome::Panicked(p),
820 }
821 }
822
823 match guard.execute_async(cx, "BEGIN", &[]).await {
824 Outcome::Ok(_) => {}
825 Outcome::Err(e) => return Outcome::Err(e),
826 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
827 Outcome::Panicked(p) => return Outcome::Panicked(p),
828 }
829
830 drop(guard);
831 Outcome::Ok(SharedPgTransaction {
832 inner,
833 committed: false,
834 _marker: std::marker::PhantomData,
835 })
836 }
837}
838
839impl Clone for SharedPgConnection {
840 fn clone(&self) -> Self {
841 Self {
842 inner: Arc::clone(&self.inner),
843 }
844 }
845}
846
847impl std::fmt::Debug for SharedPgConnection {
848 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849 f.debug_struct("SharedPgConnection")
850 .field("inner", &"Arc<Mutex<PgAsyncConnection>>")
851 .finish()
852 }
853}
854
855pub struct SharedPgTransaction<'conn> {
856 inner: Arc<Mutex<PgAsyncConnection>>,
857 committed: bool,
858 _marker: std::marker::PhantomData<&'conn ()>,
859}
860
861impl<'conn> Drop for SharedPgTransaction<'conn> {
862 fn drop(&mut self) {
863 if !self.committed {
864 #[cfg(debug_assertions)]
869 eprintln!(
870 "WARNING: SharedPgTransaction dropped without commit/rollback. \
871 The PostgreSQL transaction may still be open."
872 );
873 }
874 }
875}
876
877impl Connection for SharedPgConnection {
878 type Tx<'conn>
879 = SharedPgTransaction<'conn>
880 where
881 Self: 'conn;
882
883 fn query(
884 &self,
885 cx: &Cx,
886 sql: &str,
887 params: &[Value],
888 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
889 let inner = Arc::clone(&self.inner);
890 let sql = sql.to_string();
891 let params = params.to_vec();
892 async move {
893 let Ok(mut guard) = inner.lock(cx).await else {
894 return Outcome::Err(connection_error("Failed to acquire connection lock"));
895 };
896 guard.query_async(cx, &sql, ¶ms).await
897 }
898 }
899
900 fn query_one(
901 &self,
902 cx: &Cx,
903 sql: &str,
904 params: &[Value],
905 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
906 let inner = Arc::clone(&self.inner);
907 let sql = sql.to_string();
908 let params = params.to_vec();
909 async move {
910 let Ok(mut guard) = inner.lock(cx).await else {
911 return Outcome::Err(connection_error("Failed to acquire connection lock"));
912 };
913 let rows = match guard.query_async(cx, &sql, ¶ms).await {
914 Outcome::Ok(r) => r,
915 Outcome::Err(e) => return Outcome::Err(e),
916 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
917 Outcome::Panicked(p) => return Outcome::Panicked(p),
918 };
919 Outcome::Ok(rows.into_iter().next())
920 }
921 }
922
923 fn execute(
924 &self,
925 cx: &Cx,
926 sql: &str,
927 params: &[Value],
928 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
929 let inner = Arc::clone(&self.inner);
930 let sql = sql.to_string();
931 let params = params.to_vec();
932 async move {
933 let Ok(mut guard) = inner.lock(cx).await else {
934 return Outcome::Err(connection_error("Failed to acquire connection lock"));
935 };
936 guard.execute_async(cx, &sql, ¶ms).await
937 }
938 }
939
940 fn insert(
941 &self,
942 cx: &Cx,
943 sql: &str,
944 params: &[Value],
945 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
946 let inner = Arc::clone(&self.inner);
947 let sql = sql.to_string();
948 let params = params.to_vec();
949 async move {
950 let Ok(mut guard) = inner.lock(cx).await else {
951 return Outcome::Err(connection_error("Failed to acquire connection lock"));
952 };
953 guard.insert_async(cx, &sql, ¶ms).await
954 }
955 }
956
957 fn batch(
958 &self,
959 cx: &Cx,
960 statements: &[(String, Vec<Value>)],
961 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
962 let inner = Arc::clone(&self.inner);
963 let statements = statements.to_vec();
964 async move {
965 let Ok(mut guard) = inner.lock(cx).await else {
966 return Outcome::Err(connection_error("Failed to acquire connection lock"));
967 };
968 let mut results = Vec::with_capacity(statements.len());
969 for (sql, params) in &statements {
970 match guard.execute_async(cx, sql, params).await {
971 Outcome::Ok(n) => results.push(n),
972 Outcome::Err(e) => return Outcome::Err(e),
973 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
974 Outcome::Panicked(p) => return Outcome::Panicked(p),
975 }
976 }
977 Outcome::Ok(results)
978 }
979 }
980
981 fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
982 self.begin_with(cx, IsolationLevel::default())
983 }
984
985 fn begin_with(
986 &self,
987 cx: &Cx,
988 isolation: IsolationLevel,
989 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
990 self.begin_transaction_impl(cx, Some(isolation))
991 }
992
993 fn prepare(
994 &self,
995 _cx: &Cx,
996 sql: &str,
997 ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
998 let sql = sql.to_string();
999 async move {
1000 Outcome::Ok(PreparedStatement::new(0, sql, 0))
1004 }
1005 }
1006
1007 fn query_prepared(
1008 &self,
1009 cx: &Cx,
1010 stmt: &PreparedStatement,
1011 params: &[Value],
1012 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1013 self.query(cx, stmt.sql(), params)
1014 }
1015
1016 fn execute_prepared(
1017 &self,
1018 cx: &Cx,
1019 stmt: &PreparedStatement,
1020 params: &[Value],
1021 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1022 self.execute(cx, stmt.sql(), params)
1023 }
1024
1025 fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1026 let inner = Arc::clone(&self.inner);
1027 async move {
1028 let Ok(mut guard) = inner.lock(cx).await else {
1029 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1030 };
1031 guard.ping_async(cx).await
1032 }
1033 }
1034
1035 async fn close(self, cx: &Cx) -> sqlmodel_core::Result<()> {
1036 let Ok(mut guard) = self.inner.lock(cx).await else {
1037 return Err(connection_error("Failed to acquire connection lock"));
1038 };
1039 match guard.close_async(cx).await {
1040 Outcome::Ok(()) => Ok(()),
1041 Outcome::Err(e) => Err(e),
1042 Outcome::Cancelled(r) => Err(Error::Query(QueryError {
1043 kind: QueryErrorKind::Cancelled,
1044 message: format!("Cancelled: {r:?}"),
1045 sqlstate: None,
1046 sql: None,
1047 detail: None,
1048 hint: None,
1049 position: None,
1050 source: None,
1051 })),
1052 Outcome::Panicked(p) => Err(Error::Protocol(ProtocolError {
1053 message: format!("Panicked: {p:?}"),
1054 raw_data: None,
1055 source: None,
1056 })),
1057 }
1058 }
1059}
1060
1061impl<'conn> TransactionOps for SharedPgTransaction<'conn> {
1062 fn query(
1063 &self,
1064 cx: &Cx,
1065 sql: &str,
1066 params: &[Value],
1067 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1068 let inner = Arc::clone(&self.inner);
1069 let sql = sql.to_string();
1070 let params = params.to_vec();
1071 async move {
1072 let Ok(mut guard) = inner.lock(cx).await else {
1073 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1074 };
1075 guard.query_async(cx, &sql, ¶ms).await
1076 }
1077 }
1078
1079 fn query_one(
1080 &self,
1081 cx: &Cx,
1082 sql: &str,
1083 params: &[Value],
1084 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1085 let inner = Arc::clone(&self.inner);
1086 let sql = sql.to_string();
1087 let params = params.to_vec();
1088 async move {
1089 let Ok(mut guard) = inner.lock(cx).await else {
1090 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1091 };
1092 let rows = match guard.query_async(cx, &sql, ¶ms).await {
1093 Outcome::Ok(r) => r,
1094 Outcome::Err(e) => return Outcome::Err(e),
1095 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1096 Outcome::Panicked(p) => return Outcome::Panicked(p),
1097 };
1098 Outcome::Ok(rows.into_iter().next())
1099 }
1100 }
1101
1102 fn execute(
1103 &self,
1104 cx: &Cx,
1105 sql: &str,
1106 params: &[Value],
1107 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1108 let inner = Arc::clone(&self.inner);
1109 let sql = sql.to_string();
1110 let params = params.to_vec();
1111 async move {
1112 let Ok(mut guard) = inner.lock(cx).await else {
1113 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1114 };
1115 guard.execute_async(cx, &sql, ¶ms).await
1116 }
1117 }
1118
1119 fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1120 let inner = Arc::clone(&self.inner);
1121 let name = name.to_string();
1122 async move {
1123 if let Err(e) = validate_savepoint_name(&name) {
1124 return Outcome::Err(e);
1125 }
1126 let sql = format!("SAVEPOINT {}", name);
1127 let Ok(mut guard) = inner.lock(cx).await else {
1128 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1129 };
1130 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1131 }
1132 }
1133
1134 fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1135 let inner = Arc::clone(&self.inner);
1136 let name = name.to_string();
1137 async move {
1138 if let Err(e) = validate_savepoint_name(&name) {
1139 return Outcome::Err(e);
1140 }
1141 let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
1142 let Ok(mut guard) = inner.lock(cx).await else {
1143 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1144 };
1145 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1146 }
1147 }
1148
1149 fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1150 let inner = Arc::clone(&self.inner);
1151 let name = name.to_string();
1152 async move {
1153 if let Err(e) = validate_savepoint_name(&name) {
1154 return Outcome::Err(e);
1155 }
1156 let sql = format!("RELEASE SAVEPOINT {}", name);
1157 let Ok(mut guard) = inner.lock(cx).await else {
1158 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1159 };
1160 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1161 }
1162 }
1163
1164 #[allow(unused_assignments)]
1166 fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1167 let inner = Arc::clone(&self.inner);
1168 async move {
1169 let Ok(mut guard) = inner.lock(cx).await else {
1170 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1171 };
1172 let result = guard.execute_async(cx, "COMMIT", &[]).await;
1173 if matches!(result, Outcome::Ok(_)) {
1174 self.committed = true;
1175 }
1176 result.map(|_| ())
1177 }
1178 }
1179
1180 #[allow(unused_assignments)]
1181 fn rollback(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1182 let inner = Arc::clone(&self.inner);
1183 async move {
1184 let Ok(mut guard) = inner.lock(cx).await else {
1185 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1186 };
1187 let result = guard.execute_async(cx, "ROLLBACK", &[]).await;
1188 if matches!(result, Outcome::Ok(_)) {
1189 self.committed = true;
1190 }
1191 result.map(|_| ())
1192 }
1193 }
1194}
1195
1196struct PgQueryResult {
1199 rows: Vec<Row>,
1200 command_tag: Option<String>,
1201}
1202
1203fn connection_error(msg: impl Into<String>) -> Error {
1204 Error::Connection(ConnectionError {
1205 kind: ConnectionErrorKind::Connect,
1206 message: msg.into(),
1207 source: None,
1208 })
1209}
1210
1211fn auth_error(msg: impl Into<String>) -> Error {
1212 Error::Connection(ConnectionError {
1213 kind: ConnectionErrorKind::Authentication,
1214 message: msg.into(),
1215 source: None,
1216 })
1217}
1218
1219fn protocol_error(msg: impl Into<String>) -> Error {
1220 Error::Protocol(ProtocolError {
1221 message: msg.into(),
1222 raw_data: None,
1223 source: None,
1224 })
1225}
1226
1227fn query_error_msg(msg: impl Into<String>, kind: QueryErrorKind) -> Error {
1228 Error::Query(QueryError {
1229 kind,
1230 message: msg.into(),
1231 sqlstate: None,
1232 sql: None,
1233 detail: None,
1234 hint: None,
1235 position: None,
1236 source: None,
1237 })
1238}
1239
1240fn error_from_fields(fields: &ErrorFields) -> Error {
1241 let kind = match fields.code.get(..2) {
1242 Some("08") => {
1243 return Error::Connection(ConnectionError {
1244 kind: ConnectionErrorKind::Connect,
1245 message: fields.message.clone(),
1246 source: None,
1247 });
1248 }
1249 Some("28") => {
1250 return Error::Connection(ConnectionError {
1251 kind: ConnectionErrorKind::Authentication,
1252 message: fields.message.clone(),
1253 source: None,
1254 });
1255 }
1256 Some("42") => QueryErrorKind::Syntax,
1257 Some("23") => QueryErrorKind::Constraint,
1258 Some("40") => {
1259 if fields.code == "40001" {
1260 QueryErrorKind::Serialization
1261 } else {
1262 QueryErrorKind::Deadlock
1263 }
1264 }
1265 Some("57") => {
1266 if fields.code == "57014" {
1267 QueryErrorKind::Cancelled
1268 } else {
1269 QueryErrorKind::Timeout
1270 }
1271 }
1272 _ => QueryErrorKind::Database,
1273 };
1274
1275 Error::Query(QueryError {
1276 kind,
1277 sql: None,
1278 sqlstate: Some(fields.code.clone()),
1279 message: fields.message.clone(),
1280 detail: fields.detail.clone(),
1281 hint: fields.hint.clone(),
1282 position: fields.position.map(|p| p as usize),
1283 source: None,
1284 })
1285}
1286
1287fn parse_rows_affected(tag: Option<&str>) -> Option<u64> {
1288 let tag = tag?;
1289 let mut parts = tag.split_whitespace().collect::<Vec<_>>();
1290 parts.pop().and_then(|last| last.parse::<u64>().ok())
1291}
1292
1293fn validate_savepoint_name(name: &str) -> sqlmodel_core::Result<()> {
1295 if name.is_empty() {
1296 return Err(query_error_msg(
1297 "Savepoint name cannot be empty",
1298 QueryErrorKind::Syntax,
1299 ));
1300 }
1301 if name.len() > 63 {
1302 return Err(query_error_msg(
1303 "Savepoint name exceeds maximum length of 63 characters",
1304 QueryErrorKind::Syntax,
1305 ));
1306 }
1307 let mut chars = name.chars();
1308 let Some(first) = chars.next() else {
1309 return Err(query_error_msg(
1310 "Savepoint name cannot be empty",
1311 QueryErrorKind::Syntax,
1312 ));
1313 };
1314 if !first.is_ascii_alphabetic() && first != '_' {
1315 return Err(query_error_msg(
1316 "Savepoint name must start with a letter or underscore",
1317 QueryErrorKind::Syntax,
1318 ));
1319 }
1320 for c in chars {
1321 if !c.is_ascii_alphanumeric() && c != '_' {
1322 return Err(query_error_msg(
1323 format!("Savepoint name contains invalid character: '{c}'"),
1324 QueryErrorKind::Syntax,
1325 ));
1326 }
1327 }
1328 Ok(())
1329}
1330
1331fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
1332 use std::fmt::Write;
1333
1334 let inner = format!("{password}{user}");
1335 let inner_hash = md5::compute(inner.as_bytes());
1336
1337 let mut outer_input = format!("{inner_hash:x}").into_bytes();
1338 outer_input.extend_from_slice(&salt);
1339 let outer_hash = md5::compute(&outer_input);
1340
1341 let mut result = String::with_capacity(35);
1342 result.push_str("md5");
1343 write!(&mut result, "{outer_hash:x}").unwrap();
1344 result
1345}
1346
1347async fn read_exact_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<()> {
1348 let mut read = 0;
1349 while read < buf.len() {
1350 let mut read_buf = ReadBuf::new(&mut buf[read..]);
1351 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
1352 .await?;
1353 let n = read_buf.filled().len();
1354 if n == 0 {
1355 return Err(std::io::Error::new(
1356 std::io::ErrorKind::UnexpectedEof,
1357 "connection closed",
1358 ));
1359 }
1360 read += n;
1361 }
1362 Ok(())
1363}