use super::{ConnectionError, ConnectionSslRequirement};
use crate::{
handshake::{ServerAuth, ServerAuthDrive, ServerAuthError, ServerAuthResponse},
AuthType, CredentialData,
};
use gel_pg_protocol::{
errors::{
PgError, PgErrorConnectionException, PgErrorFeatureNotSupported,
PgErrorInvalidAuthorizationSpecification, PgServerError, PgServerErrorField,
},
prelude::*,
protocol::*,
};
use std::str::Utf8Error;
use tracing::{error, trace, warn};
#[derive(Clone, Copy, Debug)]
pub enum ConnectionStateType {
Connecting,
SslConnecting,
Authenticating,
Synchronizing,
Ready,
}
#[derive(Debug)]
pub enum ConnectionDrive<'a> {
RawMessage(&'a [u8]),
Initial(Result<InitialMessage<'a>, ParseError>),
Message(Result<Message<'a>, ParseError>),
SslReady,
AuthInfo(AuthType, CredentialData),
Parameter(String, String),
Ready(i32, i32),
Fail(PgError, &'a str),
}
pub trait ConnectionStateSend {
fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error>;
fn send<'a, M>(
&mut self,
message: impl IntoBackendBuilder<'a, M>,
) -> Result<(), std::io::Error>;
fn upgrade(&mut self) -> Result<(), std::io::Error>;
fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error>;
fn params(&mut self) -> Result<(), std::io::Error>;
}
#[allow(unused)]
pub trait ConnectionStateUpdate: ConnectionStateSend {
fn parameter(&mut self, name: &str, value: &str) {}
fn state_changed(&mut self, state: ConnectionStateType) {}
fn server_error(&mut self, error: &PgServerError) {}
}
#[derive(Debug)]
pub enum ConnectionEvent<'a> {
SendSSL(SSLResponseBuilder),
Send(BackendBuilder<'a>),
Upgrade,
Auth(String, String),
Params,
Parameter(&'a str, &'a str),
StateChanged(ConnectionStateType),
ServerError(&'a PgServerError),
}
impl<F> ConnectionStateSend for F
where
F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
{
fn send_ssl(&mut self, message: SSLResponseBuilder) -> Result<(), std::io::Error> {
self(ConnectionEvent::SendSSL(message))
}
fn send<'a, M>(
&mut self,
message: impl IntoBackendBuilder<'a, M>,
) -> Result<(), std::io::Error> {
self(ConnectionEvent::Send(message.into_builder()))
}
fn upgrade(&mut self) -> Result<(), std::io::Error> {
self(ConnectionEvent::Upgrade)
}
fn auth(&mut self, user: String, database: String) -> Result<(), std::io::Error> {
self(ConnectionEvent::Auth(user, database))
}
fn params(&mut self) -> Result<(), std::io::Error> {
self(ConnectionEvent::Params)
}
}
impl<F> ConnectionStateUpdate for F
where
F: FnMut(ConnectionEvent) -> Result<(), std::io::Error>,
{
fn parameter(&mut self, name: &str, value: &str) {
let _ = self(ConnectionEvent::Parameter(name, value));
}
fn state_changed(&mut self, state: ConnectionStateType) {
let _ = self(ConnectionEvent::StateChanged(state));
}
fn server_error(&mut self, error: &PgServerError) {
let _ = self(ConnectionEvent::ServerError(error));
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)] enum ServerStateImpl {
Initial(Option<ConnectionSslRequirement>),
SslConnecting,
AuthInfo(String),
Authenticating(ServerAuth),
Synchronizing,
Ready,
Error,
}
#[derive(derive_more::Debug)]
pub struct ServerState {
state: ServerStateImpl,
#[debug(skip)]
initial_buffer: StructBuffer<InitialMessage<'static>>,
#[debug(skip)]
buffer: StructBuffer<Message<'static>>,
}
fn send_error(
update: &mut impl ConnectionStateUpdate,
code: PgError,
message: &str,
) -> std::io::Result<()> {
let error = PgServerError::new(code, message, Default::default());
update.server_error(&error);
update.send(&ErrorResponseBuilder {
fields: &[
&ErrorFieldBuilder {
etype: PgServerErrorField::Severity as u8,
value: "ERROR",
},
&ErrorFieldBuilder {
etype: PgServerErrorField::SeverityNonLocalized as u8,
value: "ERROR",
},
&ErrorFieldBuilder {
etype: PgServerErrorField::Code as u8,
value: std::str::from_utf8(&code.to_code()).unwrap(),
},
&ErrorFieldBuilder {
etype: PgServerErrorField::Message as u8,
value: message,
},
],
})
}
#[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)]
enum ServerError {
IO(#[from] std::io::Error),
Protocol(#[from] PgError),
Utf8Error(#[from] Utf8Error),
}
impl From<ServerAuthError> for ServerError {
fn from(value: ServerAuthError) -> Self {
match value {
ServerAuthError::InvalidAuthorizationSpecification => {
ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification,
))
}
ServerAuthError::InvalidPassword => {
ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
PgErrorInvalidAuthorizationSpecification::InvalidPassword,
))
}
ServerAuthError::InvalidSaslMessage(_) => ServerError::Protocol(
PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation),
),
ServerAuthError::UnsupportedAuthType => ServerError::Protocol(
PgError::FeatureNotSupported(PgErrorFeatureNotSupported::FeatureNotSupported),
),
ServerAuthError::InvalidMessageType => ServerError::Protocol(
PgError::ConnectionException(PgErrorConnectionException::ProtocolViolation),
),
}
}
}
const PROTOCOL_ERROR: ServerError = ServerError::Protocol(PgError::ConnectionException(
PgErrorConnectionException::ProtocolViolation,
));
const AUTH_ERROR: ServerError = ServerError::Protocol(PgError::InvalidAuthorizationSpecification(
PgErrorInvalidAuthorizationSpecification::InvalidAuthorizationSpecification,
));
const PROTOCOL_VERSION_ERROR: ServerError = ServerError::Protocol(PgError::FeatureNotSupported(
PgErrorFeatureNotSupported::FeatureNotSupported,
));
impl ServerState {
pub fn new(ssl_requirement: ConnectionSslRequirement) -> Self {
Self {
state: ServerStateImpl::Initial(Some(ssl_requirement)),
initial_buffer: Default::default(),
buffer: Default::default(),
}
}
pub fn is_ready(&self) -> bool {
matches!(self.state, ServerStateImpl::Ready)
}
pub fn is_error(&self) -> bool {
matches!(self.state, ServerStateImpl::Error)
}
pub fn is_done(&self) -> bool {
self.is_ready() || self.is_error()
}
pub fn drive(
&mut self,
drive: ConnectionDrive,
update: &mut impl ConnectionStateUpdate,
) -> Result<(), ConnectionError> {
trace!("SERVER DRIVE: {:?} {:?}", self.state, drive);
let res = match drive {
ConnectionDrive::RawMessage(raw) => match self.state {
ServerStateImpl::Initial(..) => self.initial_buffer.push_fallible(raw, |message| {
self.state
.drive_inner(ConnectionDrive::Initial(message), update)
}),
ServerStateImpl::Authenticating(..) => self.buffer.push_fallible(raw, |message| {
self.state
.drive_inner(ConnectionDrive::Message(message), update)
}),
_ => {
error!("Unexpected drive in state {:?}", self.state);
Err(PROTOCOL_ERROR)
}
},
drive => self.state.drive_inner(drive, update),
};
match res {
Ok(_) => Ok(()),
Err(ServerError::IO(e)) => Err(e.into()),
Err(ServerError::Utf8Error(e)) => Err(e.into()),
Err(ServerError::Protocol(code)) => {
self.state = ServerStateImpl::Error;
send_error(update, code, "Connection error")?;
Err(PgServerError::new(code, "Connection error", Default::default()).into())
}
}
}
}
impl ServerStateImpl {
fn drive_inner(
&mut self,
drive: ConnectionDrive,
update: &mut impl ConnectionStateUpdate,
) -> Result<(), ServerError> {
use ServerStateImpl::*;
match (&mut *self, drive) {
(Initial(ssl), ConnectionDrive::Initial(initial_message)) => {
match_message!(initial_message, InitialMessage {
(StartupMessage as startup) => {
let mut user = String::new();
let mut database = String::new();
for param in startup.params() {
if param.name() == "user" {
user = param.value().to_owned()?;
} else if param.name() == "database" {
database = param.value().to_owned()?;
}
trace!("param: {:?}={:?}", param.name(), param.value());
update.parameter(param.name().to_str()?, param.value().to_str()?);
}
if user.is_empty() {
return Err(AUTH_ERROR);
}
if database.is_empty() {
database = user.clone();
}
*self = AuthInfo(user.clone());
update.auth(user, database)?;
},
(SSLRequest) => {
let Some(ssl) = *ssl else {
return Err(PROTOCOL_ERROR);
};
if ssl == ConnectionSslRequirement::Disable {
update.send_ssl(SSLResponseBuilder { code: b'N' })?;
update.upgrade()?;
} else {
update.send_ssl(SSLResponseBuilder { code: b'S' })?;
*self = SslConnecting;
}
},
unknown => {
log_unknown_initial_message(unknown, "Initial")?;
}
});
}
(SslConnecting, ConnectionDrive::SslReady) => {
*self = Initial(None);
}
(SslConnecting, _) => {
return Err(PROTOCOL_ERROR);
}
(AuthInfo(username), ConnectionDrive::AuthInfo(auth_type, credential_data)) => {
let mut auth = ServerAuth::new(username.clone(), auth_type, credential_data);
match auth.drive(ServerAuthDrive::Initial) {
ServerAuthResponse::Initial(AuthType::Plain, _) => {
update.send(&AuthenticationCleartextPasswordBuilder::default())?;
}
ServerAuthResponse::Initial(AuthType::Md5, salt) => {
update.send(&AuthenticationMD5PasswordBuilder {
salt: TryInto::<[u8; 4]>::try_into(salt).map_err(|_| PROTOCOL_ERROR)?,
})?;
}
ServerAuthResponse::Initial(AuthType::ScramSha256, _) => {
update.send(&AuthenticationSASLBuilder {
mechanisms: ["SCRAM-SHA-256"],
})?;
}
ServerAuthResponse::Complete(..) => {
update.send(&AuthenticationOkBuilder::default())?;
*self = Synchronizing;
update.params()?;
return Ok(());
}
ServerAuthResponse::Error(e) => {
error!("Authentication error in initial state: {e:?}");
return Err(e.into());
}
response => {
error!("Unexpected response: {response:?}");
return Err(PROTOCOL_ERROR);
}
}
*self = Authenticating(auth);
}
(Authenticating(auth), ConnectionDrive::Message(message)) => {
trace!("auth = {auth:?}, initial = {}", auth.is_initial_message());
match_message!(message, Message {
(PasswordMessage as password) if matches!(auth.auth_type(), AuthType::Plain | AuthType::Md5) => {
match auth.drive(ServerAuthDrive::Message(auth.auth_type(), password.password().to_bytes())) {
ServerAuthResponse::Complete(..) => {
update.send(&AuthenticationOkBuilder::default())?;
*self = Synchronizing;
update.params()?;
}
ServerAuthResponse::Error(e) => {
error!("Authentication error for password message: {e:?}");
return Err(e.into())
},
response => {
error!("Unexpected response for password message: {response:?}");
return Err(PROTOCOL_ERROR);
}
}
},
(SASLInitialResponse as sasl) if auth.is_initial_message() => {
if sasl.mechanism() != "SCRAM-SHA-256" {
error!("Unexpected mechanism: {:?}", sasl.mechanism());
return Err(PROTOCOL_ERROR);
}
match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) {
ServerAuthResponse::Continue(final_message) => {
update.send(&AuthenticationSASLContinueBuilder {
data: &final_message,
})?;
}
ServerAuthResponse::Error(e) => {
error!("Authentication error for SASL initial response: {e:?}");
return Err(e.into())
},
response => {
error!("Unexpected response for SASL initial response: {response:?}");
return Err(PROTOCOL_ERROR);
}
}
},
(SASLResponse as sasl) if !auth.is_initial_message() => {
match auth.drive(ServerAuthDrive::Message(AuthType::ScramSha256, sasl.response().as_ref())) {
ServerAuthResponse::Complete(data) => {
update.send(&AuthenticationSASLFinalBuilder {
data,
})?;
update.send(&AuthenticationOkBuilder::default())?;
*self = Synchronizing;
update.params()?;
}
ServerAuthResponse::Error(e) => {
error!("Authentication error for SASL response: {e:?}");
return Err(e.into())
},
response => {
error!("Unexpected response for SASL response: {response:?}");
return Err(PROTOCOL_ERROR);
}
}
},
unknown => {
log_unknown_message(unknown, "Authenticating")?;
}
});
}
(Synchronizing, ConnectionDrive::Parameter(name, value)) => {
update.send(&ParameterStatusBuilder { name, value })?;
}
(Synchronizing, ConnectionDrive::Ready(pid, key)) => {
update.send(&BackendKeyDataBuilder { pid, key })?;
update.send(&ReadyForQueryBuilder { status: b'I' })?;
*self = Ready;
}
(_, ConnectionDrive::Fail(error, _)) => {
return Err(ServerError::Protocol(error));
}
_ => {
error!("Unexpected drive in state {:?}", self);
return Err(PROTOCOL_ERROR);
}
}
Ok(())
}
}
fn log_unknown_initial_message(
message: Result<InitialMessage, ParseError>,
state: &str,
) -> Result<(), ServerError> {
match message {
Ok(message) => {
warn!(
"Unexpected message {:?} (length {}) received in {} state",
message.protocol_version(),
message.mlen(),
state
);
Err(PROTOCOL_VERSION_ERROR)
}
Err(e) => {
error!("Corrupted message received in {} state {:?}", state, e);
Err(PROTOCOL_ERROR)
}
}
}
fn log_unknown_message(
message: Result<Message, ParseError>,
state: &str,
) -> Result<(), ServerError> {
match message {
Ok(message) => {
warn!(
"Unexpected message {:?} (length {}) received in {} state",
message.mtype(),
message.mlen(),
state
);
Ok(())
}
Err(e) => {
error!("Corrupted message received in {} state {:?}", state, e);
Err(PROTOCOL_ERROR)
}
}
}