1use super::error::{BackendError, BackendResult};
15use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
16use hmac::{Hmac, Mac};
17use sha2::{Digest, Sha256};
18
19type HmacSha256 = Hmac<Sha256>;
20
21pub fn md5_password_response(user: &str, password: &str, salt: &[u8; 4]) -> Vec<u8> {
31 let mut out = Vec::with_capacity(35 + 1);
32 let inner = md5_hex(format!("{}{}", password, user).as_bytes());
33 let mut salted = Vec::with_capacity(inner.len() + 4);
34 salted.extend_from_slice(inner.as_bytes());
35 salted.extend_from_slice(salt);
36 out.extend_from_slice(b"md5");
37 out.extend_from_slice(md5_hex(&salted).as_bytes());
38 out.push(0);
39 out
40}
41
42fn md5_hex(bytes: &[u8]) -> String {
43 let digest = md5::Md5::digest(bytes);
44 let mut s = String::with_capacity(digest.len() * 2);
45 for b in digest {
46 s.push_str(&format!("{:02x}", b));
47 }
48 s
49}
50
51pub struct Scram {
59 client_first_bare: String,
61 nonce: String,
63 server_key: [u8; 32],
65 auth_message: String,
67 finalised: bool,
69}
70
71#[derive(Debug)]
73pub struct ScramMessage(pub Vec<u8>);
74
75impl Scram {
76 pub fn client_first(nonce: impl Into<String>) -> (Self, ScramMessage) {
82 let nonce = nonce.into();
83 let client_first_bare = format!("n=,r={}", nonce);
87 let client_first = format!("n,,{}", client_first_bare);
88
89 let mech = b"SCRAM-SHA-256\0";
91 let mut out = Vec::with_capacity(mech.len() + 4 + client_first.len());
92 out.extend_from_slice(mech);
93 out.extend_from_slice(&(client_first.len() as u32).to_be_bytes());
94 out.extend_from_slice(client_first.as_bytes());
95
96 (
97 Self {
98 client_first_bare,
99 nonce,
100 server_key: [0u8; 32],
101 auth_message: String::new(),
102 finalised: false,
103 },
104 ScramMessage(out),
105 )
106 }
107
108 pub fn client_final(
114 &mut self,
115 server_first: &[u8],
116 password: &str,
117 ) -> BackendResult<ScramMessage> {
118 let server_first_str = std::str::from_utf8(server_first)
119 .map_err(|e| BackendError::Auth(format!("server-first is not UTF-8: {}", e)))?;
120
121 let mut server_nonce = None;
123 let mut salt_b64 = None;
124 let mut iterations: Option<u32> = None;
125 for field in server_first_str.split(',') {
126 if let Some(rest) = field.strip_prefix("r=") {
127 server_nonce = Some(rest);
128 } else if let Some(rest) = field.strip_prefix("s=") {
129 salt_b64 = Some(rest);
130 } else if let Some(rest) = field.strip_prefix("i=") {
131 iterations = rest.parse().ok();
132 }
133 }
134 let server_nonce =
135 server_nonce.ok_or_else(|| BackendError::Auth("missing r= in server-first".into()))?;
136 let salt_b64 =
137 salt_b64.ok_or_else(|| BackendError::Auth("missing s= in server-first".into()))?;
138 let iterations = iterations
139 .ok_or_else(|| BackendError::Auth("missing/invalid i= in server-first".into()))?;
140
141 if !server_nonce.starts_with(&self.nonce) {
143 return Err(BackendError::Auth(
144 "server nonce does not extend client nonce".into(),
145 ));
146 }
147 if iterations < 1 {
148 return Err(BackendError::Auth("iteration count must be >= 1".into()));
149 }
150
151 let salt = BASE64
152 .decode(salt_b64)
153 .map_err(|e| BackendError::Auth(format!("bad salt base64: {}", e)))?;
154
155 let salted_password = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
157 let client_key = hmac_sha256(&salted_password, b"Client Key");
158 let stored_key = sha256(&client_key);
159 self.server_key = hmac_sha256(&salted_password, b"Server Key");
160
161 let channel_binding = BASE64.encode(b"n,,");
163
164 let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
165 self.auth_message = format!(
166 "{},{},{}",
167 self.client_first_bare, server_first_str, client_final_without_proof
168 );
169
170 let client_signature = hmac_sha256(&stored_key, self.auth_message.as_bytes());
171 let mut client_proof = [0u8; 32];
172 for i in 0..32 {
173 client_proof[i] = client_key[i] ^ client_signature[i];
174 }
175
176 let client_final = format!(
177 "{},p={}",
178 client_final_without_proof,
179 BASE64.encode(client_proof)
180 );
181
182 self.finalised = true;
183 Ok(ScramMessage(client_final.into_bytes()))
184 }
185
186 pub fn verify_server(&self, server_final: &[u8]) -> BackendResult<()> {
190 if !self.finalised {
191 return Err(BackendError::Auth(
192 "verify_server called before client_final".into(),
193 ));
194 }
195 let s = std::str::from_utf8(server_final)
196 .map_err(|e| BackendError::Auth(format!("server-final is not UTF-8: {}", e)))?;
197 if let Some(err) = s.strip_prefix("e=") {
199 return Err(BackendError::Auth(format!("server reported: {}", err)));
200 }
201 let sig_b64 = s
202 .strip_prefix("v=")
203 .ok_or_else(|| BackendError::Auth("missing v= in server-final".into()))?
204 .split(',')
205 .next()
206 .unwrap_or("");
207 let received = BASE64
208 .decode(sig_b64)
209 .map_err(|e| BackendError::Auth(format!("bad v= base64: {}", e)))?;
210 let expected = hmac_sha256(&self.server_key, self.auth_message.as_bytes());
211 if received == expected {
212 Ok(())
213 } else {
214 Err(BackendError::Auth("server signature mismatch".into()))
215 }
216 }
217}
218
219pub(crate) fn sha256(data: &[u8]) -> [u8; 32] {
222 let mut h = Sha256::new();
223 h.update(data);
224 h.finalize().into()
225}
226
227pub(crate) fn hmac_sha256(key: &[u8], data: &[u8]) -> [u8; 32] {
228 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
229 mac.update(data);
230 let tag = mac.finalize().into_bytes();
231 let mut out = [0u8; 32];
232 out.copy_from_slice(&tag);
233 out
234}
235
236pub(crate) fn pbkdf2_hmac_sha256(password: &[u8], salt: &[u8], iters: u32) -> [u8; 32] {
237 let mut mac = HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
240 mac.update(salt);
241 mac.update(&1u32.to_be_bytes());
242 let mut u: [u8; 32] = mac.finalize().into_bytes().into();
243 let mut out = u;
244 for _ in 1..iters {
245 let mut mac = HmacSha256::new_from_slice(password).expect("HMAC accepts any key length");
246 mac.update(&u);
247 u = mac.finalize().into_bytes().into();
248 for i in 0..32 {
249 out[i] ^= u[i];
250 }
251 }
252 out
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
262 fn test_md5_password_response_known_answer() {
263 let got = md5_password_response("alice", "secret", &[0x01, 0x02, 0x03, 0x04]);
265 assert_eq!(got.last().copied(), Some(0u8));
267 let body = std::str::from_utf8(&got[..got.len() - 1]).unwrap();
268 assert!(body.starts_with("md5"));
269 assert_eq!(body.len(), 3 + 32); let inner = md5_hex(b"secretalice");
272 let mut combined = inner.into_bytes();
273 combined.extend_from_slice(&[0x01, 0x02, 0x03, 0x04]);
274 let outer = md5_hex(&combined);
275 assert_eq!(&body[3..], outer);
276 }
277
278 #[test]
281 fn test_pbkdf2_hmac_sha256_rfc_vector() {
282 let got = pbkdf2_hmac_sha256(b"password", b"salt", 1);
283 let expected = [
284 0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, 0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4,
285 0xf8, 0x37, 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48, 0x08, 0x05, 0x98, 0x7c,
286 0xb7, 0x0b, 0xe1, 0x7b,
287 ];
288 assert_eq!(got, expected);
289 }
290
291 #[test]
294 fn test_pbkdf2_hmac_sha256_high_iters() {
295 let got = pbkdf2_hmac_sha256(b"password", b"salt", 4096);
296 let expected = [
297 0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, 0xaa, 0x53, 0x0d, 0xb6, 0x84, 0x5c,
298 0x4c, 0x8d, 0x96, 0x28, 0x93, 0xa0, 0x01, 0xce, 0x4e, 0x11, 0xa4, 0x96, 0x38, 0x73,
299 0xaa, 0x98, 0x13, 0x4a,
300 ];
301 assert_eq!(got, expected);
302 }
303
304 #[test]
310 fn test_scram_roundtrip_against_synthetic_server() {
311 let (mut scram, first) = Scram::client_first("fyko+d2lbbFgONRv9qkxdawL");
313 let msg = &first.0;
316 let mech_end = msg.iter().position(|&b| b == 0).unwrap();
317 assert_eq!(&msg[..mech_end], b"SCRAM-SHA-256");
318 let len = u32::from_be_bytes(msg[mech_end + 1..mech_end + 5].try_into().unwrap()) as usize;
319 let cfirst = &msg[mech_end + 5..mech_end + 5 + len];
320 let cfirst_str = std::str::from_utf8(cfirst).unwrap();
321 assert!(cfirst_str.starts_with("n,,n=,r=fyko+d2lbbFgONRv9qkxdawL"));
322
323 let server_nonce_suffix = "3rfcNHYJY1ZVvWVs7j";
325 let combined_nonce = format!("fyko+d2lbbFgONRv9qkxdawL{}", server_nonce_suffix);
326 let salt: [u8; 16] = [
327 0x41, 0x25, 0xc2, 0x47, 0xe4, 0x3a, 0xb1, 0xe9, 0x3c, 0x6d, 0xff, 0x76, 0xd1, 0x22,
328 0x3a, 0x10,
329 ];
330 let iterations = 4096u32;
331 let salt_b64 = BASE64.encode(salt);
332 let server_first = format!("r={},s={},i={}", combined_nonce, salt_b64, iterations);
333
334 let password = "pencil";
335 let client_final = scram
336 .client_final(server_first.as_bytes(), password)
337 .expect("client_final");
338 let cfinal_str = std::str::from_utf8(&client_final.0).unwrap();
339
340 assert!(cfinal_str.starts_with("c=biws,r=")); assert!(cfinal_str.contains(&format!("r={}", combined_nonce)));
343 assert!(cfinal_str.contains(",p="));
344
345 let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
348 let server_key = hmac_sha256(&salted, b"Server Key");
349 let (cfinal_no_proof, _proof) = {
350 let idx = cfinal_str.rfind(",p=").unwrap();
351 (&cfinal_str[..idx], &cfinal_str[idx + 3..])
352 };
353 let auth_message = format!(
354 "n=,r=fyko+d2lbbFgONRv9qkxdawL,{},{}",
355 server_first, cfinal_no_proof
356 );
357 let server_sig = hmac_sha256(&server_key, auth_message.as_bytes());
358 let server_final = format!("v={}", BASE64.encode(server_sig));
359
360 scram
362 .verify_server(server_final.as_bytes())
363 .expect("verify_server");
364 }
365
366 #[test]
367 fn test_scram_rejects_nonce_mismatch() {
368 let (mut scram, _) = Scram::client_first("client-nonce");
369 let server_first = "r=OTHER-nonce,s=QUJD,i=4096";
370 let err = scram
371 .client_final(server_first.as_bytes(), "pw")
372 .unwrap_err();
373 assert!(matches!(err, BackendError::Auth(_)));
374 }
375
376 #[test]
377 fn test_scram_rejects_bad_server_signature() {
378 let (mut scram, _) = Scram::client_first("abc");
379 let server_first = "r=abc-extension,s=QUJD,i=4096";
381 let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
382 let bad_sig = BASE64.encode([0u8; 32]);
384 let server_final = format!("v={}", bad_sig);
385 assert!(scram.verify_server(server_final.as_bytes()).is_err());
386 }
387
388 #[test]
389 fn test_scram_rejects_server_error() {
390 let (mut scram, _) = Scram::client_first("abc");
391 let server_first = "r=abc-extension,s=QUJD,i=4096";
392 let _ = scram.client_final(server_first.as_bytes(), "pw").unwrap();
393 assert!(scram.verify_server(b"e=invalid-proof").is_err());
394 }
395}