postgres_protocol/authentication/
sasl.rs

1//! SASL-based authentication support.
2
3use base64::display::Base64Display;
4use base64::engine::general_purpose::STANDARD;
5use base64::Engine;
6use hmac::{Hmac, Mac};
7use rand::{self, Rng};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
19pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
20/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
21pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
22
23// since postgres passwords are not required to exclude saslprep-prohibited
24// characters or even be valid UTF8, we run saslprep if possible and otherwise
25// return the raw password.
26fn normalize(pass: &[u8]) -> Vec<u8> {
27    let pass = match str::from_utf8(pass) {
28        Ok(pass) => pass,
29        Err(_) => return pass.to_vec(),
30    };
31
32    match stringprep::saslprep(pass) {
33        Ok(pass) => pass.into_owned().into_bytes(),
34        Err(_) => pass.as_bytes().to_vec(),
35    }
36}
37
38pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
39    let mut hmac =
40        Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
41    hmac.update(salt);
42    hmac.update(&[0, 0, 0, 1]);
43    let mut prev = hmac.finalize().into_bytes();
44
45    let mut hi = prev;
46
47    for _ in 1..i {
48        let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
49        hmac.update(&prev);
50        prev = hmac.finalize().into_bytes();
51
52        for (hi, prev) in hi.iter_mut().zip(prev) {
53            *hi ^= prev;
54        }
55    }
56
57    hi.into()
58}
59
60enum ChannelBindingInner {
61    Unrequested,
62    Unsupported,
63    TlsServerEndPoint(Vec<u8>),
64}
65
66/// The channel binding configuration for a SCRAM authentication exchange.
67pub struct ChannelBinding(ChannelBindingInner);
68
69impl ChannelBinding {
70    /// The server did not request channel binding.
71    pub fn unrequested() -> ChannelBinding {
72        ChannelBinding(ChannelBindingInner::Unrequested)
73    }
74
75    /// The server requested channel binding but the client is unable to provide it.
76    pub fn unsupported() -> ChannelBinding {
77        ChannelBinding(ChannelBindingInner::Unsupported)
78    }
79
80    /// The server requested channel binding and the client will use the `tls-server-end-point`
81    /// method.
82    pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
83        ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
84    }
85
86    fn gs2_header(&self) -> &'static str {
87        match self.0 {
88            ChannelBindingInner::Unrequested => "y,,",
89            ChannelBindingInner::Unsupported => "n,,",
90            ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
91        }
92    }
93
94    fn cbind_data(&self) -> &[u8] {
95        match self.0 {
96            ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
97            ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
98        }
99    }
100}
101
102enum State {
103    Update {
104        nonce: String,
105        password: Vec<u8>,
106        channel_binding: ChannelBinding,
107    },
108    Finish {
109        salted_password: [u8; 32],
110        auth_message: String,
111    },
112    Done,
113}
114
115/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
116/// process.
117///
118/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
119/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
120///
121/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
122/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
123///
124/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
125/// passed to the `update()` method, after which the buffer returned by the `message()` method
126/// should be sent to the backend in a `SASLResponse` message.
127///
128/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
129/// to the `finish()` method, after which the authentication process is complete.
130pub struct ScramSha256 {
131    message: String,
132    state: State,
133}
134
135impl ScramSha256 {
136    /// Constructs a new instance which will use the provided password for authentication.
137    pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
138        // rand 0.5's ThreadRng is cryptographically secure
139        let mut rng = rand::rng();
140        let nonce = (0..NONCE_LENGTH)
141            .map(|_| {
142                let mut v = rng.random_range(0x21u8..0x7e);
143                if v == 0x2c {
144                    v = 0x7e
145                }
146                v as char
147            })
148            .collect::<String>();
149
150        ScramSha256::new_inner(password, channel_binding, nonce)
151    }
152
153    fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
154        ScramSha256 {
155            message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
156            state: State::Update {
157                nonce,
158                password: normalize(password),
159                channel_binding,
160            },
161        }
162    }
163
164    /// Returns the message which should be sent to the backend in an `SASLResponse` message.
165    pub fn message(&self) -> &[u8] {
166        if let State::Done = self.state {
167            panic!("invalid SCRAM state");
168        }
169        self.message.as_bytes()
170    }
171
172    /// Updates the state machine with the response from the backend.
173    ///
174    /// This should be called when an `AuthenticationSASLContinue` message is received.
175    pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
176        let (client_nonce, password, channel_binding) =
177            match mem::replace(&mut self.state, State::Done) {
178                State::Update {
179                    nonce,
180                    password,
181                    channel_binding,
182                } => (nonce, password, channel_binding),
183                _ => return Err(io::Error::other("invalid SCRAM state")),
184            };
185
186        let message =
187            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
188
189        let parsed = Parser::new(message).server_first_message()?;
190
191        if !parsed.nonce.starts_with(&client_nonce) {
192            return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
193        }
194
195        let salt = match STANDARD.decode(parsed.salt) {
196            Ok(salt) => salt,
197            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
198        };
199
200        let salted_password = hi(&password, &salt, parsed.iteration_count);
201
202        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
203            .expect("HMAC is able to accept all key sizes");
204        hmac.update(b"Client Key");
205        let client_key = hmac.finalize().into_bytes();
206
207        let mut hash = Sha256::default();
208        hash.update(client_key.as_slice());
209        let stored_key = hash.finalize_fixed();
210
211        let mut cbind_input = vec![];
212        cbind_input.extend(channel_binding.gs2_header().as_bytes());
213        cbind_input.extend(channel_binding.cbind_data());
214        let cbind_input = STANDARD.encode(&cbind_input);
215
216        self.message.clear();
217        write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
218
219        let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
220
221        let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
222            .expect("HMAC is able to accept all key sizes");
223        hmac.update(auth_message.as_bytes());
224        let client_signature = hmac.finalize().into_bytes();
225
226        let mut client_proof = client_key;
227        for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
228            *proof ^= signature;
229        }
230
231        write!(
232            &mut self.message,
233            ",p={}",
234            Base64Display::new(&client_proof, &STANDARD)
235        )
236        .unwrap();
237
238        self.state = State::Finish {
239            salted_password,
240            auth_message,
241        };
242        Ok(())
243    }
244
245    /// Finalizes the authentication process.
246    ///
247    /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
248    /// Authentication has only succeeded if this method returns `Ok(())`.
249    pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
250        let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
251            State::Finish {
252                salted_password,
253                auth_message,
254            } => (salted_password, auth_message),
255            _ => return Err(io::Error::other("invalid SCRAM state")),
256        };
257
258        let message =
259            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
260
261        let parsed = Parser::new(message).server_final_message()?;
262
263        let verifier = match parsed {
264            ServerFinalMessage::Error(e) => {
265                return Err(io::Error::other(format!("SCRAM error: {e}")));
266            }
267            ServerFinalMessage::Verifier(verifier) => verifier,
268        };
269
270        let verifier = match STANDARD.decode(verifier) {
271            Ok(verifier) => verifier,
272            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
273        };
274
275        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
276            .expect("HMAC is able to accept all key sizes");
277        hmac.update(b"Server Key");
278        let server_key = hmac.finalize().into_bytes();
279
280        let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
281            .expect("HMAC is able to accept all key sizes");
282        hmac.update(auth_message.as_bytes());
283        hmac.verify_slice(&verifier)
284            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
285    }
286}
287
288struct Parser<'a> {
289    s: &'a str,
290    it: iter::Peekable<str::CharIndices<'a>>,
291}
292
293impl<'a> Parser<'a> {
294    fn new(s: &'a str) -> Parser<'a> {
295        Parser {
296            s,
297            it: s.char_indices().peekable(),
298        }
299    }
300
301    fn eat(&mut self, target: char) -> io::Result<()> {
302        match self.it.next() {
303            Some((_, c)) if c == target => Ok(()),
304            Some((i, c)) => {
305                let m =
306                    format!("unexpected character at byte {i}: expected `{target}` but got `{c}");
307                Err(io::Error::new(io::ErrorKind::InvalidInput, m))
308            }
309            None => Err(io::Error::new(
310                io::ErrorKind::UnexpectedEof,
311                "unexpected EOF",
312            )),
313        }
314    }
315
316    fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
317    where
318        F: Fn(char) -> bool,
319    {
320        let start = match self.it.peek() {
321            Some(&(i, _)) => i,
322            None => return Ok(""),
323        };
324
325        loop {
326            match self.it.peek() {
327                Some(&(_, c)) if f(c) => {
328                    self.it.next();
329                }
330                Some(&(i, _)) => return Ok(&self.s[start..i]),
331                None => return Ok(&self.s[start..]),
332            }
333        }
334    }
335
336    fn printable(&mut self) -> io::Result<&'a str> {
337        self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
338    }
339
340    fn nonce(&mut self) -> io::Result<&'a str> {
341        self.eat('r')?;
342        self.eat('=')?;
343        self.printable()
344    }
345
346    fn base64(&mut self) -> io::Result<&'a str> {
347        self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
348    }
349
350    fn salt(&mut self) -> io::Result<&'a str> {
351        self.eat('s')?;
352        self.eat('=')?;
353        self.base64()
354    }
355
356    fn posit_number(&mut self) -> io::Result<u32> {
357        let n = self.take_while(|c| c.is_ascii_digit())?;
358        n.parse()
359            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
360    }
361
362    fn iteration_count(&mut self) -> io::Result<u32> {
363        self.eat('i')?;
364        self.eat('=')?;
365        self.posit_number()
366    }
367
368    fn eof(&mut self) -> io::Result<()> {
369        match self.it.peek() {
370            Some(&(i, _)) => Err(io::Error::new(
371                io::ErrorKind::InvalidInput,
372                format!("unexpected trailing data at byte {i}"),
373            )),
374            None => Ok(()),
375        }
376    }
377
378    fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
379        let nonce = self.nonce()?;
380        self.eat(',')?;
381        let salt = self.salt()?;
382        self.eat(',')?;
383        let iteration_count = self.iteration_count()?;
384        self.eof()?;
385
386        Ok(ServerFirstMessage {
387            nonce,
388            salt,
389            iteration_count,
390        })
391    }
392
393    fn value(&mut self) -> io::Result<&'a str> {
394        self.take_while(|c| matches!(c, '\0' | '=' | ','))
395    }
396
397    fn server_error(&mut self) -> io::Result<Option<&'a str>> {
398        match self.it.peek() {
399            Some(&(_, 'e')) => {}
400            _ => return Ok(None),
401        }
402
403        self.eat('e')?;
404        self.eat('=')?;
405        self.value().map(Some)
406    }
407
408    fn verifier(&mut self) -> io::Result<&'a str> {
409        self.eat('v')?;
410        self.eat('=')?;
411        self.base64()
412    }
413
414    fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
415        let message = match self.server_error()? {
416            Some(error) => ServerFinalMessage::Error(error),
417            None => ServerFinalMessage::Verifier(self.verifier()?),
418        };
419        self.eof()?;
420        Ok(message)
421    }
422}
423
424struct ServerFirstMessage<'a> {
425    nonce: &'a str,
426    salt: &'a str,
427    iteration_count: u32,
428}
429
430enum ServerFinalMessage<'a> {
431    Error(&'a str),
432    Verifier(&'a str),
433}
434
435#[cfg(test)]
436mod test {
437    use super::*;
438
439    #[test]
440    fn parse_server_first_message() {
441        let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
442        let message = Parser::new(message).server_first_message().unwrap();
443        assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
444        assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
445        assert_eq!(message.iteration_count, 4096);
446    }
447
448    // recorded auth exchange from psql
449    #[test]
450    fn exchange() {
451        let password = "foobar";
452        let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
453
454        let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
455        let server_first =
456            "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
457             =4096";
458        let client_final =
459            "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
460             1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
461        let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
462
463        let mut scram = ScramSha256::new_inner(
464            password.as_bytes(),
465            ChannelBinding::unsupported(),
466            nonce.to_string(),
467        );
468        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
469
470        scram.update(server_first.as_bytes()).unwrap();
471        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
472
473        scram.finish(server_final.as_bytes()).unwrap();
474    }
475}