use crate::error::Result;
use crate::proto::grpc::sasl::{ChannelAuthenticationScheme, SaslMessage, SaslMessageType};
pub trait SaslClientHandler: Send + Sync {
fn initial_message(&self, client_id: &str, channel_ref: &str) -> Result<SaslMessage>;
fn handle_message(&self, message: &SaslMessage) -> Result<Option<SaslMessage>>;
fn auth_scheme(&self) -> ChannelAuthenticationScheme;
}
pub struct PlainSaslClientHandler {
auth_scheme: ChannelAuthenticationScheme,
initial_response: Vec<u8>,
}
impl PlainSaslClientHandler {
pub fn new_simple(username: &str, password: &str, impersonation_user: Option<&str>) -> Self {
Self::new(
ChannelAuthenticationScheme::Simple,
username,
password,
impersonation_user,
)
}
fn new(
auth_scheme: ChannelAuthenticationScheme,
username: &str,
password: &str,
impersonation_user: Option<&str>,
) -> Self {
let authzid = impersonation_user.unwrap_or("");
let initial_response = format!("{}\0{}\0{}", authzid, username, password).into_bytes();
Self {
auth_scheme,
initial_response,
}
}
}
impl SaslClientHandler for PlainSaslClientHandler {
fn initial_message(&self, client_id: &str, channel_ref: &str) -> Result<SaslMessage> {
Ok(SaslMessage {
message_type: Some(SaslMessageType::Challenge as i32),
message: Some(self.initial_response.clone()),
client_id: Some(client_id.to_string()),
authentication_scheme: Some(self.auth_scheme as i32),
channel_ref: Some(channel_ref.to_string()),
})
}
fn handle_message(&self, message: &SaslMessage) -> Result<Option<SaslMessage>> {
let msg_type = message
.message_type
.and_then(|v| SaslMessageType::try_from(v).ok())
.unwrap_or(SaslMessageType::Challenge);
match msg_type {
SaslMessageType::Challenge => {
Ok(Some(SaslMessage {
message_type: Some(SaslMessageType::Challenge as i32),
message: Some(self.initial_response.clone()),
client_id: None,
authentication_scheme: None,
channel_ref: None,
}))
}
SaslMessageType::Success => {
Ok(None)
}
}
}
fn auth_scheme(&self) -> ChannelAuthenticationScheme {
self.auth_scheme
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_plain_sasl_initial_response_format() {
let handler = PlainSaslClientHandler::new_simple("testuser", "noPassword", None);
assert_eq!(handler.initial_response, b"\0testuser\0noPassword");
}
#[test]
fn test_plain_sasl_with_impersonation_user() {
let handler =
PlainSaslClientHandler::new_simple("testuser", "noPassword", Some("proxyuser"));
assert_eq!(handler.initial_response, b"proxyuser\0testuser\0noPassword");
}
#[test]
fn test_plain_sasl_initial_message() {
let handler = PlainSaslClientHandler::new_simple("testuser", "noPassword", None);
let msg = handler
.initial_message("test-client-id", "test-channel")
.unwrap();
assert_eq!(msg.message_type, Some(SaslMessageType::Challenge as i32));
assert_eq!(msg.message, Some(b"\0testuser\0noPassword".to_vec()));
assert_eq!(msg.client_id, Some("test-client-id".to_string()));
assert_eq!(
msg.authentication_scheme,
Some(ChannelAuthenticationScheme::Simple as i32)
);
assert_eq!(msg.channel_ref, Some("test-channel".to_string()));
}
#[test]
fn test_plain_sasl_handle_success() {
let handler = PlainSaslClientHandler::new_simple("testuser", "noPassword", None);
let server_msg = SaslMessage {
message_type: Some(SaslMessageType::Success as i32),
message: None,
client_id: None,
authentication_scheme: None,
channel_ref: None,
};
let result = handler.handle_message(&server_msg).unwrap();
assert!(
result.is_none(),
"SUCCESS message should return None indicating auth complete"
);
}
#[test]
fn test_plain_sasl_handle_challenge() {
let handler = PlainSaslClientHandler::new_simple("testuser", "noPassword", None);
let server_msg = SaslMessage {
message_type: Some(SaslMessageType::Challenge as i32),
message: Some(vec![]),
client_id: None,
authentication_scheme: None,
channel_ref: None,
};
let result = handler.handle_message(&server_msg).unwrap();
assert!(
result.is_some(),
"CHALLENGE message should return a response"
);
}
}