Skip to main content

sqlmodel_postgres/
async_connection.rs

1//! Async PostgreSQL connection implementation.
2//!
3//! This module implements an async PostgreSQL connection using asupersync's TCP
4//! primitives. It provides a shared wrapper that implements `sqlmodel-core`'s
5//! [`Connection`] trait.
6//!
7//! The implementation currently focuses on:
8//! - Async connect + authentication (cleartext, MD5, SCRAM-SHA-256)
9//! - Extended query protocol for parameterized queries
10//! - Row decoding via the postgres type registry (OID + text/binary format)
11//! - Basic transaction support (BEGIN/COMMIT/ROLLBACK + savepoints)
12
13// Allow `impl Future` return types in trait methods - intentional for async trait compat
14#![allow(clippy::manual_async_fn)]
15// The Error type is intentionally large to carry full context
16#![allow(clippy::result_large_err)]
17
18use std::collections::HashMap;
19use std::future::Future;
20#[cfg(feature = "tls")]
21use std::io::{Read, Write};
22use std::sync::Arc;
23
24use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
25use asupersync::net::TcpStream;
26use asupersync::sync::Mutex;
27use asupersync::{Cx, Outcome};
28
29use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
30use sqlmodel_core::error::{
31    ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
32};
33use sqlmodel_core::row::ColumnInfo;
34use sqlmodel_core::{Error, Row, Value};
35
36use crate::auth::ScramClient;
37use crate::config::{PgConfig, SslMode};
38use crate::connection::{ConnectionState, TransactionStatusState};
39use crate::protocol::{
40    BackendMessage, DescribeKind, ErrorFields, FrontendMessage, MessageReader, MessageWriter,
41    PROTOCOL_VERSION,
42};
43use crate::types::{Format, decode_value, encode_value};
44
45#[cfg(feature = "tls")]
46use crate::tls;
47
48enum PgAsyncStream {
49    Plain(TcpStream),
50    #[cfg(feature = "tls")]
51    Tls(AsyncTlsStream),
52    #[cfg(feature = "tls")]
53    Closed,
54}
55
56impl PgAsyncStream {
57    #[cfg(feature = "tls")]
58    async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59        match self {
60            PgAsyncStream::Plain(s) => read_exact_plain_async(s, buf).await,
61            #[cfg(feature = "tls")]
62            PgAsyncStream::Tls(s) => s.read_exact(buf).await,
63            #[cfg(feature = "tls")]
64            PgAsyncStream::Closed => Err(std::io::Error::new(
65                std::io::ErrorKind::NotConnected,
66                "connection closed",
67            )),
68        }
69    }
70
71    async fn read_some(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
72        match self {
73            PgAsyncStream::Plain(s) => read_some_plain_async(s, buf).await,
74            #[cfg(feature = "tls")]
75            PgAsyncStream::Tls(s) => s.read_plain(buf).await,
76            #[cfg(feature = "tls")]
77            PgAsyncStream::Closed => Err(std::io::Error::new(
78                std::io::ErrorKind::NotConnected,
79                "connection closed",
80            )),
81        }
82    }
83
84    async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
85        match self {
86            PgAsyncStream::Plain(s) => write_all_plain_async(s, buf).await,
87            #[cfg(feature = "tls")]
88            PgAsyncStream::Tls(s) => s.write_all(buf).await,
89            #[cfg(feature = "tls")]
90            PgAsyncStream::Closed => Err(std::io::Error::new(
91                std::io::ErrorKind::NotConnected,
92                "connection closed",
93            )),
94        }
95    }
96
97    async fn flush(&mut self) -> std::io::Result<()> {
98        match self {
99            PgAsyncStream::Plain(s) => flush_plain_async(s).await,
100            #[cfg(feature = "tls")]
101            PgAsyncStream::Tls(s) => s.flush().await,
102            #[cfg(feature = "tls")]
103            PgAsyncStream::Closed => Err(std::io::Error::new(
104                std::io::ErrorKind::NotConnected,
105                "connection closed",
106            )),
107        }
108    }
109}
110
111#[cfg(feature = "tls")]
112struct AsyncTlsStream {
113    tcp: TcpStream,
114    tls: rustls::ClientConnection,
115}
116
117#[cfg(feature = "tls")]
118impl AsyncTlsStream {
119    async fn handshake(mut tcp: TcpStream, ssl_mode: SslMode, host: &str) -> Result<Self, Error> {
120        let config = tls::build_client_config(ssl_mode)?;
121        let server_name = tls::server_name(host)?;
122        let mut tls = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
123            .map_err(|e| connection_error(format!("Failed to create TLS connection: {e}")))?;
124
125        while tls.is_handshaking() {
126            while tls.wants_write() {
127                let mut out = Vec::new();
128                tls.write_tls(&mut out)
129                    .map_err(|e| connection_error(format!("TLS handshake write_tls error: {e}")))?;
130                if !out.is_empty() {
131                    write_all_plain_async(&mut tcp, &out).await.map_err(|e| {
132                        Error::Connection(ConnectionError {
133                            kind: ConnectionErrorKind::Disconnected,
134                            message: format!("TLS handshake write error: {e}"),
135                            source: Some(Box::new(e)),
136                        })
137                    })?;
138                }
139            }
140
141            if tls.wants_read() {
142                let mut buf = [0u8; 8192];
143                let n = read_some_plain_async(&mut tcp, &mut buf)
144                    .await
145                    .map_err(|e| {
146                        Error::Connection(ConnectionError {
147                            kind: ConnectionErrorKind::Disconnected,
148                            message: format!("TLS handshake read error: {e}"),
149                            source: Some(Box::new(e)),
150                        })
151                    })?;
152                if n == 0 {
153                    return Err(connection_error("Connection closed during TLS handshake"));
154                }
155
156                let mut cursor = std::io::Cursor::new(&buf[..n]);
157                tls.read_tls(&mut cursor)
158                    .map_err(|e| connection_error(format!("TLS handshake read_tls error: {e}")))?;
159                tls.process_new_packets()
160                    .map_err(|e| connection_error(format!("TLS handshake error: {e}")))?;
161            }
162        }
163
164        Ok(Self { tcp, tls })
165    }
166
167    async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
168        let mut read = 0;
169        while read < buf.len() {
170            let n = self.read_plain(&mut buf[read..]).await?;
171            if n == 0 {
172                return Err(std::io::Error::new(
173                    std::io::ErrorKind::UnexpectedEof,
174                    "connection closed",
175                ));
176            }
177            read += n;
178        }
179        Ok(())
180    }
181
182    async fn read_plain(&mut self, out: &mut [u8]) -> std::io::Result<usize> {
183        loop {
184            match self.tls.reader().read(out) {
185                Ok(n) if n > 0 => return Ok(n),
186                Ok(_) => {}
187                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
188                Err(e) => return Err(e),
189            }
190
191            if !self.tls.wants_read() {
192                return Ok(0);
193            }
194
195            let mut enc = [0u8; 8192];
196            let n = read_some_plain_async(&mut self.tcp, &mut enc).await?;
197            if n == 0 {
198                return Ok(0);
199            }
200
201            let mut cursor = std::io::Cursor::new(&enc[..n]);
202            self.tls.read_tls(&mut cursor)?;
203            self.tls
204                .process_new_packets()
205                .map_err(|e| std::io::Error::other(format!("TLS error: {e}")))?;
206        }
207    }
208
209    async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
210        let mut written = 0;
211        while written < buf.len() {
212            let n = self.tls.writer().write(&buf[written..])?;
213            if n == 0 {
214                return Err(std::io::Error::new(
215                    std::io::ErrorKind::WriteZero,
216                    "TLS write zero",
217                ));
218            }
219            written += n;
220            self.flush().await?;
221        }
222        Ok(())
223    }
224
225    async fn flush(&mut self) -> std::io::Result<()> {
226        self.tls.writer().flush()?;
227        while self.tls.wants_write() {
228            let mut out = Vec::new();
229            self.tls.write_tls(&mut out)?;
230            if !out.is_empty() {
231                write_all_plain_async(&mut self.tcp, &out).await?;
232            }
233        }
234        flush_plain_async(&mut self.tcp).await
235    }
236}
237
238#[cfg(feature = "tls")]
239async fn read_exact_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<()> {
240    let mut read = 0;
241    while read < buf.len() {
242        let n = read_some_plain_async(stream, &mut buf[read..]).await?;
243        if n == 0 {
244            return Err(std::io::Error::new(
245                std::io::ErrorKind::UnexpectedEof,
246                "connection closed",
247            ));
248        }
249        read += n;
250    }
251    Ok(())
252}
253
254async fn read_some_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<usize> {
255    let mut read_buf = ReadBuf::new(buf);
256    std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
257        .await?;
258    Ok(read_buf.filled().len())
259}
260
261async fn write_all_plain_async(stream: &mut TcpStream, buf: &[u8]) -> std::io::Result<()> {
262    let mut written = 0;
263    while written < buf.len() {
264        let n = std::future::poll_fn(|cx| {
265            std::pin::Pin::new(&mut *stream).poll_write(cx, &buf[written..])
266        })
267        .await?;
268        if n == 0 {
269            return Err(std::io::Error::new(
270                std::io::ErrorKind::WriteZero,
271                "connection closed",
272            ));
273        }
274        written += n;
275    }
276    Ok(())
277}
278
279async fn flush_plain_async(stream: &mut TcpStream) -> std::io::Result<()> {
280    std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
281}
282
283/// Async PostgreSQL connection.
284///
285/// This connection uses asupersync's TCP stream for non-blocking I/O and
286/// supports the extended query protocol for parameter binding.
287pub struct PgAsyncConnection {
288    stream: PgAsyncStream,
289    state: ConnectionState,
290    process_id: i32,
291    secret_key: i32,
292    parameters: HashMap<String, String>,
293    next_prepared_id: u64,
294    prepared: HashMap<u64, PgPreparedMeta>,
295    config: PgConfig,
296    reader: MessageReader,
297    writer: MessageWriter,
298    read_buf: Vec<u8>,
299}
300
301#[derive(Debug, Clone)]
302struct PgPreparedMeta {
303    name: String,
304    param_type_oids: Vec<u32>,
305}
306
307impl std::fmt::Debug for PgAsyncConnection {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        f.debug_struct("PgAsyncConnection")
310            .field("state", &self.state)
311            .field("process_id", &self.process_id)
312            .field("host", &self.config.host)
313            .field("port", &self.config.port)
314            .field("database", &self.config.database)
315            .finish_non_exhaustive()
316    }
317}
318
319impl PgAsyncConnection {
320    /// Establish a new async connection to the PostgreSQL server.
321    pub async fn connect(_cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
322        let addr = config.socket_addr();
323        let socket_addr = match addr.parse() {
324            Ok(a) => a,
325            Err(e) => {
326                return Outcome::Err(Error::Connection(ConnectionError {
327                    kind: ConnectionErrorKind::Connect,
328                    message: format!("Invalid socket address: {}", e),
329                    source: None,
330                }));
331            }
332        };
333
334        let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
335            Ok(s) => s,
336            Err(e) => {
337                let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
338                    ConnectionErrorKind::Refused
339                } else {
340                    ConnectionErrorKind::Connect
341                };
342                return Outcome::Err(Error::Connection(ConnectionError {
343                    kind,
344                    message: format!("Failed to connect to {}: {}", addr, e),
345                    source: Some(Box::new(e)),
346                }));
347            }
348        };
349
350        stream.set_nodelay(true).ok();
351
352        let mut conn = Self {
353            stream: PgAsyncStream::Plain(stream),
354            state: ConnectionState::Connecting,
355            process_id: 0,
356            secret_key: 0,
357            parameters: HashMap::new(),
358            next_prepared_id: 1,
359            prepared: HashMap::new(),
360            config,
361            reader: MessageReader::new(),
362            writer: MessageWriter::new(),
363            read_buf: vec![0u8; 8192],
364        };
365
366        // SSL negotiation (feature-gated TLS)
367        if conn.config.ssl_mode.should_try_ssl() {
368            #[cfg(feature = "tls")]
369            match conn.negotiate_ssl().await {
370                Outcome::Ok(()) => {}
371                Outcome::Err(e) => return Outcome::Err(e),
372                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
373                Outcome::Panicked(p) => return Outcome::Panicked(p),
374            }
375
376            #[cfg(not(feature = "tls"))]
377            if conn.config.ssl_mode != SslMode::Prefer {
378                return Outcome::Err(connection_error(
379                    "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
380                ));
381            }
382        }
383
384        // Startup + authentication
385        if let Outcome::Err(e) = conn.send_startup().await {
386            return Outcome::Err(e);
387        }
388        conn.state = ConnectionState::Authenticating;
389
390        match conn.handle_auth().await {
391            Outcome::Ok(()) => {}
392            Outcome::Err(e) => return Outcome::Err(e),
393            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
394            Outcome::Panicked(p) => return Outcome::Panicked(p),
395        }
396
397        match conn.read_startup_messages().await {
398            Outcome::Ok(()) => Outcome::Ok(conn),
399            Outcome::Err(e) => Outcome::Err(e),
400            Outcome::Cancelled(r) => Outcome::Cancelled(r),
401            Outcome::Panicked(p) => Outcome::Panicked(p),
402        }
403    }
404
405    /// Run a parameterized query and return all rows.
406    pub async fn query_async(
407        &mut self,
408        cx: &Cx,
409        sql: &str,
410        params: &[Value],
411    ) -> Outcome<Vec<Row>, Error> {
412        match self.run_extended(cx, sql, params).await {
413            Outcome::Ok(result) => Outcome::Ok(result.rows),
414            Outcome::Err(e) => Outcome::Err(e),
415            Outcome::Cancelled(r) => Outcome::Cancelled(r),
416            Outcome::Panicked(p) => Outcome::Panicked(p),
417        }
418    }
419
420    /// Execute a statement and return rows affected.
421    pub async fn execute_async(
422        &mut self,
423        cx: &Cx,
424        sql: &str,
425        params: &[Value],
426    ) -> Outcome<u64, Error> {
427        match self.run_extended(cx, sql, params).await {
428            Outcome::Ok(result) => {
429                Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
430            }
431            Outcome::Err(e) => Outcome::Err(e),
432            Outcome::Cancelled(r) => Outcome::Cancelled(r),
433            Outcome::Panicked(p) => Outcome::Panicked(p),
434        }
435    }
436
437    /// Execute an INSERT and return the inserted id.
438    ///
439    /// PostgreSQL requires `RETURNING` to retrieve generated IDs. This method
440    /// expects the SQL to return a single-row, single-column result set
441    /// containing an integer id.
442    pub async fn insert_async(
443        &mut self,
444        cx: &Cx,
445        sql: &str,
446        params: &[Value],
447    ) -> Outcome<i64, Error> {
448        let result = match self.run_extended(cx, sql, params).await {
449            Outcome::Ok(r) => r,
450            Outcome::Err(e) => return Outcome::Err(e),
451            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
452            Outcome::Panicked(p) => return Outcome::Panicked(p),
453        };
454
455        let Some(row) = result.rows.first() else {
456            return Outcome::Err(query_error_msg(
457                "INSERT did not return an id; add `RETURNING id`",
458                QueryErrorKind::Database,
459            ));
460        };
461        let Some(id_value) = row.get(0) else {
462            return Outcome::Err(query_error_msg(
463                "INSERT result row missing id column",
464                QueryErrorKind::Database,
465            ));
466        };
467        match id_value.as_i64() {
468            Some(v) => Outcome::Ok(v),
469            None => Outcome::Err(query_error_msg(
470                "INSERT returned non-integer id",
471                QueryErrorKind::Database,
472            )),
473        }
474    }
475
476    /// Ping the server.
477    pub async fn ping_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
478        self.execute_async(cx, "SELECT 1", &[]).await.map(|_| ())
479    }
480
481    /// Close the connection.
482    pub async fn close_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
483        // Best-effort terminate. If this fails, the drop will close the socket.
484        //
485        // Note: server-side prepared statements are released when the connection terminates;
486        // explicit Close/DEALLOCATE is not required for correctness here.
487        let _ = self.send_message(cx, &FrontendMessage::Terminate).await;
488        self.state = ConnectionState::Closed;
489        Outcome::Ok(())
490    }
491
492    // ==================== Prepared statements ====================
493
494    /// Prepare a server-side statement and return a reusable handle.
495    pub async fn prepare_async(&mut self, cx: &Cx, sql: &str) -> Outcome<PreparedStatement, Error> {
496        let stmt_id = self.next_prepared_id;
497        self.next_prepared_id = self.next_prepared_id.saturating_add(1);
498        let stmt_name = format!("sqlmodel_stmt_{stmt_id}");
499
500        if let Outcome::Err(e) = self
501            .send_message(
502                cx,
503                &FrontendMessage::Parse {
504                    name: stmt_name.clone(),
505                    query: sql.to_string(),
506                    // Let PostgreSQL infer types where possible; ambiguous queries will error
507                    // and should add explicit casts.
508                    param_types: Vec::new(),
509                },
510            )
511            .await
512        {
513            return Outcome::Err(e);
514        }
515
516        if let Outcome::Err(e) = self
517            .send_message(
518                cx,
519                &FrontendMessage::Describe {
520                    kind: DescribeKind::Statement,
521                    name: stmt_name.clone(),
522                },
523            )
524            .await
525        {
526            return Outcome::Err(e);
527        }
528
529        if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
530            return Outcome::Err(e);
531        }
532
533        let mut param_type_oids: Option<Vec<u32>> = None;
534        let mut columns: Option<Vec<String>> = None;
535
536        loop {
537            let msg = match self.receive_message(cx).await {
538                Outcome::Ok(m) => m,
539                Outcome::Err(e) => return Outcome::Err(e),
540                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
541                Outcome::Panicked(p) => return Outcome::Panicked(p),
542            };
543
544            match msg {
545                BackendMessage::ParseComplete
546                | BackendMessage::BindComplete
547                | BackendMessage::CloseComplete
548                | BackendMessage::NoData
549                | BackendMessage::EmptyQueryResponse => {}
550                BackendMessage::ParameterDescription(oids) => {
551                    param_type_oids = Some(oids);
552                }
553                BackendMessage::RowDescription(desc) => {
554                    columns = Some(desc.iter().map(|f| f.name.clone()).collect());
555                }
556                BackendMessage::ReadyForQuery(status) => {
557                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
558                    break;
559                }
560                BackendMessage::ErrorResponse(e) => {
561                    self.state = ConnectionState::Error;
562                    return Outcome::Err(error_from_fields(&e));
563                }
564                BackendMessage::NoticeResponse(_notice) => {}
565                other => {
566                    return Outcome::Err(protocol_error(format!(
567                        "Unexpected message during prepare: {other:?}"
568                    )));
569                }
570            }
571        }
572
573        let param_type_oids = param_type_oids.unwrap_or_default();
574        self.prepared.insert(
575            stmt_id,
576            PgPreparedMeta {
577                name: stmt_name,
578                param_type_oids: param_type_oids.clone(),
579            },
580        );
581
582        match columns {
583            Some(cols) => Outcome::Ok(PreparedStatement::with_columns(
584                stmt_id,
585                sql.to_string(),
586                param_type_oids.len(),
587                cols,
588            )),
589            None => Outcome::Ok(PreparedStatement::new(
590                stmt_id,
591                sql.to_string(),
592                param_type_oids.len(),
593            )),
594        }
595    }
596
597    pub async fn query_prepared_async(
598        &mut self,
599        cx: &Cx,
600        stmt: &PreparedStatement,
601        params: &[Value],
602    ) -> Outcome<Vec<Row>, Error> {
603        let meta = match self.prepared.get(&stmt.id()) {
604            Some(m) => m.clone(),
605            None => {
606                return Outcome::Err(query_error_msg(
607                    format!("Unknown prepared statement id {}", stmt.id()),
608                    QueryErrorKind::Database,
609                ));
610            }
611        };
612
613        if meta.param_type_oids.len() != params.len() {
614            return Outcome::Err(query_error_msg(
615                format!(
616                    "Prepared statement expects {} params, got {}",
617                    meta.param_type_oids.len(),
618                    params.len()
619                ),
620                QueryErrorKind::Database,
621            ));
622        }
623
624        match self.run_prepared(cx, &meta, params).await {
625            Outcome::Ok(result) => Outcome::Ok(result.rows),
626            Outcome::Err(e) => Outcome::Err(e),
627            Outcome::Cancelled(r) => Outcome::Cancelled(r),
628            Outcome::Panicked(p) => Outcome::Panicked(p),
629        }
630    }
631
632    pub async fn execute_prepared_async(
633        &mut self,
634        cx: &Cx,
635        stmt: &PreparedStatement,
636        params: &[Value],
637    ) -> Outcome<u64, Error> {
638        let meta = match self.prepared.get(&stmt.id()) {
639            Some(m) => m.clone(),
640            None => {
641                return Outcome::Err(query_error_msg(
642                    format!("Unknown prepared statement id {}", stmt.id()),
643                    QueryErrorKind::Database,
644                ));
645            }
646        };
647
648        if meta.param_type_oids.len() != params.len() {
649            return Outcome::Err(query_error_msg(
650                format!(
651                    "Prepared statement expects {} params, got {}",
652                    meta.param_type_oids.len(),
653                    params.len()
654                ),
655                QueryErrorKind::Database,
656            ));
657        }
658
659        match self.run_prepared(cx, &meta, params).await {
660            Outcome::Ok(result) => {
661                Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
662            }
663            Outcome::Err(e) => Outcome::Err(e),
664            Outcome::Cancelled(r) => Outcome::Cancelled(r),
665            Outcome::Panicked(p) => Outcome::Panicked(p),
666        }
667    }
668
669    // ==================== Protocol: extended query ====================
670
671    async fn read_extended_result(&mut self, cx: &Cx) -> Outcome<PgQueryResult, Error> {
672        // Read responses until ReadyForQuery
673        let mut field_descs: Option<Vec<crate::protocol::FieldDescription>> = None;
674        let mut columns: Option<Arc<ColumnInfo>> = None;
675        let mut rows: Vec<Row> = Vec::new();
676        let mut command_tag: Option<String> = None;
677
678        loop {
679            let msg = match self.receive_message(cx).await {
680                Outcome::Ok(m) => m,
681                Outcome::Err(e) => return Outcome::Err(e),
682                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
683                Outcome::Panicked(p) => return Outcome::Panicked(p),
684            };
685
686            match msg {
687                BackendMessage::ParseComplete
688                | BackendMessage::BindComplete
689                | BackendMessage::CloseComplete
690                | BackendMessage::ParameterDescription(_)
691                | BackendMessage::NoData
692                | BackendMessage::PortalSuspended
693                | BackendMessage::EmptyQueryResponse => {}
694                BackendMessage::RowDescription(desc) => {
695                    let names: Vec<String> = desc.iter().map(|f| f.name.clone()).collect();
696                    columns = Some(Arc::new(ColumnInfo::new(names)));
697                    field_descs = Some(desc);
698                }
699                BackendMessage::DataRow(raw_values) => {
700                    let Some(ref desc) = field_descs else {
701                        return Outcome::Err(protocol_error(
702                            "DataRow received before RowDescription",
703                        ));
704                    };
705                    let Some(ref cols) = columns else {
706                        return Outcome::Err(protocol_error("Row column metadata missing"));
707                    };
708                    if raw_values.len() != desc.len() {
709                        return Outcome::Err(protocol_error("DataRow field count mismatch"));
710                    }
711
712                    let mut values = Vec::with_capacity(raw_values.len());
713                    for (i, raw) in raw_values.into_iter().enumerate() {
714                        match raw {
715                            None => values.push(Value::Null),
716                            Some(bytes) => {
717                                let field = &desc[i];
718                                let format = Format::from_code(field.format);
719                                let decoded = match decode_value(
720                                    field.type_oid,
721                                    Some(bytes.as_slice()),
722                                    format,
723                                ) {
724                                    Ok(v) => v,
725                                    Err(e) => return Outcome::Err(e),
726                                };
727                                values.push(decoded);
728                            }
729                        }
730                    }
731                    rows.push(Row::with_columns(Arc::clone(cols), values));
732                }
733                BackendMessage::CommandComplete(tag) => {
734                    command_tag = Some(tag);
735                }
736                BackendMessage::ReadyForQuery(status) => {
737                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
738                    break;
739                }
740                BackendMessage::ErrorResponse(e) => {
741                    self.state = ConnectionState::Error;
742                    return Outcome::Err(error_from_fields(&e));
743                }
744                BackendMessage::NoticeResponse(_notice) => {}
745                _ => {}
746            }
747        }
748
749        Outcome::Ok(PgQueryResult { rows, command_tag })
750    }
751
752    async fn run_extended(
753        &mut self,
754        cx: &Cx,
755        sql: &str,
756        params: &[Value],
757    ) -> Outcome<PgQueryResult, Error> {
758        // Encode parameters
759        let mut param_types = Vec::with_capacity(params.len());
760        let mut param_values = Vec::with_capacity(params.len());
761
762        for v in params {
763            if matches!(v, Value::Null) {
764                param_types.push(0);
765                param_values.push(None);
766                continue;
767            }
768            match encode_value(v, Format::Text) {
769                Ok((bytes, oid)) => {
770                    param_types.push(oid);
771                    param_values.push(Some(bytes));
772                }
773                Err(e) => return Outcome::Err(e),
774            }
775        }
776
777        // Parse + bind unnamed statement/portal
778        if let Outcome::Err(e) = self
779            .send_message(
780                cx,
781                &FrontendMessage::Parse {
782                    name: String::new(),
783                    query: sql.to_string(),
784                    param_types,
785                },
786            )
787            .await
788        {
789            return Outcome::Err(e);
790        }
791
792        let param_formats = if params.is_empty() {
793            Vec::new()
794        } else {
795            vec![Format::Text.code()]
796        };
797        if let Outcome::Err(e) = self
798            .send_message(
799                cx,
800                &FrontendMessage::Bind {
801                    portal: String::new(),
802                    statement: String::new(),
803                    param_formats,
804                    params: param_values,
805                    // Default result formats (text) when empty.
806                    result_formats: Vec::new(),
807                },
808            )
809            .await
810        {
811            return Outcome::Err(e);
812        }
813
814        if let Outcome::Err(e) = self
815            .send_message(
816                cx,
817                &FrontendMessage::Describe {
818                    kind: DescribeKind::Portal,
819                    name: String::new(),
820                },
821            )
822            .await
823        {
824            return Outcome::Err(e);
825        }
826
827        if let Outcome::Err(e) = self
828            .send_message(
829                cx,
830                &FrontendMessage::Execute {
831                    portal: String::new(),
832                    max_rows: 0,
833                },
834            )
835            .await
836        {
837            return Outcome::Err(e);
838        }
839
840        if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
841            return Outcome::Err(e);
842        }
843        self.read_extended_result(cx).await
844    }
845
846    async fn run_prepared(
847        &mut self,
848        cx: &Cx,
849        meta: &PgPreparedMeta,
850        params: &[Value],
851    ) -> Outcome<PgQueryResult, Error> {
852        let mut param_values = Vec::with_capacity(params.len());
853
854        for (i, v) in params.iter().enumerate() {
855            if matches!(v, Value::Null) {
856                param_values.push(None);
857                continue;
858            }
859            match encode_value(v, Format::Text) {
860                Ok((bytes, oid)) => {
861                    let expected = meta.param_type_oids.get(i).copied().unwrap_or(0);
862                    if expected != 0 && expected != oid {
863                        return Outcome::Err(query_error_msg(
864                            format!(
865                                "Prepared statement param {} expects type OID {}, got {}",
866                                i + 1,
867                                expected,
868                                oid
869                            ),
870                            QueryErrorKind::Database,
871                        ));
872                    }
873                    param_values.push(Some(bytes));
874                }
875                Err(e) => return Outcome::Err(e),
876            }
877        }
878
879        let param_formats = if params.is_empty() {
880            Vec::new()
881        } else {
882            vec![Format::Text.code()]
883        };
884
885        if let Outcome::Err(e) = self
886            .send_message(
887                cx,
888                &FrontendMessage::Bind {
889                    portal: String::new(),
890                    statement: meta.name.clone(),
891                    param_formats,
892                    params: param_values,
893                    result_formats: Vec::new(),
894                },
895            )
896            .await
897        {
898            return Outcome::Err(e);
899        }
900
901        if let Outcome::Err(e) = self
902            .send_message(
903                cx,
904                &FrontendMessage::Describe {
905                    kind: DescribeKind::Portal,
906                    name: String::new(),
907                },
908            )
909            .await
910        {
911            return Outcome::Err(e);
912        }
913
914        if let Outcome::Err(e) = self
915            .send_message(
916                cx,
917                &FrontendMessage::Execute {
918                    portal: String::new(),
919                    max_rows: 0,
920                },
921            )
922            .await
923        {
924            return Outcome::Err(e);
925        }
926
927        if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
928            return Outcome::Err(e);
929        }
930
931        self.read_extended_result(cx).await
932    }
933
934    // ==================== Startup + auth ====================
935
936    #[cfg(feature = "tls")]
937    async fn negotiate_ssl(&mut self) -> Outcome<(), Error> {
938        // Send SSL request
939        if let Outcome::Err(e) = self.send_message_no_cx(&FrontendMessage::SSLRequest).await {
940            return Outcome::Err(e);
941        }
942
943        // Read single-byte response
944        let mut buf = [0u8; 1];
945        if let Err(e) = self.stream.read_exact(&mut buf).await {
946            return Outcome::Err(Error::Connection(ConnectionError {
947                kind: ConnectionErrorKind::Ssl,
948                message: format!("Failed to read SSL response: {}", e),
949                source: Some(Box::new(e)),
950            }));
951        }
952
953        match buf[0] {
954            b'S' => {
955                #[cfg(feature = "tls")]
956                {
957                    let plain = match std::mem::replace(&mut self.stream, PgAsyncStream::Closed) {
958                        PgAsyncStream::Plain(s) => s,
959                        other => {
960                            self.stream = other;
961                            return Outcome::Err(connection_error(
962                                "TLS upgrade requires a plain TCP stream",
963                            ));
964                        }
965                    };
966
967                    let tls_stream = match AsyncTlsStream::handshake(
968                        plain,
969                        self.config.ssl_mode,
970                        &self.config.host,
971                    )
972                    .await
973                    {
974                        Ok(s) => s,
975                        Err(e) => return Outcome::Err(e),
976                    };
977
978                    self.stream = PgAsyncStream::Tls(tls_stream);
979                    Outcome::Ok(())
980                }
981
982                #[cfg(not(feature = "tls"))]
983                {
984                    Outcome::Err(connection_error(
985                        "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
986                    ))
987                }
988            }
989            b'N' => {
990                if self.config.ssl_mode.is_required() {
991                    Outcome::Err(Error::Connection(ConnectionError {
992                        kind: ConnectionErrorKind::Ssl,
993                        message: "Server does not support SSL".to_string(),
994                        source: None,
995                    }))
996                } else {
997                    Outcome::Ok(())
998                }
999            }
1000            other => Outcome::Err(Error::Connection(ConnectionError {
1001                kind: ConnectionErrorKind::Ssl,
1002                message: format!("Unexpected SSL response: 0x{other:02x}"),
1003                source: None,
1004            })),
1005        }
1006    }
1007
1008    async fn send_startup(&mut self) -> Outcome<(), Error> {
1009        let params = self.config.startup_params();
1010        self.send_message_no_cx(&FrontendMessage::Startup {
1011            version: PROTOCOL_VERSION,
1012            params,
1013        })
1014        .await
1015    }
1016
1017    fn require_auth_value(&self, message: &'static str) -> Outcome<&str, Error> {
1018        // NOTE: Auth values are sourced from runtime config, not hardcoded.
1019        match self.config.password.as_deref() {
1020            Some(password) => Outcome::Ok(password),
1021            None => Outcome::Err(auth_error(message)),
1022        }
1023    }
1024
1025    async fn handle_auth(&mut self) -> Outcome<(), Error> {
1026        loop {
1027            let msg = match self.receive_message_no_cx().await {
1028                Outcome::Ok(m) => m,
1029                Outcome::Err(e) => return Outcome::Err(e),
1030                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1031                Outcome::Panicked(p) => return Outcome::Panicked(p),
1032            };
1033
1034            match msg {
1035                BackendMessage::AuthenticationOk => return Outcome::Ok(()),
1036                BackendMessage::AuthenticationCleartextPassword => {
1037                    let auth_value = match self
1038                        .require_auth_value("Authentication value required but not provided")
1039                    {
1040                        Outcome::Ok(password) => password,
1041                        Outcome::Err(e) => return Outcome::Err(e),
1042                        Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1043                        Outcome::Panicked(p) => return Outcome::Panicked(p),
1044                    };
1045                    if let Outcome::Err(e) = self
1046                        .send_message_no_cx(&FrontendMessage::PasswordMessage(
1047                            auth_value.to_string(),
1048                        ))
1049                        .await
1050                    {
1051                        return Outcome::Err(e);
1052                    }
1053                }
1054                BackendMessage::AuthenticationMD5Password(salt) => {
1055                    let auth_value = match self
1056                        .require_auth_value("Authentication value required but not provided")
1057                    {
1058                        Outcome::Ok(password) => password,
1059                        Outcome::Err(e) => return Outcome::Err(e),
1060                        Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1061                        Outcome::Panicked(p) => return Outcome::Panicked(p),
1062                    };
1063                    let hash = md5_password(&self.config.user, auth_value, salt);
1064                    if let Outcome::Err(e) = self
1065                        .send_message_no_cx(&FrontendMessage::PasswordMessage(hash))
1066                        .await
1067                    {
1068                        return Outcome::Err(e);
1069                    }
1070                }
1071                BackendMessage::AuthenticationSASL(mechanisms) => {
1072                    if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
1073                        match self.scram_auth().await {
1074                            Outcome::Ok(()) => {}
1075                            Outcome::Err(e) => return Outcome::Err(e),
1076                            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1077                            Outcome::Panicked(p) => return Outcome::Panicked(p),
1078                        }
1079                    } else {
1080                        return Outcome::Err(auth_error(format!(
1081                            "Unsupported SASL mechanisms: {:?}",
1082                            mechanisms
1083                        )));
1084                    }
1085                }
1086                BackendMessage::ErrorResponse(e) => {
1087                    self.state = ConnectionState::Error;
1088                    return Outcome::Err(error_from_fields(&e));
1089                }
1090                other => {
1091                    return Outcome::Err(protocol_error(format!(
1092                        "Unexpected message during auth: {other:?}"
1093                    )));
1094                }
1095            }
1096        }
1097    }
1098
1099    async fn scram_auth(&mut self) -> Outcome<(), Error> {
1100        let auth_value =
1101            match self.require_auth_value("Authentication value required for SCRAM-SHA-256") {
1102                Outcome::Ok(password) => password,
1103                Outcome::Err(e) => return Outcome::Err(e),
1104                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1105                Outcome::Panicked(p) => return Outcome::Panicked(p),
1106            };
1107
1108        let mut client = ScramClient::new(&self.config.user, auth_value);
1109
1110        // Client-first
1111        let client_first = client.client_first();
1112        if let Outcome::Err(e) = self
1113            .send_message_no_cx(&FrontendMessage::SASLInitialResponse {
1114                mechanism: "SCRAM-SHA-256".to_string(),
1115                data: client_first,
1116            })
1117            .await
1118        {
1119            return Outcome::Err(e);
1120        }
1121
1122        // Server-first
1123        let msg = match self.receive_message_no_cx().await {
1124            Outcome::Ok(m) => m,
1125            Outcome::Err(e) => return Outcome::Err(e),
1126            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1127            Outcome::Panicked(p) => return Outcome::Panicked(p),
1128        };
1129        let server_first_data = match msg {
1130            BackendMessage::AuthenticationSASLContinue(data) => data,
1131            BackendMessage::ErrorResponse(e) => {
1132                self.state = ConnectionState::Error;
1133                return Outcome::Err(error_from_fields(&e));
1134            }
1135            other => {
1136                return Outcome::Err(protocol_error(format!(
1137                    "Expected SASL continue, got: {other:?}"
1138                )));
1139            }
1140        };
1141
1142        // Client-final
1143        let client_final = match client.process_server_first(&server_first_data) {
1144            Ok(v) => v,
1145            Err(e) => return Outcome::Err(e),
1146        };
1147        if let Outcome::Err(e) = self
1148            .send_message_no_cx(&FrontendMessage::SASLResponse(client_final))
1149            .await
1150        {
1151            return Outcome::Err(e);
1152        }
1153
1154        // Server-final
1155        let msg = match self.receive_message_no_cx().await {
1156            Outcome::Ok(m) => m,
1157            Outcome::Err(e) => return Outcome::Err(e),
1158            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1159            Outcome::Panicked(p) => return Outcome::Panicked(p),
1160        };
1161        let server_final_data = match msg {
1162            BackendMessage::AuthenticationSASLFinal(data) => data,
1163            BackendMessage::ErrorResponse(e) => {
1164                self.state = ConnectionState::Error;
1165                return Outcome::Err(error_from_fields(&e));
1166            }
1167            other => {
1168                return Outcome::Err(protocol_error(format!(
1169                    "Expected SASL final, got: {other:?}"
1170                )));
1171            }
1172        };
1173
1174        if let Err(e) = client.verify_server_final(&server_final_data) {
1175            return Outcome::Err(e);
1176        }
1177
1178        // Wait for AuthenticationOk
1179        let msg = match self.receive_message_no_cx().await {
1180            Outcome::Ok(m) => m,
1181            Outcome::Err(e) => return Outcome::Err(e),
1182            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1183            Outcome::Panicked(p) => return Outcome::Panicked(p),
1184        };
1185        match msg {
1186            BackendMessage::AuthenticationOk => Outcome::Ok(()),
1187            BackendMessage::ErrorResponse(e) => {
1188                self.state = ConnectionState::Error;
1189                Outcome::Err(error_from_fields(&e))
1190            }
1191            other => Outcome::Err(protocol_error(format!(
1192                "Expected AuthenticationOk, got: {other:?}"
1193            ))),
1194        }
1195    }
1196
1197    async fn read_startup_messages(&mut self) -> Outcome<(), Error> {
1198        loop {
1199            let msg = match self.receive_message_no_cx().await {
1200                Outcome::Ok(m) => m,
1201                Outcome::Err(e) => return Outcome::Err(e),
1202                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1203                Outcome::Panicked(p) => return Outcome::Panicked(p),
1204            };
1205
1206            match msg {
1207                BackendMessage::BackendKeyData {
1208                    process_id,
1209                    secret_key,
1210                } => {
1211                    self.process_id = process_id;
1212                    self.secret_key = secret_key;
1213                }
1214                BackendMessage::ParameterStatus { name, value } => {
1215                    self.parameters.insert(name, value);
1216                }
1217                BackendMessage::ReadyForQuery(status) => {
1218                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
1219                    return Outcome::Ok(());
1220                }
1221                BackendMessage::ErrorResponse(e) => {
1222                    self.state = ConnectionState::Error;
1223                    return Outcome::Err(error_from_fields(&e));
1224                }
1225                BackendMessage::NoticeResponse(_notice) => {}
1226                other => {
1227                    return Outcome::Err(protocol_error(format!(
1228                        "Unexpected startup message: {other:?}"
1229                    )));
1230                }
1231            }
1232        }
1233    }
1234
1235    // ==================== I/O ====================
1236
1237    async fn send_message(&mut self, cx: &Cx, msg: &FrontendMessage) -> Outcome<(), Error> {
1238        // If cancelled, propagate early.
1239        if let Some(reason) = cx.cancel_reason() {
1240            return Outcome::Cancelled(reason);
1241        }
1242        self.send_message_no_cx(msg).await
1243    }
1244
1245    async fn receive_message(&mut self, cx: &Cx) -> Outcome<BackendMessage, Error> {
1246        if let Some(reason) = cx.cancel_reason() {
1247            return Outcome::Cancelled(reason);
1248        }
1249        self.receive_message_no_cx().await
1250    }
1251
1252    async fn send_message_no_cx(&mut self, msg: &FrontendMessage) -> Outcome<(), Error> {
1253        let data = self.writer.write(msg).to_vec();
1254
1255        if let Err(e) = self.stream.write_all(&data).await {
1256            self.state = ConnectionState::Error;
1257            return Outcome::Err(Error::Connection(ConnectionError {
1258                kind: ConnectionErrorKind::Disconnected,
1259                message: format!("Failed to write to server: {}", e),
1260                source: Some(Box::new(e)),
1261            }));
1262        }
1263
1264        if let Err(e) = self.stream.flush().await {
1265            self.state = ConnectionState::Error;
1266            return Outcome::Err(Error::Connection(ConnectionError {
1267                kind: ConnectionErrorKind::Disconnected,
1268                message: format!("Failed to flush stream: {}", e),
1269                source: Some(Box::new(e)),
1270            }));
1271        }
1272
1273        Outcome::Ok(())
1274    }
1275
1276    async fn receive_message_no_cx(&mut self) -> Outcome<BackendMessage, Error> {
1277        loop {
1278            match self.reader.next_message() {
1279                Ok(Some(msg)) => return Outcome::Ok(msg),
1280                Ok(None) => {}
1281                Err(e) => {
1282                    self.state = ConnectionState::Error;
1283                    return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
1284                }
1285            }
1286
1287            let n = match self.stream.read_some(&mut self.read_buf).await {
1288                Ok(n) => n,
1289                Err(e) => {
1290                    self.state = ConnectionState::Error;
1291                    return Outcome::Err(match e.kind() {
1292                        std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock => {
1293                            Error::Timeout
1294                        }
1295                        _ => Error::Connection(ConnectionError {
1296                            kind: ConnectionErrorKind::Disconnected,
1297                            message: format!("Failed to read from server: {}", e),
1298                            source: Some(Box::new(e)),
1299                        }),
1300                    });
1301                }
1302            };
1303
1304            if n == 0 {
1305                self.state = ConnectionState::Disconnected;
1306                return Outcome::Err(Error::Connection(ConnectionError {
1307                    kind: ConnectionErrorKind::Disconnected,
1308                    message: "Connection closed by server".to_string(),
1309                    source: None,
1310                }));
1311            }
1312
1313            if let Err(e) = self.reader.feed(&self.read_buf[..n]) {
1314                self.state = ConnectionState::Error;
1315                return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
1316            }
1317        }
1318    }
1319}
1320
1321/// Shared, cloneable PostgreSQL connection with interior mutability.
1322pub struct SharedPgConnection {
1323    inner: Arc<Mutex<PgAsyncConnection>>,
1324}
1325
1326impl SharedPgConnection {
1327    pub fn new(conn: PgAsyncConnection) -> Self {
1328        Self {
1329            inner: Arc::new(Mutex::new(conn)),
1330        }
1331    }
1332
1333    pub async fn connect(cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
1334        match PgAsyncConnection::connect(cx, config).await {
1335            Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
1336            Outcome::Err(e) => Outcome::Err(e),
1337            Outcome::Cancelled(r) => Outcome::Cancelled(r),
1338            Outcome::Panicked(p) => Outcome::Panicked(p),
1339        }
1340    }
1341
1342    pub fn inner(&self) -> &Arc<Mutex<PgAsyncConnection>> {
1343        &self.inner
1344    }
1345
1346    async fn begin_transaction_impl(
1347        &self,
1348        cx: &Cx,
1349        isolation: Option<IsolationLevel>,
1350    ) -> Outcome<SharedPgTransaction<'_>, Error> {
1351        let inner = Arc::clone(&self.inner);
1352        let Ok(mut guard) = inner.lock(cx).await else {
1353            return Outcome::Err(connection_error("Failed to acquire connection lock"));
1354        };
1355
1356        if let Some(level) = isolation {
1357            let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
1358            match guard.execute_async(cx, &sql, &[]).await {
1359                Outcome::Ok(_) => {}
1360                Outcome::Err(e) => return Outcome::Err(e),
1361                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1362                Outcome::Panicked(p) => return Outcome::Panicked(p),
1363            }
1364        }
1365
1366        match guard.execute_async(cx, "BEGIN", &[]).await {
1367            Outcome::Ok(_) => {}
1368            Outcome::Err(e) => return Outcome::Err(e),
1369            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1370            Outcome::Panicked(p) => return Outcome::Panicked(p),
1371        }
1372
1373        drop(guard);
1374        Outcome::Ok(SharedPgTransaction {
1375            inner,
1376            committed: false,
1377            _marker: std::marker::PhantomData,
1378        })
1379    }
1380}
1381
1382impl Clone for SharedPgConnection {
1383    fn clone(&self) -> Self {
1384        Self {
1385            inner: Arc::clone(&self.inner),
1386        }
1387    }
1388}
1389
1390impl std::fmt::Debug for SharedPgConnection {
1391    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1392        f.debug_struct("SharedPgConnection")
1393            .field("inner", &"Arc<Mutex<PgAsyncConnection>>")
1394            .finish()
1395    }
1396}
1397
1398pub struct SharedPgTransaction<'conn> {
1399    inner: Arc<Mutex<PgAsyncConnection>>,
1400    committed: bool,
1401    _marker: std::marker::PhantomData<&'conn ()>,
1402}
1403
1404impl<'conn> Drop for SharedPgTransaction<'conn> {
1405    fn drop(&mut self) {
1406        if !self.committed {
1407            // WARNING: Transaction was dropped without commit() or rollback()!
1408            // We cannot do async work in Drop, so the PostgreSQL transaction will
1409            // remain open until the connection is closed or a new transaction
1410            // is started.
1411            #[cfg(debug_assertions)]
1412            eprintln!(
1413                "WARNING: SharedPgTransaction dropped without commit/rollback. \
1414                 The PostgreSQL transaction may still be open."
1415            );
1416        }
1417    }
1418}
1419
1420impl Connection for SharedPgConnection {
1421    type Tx<'conn>
1422        = SharedPgTransaction<'conn>
1423    where
1424        Self: 'conn;
1425
1426    fn dialect(&self) -> sqlmodel_core::Dialect {
1427        sqlmodel_core::Dialect::Postgres
1428    }
1429
1430    fn query(
1431        &self,
1432        cx: &Cx,
1433        sql: &str,
1434        params: &[Value],
1435    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1436        let inner = Arc::clone(&self.inner);
1437        let sql = sql.to_string();
1438        let params = params.to_vec();
1439        async move {
1440            let Ok(mut guard) = inner.lock(cx).await else {
1441                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1442            };
1443            guard.query_async(cx, &sql, &params).await
1444        }
1445    }
1446
1447    fn query_one(
1448        &self,
1449        cx: &Cx,
1450        sql: &str,
1451        params: &[Value],
1452    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1453        let inner = Arc::clone(&self.inner);
1454        let sql = sql.to_string();
1455        let params = params.to_vec();
1456        async move {
1457            let Ok(mut guard) = inner.lock(cx).await else {
1458                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1459            };
1460            let rows = match guard.query_async(cx, &sql, &params).await {
1461                Outcome::Ok(r) => r,
1462                Outcome::Err(e) => return Outcome::Err(e),
1463                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1464                Outcome::Panicked(p) => return Outcome::Panicked(p),
1465            };
1466            Outcome::Ok(rows.into_iter().next())
1467        }
1468    }
1469
1470    fn execute(
1471        &self,
1472        cx: &Cx,
1473        sql: &str,
1474        params: &[Value],
1475    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1476        let inner = Arc::clone(&self.inner);
1477        let sql = sql.to_string();
1478        let params = params.to_vec();
1479        async move {
1480            let Ok(mut guard) = inner.lock(cx).await else {
1481                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1482            };
1483            guard.execute_async(cx, &sql, &params).await
1484        }
1485    }
1486
1487    fn insert(
1488        &self,
1489        cx: &Cx,
1490        sql: &str,
1491        params: &[Value],
1492    ) -> impl Future<Output = Outcome<i64, Error>> + Send {
1493        let inner = Arc::clone(&self.inner);
1494        let sql = sql.to_string();
1495        let params = params.to_vec();
1496        async move {
1497            let Ok(mut guard) = inner.lock(cx).await else {
1498                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1499            };
1500            guard.insert_async(cx, &sql, &params).await
1501        }
1502    }
1503
1504    fn batch(
1505        &self,
1506        cx: &Cx,
1507        statements: &[(String, Vec<Value>)],
1508    ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
1509        let inner = Arc::clone(&self.inner);
1510        let statements = statements.to_vec();
1511        async move {
1512            let Ok(mut guard) = inner.lock(cx).await else {
1513                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1514            };
1515            let mut results = Vec::with_capacity(statements.len());
1516            for (sql, params) in &statements {
1517                match guard.execute_async(cx, sql, params).await {
1518                    Outcome::Ok(n) => results.push(n),
1519                    Outcome::Err(e) => return Outcome::Err(e),
1520                    Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1521                    Outcome::Panicked(p) => return Outcome::Panicked(p),
1522                }
1523            }
1524            Outcome::Ok(results)
1525        }
1526    }
1527
1528    fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1529        self.begin_with(cx, IsolationLevel::default())
1530    }
1531
1532    fn begin_with(
1533        &self,
1534        cx: &Cx,
1535        isolation: IsolationLevel,
1536    ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1537        self.begin_transaction_impl(cx, Some(isolation))
1538    }
1539
1540    fn prepare(
1541        &self,
1542        cx: &Cx,
1543        sql: &str,
1544    ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
1545        let inner = Arc::clone(&self.inner);
1546        let sql = sql.to_string();
1547        async move {
1548            let Ok(mut guard) = inner.lock(cx).await else {
1549                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1550            };
1551            guard.prepare_async(cx, &sql).await
1552        }
1553    }
1554
1555    fn query_prepared(
1556        &self,
1557        cx: &Cx,
1558        stmt: &PreparedStatement,
1559        params: &[Value],
1560    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1561        let inner = Arc::clone(&self.inner);
1562        let stmt = stmt.clone();
1563        let params = params.to_vec();
1564        async move {
1565            let Ok(mut guard) = inner.lock(cx).await else {
1566                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1567            };
1568            guard.query_prepared_async(cx, &stmt, &params).await
1569        }
1570    }
1571
1572    fn execute_prepared(
1573        &self,
1574        cx: &Cx,
1575        stmt: &PreparedStatement,
1576        params: &[Value],
1577    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1578        let inner = Arc::clone(&self.inner);
1579        let stmt = stmt.clone();
1580        let params = params.to_vec();
1581        async move {
1582            let Ok(mut guard) = inner.lock(cx).await else {
1583                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1584            };
1585            guard.execute_prepared_async(cx, &stmt, &params).await
1586        }
1587    }
1588
1589    fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1590        let inner = Arc::clone(&self.inner);
1591        async move {
1592            let Ok(mut guard) = inner.lock(cx).await else {
1593                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1594            };
1595            guard.ping_async(cx).await
1596        }
1597    }
1598
1599    async fn close(self, cx: &Cx) -> sqlmodel_core::Result<()> {
1600        let Ok(mut guard) = self.inner.lock(cx).await else {
1601            return Err(connection_error("Failed to acquire connection lock"));
1602        };
1603        match guard.close_async(cx).await {
1604            Outcome::Ok(()) => Ok(()),
1605            Outcome::Err(e) => Err(e),
1606            Outcome::Cancelled(r) => Err(Error::Query(QueryError {
1607                kind: QueryErrorKind::Cancelled,
1608                message: format!("Cancelled: {r:?}"),
1609                sqlstate: None,
1610                sql: None,
1611                detail: None,
1612                hint: None,
1613                position: None,
1614                source: None,
1615            })),
1616            Outcome::Panicked(p) => Err(Error::Protocol(ProtocolError {
1617                message: format!("Panicked: {p:?}"),
1618                raw_data: None,
1619                source: None,
1620            })),
1621        }
1622    }
1623}
1624
1625impl<'conn> TransactionOps for SharedPgTransaction<'conn> {
1626    fn query(
1627        &self,
1628        cx: &Cx,
1629        sql: &str,
1630        params: &[Value],
1631    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1632        let inner = Arc::clone(&self.inner);
1633        let sql = sql.to_string();
1634        let params = params.to_vec();
1635        async move {
1636            let Ok(mut guard) = inner.lock(cx).await else {
1637                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1638            };
1639            guard.query_async(cx, &sql, &params).await
1640        }
1641    }
1642
1643    fn query_one(
1644        &self,
1645        cx: &Cx,
1646        sql: &str,
1647        params: &[Value],
1648    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1649        let inner = Arc::clone(&self.inner);
1650        let sql = sql.to_string();
1651        let params = params.to_vec();
1652        async move {
1653            let Ok(mut guard) = inner.lock(cx).await else {
1654                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1655            };
1656            let rows = match guard.query_async(cx, &sql, &params).await {
1657                Outcome::Ok(r) => r,
1658                Outcome::Err(e) => return Outcome::Err(e),
1659                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1660                Outcome::Panicked(p) => return Outcome::Panicked(p),
1661            };
1662            Outcome::Ok(rows.into_iter().next())
1663        }
1664    }
1665
1666    fn execute(
1667        &self,
1668        cx: &Cx,
1669        sql: &str,
1670        params: &[Value],
1671    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1672        let inner = Arc::clone(&self.inner);
1673        let sql = sql.to_string();
1674        let params = params.to_vec();
1675        async move {
1676            let Ok(mut guard) = inner.lock(cx).await else {
1677                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1678            };
1679            guard.execute_async(cx, &sql, &params).await
1680        }
1681    }
1682
1683    fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1684        let inner = Arc::clone(&self.inner);
1685        let name = name.to_string();
1686        async move {
1687            if let Err(e) = validate_savepoint_name(&name) {
1688                return Outcome::Err(e);
1689            }
1690            let sql = format!("SAVEPOINT {}", name);
1691            let Ok(mut guard) = inner.lock(cx).await else {
1692                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1693            };
1694            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1695        }
1696    }
1697
1698    fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1699        let inner = Arc::clone(&self.inner);
1700        let name = name.to_string();
1701        async move {
1702            if let Err(e) = validate_savepoint_name(&name) {
1703                return Outcome::Err(e);
1704            }
1705            let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
1706            let Ok(mut guard) = inner.lock(cx).await else {
1707                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1708            };
1709            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1710        }
1711    }
1712
1713    fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1714        let inner = Arc::clone(&self.inner);
1715        let name = name.to_string();
1716        async move {
1717            if let Err(e) = validate_savepoint_name(&name) {
1718                return Outcome::Err(e);
1719            }
1720            let sql = format!("RELEASE SAVEPOINT {}", name);
1721            let Ok(mut guard) = inner.lock(cx).await else {
1722                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1723            };
1724            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1725        }
1726    }
1727
1728    // Note: clippy sometimes flags `self.committed = true` as unused, but Drop reads it.
1729    #[allow(unused_assignments)]
1730    fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1731        let inner = Arc::clone(&self.inner);
1732        async move {
1733            let Ok(mut guard) = inner.lock(cx).await else {
1734                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1735            };
1736            let result = guard.execute_async(cx, "COMMIT", &[]).await;
1737            if matches!(result, Outcome::Ok(_)) {
1738                self.committed = true;
1739            }
1740            result.map(|_| ())
1741        }
1742    }
1743
1744    #[allow(unused_assignments)]
1745    fn rollback(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1746        let inner = Arc::clone(&self.inner);
1747        async move {
1748            let Ok(mut guard) = inner.lock(cx).await else {
1749                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1750            };
1751            let result = guard.execute_async(cx, "ROLLBACK", &[]).await;
1752            if matches!(result, Outcome::Ok(_)) {
1753                self.committed = true;
1754            }
1755            result.map(|_| ())
1756        }
1757    }
1758}
1759
1760// ==================== Helpers ====================
1761
1762struct PgQueryResult {
1763    rows: Vec<Row>,
1764    command_tag: Option<String>,
1765}
1766
1767fn connection_error(msg: impl Into<String>) -> Error {
1768    Error::Connection(ConnectionError {
1769        kind: ConnectionErrorKind::Connect,
1770        message: msg.into(),
1771        source: None,
1772    })
1773}
1774
1775fn auth_error(msg: impl Into<String>) -> Error {
1776    Error::Connection(ConnectionError {
1777        kind: ConnectionErrorKind::Authentication,
1778        message: msg.into(),
1779        source: None,
1780    })
1781}
1782
1783fn protocol_error(msg: impl Into<String>) -> Error {
1784    Error::Protocol(ProtocolError {
1785        message: msg.into(),
1786        raw_data: None,
1787        source: None,
1788    })
1789}
1790
1791fn query_error_msg(msg: impl Into<String>, kind: QueryErrorKind) -> Error {
1792    Error::Query(QueryError {
1793        kind,
1794        message: msg.into(),
1795        sqlstate: None,
1796        sql: None,
1797        detail: None,
1798        hint: None,
1799        position: None,
1800        source: None,
1801    })
1802}
1803
1804fn error_from_fields(fields: &ErrorFields) -> Error {
1805    let kind = match fields.code.get(..2) {
1806        Some("08") => {
1807            return Error::Connection(ConnectionError {
1808                kind: ConnectionErrorKind::Connect,
1809                message: fields.message.clone(),
1810                source: None,
1811            });
1812        }
1813        Some("28") => {
1814            return Error::Connection(ConnectionError {
1815                kind: ConnectionErrorKind::Authentication,
1816                message: fields.message.clone(),
1817                source: None,
1818            });
1819        }
1820        Some("42") => QueryErrorKind::Syntax,
1821        Some("23") => QueryErrorKind::Constraint,
1822        Some("40") => {
1823            if fields.code == "40001" {
1824                QueryErrorKind::Serialization
1825            } else {
1826                QueryErrorKind::Deadlock
1827            }
1828        }
1829        Some("57") => {
1830            if fields.code == "57014" {
1831                QueryErrorKind::Cancelled
1832            } else {
1833                QueryErrorKind::Timeout
1834            }
1835        }
1836        _ => QueryErrorKind::Database,
1837    };
1838
1839    Error::Query(QueryError {
1840        kind,
1841        sql: None,
1842        sqlstate: Some(fields.code.clone()),
1843        message: fields.message.clone(),
1844        detail: fields.detail.clone(),
1845        hint: fields.hint.clone(),
1846        position: fields.position.map(|p| p as usize),
1847        source: None,
1848    })
1849}
1850
1851fn parse_rows_affected(tag: Option<&str>) -> Option<u64> {
1852    let tag = tag?;
1853    let mut parts = tag.split_whitespace().collect::<Vec<_>>();
1854    parts.pop().and_then(|last| last.parse::<u64>().ok())
1855}
1856
1857/// Validate a savepoint name to reduce SQL injection risk.
1858fn validate_savepoint_name(name: &str) -> sqlmodel_core::Result<()> {
1859    if name.is_empty() {
1860        return Err(query_error_msg(
1861            "Savepoint name cannot be empty",
1862            QueryErrorKind::Syntax,
1863        ));
1864    }
1865    if name.len() > 63 {
1866        return Err(query_error_msg(
1867            "Savepoint name exceeds maximum length of 63 characters",
1868            QueryErrorKind::Syntax,
1869        ));
1870    }
1871    let mut chars = name.chars();
1872    let Some(first) = chars.next() else {
1873        return Err(query_error_msg(
1874            "Savepoint name cannot be empty",
1875            QueryErrorKind::Syntax,
1876        ));
1877    };
1878    if !first.is_ascii_alphabetic() && first != '_' {
1879        return Err(query_error_msg(
1880            "Savepoint name must start with a letter or underscore",
1881            QueryErrorKind::Syntax,
1882        ));
1883    }
1884    for c in chars {
1885        if !c.is_ascii_alphanumeric() && c != '_' {
1886            return Err(query_error_msg(
1887                format!("Savepoint name contains invalid character: '{c}'"),
1888                QueryErrorKind::Syntax,
1889            ));
1890        }
1891    }
1892    Ok(())
1893}
1894
1895fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
1896    use std::fmt::Write;
1897
1898    let inner = format!("{password}{user}");
1899    let inner_hash = md5::compute(inner.as_bytes());
1900
1901    let mut outer_input = format!("{inner_hash:x}").into_bytes();
1902    outer_input.extend_from_slice(&salt);
1903    let outer_hash = md5::compute(&outer_input);
1904
1905    let mut result = String::with_capacity(35);
1906    result.push_str("md5");
1907    write!(&mut result, "{outer_hash:x}").unwrap();
1908    result
1909}
1910
1911// Note: read/write helpers are implemented above on PgAsyncStream.