1use base64::Engine;
4use base64::display::Base64Display;
5use base64::engine::general_purpose::STANDARD;
6use hmac::{Hmac, KeyInit, Mac};
7use rand::{self, RngExt};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18const MAX_ITERATION_COUNT: u32 = 100_000;
27
28pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
30pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
32
33fn normalize(pass: &[u8]) -> Vec<u8> {
37 let pass = match str::from_utf8(pass) {
38 Ok(pass) => pass,
39 Err(_) => return pass.to_vec(),
40 };
41
42 match stringprep::saslprep(pass) {
43 Ok(pass) => pass.into_owned().into_bytes(),
44 Err(_) => pass.as_bytes().to_vec(),
45 }
46}
47
48pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
49 let mut hmac =
50 Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
51 hmac.update(salt);
52 hmac.update(&[0, 0, 0, 1]);
53 let mut prev = hmac.finalize().into_bytes();
54
55 let mut hi = prev;
56
57 for _ in 1..i {
58 let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
59 hmac.update(&prev);
60 prev = hmac.finalize().into_bytes();
61
62 for (hi, prev) in hi.iter_mut().zip(prev) {
63 *hi ^= prev;
64 }
65 }
66
67 hi.into()
68}
69
70enum ChannelBindingInner {
71 Unrequested,
72 Unsupported,
73 TlsServerEndPoint(Vec<u8>),
74}
75
76pub struct ChannelBinding(ChannelBindingInner);
78
79impl ChannelBinding {
80 pub fn unrequested() -> ChannelBinding {
82 ChannelBinding(ChannelBindingInner::Unrequested)
83 }
84
85 pub fn unsupported() -> ChannelBinding {
87 ChannelBinding(ChannelBindingInner::Unsupported)
88 }
89
90 pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
93 ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
94 }
95
96 fn gs2_header(&self) -> &'static str {
97 match self.0 {
98 ChannelBindingInner::Unrequested => "y,,",
99 ChannelBindingInner::Unsupported => "n,,",
100 ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
101 }
102 }
103
104 fn cbind_data(&self) -> &[u8] {
105 match self.0 {
106 ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
107 ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
108 }
109 }
110}
111
112enum State {
113 Update {
114 nonce: String,
115 password: Vec<u8>,
116 channel_binding: ChannelBinding,
117 },
118 Finish {
119 salted_password: [u8; 32],
120 auth_message: String,
121 },
122 Done,
123}
124
125pub struct ScramSha256 {
141 message: String,
142 state: State,
143}
144
145impl ScramSha256 {
146 pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
148 let mut rng = rand::rng();
150 let nonce = (0..NONCE_LENGTH)
151 .map(|_| {
152 let mut v = rng.random_range(0x21u8..0x7e);
153 if v == 0x2c {
154 v = 0x7e
155 }
156 v as char
157 })
158 .collect::<String>();
159
160 ScramSha256::new_inner(password, channel_binding, nonce)
161 }
162
163 fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
164 ScramSha256 {
165 message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
166 state: State::Update {
167 nonce,
168 password: normalize(password),
169 channel_binding,
170 },
171 }
172 }
173
174 pub fn message(&self) -> &[u8] {
176 if let State::Done = self.state {
177 panic!("invalid SCRAM state");
178 }
179 self.message.as_bytes()
180 }
181
182 pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
186 let (client_nonce, password, channel_binding) =
187 match mem::replace(&mut self.state, State::Done) {
188 State::Update {
189 nonce,
190 password,
191 channel_binding,
192 } => (nonce, password, channel_binding),
193 _ => return Err(io::Error::other("invalid SCRAM state")),
194 };
195
196 let message =
197 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
198
199 let parsed = Parser::new(message).server_first_message()?;
200
201 if !parsed.nonce.starts_with(&client_nonce) {
202 return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
203 }
204
205 if parsed.iteration_count > MAX_ITERATION_COUNT {
206 return Err(io::Error::new(
207 io::ErrorKind::InvalidInput,
208 "SCRAM iteration count exceeds the maximum allowed",
209 ));
210 }
211
212 let salt = match STANDARD.decode(parsed.salt) {
213 Ok(salt) => salt,
214 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
215 };
216
217 let salted_password = hi(&password, &salt, parsed.iteration_count);
218
219 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
220 .expect("HMAC is able to accept all key sizes");
221 hmac.update(b"Client Key");
222 let client_key = hmac.finalize().into_bytes();
223
224 let mut hash = Sha256::default();
225 hash.update(client_key);
226 let stored_key = hash.finalize_fixed();
227
228 let mut cbind_input = vec![];
229 cbind_input.extend(channel_binding.gs2_header().as_bytes());
230 cbind_input.extend(channel_binding.cbind_data());
231 let cbind_input = STANDARD.encode(&cbind_input);
232
233 self.message.clear();
234 write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
235
236 let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
237
238 let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
239 .expect("HMAC is able to accept all key sizes");
240 hmac.update(auth_message.as_bytes());
241 let client_signature = hmac.finalize().into_bytes();
242
243 let mut client_proof = client_key;
244 for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
245 *proof ^= signature;
246 }
247
248 write!(
249 &mut self.message,
250 ",p={}",
251 Base64Display::new(&client_proof, &STANDARD)
252 )
253 .unwrap();
254
255 self.state = State::Finish {
256 salted_password,
257 auth_message,
258 };
259 Ok(())
260 }
261
262 pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
267 let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
268 State::Finish {
269 salted_password,
270 auth_message,
271 } => (salted_password, auth_message),
272 _ => return Err(io::Error::other("invalid SCRAM state")),
273 };
274
275 let message =
276 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
277
278 let parsed = Parser::new(message).server_final_message()?;
279
280 let verifier = match parsed {
281 ServerFinalMessage::Error(e) => {
282 return Err(io::Error::other(format!("SCRAM error: {e}")));
283 }
284 ServerFinalMessage::Verifier(verifier) => verifier,
285 };
286
287 let verifier = match STANDARD.decode(verifier) {
288 Ok(verifier) => verifier,
289 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
290 };
291
292 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
293 .expect("HMAC is able to accept all key sizes");
294 hmac.update(b"Server Key");
295 let server_key = hmac.finalize().into_bytes();
296
297 let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
298 .expect("HMAC is able to accept all key sizes");
299 hmac.update(auth_message.as_bytes());
300 hmac.verify_slice(&verifier)
301 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
302 }
303}
304
305struct Parser<'a> {
306 s: &'a str,
307 it: iter::Peekable<str::CharIndices<'a>>,
308}
309
310impl<'a> Parser<'a> {
311 fn new(s: &'a str) -> Parser<'a> {
312 Parser {
313 s,
314 it: s.char_indices().peekable(),
315 }
316 }
317
318 fn eat(&mut self, target: char) -> io::Result<()> {
319 match self.it.next() {
320 Some((_, c)) if c == target => Ok(()),
321 Some((i, c)) => {
322 let m =
323 format!("unexpected character at byte {i}: expected `{target}` but got `{c}");
324 Err(io::Error::new(io::ErrorKind::InvalidInput, m))
325 }
326 None => Err(io::Error::new(
327 io::ErrorKind::UnexpectedEof,
328 "unexpected EOF",
329 )),
330 }
331 }
332
333 fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
334 where
335 F: Fn(char) -> bool,
336 {
337 let start = match self.it.peek() {
338 Some(&(i, _)) => i,
339 None => return Ok(""),
340 };
341
342 loop {
343 match self.it.peek() {
344 Some(&(_, c)) if f(c) => {
345 self.it.next();
346 }
347 Some(&(i, _)) => return Ok(&self.s[start..i]),
348 None => return Ok(&self.s[start..]),
349 }
350 }
351 }
352
353 fn printable(&mut self) -> io::Result<&'a str> {
354 self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
355 }
356
357 fn nonce(&mut self) -> io::Result<&'a str> {
358 self.eat('r')?;
359 self.eat('=')?;
360 self.printable()
361 }
362
363 fn base64(&mut self) -> io::Result<&'a str> {
364 self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
365 }
366
367 fn salt(&mut self) -> io::Result<&'a str> {
368 self.eat('s')?;
369 self.eat('=')?;
370 self.base64()
371 }
372
373 fn posit_number(&mut self) -> io::Result<u32> {
374 let n = self.take_while(|c| c.is_ascii_digit())?;
375 n.parse()
376 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
377 }
378
379 fn iteration_count(&mut self) -> io::Result<u32> {
380 self.eat('i')?;
381 self.eat('=')?;
382 self.posit_number()
383 }
384
385 fn eof(&mut self) -> io::Result<()> {
386 match self.it.peek() {
387 Some(&(i, _)) => Err(io::Error::new(
388 io::ErrorKind::InvalidInput,
389 format!("unexpected trailing data at byte {i}"),
390 )),
391 None => Ok(()),
392 }
393 }
394
395 fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
396 let nonce = self.nonce()?;
397 self.eat(',')?;
398 let salt = self.salt()?;
399 self.eat(',')?;
400 let iteration_count = self.iteration_count()?;
401 self.eof()?;
402
403 Ok(ServerFirstMessage {
404 nonce,
405 salt,
406 iteration_count,
407 })
408 }
409
410 fn value(&mut self) -> io::Result<&'a str> {
411 self.take_while(|c| !matches!(c, '\0' | '=' | ','))
412 }
413
414 fn server_error(&mut self) -> io::Result<Option<&'a str>> {
415 match self.it.peek() {
416 Some(&(_, 'e')) => {}
417 _ => return Ok(None),
418 }
419
420 self.eat('e')?;
421 self.eat('=')?;
422 self.value().map(Some)
423 }
424
425 fn verifier(&mut self) -> io::Result<&'a str> {
426 self.eat('v')?;
427 self.eat('=')?;
428 self.base64()
429 }
430
431 fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
432 let message = match self.server_error()? {
433 Some(error) => ServerFinalMessage::Error(error),
434 None => ServerFinalMessage::Verifier(self.verifier()?),
435 };
436 self.eof()?;
437 Ok(message)
438 }
439}
440
441struct ServerFirstMessage<'a> {
442 nonce: &'a str,
443 salt: &'a str,
444 iteration_count: u32,
445}
446
447enum ServerFinalMessage<'a> {
448 Error(&'a str),
449 Verifier(&'a str),
450}
451
452#[cfg(test)]
453mod test {
454 use super::*;
455
456 #[test]
457 fn parse_server_first_message() {
458 let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
459 let message = Parser::new(message).server_first_message().unwrap();
460 assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
461 assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
462 assert_eq!(message.iteration_count, 4096);
463 }
464
465 #[test]
466 fn parse_server_error_message() {
467 let message = "e=invalid-proof";
468 match Parser::new(message).server_final_message().unwrap() {
469 ServerFinalMessage::Error(error) => assert_eq!(error, "invalid-proof"),
470 ServerFinalMessage::Verifier(_) => panic!("expected server error"),
471 }
472
473 for message in ["invalid-proof\0x", "invalid-proof=x", "invalid-proof,x"] {
475 assert_eq!(Parser::new(message).value().unwrap(), "invalid-proof");
476 }
477 }
478
479 #[test]
481 fn exchange() {
482 let password = "foobar";
483 let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
484
485 let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
486 let server_first = "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
487 =4096";
488 let client_final = "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
489 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
490 let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
491
492 let mut scram = ScramSha256::new_inner(
493 password.as_bytes(),
494 ChannelBinding::unsupported(),
495 nonce.to_string(),
496 );
497 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
498
499 scram.update(server_first.as_bytes()).unwrap();
500 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
501
502 scram.finish(server_final.as_bytes()).unwrap();
503 }
504
505 #[test]
506 fn excessive_iteration_count_is_rejected() {
507 let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
510 let server_first =
511 "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i=1000000";
512
513 let mut scram =
514 ScramSha256::new_inner(b"foobar", ChannelBinding::unsupported(), nonce.to_string());
515 assert!(scram.update(server_first.as_bytes()).is_err());
516 }
517}