use crate::{Result, Error};
use sha2::{Sha256, Digest};
#[cfg(feature = "ring-crypto")]
use ring::pbkdf2;
#[cfg(feature = "ring-crypto")]
use ring::hmac;
use std::collections::HashMap;
use std::num::NonZeroU32;
use super::password_store::{InMemoryPasswordStore, SharedPasswordStore};
pub fn parse_scram_client_first(msg: &str) -> Result<(String, String)> {
let mut iter = msg.splitn(3, ',');
let _gs2_cbind_flag = iter.next().ok_or_else(|| {
Error::protocol("Invalid SCRAM client-first-message: missing GS2 channel-binding flag")
})?;
let _authzid = iter.next().ok_or_else(|| {
Error::protocol("Invalid SCRAM client-first-message: missing GS2 authzid slot")
})?;
let bare = iter.next().ok_or_else(|| {
Error::protocol("Invalid SCRAM client-first-message: missing client-first-message-bare")
})?;
if bare.is_empty() {
return Err(Error::protocol(
"Invalid SCRAM client-first-message: empty bare body",
));
}
let mut username: Option<&str> = None;
let mut nonce: Option<&str> = None;
for part in bare.split(',') {
if let Some(rest) = part.strip_prefix("n=") {
username = Some(rest);
} else if let Some(rest) = part.strip_prefix("r=") {
nonce = Some(rest);
}
}
let username = username.ok_or_else(|| {
Error::protocol("Invalid SCRAM client-first-message: missing username (n=)")
})?;
let nonce = nonce.ok_or_else(|| {
Error::protocol("Invalid SCRAM client-first-message: missing nonce (r=)")
})?;
if username.is_empty() {
return Err(Error::protocol(
"Invalid SCRAM client-first-message: empty username",
));
}
if nonce.is_empty() {
return Err(Error::protocol(
"Invalid SCRAM client-first-message: empty nonce",
));
}
Ok((username.to_string(), nonce.to_string()))
}
#[doc(hidden)]
pub fn parse_scram_client_first_for_test(msg: &str) -> Result<(String, String)> {
parse_scram_client_first(msg)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthMethod {
Trust,
CleartextPassword,
Md5,
ScramSha256,
}
#[derive(Debug, Clone)]
pub struct UserCredentials {
pub username: String,
pub password_hash: String,
pub salt: Option<Vec<u8>>,
}
pub struct AuthManager {
method: AuthMethod,
users: HashMap<String, UserCredentials>,
password_store: Option<SharedPasswordStore>,
}
impl AuthManager {
pub fn new(method: AuthMethod) -> Self {
Self {
method,
users: HashMap::new(),
password_store: None,
}
}
pub fn with_password_store(method: AuthMethod, password_store: SharedPasswordStore) -> Self {
Self {
method,
users: HashMap::new(),
password_store: Some(password_store),
}
}
pub fn with_scram_store(method: AuthMethod) -> Self {
let store = SharedPasswordStore::new(InMemoryPasswordStore::new());
Self::with_password_store(method, store)
}
pub fn add_user(&mut self, username: String, password: String) {
if self.method == AuthMethod::ScramSha256 {
if let Some(ref store) = self.password_store {
let _ = store.add_user(&username, &password);
return;
}
}
let password_hash = Self::hash_password(&password);
self.users.insert(
username.clone(),
UserCredentials {
username,
password_hash,
salt: None,
},
);
}
pub fn method(&self) -> AuthMethod {
self.method
}
pub fn password_store(&self) -> Option<&SharedPasswordStore> {
self.password_store.as_ref()
}
pub fn verify_cleartext(&self, username: &str, password: &str) -> Result<bool> {
if let Some(ref store) = self.password_store {
if let Some(creds) = store.get_credentials(username) {
return Ok(creds.verify_password(password));
} else {
let _ = Self::hash_password(password);
return Ok(false);
}
}
if let Some(user) = self.users.get(username) {
let password_hash = Self::hash_password(password);
Ok(user.password_hash == password_hash)
} else {
let _ = Self::hash_password(password);
Ok(false)
}
}
pub fn verify_md5(&self, username: &str, password: &str, salt: &[u8; 4]) -> Result<bool> {
if let Some(user) = self.users.get(username) {
let inner = format!("{}{}", password, username);
let inner_hash = format!("{:x}", md5::compute(inner.as_bytes()));
let mut outer_input = inner_hash.as_bytes().to_vec();
outer_input.extend_from_slice(salt);
let outer_hash = format!("md5{:x}", md5::compute(&outer_input));
Ok(outer_hash == user.password_hash)
} else {
Ok(false)
}
}
fn hash_password(password: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn with_default_users(mut self) -> Self {
self.add_user("postgres".to_string(), "postgres".to_string());
self.add_user("admin".to_string(), "admin".to_string());
self
}
}
#[derive(Debug, Clone)]
pub struct ScramAuthState {
username: String,
client_nonce: String,
server_nonce: String,
salt: Vec<u8>,
iteration_count: u32,
client_first_message_bare: String,
server_first_message: String,
}
impl ScramAuthState {
pub fn new(username: String) -> Self {
use rand::Rng;
let mut rng = rand::thread_rng();
let server_nonce: String = (0..24)
.map(|_| rng.sample(rand::distributions::Alphanumeric) as char)
.collect();
let salt: Vec<u8> = (0..16).map(|_| rng.gen::<u8>()).collect();
Self {
username,
client_nonce: String::new(),
server_nonce,
salt,
iteration_count: 4096,
client_first_message_bare: String::new(),
server_first_message: String::new(),
}
}
pub fn set_client_nonce(&mut self, nonce: String) {
self.client_nonce = nonce;
}
pub fn set_client_first_message_bare(&mut self, msg: String) {
self.client_first_message_bare = msg;
}
pub fn build_server_first_message(&mut self) -> Result<String> {
let salt_b64 = base64_encode(&self.salt)
.map_err(|e| Error::authentication(format!("Failed to encode salt: {}", e)))?;
let msg = format!(
"r={}{},s={},i={}",
self.client_nonce, self.server_nonce, salt_b64, self.iteration_count
);
self.server_first_message = msg.clone();
Ok(msg)
}
pub fn combined_nonce(&self) -> String {
format!("{}{}", self.client_nonce, self.server_nonce)
}
pub fn verify_client_proof(
&self,
client_proof_b64: &str,
client_final_message_without_proof: &str,
stored_key: &[u8],
server_key: &[u8],
) -> Result<Vec<u8>> {
let client_proof = base64_decode(client_proof_b64)
.map_err(|e| Error::authentication(format!("Invalid client proof encoding: {}", e)))?;
let auth_message = format!(
"{},{},{}",
self.client_first_message_bare,
self.server_first_message,
client_final_message_without_proof
);
let client_signature = scram_hmac_sha256(stored_key, auth_message.as_bytes());
let client_key: Vec<u8> = client_proof.iter()
.zip(client_signature.iter())
.map(|(a, b)| a ^ b)
.collect();
let computed_stored_key = scram_h(&client_key);
if !constant_time_compare(&computed_stored_key, stored_key) {
return Err(Error::authentication("Invalid password"));
}
let server_signature = scram_hmac_sha256(server_key, auth_message.as_bytes());
Ok(server_signature)
}
pub fn build_server_final_message(&self, server_signature: &[u8]) -> Result<String> {
let signature_b64 = base64_encode(server_signature)
.map_err(|e| Error::authentication(format!("Failed to encode signature: {}", e)))?;
Ok(format!("v={}", signature_b64))
}
pub fn username(&self) -> &str {
&self.username
}
pub fn salt(&self) -> &[u8] {
&self.salt
}
pub fn iteration_count(&self) -> u32 {
self.iteration_count
}
}
pub fn scram_hi(password: &str, salt: &[u8], iterations: u32) -> Vec<u8> {
const DEFAULT_ITERATIONS: NonZeroU32 = match NonZeroU32::new(4096) {
Some(n) => n,
None => unreachable!(),
};
let iterations = NonZeroU32::new(iterations).unwrap_or(DEFAULT_ITERATIONS);
let mut out = vec![0u8; 32];
pbkdf2::derive(
pbkdf2::PBKDF2_HMAC_SHA256,
iterations,
salt,
password.as_bytes(),
&mut out,
);
out
}
pub fn scram_hmac_sha256(key: &[u8], message: &[u8]) -> Vec<u8> {
let key = hmac::Key::new(hmac::HMAC_SHA256, key);
let signature = hmac::sign(&key, message);
signature.as_ref().to_vec()
}
pub fn scram_h(input: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(input);
hasher.finalize().to_vec()
}
pub fn scram_salted_password(password: &str, salt: &[u8], iterations: u32) -> Vec<u8> {
scram_hi(password, salt, iterations)
}
pub fn scram_client_key(salted_password: &[u8]) -> Vec<u8> {
scram_hmac_sha256(salted_password, b"Client Key")
}
pub fn scram_stored_key(client_key: &[u8]) -> Vec<u8> {
scram_h(client_key)
}
pub fn scram_server_key(salted_password: &[u8]) -> Vec<u8> {
scram_hmac_sha256(salted_password, b"Server Key")
}
fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
pub fn prepare_scram_credentials(
password: &str,
salt: &[u8],
iterations: u32,
) -> (Vec<u8>, Vec<u8>) {
let salted_password = scram_salted_password(password, salt, iterations);
let client_key = scram_client_key(&salted_password);
let stored_key = scram_stored_key(&client_key);
let server_key = scram_server_key(&salted_password);
(stored_key, server_key)
}
fn base64_encode(data: &[u8]) -> std::result::Result<String, Box<dyn std::error::Error>> {
use base64::Engine;
let encoded = base64::engine::general_purpose::STANDARD.encode(data);
Ok(encoded)
}
fn base64_decode(data: &str) -> std::result::Result<Vec<u8>, Box<dyn std::error::Error>> {
use base64::Engine;
let decoded = base64::engine::general_purpose::STANDARD.decode(data)?;
Ok(decoded)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_auth_manager_creation() {
let mut auth = AuthManager::new(AuthMethod::CleartextPassword);
auth.add_user("test_user".to_string(), "test_pass".to_string());
assert!(auth.verify_cleartext("test_user", "test_pass").unwrap());
assert!(!auth.verify_cleartext("test_user", "wrong_pass").unwrap());
}
#[test]
fn test_password_hashing() {
let hash1 = AuthManager::hash_password("password123");
let hash2 = AuthManager::hash_password("password123");
assert_eq!(hash1, hash2);
}
#[test]
fn test_scram_state_creation() {
let scram = ScramAuthState::new("testuser".to_string());
assert_eq!(scram.username(), "testuser");
assert_eq!(scram.salt().len(), 16);
assert_eq!(scram.iteration_count(), 4096);
}
#[test]
fn test_scram_server_first_message() {
let mut scram = ScramAuthState::new("testuser".to_string());
scram.set_client_nonce("clientnonce".to_string());
let msg = scram.build_server_first_message().unwrap();
assert!(msg.starts_with("r=clientnonce"));
assert!(msg.contains(",s="));
assert!(msg.contains(",i=4096"));
}
#[test]
fn test_scram_hi_function() {
let password = "pencil";
let salt = b"salt";
let iterations = 4096;
let result = scram_hi(password, salt, iterations);
assert_eq!(result.len(), 32); }
#[test]
fn test_scram_hmac_sha256() {
let key = b"key";
let message = b"The quick brown fox jumps over the lazy dog";
let result = scram_hmac_sha256(key, message);
assert_eq!(result.len(), 32); }
#[test]
fn test_scram_h_function() {
let input = b"test data";
let result = scram_h(input);
assert_eq!(result.len(), 32); }
#[test]
fn test_scram_key_derivation() {
let password = "pencil";
let salt = b"salt1234567890ab";
let iterations = 4096;
let salted_password = scram_salted_password(password, salt, iterations);
assert_eq!(salted_password.len(), 32);
let client_key = scram_client_key(&salted_password);
assert_eq!(client_key.len(), 32);
let stored_key = scram_stored_key(&client_key);
assert_eq!(stored_key.len(), 32);
let server_key = scram_server_key(&salted_password);
assert_eq!(server_key.len(), 32);
}
#[test]
fn test_prepare_scram_credentials() {
let password = "secret";
let salt = b"randomsalt123456";
let iterations = 4096;
let (stored_key, server_key) = prepare_scram_credentials(password, salt, iterations);
assert_eq!(stored_key.len(), 32);
assert_eq!(server_key.len(), 32);
assert_ne!(stored_key, server_key);
}
#[test]
fn test_constant_time_compare() {
let a = vec![1, 2, 3, 4];
let b = vec![1, 2, 3, 4];
let c = vec![1, 2, 3, 5];
assert!(constant_time_compare(&a, &b));
assert!(!constant_time_compare(&a, &c));
assert!(!constant_time_compare(&a, &[1, 2, 3]));
}
#[test]
fn test_base64_encoding() {
let data = b"Hello, World!";
let encoded = base64_encode(data).unwrap();
let decoded = base64_decode(&encoded).unwrap();
assert_eq!(data.to_vec(), decoded);
}
#[test]
fn test_scram_proof_verification() {
let password = "secret";
let salt = b"randomsalt123456";
let iterations = 4096;
let (stored_key, server_key) = prepare_scram_credentials(password, salt, iterations);
let mut scram = ScramAuthState::new("testuser".to_string());
scram.set_client_nonce("clientnonce".to_string());
scram.set_client_first_message_bare("n=testuser,r=clientnonce".to_string());
let _server_msg = scram.build_server_first_message().unwrap();
assert_eq!(stored_key.len(), 32);
assert_eq!(server_key.len(), 32);
}
}