postgres_protocol/authentication/
sasl.rs1use base64::display::Base64Display;
4use base64::engine::general_purpose::STANDARD;
5use base64::Engine;
6use hmac::{Hmac, Mac};
7use rand::{self, Rng};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
20pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
22
23fn normalize(pass: &[u8]) -> Vec<u8> {
27 let pass = match str::from_utf8(pass) {
28 Ok(pass) => pass,
29 Err(_) => return pass.to_vec(),
30 };
31
32 match stringprep::saslprep(pass) {
33 Ok(pass) => pass.into_owned().into_bytes(),
34 Err(_) => pass.as_bytes().to_vec(),
35 }
36}
37
38pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
39 let mut hmac =
40 Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
41 hmac.update(salt);
42 hmac.update(&[0, 0, 0, 1]);
43 let mut prev = hmac.finalize().into_bytes();
44
45 let mut hi = prev;
46
47 for _ in 1..i {
48 let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
49 hmac.update(&prev);
50 prev = hmac.finalize().into_bytes();
51
52 for (hi, prev) in hi.iter_mut().zip(prev) {
53 *hi ^= prev;
54 }
55 }
56
57 hi.into()
58}
59
60enum ChannelBindingInner {
61 Unrequested,
62 Unsupported,
63 TlsServerEndPoint(Vec<u8>),
64}
65
66pub struct ChannelBinding(ChannelBindingInner);
68
69impl ChannelBinding {
70 pub fn unrequested() -> ChannelBinding {
72 ChannelBinding(ChannelBindingInner::Unrequested)
73 }
74
75 pub fn unsupported() -> ChannelBinding {
77 ChannelBinding(ChannelBindingInner::Unsupported)
78 }
79
80 pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
83 ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
84 }
85
86 fn gs2_header(&self) -> &'static str {
87 match self.0 {
88 ChannelBindingInner::Unrequested => "y,,",
89 ChannelBindingInner::Unsupported => "n,,",
90 ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
91 }
92 }
93
94 fn cbind_data(&self) -> &[u8] {
95 match self.0 {
96 ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
97 ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
98 }
99 }
100}
101
102enum State {
103 Update {
104 nonce: String,
105 password: Vec<u8>,
106 channel_binding: ChannelBinding,
107 },
108 Finish {
109 salted_password: [u8; 32],
110 auth_message: String,
111 },
112 Done,
113}
114
115pub struct ScramSha256 {
131 message: String,
132 state: State,
133}
134
135impl ScramSha256 {
136 pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
138 let mut rng = rand::rng();
140 let nonce = (0..NONCE_LENGTH)
141 .map(|_| {
142 let mut v = rng.random_range(0x21u8..0x7e);
143 if v == 0x2c {
144 v = 0x7e
145 }
146 v as char
147 })
148 .collect::<String>();
149
150 ScramSha256::new_inner(password, channel_binding, nonce)
151 }
152
153 fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
154 ScramSha256 {
155 message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
156 state: State::Update {
157 nonce,
158 password: normalize(password),
159 channel_binding,
160 },
161 }
162 }
163
164 pub fn message(&self) -> &[u8] {
166 if let State::Done = self.state {
167 panic!("invalid SCRAM state");
168 }
169 self.message.as_bytes()
170 }
171
172 pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
176 let (client_nonce, password, channel_binding) =
177 match mem::replace(&mut self.state, State::Done) {
178 State::Update {
179 nonce,
180 password,
181 channel_binding,
182 } => (nonce, password, channel_binding),
183 _ => return Err(io::Error::other("invalid SCRAM state")),
184 };
185
186 let message =
187 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
188
189 let parsed = Parser::new(message).server_first_message()?;
190
191 if !parsed.nonce.starts_with(&client_nonce) {
192 return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
193 }
194
195 let salt = match STANDARD.decode(parsed.salt) {
196 Ok(salt) => salt,
197 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
198 };
199
200 let salted_password = hi(&password, &salt, parsed.iteration_count);
201
202 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
203 .expect("HMAC is able to accept all key sizes");
204 hmac.update(b"Client Key");
205 let client_key = hmac.finalize().into_bytes();
206
207 let mut hash = Sha256::default();
208 hash.update(client_key.as_slice());
209 let stored_key = hash.finalize_fixed();
210
211 let mut cbind_input = vec![];
212 cbind_input.extend(channel_binding.gs2_header().as_bytes());
213 cbind_input.extend(channel_binding.cbind_data());
214 let cbind_input = STANDARD.encode(&cbind_input);
215
216 self.message.clear();
217 write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
218
219 let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
220
221 let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
222 .expect("HMAC is able to accept all key sizes");
223 hmac.update(auth_message.as_bytes());
224 let client_signature = hmac.finalize().into_bytes();
225
226 let mut client_proof = client_key;
227 for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
228 *proof ^= signature;
229 }
230
231 write!(
232 &mut self.message,
233 ",p={}",
234 Base64Display::new(&client_proof, &STANDARD)
235 )
236 .unwrap();
237
238 self.state = State::Finish {
239 salted_password,
240 auth_message,
241 };
242 Ok(())
243 }
244
245 pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
250 let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
251 State::Finish {
252 salted_password,
253 auth_message,
254 } => (salted_password, auth_message),
255 _ => return Err(io::Error::other("invalid SCRAM state")),
256 };
257
258 let message =
259 str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
260
261 let parsed = Parser::new(message).server_final_message()?;
262
263 let verifier = match parsed {
264 ServerFinalMessage::Error(e) => {
265 return Err(io::Error::other(format!("SCRAM error: {e}")));
266 }
267 ServerFinalMessage::Verifier(verifier) => verifier,
268 };
269
270 let verifier = match STANDARD.decode(verifier) {
271 Ok(verifier) => verifier,
272 Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
273 };
274
275 let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
276 .expect("HMAC is able to accept all key sizes");
277 hmac.update(b"Server Key");
278 let server_key = hmac.finalize().into_bytes();
279
280 let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
281 .expect("HMAC is able to accept all key sizes");
282 hmac.update(auth_message.as_bytes());
283 hmac.verify_slice(&verifier)
284 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
285 }
286}
287
288struct Parser<'a> {
289 s: &'a str,
290 it: iter::Peekable<str::CharIndices<'a>>,
291}
292
293impl<'a> Parser<'a> {
294 fn new(s: &'a str) -> Parser<'a> {
295 Parser {
296 s,
297 it: s.char_indices().peekable(),
298 }
299 }
300
301 fn eat(&mut self, target: char) -> io::Result<()> {
302 match self.it.next() {
303 Some((_, c)) if c == target => Ok(()),
304 Some((i, c)) => {
305 let m =
306 format!("unexpected character at byte {i}: expected `{target}` but got `{c}");
307 Err(io::Error::new(io::ErrorKind::InvalidInput, m))
308 }
309 None => Err(io::Error::new(
310 io::ErrorKind::UnexpectedEof,
311 "unexpected EOF",
312 )),
313 }
314 }
315
316 fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
317 where
318 F: Fn(char) -> bool,
319 {
320 let start = match self.it.peek() {
321 Some(&(i, _)) => i,
322 None => return Ok(""),
323 };
324
325 loop {
326 match self.it.peek() {
327 Some(&(_, c)) if f(c) => {
328 self.it.next();
329 }
330 Some(&(i, _)) => return Ok(&self.s[start..i]),
331 None => return Ok(&self.s[start..]),
332 }
333 }
334 }
335
336 fn printable(&mut self) -> io::Result<&'a str> {
337 self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
338 }
339
340 fn nonce(&mut self) -> io::Result<&'a str> {
341 self.eat('r')?;
342 self.eat('=')?;
343 self.printable()
344 }
345
346 fn base64(&mut self) -> io::Result<&'a str> {
347 self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
348 }
349
350 fn salt(&mut self) -> io::Result<&'a str> {
351 self.eat('s')?;
352 self.eat('=')?;
353 self.base64()
354 }
355
356 fn posit_number(&mut self) -> io::Result<u32> {
357 let n = self.take_while(|c| c.is_ascii_digit())?;
358 n.parse()
359 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
360 }
361
362 fn iteration_count(&mut self) -> io::Result<u32> {
363 self.eat('i')?;
364 self.eat('=')?;
365 self.posit_number()
366 }
367
368 fn eof(&mut self) -> io::Result<()> {
369 match self.it.peek() {
370 Some(&(i, _)) => Err(io::Error::new(
371 io::ErrorKind::InvalidInput,
372 format!("unexpected trailing data at byte {i}"),
373 )),
374 None => Ok(()),
375 }
376 }
377
378 fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
379 let nonce = self.nonce()?;
380 self.eat(',')?;
381 let salt = self.salt()?;
382 self.eat(',')?;
383 let iteration_count = self.iteration_count()?;
384 self.eof()?;
385
386 Ok(ServerFirstMessage {
387 nonce,
388 salt,
389 iteration_count,
390 })
391 }
392
393 fn value(&mut self) -> io::Result<&'a str> {
394 self.take_while(|c| matches!(c, '\0' | '=' | ','))
395 }
396
397 fn server_error(&mut self) -> io::Result<Option<&'a str>> {
398 match self.it.peek() {
399 Some(&(_, 'e')) => {}
400 _ => return Ok(None),
401 }
402
403 self.eat('e')?;
404 self.eat('=')?;
405 self.value().map(Some)
406 }
407
408 fn verifier(&mut self) -> io::Result<&'a str> {
409 self.eat('v')?;
410 self.eat('=')?;
411 self.base64()
412 }
413
414 fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
415 let message = match self.server_error()? {
416 Some(error) => ServerFinalMessage::Error(error),
417 None => ServerFinalMessage::Verifier(self.verifier()?),
418 };
419 self.eof()?;
420 Ok(message)
421 }
422}
423
424struct ServerFirstMessage<'a> {
425 nonce: &'a str,
426 salt: &'a str,
427 iteration_count: u32,
428}
429
430enum ServerFinalMessage<'a> {
431 Error(&'a str),
432 Verifier(&'a str),
433}
434
435#[cfg(test)]
436mod test {
437 use super::*;
438
439 #[test]
440 fn parse_server_first_message() {
441 let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
442 let message = Parser::new(message).server_first_message().unwrap();
443 assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
444 assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
445 assert_eq!(message.iteration_count, 4096);
446 }
447
448 #[test]
450 fn exchange() {
451 let password = "foobar";
452 let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
453
454 let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
455 let server_first =
456 "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
457 =4096";
458 let client_final =
459 "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
460 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
461 let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
462
463 let mut scram = ScramSha256::new_inner(
464 password.as_bytes(),
465 ChannelBinding::unsupported(),
466 nonce.to_string(),
467 );
468 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
469
470 scram.update(server_first.as_bytes()).unwrap();
471 assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
472
473 scram.finish(server_final.as_bytes()).unwrap();
474 }
475}