use super::{invalid_state, ConnectionError, ConnectionSslRequirement, Credentials};
use crate::md5::md5_password;
use crate::postgres::SslError;
use crate::scram::{
generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError, Sha256Out,
};
use crate::AuthType;
use base64::Engine;
use gel_pg_protocol::{errors::PgServerError, prelude::*, protocol::*};
use tracing::{error, trace, warn};
#[derive(Debug)]
struct ClientEnvironmentImpl {
credentials: Credentials,
}
impl ClientEnvironment for ClientEnvironmentImpl {
fn generate_nonce(&self) -> String {
let nonce: [u8; 32] = rand::random();
base64::engine::general_purpose::STANDARD.encode(nonce)
}
fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
generate_salted_password(self.credentials.password.as_bytes(), salt, iterations)
}
}
#[derive(Debug)]
enum ConnectionStateImpl {
SslInitializing(Credentials, ConnectionSslRequirement),
SslWaiting(Credentials, ConnectionSslRequirement),
SslConnecting(Credentials),
Initializing(Credentials),
Connecting(Credentials, bool),
Scram(ClientTransaction, ClientEnvironmentImpl),
Connected,
Ready,
Error,
}
#[derive(Clone, Copy, Debug)]
pub enum ConnectionStateType {
Connecting,
SslConnecting,
Authenticating,
Synchronizing,
Ready,
}
#[derive(Debug)]
pub enum ConnectionDrive<'a> {
Initial,
Message(Result<Message<'a>, ParseError>),
SslResponse(SSLResponse<'a>),
SslReady,
}
pub trait ConnectionStateSend {
fn send_initial<'a, M>(
&mut self,
message: impl IntoInitialBuilder<'a, M>,
) -> Result<(), std::io::Error>;
fn send<'a, M>(
&mut self,
message: impl IntoFrontendBuilder<'a, M>,
) -> Result<(), std::io::Error>;
fn upgrade(&mut self) -> Result<(), std::io::Error>;
}
#[allow(unused)]
pub trait ConnectionStateUpdate: ConnectionStateSend {
fn parameter(&mut self, name: &str, value: &str) {}
fn cancellation_key(&mut self, pid: i32, key: i32) {}
fn state_changed(&mut self, state: ConnectionStateType) {}
fn server_error(&mut self, error: &PgServerError) {
error!("Server error during handshake: {:?}", error);
}
fn server_notice(&mut self, notice: &PgServerError) {
warn!("Server notice during handshake: {:?}", notice);
}
fn auth(&mut self, auth: AuthType) {}
}
#[derive(Debug)]
pub struct ConnectionState(ConnectionStateImpl);
impl ConnectionState {
pub fn new(credentials: Credentials, ssl_mode: ConnectionSslRequirement) -> Self {
if ssl_mode == ConnectionSslRequirement::Disable {
Self(ConnectionStateImpl::Initializing(credentials))
} else {
Self(ConnectionStateImpl::SslInitializing(credentials, ssl_mode))
}
}
pub fn is_ready(&self) -> bool {
matches!(self.0, ConnectionStateImpl::Ready)
}
pub fn is_error(&self) -> bool {
matches!(self.0, ConnectionStateImpl::Error)
}
pub fn is_done(&self) -> bool {
self.is_ready() || self.is_error()
}
pub fn read_ssl_response(&self) -> bool {
matches!(self.0, ConnectionStateImpl::SslWaiting(..))
}
pub fn drive(
&mut self,
drive: ConnectionDrive,
update: &mut impl ConnectionStateUpdate,
) -> Result<(), ConnectionError> {
use ConnectionStateImpl::*;
trace!("Received drive {drive:?} in state {:?}", self.0);
match (&mut self.0, drive) {
(SslInitializing(credentials, mode), ConnectionDrive::Initial) => {
update.send_initial(&SSLRequestBuilder::default())?;
self.0 = SslWaiting(std::mem::take(credentials), *mode);
update.state_changed(ConnectionStateType::Connecting);
}
(SslWaiting(credentials, mode), ConnectionDrive::SslResponse(response)) => {
if *mode == ConnectionSslRequirement::Disable {
return Err(invalid_state!("SSL mode is Disable in SslWaiting state"));
}
if response.code() == b'S' {
update.upgrade()?;
self.0 = SslConnecting(std::mem::take(credentials));
update.state_changed(ConnectionStateType::SslConnecting);
} else if response.code() == b'N' {
if *mode == ConnectionSslRequirement::Required {
return Err(ConnectionError::SslError(SslError::SslRequiredByClient));
}
Self::send_startup_message(credentials, update)?;
self.0 = Connecting(std::mem::take(credentials), false);
} else {
return Err(ConnectionError::UnexpectedResponse(format!(
"Unexpected SSL response from server: {:?}",
response.code() as char
)));
}
}
(SslConnecting(credentials), ConnectionDrive::SslReady) => {
Self::send_startup_message(credentials, update)?;
self.0 = Connecting(std::mem::take(credentials), false);
}
(Initializing(credentials), ConnectionDrive::Initial) => {
Self::send_startup_message(credentials, update)?;
self.0 = Connecting(std::mem::take(credentials), false);
update.state_changed(ConnectionStateType::Connecting);
}
(Connecting(credentials, sent_auth), ConnectionDrive::Message(message)) => {
match_message!(message, Backend {
(AuthenticationOk) => {
if !*sent_auth {
update.auth(AuthType::Trust);
}
trace!("auth ok");
self.0 = Connected;
update.state_changed(ConnectionStateType::Synchronizing);
},
(AuthenticationSASL as sasl) => {
*sent_auth = true;
let mut found_scram_sha256 = false;
for mech in sasl.mechanisms() {
trace!("auth sasl: {:?}", mech);
if mech == "SCRAM-SHA-256" {
found_scram_sha256 = true;
break;
}
}
if !found_scram_sha256 {
return Err(ConnectionError::UnexpectedResponse("Server requested SASL authentication but does not support SCRAM-SHA-256".into()));
}
let credentials = credentials.clone();
let mut tx = ClientTransaction::new("".into());
let env = ClientEnvironmentImpl { credentials };
let Some(initial_message) = tx.process_message(&[], &env)? else {
return Err(SCRAMError::ProtocolError.into());
};
update.auth(AuthType::ScramSha256);
update.send(&SASLInitialResponseBuilder {
mechanism: "SCRAM-SHA-256",
response: initial_message.as_slice(),
})?;
self.0 = Scram(tx, env);
update.state_changed(ConnectionStateType::Authenticating);
},
(AuthenticationMD5Password as md5) => {
*sent_auth = true;
trace!("auth md5");
let md5_hash = md5_password(&credentials.password, &credentials.username, md5.salt());
update.auth(AuthType::Md5);
update.send(&PasswordMessageBuilder {
password: &md5_hash,
})?;
},
(AuthenticationCleartextPassword) => {
*sent_auth = true;
trace!("auth cleartext");
update.auth(AuthType::Plain);
update.send(&PasswordMessageBuilder {
password: &credentials.password,
})?;
},
(NoticeResponse as notice) => {
let err = PgServerError::from(notice);
update.server_notice(&err);
},
(ErrorResponse as error) => {
self.0 = Error;
let err = PgServerError::from(error);
update.server_error(&err);
return Err(err.into());
},
message => {
log_unknown_message(message, "Connecting")?
},
});
}
(Scram(tx, env), ConnectionDrive::Message(message)) => {
match_message!(message, Backend {
(AuthenticationSASLContinue as sasl) => {
let Some(message) = tx.process_message(&sasl.data(), env)? else {
return Err(SCRAMError::ProtocolError.into());
};
update.send(&SASLResponseBuilder {
response: &message,
})?;
},
(AuthenticationSASLFinal as sasl) => {
let None = tx.process_message(&sasl.data(), env)? else {
return Err(SCRAMError::ProtocolError.into());
};
},
(AuthenticationOk) => {
trace!("auth ok");
self.0 = Connected;
update.state_changed(ConnectionStateType::Synchronizing);
},
(AuthenticationMessage as auth) => {
trace!("SCRAM Unknown auth message: {}", auth.status())
},
(NoticeResponse as notice) => {
let err = PgServerError::from(notice);
update.server_notice(&err);
},
(ErrorResponse as error) => {
self.0 = Error;
let err = PgServerError::from(error);
update.server_error(&err);
return Err(err.into());
},
message => {
log_unknown_message(message, "SCRAM")?
},
});
}
(Connected, ConnectionDrive::Message(message)) => {
match_message!(message, Backend {
(ParameterStatus as param) => {
trace!("param: {:?}={:?}", param.name(), param.value());
update.parameter(param.name().try_into()?, param.value().try_into()?);
},
(BackendKeyData as key_data) => {
trace!("key={:?} pid={:?}", key_data.key(), key_data.pid());
update.cancellation_key(key_data.pid(), key_data.key());
},
(ReadyForQuery as ready) => {
trace!("ready: {:?}", ready.status() as char);
trace!("-> Ready");
self.0 = Ready;
update.state_changed(ConnectionStateType::Ready);
},
(NoticeResponse as notice) => {
let err = PgServerError::from(notice);
update.server_notice(&err);
},
(ErrorResponse as error) => {
self.0 = Error;
let err = PgServerError::from(error);
update.server_error(&err);
return Err(err.into());
},
message => {
log_unknown_message(message, "Connected")?
},
});
}
(Ready, _) | (Error, _) => {
return Err(invalid_state!("Unexpected drive for Ready or Error state"))
}
_ => return Err(invalid_state!("Unexpected (state, drive) combination")),
}
Ok(())
}
fn send_startup_message(
credentials: &Credentials,
update: &mut impl ConnectionStateUpdate,
) -> Result<(), std::io::Error> {
let mut params = vec![
StartupNameValueBuilder {
name: "user",
value: &credentials.username,
},
StartupNameValueBuilder {
name: "database",
value: &credentials.database,
},
];
for (name, value) in &credentials.server_settings {
params.push(StartupNameValueBuilder { name, value })
}
update.send_initial(&StartupMessageBuilder { params: || ¶ms })
}
}
fn log_unknown_message(
message: Result<Message, ParseError>,
state: &str,
) -> Result<(), ParseError> {
match message {
Ok(message) => {
warn!(
"Unexpected message {:?} (length {}) received in {} state",
message.mtype(),
message.mlen(),
state
);
Ok(())
}
Err(e) => {
error!("Corrupted message received in {} state", state);
Err(e)
}
}
}