use std::fmt::Debug;
use std::sync::Arc;
use async_trait::async_trait;
use futures::{Sink, SinkExt};
use tokio::sync::Mutex;
use crate::api::auth::sasl::scram::ScramServerAuthWaitingForClientFinal;
use crate::api::{ClientInfo, PgWireConnectionState};
use crate::error::{PgWireError, PgWireResult};
use crate::messages::startup::{Authentication, PasswordMessageFamily};
use crate::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use super::{ServerParameterProvider, StartupHandler};
pub mod oauth;
pub mod scram;
pub const SCRAM_SHA_256_METHOD: &str = "SCRAM-SHA-256";
pub const SCRAM_SHA_256_PLUS_METHOD: &str = "SCRAM-SHA-256-PLUS";
pub const OAUTHBEARER_METHOD: &str = "OAUTHBEARER";
#[derive(Debug)]
pub enum SASLState {
Initial,
ScramClientFirstReceived,
ScramServerFirstSent(Box<ScramServerAuthWaitingForClientFinal>),
OauthStateInit,
OauthStateError,
Finished,
}
impl SASLState {
fn is_scram(&self) -> bool {
matches!(
self,
SASLState::ScramClientFirstReceived | SASLState::ScramServerFirstSent(_)
)
}
fn is_oauth(&self) -> bool {
matches!(self, SASLState::OauthStateInit | SASLState::OauthStateError)
}
}
#[derive(Debug)]
pub struct SASLAuthStartupHandler<P> {
parameter_provider: Arc<P>,
state: Mutex<SASLState>,
scram: Option<scram::ScramAuth>,
oauth: Option<oauth::Oauth>,
}
#[async_trait]
impl<P: ServerParameterProvider> StartupHandler for SASLAuthStartupHandler<P> {
async fn on_startup<C>(
&self,
client: &mut C,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
match message {
PgWireFrontendMessage::Startup(ref startup) => {
super::protocol_negotiation(client, startup).await?;
super::save_startup_parameters_to_metadata(client, startup);
client.set_state(PgWireConnectionState::AuthenticationInProgress);
let supported_mechanisms = self.supported_mechanisms();
client
.send(PgWireBackendMessage::Authentication(Authentication::SASL(
supported_mechanisms,
)))
.await?;
}
PgWireFrontendMessage::PasswordMessageFamily(mut msg) => {
let mut state = self.state.lock().await;
msg = if let SASLState::Initial = *state {
let sasl_initial_response = msg.into_sasl_initial_response()?;
let selected_mechanism = sasl_initial_response.auth_method.as_str();
*state = if [SCRAM_SHA_256_METHOD, SCRAM_SHA_256_PLUS_METHOD]
.contains(&selected_mechanism)
{
SASLState::ScramClientFirstReceived
} else if OAUTHBEARER_METHOD == selected_mechanism {
SASLState::OauthStateInit
} else {
return Err(PgWireError::UnsupportedSASLAuthMethod(
selected_mechanism.to_string(),
));
};
PasswordMessageFamily::SASLInitialResponse(sasl_initial_response)
} else {
let sasl_response = msg.into_sasl_response()?;
PasswordMessageFamily::SASLResponse(sasl_response)
};
if state.is_scram() {
let scram = self.scram.as_ref().ok_or_else(|| {
PgWireError::UnsupportedSASLAuthMethod("SCRAM".to_string())
})?;
let (res, new_state) = scram.process_scram_message(client, msg, &state).await?;
client
.send(PgWireBackendMessage::Authentication(res))
.await?;
*state = new_state;
} else if state.is_oauth() {
let oauth = self.oauth.as_ref().ok_or_else(|| {
PgWireError::UnsupportedSASLAuthMethod("OAUTHBEARER".to_string())
})?;
let (res, new_state) = oauth.process_oauth_message(client, msg, &state).await?;
if let Some(res) = res {
client
.send(PgWireBackendMessage::Authentication(res))
.await?;
}
*state = new_state;
} else {
return Err(PgWireError::InvalidSASLState);
};
if matches!(*state, SASLState::Finished) {
super::finish_authentication(client, self.parameter_provider.as_ref()).await?;
}
}
_ => {}
}
Ok(())
}
}
impl<P> SASLAuthStartupHandler<P> {
pub fn new(parameter_provider: Arc<P>) -> Self {
SASLAuthStartupHandler {
parameter_provider,
state: Mutex::new(SASLState::Initial),
scram: None,
oauth: None,
}
}
pub fn with_scram(mut self, scram_auth: scram::ScramAuth) -> Self {
self.scram = Some(scram_auth);
self
}
pub fn with_oauth(mut self, oauth: oauth::Oauth) -> Self {
self.oauth = Some(oauth);
self
}
fn supported_mechanisms(&self) -> Vec<String> {
let mut mechanisms = vec![];
if let Some(scram) = &self.scram {
mechanisms.push(SCRAM_SHA_256_METHOD.to_owned());
if scram.supports_channel_binding() {
mechanisms.push(SCRAM_SHA_256_PLUS_METHOD.to_owned());
}
}
if self.oauth.is_some() {
mechanisms.push(OAUTHBEARER_METHOD.to_owned());
}
mechanisms
}
}