use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::TcpStream;
#[cfg(feature = "console")]
use std::sync::Arc;
use sqlmodel_core::Error;
use sqlmodel_core::error::{
ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
};
#[cfg(feature = "console")]
use sqlmodel_console::{ConsoleAware, SqlModelConsole};
use crate::auth::ScramClient;
use crate::config::PgConfig;
#[cfg(not(feature = "tls"))]
use crate::config::SslMode;
use crate::protocol::{
BackendMessage, ErrorFields, FrontendMessage, MessageReader, MessageWriter, PROTOCOL_VERSION,
TransactionStatus,
};
#[cfg(feature = "tls")]
use crate::tls;
enum PgStream {
Plain(TcpStream),
#[cfg(feature = "tls")]
Tls(rustls::StreamOwned<rustls::ClientConnection, TcpStream>),
#[cfg(feature = "tls")]
Closed,
}
impl PgStream {
#[cfg(feature = "tls")]
fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
match self {
PgStream::Plain(s) => s.read_exact(buf),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.read_exact(buf),
#[cfg(feature = "tls")]
PgStream::Closed => Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"connection closed",
)),
}
}
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
PgStream::Plain(s) => s.read(buf),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.read(buf),
#[cfg(feature = "tls")]
PgStream::Closed => Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"connection closed",
)),
}
}
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
match self {
PgStream::Plain(s) => s.write_all(buf),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.write_all(buf),
#[cfg(feature = "tls")]
PgStream::Closed => Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"connection closed",
)),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
PgStream::Plain(s) => s.flush(),
#[cfg(feature = "tls")]
PgStream::Tls(s) => s.flush(),
#[cfg(feature = "tls")]
PgStream::Closed => Err(std::io::Error::new(
std::io::ErrorKind::NotConnected,
"connection closed",
)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Authenticating,
Ready(TransactionStatusState),
InQuery,
InTransaction(TransactionStatusState),
Error,
Closed,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TransactionStatusState {
#[default]
Idle,
InTransaction,
InFailed,
}
impl From<TransactionStatus> for TransactionStatusState {
fn from(status: TransactionStatus) -> Self {
match status {
TransactionStatus::Idle => TransactionStatusState::Idle,
TransactionStatus::Transaction => TransactionStatusState::InTransaction,
TransactionStatus::Error => TransactionStatusState::InFailed,
}
}
}
pub struct PgConnection {
stream: PgStream,
state: ConnectionState,
process_id: i32,
secret_key: i32,
parameters: HashMap<String, String>,
config: PgConfig,
reader: MessageReader,
writer: MessageWriter,
read_buf: Vec<u8>,
#[cfg(feature = "console")]
console: Option<Arc<SqlModelConsole>>,
}
impl std::fmt::Debug for PgConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PgConnection")
.field("state", &self.state)
.field("process_id", &self.process_id)
.field("host", &self.config.host)
.field("port", &self.config.port)
.field("database", &self.config.database)
.finish_non_exhaustive()
}
}
impl PgConnection {
#[allow(clippy::result_large_err)]
pub fn connect(config: PgConfig) -> Result<Self, Error> {
let stream = TcpStream::connect_timeout(
&config.socket_addr().parse().map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Connect,
message: format!("Invalid socket address: {}", e),
source: None,
})
})?,
config.connect_timeout,
)
.map_err(|e| {
let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
ConnectionErrorKind::Refused
} else {
ConnectionErrorKind::Connect
};
Error::Connection(ConnectionError {
kind,
message: format!("Failed to connect to {}: {}", config.socket_addr(), e),
source: Some(Box::new(e)),
})
})?;
stream.set_nodelay(true).ok();
stream.set_read_timeout(Some(config.connect_timeout)).ok();
stream.set_write_timeout(Some(config.connect_timeout)).ok();
let mut conn = Self {
stream: PgStream::Plain(stream),
state: ConnectionState::Connecting,
process_id: 0,
secret_key: 0,
parameters: HashMap::new(),
config,
reader: MessageReader::new(),
writer: MessageWriter::new(),
read_buf: vec![0u8; 8192],
#[cfg(feature = "console")]
console: None,
};
if conn.config.ssl_mode.should_try_ssl() {
#[cfg(feature = "tls")]
conn.negotiate_ssl()?;
#[cfg(not(feature = "tls"))]
if conn.config.ssl_mode != SslMode::Prefer {
return Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message:
"TLS requested but 'sqlmodel-postgres' was built without feature 'tls'"
.to_string(),
source: None,
}));
}
}
conn.send_startup()?;
conn.state = ConnectionState::Authenticating;
conn.handle_auth()?;
conn.read_startup_messages()?;
Ok(conn)
}
pub fn state(&self) -> ConnectionState {
self.state
}
pub fn is_ready(&self) -> bool {
matches!(self.state, ConnectionState::Ready(_))
}
pub fn process_id(&self) -> i32 {
self.process_id
}
pub fn secret_key(&self) -> i32 {
self.secret_key
}
pub fn parameter(&self, name: &str) -> Option<&str> {
self.parameters.get(name).map(|s| s.as_str())
}
pub fn parameters(&self) -> &HashMap<String, String> {
&self.parameters
}
#[allow(clippy::result_large_err)]
pub fn close(&mut self) -> Result<(), Error> {
if matches!(
self.state,
ConnectionState::Closed | ConnectionState::Disconnected
) {
return Ok(());
}
self.send_message(&FrontendMessage::Terminate)?;
self.state = ConnectionState::Closed;
Ok(())
}
#[allow(clippy::result_large_err)]
#[cfg(feature = "tls")]
fn negotiate_ssl(&mut self) -> Result<(), Error> {
self.send_message(&FrontendMessage::SSLRequest)?;
let mut buf = [0u8; 1];
self.stream.read_exact(&mut buf).map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message: format!("Failed to read SSL response: {}", e),
source: Some(Box::new(e)),
})
})?;
match buf[0] {
b'S' => {
#[cfg(feature = "tls")]
{
let plain = match std::mem::replace(&mut self.stream, PgStream::Closed) {
PgStream::Plain(s) => s,
other => {
self.stream = other;
return Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message: "TLS upgrade requires a plain TCP stream".to_string(),
source: None,
}));
}
};
let config = tls::build_client_config(self.config.ssl_mode)?;
let server_name = tls::server_name(&self.config.host)?;
let conn =
rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
.map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message: format!("Failed to create TLS connection: {e}"),
source: None,
})
})?;
let mut tls_stream = rustls::StreamOwned::new(conn, plain);
while tls_stream.conn.is_handshaking() {
tls_stream
.conn
.complete_io(&mut tls_stream.sock)
.map_err(|e| {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message: format!("TLS handshake failed: {e}"),
source: Some(Box::new(e)),
})
})?;
}
self.stream = PgStream::Tls(tls_stream);
Ok(())
}
#[cfg(not(feature = "tls"))]
{
Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message:
"TLS requested but 'sqlmodel-postgres' was built without feature 'tls'"
.to_string(),
source: None,
}))
}
}
b'N' => {
if self.config.ssl_mode.is_required() {
return Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message: "Server does not support SSL".to_string(),
source: None,
}));
}
Ok(())
}
_ => Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Ssl,
message: format!("Unexpected SSL response: 0x{:02x}", buf[0]),
source: None,
})),
}
}
#[allow(clippy::result_large_err)]
fn send_startup(&mut self) -> Result<(), Error> {
let params = self.config.startup_params();
let msg = FrontendMessage::Startup {
version: PROTOCOL_VERSION,
params,
};
self.send_message(&msg)
}
#[allow(clippy::result_large_err)]
fn require_auth_value(&self, message: &'static str) -> Result<&str, Error> {
self.config
.password
.as_deref()
.ok_or_else(|| auth_error(message))
}
#[allow(clippy::result_large_err)]
fn handle_auth(&mut self) -> Result<(), Error> {
loop {
let msg = self.receive_message()?;
match msg {
BackendMessage::AuthenticationOk => {
return Ok(());
}
BackendMessage::AuthenticationCleartextPassword => {
let auth_value =
self.require_auth_value("Authentication value required but not provided")?;
self.send_message(&FrontendMessage::PasswordMessage(auth_value.to_string()))?;
}
BackendMessage::AuthenticationMD5Password(salt) => {
let auth_value =
self.require_auth_value("Authentication value required but not provided")?;
let hash = md5_password(&self.config.user, auth_value, salt);
self.send_message(&FrontendMessage::PasswordMessage(hash))?;
}
BackendMessage::AuthenticationSASL(mechanisms) => {
if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
self.scram_auth()?;
} else {
return Err(auth_error(format!(
"Unsupported SASL mechanisms: {:?}",
mechanisms
)));
}
}
BackendMessage::ErrorResponse(e) => {
self.state = ConnectionState::Error;
return Err(error_from_fields(&e));
}
_ => {
return Err(Error::Protocol(ProtocolError {
message: format!("Unexpected message during auth: {:?}", msg),
raw_data: None,
source: None,
}));
}
}
}
}
#[allow(clippy::result_large_err)]
fn scram_auth(&mut self) -> Result<(), Error> {
let auth_value =
self.require_auth_value("Authentication value required for SCRAM-SHA-256")?;
let mut client = ScramClient::new(&self.config.user, auth_value);
let client_first = client.client_first();
self.send_message(&FrontendMessage::SASLInitialResponse {
mechanism: "SCRAM-SHA-256".to_string(),
data: client_first,
})?;
let msg = self.receive_message()?;
let server_first_data = match msg {
BackendMessage::AuthenticationSASLContinue(data) => data,
BackendMessage::ErrorResponse(e) => {
self.state = ConnectionState::Error;
return Err(error_from_fields(&e));
}
_ => {
return Err(Error::Protocol(ProtocolError {
message: format!("Expected SASL continue, got: {:?}", msg),
raw_data: None,
source: None,
}));
}
};
let client_final = client.process_server_first(&server_first_data)?;
self.send_message(&FrontendMessage::SASLResponse(client_final))?;
let msg = self.receive_message()?;
let server_final_data = match msg {
BackendMessage::AuthenticationSASLFinal(data) => data,
BackendMessage::ErrorResponse(e) => {
self.state = ConnectionState::Error;
return Err(error_from_fields(&e));
}
_ => {
return Err(Error::Protocol(ProtocolError {
message: format!("Expected SASL final, got: {:?}", msg),
raw_data: None,
source: None,
}));
}
};
client.verify_server_final(&server_final_data)?;
let msg = self.receive_message()?;
match msg {
BackendMessage::AuthenticationOk => Ok(()),
BackendMessage::ErrorResponse(e) => {
self.state = ConnectionState::Error;
Err(error_from_fields(&e))
}
_ => Err(Error::Protocol(ProtocolError {
message: format!("Expected AuthenticationOk, got: {:?}", msg),
raw_data: None,
source: None,
})),
}
}
#[allow(clippy::result_large_err)]
fn read_startup_messages(&mut self) -> Result<(), Error> {
loop {
let msg = self.receive_message()?;
match msg {
BackendMessage::BackendKeyData {
process_id,
secret_key,
} => {
self.process_id = process_id;
self.secret_key = secret_key;
}
BackendMessage::ParameterStatus { name, value } => {
self.parameters.insert(name, value);
}
BackendMessage::ReadyForQuery(status) => {
self.state = ConnectionState::Ready(status.into());
return Ok(());
}
BackendMessage::ErrorResponse(e) => {
self.state = ConnectionState::Error;
return Err(error_from_fields(&e));
}
BackendMessage::NoticeResponse(_notice) => {
}
_ => {
return Err(Error::Protocol(ProtocolError {
message: format!("Unexpected startup message: {:?}", msg),
raw_data: None,
source: None,
}));
}
}
}
}
#[allow(clippy::result_large_err)]
fn send_message(&mut self, msg: &FrontendMessage) -> Result<(), Error> {
let data = self.writer.write(msg);
self.stream.write_all(data).map_err(|e| {
self.state = ConnectionState::Error;
Error::Io(e)
})?;
self.stream.flush().map_err(|e| {
self.state = ConnectionState::Error;
Error::Io(e)
})?;
Ok(())
}
#[allow(clippy::result_large_err)]
fn receive_message(&mut self) -> Result<BackendMessage, Error> {
loop {
match self.reader.next_message() {
Ok(Some(msg)) => return Ok(msg),
Ok(None) => {
let n = self.stream.read(&mut self.read_buf).map_err(|e| {
if e.kind() == std::io::ErrorKind::TimedOut
|| e.kind() == std::io::ErrorKind::WouldBlock
{
Error::Timeout
} else {
self.state = ConnectionState::Error;
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: format!("Failed to read from server: {}", e),
source: Some(Box::new(e)),
})
}
})?;
if n == 0 {
self.state = ConnectionState::Disconnected;
return Err(Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Disconnected,
message: "Connection closed by server".to_string(),
source: None,
}));
}
self.reader.feed(&self.read_buf[..n]).map_err(|e| {
Error::Protocol(ProtocolError {
message: format!("Protocol error: {}", e),
raw_data: None,
source: None,
})
})?;
}
Err(e) => {
self.state = ConnectionState::Error;
return Err(Error::Protocol(ProtocolError {
message: format!("Protocol error: {}", e),
raw_data: None,
source: None,
}));
}
}
}
}
}
impl Drop for PgConnection {
fn drop(&mut self) {
let _ = self.close();
}
}
#[cfg(feature = "console")]
impl ConsoleAware for PgConnection {
fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
self.console = console;
}
fn console(&self) -> Option<&Arc<SqlModelConsole>> {
self.console.as_ref()
}
fn has_console(&self) -> bool {
self.console.is_some()
}
}
#[cfg(feature = "console")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionStage {
DnsResolve,
TcpConnect,
SslNegotiate,
SslEstablished,
Startup,
Authenticating,
Authenticated,
Ready,
}
#[cfg(feature = "console")]
impl ConnectionStage {
#[must_use]
pub fn description(&self) -> &'static str {
match self {
Self::DnsResolve => "Resolving DNS",
Self::TcpConnect => "Connecting (TCP)",
Self::SslNegotiate => "Negotiating SSL",
Self::SslEstablished => "SSL established",
Self::Startup => "Sending startup",
Self::Authenticating => "Authenticating",
Self::Authenticated => "Authenticated",
Self::Ready => "Ready",
}
}
}
#[cfg(feature = "console")]
impl PgConnection {
pub fn emit_progress(&self, stage: ConnectionStage, success: bool) {
if let Some(console) = &self.console {
let status = if success { "[OK]" } else { "[..] " };
let message = format!("{} {}", status, stage.description());
console.info(&message);
}
}
pub fn emit_connected(&self) {
if let Some(console) = &self.console {
let server_version = self
.parameters
.get("server_version")
.map_or("unknown", |s| s.as_str());
let message = format!(
"Connected to PostgreSQL {} at {}:{}",
server_version, self.config.host, self.config.port
);
console.success(&message);
}
}
pub fn emit_connected_plain(&self) -> String {
let server_version = self
.parameters
.get("server_version")
.map_or("unknown", |s| s.as_str());
format!(
"Connected to PostgreSQL {} at {}:{}",
server_version, self.config.host, self.config.port
)
}
}
fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
use std::fmt::Write;
let inner = format!("{}{}", password, user);
let inner_hash = md5::compute(inner.as_bytes());
let mut outer_input = format!("{:x}", inner_hash).into_bytes();
outer_input.extend_from_slice(&salt);
let outer_hash = md5::compute(&outer_input);
let mut result = String::with_capacity(35);
result.push_str("md5");
write!(&mut result, "{:x}", outer_hash).unwrap();
result
}
fn auth_error(msg: impl Into<String>) -> Error {
Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Authentication,
message: msg.into(),
source: None,
})
}
fn error_from_fields(fields: &ErrorFields) -> Error {
let kind = match fields.code.get(..2) {
Some("08") => {
return Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Connect,
message: fields.message.clone(),
source: None,
});
}
Some("28") => {
return Error::Connection(ConnectionError {
kind: ConnectionErrorKind::Authentication,
message: fields.message.clone(),
source: None,
});
}
Some("42") => QueryErrorKind::Syntax, Some("23") => QueryErrorKind::Constraint, Some("40") => {
if fields.code == "40001" {
QueryErrorKind::Serialization
} else {
QueryErrorKind::Deadlock
}
}
Some("57") => {
if fields.code == "57014" {
QueryErrorKind::Cancelled
} else {
QueryErrorKind::Timeout
}
}
_ => QueryErrorKind::Database,
};
Error::Query(QueryError {
kind,
sql: None,
sqlstate: Some(fields.code.clone()),
message: fields.message.clone(),
detail: fields.detail.clone(),
hint: fields.hint.clone(),
position: fields.position.map(|p| p as usize),
source: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_md5_password() {
let hash = md5_password("postgres", "mysecretpassword", *b"abcd");
assert!(hash.starts_with("md5"));
assert_eq!(hash.len(), 35); }
#[test]
fn test_transaction_status_conversion() {
assert_eq!(
TransactionStatusState::from(TransactionStatus::Idle),
TransactionStatusState::Idle
);
assert_eq!(
TransactionStatusState::from(TransactionStatus::Transaction),
TransactionStatusState::InTransaction
);
assert_eq!(
TransactionStatusState::from(TransactionStatus::Error),
TransactionStatusState::InFailed
);
}
#[test]
fn test_error_classification() {
let fields = ErrorFields {
severity: "ERROR".to_string(),
code: "23505".to_string(),
message: "unique violation".to_string(),
..Default::default()
};
let err = error_from_fields(&fields);
assert!(matches!(err, Error::Query(q) if q.kind == QueryErrorKind::Constraint));
let fields = ErrorFields {
severity: "FATAL".to_string(),
code: "28P01".to_string(),
message: "password authentication failed".to_string(),
..Default::default()
};
let err = error_from_fields(&fields);
assert!(matches!(
err,
Error::Connection(c) if c.kind == ConnectionErrorKind::Authentication
));
}
}