1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::{rngs::OsRng, Rng};
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15const MAX_SCRAM_ITERATIONS: u32 = 1_000_000;
22
23#[derive(Debug, Clone)]
25#[non_exhaustive]
26pub enum ScramError {
27 InvalidServerProof(String),
29 InvalidServerMessage(String),
31 Utf8Error(String),
33 Base64Error(String),
35}
36
37impl fmt::Display for ScramError {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 match self {
40 ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
41 ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
42 ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
43 ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
44 }
45 }
46}
47
48impl std::error::Error for ScramError {}
49
50#[derive(Clone, Debug)]
52pub struct ScramState {
53 auth_message: Vec<u8>,
55 server_key: Vec<u8>,
57}
58
59pub struct ScramClient {
61 username: String,
62 password: String,
63 nonce: String,
64}
65
66impl ScramClient {
67 pub fn new(username: String, password: String) -> Self {
69 let mut rng = OsRng;
71 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
72 let nonce = BASE64.encode(&nonce_bytes);
73
74 Self {
75 username,
76 password,
77 nonce,
78 }
79 }
80
81 pub fn client_first(&self) -> String {
83 let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
88 format!("n,,n={},r={}", escaped_username, self.nonce)
89 }
90
91 pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
102 let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
104
105 if !server_nonce.starts_with(&self.nonce) {
107 return Err(ScramError::InvalidServerMessage(
108 "server nonce doesn't contain client nonce".to_string(),
109 ));
110 }
111
112 let salt_bytes = BASE64
114 .decode(&salt)
115 .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
116 let iterations = iterations
117 .parse::<u32>()
118 .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
119
120 if iterations > MAX_SCRAM_ITERATIONS {
123 return Err(ScramError::InvalidServerMessage(format!(
124 "server iteration count {iterations} exceeds maximum of {MAX_SCRAM_ITERATIONS}"
125 )));
126 }
127
128 let channel_binding = BASE64.encode(b"n,,");
130
131 let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
133
134 let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
140 let client_first_bare = format!("n={},r={}", escaped_username, self.nonce);
141 let auth_message = format!(
142 "{},{},{}",
143 client_first_bare, server_first, client_final_without_proof
144 );
145
146 let proof = calculate_client_proof(
148 &self.password,
149 &salt_bytes,
150 iterations,
151 auth_message.as_bytes(),
152 )?;
153
154 let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
156
157 let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
159
160 let state = ScramState {
161 auth_message: auth_message.into_bytes(),
162 server_key,
163 };
164
165 Ok((client_final, state))
166 }
167
168 pub fn verify_server_final(
176 &self,
177 server_final: &str,
178 state: &ScramState,
179 ) -> Result<(), ScramError> {
180 let server_sig_encoded = server_final
182 .strip_prefix("v=")
183 .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
184
185 let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
186 ScramError::Base64Error("invalid server signature encoding".to_string())
187 })?;
188
189 let expected_signature =
191 calculate_server_signature(&state.server_key, &state.auth_message)?;
192
193 if constant_time_compare(&server_signature, &expected_signature) {
195 Ok(())
196 } else {
197 Err(ScramError::InvalidServerProof(
198 "server signature verification failed".to_string(),
199 ))
200 }
201 }
202}
203
204fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
206 let mut nonce = String::new();
207 let mut salt = String::new();
208 let mut iterations = String::new();
209
210 for part in msg.split(',') {
211 if let Some(value) = part.strip_prefix("r=") {
212 nonce = value.to_string();
213 } else if let Some(value) = part.strip_prefix("s=") {
214 salt = value.to_string();
215 } else if let Some(value) = part.strip_prefix("i=") {
216 iterations = value.to_string();
217 }
218 }
219
220 if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
221 return Err(ScramError::InvalidServerMessage(
222 "missing required fields in server first message".to_string(),
223 ));
224 }
225
226 Ok((nonce, salt, iterations))
227}
228
229fn calculate_client_proof(
231 password: &str,
232 salt: &[u8],
233 iterations: u32,
234 auth_message: &[u8],
235) -> Result<Vec<u8>, ScramError> {
236 let password_bytes = password.as_bytes();
238 let mut salted_password = vec![0u8; 32]; let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
240
241 let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
243 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
244 client_key_hmac.update(b"Client Key");
245 let client_key = client_key_hmac.finalize().into_bytes();
246
247 let stored_key = Sha256::digest(client_key.to_vec().as_slice());
249
250 let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
252 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
253 client_sig_hmac.update(auth_message);
254 let client_signature = client_sig_hmac.finalize().into_bytes();
255
256 let mut proof = client_key.to_vec();
258 for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
259 *proof_byte ^= sig_byte;
260 }
261
262 Ok(proof.clone())
263}
264
265fn calculate_server_key(
267 password: &str,
268 salt: &[u8],
269 iterations: u32,
270) -> Result<Vec<u8>, ScramError> {
271 let password_bytes = password.as_bytes();
273 let mut salted_password = vec![0u8; 32];
274 let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
275
276 let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
278 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
279 server_key_hmac.update(b"Server Key");
280
281 Ok(server_key_hmac.finalize().into_bytes().to_vec())
282}
283
284fn calculate_server_signature(
286 server_key: &[u8],
287 auth_message: &[u8],
288) -> Result<Vec<u8>, ScramError> {
289 let mut hmac = HmacSha256::new_from_slice(server_key)
290 .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
291 hmac.update(auth_message);
292 Ok(hmac.finalize().into_bytes().to_vec())
293}
294
295fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
299 use subtle::ConstantTimeEq;
300 a.ct_eq(b).into()
301}
302
303#[cfg(test)]
304mod tests {
305 #![allow(clippy::unwrap_used)] use super::*;
307
308 #[test]
309 fn test_scram_client_creation() {
310 let client = ScramClient::new("user".to_string(), "password".to_string());
311 assert_eq!(client.username, "user");
312 assert_eq!(client.password, "password");
313 assert!(!client.nonce.is_empty());
314 }
315
316 #[test]
317 fn test_client_first_message_format() {
318 let client = ScramClient::new("alice".to_string(), "secret".to_string());
319 let first = client.client_first();
320
321 assert!(first.starts_with("n,,n=alice,r="));
323 assert!(first.len() > 20);
324 }
325
326 #[test]
327 fn test_parse_server_first_valid() {
328 let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
329 let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
330
331 assert_eq!(nonce, "client_nonce_server_nonce");
332 assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
333 assert_eq!(iterations, "4096");
334 }
335
336 #[test]
337 fn test_parse_server_first_invalid() {
338 let server_first = "r=nonce,s=salt"; let result = parse_server_first(server_first);
340 assert!(
341 matches!(result, Err(ScramError::InvalidServerMessage(_))),
342 "expected InvalidServerMessage error, got: {result:?}"
343 );
344 }
345
346 #[test]
347 fn test_constant_time_compare_equal() {
348 let a = b"test_value";
349 let b_arr = b"test_value";
350 assert!(constant_time_compare(a, b_arr));
351 }
352
353 #[test]
354 fn test_constant_time_compare_different() {
355 let a = b"test_value";
356 let b_arr = b"test_wrong";
357 assert!(!constant_time_compare(a, b_arr));
358 }
359
360 #[test]
361 fn test_constant_time_compare_different_length() {
362 let a = b"test";
363 let b_arr = b"test_longer";
364 assert!(!constant_time_compare(a, b_arr));
365 }
366
367 #[test]
368 fn test_scram_client_final_flow() {
369 let mut client = ScramClient::new("user".to_string(), "password".to_string());
370 let _client_first = client.client_first();
371
372 let server_nonce = format!("{}server_nonce_part", client.nonce);
374 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
375
376 let result = client.client_final(&server_first);
378 let (client_final, state) = result.unwrap_or_else(|e| {
379 panic!("expected Ok for client_final with valid server message: {e}")
380 });
381 assert!(client_final.starts_with("c="));
382 assert!(!state.auth_message.is_empty());
383 }
384
385 #[test]
386 fn test_scram_iteration_count_too_high_is_rejected() {
387 let mut client = ScramClient::new("user".to_string(), "password".to_string());
390 let _client_first = client.client_first();
391
392 let server_nonce = format!("{}server_nonce_part", client.nonce);
393 let excessive_iterations = MAX_SCRAM_ITERATIONS + 1;
394 let server_first = format!(
395 "r={},s={},i={}",
396 server_nonce,
397 BASE64.encode(b"salty"),
398 excessive_iterations
399 );
400
401 let result = client.client_final(&server_first);
402 assert!(
403 matches!(result, Err(ScramError::InvalidServerMessage(_))),
404 "expected InvalidServerMessage for excessive iterations, got: {result:?}"
405 );
406 }
407
408 #[test]
409 fn test_scram_iteration_count_at_limit_is_accepted() {
410 let mut client = ScramClient::new("user".to_string(), "password".to_string());
412 let _client_first = client.client_first();
413
414 let server_nonce = format!("{}server_nonce_part", client.nonce);
415 let server_first = format!(
416 "r={},s={},i={}",
417 server_nonce,
418 BASE64.encode(b"salty"),
419 MAX_SCRAM_ITERATIONS
420 );
421
422 let result = client.client_final(&server_first);
424 if let Err(ScramError::InvalidServerMessage(msg)) = &result {
426 assert!(
427 !msg.contains("iteration count"),
428 "unexpected iteration-count rejection at limit: {msg}"
429 );
430 }
431 }
432
433 #[test]
434 fn test_scram_username_escaping_in_auth_message() {
435 let mut client = ScramClient::new("user,admin=evil".to_string(), "password".to_string());
438 let client_first = client.client_first();
439 assert!(
441 client_first.contains("user=2Cadmin=3Devil"),
442 "client_first should escape ',' and '=' in username, got: {client_first}"
443 );
444
445 let server_nonce = format!("{}server_nonce_part", client.nonce);
447 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
448
449 let result = client.client_final(&server_first);
450 let (_client_final, state) =
451 result.unwrap_or_else(|e| panic!("expected Ok for escaped-username client_final: {e}"));
452
453 let auth_message = String::from_utf8(state.auth_message).unwrap();
455 assert!(
456 auth_message.contains("user=2Cadmin=3Devil"),
457 "auth_message should contain escaped username, got: {auth_message}"
458 );
459 assert!(
460 !auth_message.contains("user,admin=evil"),
461 "auth_message must NOT contain raw (unescaped) username, got: {auth_message}"
462 );
463 }
464}