use crate::protocol::codec::MessageBuilder;
pub fn write_password(buf: &mut Vec<u8>, password: &str) {
let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
msg.write_cstr(password);
msg.finish();
}
pub fn md5_password(username: &str, password: &str, salt: &[u8; 4]) -> String {
use md5::{Digest, Md5};
let first_hash = {
let mut hasher = Md5::new();
hasher.update(password.as_bytes());
hasher.update(username.as_bytes());
hasher.finalize()
};
let first_hash_hex = format!("{:x}", first_hash);
let second_hash = {
let mut hasher = Md5::new();
hasher.update(first_hash_hex.as_bytes());
hasher.update(salt);
hasher.finalize()
};
format!("md5{:x}", second_hash)
}
pub fn write_sasl_initial_response(buf: &mut Vec<u8>, mechanism: &str, initial_response: &[u8]) {
let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
msg.write_cstr(mechanism);
msg.write_i32(initial_response.len() as i32);
msg.write_bytes(initial_response);
msg.finish();
}
pub fn write_sasl_response(buf: &mut Vec<u8>, response: &[u8]) {
let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
msg.write_bytes(response);
msg.finish();
}
pub struct ScramClient {
nonce: String,
channel_binding: String,
password: String,
server_first: Option<String>,
auth_message: Option<String>,
salted_password: Option<Vec<u8>>,
}
impl ScramClient {
pub fn new(password: &str) -> Self {
use rand::Rng;
let mut nonce_bytes = [0u8; 24];
rand::rng().fill(&mut nonce_bytes);
let nonce = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes);
Self {
nonce,
channel_binding: "n,,".to_string(), password: password.to_string(),
server_first: None,
auth_message: None,
salted_password: None,
}
}
pub fn new_with_channel_binding(password: &str, channel_binding_data: &[u8]) -> Self {
use rand::Rng;
let mut nonce_bytes = [0u8; 24];
rand::rng().fill(&mut nonce_bytes);
let nonce = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes);
let cb_data = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
channel_binding_data,
);
Self {
nonce,
channel_binding: format!("p=tls-server-end-point,,{}", cb_data),
password: password.to_string(),
server_first: None,
auth_message: None,
salted_password: None,
}
}
pub fn client_first_message(&self) -> String {
format!("{}n=,r={}", self.channel_binding, self.nonce)
}
fn client_first_message_bare(&self) -> String {
format!("n=,r={}", self.nonce)
}
pub fn process_server_first(&mut self, server_first: &str) -> Result<String, String> {
use base64::Engine;
use hmac::{Hmac, Mac};
use pbkdf2::pbkdf2_hmac;
use sha2::{Digest, Sha256};
self.server_first = Some(server_first.to_string());
let mut combined_nonce = None;
let mut salt_b64 = None;
let mut iterations = None;
for part in server_first.split(',') {
if let Some(value) = part.strip_prefix("r=") {
combined_nonce = Some(value);
} else if let Some(value) = part.strip_prefix("s=") {
salt_b64 = Some(value);
} else if let Some(value) = part.strip_prefix("i=") {
iterations = value.parse().ok();
}
}
let combined_nonce = combined_nonce.ok_or("Missing nonce in server-first-message")?;
let salt_b64 = salt_b64.ok_or("Missing salt in server-first-message")?;
let iterations: u32 = iterations.ok_or("Missing iterations in server-first-message")?;
if !combined_nonce.starts_with(&self.nonce) {
return Err("Server nonce doesn't start with client nonce".to_string());
}
let salt = base64::engine::general_purpose::STANDARD
.decode(salt_b64)
.map_err(|e| format!("Invalid salt: {}", e))?;
let mut salted_password = vec![0u8; 32];
pbkdf2_hmac::<Sha256>(
self.password.as_bytes(),
&salt,
iterations,
&mut salted_password,
);
self.salted_password = Some(salted_password.clone());
let client_key = {
let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(&salted_password)
.map_err(|e| format!("HMAC error: {}", e))?;
mac.update(b"Client Key");
mac.finalize().into_bytes()
};
let stored_key = Sha256::digest(client_key);
let channel_binding_b64 =
base64::engine::general_purpose::STANDARD.encode(self.channel_binding.as_bytes());
let client_final_without_proof = format!("c={},r={}", channel_binding_b64, combined_nonce);
let auth_message = format!(
"{},{},{}",
self.client_first_message_bare(),
server_first,
client_final_without_proof
);
self.auth_message = Some(auth_message.clone());
let client_signature = {
let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(&stored_key)
.map_err(|e| format!("HMAC error: {}", e))?;
mac.update(auth_message.as_bytes());
mac.finalize().into_bytes()
};
let mut client_proof = [0u8; 32];
for i in 0..32 {
client_proof[i] = client_key[i] ^ client_signature[i];
}
let proof_b64 = base64::engine::general_purpose::STANDARD.encode(client_proof);
Ok(format!("{},p={}", client_final_without_proof, proof_b64))
}
pub fn verify_server_final(&self, server_final: &str) -> Result<(), String> {
use base64::Engine;
use hmac::{Hmac, Mac};
let server_signature_b64 = server_final
.strip_prefix("v=")
.ok_or("Invalid server-final-message format")?;
let server_signature = base64::engine::general_purpose::STANDARD
.decode(server_signature_b64)
.map_err(|e| format!("Invalid server signature: {}", e))?;
let salted_password = self
.salted_password
.as_ref()
.ok_or("Missing salted password")?;
let auth_message = self.auth_message.as_ref().ok_or("Missing auth message")?;
let server_key = {
let mut mac = <Hmac<sha2::Sha256> as Mac>::new_from_slice(salted_password)
.map_err(|e| format!("HMAC error: {}", e))?;
mac.update(b"Server Key");
mac.finalize().into_bytes()
};
let expected_signature = {
let mut mac = <Hmac<sha2::Sha256> as Mac>::new_from_slice(&server_key)
.map_err(|e| format!("HMAC error: {}", e))?;
mac.update(auth_message.as_bytes());
mac.finalize().into_bytes()
};
if server_signature.as_slice() != expected_signature.as_slice() {
return Err("Server signature verification failed".to_string());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn md5_password_hash() {
let result = md5_password("postgres", "password", &[0x01, 0x02, 0x03, 0x04]);
assert!(result.starts_with("md5"));
assert_eq!(result.len(), 35); }
#[test]
fn password_message() {
let mut buf = Vec::new();
write_password(&mut buf, "secret");
assert_eq!(buf[0], b'p');
assert!(buf.ends_with(&[0]));
}
}