use crate::{
md5::md5_password,
scram::{
generate_nonce, generate_salted_password, ClientEnvironment, ClientTransaction, SCRAMError,
Sha256Out,
},
AuthType, CredentialData,
};
use tracing::error;
#[derive(Debug)]
pub enum ClientAuthResponse {
Initial(AuthType, Vec<u8>),
Continue(Vec<u8>),
Complete,
Waiting,
Error(ClientAuthError),
}
#[derive(Debug, derive_more::Error, derive_more::Display, derive_more::From)]
pub enum ClientAuthError {
#[display("SCRAM protocol error: {_0}")]
ScramError(#[from] SCRAMError),
#[display("Invalid authentication state")]
InvalidState,
#[display("Invalid credentials")]
InvalidCredentials,
#[display("Unexpected message during authentication")]
UnexpectedMessage,
}
#[derive(Debug)]
enum ClientAuthState {
Initial(String, CredentialData),
Complete,
Waiting,
Sasl(ClientTransaction, ClientEnvironmentImpl),
}
#[derive(Debug)]
pub enum ClientAuthDrive<'a> {
Ok,
Plain,
Md5([u8; 4]),
Scram,
ScramResponse(&'a [u8]),
}
#[derive(Debug)]
pub struct ClientAuth {
state: ClientAuthState,
auth_type: Option<AuthType>,
}
impl ClientAuth {
pub fn new(username: String, credentials: CredentialData) -> Self {
Self {
state: ClientAuthState::Initial(username, credentials),
auth_type: None,
}
}
pub fn is_complete(&self) -> bool {
matches!(self.state, ClientAuthState::Complete)
}
pub fn auth_type(&self) -> Option<AuthType> {
self.auth_type
}
pub fn drive(&mut self, drive: ClientAuthDrive) -> Result<ClientAuthResponse, ClientAuthError> {
match (&mut self.state, drive) {
(ClientAuthState::Initial(username, credentials), drive) => {
let username = std::mem::take(username);
let credentials = std::mem::replace(credentials, CredentialData::Deny);
self.handle_initial(username, credentials, drive)
}
(ClientAuthState::Sasl(tx, env), ClientAuthDrive::ScramResponse(message)) => {
let response = tx.process_message(message, env)?;
match response {
Some(response) => Ok(ClientAuthResponse::Continue(response)),
None => {
self.state = ClientAuthState::Waiting;
Ok(ClientAuthResponse::Waiting)
}
}
}
(ClientAuthState::Waiting, ClientAuthDrive::Ok) => {
self.state = ClientAuthState::Complete;
Ok(ClientAuthResponse::Complete)
}
(_, drive) => {
error!("Received invalid drive {drive:?} in state {:?}", self.state);
Err(ClientAuthError::InvalidState)
}
}
}
fn handle_initial(
&mut self,
username: String,
credentials: CredentialData,
drive: ClientAuthDrive,
) -> Result<ClientAuthResponse, ClientAuthError> {
let (auth_type, (state, response)) = match drive {
ClientAuthDrive::Ok => (
AuthType::Trust,
match credentials {
CredentialData::Deny => (
ClientAuthState::Complete,
ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
),
_ => (ClientAuthState::Complete, ClientAuthResponse::Complete),
},
),
ClientAuthDrive::Plain => (
AuthType::Plain,
match credentials {
CredentialData::Plain(credentials) => (
ClientAuthState::Waiting,
ClientAuthResponse::Initial(
AuthType::Plain,
credentials.clone().into_bytes(),
),
),
_ => (
ClientAuthState::Complete,
ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
),
},
),
ClientAuthDrive::Md5(salt) => (
AuthType::Md5,
match credentials {
CredentialData::Md5(credentials) => (
ClientAuthState::Waiting,
ClientAuthResponse::Initial(
AuthType::Md5,
credentials.salted(salt).into_bytes(),
),
),
CredentialData::Plain(credentials) => (
ClientAuthState::Waiting,
ClientAuthResponse::Initial(
AuthType::Md5,
md5_password(&credentials, &username, salt).into_bytes(),
),
),
_ => (
ClientAuthState::Complete,
ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
),
},
),
ClientAuthDrive::Scram => (
AuthType::ScramSha256,
match credentials {
CredentialData::Plain(credentials) => {
let env = ClientEnvironmentImpl {
password: credentials,
};
let mut tx = ClientTransaction::new(username.into());
let response = tx.process_message(&[], &env);
match response {
Ok(Some(response)) => (
ClientAuthState::Sasl(tx, env),
ClientAuthResponse::Initial(AuthType::ScramSha256, response),
),
Ok(None) => (
ClientAuthState::Complete,
ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
),
Err(e) => (
ClientAuthState::Complete,
ClientAuthResponse::Error(ClientAuthError::ScramError(e)),
),
}
}
_ => (
ClientAuthState::Complete,
ClientAuthResponse::Error(ClientAuthError::InvalidCredentials),
),
},
),
_ => {
error!("Received invalid drive {drive:?} in state Initial");
return Err(ClientAuthError::InvalidState);
}
};
self.auth_type = Some(auth_type);
self.state = state;
Ok(response)
}
}
#[derive(Debug)]
struct ClientEnvironmentImpl {
password: String,
}
impl ClientEnvironment for ClientEnvironmentImpl {
fn generate_nonce(&self) -> String {
generate_nonce()
}
fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
generate_salted_password(self.password.as_bytes(), salt, iterations)
}
}