1use super::error::{BackendError, BackendResult};
15use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
16use hmac::{Hmac, Mac};
17use sha2::{Digest, Sha256};
18
19type HmacSha256 = Hmac<Sha256>;
20
21pub fn md5_password_response(user: &str, password: &str, salt: &[u8; 4]) -> Vec<u8> {
31 let mut out = Vec::with_capacity(35 + 1);
32 let inner = md5_hex(format!("{}{}", password, user).as_bytes());
33 let mut salted = Vec::with_capacity(inner.len() + 4);
34 salted.extend_from_slice(inner.as_bytes());
35 salted.extend_from_slice(salt);
36 out.extend_from_slice(b"md5");
37 out.extend_from_slice(md5_hex(&salted).as_bytes());
38 out.push(0);
39 out
40}
41
42fn md5_hex(bytes: &[u8]) -> String {
43 let digest = md5::Md5::digest(bytes);
44 let mut s = String::with_capacity(digest.len() * 2);
45 for b in digest {
46 s.push_str(&format!("{:02x}", b));
47 }
48 s
49}
50
51pub struct Scram {
59 client_first_bare: String,
61 nonce: String,
63 server_key: [u8; 32],
65 auth_message: String,
67 finalised: bool,
69}
70
71#[derive(Debug)]
73pub struct ScramMessage(pub Vec<u8>);
74
75impl Scram {
76 pub fn client_first(nonce: impl Into<String>) -> (Self, ScramMessage) {
82 let nonce = nonce.into();
83 let client_first_bare = format!("n=,r={}", nonce);
87 let client_first = format!("n,,{}", client_first_bare);
88
89 let mech = b"SCRAM-SHA-256\0";
91 let mut out = Vec::with_capacity(mech.len() + 4 + client_first.len());
92 out.extend_from_slice(mech);
93 out.extend_from_slice(&(client_first.len() as u32).to_be_bytes());
94 out.extend_from_slice(client_first.as_bytes());
95
96 (
97 Self {
98 client_first_bare,
99 nonce,
100 server_key: [0u8; 32],
101 auth_message: String::new(),
102 finalised: false,
103 },
104 ScramMessage(out),
105 )
106 }
107
108 pub fn client_final(
114 &mut self,
115 server_first: &[u8],
116 password: &str,
117 ) -> BackendResult<ScramMessage> {
118 let server_first_str = std::str::from_utf8(server_first).map_err(|e| {
119 BackendError::Auth(format!("server-first is not UTF-8: {}", e))
120 })?;
121
122 let mut server_nonce = None;
124 let mut salt_b64 = None;
125 let mut iterations: Option<u32> = None;
126 for field in server_first_str.split(',') {
127 if let Some(rest) = field.strip_prefix("r=") {
128 server_nonce = Some(rest);
129 } else if let Some(rest) = field.strip_prefix("s=") {
130 salt_b64 = Some(rest);
131 } else if let Some(rest) = field.strip_prefix("i=") {
132 iterations = rest.parse().ok();
133 }
134 }
135 let server_nonce = server_nonce
136 .ok_or_else(|| BackendError::Auth("missing r= in server-first".into()))?;
137 let salt_b64 = salt_b64
138 .ok_or_else(|| BackendError::Auth("missing s= in server-first".into()))?;
139 let iterations = iterations
140 .ok_or_else(|| BackendError::Auth("missing/invalid i= in server-first".into()))?;
141
142 if !server_nonce.starts_with(&self.nonce) {
144 return Err(BackendError::Auth(
145 "server nonce does not extend client nonce".into(),
146 ));
147 }
148 if iterations < 1 {
149 return Err(BackendError::Auth("iteration count must be >= 1".into()));
150 }
151
152 let salt = BASE64
153 .decode(salt_b64)
154 .map_err(|e| BackendError::Auth(format!("bad salt base64: {}", e)))?;
155
156 let salted_password = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
158 let client_key = hmac_sha256(&salted_password, b"Client Key");
159 let stored_key = sha256(&client_key);
160 self.server_key = hmac_sha256(&salted_password, b"Server Key");
161
162 let channel_binding = BASE64.encode(b"n,,");
164
165 let client_final_without_proof =
166 format!("c={},r={}", channel_binding, server_nonce);
167 self.auth_message = format!(
168 "{},{},{}",
169 self.client_first_bare, server_first_str, client_final_without_proof
170 );
171
172 let client_signature = hmac_sha256(&stored_key, self.auth_message.as_bytes());
173 let mut client_proof = [0u8; 32];
174 for i in 0..32 {
175 client_proof[i] = client_key[i] ^ client_signature[i];
176 }
177
178 let client_final = format!(
179 "{},p={}",
180 client_final_without_proof,
181 BASE64.encode(client_proof)
182 );
183
184 self.finalised = true;
185 Ok(ScramMessage(client_final.into_bytes()))
186 }
187
188 pub fn verify_server(&self, server_final: &[u8]) -> BackendResult<()> {
192 if !self.finalised {
193 return Err(BackendError::Auth(
194 "verify_server called before client_final".into(),
195 ));
196 }
197 let s = std::str::from_utf8(server_final).map_err(|e| {
198 BackendError::Auth(format!("server-final is not UTF-8: {}", e))
199 })?;
200 if let Some(err) = s.strip_prefix("e=") {
202 return Err(BackendError::Auth(format!("server reported: {}", err)));
203 }
204 let sig_b64 = s
205 .strip_prefix("v=")
206 .ok_or_else(|| BackendError::Auth("missing v= in server-final".into()))?
207 .split(',')
208 .next()
209 .unwrap_or("");
210 let received = BASE64
211 .decode(sig_b64)
212 .map_err(|e| BackendError::Auth(format!("bad v= base64: {}", e)))?;
213 let expected = hmac_sha256(&self.server_key, self.auth_message.as_bytes());
214 if received == expected {
215 Ok(())
216 } else {
217 Err(BackendError::Auth("server signature mismatch".into()))
218 }
219 }
220}
221
222fn sha256(data: &[u8]) -> [u8; 32] {
225 let mut h = Sha256::new();
226 h.update(data);
227 h.finalize().into()
228}
229
230fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
231 let mut mac =
232 HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
233 mac.update(data);
234 let tag = mac.finalize().into_bytes();
235 let mut out = [0u8; 32];
236 out.copy_from_slice(&tag);
237 out
238}
239
240fn pbkdf2_hmac_sha256(password: &[u8], salt: &[u8], iters: u32) -> [u8; 32] {
241 let mut mac = HmacSha256::new_from_slice(password)
244 .expect("HMAC accepts any key length");
245 mac.update(salt);
246 mac.update(&1u32.to_be_bytes());
247 let mut u: [u8; 32] = mac.finalize().into_bytes().into();
248 let mut out = u;
249 for _ in 1..iters {
250 let mut mac = HmacSha256::new_from_slice(password)
251 .expect("HMAC accepts any key length");
252 mac.update(&u);
253 u = mac.finalize().into_bytes().into();
254 for i in 0..32 {
255 out[i] ^= u[i];
256 }
257 }
258 out
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
268 fn test_md5_password_response_known_answer() {
269 let got = md5_password_response("alice", "secret", &[0x01, 0x02, 0x03, 0x04]);
271 assert_eq!(got.last().copied(), Some(0u8));
273 let body = std::str::from_utf8(&got[..got.len() - 1]).unwrap();
274 assert!(body.starts_with("md5"));
275 assert_eq!(body.len(), 3 + 32); let inner = md5_hex(b"secretalice");
278 let mut combined = inner.into_bytes();
279 combined.extend_from_slice(&[0x01, 0x02, 0x03, 0x04]);
280 let outer = md5_hex(&combined);
281 assert_eq!(&body[3..], outer);
282 }
283
284 #[test]
287 fn test_pbkdf2_hmac_sha256_rfc_vector() {
288 let got = pbkdf2_hmac_sha256(b"password", b"salt", 1);
289 let expected = [
290 0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, 0x43, 0xe7, 0x22, 0x52,
291 0x56, 0xc4, 0xf8, 0x37, 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48,
292 0x08, 0x05, 0x98, 0x7c, 0xb7, 0x0b, 0xe1, 0x7b,
293 ];
294 assert_eq!(got, expected);
295 }
296
297 #[test]
300 fn test_pbkdf2_hmac_sha256_high_iters() {
301 let got = pbkdf2_hmac_sha256(b"password", b"salt", 4096);
302 let expected = [
303 0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, 0xaa, 0x53, 0x0d, 0xb6,
304 0x84, 0x5c, 0x4c, 0x8d, 0x96, 0x28, 0x93, 0xa0, 0x01, 0xce, 0x4e, 0x11,
305 0xa4, 0x96, 0x38, 0x73, 0xaa, 0x98, 0x13, 0x4a,
306 ];
307 assert_eq!(got, expected);
308 }
309
310 #[test]
316 fn test_scram_roundtrip_against_synthetic_server() {
317 let (mut scram, first) = Scram::client_first("fyko+d2lbbFgONRv9qkxdawL");
319 let msg = &first.0;
322 let mech_end = msg.iter().position(|&b| b == 0).unwrap();
323 assert_eq!(&msg[..mech_end], b"SCRAM-SHA-256");
324 let len =
325 u32::from_be_bytes(msg[mech_end + 1..mech_end + 5].try_into().unwrap())
326 as usize;
327 let cfirst = &msg[mech_end + 5..mech_end + 5 + len];
328 let cfirst_str = std::str::from_utf8(cfirst).unwrap();
329 assert!(cfirst_str.starts_with("n,,n=,r=fyko+d2lbbFgONRv9qkxdawL"));
330
331 let server_nonce_suffix = "3rfcNHYJY1ZVvWVs7j";
333 let combined_nonce =
334 format!("fyko+d2lbbFgONRv9qkxdawL{}", server_nonce_suffix);
335 let salt: [u8; 16] = [
336 0x41, 0x25, 0xc2, 0x47, 0xe4, 0x3a, 0xb1, 0xe9, 0x3c, 0x6d, 0xff, 0x76,
337 0xd1, 0x22, 0x3a, 0x10,
338 ];
339 let iterations = 4096u32;
340 let salt_b64 = BASE64.encode(salt);
341 let server_first = format!(
342 "r={},s={},i={}",
343 combined_nonce, salt_b64, iterations
344 );
345
346 let password = "pencil";
347 let client_final = scram
348 .client_final(server_first.as_bytes(), password)
349 .expect("client_final");
350 let cfinal_str = std::str::from_utf8(&client_final.0).unwrap();
351
352 assert!(cfinal_str.starts_with("c=biws,r=")); assert!(cfinal_str.contains(&format!("r={}", combined_nonce)));
355 assert!(cfinal_str.contains(",p="));
356
357 let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
360 let server_key = hmac_sha256(&salted, b"Server Key");
361 let (cfinal_no_proof, _proof) = {
362 let idx = cfinal_str.rfind(",p=").unwrap();
363 (&cfinal_str[..idx], &cfinal_str[idx + 3..])
364 };
365 let auth_message = format!(
366 "n=,r=fyko+d2lbbFgONRv9qkxdawL,{},{}",
367 server_first, cfinal_no_proof
368 );
369 let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
370 let server_final = format!("v={}", BASE64.encode(server_sig));
371
372 scram
374 .verify_server(server_final.as_bytes())
375 .expect("verify_server");
376 }
377
378 #[test]
379 fn test_scram_rejects_nonce_mismatch() {
380 let (mut scram, _) = Scram::client_first("client-nonce");
381 let server_first = "r=OTHER-nonce,s=QUJD,i=4096";
382 let err = scram.client_final(server_first.as_bytes(), "pw").unwrap_err();
383 assert!(matches!(err, BackendError::Auth(_)));
384 }
385
386 #[test]
387 fn test_scram_rejects_bad_server_signature() {
388 let (mut scram, _) = Scram::client_first("abc");
389 let server_first = "r=abc-extension,s=QUJD,i=4096";
391 let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
392 let bad_sig = BASE64.encode([0u8; 32]);
394 let server_final = format!("v={}", bad_sig);
395 assert!(scram.verify_server(server_final.as_bytes()).is_err());
396 }
397
398 #[test]
399 fn test_scram_rejects_server_error() {
400 let (mut scram, _) = Scram::client_first("abc");
401 let server_first = "r=abc-extension,s=QUJD,i=4096";
402 let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
403 assert!(scram.verify_server(b"e=invalid-proof").is_err());
404 }
405}