1use base64::display::Base64Display;
4use base64::engine::general_purpose::STANDARD;
5use base64::Engine;
6use hmac::{Hmac, Mac};
7use rand::{self, Rng};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
20pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
22
23fn normalize(pass: &[u8]) -> Vec<u8> {
27 let pass = match str::from_utf8(pass) {
28 Ok(pass) => pass,
29 Err(_) => return pass.to_vec(),
30 };
31
32 match stringprep::saslprep(pass) {
33 Ok(pass) => pass.into_owned().into_bytes(),
34 Err(_) => pass.as_bytes().to_vec(),
35 }
36}
37
38pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
39 let mut hmac =
40 Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
41 hmac.update(salt);
42 hmac.update(&[0, 0, 0, 1]);
43 let mut prev = hmac.finalize().into_bytes();
44
45 let mut hi = prev;
46
47 for _ in 1..i {
48 let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
49 hmac.update(&prev);
50 prev = hmac.finalize().into_bytes();
51
52 for (hi, prev) in hi.iter_mut().zip(prev) {
53 *hi ^= prev;
54 }
55 }
56
57 hi.into()
58}
59
60enum ChannelBindingInner {
61 Unrequested,
62 Unsupported,
63 TlsServerEndPoint(Vec<u8>),
64}
65
66pub struct ChannelBinding(ChannelBindingInner);
68
69impl ChannelBinding {
70 pub fn unrequested() -> ChannelBinding {
72 ChannelBinding(ChannelBindingInner::Unrequested)
73 }
74
75 pub fn unsupported() -> ChannelBinding {
77 ChannelBinding(ChannelBindingInner::Unsupported)
78 }
79
80 pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
83 ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
84 }
85
86 fn gs2_header(&self) -> &'static str {
87 match self.0 {
88 ChannelBindingInner::Unrequested => "y,,",
89 ChannelBindingInner::Unsupported => "n,,",
90 ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
91 }
92 }
93
94 fn cbind_data(&self) -> &[u8] {
95 match self.0 {
96 ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
97 ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
98 }
99 }
100}
101
102enum State {
103 Update {
104 nonce: String,
105 password: Vec<u8>,
106 channel_binding: ChannelBinding,
107 },
108 Finish {
109 salted_password: [u8; 32],
110 auth_message: String,
111 },
112 Done,
113}
114
115pub struct ScramSha256 {
131 message: String,
132 state: State,
133}
134
135impl ScramSha256 {
136 pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
138 let mut rng = rand::rng();
140 let nonce = (0..NONCE_LENGTH)
141 .map(|_| {
142 let mut v = rng.random_range(0x21u8..0x7e);
143 if v == 0x2c {
144 v = 0x7e
145 }
146 v as char
147 })
148 .collect::<String>();
149
150 ScramSha256::new_inner(password, channel_binding, nonce)
151 }
152
153 fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
154 ScramSha256 {
155 message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
156 state: State::Update {
157 nonce,
158 password: normalize(password),
159 channel_binding,
160 },
161 }
162 }
163
164 pub fn message(&self) -> &[u8] {
166 if let State::Done = self.state {
167 panic!("invalid SCRAM state");
168 }
169 self.message.as_bytes()
170 }
171
172 pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
176 let (client_nonce, password, channel_binding) =
177 match mem::replace(&mut self.state, State::Done) {
178 State::Update {
179 nonce,
180 password,
181 channel_binding,
182 } => (nonce, password, channel_binding),
183 _ => return Err(io::Error::other("invalid SCRAM state")),
184 };
185
186 let message =
187 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
188
189 let parsed = Parser::new(message).server_first_message()?;
190
191 if !parsed.nonce.starts_with(&client_nonce) {
192 return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
193 }
194
195 let salt = match STANDARD.decode(parsed.salt) {
196 Ok(salt) => salt,
197 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
198 };
199
200 let salted_password = hi(&password, &salt, parsed.iteration_count);
201
202 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
203 .expect("HMAC is able to accept all key sizes");
204 hmac.update(b"Client Key");
205 let client_key = hmac.finalize().into_bytes();
206
207 let mut hash = Sha256::default();
208 hash.update(client_key.as_slice());
209 let stored_key = hash.finalize_fixed();
210
211 let mut cbind_input = vec![];
212 cbind_input.extend(channel_binding.gs2_header().as_bytes());
213 cbind_input.extend(channel_binding.cbind_data());
214 let cbind_input = STANDARD.encode(&cbind_input);
215
216 self.message.clear();
217 write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
218
219 let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
220
221 let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
222 .expect("HMAC is able to accept all key sizes");
223 hmac.update(auth_message.as_bytes());
224 let client_signature = hmac.finalize().into_bytes();
225
226 let mut client_proof = client_key;
227 for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
228 *proof ^= signature;
229 }
230
231 write!(
232 &mut self.message,
233 ",p={}",
234 Base64Display::new(&client_proof, &STANDARD)
235 )
236 .unwrap();
237
238 self.state = State::Finish {
239 salted_password,
240 auth_message,
241 };
242 Ok(())
243 }
244
245 pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
250 let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
251 State::Finish {
252 salted_password,
253 auth_message,
254 } => (salted_password, auth_message),
255 _ => return Err(io::Error::other("invalid SCRAM state")),
256 };
257
258 let message =
259 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
260
261 let parsed = Parser::new(message).server_final_message()?;
262
263 let verifier = match parsed {
264 ServerFinalMessage::Error(e) => {
265 return Err(io::Error::other(format!("SCRAM error: {}", e)));
266 }
267 ServerFinalMessage::Verifier(verifier) => verifier,
268 };
269
270 let verifier = match STANDARD.decode(verifier) {
271 Ok(verifier) => verifier,
272 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
273 };
274
275 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
276 .expect("HMAC is able to accept all key sizes");
277 hmac.update(b"Server Key");
278 let server_key = hmac.finalize().into_bytes();
279
280 let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
281 .expect("HMAC is able to accept all key sizes");
282 hmac.update(auth_message.as_bytes());
283 hmac.verify_slice(&verifier)
284 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
285 }
286}
287
288struct Parser<'a> {
289 s: &'a str,
290 it: iter::Peekable<str::CharIndices<'a>>,
291}
292
293impl<'a> Parser<'a> {
294 fn new(s: &'a str) -> Parser<'a> {
295 Parser {
296 s,
297 it: s.char_indices().peekable(),
298 }
299 }
300
301 fn eat(&mut self, target: char) -> io::Result<()> {
302 match self.it.next() {
303 Some((_, c)) if c == target => Ok(()),
304 Some((i, c)) => {
305 let m = format!(
306 "unexpected character at byte {}: expected `{}` but got `{}",
307 i, target, c
308 );
309 Err(io::Error::new(io::ErrorKind::InvalidInput, m))
310 }
311 None => Err(io::Error::new(
312 io::ErrorKind::UnexpectedEof,
313 "unexpected EOF",
314 )),
315 }
316 }
317
318 fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
319 where
320 F: Fn(char) -> bool,
321 {
322 let start = match self.it.peek() {
323 Some(&(i, _)) => i,
324 None => return Ok(""),
325 };
326
327 loop {
328 match self.it.peek() {
329 Some(&(_, c)) if f(c) => {
330 self.it.next();
331 }
332 Some(&(i, _)) => return Ok(&self.s[start..i]),
333 None => return Ok(&self.s[start..]),
334 }
335 }
336 }
337
338 fn printable(&mut self) -> io::Result<&'a str> {
339 self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
340 }
341
342 fn nonce(&mut self) -> io::Result<&'a str> {
343 self.eat('r')?;
344 self.eat('=')?;
345 self.printable()
346 }
347
348 fn base64(&mut self) -> io::Result<&'a str> {
349 self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
350 }
351
352 fn salt(&mut self) -> io::Result<&'a str> {
353 self.eat('s')?;
354 self.eat('=')?;
355 self.base64()
356 }
357
358 fn posit_number(&mut self) -> io::Result<u32> {
359 let n = self.take_while(|c| c.is_ascii_digit())?;
360 n.parse()
361 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
362 }
363
364 fn iteration_count(&mut self) -> io::Result<u32> {
365 self.eat('i')?;
366 self.eat('=')?;
367 self.posit_number()
368 }
369
370 fn eof(&mut self) -> io::Result<()> {
371 match self.it.peek() {
372 Some(&(i, _)) => Err(io::Error::new(
373 io::ErrorKind::InvalidInput,
374 format!("unexpected trailing data at byte {}", i),
375 )),
376 None => Ok(()),
377 }
378 }
379
380 fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
381 let nonce = self.nonce()?;
382 self.eat(',')?;
383 let salt = self.salt()?;
384 self.eat(',')?;
385 let iteration_count = self.iteration_count()?;
386 self.eof()?;
387
388 Ok(ServerFirstMessage {
389 nonce,
390 salt,
391 iteration_count,
392 })
393 }
394
395 fn value(&mut self) -> io::Result<&'a str> {
396 self.take_while(|c| matches!(c, '\0' | '=' | ','))
397 }
398
399 fn server_error(&mut self) -> io::Result<Option<&'a str>> {
400 match self.it.peek() {
401 Some(&(_, 'e')) => {}
402 _ => return Ok(None),
403 }
404
405 self.eat('e')?;
406 self.eat('=')?;
407 self.value().map(Some)
408 }
409
410 fn verifier(&mut self) -> io::Result<&'a str> {
411 self.eat('v')?;
412 self.eat('=')?;
413 self.base64()
414 }
415
416 fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
417 let message = match self.server_error()? {
418 Some(error) => ServerFinalMessage::Error(error),
419 None => ServerFinalMessage::Verifier(self.verifier()?),
420 };
421 self.eof()?;
422 Ok(message)
423 }
424}
425
426struct ServerFirstMessage<'a> {
427 nonce: &'a str,
428 salt: &'a str,
429 iteration_count: u32,
430}
431
432enum ServerFinalMessage<'a> {
433 Error(&'a str),
434 Verifier(&'a str),
435}
436
437#[cfg(test)]
438mod test {
439 use super::*;
440
441 #[test]
442 fn parse_server_first_message() {
443 let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
444 let message = Parser::new(message).server_first_message().unwrap();
445 assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
446 assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
447 assert_eq!(message.iteration_count, 4096);
448 }
449
450 #[test]
452 fn exchange() {
453 let password = "foobar";
454 let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
455
456 let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
457 let server_first =
458 "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
459 =4096";
460 let client_final =
461 "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
462 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
463 let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
464
465 let mut scram = ScramSha256::new_inner(
466 password.as_bytes(),
467 ChannelBinding::unsupported(),
468 nonce.to_string(),
469 );
470 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
471
472 scram.update(server_first.as_bytes()).unwrap();
473 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
474
475 scram.finish(server_final.as_bytes()).unwrap();
476 }
477}