use super::wire;
use crate::error::ReplicationError;
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncWrite};
pub async fn authenticate<S: AsyncRead + AsyncWrite + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
user: &str,
password: Option<&str>,
tls_server_end_point: Option<Vec<u8>>,
) -> Result<(), ReplicationError> {
loop {
let msg = wire::read_message(stream, buf).await?;
if msg.is_empty() {
return Err(ReplicationError::protocol(
"Empty message during authentication".to_string(),
));
}
match msg[0] {
b'R' => {
if msg.len() < 9 {
return Err(ReplicationError::protocol(
"Authentication message too short".to_string(),
));
}
let auth_type = i32::from_be_bytes(msg[5..9].try_into().unwrap());
match auth_type {
0 => {
tracing::debug!("Authentication successful");
return Ok(());
}
3 => {
if tls_server_end_point.is_none() {
tracing::warn!(
"Server requested cleartext password over unencrypted connection"
);
}
let pw = password.ok_or_else(|| {
ReplicationError::authentication(
"Server requires password but none provided".to_string(),
)
})?;
let pw_msg = wire::build_password_message(pw);
wire::write_all(stream, &pw_msg).await?;
wire::flush(stream).await?;
}
5 => {
let pw = password.ok_or_else(|| {
ReplicationError::authentication(
"Server requires password but none provided".to_string(),
)
})?;
if msg.len() < 13 {
return Err(ReplicationError::protocol(
"MD5 auth message too short (missing salt)".to_string(),
));
}
let salt = &msg[9..13];
let hashed = md5_password(user, pw, salt);
let pw_msg = wire::build_password_message(&hashed);
wire::write_all(stream, &pw_msg).await?;
wire::flush(stream).await?;
}
10 => {
handle_scram_sha256(stream, buf, &msg, password, &tls_server_end_point)
.await?;
return Ok(());
}
_ => {
return Err(ReplicationError::authentication(format!(
"Unsupported authentication type: {auth_type}"
)));
}
}
}
b'E' => {
let fields = super::error::parse_error_fields(&msg[5..]);
return Err(ReplicationError::authentication(format!(
"Authentication failed: {}",
fields
)));
}
_ => {
tracing::warn!(
"Unexpected message type '{}' during authentication",
msg[0] as char
);
}
}
}
}
fn md5_password(user: &str, password: &str, salt: &[u8]) -> String {
let mut ctx = super::md5::Context::new();
ctx.consume(password.as_bytes());
ctx.consume(user.as_bytes());
let inner_hash = ctx.finalize();
let inner = super::md5::hex_digest(&inner_hash);
let mut ctx = super::md5::Context::new();
ctx.consume(inner.as_bytes());
ctx.consume(salt);
let outer_hash = ctx.finalize();
let outer = super::md5::hex_digest(&outer_hash);
format!("md5{outer}")
}
async fn handle_scram_sha256<S: AsyncRead + AsyncWrite + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
initial_msg: &[u8],
password: Option<&str>,
tls_server_end_point: &Option<Vec<u8>>,
) -> Result<(), ReplicationError> {
use postgres_protocol::authentication::sasl;
let pw = password.ok_or_else(|| {
ReplicationError::authentication(
"Server requires password for SCRAM-SHA-256 but none provided".to_string(),
)
})?;
let mechanisms_payload = &initial_msg[9..];
let mut has_scram_sha256 = false;
let mut has_scram_sha256_plus = false;
let mut pos = 0;
while pos < mechanisms_payload.len() {
if mechanisms_payload[pos] == 0 {
break;
}
let end = mechanisms_payload[pos..]
.iter()
.position(|&b| b == 0)
.unwrap_or(mechanisms_payload.len() - pos);
let mechanism = std::str::from_utf8(&mechanisms_payload[pos..pos + end]).unwrap_or("");
match mechanism {
"SCRAM-SHA-256" => has_scram_sha256 = true,
"SCRAM-SHA-256-PLUS" => has_scram_sha256_plus = true,
_ => {}
}
pos += end + 1;
}
if !has_scram_sha256 && !has_scram_sha256_plus {
return Err(ReplicationError::authentication(
"Server does not support SCRAM-SHA-256".to_string(),
));
}
let (channel_binding, use_plus) = match tls_server_end_point {
Some(ref data) if has_scram_sha256_plus => (
sasl::ChannelBinding::tls_server_end_point(data.clone()),
true,
),
Some(_) => (sasl::ChannelBinding::unrequested(), false),
None => (sasl::ChannelBinding::unsupported(), false),
};
let mechanism_name = if use_plus {
sasl::SCRAM_SHA_256_PLUS
} else {
sasl::SCRAM_SHA_256
};
let mut scram = sasl::ScramSha256::new(pw.as_bytes(), channel_binding);
let mut sasl_init = BytesMut::new();
postgres_protocol::message::frontend::sasl_initial_response(
mechanism_name,
scram.message(),
&mut sasl_init,
)
.map_err(|e| {
ReplicationError::authentication(format!("Failed to build SASL initial response: {e}"))
})?;
wire::write_all(stream, &sasl_init).await?;
wire::flush(stream).await?;
let msg2 = wire::read_message(stream, buf).await?;
if msg2[0] == b'E' {
let fields = super::error::parse_error_fields(&msg2[5..]);
return Err(ReplicationError::authentication(format!(
"SCRAM authentication failed: {}",
fields
)));
}
if msg2[0] != b'R' || msg2.len() < 9 {
return Err(ReplicationError::protocol(
"Expected AuthenticationSASLContinue".to_string(),
));
}
let auth_type2 = i32::from_be_bytes(msg2[5..9].try_into().unwrap());
if auth_type2 != 11 {
return Err(ReplicationError::protocol(format!(
"Expected SASL continue (11), got {auth_type2}"
)));
}
let server_first = &msg2[9..];
scram.update(server_first).map_err(|e| {
ReplicationError::authentication(format!("SCRAM server-first processing failed: {e}"))
})?;
let mut sasl_resp = BytesMut::new();
postgres_protocol::message::frontend::sasl_response(scram.message(), &mut sasl_resp).map_err(
|e| ReplicationError::authentication(format!("Failed to build SASL response: {e}")),
)?;
wire::write_all(stream, &sasl_resp).await?;
wire::flush(stream).await?;
let msg3 = wire::read_message(stream, buf).await?;
if msg3[0] == b'E' {
let fields = super::error::parse_error_fields(&msg3[5..]);
return Err(ReplicationError::authentication(format!(
"SCRAM authentication failed: {}",
fields
)));
}
if msg3[0] != b'R' || msg3.len() < 9 {
return Err(ReplicationError::protocol(
"Expected AuthenticationSASLFinal".to_string(),
));
}
let auth_type3 = i32::from_be_bytes(msg3[5..9].try_into().unwrap());
if auth_type3 != 12 {
return Err(ReplicationError::protocol(format!(
"Expected SASL final (12), got {auth_type3}"
)));
}
let server_final = &msg3[9..];
scram.finish(server_final).map_err(|e| {
ReplicationError::authentication(format!("SCRAM server signature verification failed: {e}"))
})?;
let msg4 = wire::read_message(stream, buf).await?;
if msg4[0] == b'R' && msg4.len() >= 9 {
let auth_type4 = i32::from_be_bytes(msg4[5..9].try_into().unwrap());
if auth_type4 != 0 {
return Err(ReplicationError::protocol(format!(
"Expected AuthenticationOk (0) after SCRAM, got {auth_type4}"
)));
}
} else if msg4[0] == b'E' {
let fields = super::error::parse_error_fields(&msg4[5..]);
return Err(ReplicationError::authentication(format!(
"Authentication failed after SCRAM: {}",
fields
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[test]
fn test_md5_password() {
let result = md5_password("md5", "md5", &[0x01, 0x02, 0x03, 0x04]);
assert!(result.starts_with("md5"));
assert_eq!(result.len(), 35); }
#[test]
fn test_md5_password_known() {
let result = md5_password("test", "test", &[0, 0, 0, 0]);
assert!(result.starts_with("md5"));
}
fn build_auth_ok() -> Vec<u8> {
let mut msg = vec![b'R'];
msg.extend_from_slice(&8i32.to_be_bytes());
msg.extend_from_slice(&0i32.to_be_bytes());
msg
}
fn build_auth_cleartext() -> Vec<u8> {
let mut msg = vec![b'R'];
msg.extend_from_slice(&8i32.to_be_bytes());
msg.extend_from_slice(&3i32.to_be_bytes());
msg
}
fn build_auth_md5(salt: &[u8; 4]) -> Vec<u8> {
let mut msg = vec![b'R'];
msg.extend_from_slice(&12i32.to_be_bytes());
msg.extend_from_slice(&5i32.to_be_bytes());
msg.extend_from_slice(salt);
msg
}
fn build_error_response(severity: &str, code: &str, message: &str) -> Vec<u8> {
let mut payload = Vec::new();
payload.push(b'S');
payload.extend_from_slice(severity.as_bytes());
payload.push(0);
payload.push(b'C');
payload.extend_from_slice(code.as_bytes());
payload.push(0);
payload.push(b'M');
payload.extend_from_slice(message.as_bytes());
payload.push(0);
payload.push(0);
let mut msg = vec![b'E'];
let len = (4 + payload.len()) as i32;
msg.extend_from_slice(&len.to_be_bytes());
msg.extend_from_slice(&payload);
msg
}
#[tokio::test]
async fn test_auth_ok_immediately() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
server.write_all(&build_auth_ok()).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "user", Some("pass"), None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_auth_cleartext_then_ok() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
server.write_all(&build_auth_cleartext()).await.unwrap();
server.flush().await.unwrap();
let mut discard = vec![0u8; 1024];
let _ = AsyncReadExt::read(&mut server, &mut discard).await;
server.write_all(&build_auth_ok()).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "user", Some("secret"), None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_auth_md5_then_ok() {
let (mut client, mut server) = tokio::io::duplex(8192);
let salt = [0x01, 0x02, 0x03, 0x04];
tokio::spawn(async move {
server.write_all(&build_auth_md5(&salt)).await.unwrap();
server.flush().await.unwrap();
let mut discard = vec![0u8; 1024];
let _ = AsyncReadExt::read(&mut server, &mut discard).await;
server.write_all(&build_auth_ok()).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "testuser", Some("testpass"), None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_auth_cleartext_no_password() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
server.write_all(&build_auth_cleartext()).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "user", None, None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("password"),
"Expected password error, got: {err}"
);
}
#[tokio::test]
async fn test_auth_error_response() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let err_msg = build_error_response("FATAL", "28000", "password authentication failed");
server.write_all(&err_msg).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "user", Some("wrong"), None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("authentication")
|| err.to_string().contains("Authentication"),
"Expected authentication error, got: {err}"
);
}
#[tokio::test]
async fn test_auth_unsupported_type() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let mut msg = vec![b'R'];
msg.extend_from_slice(&8i32.to_be_bytes());
msg.extend_from_slice(&99i32.to_be_bytes());
server.write_all(&msg).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "user", Some("pass"), None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("Unsupported") || err.to_string().contains("unsupported"),
"Expected unsupported auth type error, got: {err}"
);
}
fn build_auth_sasl(mechanisms: &[&str]) -> Vec<u8> {
let mut payload = Vec::new();
for mech in mechanisms {
payload.extend_from_slice(mech.as_bytes());
payload.push(0);
}
payload.push(0);
let body_len = 4 + 4 + payload.len(); let mut msg = vec![b'R'];
msg.extend_from_slice(&(body_len as i32).to_be_bytes());
msg.extend_from_slice(&10i32.to_be_bytes()); msg.extend_from_slice(&payload);
msg
}
#[tokio::test]
async fn test_scram_no_password_returns_error() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let msg = build_auth_sasl(&["SCRAM-SHA-256"]);
server.write_all(&msg).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "user", None, None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("password"),
"Expected password error, got: {err}"
);
}
#[tokio::test]
async fn test_scram_no_supported_mechanism_returns_error() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let msg = build_auth_sasl(&["SCRAM-SHA-512"]);
server.write_all(&msg).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let result = authenticate(&mut client, &mut buf, "user", Some("pass"), None).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("SCRAM-SHA-256"),
"Expected SCRAM-SHA-256 not supported error, got: {err}"
);
}
#[tokio::test]
async fn test_scram_selects_plain_without_tls_endpoint() {
let (mut client, mut server) = tokio::io::duplex(8192);
tokio::spawn(async move {
let msg = build_auth_sasl(&["SCRAM-SHA-256-PLUS", "SCRAM-SHA-256"]);
server.write_all(&msg).await.unwrap();
server.flush().await.unwrap();
let mut response = vec![0u8; 4096];
let n = AsyncReadExt::read(&mut server, &mut response)
.await
.unwrap();
let response = &response[..n];
assert_eq!(response[0], b'p');
let mech_start = 5;
let mech_end =
response[mech_start..].iter().position(|&b| b == 0).unwrap() + mech_start;
let mechanism = std::str::from_utf8(&response[mech_start..mech_end]).unwrap();
assert_eq!(
mechanism, "SCRAM-SHA-256",
"Without TLS endpoint, should select SCRAM-SHA-256, not PLUS"
);
let err = build_error_response("FATAL", "28000", "test abort");
server.write_all(&err).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let _ = authenticate(&mut client, &mut buf, "user", Some("pass"), None).await;
}
#[tokio::test]
async fn test_scram_selects_plus_with_tls_endpoint() {
let (mut client, mut server) = tokio::io::duplex(8192);
let fake_tls_hash = vec![0xAA; 32];
tokio::spawn(async move {
let msg = build_auth_sasl(&["SCRAM-SHA-256-PLUS", "SCRAM-SHA-256"]);
server.write_all(&msg).await.unwrap();
server.flush().await.unwrap();
let mut response = vec![0u8; 4096];
let n = AsyncReadExt::read(&mut server, &mut response)
.await
.unwrap();
let response = &response[..n];
assert_eq!(response[0], b'p');
let mech_start = 5;
let mech_end =
response[mech_start..].iter().position(|&b| b == 0).unwrap() + mech_start;
let mechanism = std::str::from_utf8(&response[mech_start..mech_end]).unwrap();
assert_eq!(
mechanism, "SCRAM-SHA-256-PLUS",
"With TLS endpoint and server offering PLUS, should select SCRAM-SHA-256-PLUS"
);
let err = build_error_response("FATAL", "28000", "test abort");
server.write_all(&err).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let _ = authenticate(
&mut client,
&mut buf,
"user",
Some("pass"),
Some(fake_tls_hash),
)
.await;
}
#[tokio::test]
async fn test_scram_falls_back_without_plus_offered() {
let (mut client, mut server) = tokio::io::duplex(8192);
let fake_tls_hash = vec![0xBB; 32];
tokio::spawn(async move {
let msg = build_auth_sasl(&["SCRAM-SHA-256"]);
server.write_all(&msg).await.unwrap();
server.flush().await.unwrap();
let mut response = vec![0u8; 4096];
let n = AsyncReadExt::read(&mut server, &mut response)
.await
.unwrap();
let response = &response[..n];
assert_eq!(response[0], b'p');
let mech_start = 5;
let mech_end =
response[mech_start..].iter().position(|&b| b == 0).unwrap() + mech_start;
let mechanism = std::str::from_utf8(&response[mech_start..mech_end]).unwrap();
assert_eq!(
mechanism, "SCRAM-SHA-256",
"When server doesn't offer PLUS, should fall back to SCRAM-SHA-256"
);
let err = build_error_response("FATAL", "28000", "test abort");
server.write_all(&err).await.unwrap();
server.flush().await.unwrap();
});
let mut buf = BytesMut::new();
let _ = authenticate(
&mut client,
&mut buf,
"user",
Some("pass"),
Some(fake_tls_hash),
)
.await;
}
}