1#[cfg(feature = "scram")]
36use base64::{engine::general_purpose::STANDARD as B64, Engine as _};
37#[cfg(feature = "scram")]
38use hmac::{Hmac, Mac};
39#[cfg(feature = "scram")]
40use rand::RngCore;
41#[cfg(feature = "scram")]
42use sha2::{Digest, Sha256};
43
44use crate::error::{PgWireError, Result};
45
46#[cfg(feature = "scram")]
47type HmacSha256 = Hmac<Sha256>;
48
49#[cfg(feature = "scram")]
53#[derive(Debug, Clone)]
54pub struct ScramClient {
55 pub client_nonce_b64: String,
57 pub client_first_bare: String,
59 pub client_first: String,
61}
62
63#[cfg(feature = "scram")]
64impl ScramClient {
65 pub fn new(username: &str) -> ScramClient {
70 let mut nonce = [0u8; 18];
71 rand::rng().fill_bytes(&mut nonce);
72 let nonce_b64 = B64.encode(nonce);
73
74 let user = sasl_escape_username(username);
75 let client_first_bare = format!("n={user},r={nonce_b64}");
76 let client_first = format!("n,,{client_first_bare}");
77
78 ScramClient {
79 client_nonce_b64: nonce_b64,
80 client_first_bare,
81 client_first,
82 }
83 }
84
85 #[cfg(test)]
87 pub(crate) fn with_nonce(username: &str, nonce_b64: &str) -> ScramClient {
88 let user = sasl_escape_username(username);
89 let client_first_bare = format!("n={user},r={nonce_b64}");
90 let client_first = format!("n,,{client_first_bare}");
91
92 ScramClient {
93 client_nonce_b64: nonce_b64.to_string(),
94 client_first_bare,
95 client_first,
96 }
97 }
98
99 pub fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
109 let mut r = None;
110 let mut s = None;
111 let mut i = None;
112
113 for part in server_first.split(',') {
114 if let Some(v) = part.strip_prefix("r=") {
115 r = Some(v.to_string());
116 } else if let Some(v) = part.strip_prefix("s=") {
117 s = Some(v.to_string());
118 } else if let Some(v) = part.strip_prefix("i=") {
119 i = v.parse::<u32>().ok();
120 }
121 }
122
123 Ok((
124 r.ok_or_else(|| PgWireError::Auth("SCRAM server-first missing nonce (r=)".into()))?,
125 s.ok_or_else(|| PgWireError::Auth("SCRAM server-first missing salt (s=)".into()))?,
126 i.ok_or_else(|| {
127 PgWireError::Auth(
128 "SCRAM server-first missing or invalid iteration count (i=)".into(),
129 )
130 })?,
131 ))
132 }
133
134 pub fn client_final(
150 &self,
151 password: &str,
152 server_first: &str,
153 ) -> Result<(String, String, Vec<u8>)> {
154 let (rnonce, salt_b64, iters) = Self::parse_server_first(server_first)?;
155
156 if !rnonce.starts_with(&self.client_nonce_b64) {
158 return Err(PgWireError::Auth(
159 "SCRAM nonce mismatch: server nonce doesn't include client nonce".into(),
160 ));
161 }
162
163 let salt = B64
164 .decode(salt_b64.as_bytes())
165 .map_err(|e| PgWireError::Auth(format!("SCRAM invalid salt base64: {e}")))?;
166
167 let channel_binding = "biws";
169 let client_final_wo_proof = format!("c={channel_binding},r={rnonce}");
170
171 let auth_message = format!(
172 "{},{},{}",
173 self.client_first_bare, server_first, client_final_wo_proof
174 );
175
176 let salted_password = hi_sha256(password.as_bytes(), &salt, iters);
178 let client_key = hmac_sha256(&salted_password, b"Client Key");
179 let stored_key = Sha256::digest(&client_key);
180
181 let client_sig = hmac_sha256(stored_key.as_slice(), auth_message.as_bytes());
183 let proof = xor_bytes(&client_key, &client_sig);
184 let proof_b64 = B64.encode(proof);
185
186 let client_final = format!("{client_final_wo_proof},p={proof_b64}");
187 Ok((client_final, auth_message, salted_password))
188 }
189
190 pub fn verify_server_final(
205 server_final: &str,
206 salted_password: &[u8],
207 auth_message: &str,
208 ) -> Result<()> {
209 if let Some(err) = server_final.split(',').find_map(|p| p.strip_prefix("e=")) {
211 return Err(PgWireError::Auth(format!("SCRAM server error: {err}")));
212 }
213
214 let v = server_final
215 .split(',')
216 .find_map(|p| p.strip_prefix("v="))
217 .ok_or_else(|| PgWireError::Auth("SCRAM server-final missing signature (v=)".into()))?;
218
219 let server_sig = B64.decode(v.trim().as_bytes()).map_err(|e| {
220 PgWireError::Auth(format!("SCRAM invalid server signature base64: {e}"))
221 })?;
222
223 let server_key = hmac_sha256(salted_password, b"Server Key");
225 let expected = hmac_sha256(&server_key, auth_message.as_bytes());
226
227 if !constant_time_eq(&server_sig, &expected) {
229 return Err(PgWireError::Auth(
230 "SCRAM server signature mismatch: server may not know the password".into(),
231 ));
232 }
233
234 Ok(())
235 }
236}
237
238#[cfg(feature = "scram")]
242fn sasl_escape_username(u: &str) -> String {
243 u.replace('=', "=3D").replace(',', "=2C")
244}
245
246#[cfg(feature = "scram")]
250fn hi_sha256(password: &[u8], salt: &[u8], iters: u32) -> Vec<u8> {
251 let mut s1 = Vec::with_capacity(salt.len() + 4);
253 s1.extend_from_slice(salt);
254 s1.extend_from_slice(&1u32.to_be_bytes());
255
256 let mut u = hmac_sha256(password, &s1);
257 let mut out = u.clone();
258
259 for _ in 1..iters {
261 u = hmac_sha256(password, &u);
262 for (o, ui) in out.iter_mut().zip(u.iter()) {
263 *o ^= *ui;
264 }
265 }
266
267 out
268}
269
270#[cfg(feature = "scram")]
272fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
273 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC key length is always valid");
274 mac.update(msg);
275 mac.finalize().into_bytes().to_vec()
276}
277
278#[cfg(feature = "scram")]
280fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
281 debug_assert_eq!(a.len(), b.len(), "XOR operands must have equal length");
282 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
283}
284
285#[cfg(feature = "scram")]
290fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
291 if a.len() != b.len() {
292 return false;
293 }
294
295 let result = a
297 .iter()
298 .zip(b.iter())
299 .fold(0u8, |acc, (x, y)| acc | (x ^ y));
300
301 result == 0
302}
303
304#[cfg(test)]
305#[cfg(feature = "scram")]
306mod tests {
307 use super::*;
308
309 #[test]
312 fn scram_builds_first_message() {
313 let c = ScramClient::new("user");
314 assert!(c.client_first.starts_with("n,,n=user,r="));
315 assert!(c.client_first_bare.starts_with("n=user,r="));
316 assert!(!c.client_nonce_b64.is_empty());
317 }
318
319 #[test]
320 fn scram_escapes_special_chars_in_username() {
321 let c = ScramClient::new("user=name,test");
322 assert!(c.client_first.contains("n=user=3Dname=2Ctest,r="));
324 }
325
326 #[test]
327 fn scram_unique_nonces() {
328 let c1 = ScramClient::new("user");
329 let c2 = ScramClient::new("user");
330 assert_ne!(c1.client_nonce_b64, c2.client_nonce_b64);
331 }
332
333 #[test]
336 fn parse_server_first_valid() {
337 let (r, s, i) = ScramClient::parse_server_first("r=abc123,s=c2FsdA==,i=4096").unwrap();
338 assert_eq!(r, "abc123");
339 assert_eq!(s, "c2FsdA==");
340 assert_eq!(i, 4096);
341 }
342
343 #[test]
344 fn parse_server_first_different_order() {
345 let (r, s, i) = ScramClient::parse_server_first("i=1000,s=Zm9v,r=xyz").unwrap();
347 assert_eq!(r, "xyz");
348 assert_eq!(s, "Zm9v");
349 assert_eq!(i, 1000);
350 }
351
352 #[test]
353 fn parse_server_first_with_extensions() {
354 let (r, s, i) =
356 ScramClient::parse_server_first("r=nonce,s=c2FsdA==,i=4096,x=unknown").unwrap();
357 assert_eq!(r, "nonce");
358 assert_eq!(i, 4096);
359 let _ = s; }
361
362 #[test]
363 fn parse_server_first_missing_nonce() {
364 let err = ScramClient::parse_server_first("s=c2FsdA==,i=4096").unwrap_err();
365 assert!(err.to_string().contains("nonce"));
366 }
367
368 #[test]
369 fn parse_server_first_missing_salt() {
370 let err = ScramClient::parse_server_first("r=abc,i=4096").unwrap_err();
371 assert!(err.to_string().contains("salt"));
372 }
373
374 #[test]
375 fn parse_server_first_missing_iterations() {
376 let err = ScramClient::parse_server_first("r=abc,s=c2FsdA==").unwrap_err();
377 assert!(err.to_string().contains("iteration"));
378 }
379
380 #[test]
381 fn parse_server_first_invalid_iterations() {
382 let err = ScramClient::parse_server_first("r=abc,s=c2FsdA==,i=notanumber").unwrap_err();
383 assert!(err.to_string().contains("iteration"));
384 }
385
386 #[test]
389 fn client_final_computes_proof() {
390 let client = ScramClient::with_nonce("user", "rOprNGfwEbeRWgbNEkqO");
392
393 let server_first = "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096";
394
395 let (client_final, auth_message, salted_password) =
396 client.client_final("pencil", server_first).unwrap();
397
398 assert!(client_final.starts_with("c=biws,r="));
400 assert!(client_final.contains(",p="));
401
402 assert!(auth_message.contains(&client.client_first_bare));
404 assert!(auth_message.contains(server_first));
405
406 assert_eq!(salted_password.len(), 32);
408 }
409
410 #[test]
411 fn client_final_rejects_nonce_mismatch() {
412 let client = ScramClient::with_nonce("user", "clientnonce");
413
414 let server_first = "r=differentnonce,s=c2FsdA==,i=4096";
416
417 let err = client.client_final("password", server_first).unwrap_err();
418 assert!(err.to_string().contains("nonce mismatch"));
419 }
420
421 #[test]
422 fn client_final_rejects_invalid_salt_base64() {
423 let client = ScramClient::with_nonce("user", "abc");
424
425 let server_first = "r=abcdef,s=!!!invalid!!!,i=4096";
426
427 let err = client.client_final("password", server_first).unwrap_err();
428 assert!(err.to_string().contains("base64"));
429 }
430
431 #[test]
434 fn verify_server_final_accepts_valid_signature() {
435 let client = ScramClient::with_nonce("user", "fyko+d2lbbFgONRv9qkxdawL");
437
438 let server_first = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
439
440 let (_, auth_message, salted_password) =
441 client.client_final("pencil", server_first).unwrap();
442
443 let server_key = hmac_sha256(&salted_password, b"Server Key");
445 let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
446 let server_final = format!("v={}", B64.encode(&server_sig));
447
448 ScramClient::verify_server_final(&server_final, &salted_password, &auth_message).unwrap();
450 }
451
452 #[test]
453 fn verify_server_final_rejects_wrong_signature() {
454 let salted_password = vec![0u8; 32];
455 let auth_message = "test";
456 let server_final = "v=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; let err = ScramClient::verify_server_final(server_final, &salted_password, auth_message)
459 .unwrap_err();
460 assert!(err.to_string().contains("signature mismatch"));
461 }
462
463 #[test]
464 fn verify_server_final_rejects_missing_signature() {
465 let err = ScramClient::verify_server_final("", &[], "").unwrap_err();
466 assert!(err.to_string().contains("missing signature"));
467 }
468
469 #[test]
470 fn verify_server_final_handles_server_error() {
471 let err = ScramClient::verify_server_final("e=invalid-proof", &[], "").unwrap_err();
472 assert!(err.to_string().contains("server error"));
473 assert!(err.to_string().contains("invalid-proof"));
474 }
475
476 #[test]
477 fn verify_server_final_rejects_invalid_base64() {
478 let err = ScramClient::verify_server_final("v=!!!invalid!!!", &[], "").unwrap_err();
479 assert!(err.to_string().contains("base64"));
480 }
481
482 #[test]
485 fn sasl_escape_username_escapes_equals() {
486 assert_eq!(sasl_escape_username("a=b"), "a=3Db");
487 }
488
489 #[test]
490 fn sasl_escape_username_escapes_comma() {
491 assert_eq!(sasl_escape_username("a,b"), "a=2Cb");
492 }
493
494 #[test]
495 fn sasl_escape_username_escapes_both() {
496 assert_eq!(sasl_escape_username("a=b,c"), "a=3Db=2Cc");
497 }
498
499 #[test]
500 fn sasl_escape_username_preserves_normal() {
501 assert_eq!(sasl_escape_username("normal_user123"), "normal_user123");
502 }
503
504 #[test]
505 fn hi_sha256_single_iteration() {
506 let result = hi_sha256(b"password", b"salt", 1);
508 assert_eq!(result.len(), 32);
509 }
510
511 #[test]
512 fn hi_sha256_multiple_iterations() {
513 let result = hi_sha256(b"password", b"salt", 4096);
514 assert_eq!(result.len(), 32);
515
516 let result2 = hi_sha256(b"password", b"salt", 1000);
518 assert_ne!(result, result2);
519 }
520
521 #[test]
522 fn hmac_sha256_produces_correct_length() {
523 let result = hmac_sha256(b"key", b"message");
524 assert_eq!(result.len(), 32);
525 }
526
527 #[test]
528 fn xor_bytes_works() {
529 assert_eq!(xor_bytes(&[0xFF, 0x00], &[0x0F, 0xF0]), vec![0xF0, 0xF0]);
530 assert_eq!(xor_bytes(&[0x00], &[0x00]), vec![0x00]);
531 }
532
533 #[test]
534 fn constant_time_eq_equal() {
535 assert!(constant_time_eq(&[1, 2, 3], &[1, 2, 3]));
536 assert!(constant_time_eq(&[], &[]));
537 }
538
539 #[test]
540 fn constant_time_eq_not_equal() {
541 assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 4]));
542 assert!(!constant_time_eq(&[1, 2, 3], &[1, 2]));
543 }
544
545 #[test]
546 fn constant_time_eq_different_lengths() {
547 assert!(!constant_time_eq(&[1, 2, 3], &[1, 2, 3, 4]));
548 }
549}