Skip to main content

sqlmodel_postgres/
async_connection.rs

1//! Async PostgreSQL connection implementation.
2//!
3//! This module implements an async PostgreSQL connection using asupersync's TCP
4//! primitives. It provides a shared wrapper that implements `sqlmodel-core`'s
5//! [`Connection`] trait.
6//!
7//! The implementation currently focuses on:
8//! - Async connect + authentication (cleartext, MD5, SCRAM-SHA-256)
9//! - Extended query protocol for parameterized queries
10//! - Row decoding via the postgres type registry (OID + text/binary format)
11//! - Basic transaction support (BEGIN/COMMIT/ROLLBACK + savepoints)
12
13// Allow `impl Future` return types in trait methods - intentional for async trait compat
14#![allow(clippy::manual_async_fn)]
15// The Error type is intentionally large to carry full context
16#![allow(clippy::result_large_err)]
17
18use std::collections::HashMap;
19use std::future::Future;
20use std::sync::Arc;
21
22use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
23use asupersync::net::TcpStream;
24use asupersync::sync::Mutex;
25use asupersync::{Cx, Outcome};
26
27use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
28use sqlmodel_core::error::{
29    ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
30};
31use sqlmodel_core::row::ColumnInfo;
32use sqlmodel_core::{Error, Row, Value};
33
34use crate::auth::ScramClient;
35use crate::config::PgConfig;
36use crate::connection::{ConnectionState, TransactionStatusState};
37use crate::protocol::{
38    BackendMessage, DescribeKind, ErrorFields, FrontendMessage, MessageReader, MessageWriter,
39    PROTOCOL_VERSION,
40};
41use crate::types::{Format, decode_value, encode_value};
42
43/// Async PostgreSQL connection.
44///
45/// This connection uses asupersync's TCP stream for non-blocking I/O and
46/// supports the extended query protocol for parameter binding.
47pub struct PgAsyncConnection {
48    stream: TcpStream,
49    state: ConnectionState,
50    process_id: i32,
51    secret_key: i32,
52    parameters: HashMap<String, String>,
53    config: PgConfig,
54    reader: MessageReader,
55    writer: MessageWriter,
56    read_buf: Vec<u8>,
57}
58
59impl std::fmt::Debug for PgAsyncConnection {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("PgAsyncConnection")
62            .field("state", &self.state)
63            .field("process_id", &self.process_id)
64            .field("host", &self.config.host)
65            .field("port", &self.config.port)
66            .field("database", &self.config.database)
67            .finish_non_exhaustive()
68    }
69}
70
71impl PgAsyncConnection {
72    /// Establish a new async connection to the PostgreSQL server.
73    pub async fn connect(_cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
74        let addr = config.socket_addr();
75        let socket_addr = match addr.parse() {
76            Ok(a) => a,
77            Err(e) => {
78                return Outcome::Err(Error::Connection(ConnectionError {
79                    kind: ConnectionErrorKind::Connect,
80                    message: format!("Invalid socket address: {}", e),
81                    source: None,
82                }));
83            }
84        };
85
86        let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
87            Ok(s) => s,
88            Err(e) => {
89                let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
90                    ConnectionErrorKind::Refused
91                } else {
92                    ConnectionErrorKind::Connect
93                };
94                return Outcome::Err(Error::Connection(ConnectionError {
95                    kind,
96                    message: format!("Failed to connect to {}: {}", addr, e),
97                    source: Some(Box::new(e)),
98                }));
99            }
100        };
101
102        stream.set_nodelay(true).ok();
103
104        let mut conn = Self {
105            stream,
106            state: ConnectionState::Connecting,
107            process_id: 0,
108            secret_key: 0,
109            parameters: HashMap::new(),
110            config,
111            reader: MessageReader::new(),
112            writer: MessageWriter::new(),
113            read_buf: vec![0u8; 8192],
114        };
115
116        // SSL negotiation (TLS not implemented; matches sync driver behavior)
117        if conn.config.ssl_mode.should_try_ssl() {
118            match conn.negotiate_ssl().await {
119                Outcome::Ok(()) => {}
120                Outcome::Err(e) => return Outcome::Err(e),
121                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
122                Outcome::Panicked(p) => return Outcome::Panicked(p),
123            }
124        }
125
126        // Startup + authentication
127        if let Outcome::Err(e) = conn.send_startup().await {
128            return Outcome::Err(e);
129        }
130        conn.state = ConnectionState::Authenticating;
131
132        match conn.handle_auth().await {
133            Outcome::Ok(()) => {}
134            Outcome::Err(e) => return Outcome::Err(e),
135            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
136            Outcome::Panicked(p) => return Outcome::Panicked(p),
137        }
138
139        match conn.read_startup_messages().await {
140            Outcome::Ok(()) => Outcome::Ok(conn),
141            Outcome::Err(e) => Outcome::Err(e),
142            Outcome::Cancelled(r) => Outcome::Cancelled(r),
143            Outcome::Panicked(p) => Outcome::Panicked(p),
144        }
145    }
146
147    /// Run a parameterized query and return all rows.
148    pub async fn query_async(
149        &mut self,
150        cx: &Cx,
151        sql: &str,
152        params: &[Value],
153    ) -> Outcome<Vec<Row>, Error> {
154        match self.run_extended(cx, sql, params).await {
155            Outcome::Ok(result) => Outcome::Ok(result.rows),
156            Outcome::Err(e) => Outcome::Err(e),
157            Outcome::Cancelled(r) => Outcome::Cancelled(r),
158            Outcome::Panicked(p) => Outcome::Panicked(p),
159        }
160    }
161
162    /// Execute a statement and return rows affected.
163    pub async fn execute_async(
164        &mut self,
165        cx: &Cx,
166        sql: &str,
167        params: &[Value],
168    ) -> Outcome<u64, Error> {
169        match self.run_extended(cx, sql, params).await {
170            Outcome::Ok(result) => {
171                Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
172            }
173            Outcome::Err(e) => Outcome::Err(e),
174            Outcome::Cancelled(r) => Outcome::Cancelled(r),
175            Outcome::Panicked(p) => Outcome::Panicked(p),
176        }
177    }
178
179    /// Execute an INSERT and return the inserted id.
180    ///
181    /// PostgreSQL requires `RETURNING` to retrieve generated IDs. This method
182    /// expects the SQL to return a single-row, single-column result set
183    /// containing an integer id.
184    pub async fn insert_async(
185        &mut self,
186        cx: &Cx,
187        sql: &str,
188        params: &[Value],
189    ) -> Outcome<i64, Error> {
190        let result = match self.run_extended(cx, sql, params).await {
191            Outcome::Ok(r) => r,
192            Outcome::Err(e) => return Outcome::Err(e),
193            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
194            Outcome::Panicked(p) => return Outcome::Panicked(p),
195        };
196
197        let Some(row) = result.rows.first() else {
198            return Outcome::Err(query_error_msg(
199                "INSERT did not return an id; add `RETURNING id`",
200                QueryErrorKind::Database,
201            ));
202        };
203        let Some(id_value) = row.get(0) else {
204            return Outcome::Err(query_error_msg(
205                "INSERT result row missing id column",
206                QueryErrorKind::Database,
207            ));
208        };
209        match id_value.as_i64() {
210            Some(v) => Outcome::Ok(v),
211            None => Outcome::Err(query_error_msg(
212                "INSERT returned non-integer id",
213                QueryErrorKind::Database,
214            )),
215        }
216    }
217
218    /// Ping the server.
219    pub async fn ping_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
220        self.execute_async(cx, "SELECT 1", &[]).await.map(|_| ())
221    }
222
223    /// Close the connection.
224    pub async fn close_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
225        // Best-effort terminate. If this fails, the drop will close the socket.
226        let _ = self.send_message(cx, &FrontendMessage::Terminate).await;
227        self.state = ConnectionState::Closed;
228        Outcome::Ok(())
229    }
230
231    // ==================== Protocol: extended query ====================
232
233    async fn run_extended(
234        &mut self,
235        cx: &Cx,
236        sql: &str,
237        params: &[Value],
238    ) -> Outcome<PgQueryResult, Error> {
239        // Encode parameters
240        let mut param_types = Vec::with_capacity(params.len());
241        let mut param_values = Vec::with_capacity(params.len());
242
243        for v in params {
244            if matches!(v, Value::Null) {
245                param_types.push(0);
246                param_values.push(None);
247                continue;
248            }
249            match encode_value(v, Format::Text) {
250                Ok((bytes, oid)) => {
251                    param_types.push(oid);
252                    param_values.push(Some(bytes));
253                }
254                Err(e) => return Outcome::Err(e),
255            }
256        }
257
258        // Parse + bind unnamed statement/portal
259        if let Outcome::Err(e) = self
260            .send_message(
261                cx,
262                &FrontendMessage::Parse {
263                    name: String::new(),
264                    query: sql.to_string(),
265                    param_types,
266                },
267            )
268            .await
269        {
270            return Outcome::Err(e);
271        }
272
273        let param_formats = if params.is_empty() {
274            Vec::new()
275        } else {
276            vec![Format::Text.code()]
277        };
278        if let Outcome::Err(e) = self
279            .send_message(
280                cx,
281                &FrontendMessage::Bind {
282                    portal: String::new(),
283                    statement: String::new(),
284                    param_formats,
285                    params: param_values,
286                    // Default result formats (text) when empty.
287                    result_formats: Vec::new(),
288                },
289            )
290            .await
291        {
292            return Outcome::Err(e);
293        }
294
295        if let Outcome::Err(e) = self
296            .send_message(
297                cx,
298                &FrontendMessage::Describe {
299                    kind: DescribeKind::Portal,
300                    name: String::new(),
301                },
302            )
303            .await
304        {
305            return Outcome::Err(e);
306        }
307
308        if let Outcome::Err(e) = self
309            .send_message(
310                cx,
311                &FrontendMessage::Execute {
312                    portal: String::new(),
313                    max_rows: 0,
314                },
315            )
316            .await
317        {
318            return Outcome::Err(e);
319        }
320
321        if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
322            return Outcome::Err(e);
323        }
324
325        // Read responses until ReadyForQuery
326        let mut field_descs: Option<Vec<crate::protocol::FieldDescription>> = None;
327        let mut columns: Option<Arc<ColumnInfo>> = None;
328        let mut rows: Vec<Row> = Vec::new();
329        let mut command_tag: Option<String> = None;
330
331        loop {
332            let msg = match self.receive_message(cx).await {
333                Outcome::Ok(m) => m,
334                Outcome::Err(e) => return Outcome::Err(e),
335                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
336                Outcome::Panicked(p) => return Outcome::Panicked(p),
337            };
338
339            match msg {
340                BackendMessage::ParseComplete
341                | BackendMessage::BindComplete
342                | BackendMessage::CloseComplete
343                | BackendMessage::ParameterDescription(_)
344                | BackendMessage::NoData
345                | BackendMessage::PortalSuspended
346                | BackendMessage::EmptyQueryResponse => {}
347                BackendMessage::RowDescription(desc) => {
348                    let names: Vec<String> = desc.iter().map(|f| f.name.clone()).collect();
349                    columns = Some(Arc::new(ColumnInfo::new(names)));
350                    field_descs = Some(desc);
351                }
352                BackendMessage::DataRow(raw_values) => {
353                    let Some(ref desc) = field_descs else {
354                        return Outcome::Err(protocol_error(
355                            "DataRow received before RowDescription",
356                        ));
357                    };
358                    let Some(ref cols) = columns else {
359                        return Outcome::Err(protocol_error("Row column metadata missing"));
360                    };
361                    if raw_values.len() != desc.len() {
362                        return Outcome::Err(protocol_error("DataRow field count mismatch"));
363                    }
364
365                    let mut values = Vec::with_capacity(raw_values.len());
366                    for (i, raw) in raw_values.into_iter().enumerate() {
367                        match raw {
368                            None => values.push(Value::Null),
369                            Some(bytes) => {
370                                let field = &desc[i];
371                                let format = Format::from_code(field.format);
372                                let decoded = match decode_value(
373                                    field.type_oid,
374                                    Some(bytes.as_slice()),
375                                    format,
376                                ) {
377                                    Ok(v) => v,
378                                    Err(e) => return Outcome::Err(e),
379                                };
380                                values.push(decoded);
381                            }
382                        }
383                    }
384                    rows.push(Row::with_columns(Arc::clone(cols), values));
385                }
386                BackendMessage::CommandComplete(tag) => {
387                    command_tag = Some(tag);
388                }
389                BackendMessage::ReadyForQuery(status) => {
390                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
391                    break;
392                }
393                BackendMessage::ErrorResponse(e) => {
394                    self.state = ConnectionState::Error;
395                    return Outcome::Err(error_from_fields(&e));
396                }
397                BackendMessage::NoticeResponse(_notice) => {}
398                _ => {}
399            }
400        }
401
402        Outcome::Ok(PgQueryResult { rows, command_tag })
403    }
404
405    // ==================== Startup + auth ====================
406
407    async fn negotiate_ssl(&mut self) -> Outcome<(), Error> {
408        // Send SSL request
409        if let Outcome::Err(e) = self.send_message_no_cx(&FrontendMessage::SSLRequest).await {
410            return Outcome::Err(e);
411        }
412
413        // Read single-byte response
414        let mut buf = [0u8; 1];
415        match read_exact_async(&mut self.stream, &mut buf).await {
416            Ok(()) => {}
417            Err(e) => {
418                return Outcome::Err(Error::Connection(ConnectionError {
419                    kind: ConnectionErrorKind::Ssl,
420                    message: format!("Failed to read SSL response: {}", e),
421                    source: Some(Box::new(e)),
422                }));
423            }
424        }
425
426        match buf[0] {
427            b'S' => {
428                // Server supports SSL but TLS handshake is not implemented.
429                if self.config.ssl_mode.is_required() {
430                    Outcome::Err(Error::Connection(ConnectionError {
431                        kind: ConnectionErrorKind::Ssl,
432                        message: "SSL/TLS not yet implemented".to_string(),
433                        source: None,
434                    }))
435                } else {
436                    Outcome::Err(Error::Connection(ConnectionError {
437                        kind: ConnectionErrorKind::Ssl,
438                        message: "SSL/TLS not yet implemented, reconnect with ssl_mode=disable"
439                            .to_string(),
440                        source: None,
441                    }))
442                }
443            }
444            b'N' => {
445                if self.config.ssl_mode.is_required() {
446                    Outcome::Err(Error::Connection(ConnectionError {
447                        kind: ConnectionErrorKind::Ssl,
448                        message: "Server does not support SSL".to_string(),
449                        source: None,
450                    }))
451                } else {
452                    Outcome::Ok(())
453                }
454            }
455            other => Outcome::Err(Error::Connection(ConnectionError {
456                kind: ConnectionErrorKind::Ssl,
457                message: format!("Unexpected SSL response: 0x{other:02x}"),
458                source: None,
459            })),
460        }
461    }
462
463    async fn send_startup(&mut self) -> Outcome<(), Error> {
464        let params = self.config.startup_params();
465        self.send_message_no_cx(&FrontendMessage::Startup {
466            version: PROTOCOL_VERSION,
467            params,
468        })
469        .await
470    }
471
472    async fn handle_auth(&mut self) -> Outcome<(), Error> {
473        loop {
474            let msg = match self.receive_message_no_cx().await {
475                Outcome::Ok(m) => m,
476                Outcome::Err(e) => return Outcome::Err(e),
477                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
478                Outcome::Panicked(p) => return Outcome::Panicked(p),
479            };
480
481            match msg {
482                BackendMessage::AuthenticationOk => return Outcome::Ok(()),
483                BackendMessage::AuthenticationCleartextPassword => {
484                    let Some(password) = self.config.password.as_ref() else {
485                        return Outcome::Err(auth_error("Password required but not provided"));
486                    };
487                    if let Outcome::Err(e) = self
488                        .send_message_no_cx(&FrontendMessage::PasswordMessage(password.clone()))
489                        .await
490                    {
491                        return Outcome::Err(e);
492                    }
493                }
494                BackendMessage::AuthenticationMD5Password(salt) => {
495                    let Some(password) = self.config.password.as_ref() else {
496                        return Outcome::Err(auth_error("Password required but not provided"));
497                    };
498                    let hash = md5_password(&self.config.user, password, salt);
499                    if let Outcome::Err(e) = self
500                        .send_message_no_cx(&FrontendMessage::PasswordMessage(hash))
501                        .await
502                    {
503                        return Outcome::Err(e);
504                    }
505                }
506                BackendMessage::AuthenticationSASL(mechanisms) => {
507                    if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
508                        match self.scram_auth().await {
509                            Outcome::Ok(()) => {}
510                            Outcome::Err(e) => return Outcome::Err(e),
511                            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
512                            Outcome::Panicked(p) => return Outcome::Panicked(p),
513                        }
514                    } else {
515                        return Outcome::Err(auth_error(format!(
516                            "Unsupported SASL mechanisms: {:?}",
517                            mechanisms
518                        )));
519                    }
520                }
521                BackendMessage::ErrorResponse(e) => {
522                    self.state = ConnectionState::Error;
523                    return Outcome::Err(error_from_fields(&e));
524                }
525                other => {
526                    return Outcome::Err(protocol_error(format!(
527                        "Unexpected message during auth: {other:?}"
528                    )));
529                }
530            }
531        }
532    }
533
534    async fn scram_auth(&mut self) -> Outcome<(), Error> {
535        let Some(password) = self.config.password.as_ref() else {
536            return Outcome::Err(auth_error("Password required for SCRAM-SHA-256"));
537        };
538
539        let mut client = ScramClient::new(&self.config.user, password);
540
541        // Client-first
542        let client_first = client.client_first();
543        if let Outcome::Err(e) = self
544            .send_message_no_cx(&FrontendMessage::SASLInitialResponse {
545                mechanism: "SCRAM-SHA-256".to_string(),
546                data: client_first,
547            })
548            .await
549        {
550            return Outcome::Err(e);
551        }
552
553        // Server-first
554        let msg = match self.receive_message_no_cx().await {
555            Outcome::Ok(m) => m,
556            Outcome::Err(e) => return Outcome::Err(e),
557            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
558            Outcome::Panicked(p) => return Outcome::Panicked(p),
559        };
560        let server_first_data = match msg {
561            BackendMessage::AuthenticationSASLContinue(data) => data,
562            BackendMessage::ErrorResponse(e) => {
563                self.state = ConnectionState::Error;
564                return Outcome::Err(error_from_fields(&e));
565            }
566            other => {
567                return Outcome::Err(protocol_error(format!(
568                    "Expected SASL continue, got: {other:?}"
569                )));
570            }
571        };
572
573        // Client-final
574        let client_final = match client.process_server_first(&server_first_data) {
575            Ok(v) => v,
576            Err(e) => return Outcome::Err(e),
577        };
578        if let Outcome::Err(e) = self
579            .send_message_no_cx(&FrontendMessage::SASLResponse(client_final))
580            .await
581        {
582            return Outcome::Err(e);
583        }
584
585        // Server-final
586        let msg = match self.receive_message_no_cx().await {
587            Outcome::Ok(m) => m,
588            Outcome::Err(e) => return Outcome::Err(e),
589            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
590            Outcome::Panicked(p) => return Outcome::Panicked(p),
591        };
592        let server_final_data = match msg {
593            BackendMessage::AuthenticationSASLFinal(data) => data,
594            BackendMessage::ErrorResponse(e) => {
595                self.state = ConnectionState::Error;
596                return Outcome::Err(error_from_fields(&e));
597            }
598            other => {
599                return Outcome::Err(protocol_error(format!(
600                    "Expected SASL final, got: {other:?}"
601                )));
602            }
603        };
604
605        if let Err(e) = client.verify_server_final(&server_final_data) {
606            return Outcome::Err(e);
607        }
608
609        // Wait for AuthenticationOk
610        let msg = match self.receive_message_no_cx().await {
611            Outcome::Ok(m) => m,
612            Outcome::Err(e) => return Outcome::Err(e),
613            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
614            Outcome::Panicked(p) => return Outcome::Panicked(p),
615        };
616        match msg {
617            BackendMessage::AuthenticationOk => Outcome::Ok(()),
618            BackendMessage::ErrorResponse(e) => {
619                self.state = ConnectionState::Error;
620                Outcome::Err(error_from_fields(&e))
621            }
622            other => Outcome::Err(protocol_error(format!(
623                "Expected AuthenticationOk, got: {other:?}"
624            ))),
625        }
626    }
627
628    async fn read_startup_messages(&mut self) -> Outcome<(), Error> {
629        loop {
630            let msg = match self.receive_message_no_cx().await {
631                Outcome::Ok(m) => m,
632                Outcome::Err(e) => return Outcome::Err(e),
633                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
634                Outcome::Panicked(p) => return Outcome::Panicked(p),
635            };
636
637            match msg {
638                BackendMessage::BackendKeyData {
639                    process_id,
640                    secret_key,
641                } => {
642                    self.process_id = process_id;
643                    self.secret_key = secret_key;
644                }
645                BackendMessage::ParameterStatus { name, value } => {
646                    self.parameters.insert(name, value);
647                }
648                BackendMessage::ReadyForQuery(status) => {
649                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
650                    return Outcome::Ok(());
651                }
652                BackendMessage::ErrorResponse(e) => {
653                    self.state = ConnectionState::Error;
654                    return Outcome::Err(error_from_fields(&e));
655                }
656                BackendMessage::NoticeResponse(_notice) => {}
657                other => {
658                    return Outcome::Err(protocol_error(format!(
659                        "Unexpected startup message: {other:?}"
660                    )));
661                }
662            }
663        }
664    }
665
666    // ==================== I/O ====================
667
668    async fn send_message(&mut self, cx: &Cx, msg: &FrontendMessage) -> Outcome<(), Error> {
669        // If cancelled, propagate early.
670        if let Some(reason) = cx.cancel_reason() {
671            return Outcome::Cancelled(reason);
672        }
673        self.send_message_no_cx(msg).await
674    }
675
676    async fn receive_message(&mut self, cx: &Cx) -> Outcome<BackendMessage, Error> {
677        if let Some(reason) = cx.cancel_reason() {
678            return Outcome::Cancelled(reason);
679        }
680        self.receive_message_no_cx().await
681    }
682
683    async fn send_message_no_cx(&mut self, msg: &FrontendMessage) -> Outcome<(), Error> {
684        let data = self.writer.write(msg).to_vec();
685
686        let mut written = 0;
687        while written < data.len() {
688            match std::future::poll_fn(|cx| {
689                std::pin::Pin::new(&mut self.stream).poll_write(cx, &data[written..])
690            })
691            .await
692            {
693                Ok(n) => {
694                    if n == 0 {
695                        self.state = ConnectionState::Error;
696                        return Outcome::Err(Error::Connection(ConnectionError {
697                            kind: ConnectionErrorKind::Disconnected,
698                            message: "Connection closed while writing".to_string(),
699                            source: None,
700                        }));
701                    }
702                    written += n;
703                }
704                Err(e) => {
705                    self.state = ConnectionState::Error;
706                    return Outcome::Err(Error::Connection(ConnectionError {
707                        kind: ConnectionErrorKind::Disconnected,
708                        message: format!("Failed to write to server: {}", e),
709                        source: Some(Box::new(e)),
710                    }));
711                }
712            }
713        }
714
715        match std::future::poll_fn(|cx| std::pin::Pin::new(&mut self.stream).poll_flush(cx)).await {
716            Ok(()) => Outcome::Ok(()),
717            Err(e) => {
718                self.state = ConnectionState::Error;
719                Outcome::Err(Error::Connection(ConnectionError {
720                    kind: ConnectionErrorKind::Disconnected,
721                    message: format!("Failed to flush stream: {}", e),
722                    source: Some(Box::new(e)),
723                }))
724            }
725        }
726    }
727
728    async fn receive_message_no_cx(&mut self) -> Outcome<BackendMessage, Error> {
729        loop {
730            match self.reader.next_message() {
731                Ok(Some(msg)) => return Outcome::Ok(msg),
732                Ok(None) => {}
733                Err(e) => {
734                    self.state = ConnectionState::Error;
735                    return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
736                }
737            }
738
739            let mut read_buf = ReadBuf::new(&mut self.read_buf);
740            match std::future::poll_fn(|cx| {
741                std::pin::Pin::new(&mut self.stream).poll_read(cx, &mut read_buf)
742            })
743            .await
744            {
745                Ok(()) => {
746                    let n = read_buf.filled().len();
747                    if n == 0 {
748                        self.state = ConnectionState::Disconnected;
749                        return Outcome::Err(Error::Connection(ConnectionError {
750                            kind: ConnectionErrorKind::Disconnected,
751                            message: "Connection closed by server".to_string(),
752                            source: None,
753                        }));
754                    }
755                    if let Err(e) = self.reader.feed(read_buf.filled()) {
756                        self.state = ConnectionState::Error;
757                        return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
758                    }
759                }
760                Err(e) => {
761                    self.state = ConnectionState::Error;
762                    return Outcome::Err(match e.kind() {
763                        std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock => {
764                            Error::Timeout
765                        }
766                        _ => Error::Connection(ConnectionError {
767                            kind: ConnectionErrorKind::Disconnected,
768                            message: format!("Failed to read from server: {}", e),
769                            source: Some(Box::new(e)),
770                        }),
771                    });
772                }
773            }
774        }
775    }
776}
777
778/// Shared, cloneable PostgreSQL connection with interior mutability.
779pub struct SharedPgConnection {
780    inner: Arc<Mutex<PgAsyncConnection>>,
781}
782
783impl SharedPgConnection {
784    pub fn new(conn: PgAsyncConnection) -> Self {
785        Self {
786            inner: Arc::new(Mutex::new(conn)),
787        }
788    }
789
790    pub async fn connect(cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
791        match PgAsyncConnection::connect(cx, config).await {
792            Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
793            Outcome::Err(e) => Outcome::Err(e),
794            Outcome::Cancelled(r) => Outcome::Cancelled(r),
795            Outcome::Panicked(p) => Outcome::Panicked(p),
796        }
797    }
798
799    pub fn inner(&self) -> &Arc<Mutex<PgAsyncConnection>> {
800        &self.inner
801    }
802
803    async fn begin_transaction_impl(
804        &self,
805        cx: &Cx,
806        isolation: Option<IsolationLevel>,
807    ) -> Outcome<SharedPgTransaction<'_>, Error> {
808        let inner = Arc::clone(&self.inner);
809        let Ok(mut guard) = inner.lock(cx).await else {
810            return Outcome::Err(connection_error("Failed to acquire connection lock"));
811        };
812
813        if let Some(level) = isolation {
814            let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
815            match guard.execute_async(cx, &sql, &[]).await {
816                Outcome::Ok(_) => {}
817                Outcome::Err(e) => return Outcome::Err(e),
818                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
819                Outcome::Panicked(p) => return Outcome::Panicked(p),
820            }
821        }
822
823        match guard.execute_async(cx, "BEGIN", &[]).await {
824            Outcome::Ok(_) => {}
825            Outcome::Err(e) => return Outcome::Err(e),
826            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
827            Outcome::Panicked(p) => return Outcome::Panicked(p),
828        }
829
830        drop(guard);
831        Outcome::Ok(SharedPgTransaction {
832            inner,
833            committed: false,
834            _marker: std::marker::PhantomData,
835        })
836    }
837}
838
839impl Clone for SharedPgConnection {
840    fn clone(&self) -> Self {
841        Self {
842            inner: Arc::clone(&self.inner),
843        }
844    }
845}
846
847impl std::fmt::Debug for SharedPgConnection {
848    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849        f.debug_struct("SharedPgConnection")
850            .field("inner", &"Arc<Mutex<PgAsyncConnection>>")
851            .finish()
852    }
853}
854
855pub struct SharedPgTransaction<'conn> {
856    inner: Arc<Mutex<PgAsyncConnection>>,
857    committed: bool,
858    _marker: std::marker::PhantomData<&'conn ()>,
859}
860
861impl<'conn> Drop for SharedPgTransaction<'conn> {
862    fn drop(&mut self) {
863        if !self.committed {
864            // WARNING: Transaction was dropped without commit() or rollback()!
865            // We cannot do async work in Drop, so the PostgreSQL transaction will
866            // remain open until the connection is closed or a new transaction
867            // is started.
868            #[cfg(debug_assertions)]
869            eprintln!(
870                "WARNING: SharedPgTransaction dropped without commit/rollback. \
871                 The PostgreSQL transaction may still be open."
872            );
873        }
874    }
875}
876
877impl Connection for SharedPgConnection {
878    type Tx<'conn>
879        = SharedPgTransaction<'conn>
880    where
881        Self: 'conn;
882
883    fn query(
884        &self,
885        cx: &Cx,
886        sql: &str,
887        params: &[Value],
888    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
889        let inner = Arc::clone(&self.inner);
890        let sql = sql.to_string();
891        let params = params.to_vec();
892        async move {
893            let Ok(mut guard) = inner.lock(cx).await else {
894                return Outcome::Err(connection_error("Failed to acquire connection lock"));
895            };
896            guard.query_async(cx, &sql, &params).await
897        }
898    }
899
900    fn query_one(
901        &self,
902        cx: &Cx,
903        sql: &str,
904        params: &[Value],
905    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
906        let inner = Arc::clone(&self.inner);
907        let sql = sql.to_string();
908        let params = params.to_vec();
909        async move {
910            let Ok(mut guard) = inner.lock(cx).await else {
911                return Outcome::Err(connection_error("Failed to acquire connection lock"));
912            };
913            let rows = match guard.query_async(cx, &sql, &params).await {
914                Outcome::Ok(r) => r,
915                Outcome::Err(e) => return Outcome::Err(e),
916                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
917                Outcome::Panicked(p) => return Outcome::Panicked(p),
918            };
919            Outcome::Ok(rows.into_iter().next())
920        }
921    }
922
923    fn execute(
924        &self,
925        cx: &Cx,
926        sql: &str,
927        params: &[Value],
928    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
929        let inner = Arc::clone(&self.inner);
930        let sql = sql.to_string();
931        let params = params.to_vec();
932        async move {
933            let Ok(mut guard) = inner.lock(cx).await else {
934                return Outcome::Err(connection_error("Failed to acquire connection lock"));
935            };
936            guard.execute_async(cx, &sql, &params).await
937        }
938    }
939
940    fn insert(
941        &self,
942        cx: &Cx,
943        sql: &str,
944        params: &[Value],
945    ) -> impl Future<Output = Outcome<i64, Error>> + Send {
946        let inner = Arc::clone(&self.inner);
947        let sql = sql.to_string();
948        let params = params.to_vec();
949        async move {
950            let Ok(mut guard) = inner.lock(cx).await else {
951                return Outcome::Err(connection_error("Failed to acquire connection lock"));
952            };
953            guard.insert_async(cx, &sql, &params).await
954        }
955    }
956
957    fn batch(
958        &self,
959        cx: &Cx,
960        statements: &[(String, Vec<Value>)],
961    ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
962        let inner = Arc::clone(&self.inner);
963        let statements = statements.to_vec();
964        async move {
965            let Ok(mut guard) = inner.lock(cx).await else {
966                return Outcome::Err(connection_error("Failed to acquire connection lock"));
967            };
968            let mut results = Vec::with_capacity(statements.len());
969            for (sql, params) in &statements {
970                match guard.execute_async(cx, sql, params).await {
971                    Outcome::Ok(n) => results.push(n),
972                    Outcome::Err(e) => return Outcome::Err(e),
973                    Outcome::Cancelled(r) => return Outcome::Cancelled(r),
974                    Outcome::Panicked(p) => return Outcome::Panicked(p),
975                }
976            }
977            Outcome::Ok(results)
978        }
979    }
980
981    fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
982        self.begin_with(cx, IsolationLevel::default())
983    }
984
985    fn begin_with(
986        &self,
987        cx: &Cx,
988        isolation: IsolationLevel,
989    ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
990        self.begin_transaction_impl(cx, Some(isolation))
991    }
992
993    fn prepare(
994        &self,
995        _cx: &Cx,
996        sql: &str,
997    ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
998        let sql = sql.to_string();
999        async move {
1000            // Note: Client-side prepared statement stub. Server-side prepared statements
1001            // (PostgreSQL PREPARE/EXECUTE) can be added later for performance optimization.
1002            // Current implementation passes through to regular query execution.
1003            Outcome::Ok(PreparedStatement::new(0, sql, 0))
1004        }
1005    }
1006
1007    fn query_prepared(
1008        &self,
1009        cx: &Cx,
1010        stmt: &PreparedStatement,
1011        params: &[Value],
1012    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1013        self.query(cx, stmt.sql(), params)
1014    }
1015
1016    fn execute_prepared(
1017        &self,
1018        cx: &Cx,
1019        stmt: &PreparedStatement,
1020        params: &[Value],
1021    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1022        self.execute(cx, stmt.sql(), params)
1023    }
1024
1025    fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1026        let inner = Arc::clone(&self.inner);
1027        async move {
1028            let Ok(mut guard) = inner.lock(cx).await else {
1029                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1030            };
1031            guard.ping_async(cx).await
1032        }
1033    }
1034
1035    async fn close(self, cx: &Cx) -> sqlmodel_core::Result<()> {
1036        let Ok(mut guard) = self.inner.lock(cx).await else {
1037            return Err(connection_error("Failed to acquire connection lock"));
1038        };
1039        match guard.close_async(cx).await {
1040            Outcome::Ok(()) => Ok(()),
1041            Outcome::Err(e) => Err(e),
1042            Outcome::Cancelled(r) => Err(Error::Query(QueryError {
1043                kind: QueryErrorKind::Cancelled,
1044                message: format!("Cancelled: {r:?}"),
1045                sqlstate: None,
1046                sql: None,
1047                detail: None,
1048                hint: None,
1049                position: None,
1050                source: None,
1051            })),
1052            Outcome::Panicked(p) => Err(Error::Protocol(ProtocolError {
1053                message: format!("Panicked: {p:?}"),
1054                raw_data: None,
1055                source: None,
1056            })),
1057        }
1058    }
1059}
1060
1061impl<'conn> TransactionOps for SharedPgTransaction<'conn> {
1062    fn query(
1063        &self,
1064        cx: &Cx,
1065        sql: &str,
1066        params: &[Value],
1067    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1068        let inner = Arc::clone(&self.inner);
1069        let sql = sql.to_string();
1070        let params = params.to_vec();
1071        async move {
1072            let Ok(mut guard) = inner.lock(cx).await else {
1073                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1074            };
1075            guard.query_async(cx, &sql, &params).await
1076        }
1077    }
1078
1079    fn query_one(
1080        &self,
1081        cx: &Cx,
1082        sql: &str,
1083        params: &[Value],
1084    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1085        let inner = Arc::clone(&self.inner);
1086        let sql = sql.to_string();
1087        let params = params.to_vec();
1088        async move {
1089            let Ok(mut guard) = inner.lock(cx).await else {
1090                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1091            };
1092            let rows = match guard.query_async(cx, &sql, &params).await {
1093                Outcome::Ok(r) => r,
1094                Outcome::Err(e) => return Outcome::Err(e),
1095                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1096                Outcome::Panicked(p) => return Outcome::Panicked(p),
1097            };
1098            Outcome::Ok(rows.into_iter().next())
1099        }
1100    }
1101
1102    fn execute(
1103        &self,
1104        cx: &Cx,
1105        sql: &str,
1106        params: &[Value],
1107    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1108        let inner = Arc::clone(&self.inner);
1109        let sql = sql.to_string();
1110        let params = params.to_vec();
1111        async move {
1112            let Ok(mut guard) = inner.lock(cx).await else {
1113                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1114            };
1115            guard.execute_async(cx, &sql, &params).await
1116        }
1117    }
1118
1119    fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1120        let inner = Arc::clone(&self.inner);
1121        let name = name.to_string();
1122        async move {
1123            if let Err(e) = validate_savepoint_name(&name) {
1124                return Outcome::Err(e);
1125            }
1126            let sql = format!("SAVEPOINT {}", name);
1127            let Ok(mut guard) = inner.lock(cx).await else {
1128                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1129            };
1130            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1131        }
1132    }
1133
1134    fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1135        let inner = Arc::clone(&self.inner);
1136        let name = name.to_string();
1137        async move {
1138            if let Err(e) = validate_savepoint_name(&name) {
1139                return Outcome::Err(e);
1140            }
1141            let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
1142            let Ok(mut guard) = inner.lock(cx).await else {
1143                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1144            };
1145            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1146        }
1147    }
1148
1149    fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1150        let inner = Arc::clone(&self.inner);
1151        let name = name.to_string();
1152        async move {
1153            if let Err(e) = validate_savepoint_name(&name) {
1154                return Outcome::Err(e);
1155            }
1156            let sql = format!("RELEASE SAVEPOINT {}", name);
1157            let Ok(mut guard) = inner.lock(cx).await else {
1158                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1159            };
1160            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1161        }
1162    }
1163
1164    // Note: clippy sometimes flags `self.committed = true` as unused, but Drop reads it.
1165    #[allow(unused_assignments)]
1166    fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1167        let inner = Arc::clone(&self.inner);
1168        async move {
1169            let Ok(mut guard) = inner.lock(cx).await else {
1170                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1171            };
1172            let result = guard.execute_async(cx, "COMMIT", &[]).await;
1173            if matches!(result, Outcome::Ok(_)) {
1174                self.committed = true;
1175            }
1176            result.map(|_| ())
1177        }
1178    }
1179
1180    #[allow(unused_assignments)]
1181    fn rollback(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1182        let inner = Arc::clone(&self.inner);
1183        async move {
1184            let Ok(mut guard) = inner.lock(cx).await else {
1185                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1186            };
1187            let result = guard.execute_async(cx, "ROLLBACK", &[]).await;
1188            if matches!(result, Outcome::Ok(_)) {
1189                self.committed = true;
1190            }
1191            result.map(|_| ())
1192        }
1193    }
1194}
1195
1196// ==================== Helpers ====================
1197
1198struct PgQueryResult {
1199    rows: Vec<Row>,
1200    command_tag: Option<String>,
1201}
1202
1203fn connection_error(msg: impl Into<String>) -> Error {
1204    Error::Connection(ConnectionError {
1205        kind: ConnectionErrorKind::Connect,
1206        message: msg.into(),
1207        source: None,
1208    })
1209}
1210
1211fn auth_error(msg: impl Into<String>) -> Error {
1212    Error::Connection(ConnectionError {
1213        kind: ConnectionErrorKind::Authentication,
1214        message: msg.into(),
1215        source: None,
1216    })
1217}
1218
1219fn protocol_error(msg: impl Into<String>) -> Error {
1220    Error::Protocol(ProtocolError {
1221        message: msg.into(),
1222        raw_data: None,
1223        source: None,
1224    })
1225}
1226
1227fn query_error_msg(msg: impl Into<String>, kind: QueryErrorKind) -> Error {
1228    Error::Query(QueryError {
1229        kind,
1230        message: msg.into(),
1231        sqlstate: None,
1232        sql: None,
1233        detail: None,
1234        hint: None,
1235        position: None,
1236        source: None,
1237    })
1238}
1239
1240fn error_from_fields(fields: &ErrorFields) -> Error {
1241    let kind = match fields.code.get(..2) {
1242        Some("08") => {
1243            return Error::Connection(ConnectionError {
1244                kind: ConnectionErrorKind::Connect,
1245                message: fields.message.clone(),
1246                source: None,
1247            });
1248        }
1249        Some("28") => {
1250            return Error::Connection(ConnectionError {
1251                kind: ConnectionErrorKind::Authentication,
1252                message: fields.message.clone(),
1253                source: None,
1254            });
1255        }
1256        Some("42") => QueryErrorKind::Syntax,
1257        Some("23") => QueryErrorKind::Constraint,
1258        Some("40") => {
1259            if fields.code == "40001" {
1260                QueryErrorKind::Serialization
1261            } else {
1262                QueryErrorKind::Deadlock
1263            }
1264        }
1265        Some("57") => {
1266            if fields.code == "57014" {
1267                QueryErrorKind::Cancelled
1268            } else {
1269                QueryErrorKind::Timeout
1270            }
1271        }
1272        _ => QueryErrorKind::Database,
1273    };
1274
1275    Error::Query(QueryError {
1276        kind,
1277        sql: None,
1278        sqlstate: Some(fields.code.clone()),
1279        message: fields.message.clone(),
1280        detail: fields.detail.clone(),
1281        hint: fields.hint.clone(),
1282        position: fields.position.map(|p| p as usize),
1283        source: None,
1284    })
1285}
1286
1287fn parse_rows_affected(tag: Option<&str>) -> Option<u64> {
1288    let tag = tag?;
1289    let mut parts = tag.split_whitespace().collect::<Vec<_>>();
1290    parts.pop().and_then(|last| last.parse::<u64>().ok())
1291}
1292
1293/// Validate a savepoint name to reduce SQL injection risk.
1294fn validate_savepoint_name(name: &str) -> sqlmodel_core::Result<()> {
1295    if name.is_empty() {
1296        return Err(query_error_msg(
1297            "Savepoint name cannot be empty",
1298            QueryErrorKind::Syntax,
1299        ));
1300    }
1301    if name.len() > 63 {
1302        return Err(query_error_msg(
1303            "Savepoint name exceeds maximum length of 63 characters",
1304            QueryErrorKind::Syntax,
1305        ));
1306    }
1307    let mut chars = name.chars();
1308    let Some(first) = chars.next() else {
1309        return Err(query_error_msg(
1310            "Savepoint name cannot be empty",
1311            QueryErrorKind::Syntax,
1312        ));
1313    };
1314    if !first.is_ascii_alphabetic() && first != '_' {
1315        return Err(query_error_msg(
1316            "Savepoint name must start with a letter or underscore",
1317            QueryErrorKind::Syntax,
1318        ));
1319    }
1320    for c in chars {
1321        if !c.is_ascii_alphanumeric() && c != '_' {
1322            return Err(query_error_msg(
1323                format!("Savepoint name contains invalid character: '{c}'"),
1324                QueryErrorKind::Syntax,
1325            ));
1326        }
1327    }
1328    Ok(())
1329}
1330
1331fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
1332    use std::fmt::Write;
1333
1334    let inner = format!("{password}{user}");
1335    let inner_hash = md5::compute(inner.as_bytes());
1336
1337    let mut outer_input = format!("{inner_hash:x}").into_bytes();
1338    outer_input.extend_from_slice(&salt);
1339    let outer_hash = md5::compute(&outer_input);
1340
1341    let mut result = String::with_capacity(35);
1342    result.push_str("md5");
1343    write!(&mut result, "{outer_hash:x}").unwrap();
1344    result
1345}
1346
1347async fn read_exact_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<()> {
1348    let mut read = 0;
1349    while read < buf.len() {
1350        let mut read_buf = ReadBuf::new(&mut buf[read..]);
1351        std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
1352            .await?;
1353        let n = read_buf.filled().len();
1354        if n == 0 {
1355            return Err(std::io::Error::new(
1356                std::io::ErrorKind::UnexpectedEof,
1357                "connection closed",
1358            ));
1359        }
1360        read += n;
1361    }
1362    Ok(())
1363}