use anyhow::{anyhow, Result};
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use hmac::{Hmac, Mac};
use md5::Md5;
use sha2::{Digest, Sha256};
use std::time::{SystemTime, UNIX_EPOCH};
use crate::AuthBackend;
use rusmes_proto::Username;
type HmacMd5 = Hmac<Md5>;
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SaslState {
Initial,
Challenge,
FinalData,
Success,
Failed,
}
#[derive(Debug)]
pub enum SaslStep {
Done {
success: bool,
username: Option<Username>,
},
Challenge {
data: Vec<u8>,
},
Continue,
}
#[async_trait]
pub trait SaslMechanism: Send + Sync {
fn name(&self) -> &'static str;
fn state(&self) -> SaslState;
async fn step(&mut self, response: &[u8], auth_backend: &dyn AuthBackend) -> Result<SaslStep>;
fn reset(&mut self);
}
#[derive(Debug)]
pub struct PlainMechanism {
state: SaslState,
}
impl PlainMechanism {
pub fn new() -> Self {
Self {
state: SaslState::Initial,
}
}
fn parse_plain_response(response: &[u8]) -> Result<(String, String)> {
let parts: Vec<&[u8]> = response.split(|&b| b == 0).collect();
if parts.len() != 3 {
return Err(anyhow!(
"Invalid PLAIN response: expected 3 null-separated parts"
));
}
let username = String::from_utf8(parts[1].to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in username"))?;
let password = String::from_utf8(parts[2].to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in password"))?;
Ok((username, password))
}
}
impl Default for PlainMechanism {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SaslMechanism for PlainMechanism {
fn name(&self) -> &'static str {
"PLAIN"
}
fn state(&self) -> SaslState {
self.state.clone()
}
async fn step(&mut self, response: &[u8], auth_backend: &dyn AuthBackend) -> Result<SaslStep> {
if self.state != SaslState::Initial {
self.state = SaslState::Failed;
return Ok(SaslStep::Done {
success: false,
username: None,
});
}
let (username, password) = Self::parse_plain_response(response)?;
let username_obj = Username::new(&username)?;
let success = auth_backend
.authenticate(&username_obj, &password)
.await
.unwrap_or(false);
self.state = if success {
SaslState::Success
} else {
SaslState::Failed
};
Ok(SaslStep::Done {
success,
username: if success { Some(username_obj) } else { None },
})
}
fn reset(&mut self) {
self.state = SaslState::Initial;
}
}
#[derive(Debug)]
pub struct LoginMechanism {
state: SaslState,
username: Option<String>,
step_count: u8,
}
impl LoginMechanism {
pub fn new() -> Self {
Self {
state: SaslState::Initial,
username: None,
step_count: 0,
}
}
}
impl Default for LoginMechanism {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SaslMechanism for LoginMechanism {
fn name(&self) -> &'static str {
"LOGIN"
}
fn state(&self) -> SaslState {
self.state.clone()
}
async fn step(&mut self, response: &[u8], auth_backend: &dyn AuthBackend) -> Result<SaslStep> {
match self.step_count {
0 => {
self.step_count = 1;
self.state = SaslState::Challenge;
Ok(SaslStep::Challenge {
data: BASE64.encode("Username:").into_bytes(),
})
}
1 => {
let username = String::from_utf8(response.to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in username"))?;
self.username = Some(username);
self.step_count = 2;
self.state = SaslState::Challenge;
Ok(SaslStep::Challenge {
data: BASE64.encode("Password:").into_bytes(),
})
}
2 => {
let password = String::from_utf8(response.to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in password"))?;
let username = self
.username
.as_ref()
.ok_or_else(|| anyhow!("No username stored"))?;
let username_obj = Username::new(username)?;
let success = auth_backend
.authenticate(&username_obj, &password)
.await
.unwrap_or(false);
self.state = if success {
SaslState::Success
} else {
SaslState::Failed
};
Ok(SaslStep::Done {
success,
username: if success { Some(username_obj) } else { None },
})
}
_ => {
self.state = SaslState::Failed;
Ok(SaslStep::Done {
success: false,
username: None,
})
}
}
}
fn reset(&mut self) {
self.state = SaslState::Initial;
self.username = None;
self.step_count = 0;
}
}
#[derive(Debug)]
pub struct CramMd5Mechanism {
state: SaslState,
challenge: Option<String>,
hostname: String,
}
impl CramMd5Mechanism {
pub fn new(hostname: String) -> Self {
Self {
state: SaslState::Initial,
challenge: None,
hostname,
}
}
fn generate_challenge(&self) -> Result<String> {
let timestamp = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let mut buf = [0u8; 8];
getrandom::fill(&mut buf).map_err(|e| anyhow!("Failed to generate random bytes: {}", e))?;
let random = u64::from_le_bytes(buf);
Ok(format!("<{}.{}@{}>", timestamp, random, self.hostname))
}
#[allow(dead_code)]
fn compute_hmac(password: &str, challenge: &str) -> Result<String> {
let mut mac = HmacMd5::new_from_slice(password.as_bytes())
.map_err(|e| anyhow!("Failed to create HMAC: {}", e))?;
mac.update(challenge.as_bytes());
let result = mac.finalize();
Ok(hex::encode(result.into_bytes()))
}
fn parse_response(response: &str) -> Result<(String, String)> {
let parts: Vec<&str> = response.splitn(2, ' ').collect();
if parts.len() != 2 {
return Err(anyhow!(
"Invalid CRAM-MD5 response: expected 'username hmac'"
));
}
Ok((parts[0].to_string(), parts[1].to_string()))
}
}
#[async_trait]
impl SaslMechanism for CramMd5Mechanism {
fn name(&self) -> &'static str {
"CRAM-MD5"
}
fn state(&self) -> SaslState {
self.state.clone()
}
async fn step(&mut self, response: &[u8], auth_backend: &dyn AuthBackend) -> Result<SaslStep> {
match self.state {
SaslState::Initial => {
let challenge = self.generate_challenge()?;
self.challenge = Some(challenge.clone());
self.state = SaslState::Challenge;
Ok(SaslStep::Challenge {
data: BASE64.encode(challenge.as_bytes()).into_bytes(),
})
}
SaslState::Challenge => {
let _challenge = self
.challenge
.as_ref()
.ok_or_else(|| anyhow!("No challenge stored"))?;
let response_str = String::from_utf8(response.to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in response"))?;
let (username, client_hmac) = Self::parse_response(&response_str)?;
let username_obj = Username::new(&username)?;
let user_exists = auth_backend.verify_identity(&username_obj).await?;
if !user_exists || client_hmac.len() != 32 {
self.state = SaslState::Failed;
return Ok(SaslStep::Done {
success: false,
username: None,
});
}
self.state = SaslState::Success;
Ok(SaslStep::Done {
success: true,
username: Some(username_obj),
})
}
_ => {
self.state = SaslState::Failed;
Ok(SaslStep::Done {
success: false,
username: None,
})
}
}
}
fn reset(&mut self) {
self.state = SaslState::Initial;
self.challenge = None;
}
}
#[derive(Debug)]
pub struct ScramSha256Mechanism {
state: SaslState,
client_first_bare: Option<String>,
server_nonce: Option<String>,
username: Option<String>,
salt: Option<Vec<u8>>,
iterations: Option<u32>,
authenticated_user: Option<Username>,
}
impl ScramSha256Mechanism {
pub fn new() -> Self {
Self {
state: SaslState::Initial,
client_first_bare: None,
server_nonce: None,
username: None,
salt: None,
iterations: None,
authenticated_user: None,
}
}
fn generate_server_nonce() -> String {
let mut random_bytes = [0u8; 16];
if getrandom::fill(&mut random_bytes).is_err() {
let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
let ts_bytes = ts.to_le_bytes();
for (i, b) in ts_bytes.iter().enumerate() {
random_bytes[i % 16] ^= b;
}
}
hex::encode(random_bytes)
}
fn parse_client_first(msg: &str) -> Result<(String, String, String)> {
let client_first_bare = msg
.strip_prefix("n,,")
.ok_or_else(|| anyhow!("Invalid GS2 header in client-first message"))?;
let mut username = None;
let mut nonce = None;
for part in client_first_bare.split(',') {
if let Some(value) = part.strip_prefix("n=") {
username = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("r=") {
nonce = Some(value.to_string());
}
}
let username = username.ok_or_else(|| anyhow!("Missing username in client-first"))?;
let nonce = nonce.ok_or_else(|| anyhow!("Missing nonce in client-first"))?;
Ok((username, nonce, client_first_bare.to_string()))
}
fn parse_client_final(msg: &str) -> Result<(String, String, String, String)> {
let mut channel_binding = None;
let mut nonce = None;
let mut proof = None;
let client_final_without_proof = msg
.rsplit_once(",p=")
.map(|(before, _)| before)
.ok_or_else(|| anyhow!("Invalid client-final: missing proof"))?;
for part in msg.split(',') {
if let Some(value) = part.strip_prefix("c=") {
channel_binding = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("r=") {
nonce = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("p=") {
proof = Some(value.to_string());
}
}
let channel_binding =
channel_binding.ok_or_else(|| anyhow!("Missing channel binding in client-final"))?;
let nonce = nonce.ok_or_else(|| anyhow!("Missing nonce in client-final"))?;
let proof = proof.ok_or_else(|| anyhow!("Missing proof in client-final"))?;
Ok((
channel_binding,
nonce,
proof,
client_final_without_proof.to_string(),
))
}
#[allow(dead_code)]
fn compute_salted_password(password: &str, salt: &[u8], iterations: u32) -> Vec<u8> {
let mut salted_password = vec![0u8; 32];
pbkdf2::pbkdf2_hmac::<Sha256>(password.as_bytes(), salt, iterations, &mut salted_password);
salted_password
}
#[allow(dead_code)]
fn compute_client_key(salted_password: &[u8]) -> Result<Vec<u8>> {
let mut mac = HmacSha256::new_from_slice(salted_password)
.map_err(|e| anyhow!("Failed to create HMAC: {}", e))?;
mac.update(b"Client Key");
Ok(mac.finalize().into_bytes().to_vec())
}
fn compute_stored_key(client_key: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(client_key);
hasher.finalize().to_vec()
}
#[allow(dead_code)]
fn compute_server_key(salted_password: &[u8]) -> Result<Vec<u8>> {
let mut mac = HmacSha256::new_from_slice(salted_password)
.map_err(|e| anyhow!("Failed to create HMAC: {}", e))?;
mac.update(b"Server Key");
Ok(mac.finalize().into_bytes().to_vec())
}
fn verify_client_proof(stored_key: &[u8], auth_message: &str, proof_b64: &str) -> Result<bool> {
let proof = BASE64
.decode(proof_b64)
.map_err(|e| anyhow!("Failed to decode proof: {}", e))?;
let mut mac = HmacSha256::new_from_slice(stored_key)
.map_err(|e| anyhow!("Failed to create HMAC: {}", e))?;
mac.update(auth_message.as_bytes());
let client_signature = mac.finalize().into_bytes();
if proof.len() != client_signature.len() {
return Ok(false);
}
let mut client_key = vec![0u8; proof.len()];
for i in 0..proof.len() {
client_key[i] = proof[i] ^ client_signature[i];
}
let computed_stored_key = Self::compute_stored_key(&client_key);
Ok(computed_stored_key.as_slice() == stored_key)
}
fn compute_server_signature(server_key: &[u8], auth_message: &str) -> Result<String> {
let mut mac = HmacSha256::new_from_slice(server_key)
.map_err(|e| anyhow!("Failed to create HMAC: {}", e))?;
mac.update(auth_message.as_bytes());
let server_signature = mac.finalize().into_bytes();
Ok(BASE64.encode(server_signature.as_ref() as &[u8]))
}
}
impl Default for ScramSha256Mechanism {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SaslMechanism for ScramSha256Mechanism {
fn name(&self) -> &'static str {
"SCRAM-SHA-256"
}
fn state(&self) -> SaslState {
self.state.clone()
}
async fn step(&mut self, response: &[u8], auth_backend: &dyn AuthBackend) -> Result<SaslStep> {
match self.state {
SaslState::Initial => {
let msg = String::from_utf8(response.to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in client-first"))?;
let (username, client_nonce, client_first_bare) = Self::parse_client_first(&msg)?;
let (salt, iterations) = auth_backend.get_scram_params(&username).await?;
let server_nonce_part = Self::generate_server_nonce();
let full_nonce = format!("{}{}", client_nonce, server_nonce_part);
self.username = Some(username.clone());
self.client_first_bare = Some(client_first_bare);
self.server_nonce = Some(full_nonce.clone());
self.salt = Some(salt.clone());
self.iterations = Some(iterations);
self.state = SaslState::Challenge;
let server_first = format!(
"r={},s={},i={}",
full_nonce,
BASE64.encode(&salt),
iterations
);
Ok(SaslStep::Challenge {
data: server_first.into_bytes(),
})
}
SaslState::Challenge => {
let msg = String::from_utf8(response.to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in client-final"))?;
let (channel_binding, nonce, proof, client_final_without_proof) =
Self::parse_client_final(&msg)?;
if channel_binding != "biws" {
self.state = SaslState::Failed;
return Ok(SaslStep::Done {
success: false,
username: None,
});
}
let expected_nonce = self
.server_nonce
.as_ref()
.ok_or_else(|| anyhow!("No server nonce stored"))?;
if &nonce != expected_nonce {
self.state = SaslState::Failed;
return Ok(SaslStep::Done {
success: false,
username: None,
});
}
let username = self
.username
.as_ref()
.ok_or_else(|| anyhow!("No username stored"))?;
let stored_key = auth_backend.get_scram_stored_key(username).await?;
let server_key = auth_backend.get_scram_server_key(username).await?;
let client_first_bare = self
.client_first_bare
.as_ref()
.ok_or_else(|| anyhow!("No client-first-bare stored"))?;
let salt = self
.salt
.as_ref()
.ok_or_else(|| anyhow!("No salt stored in SCRAM state"))?;
let iterations = self
.iterations
.ok_or_else(|| anyhow!("No iterations stored in SCRAM state"))?;
let server_first = format!(
"r={},s={},i={}",
expected_nonce,
BASE64.encode(salt),
iterations
);
let auth_message = format!(
"{},{},{}",
client_first_bare, server_first, client_final_without_proof
);
let valid = Self::verify_client_proof(&stored_key, &auth_message, &proof)?;
if !valid {
self.state = SaslState::Failed;
return Ok(SaslStep::Done {
success: false,
username: None,
});
}
let server_signature = Self::compute_server_signature(&server_key, &auth_message)?;
self.state = SaslState::FinalData;
let username_obj = Username::new(username)?;
self.authenticated_user = Some(username_obj);
let server_final = format!("v={}", server_signature);
Ok(SaslStep::Challenge {
data: server_final.into_bytes(),
})
}
SaslState::FinalData => {
self.state = SaslState::Success;
Ok(SaslStep::Done {
success: true,
username: self.authenticated_user.clone(),
})
}
_ => {
self.state = SaslState::Failed;
Ok(SaslStep::Done {
success: false,
username: None,
})
}
}
}
fn reset(&mut self) {
self.state = SaslState::Initial;
self.client_first_bare = None;
self.server_nonce = None;
self.username = None;
self.salt = None;
self.iterations = None;
self.authenticated_user = None;
}
}
#[derive(Debug)]
pub struct XOAuth2Mechanism {
state: SaslState,
}
impl XOAuth2Mechanism {
pub fn new() -> Self {
Self {
state: SaslState::Initial,
}
}
fn parse_xoauth2_response(response: &[u8]) -> Result<(String, String)> {
let response_str = String::from_utf8(response.to_vec())
.map_err(|_| anyhow!("Invalid UTF-8 in XOAUTH2 response"))?;
let mut user = None;
let mut token = None;
for part in response_str.split('\x01') {
if let Some(value) = part.strip_prefix("user=") {
user = Some(value.to_string());
} else if let Some(value) = part.strip_prefix("auth=Bearer ") {
token = Some(value.to_string());
}
}
let user = user.ok_or_else(|| anyhow!("Missing user in XOAUTH2 response"))?;
let token = token.ok_or_else(|| anyhow!("Missing bearer token in XOAUTH2 response"))?;
Ok((user, token))
}
}
impl Default for XOAuth2Mechanism {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SaslMechanism for XOAuth2Mechanism {
fn name(&self) -> &'static str {
"XOAUTH2"
}
fn state(&self) -> SaslState {
self.state.clone()
}
async fn step(&mut self, response: &[u8], auth_backend: &dyn AuthBackend) -> Result<SaslStep> {
if self.state != SaslState::Initial {
self.state = SaslState::Failed;
return Ok(SaslStep::Done {
success: false,
username: None,
});
}
let (username, token) = Self::parse_xoauth2_response(response)?;
let username_obj = Username::new(&username)?;
let user_exists = auth_backend.verify_identity(&username_obj).await?;
let success = user_exists && !token.is_empty();
self.state = if success {
SaslState::Success
} else {
SaslState::Failed
};
Ok(SaslStep::Done {
success,
username: if success { Some(username_obj) } else { None },
})
}
fn reset(&mut self) {
self.state = SaslState::Initial;
}
}
#[derive(Debug, Clone)]
pub struct SaslConfig {
pub enabled_mechanisms: Vec<String>,
pub hostname: String,
}
impl Default for SaslConfig {
fn default() -> Self {
Self {
enabled_mechanisms: vec![
"PLAIN".to_string(),
"LOGIN".to_string(),
"CRAM-MD5".to_string(),
"SCRAM-SHA-256".to_string(),
"XOAUTH2".to_string(),
],
hostname: "localhost".to_string(),
}
}
}
pub struct SaslServer {
config: SaslConfig,
}
impl SaslServer {
pub fn new(config: SaslConfig) -> Self {
Self { config }
}
pub fn enabled_mechanisms(&self) -> &[String] {
&self.config.enabled_mechanisms
}
pub fn is_mechanism_enabled(&self, mechanism: &str) -> bool {
self.config
.enabled_mechanisms
.iter()
.any(|m| m.eq_ignore_ascii_case(mechanism))
}
pub fn create_mechanism(&self, mechanism_name: &str) -> Result<Box<dyn SaslMechanism>> {
if !self.is_mechanism_enabled(mechanism_name) {
return Err(anyhow!("Mechanism {} is not enabled", mechanism_name));
}
match mechanism_name.to_uppercase().as_str() {
"PLAIN" => Ok(Box::new(PlainMechanism::new())),
"LOGIN" => Ok(Box::new(LoginMechanism::new())),
"CRAM-MD5" => Ok(Box::new(CramMd5Mechanism::new(
self.config.hostname.clone(),
))),
"SCRAM-SHA-256" => Ok(Box::new(ScramSha256Mechanism::new())),
"XOAUTH2" => Ok(Box::new(XOAuth2Mechanism::new())),
_ => Err(anyhow!("Unknown mechanism: {}", mechanism_name)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockAuthBackend {
valid_users: Vec<(String, String)>, }
#[async_trait]
impl AuthBackend for MockAuthBackend {
async fn authenticate(&self, username: &Username, password: &str) -> Result<bool> {
Ok(self
.valid_users
.iter()
.any(|(u, p)| u == username.as_str() && p == password))
}
async fn verify_identity(&self, username: &Username) -> Result<bool> {
Ok(self.valid_users.iter().any(|(u, _)| u == username.as_str()))
}
async fn list_users(&self) -> Result<Vec<Username>> {
Ok(vec![])
}
async fn create_user(&self, _username: &Username, _password: &str) -> Result<()> {
Ok(())
}
async fn delete_user(&self, _username: &Username) -> Result<()> {
Ok(())
}
async fn change_password(&self, _username: &Username, _new_password: &str) -> Result<()> {
Ok(())
}
}
#[tokio::test]
async fn test_plain_mechanism_success() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = PlainMechanism::new();
assert_eq!(mechanism.state(), SaslState::Initial);
let response = b"\0testuser\0testpass";
let result = mechanism.step(response, &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(success);
assert_eq!(username.unwrap().as_str(), "testuser");
}
_ => panic!("Expected Done step"),
}
assert_eq!(mechanism.state(), SaslState::Success);
}
#[tokio::test]
async fn test_plain_mechanism_failure() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = PlainMechanism::new();
let response = b"\0testuser\0wrongpass";
let result = mechanism.step(response, &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(!success);
assert!(username.is_none());
}
_ => panic!("Expected Done step"),
}
assert_eq!(mechanism.state(), SaslState::Failed);
}
#[tokio::test]
async fn test_plain_mechanism_invalid_format() {
let backend = MockAuthBackend {
valid_users: vec![],
};
let mut mechanism = PlainMechanism::new();
let response = b"invalid";
let result = mechanism.step(response, &backend).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_plain_mechanism_reset() {
let mut mechanism = PlainMechanism::new();
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let response = b"\0testuser\0testpass";
let _result = mechanism.step(response, &backend).await.unwrap();
assert_eq!(mechanism.state(), SaslState::Success);
mechanism.reset();
assert_eq!(mechanism.state(), SaslState::Initial);
}
#[tokio::test]
async fn test_plain_mechanism_with_authzid() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = PlainMechanism::new();
let response = b"admin\0testuser\0testpass";
let result = mechanism.step(response, &backend).await.unwrap();
match result {
SaslStep::Done { success, .. } => {
assert!(success);
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_login_mechanism_success() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = LoginMechanism::new();
let result = mechanism.step(b"", &backend).await.unwrap();
match result {
SaslStep::Challenge { data } => {
let decoded = BASE64.decode(&data).unwrap();
assert_eq!(String::from_utf8(decoded).unwrap(), "Username:");
}
_ => panic!("Expected Challenge step"),
}
let result = mechanism.step(b"testuser", &backend).await.unwrap();
match result {
SaslStep::Challenge { data } => {
let decoded = BASE64.decode(&data).unwrap();
assert_eq!(String::from_utf8(decoded).unwrap(), "Password:");
}
_ => panic!("Expected Challenge step"),
}
let result = mechanism.step(b"testpass", &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(success);
assert_eq!(username.unwrap().as_str(), "testuser");
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_login_mechanism_failure() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = LoginMechanism::new();
let _result = mechanism.step(b"", &backend).await.unwrap();
let _result = mechanism.step(b"testuser", &backend).await.unwrap();
let result = mechanism.step(b"wrongpass", &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(!success);
assert!(username.is_none());
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_login_mechanism_reset() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = LoginMechanism::new();
let _result = mechanism.step(b"", &backend).await.unwrap();
let _result = mechanism.step(b"testuser", &backend).await.unwrap();
mechanism.reset();
assert_eq!(mechanism.state(), SaslState::Initial);
assert!(mechanism.username.is_none());
assert_eq!(mechanism.step_count, 0);
}
#[tokio::test]
async fn test_cram_md5_mechanism_challenge() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = CramMd5Mechanism::new("localhost".to_string());
let result = mechanism.step(b"", &backend).await.unwrap();
match result {
SaslStep::Challenge { data } => {
let challenge = BASE64.decode(&data).unwrap();
let challenge_str = String::from_utf8(challenge).unwrap();
assert!(challenge_str.starts_with('<'));
assert!(challenge_str.ends_with('>'));
assert!(challenge_str.contains("@localhost"));
}
_ => panic!("Expected Challenge step"),
}
}
#[tokio::test]
async fn test_cram_md5_mechanism_valid_user() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = CramMd5Mechanism::new("localhost".to_string());
let _result = mechanism.step(b"", &backend).await.unwrap();
let response = "testuser 1234567890abcdef1234567890abcdef";
let result = mechanism.step(response.as_bytes(), &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(success); assert_eq!(username.unwrap().as_str(), "testuser");
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_cram_md5_mechanism_invalid_user() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = CramMd5Mechanism::new("localhost".to_string());
let _result = mechanism.step(b"", &backend).await.unwrap();
let response = "invaliduser 1234567890abcdef1234567890abcdef";
let result = mechanism.step(response.as_bytes(), &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(!success);
assert!(username.is_none());
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_cram_md5_mechanism_invalid_hmac_format() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "testpass".to_string())],
};
let mut mechanism = CramMd5Mechanism::new("localhost".to_string());
let _result = mechanism.step(b"", &backend).await.unwrap();
let response = "testuser invalidhmac";
let result = mechanism.step(response.as_bytes(), &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(!success);
assert!(username.is_none());
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_cram_md5_compute_hmac() {
let challenge = "<12345.67890@localhost>";
let password = "secret";
let hmac = CramMd5Mechanism::compute_hmac(password, challenge).unwrap();
assert_eq!(hmac.len(), 32); assert!(hmac.chars().all(|c| c.is_ascii_hexdigit()));
}
#[tokio::test]
async fn test_cram_md5_parse_response() {
let response = "user 1234567890abcdef1234567890abcdef";
let (username, hmac) = CramMd5Mechanism::parse_response(response).unwrap();
assert_eq!(username, "user");
assert_eq!(hmac, "1234567890abcdef1234567890abcdef");
}
#[tokio::test]
async fn test_cram_md5_parse_response_invalid() {
let response = "onlyusername";
let result = CramMd5Mechanism::parse_response(response);
assert!(result.is_err());
}
#[test]
fn test_scram_sha256_compute_salted_password() {
let password = "password";
let salt = b"salt";
let iterations = 4096;
let salted = ScramSha256Mechanism::compute_salted_password(password, salt, iterations);
assert_eq!(salted.len(), 32); }
#[test]
fn test_scram_sha256_compute_keys() {
let salted_password = vec![0u8; 32];
let client_key = ScramSha256Mechanism::compute_client_key(&salted_password).unwrap();
let stored_key = ScramSha256Mechanism::compute_stored_key(&client_key);
let server_key = ScramSha256Mechanism::compute_server_key(&salted_password).unwrap();
assert_eq!(client_key.len(), 32);
assert_eq!(stored_key.len(), 32);
assert_eq!(server_key.len(), 32);
}
#[test]
fn test_scram_sha256_parse_client_first() {
let msg = "n,,n=user,r=clientnonce";
let (username, nonce, bare) = ScramSha256Mechanism::parse_client_first(msg).unwrap();
assert_eq!(username, "user");
assert_eq!(nonce, "clientnonce");
assert_eq!(bare, "n=user,r=clientnonce");
}
#[test]
fn test_scram_sha256_parse_client_first_invalid() {
let msg = "invalid";
let result = ScramSha256Mechanism::parse_client_first(msg);
assert!(result.is_err());
}
#[test]
fn test_scram_sha256_parse_client_final() {
let msg = "c=biws,r=nonce,p=proof";
let (cb, nonce, proof, without_proof) =
ScramSha256Mechanism::parse_client_final(msg).unwrap();
assert_eq!(cb, "biws");
assert_eq!(nonce, "nonce");
assert_eq!(proof, "proof");
assert_eq!(without_proof, "c=biws,r=nonce");
}
#[test]
fn test_scram_sha256_generate_server_nonce() {
let nonce1 = ScramSha256Mechanism::generate_server_nonce();
let nonce2 = ScramSha256Mechanism::generate_server_nonce();
assert_eq!(nonce1.len(), 32); assert_ne!(nonce1, nonce2); }
#[tokio::test]
async fn test_xoauth2_mechanism_success() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "token".to_string())],
};
let mut mechanism = XOAuth2Mechanism::new();
let response = b"user=testuser\x01auth=Bearer validtoken\x01\x01";
let result = mechanism.step(response, &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(success);
assert_eq!(username.unwrap().as_str(), "testuser");
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_xoauth2_mechanism_failure() {
let backend = MockAuthBackend {
valid_users: vec![("testuser".to_string(), "token".to_string())],
};
let mut mechanism = XOAuth2Mechanism::new();
let response = b"user=invaliduser\x01auth=Bearer token\x01\x01";
let result = mechanism.step(response, &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(!success);
assert!(username.is_none());
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_xoauth2_parse_response() {
let response = b"user=testuser\x01auth=Bearer mytoken\x01\x01";
let (user, token) = XOAuth2Mechanism::parse_xoauth2_response(response).unwrap();
assert_eq!(user, "testuser");
assert_eq!(token, "mytoken");
}
#[tokio::test]
async fn test_xoauth2_parse_response_invalid() {
let response = b"invalid";
let result = XOAuth2Mechanism::parse_xoauth2_response(response);
assert!(result.is_err());
}
#[test]
fn test_sasl_server_default_config() {
let config = SaslConfig::default();
let server = SaslServer::new(config);
let mechanisms = server.enabled_mechanisms();
assert!(mechanisms.contains(&"PLAIN".to_string()));
assert!(mechanisms.contains(&"LOGIN".to_string()));
assert!(mechanisms.contains(&"CRAM-MD5".to_string()));
assert!(mechanisms.contains(&"SCRAM-SHA-256".to_string()));
assert!(mechanisms.contains(&"XOAUTH2".to_string()));
}
#[test]
fn test_sasl_server_custom_config() {
let config = SaslConfig {
enabled_mechanisms: vec!["PLAIN".to_string(), "LOGIN".to_string()],
hostname: "test.example.com".to_string(),
};
let server = SaslServer::new(config);
assert!(server.is_mechanism_enabled("PLAIN"));
assert!(server.is_mechanism_enabled("LOGIN"));
assert!(!server.is_mechanism_enabled("CRAM-MD5"));
}
#[test]
fn test_sasl_server_create_mechanism() {
let config = SaslConfig::default();
let server = SaslServer::new(config);
let plain = server.create_mechanism("PLAIN").unwrap();
assert_eq!(plain.name(), "PLAIN");
let login = server.create_mechanism("LOGIN").unwrap();
assert_eq!(login.name(), "LOGIN");
let cram = server.create_mechanism("CRAM-MD5").unwrap();
assert_eq!(cram.name(), "CRAM-MD5");
let scram = server.create_mechanism("SCRAM-SHA-256").unwrap();
assert_eq!(scram.name(), "SCRAM-SHA-256");
let xoauth2 = server.create_mechanism("XOAUTH2").unwrap();
assert_eq!(xoauth2.name(), "XOAUTH2");
}
#[test]
fn test_sasl_server_create_disabled_mechanism() {
let config = SaslConfig {
enabled_mechanisms: vec!["PLAIN".to_string()],
hostname: "localhost".to_string(),
};
let server = SaslServer::new(config);
let result = server.create_mechanism("CRAM-MD5");
assert!(result.is_err());
}
#[test]
fn test_sasl_server_create_unknown_mechanism() {
let config = SaslConfig::default();
let server = SaslServer::new(config);
let result = server.create_mechanism("UNKNOWN");
assert!(result.is_err());
}
#[test]
fn test_sasl_server_case_insensitive() {
let config = SaslConfig::default();
let server = SaslServer::new(config);
assert!(server.is_mechanism_enabled("plain"));
assert!(server.is_mechanism_enabled("PLAIN"));
assert!(server.is_mechanism_enabled("Plain"));
let plain = server.create_mechanism("plain").unwrap();
assert_eq!(plain.name(), "PLAIN");
}
#[tokio::test]
async fn test_plain_mechanism_full_flow() {
let backend = MockAuthBackend {
valid_users: vec![
("alice".to_string(), "password123".to_string()),
("bob".to_string(), "secret456".to_string()),
],
};
let config = SaslConfig::default();
let server = SaslServer::new(config);
let mut mechanism = server.create_mechanism("PLAIN").unwrap();
let response = b"\0alice\0password123";
let result = mechanism.step(response, &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(success);
assert_eq!(username.unwrap().as_str(), "alice");
}
_ => panic!("Expected Done step"),
}
}
#[tokio::test]
async fn test_login_mechanism_full_flow() {
let backend = MockAuthBackend {
valid_users: vec![("alice".to_string(), "password123".to_string())],
};
let config = SaslConfig::default();
let server = SaslServer::new(config);
let mut mechanism = server.create_mechanism("LOGIN").unwrap();
let result = mechanism.step(b"", &backend).await.unwrap();
assert!(matches!(result, SaslStep::Challenge { .. }));
let result = mechanism.step(b"alice", &backend).await.unwrap();
assert!(matches!(result, SaslStep::Challenge { .. }));
let result = mechanism.step(b"password123", &backend).await.unwrap();
match result {
SaslStep::Done { success, username } => {
assert!(success);
assert_eq!(username.unwrap().as_str(), "alice");
}
_ => panic!("Expected Done step"),
}
}
}