use std::fmt::Debug;
use async_trait::async_trait;
use futures::sink::{Sink, SinkExt};
use super::{
AuthSource, ClientInfo, LoginInfo, PgWireConnectionState, ServerParameterProvider,
StartupHandler,
};
use crate::error::{PgWireError, PgWireResult};
use crate::messages::startup::Authentication;
use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
#[derive(new)]
pub struct CleartextPasswordAuthStartupHandler<A, P> {
auth_source: A,
parameter_provider: P,
}
#[async_trait]
impl<V: AuthSource, P: ServerParameterProvider> StartupHandler
for CleartextPasswordAuthStartupHandler<V, P>
{
async fn on_startup<C>(
&self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
match message {
PgWireFrontendMessage::Startup(ref startup) => {
super::protocol_negotiation(client, startup).await?;
super::save_startup_parameters_to_metadata(client, startup);
client.set_state(PgWireConnectionState::AuthenticationInProgress);
client
.send(PgWireBackendMessage::Authentication(
Authentication::CleartextPassword,
))
.await?;
}
PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
let pwd = pwd.into_password()?;
let login_info = LoginInfo::from_client_info(client);
let pass = self.auth_source.get_password(&login_info).await?;
if pass.password == pwd.password.as_bytes() {
super::finish_authentication(client, &self.parameter_provider).await?;
} else {
return Err(PgWireError::InvalidPassword(
login_info.user().map(|x| x.to_owned()).unwrap_or_default(),
));
}
}
_ => {}
}
Ok(())
}
}