1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::Rng;
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15#[derive(Debug, Clone)]
17pub enum ScramError {
18 InvalidServerProof(String),
20 InvalidServerMessage(String),
22 Utf8Error(String),
24 Base64Error(String),
26}
27
28impl fmt::Display for ScramError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
32 ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
33 ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
34 ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
35 }
36 }
37}
38
39impl std::error::Error for ScramError {}
40
41#[derive(Clone, Debug)]
43pub enum ChannelBinding {
44 None,
46 TlsServerEndPoint(Vec<u8>),
48}
49
50#[derive(Clone, Debug)]
52pub struct ScramState {
53 auth_message: Vec<u8>,
55 server_key: Vec<u8>,
57}
58
59pub struct ScramClient {
61 username: String,
62 password: String,
63 nonce: String,
64 channel_binding: ChannelBinding,
65}
66
67impl ScramClient {
68 pub fn new(username: String, password: String) -> Self {
70 Self::with_channel_binding(username, password, ChannelBinding::None)
71 }
72
73 pub fn with_channel_binding(
75 username: String,
76 password: String,
77 channel_binding: ChannelBinding,
78 ) -> Self {
79 let mut rng = rand::thread_rng();
80 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
81 let nonce = BASE64.encode(&nonce_bytes);
82
83 Self {
84 username,
85 password,
86 nonce,
87 channel_binding,
88 }
89 }
90
91 fn gs2_header(&self) -> &'static str {
93 match self.channel_binding {
94 ChannelBinding::None => "n",
95 ChannelBinding::TlsServerEndPoint(_) => "p=tls-server-end-point",
96 }
97 }
98
99 pub fn client_first(&self) -> String {
101 format!("{},a={},r={}", self.gs2_header(), self.username, self.nonce)
102 }
103
104 pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
108 let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
110
111 if !server_nonce.starts_with(&self.nonce) {
113 return Err(ScramError::InvalidServerMessage(
114 "server nonce doesn't contain client nonce".to_string(),
115 ));
116 }
117
118 let salt_bytes = BASE64
120 .decode(&salt)
121 .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
122 let iterations = iterations
123 .parse::<u32>()
124 .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
125
126 let gs2_cbind = match &self.channel_binding {
129 ChannelBinding::None => {
130 b"n,,".to_vec()
132 }
133 ChannelBinding::TlsServerEndPoint(data) => {
134 let mut buf = b"p=tls-server-end-point,,".to_vec();
136 buf.extend_from_slice(data);
137 buf
138 }
139 };
140 let channel_binding = BASE64.encode(&gs2_cbind);
141
142 let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
144
145 let client_first_bare = format!("a={},r={}", self.username, self.nonce);
147 let auth_message = format!(
148 "{},{},{}",
149 client_first_bare, server_first, client_final_without_proof
150 );
151
152 let proof = calculate_client_proof(
154 &self.password,
155 &salt_bytes,
156 iterations,
157 auth_message.as_bytes(),
158 )?;
159
160 let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
162
163 let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
165
166 let state = ScramState {
167 auth_message: auth_message.into_bytes(),
168 server_key,
169 };
170
171 Ok((client_final, state))
172 }
173
174 pub fn verify_server_final(
176 &self,
177 server_final: &str,
178 state: &ScramState,
179 ) -> Result<(), ScramError> {
180 let server_sig_encoded = server_final
182 .strip_prefix("v=")
183 .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
184
185 let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
186 ScramError::Base64Error("invalid server signature encoding".to_string())
187 })?;
188
189 let expected_signature = calculate_server_signature(&state.server_key, &state.auth_message);
191
192 if constant_time_compare(&server_signature, &expected_signature) {
194 Ok(())
195 } else {
196 Err(ScramError::InvalidServerProof(
197 "server signature verification failed".to_string(),
198 ))
199 }
200 }
201}
202
203fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
205 let mut nonce = String::new();
206 let mut salt = String::new();
207 let mut iterations = String::new();
208
209 for part in msg.split(',') {
210 if let Some(value) = part.strip_prefix("r=") {
211 nonce = value.to_string();
212 } else if let Some(value) = part.strip_prefix("s=") {
213 salt = value.to_string();
214 } else if let Some(value) = part.strip_prefix("i=") {
215 iterations = value.to_string();
216 }
217 }
218
219 if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
220 return Err(ScramError::InvalidServerMessage(
221 "missing required fields in server first message".to_string(),
222 ));
223 }
224
225 Ok((nonce, salt, iterations))
226}
227
228fn calculate_client_proof(
230 password: &str,
231 salt: &[u8],
232 iterations: u32,
233 auth_message: &[u8],
234) -> Result<Vec<u8>, ScramError> {
235 let password_bytes = password.as_bytes();
237 let mut salted_password = vec![0u8; 32]; let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
239
240 let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
242 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
243 client_key_hmac.update(b"Client Key");
244 let client_key = client_key_hmac.finalize().into_bytes();
245
246 let stored_key = Sha256::digest(client_key.to_vec().as_slice());
248
249 let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
251 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
252 client_sig_hmac.update(auth_message);
253 let client_signature = client_sig_hmac.finalize().into_bytes();
254
255 let mut proof = client_key.to_vec();
257 for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
258 *proof_byte ^= sig_byte;
259 }
260
261 Ok(proof.to_vec())
262}
263
264fn calculate_server_key(
266 password: &str,
267 salt: &[u8],
268 iterations: u32,
269) -> Result<Vec<u8>, ScramError> {
270 let password_bytes = password.as_bytes();
272 let mut salted_password = vec![0u8; 32];
273 let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
274
275 let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
277 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
278 server_key_hmac.update(b"Server Key");
279
280 Ok(server_key_hmac.finalize().into_bytes().to_vec())
281}
282
283fn calculate_server_signature(server_key: &[u8], auth_message: &[u8]) -> Vec<u8> {
285 let mut hmac = HmacSha256::new_from_slice(server_key).expect("HMAC key should be valid");
286 hmac.update(auth_message);
287 hmac.finalize().into_bytes().to_vec()
288}
289
290fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
292 if a.len() != b.len() {
293 return false;
294 }
295 let mut result = 0u8;
296 for (x, y) in a.iter().zip(b.iter()) {
297 result |= x ^ y;
298 }
299 result == 0
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_scram_client_creation() {
308 let client = ScramClient::new("user".to_string(), "password".to_string());
309 assert_eq!(client.username, "user");
310 assert_eq!(client.password, "password");
311 assert!(!client.nonce.is_empty());
312 }
313
314 #[test]
315 fn test_client_first_message_format() {
316 let client = ScramClient::new("alice".to_string(), "secret".to_string());
317 let first = client.client_first();
318
319 assert!(first.starts_with("n,a=alice,r="));
320 assert!(first.len() > 20);
321 }
322
323 #[test]
324 fn test_parse_server_first_valid() {
325 let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
326 let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
327
328 assert_eq!(nonce, "client_nonce_server_nonce");
329 assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
330 assert_eq!(iterations, "4096");
331 }
332
333 #[test]
334 fn test_parse_server_first_invalid() {
335 let server_first = "r=nonce,s=salt"; assert!(parse_server_first(server_first).is_err());
337 }
338
339 #[test]
340 fn test_constant_time_compare_equal() {
341 let a = b"test_value";
342 let b_arr = b"test_value";
343 assert!(constant_time_compare(a, b_arr));
344 }
345
346 #[test]
347 fn test_constant_time_compare_different() {
348 let a = b"test_value";
349 let b_arr = b"test_wrong";
350 assert!(!constant_time_compare(a, b_arr));
351 }
352
353 #[test]
354 fn test_constant_time_compare_different_length() {
355 let a = b"test";
356 let b_arr = b"test_longer";
357 assert!(!constant_time_compare(a, b_arr));
358 }
359
360 #[test]
361 fn test_client_first_with_channel_binding() {
362 let client = ScramClient::with_channel_binding(
363 "alice".to_string(),
364 "secret".to_string(),
365 ChannelBinding::TlsServerEndPoint(vec![1, 2, 3, 4]),
366 );
367 let first = client.client_first();
368 assert!(first.starts_with("p=tls-server-end-point,a=alice,r="));
370 }
371
372 #[test]
373 fn test_client_first_without_channel_binding() {
374 let client = ScramClient::new("alice".to_string(), "secret".to_string());
375 let first = client.client_first();
376 assert!(first.starts_with("n,a=alice,r="));
378 }
379
380 #[test]
381 fn test_client_final_with_channel_binding() {
382 let binding_data = vec![0xDE, 0xAD, 0xBE, 0xEF];
384 let mut client = ScramClient::with_channel_binding(
385 "user".to_string(),
386 "password".to_string(),
387 ChannelBinding::TlsServerEndPoint(binding_data.clone()),
388 );
389 let _first = client.client_first();
390
391 let server_nonce = format!("{}server_part", client.nonce);
392 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
393
394 let (client_final, _state) = client.client_final(&server_first).unwrap();
395
396 let c_value = client_final
398 .split(',')
399 .find(|s| s.starts_with("c="))
400 .unwrap()
401 .strip_prefix("c=")
402 .unwrap();
403 let decoded = BASE64.decode(c_value).unwrap();
404 let header = b"p=tls-server-end-point,,";
406 assert!(decoded.starts_with(header));
407 assert_eq!(&decoded[header.len()..], &binding_data);
409 }
410
411 #[test]
412 fn test_scram_client_final_flow() {
413 let mut client = ScramClient::new("user".to_string(), "password".to_string());
414 let _client_first = client.client_first();
415
416 let server_nonce = format!("{}server_nonce_part", client.nonce);
418 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
419
420 let result = client.client_final(&server_first);
422 assert!(result.is_ok());
423
424 let (client_final, state) = result.unwrap();
425 assert!(client_final.starts_with("c="));
426 assert!(!state.auth_message.is_empty());
427 }
428
429 #[test]
432 fn test_parse_server_first_missing_nonce() {
433 let result = parse_server_first("s=c2FsdA==,i=4096");
434 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
435 }
436
437 #[test]
438 fn test_parse_server_first_missing_salt() {
439 let result = parse_server_first("r=nonce,i=4096");
440 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
441 }
442
443 #[test]
444 fn test_parse_server_first_missing_iterations() {
445 let result = parse_server_first("r=nonce,s=c2FsdA==");
446 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
447 }
448
449 #[test]
450 fn test_parse_server_first_empty_string() {
451 let result = parse_server_first("");
452 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
453 }
454
455 #[test]
456 fn test_parse_server_first_empty_values() {
457 let result = parse_server_first("r=,s=,i=");
458 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
459 }
460
461 #[test]
462 fn test_parse_server_first_extra_fields_ignored() {
463 let result = parse_server_first("r=nonce123,x=junk,s=c2FsdA==,i=4096");
464 let (nonce, salt, iterations) = result.unwrap();
465 assert_eq!(nonce, "nonce123");
466 assert_eq!(salt, "c2FsdA==");
467 assert_eq!(iterations, "4096");
468 }
469
470 #[test]
471 fn test_parse_server_first_different_field_order() {
472 let result = parse_server_first("s=c2FsdA==,i=4096,r=nonce123");
473 let (nonce, salt, iterations) = result.unwrap();
474 assert_eq!(nonce, "nonce123");
475 assert_eq!(salt, "c2FsdA==");
476 assert_eq!(iterations, "4096");
477 }
478
479 #[test]
482 fn test_client_final_nonce_prefix_mismatch() {
483 let mut client = ScramClient::new("user".to_string(), "pass".to_string());
484 let _first = client.client_first();
485
486 let server_first = format!(
487 "r=TAMPERED_NONCE_server_ext,s={},i=4096",
488 BASE64.encode(b"salty")
489 );
490 let result = client.client_final(&server_first);
491 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
492 }
493
494 #[test]
495 fn test_client_final_nonce_identical_to_client() {
496 let mut client = ScramClient::new("user".to_string(), "pass".to_string());
497 let _first = client.client_first();
498 let client_nonce = client.nonce.clone();
499
500 let server_first = format!("r={},s={},i=4096", client_nonce, BASE64.encode(b"salty"));
502 let result = client.client_final(&server_first);
503 assert!(result.is_ok());
504 }
505
506 #[test]
509 fn test_client_final_invalid_base64_salt() {
510 let mut client = ScramClient::new("user".to_string(), "pass".to_string());
511 let _first = client.client_first();
512
513 let server_first = format!("r={}server_ext,s=!!!not-base64!!!,i=4096", client.nonce);
514 let result = client.client_final(&server_first);
515 assert!(matches!(result, Err(ScramError::Base64Error(_))));
516 }
517
518 #[test]
519 fn test_client_final_non_numeric_iterations() {
520 let mut client = ScramClient::new("user".to_string(), "pass".to_string());
521 let _first = client.client_first();
522
523 let server_first = format!(
524 "r={}server_ext,s={},i=abc",
525 client.nonce,
526 BASE64.encode(b"salty")
527 );
528 let result = client.client_final(&server_first);
529 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
530 }
531
532 #[test]
533 fn test_client_final_zero_iterations() {
534 let mut client = ScramClient::new("user".to_string(), "pass".to_string());
535 let _first = client.client_first();
536
537 let server_first = format!(
538 "r={}server_ext,s={},i=0",
539 client.nonce,
540 BASE64.encode(b"salty")
541 );
542 let result = client.client_final(&server_first);
544 assert!(result.is_ok());
545 }
546
547 #[test]
550 fn test_verify_server_final_missing_v_prefix() {
551 let client = ScramClient::new("user".to_string(), "pass".to_string());
552 let state = ScramState {
553 auth_message: b"dummy".to_vec(),
554 server_key: vec![0; 32],
555 };
556 let result = client.verify_server_final("not_a_valid_response", &state);
557 assert!(matches!(result, Err(ScramError::InvalidServerMessage(_))));
558 }
559
560 #[test]
561 fn test_verify_server_final_empty_after_v() {
562 let client = ScramClient::new("user".to_string(), "pass".to_string());
563 let state = ScramState {
564 auth_message: b"dummy".to_vec(),
565 server_key: vec![0; 32],
566 };
567 let result = client.verify_server_final("v=", &state);
569 assert!(matches!(result, Err(ScramError::InvalidServerProof(_))));
570 }
571
572 #[test]
573 fn test_verify_server_final_invalid_base64() {
574 let client = ScramClient::new("user".to_string(), "pass".to_string());
575 let state = ScramState {
576 auth_message: b"dummy".to_vec(),
577 server_key: vec![0; 32],
578 };
579 let result = client.verify_server_final("v=!!!invalid!!!", &state);
580 assert!(matches!(result, Err(ScramError::Base64Error(_))));
581 }
582
583 #[test]
584 fn test_verify_server_final_wrong_signature() {
585 let client = ScramClient::new("user".to_string(), "pass".to_string());
586 let state = ScramState {
587 auth_message: b"auth_msg".to_vec(),
588 server_key: vec![0x42; 32],
589 };
590 let wrong_sig = BASE64.encode(vec![0xFF; 32]);
592 let result = client.verify_server_final(&format!("v={}", wrong_sig), &state);
593 assert!(matches!(result, Err(ScramError::InvalidServerProof(_))));
594 }
595
596 #[test]
597 fn test_verify_server_final_correct_signature() {
598 let mut client = ScramClient::new("user".to_string(), "password".to_string());
599 let _first = client.client_first();
600
601 let server_nonce = format!("{}server_ext", client.nonce);
602 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
603
604 let (_client_final, state) = client.client_final(&server_first).unwrap();
605
606 let expected = calculate_server_signature(&state.server_key, &state.auth_message);
608 let server_final = format!("v={}", BASE64.encode(&expected));
609
610 let result = client.verify_server_final(&server_final, &state);
611 assert!(result.is_ok());
612 }
613
614 #[test]
617 fn test_constant_time_compare_both_empty() {
618 assert!(constant_time_compare(&[], &[]));
619 }
620
621 #[test]
622 fn test_constant_time_compare_one_empty() {
623 assert!(!constant_time_compare(&[], &[1]));
624 }
625
626 #[test]
627 fn test_constant_time_compare_single_bit_flip() {
628 let a = vec![0b1010_1010; 32];
629 let mut b = a.clone();
630 b[15] ^= 0b0000_0001; assert!(!constant_time_compare(&a, &b));
632 }
633
634 #[test]
637 fn test_channel_binding_empty_data() {
638 let mut client = ScramClient::with_channel_binding(
639 "user".to_string(),
640 "pass".to_string(),
641 ChannelBinding::TlsServerEndPoint(vec![]),
642 );
643 let _first = client.client_first();
644
645 let server_nonce = format!("{}server_ext", client.nonce);
646 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
647
648 let (client_final, _state) = client.client_final(&server_first).unwrap();
649
650 let c_value = client_final
651 .split(',')
652 .find(|s| s.starts_with("c="))
653 .unwrap()
654 .strip_prefix("c=")
655 .unwrap();
656 let decoded = BASE64.decode(c_value).unwrap();
657 assert_eq!(decoded, b"p=tls-server-end-point,,");
659 }
660
661 #[test]
664 fn test_client_final_empty_password() {
665 let mut client = ScramClient::new("user".to_string(), String::new());
666 let _first = client.client_first();
667
668 let server_nonce = format!("{}server_ext", client.nonce);
669 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
670
671 let result = client.client_final(&server_first);
672 assert!(result.is_ok());
673 }
674
675 #[test]
676 fn test_client_final_unicode_credentials() {
677 let mut client = ScramClient::new("héllo".to_string(), "pässwörd™".to_string());
678 let _first = client.client_first();
679
680 let server_nonce = format!("{}server_ext", client.nonce);
681 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
682
683 let result = client.client_final(&server_first);
684 assert!(result.is_ok());
685 }
686}