Skip to main content

sqlmodel_postgres/
connection.rs

1//! PostgreSQL connection implementation.
2//!
3//! This module implements the PostgreSQL wire protocol connection,
4//! including connection establishment, authentication, and state management.
5//!
6//! # Console Integration
7//!
8//! When the `console` feature is enabled, the connection can report progress
9//! during connection establishment. Use the `ConsoleAware` trait to attach
10//! a console for rich output.
11//!
12//! ```rust,ignore
13//! use sqlmodel_postgres::{PgConfig, PgConnection};
14//! use sqlmodel_console::{SqlModelConsole, ConsoleAware};
15//! use std::sync::Arc;
16//!
17//! let console = Arc::new(SqlModelConsole::new());
18//! let mut conn = PgConnection::connect(config)?;
19//! conn.set_console(Some(console));
20//! ```
21
22use std::collections::HashMap;
23use std::io::{Read, Write};
24use std::net::TcpStream;
25#[cfg(feature = "console")]
26use std::sync::Arc;
27
28use sqlmodel_core::Error;
29use sqlmodel_core::error::{
30    ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
31};
32
33#[cfg(feature = "console")]
34use sqlmodel_console::{ConsoleAware, SqlModelConsole};
35
36use crate::auth::ScramClient;
37use crate::config::PgConfig;
38#[cfg(not(feature = "tls"))]
39use crate::config::SslMode;
40use crate::protocol::{
41    BackendMessage, ErrorFields, FrontendMessage, MessageReader, MessageWriter, PROTOCOL_VERSION,
42    TransactionStatus,
43};
44
45#[cfg(feature = "tls")]
46use crate::tls;
47
48enum PgStream {
49    Plain(TcpStream),
50    #[cfg(feature = "tls")]
51    Tls(rustls::StreamOwned<rustls::ClientConnection, TcpStream>),
52    #[cfg(feature = "tls")]
53    Closed,
54}
55
56impl PgStream {
57    #[cfg(feature = "tls")]
58    fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59        match self {
60            PgStream::Plain(s) => s.read_exact(buf),
61            #[cfg(feature = "tls")]
62            PgStream::Tls(s) => s.read_exact(buf),
63            #[cfg(feature = "tls")]
64            PgStream::Closed => Err(std::io::Error::new(
65                std::io::ErrorKind::NotConnected,
66                "connection closed",
67            )),
68        }
69    }
70
71    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
72        match self {
73            PgStream::Plain(s) => s.read(buf),
74            #[cfg(feature = "tls")]
75            PgStream::Tls(s) => s.read(buf),
76            #[cfg(feature = "tls")]
77            PgStream::Closed => Err(std::io::Error::new(
78                std::io::ErrorKind::NotConnected,
79                "connection closed",
80            )),
81        }
82    }
83
84    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
85        match self {
86            PgStream::Plain(s) => s.write_all(buf),
87            #[cfg(feature = "tls")]
88            PgStream::Tls(s) => s.write_all(buf),
89            #[cfg(feature = "tls")]
90            PgStream::Closed => Err(std::io::Error::new(
91                std::io::ErrorKind::NotConnected,
92                "connection closed",
93            )),
94        }
95    }
96
97    fn flush(&mut self) -> std::io::Result<()> {
98        match self {
99            PgStream::Plain(s) => s.flush(),
100            #[cfg(feature = "tls")]
101            PgStream::Tls(s) => s.flush(),
102            #[cfg(feature = "tls")]
103            PgStream::Closed => Err(std::io::Error::new(
104                std::io::ErrorKind::NotConnected,
105                "connection closed",
106            )),
107        }
108    }
109}
110
111/// Connection state in the PostgreSQL protocol state machine.
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum ConnectionState {
114    /// Not connected
115    Disconnected,
116    /// TCP connection established, sending startup
117    Connecting,
118    /// Performing authentication handshake
119    Authenticating,
120    /// Ready for queries
121    Ready(TransactionStatusState),
122    /// Currently executing a query
123    InQuery,
124    /// In a transaction block
125    InTransaction(TransactionStatusState),
126    /// Connection is in an error state
127    Error,
128    /// Connection has been closed
129    Closed,
130}
131
132/// Transaction status from the server.
133#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
134pub enum TransactionStatusState {
135    /// Not in a transaction block ('I')
136    #[default]
137    Idle,
138    /// In a transaction block ('T')
139    InTransaction,
140    /// In a failed transaction block ('E')
141    InFailed,
142}
143
144impl From<TransactionStatus> for TransactionStatusState {
145    fn from(status: TransactionStatus) -> Self {
146        match status {
147            TransactionStatus::Idle => TransactionStatusState::Idle,
148            TransactionStatus::Transaction => TransactionStatusState::InTransaction,
149            TransactionStatus::Error => TransactionStatusState::InFailed,
150        }
151    }
152}
153
154/// PostgreSQL connection.
155///
156/// Manages a TCP connection to a PostgreSQL server, handling the wire protocol,
157/// authentication, and state tracking.
158///
159/// # Console Support
160///
161/// When the `console` feature is enabled, the connection can report progress
162/// via an attached `SqlModelConsole`. This provides rich feedback during
163/// connection establishment and query execution.
164pub struct PgConnection {
165    /// TCP stream to the server
166    stream: PgStream,
167    /// Current connection state
168    state: ConnectionState,
169    /// Backend process ID (for query cancellation)
170    process_id: i32,
171    /// Secret key (for query cancellation)
172    secret_key: i32,
173    /// Server parameters received during startup
174    parameters: HashMap<String, String>,
175    /// Connection configuration
176    config: PgConfig,
177    /// Message reader for parsing backend messages
178    reader: MessageReader,
179    /// Message writer for encoding frontend messages
180    writer: MessageWriter,
181    /// Read buffer
182    read_buf: Vec<u8>,
183    /// Optional console for rich output
184    #[cfg(feature = "console")]
185    console: Option<Arc<SqlModelConsole>>,
186}
187
188impl std::fmt::Debug for PgConnection {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("PgConnection")
191            .field("state", &self.state)
192            .field("process_id", &self.process_id)
193            .field("host", &self.config.host)
194            .field("port", &self.config.port)
195            .field("database", &self.config.database)
196            .finish_non_exhaustive()
197    }
198}
199
200impl PgConnection {
201    /// Establish a new connection to the PostgreSQL server.
202    ///
203    /// This performs the complete connection handshake:
204    /// 1. TCP connection
205    /// 2. SSL negotiation (if configured)
206    /// 3. Startup message
207    /// 4. Authentication
208    /// 5. Receive server parameters and ReadyForQuery
209    #[allow(clippy::result_large_err)]
210    pub fn connect(config: PgConfig) -> Result<Self, Error> {
211        // 1. TCP connection with timeout
212        let stream = TcpStream::connect_timeout(
213            &config.socket_addr().parse().map_err(|e| {
214                Error::Connection(ConnectionError {
215                    kind: ConnectionErrorKind::Connect,
216                    message: format!("Invalid socket address: {}", e),
217                    source: None,
218                })
219            })?,
220            config.connect_timeout,
221        )
222        .map_err(|e| {
223            let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
224                ConnectionErrorKind::Refused
225            } else {
226                ConnectionErrorKind::Connect
227            };
228            Error::Connection(ConnectionError {
229                kind,
230                message: format!("Failed to connect to {}: {}", config.socket_addr(), e),
231                source: Some(Box::new(e)),
232            })
233        })?;
234
235        // Set TCP options
236        stream.set_nodelay(true).ok();
237        stream.set_read_timeout(Some(config.connect_timeout)).ok();
238        stream.set_write_timeout(Some(config.connect_timeout)).ok();
239
240        let mut conn = Self {
241            stream: PgStream::Plain(stream),
242            state: ConnectionState::Connecting,
243            process_id: 0,
244            secret_key: 0,
245            parameters: HashMap::new(),
246            config,
247            reader: MessageReader::new(),
248            writer: MessageWriter::new(),
249            read_buf: vec![0u8; 8192],
250            #[cfg(feature = "console")]
251            console: None,
252        };
253
254        // 2. SSL negotiation (if configured)
255        if conn.config.ssl_mode.should_try_ssl() {
256            #[cfg(feature = "tls")]
257            conn.negotiate_ssl()?;
258
259            #[cfg(not(feature = "tls"))]
260            if conn.config.ssl_mode != SslMode::Prefer {
261                return Err(Error::Connection(ConnectionError {
262                    kind: ConnectionErrorKind::Ssl,
263                    message:
264                        "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'"
265                            .to_string(),
266                    source: None,
267                }));
268            }
269        }
270
271        // 3. Send startup message
272        conn.send_startup()?;
273        conn.state = ConnectionState::Authenticating;
274
275        // 4. Handle authentication
276        conn.handle_auth()?;
277
278        // 5. Read remaining startup messages until ReadyForQuery
279        conn.read_startup_messages()?;
280
281        Ok(conn)
282    }
283
284    /// Get the current connection state.
285    pub fn state(&self) -> ConnectionState {
286        self.state
287    }
288
289    /// Check if the connection is ready for queries.
290    pub fn is_ready(&self) -> bool {
291        matches!(self.state, ConnectionState::Ready(_))
292    }
293
294    /// Get the backend process ID (for query cancellation).
295    pub fn process_id(&self) -> i32 {
296        self.process_id
297    }
298
299    /// Get the secret key (for query cancellation).
300    pub fn secret_key(&self) -> i32 {
301        self.secret_key
302    }
303
304    /// Get a server parameter value.
305    pub fn parameter(&self, name: &str) -> Option<&str> {
306        self.parameters.get(name).map(|s| s.as_str())
307    }
308
309    /// Get all server parameters.
310    pub fn parameters(&self) -> &HashMap<String, String> {
311        &self.parameters
312    }
313
314    /// Close the connection gracefully.
315    #[allow(clippy::result_large_err)]
316    pub fn close(&mut self) -> Result<(), Error> {
317        if matches!(
318            self.state,
319            ConnectionState::Closed | ConnectionState::Disconnected
320        ) {
321            return Ok(());
322        }
323
324        // Send Terminate message
325        self.send_message(&FrontendMessage::Terminate)?;
326        self.state = ConnectionState::Closed;
327        Ok(())
328    }
329
330    // ==================== SSL Negotiation ====================
331
332    #[allow(clippy::result_large_err)]
333    #[cfg(feature = "tls")]
334    fn negotiate_ssl(&mut self) -> Result<(), Error> {
335        // Send SSL request
336        self.send_message(&FrontendMessage::SSLRequest)?;
337
338        // Read single-byte response
339        let mut buf = [0u8; 1];
340        self.stream.read_exact(&mut buf).map_err(|e| {
341            Error::Connection(ConnectionError {
342                kind: ConnectionErrorKind::Ssl,
343                message: format!("Failed to read SSL response: {}", e),
344                source: Some(Box::new(e)),
345            })
346        })?;
347
348        match buf[0] {
349            b'S' => {
350                // Server supports SSL; upgrade to TLS.
351                #[cfg(feature = "tls")]
352                {
353                    let plain = match std::mem::replace(&mut self.stream, PgStream::Closed) {
354                        PgStream::Plain(s) => s,
355                        other => {
356                            self.stream = other;
357                            return Err(Error::Connection(ConnectionError {
358                                kind: ConnectionErrorKind::Ssl,
359                                message: "TLS upgrade requires a plain TCP stream".to_string(),
360                                source: None,
361                            }));
362                        }
363                    };
364
365                    let config = tls::build_client_config(self.config.ssl_mode)?;
366                    let server_name = tls::server_name(&self.config.host)?;
367                    let conn =
368                        rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
369                            .map_err(|e| {
370                                Error::Connection(ConnectionError {
371                                    kind: ConnectionErrorKind::Ssl,
372                                    message: format!("Failed to create TLS connection: {e}"),
373                                    source: None,
374                                })
375                            })?;
376
377                    let mut tls_stream = rustls::StreamOwned::new(conn, plain);
378                    while tls_stream.conn.is_handshaking() {
379                        tls_stream
380                            .conn
381                            .complete_io(&mut tls_stream.sock)
382                            .map_err(|e| {
383                                Error::Connection(ConnectionError {
384                                    kind: ConnectionErrorKind::Ssl,
385                                    message: format!("TLS handshake failed: {e}"),
386                                    source: Some(Box::new(e)),
387                                })
388                            })?;
389                    }
390
391                    self.stream = PgStream::Tls(tls_stream);
392                    Ok(())
393                }
394
395                #[cfg(not(feature = "tls"))]
396                {
397                    Err(Error::Connection(ConnectionError {
398                        kind: ConnectionErrorKind::Ssl,
399                        message:
400                            "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'"
401                                .to_string(),
402                        source: None,
403                    }))
404                }
405            }
406            b'N' => {
407                // Server doesn't support SSL
408                if self.config.ssl_mode.is_required() {
409                    return Err(Error::Connection(ConnectionError {
410                        kind: ConnectionErrorKind::Ssl,
411                        message: "Server does not support SSL".to_string(),
412                        source: None,
413                    }));
414                }
415                // Continue without SSL (prefer mode)
416                Ok(())
417            }
418            _ => Err(Error::Connection(ConnectionError {
419                kind: ConnectionErrorKind::Ssl,
420                message: format!("Unexpected SSL response: 0x{:02x}", buf[0]),
421                source: None,
422            })),
423        }
424    }
425
426    // ==================== Startup ====================
427
428    #[allow(clippy::result_large_err)]
429    fn send_startup(&mut self) -> Result<(), Error> {
430        let params = self.config.startup_params();
431        let msg = FrontendMessage::Startup {
432            version: PROTOCOL_VERSION,
433            params,
434        };
435        self.send_message(&msg)
436    }
437
438    // ==================== Authentication ====================
439
440    #[allow(clippy::result_large_err)]
441    fn require_auth_value(&self, message: &'static str) -> Result<&str, Error> {
442        // NOTE: Auth values are sourced from runtime config, not hardcoded.
443        self.config
444            .password
445            .as_deref()
446            .ok_or_else(|| auth_error(message))
447    }
448
449    #[allow(clippy::result_large_err)]
450    fn handle_auth(&mut self) -> Result<(), Error> {
451        loop {
452            let msg = self.receive_message()?;
453
454            match msg {
455                BackendMessage::AuthenticationOk => {
456                    return Ok(());
457                }
458                BackendMessage::AuthenticationCleartextPassword => {
459                    let auth_value =
460                        self.require_auth_value("Authentication value required but not provided")?;
461                    self.send_message(&FrontendMessage::PasswordMessage(auth_value.to_string()))?;
462                }
463                BackendMessage::AuthenticationMD5Password(salt) => {
464                    let auth_value =
465                        self.require_auth_value("Authentication value required but not provided")?;
466                    let hash = md5_password(&self.config.user, auth_value, salt);
467                    self.send_message(&FrontendMessage::PasswordMessage(hash))?;
468                }
469                BackendMessage::AuthenticationSASL(mechanisms) => {
470                    if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
471                        self.scram_auth()?;
472                    } else {
473                        return Err(auth_error(format!(
474                            "Unsupported SASL mechanisms: {:?}",
475                            mechanisms
476                        )));
477                    }
478                }
479                BackendMessage::ErrorResponse(e) => {
480                    self.state = ConnectionState::Error;
481                    return Err(error_from_fields(&e));
482                }
483                _ => {
484                    return Err(Error::Protocol(ProtocolError {
485                        message: format!("Unexpected message during auth: {:?}", msg),
486                        raw_data: None,
487                        source: None,
488                    }));
489                }
490            }
491        }
492    }
493
494    #[allow(clippy::result_large_err)]
495    fn scram_auth(&mut self) -> Result<(), Error> {
496        let auth_value =
497            self.require_auth_value("Authentication value required for SCRAM-SHA-256")?;
498
499        let mut client = ScramClient::new(&self.config.user, auth_value);
500
501        // Send client-first message
502        let client_first = client.client_first();
503        self.send_message(&FrontendMessage::SASLInitialResponse {
504            mechanism: "SCRAM-SHA-256".to_string(),
505            data: client_first,
506        })?;
507
508        // Receive server-first
509        let msg = self.receive_message()?;
510        let server_first_data = match msg {
511            BackendMessage::AuthenticationSASLContinue(data) => data,
512            BackendMessage::ErrorResponse(e) => {
513                self.state = ConnectionState::Error;
514                return Err(error_from_fields(&e));
515            }
516            _ => {
517                return Err(Error::Protocol(ProtocolError {
518                    message: format!("Expected SASL continue, got: {:?}", msg),
519                    raw_data: None,
520                    source: None,
521                }));
522            }
523        };
524
525        // Generate and send client-final
526        let client_final = client.process_server_first(&server_first_data)?;
527        self.send_message(&FrontendMessage::SASLResponse(client_final))?;
528
529        // Receive server-final
530        let msg = self.receive_message()?;
531        let server_final_data = match msg {
532            BackendMessage::AuthenticationSASLFinal(data) => data,
533            BackendMessage::ErrorResponse(e) => {
534                self.state = ConnectionState::Error;
535                return Err(error_from_fields(&e));
536            }
537            _ => {
538                return Err(Error::Protocol(ProtocolError {
539                    message: format!("Expected SASL final, got: {:?}", msg),
540                    raw_data: None,
541                    source: None,
542                }));
543            }
544        };
545
546        // Verify server signature
547        client.verify_server_final(&server_final_data)?;
548
549        // Wait for AuthenticationOk
550        let msg = self.receive_message()?;
551        match msg {
552            BackendMessage::AuthenticationOk => Ok(()),
553            BackendMessage::ErrorResponse(e) => {
554                self.state = ConnectionState::Error;
555                Err(error_from_fields(&e))
556            }
557            _ => Err(Error::Protocol(ProtocolError {
558                message: format!("Expected AuthenticationOk, got: {:?}", msg),
559                raw_data: None,
560                source: None,
561            })),
562        }
563    }
564
565    // ==================== Startup Messages ====================
566
567    #[allow(clippy::result_large_err)]
568    fn read_startup_messages(&mut self) -> Result<(), Error> {
569        loop {
570            let msg = self.receive_message()?;
571
572            match msg {
573                BackendMessage::BackendKeyData {
574                    process_id,
575                    secret_key,
576                } => {
577                    self.process_id = process_id;
578                    self.secret_key = secret_key;
579                }
580                BackendMessage::ParameterStatus { name, value } => {
581                    self.parameters.insert(name, value);
582                }
583                BackendMessage::ReadyForQuery(status) => {
584                    self.state = ConnectionState::Ready(status.into());
585                    return Ok(());
586                }
587                BackendMessage::ErrorResponse(e) => {
588                    self.state = ConnectionState::Error;
589                    return Err(error_from_fields(&e));
590                }
591                BackendMessage::NoticeResponse(_notice) => {
592                    // Log but continue - notices are informational
593                }
594                _ => {
595                    return Err(Error::Protocol(ProtocolError {
596                        message: format!("Unexpected startup message: {:?}", msg),
597                        raw_data: None,
598                        source: None,
599                    }));
600                }
601            }
602        }
603    }
604
605    // ==================== Low-Level I/O ====================
606
607    #[allow(clippy::result_large_err)]
608    fn send_message(&mut self, msg: &FrontendMessage) -> Result<(), Error> {
609        let data = self.writer.write(msg);
610        self.stream.write_all(data).map_err(|e| {
611            self.state = ConnectionState::Error;
612            Error::Io(e)
613        })?;
614        self.stream.flush().map_err(|e| {
615            self.state = ConnectionState::Error;
616            Error::Io(e)
617        })?;
618        Ok(())
619    }
620
621    #[allow(clippy::result_large_err)]
622    fn receive_message(&mut self) -> Result<BackendMessage, Error> {
623        // Try to parse any complete messages from buffer first
624        loop {
625            match self.reader.next_message() {
626                Ok(Some(msg)) => return Ok(msg),
627                Ok(None) => {
628                    // Need more data
629                    let n = self.stream.read(&mut self.read_buf).map_err(|e| {
630                        if e.kind() == std::io::ErrorKind::TimedOut
631                            || e.kind() == std::io::ErrorKind::WouldBlock
632                        {
633                            Error::Timeout
634                        } else {
635                            self.state = ConnectionState::Error;
636                            Error::Connection(ConnectionError {
637                                kind: ConnectionErrorKind::Disconnected,
638                                message: format!("Failed to read from server: {}", e),
639                                source: Some(Box::new(e)),
640                            })
641                        }
642                    })?;
643
644                    if n == 0 {
645                        self.state = ConnectionState::Disconnected;
646                        return Err(Error::Connection(ConnectionError {
647                            kind: ConnectionErrorKind::Disconnected,
648                            message: "Connection closed by server".to_string(),
649                            source: None,
650                        }));
651                    }
652
653                    // Feed data to reader
654                    self.reader.feed(&self.read_buf[..n]).map_err(|e| {
655                        Error::Protocol(ProtocolError {
656                            message: format!("Protocol error: {}", e),
657                            raw_data: None,
658                            source: None,
659                        })
660                    })?;
661                }
662                Err(e) => {
663                    self.state = ConnectionState::Error;
664                    return Err(Error::Protocol(ProtocolError {
665                        message: format!("Protocol error: {}", e),
666                        raw_data: None,
667                        source: None,
668                    }));
669                }
670            }
671        }
672    }
673}
674
675impl Drop for PgConnection {
676    fn drop(&mut self) {
677        // Try to close gracefully, ignore errors
678        let _ = self.close();
679    }
680}
681
682// ==================== Console Support ====================
683
684#[cfg(feature = "console")]
685impl ConsoleAware for PgConnection {
686    fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
687        self.console = console;
688    }
689
690    fn console(&self) -> Option<&Arc<SqlModelConsole>> {
691        self.console.as_ref()
692    }
693
694    fn has_console(&self) -> bool {
695        self.console.is_some()
696    }
697}
698
699/// Connection progress stage for console output.
700#[cfg(feature = "console")]
701#[derive(Debug, Clone, Copy, PartialEq, Eq)]
702pub enum ConnectionStage {
703    /// Resolving DNS
704    DnsResolve,
705    /// Establishing TCP connection
706    TcpConnect,
707    /// Negotiating SSL/TLS
708    SslNegotiate,
709    /// SSL/TLS established
710    SslEstablished,
711    /// Sending startup message
712    Startup,
713    /// Authenticating
714    Authenticating,
715    /// Authentication complete
716    Authenticated,
717    /// Ready for queries
718    Ready,
719}
720
721#[cfg(feature = "console")]
722impl ConnectionStage {
723    /// Get a human-readable description of the stage.
724    #[must_use]
725    pub fn description(&self) -> &'static str {
726        match self {
727            Self::DnsResolve => "Resolving DNS",
728            Self::TcpConnect => "Connecting (TCP)",
729            Self::SslNegotiate => "Negotiating SSL",
730            Self::SslEstablished => "SSL established",
731            Self::Startup => "Sending startup",
732            Self::Authenticating => "Authenticating",
733            Self::Authenticated => "Authenticated",
734            Self::Ready => "Ready",
735        }
736    }
737}
738
739#[cfg(feature = "console")]
740impl PgConnection {
741    /// Emit a connection progress message to the console.
742    ///
743    /// This is a no-op if no console is attached.
744    pub fn emit_progress(&self, stage: ConnectionStage, success: bool) {
745        if let Some(console) = &self.console {
746            let status = if success { "[OK]" } else { "[..] " };
747            let message = format!("{} {}", status, stage.description());
748            console.info(&message);
749        }
750    }
751
752    /// Emit a connection success message with server info.
753    pub fn emit_connected(&self) {
754        if let Some(console) = &self.console {
755            let server_version = self
756                .parameters
757                .get("server_version")
758                .map_or("unknown", |s| s.as_str());
759            let message = format!(
760                "Connected to PostgreSQL {} at {}:{}",
761                server_version, self.config.host, self.config.port
762            );
763            console.success(&message);
764        }
765    }
766
767    /// Emit a plain-text connection summary (for agent mode).
768    pub fn emit_connected_plain(&self) -> String {
769        let server_version = self
770            .parameters
771            .get("server_version")
772            .map_or("unknown", |s| s.as_str());
773        format!(
774            "Connected to PostgreSQL {} at {}:{}",
775            server_version, self.config.host, self.config.port
776        )
777    }
778}
779
780// ==================== Helper Functions ====================
781
782/// Compute MD5 password hash as per PostgreSQL protocol.
783fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
784    use std::fmt::Write;
785
786    // md5(md5(password + user) + salt)
787    let inner = format!("{}{}", password, user);
788    let inner_hash = md5::compute(inner.as_bytes());
789
790    let mut outer_input = format!("{:x}", inner_hash).into_bytes();
791    outer_input.extend_from_slice(&salt);
792    let outer_hash = md5::compute(&outer_input);
793
794    let mut result = String::with_capacity(35);
795    result.push_str("md5");
796    write!(&mut result, "{:x}", outer_hash).unwrap();
797    result
798}
799
800fn auth_error(msg: impl Into<String>) -> Error {
801    Error::Connection(ConnectionError {
802        kind: ConnectionErrorKind::Authentication,
803        message: msg.into(),
804        source: None,
805    })
806}
807
808fn error_from_fields(fields: &ErrorFields) -> Error {
809    // Determine error kind from SQLSTATE
810    let kind = match fields.code.get(..2) {
811        Some("08") => {
812            // Connection exception
813            return Error::Connection(ConnectionError {
814                kind: ConnectionErrorKind::Connect,
815                message: fields.message.clone(),
816                source: None,
817            });
818        }
819        Some("28") => {
820            // Invalid authorization specification
821            return Error::Connection(ConnectionError {
822                kind: ConnectionErrorKind::Authentication,
823                message: fields.message.clone(),
824                source: None,
825            });
826        }
827        Some("42") => QueryErrorKind::Syntax, // Syntax error or access rule violation
828        Some("23") => QueryErrorKind::Constraint, // Integrity constraint violation
829        Some("40") => {
830            if fields.code == "40001" {
831                QueryErrorKind::Serialization
832            } else {
833                QueryErrorKind::Deadlock
834            }
835        }
836        Some("57") => {
837            if fields.code == "57014" {
838                QueryErrorKind::Cancelled
839            } else {
840                QueryErrorKind::Timeout
841            }
842        }
843        _ => QueryErrorKind::Database,
844    };
845
846    Error::Query(QueryError {
847        kind,
848        sql: None,
849        sqlstate: Some(fields.code.clone()),
850        message: fields.message.clone(),
851        detail: fields.detail.clone(),
852        hint: fields.hint.clone(),
853        position: fields.position.map(|p| p as usize),
854        source: None,
855    })
856}
857
858#[cfg(test)]
859mod tests {
860    use super::*;
861
862    #[test]
863    fn test_md5_password() {
864        // Example from PostgreSQL documentation
865        let hash = md5_password("postgres", "mysecretpassword", *b"abcd");
866        assert!(hash.starts_with("md5"));
867        assert_eq!(hash.len(), 35); // "md5" + 32 hex chars
868    }
869
870    #[test]
871    fn test_transaction_status_conversion() {
872        assert_eq!(
873            TransactionStatusState::from(TransactionStatus::Idle),
874            TransactionStatusState::Idle
875        );
876        assert_eq!(
877            TransactionStatusState::from(TransactionStatus::Transaction),
878            TransactionStatusState::InTransaction
879        );
880        assert_eq!(
881            TransactionStatusState::from(TransactionStatus::Error),
882            TransactionStatusState::InFailed
883        );
884    }
885
886    #[test]
887    fn test_error_classification() {
888        let fields = ErrorFields {
889            severity: "ERROR".to_string(),
890            code: "23505".to_string(),
891            message: "unique violation".to_string(),
892            ..Default::default()
893        };
894        let err = error_from_fields(&fields);
895        assert!(matches!(err, Error::Query(q) if q.kind == QueryErrorKind::Constraint));
896
897        let fields = ErrorFields {
898            severity: "FATAL".to_string(),
899            code: "28P01".to_string(),
900            message: "password authentication failed".to_string(),
901            ..Default::default()
902        };
903        let err = error_from_fields(&fields);
904        assert!(matches!(
905            err,
906            Error::Connection(c) if c.kind == ConnectionErrorKind::Authentication
907        ));
908    }
909}