use std::collections::HashMap;
use bytes::{BufMut, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::{
PgConnection, PgMessage,
auth::{ScramClient, cleartext_password, md5_password},
message::{
backend::{self},
frontend::{MessageCode, cstring_len, frame},
},
};
use self::auth_msg::{AuthMessage, read_auth_message};
mod auth_msg;
mod error;
pub use error::*;
pub const SSL_REQUEST: &[u8] = &[
0x00, 0x00, 0x00, 0x08, 0x04, 0xD2, 0x16, 0x2F, ];
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum AuthenticationMode {
Trust,
Password(String),
}
const CURRENT_VERSION: ProtocolVersion = ProtocolVersion::new(3, 0);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(transparent)]
pub struct ProtocolVersion(u32);
impl ProtocolVersion {
const fn new(major: u16, minor: u16) -> Self {
Self(((major as u32) << 16) | (minor as u32))
}
fn major(&self) -> u16 {
(self.0 >> 16) as u16
}
fn minor(&self) -> u16 {
(self.0 & 0xFFFF) as u16
}
}
impl From<u32> for ProtocolVersion {
fn from(value: u32) -> Self {
Self(value)
}
}
impl From<ProtocolVersion> for u32 {
fn from(value: ProtocolVersion) -> Self {
value.0
}
}
impl PartialEq<u32> for ProtocolVersion {
fn eq(&self, other: &u32) -> bool {
self.0 == *other
}
}
impl PartialEq<ProtocolVersion> for u32 {
fn eq(&self, other: &ProtocolVersion) -> bool {
*self == other.0
}
}
impl std::fmt::Display for ProtocolVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}", self.major(), self.minor())
}
}
#[derive(Debug, Clone)]
pub struct StartupResponse {
pub process_id: u32,
pub secret_key: u32,
pub parameters: HashMap<String, String>,
}
pub struct ConnectionBuilder {
auth: AuthenticationMode,
protocol: ProtocolVersion,
options: HashMap<String, String>,
}
impl ConnectionBuilder {
pub fn new(user: impl Into<String>) -> Self {
let user = user.into();
let mut options = HashMap::new();
options.insert("application_name".into(), "pg_stream".into());
options.insert("database".into(), user.clone());
options.insert("user".into(), user);
Self {
auth: AuthenticationMode::Trust,
protocol: CURRENT_VERSION,
options,
}
}
pub fn database(self, db: impl Into<String>) -> Self {
self.add_option("database", db.into())
}
pub fn user(self, user: impl Into<String>) -> Self {
self.add_option("user", user.into())
}
pub fn auth(mut self, auth: AuthenticationMode) -> Self {
self.auth = auth;
self
}
pub fn application_name(self, app: impl Into<String>) -> Self {
self.add_option("application_name", app.into())
}
pub fn protocol(mut self, protocol: impl Into<ProtocolVersion>) -> Self {
self.protocol = protocol.into();
self
}
pub fn add_option(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
self.options.insert(key.into(), val.into());
self
}
fn get_user(&self) -> &str {
self.options.get("user").expect("user should always be set")
}
fn as_startup_message(&self) -> Bytes {
let mut buf = BytesMut::new();
let payload_len = {
let mut len = 4 + 1; for (key, val) in &self.options {
len += cstring_len(key.as_bytes()) + cstring_len(val.as_bytes());
}
len
};
frame(&mut buf, payload_len, |buf| {
buf.put_u32(self.protocol.into());
for (key, val) in &self.options {
buf.put_slice(key.as_bytes());
buf.put_u8(0);
buf.put_slice(val.as_bytes());
buf.put_u8(0);
}
buf.put_u8(0);
});
buf.freeze()
}
pub async fn connect_with_tls<S, T, F, Fut>(
&self,
mut stream: S,
upgrade_fn: F,
) -> Result<(PgConnection<T>, StartupResponse)>
where
S: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
F: FnOnce(S) -> Fut,
Fut: Future<Output = std::io::Result<T>>,
{
stream.write_all(SSL_REQUEST).await?;
stream.flush().await?;
let mut buf = [0; 1];
stream.read_exact(&mut buf).await?;
let res = u8::from_be_bytes(buf);
const SSL_SUCCESS: u8 = b'S';
const SSL_FAILURE: u8 = b'N';
let stream = match res {
SSL_SUCCESS => upgrade_fn(stream).await?,
SSL_FAILURE => Err(Error::TlsUnsupported)?,
_ => Err(format!("unexpected SSL response code '{res}'"))?,
};
self.connect(stream).await
}
pub async fn connect<S>(&self, mut stream: S) -> Result<(PgConnection<S>, StartupResponse)>
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.startup(&mut stream).await?;
let mut startup_res = StartupResponse {
process_id: 0,
secret_key: 0,
parameters: HashMap::new(),
};
loop {
let msg = backend::read_message(&mut stream).await?;
match msg {
PgMessage::ParameterStatus(ps) => {
startup_res
.parameters
.insert(ps.name().into_owned(), ps.value().into_owned());
}
PgMessage::BackendKeyData(bkd) => {
startup_res.process_id = bkd.process_id();
startup_res.secret_key = bkd.secret_key();
}
PgMessage::ReadyForQuery(_) => break,
msg => Err(format!("unexpected message: {:?}", msg))?,
}
}
Ok((PgConnection::new(stream), startup_res))
}
async fn startup<S>(&self, stream: &mut S) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let startup_msg = self.as_startup_message();
stream.write_all(&startup_msg).await?;
stream.flush().await?;
match read_auth_message(stream).await? {
AuthMessage::Ok => Ok(()),
AuthMessage::CleartextPassword => {
let AuthenticationMode::Password(pw) = &self.auth else {
return Err(Error::PasswordRequired);
};
let msg = cleartext_password(pw);
stream.write_all(&msg).await?;
stream.flush().await?;
match read_auth_message(stream).await? {
AuthMessage::Ok => Ok(()),
code => Err(format!("unexpected authentication code {code}"))?,
}
}
AuthMessage::Md5Password(salt) => {
let AuthenticationMode::Password(pw) = &self.auth else {
return Err(Error::PasswordRequired);
};
let msg = md5_password(self.get_user(), pw, &salt);
stream.write_all(&msg).await?;
stream.flush().await?;
match read_auth_message(stream).await? {
AuthMessage::Ok => Ok(()),
code => Err(format!("unexpected authentication code {code}"))?,
}
}
AuthMessage::Sasl(mech) => {
let AuthenticationMode::Password(pw) = &self.auth else {
return Err(Error::PasswordRequired);
};
let mut scram = ScramClient::new(self.get_user(), pw);
let client_first = scram.client_first();
let mech = mech.to_string();
let mut msg = BytesMut::new();
let payload_len = cstring_len(mech.as_bytes()) + 4 + client_first.len();
msg.put_u8(MessageCode::SASL_RESPONSE.as_u8());
frame(&mut msg, payload_len, |buf| {
buf.put_slice(mech.as_bytes());
buf.put_u8(0);
buf.put_u32(client_first.len() as u32);
buf.put_slice(client_first.as_bytes());
});
stream.write_all(&msg).await?;
stream.flush().await?;
let res = read_auth_message(stream).await?;
let AuthMessage::SaslContinue(server_first) = res else {
return Err(format!("unexpected authentication response {res}"))?;
};
let client_final = scram
.client_final(&server_first)
.map_err(|e| format!("scram handshake failed: {e}"))?;
let mut msg = BytesMut::new();
msg.put_u8(MessageCode::SASL_RESPONSE.as_u8());
frame(&mut msg, client_final.len(), |buf| {
buf.put_slice(client_final.as_bytes());
});
stream.write_all(&msg).await?;
stream.flush().await?;
let res = read_auth_message(stream).await?;
let AuthMessage::SaslFinal(server_final) = res else {
return Err(format!("unexpected authentication response {res}"))?;
};
scram
.verify_server(&server_final)
.map_err(|e| format!("scram handshake failed: {e}"))?;
match read_auth_message(stream).await? {
AuthMessage::Ok => Ok(()),
code => Err(format!("unexpected authentication code {code}"))?,
}
}
code => unimplemented!("oops: {code}"),
}
}
}
#[cfg(test)]
mod tests {
use crate::startup::ProtocolVersion;
#[test]
fn test_protocol_version() {
let major = 3;
let minor = 0;
let version = ProtocolVersion::new(major, minor);
assert_eq!(major, version.major());
assert_eq!(minor, version.minor());
assert_eq!(196608, version.0);
}
}