Skip to main content

postgres_protocol/authentication/
sasl.rs

1//! SASL-based authentication support.
2
3use base64::Engine;
4use base64::display::Base64Display;
5use base64::engine::general_purpose::STANDARD;
6use hmac::{Hmac, KeyInit, Mac};
7use rand::{self, RngExt};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18/// The maximum SCRAM iteration count the client will accept from the server.
19///
20/// The iteration count is sent by the server and drives a PBKDF2 loop, so an
21/// unbounded value lets a malicious or impersonating server force the client to
22/// perform an arbitrary number of HMAC operations before authentication even
23/// completes (a denial of service). 100_000 is ~24x the PostgreSQL default of
24/// 4096 and matches the default cap the PostgreSQL JDBC driver (pgjdbc) adopted
25/// for the same issue (CVE-2026-42198).
26const MAX_ITERATION_COUNT: u32 = 100_000;
27
28/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
29pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
30/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
31pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
32
33// since postgres passwords are not required to exclude saslprep-prohibited
34// characters or even be valid UTF8, we run saslprep if possible and otherwise
35// return the raw password.
36fn normalize(pass: &[u8]) -> Vec<u8> {
37    let pass = match str::from_utf8(pass) {
38        Ok(pass) => pass,
39        Err(_) => return pass.to_vec(),
40    };
41
42    match stringprep::saslprep(pass) {
43        Ok(pass) => pass.into_owned().into_bytes(),
44        Err(_) => pass.as_bytes().to_vec(),
45    }
46}
47
48pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
49    let mut hmac =
50        Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
51    hmac.update(salt);
52    hmac.update(&[0, 0, 0, 1]);
53    let mut prev = hmac.finalize().into_bytes();
54
55    let mut hi = prev;
56
57    for _ in 1..i {
58        let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
59        hmac.update(&prev);
60        prev = hmac.finalize().into_bytes();
61
62        for (hi, prev) in hi.iter_mut().zip(prev) {
63            *hi ^= prev;
64        }
65    }
66
67    hi.into()
68}
69
70enum ChannelBindingInner {
71    Unrequested,
72    Unsupported,
73    TlsServerEndPoint(Vec<u8>),
74}
75
76/// The channel binding configuration for a SCRAM authentication exchange.
77pub struct ChannelBinding(ChannelBindingInner);
78
79impl ChannelBinding {
80    /// The server did not request channel binding.
81    pub fn unrequested() -> ChannelBinding {
82        ChannelBinding(ChannelBindingInner::Unrequested)
83    }
84
85    /// The server requested channel binding but the client is unable to provide it.
86    pub fn unsupported() -> ChannelBinding {
87        ChannelBinding(ChannelBindingInner::Unsupported)
88    }
89
90    /// The server requested channel binding and the client will use the `tls-server-end-point`
91    /// method.
92    pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
93        ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
94    }
95
96    fn gs2_header(&self) -> &'static str {
97        match self.0 {
98            ChannelBindingInner::Unrequested => "y,,",
99            ChannelBindingInner::Unsupported => "n,,",
100            ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
101        }
102    }
103
104    fn cbind_data(&self) -> &[u8] {
105        match self.0 {
106            ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
107            ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
108        }
109    }
110}
111
112enum State {
113    Update {
114        nonce: String,
115        password: Vec<u8>,
116        channel_binding: ChannelBinding,
117    },
118    Finish {
119        salted_password: [u8; 32],
120        auth_message: String,
121    },
122    Done,
123}
124
125/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
126/// process.
127///
128/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
129/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
130///
131/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
132/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
133///
134/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
135/// passed to the `update()` method, after which the buffer returned by the `message()` method
136/// should be sent to the backend in a `SASLResponse` message.
137///
138/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
139/// to the `finish()` method, after which the authentication process is complete.
140pub struct ScramSha256 {
141    message: String,
142    state: State,
143}
144
145impl ScramSha256 {
146    /// Constructs a new instance which will use the provided password for authentication.
147    pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
148        // rand 0.5's ThreadRng is cryptographically secure
149        let mut rng = rand::rng();
150        let nonce = (0..NONCE_LENGTH)
151            .map(|_| {
152                let mut v = rng.random_range(0x21u8..0x7e);
153                if v == 0x2c {
154                    v = 0x7e
155                }
156                v as char
157            })
158            .collect::<String>();
159
160        ScramSha256::new_inner(password, channel_binding, nonce)
161    }
162
163    fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
164        ScramSha256 {
165            message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
166            state: State::Update {
167                nonce,
168                password: normalize(password),
169                channel_binding,
170            },
171        }
172    }
173
174    /// Returns the message which should be sent to the backend in an `SASLResponse` message.
175    pub fn message(&self) -> &[u8] {
176        if let State::Done = self.state {
177            panic!("invalid SCRAM state");
178        }
179        self.message.as_bytes()
180    }
181
182    /// Updates the state machine with the response from the backend.
183    ///
184    /// This should be called when an `AuthenticationSASLContinue` message is received.
185    pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
186        let (client_nonce, password, channel_binding) =
187            match mem::replace(&mut self.state, State::Done) {
188                State::Update {
189                    nonce,
190                    password,
191                    channel_binding,
192                } => (nonce, password, channel_binding),
193                _ => return Err(io::Error::other("invalid SCRAM state")),
194            };
195
196        let message =
197            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
198
199        let parsed = Parser::new(message).server_first_message()?;
200
201        if !parsed.nonce.starts_with(&client_nonce) {
202            return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
203        }
204
205        if parsed.iteration_count > MAX_ITERATION_COUNT {
206            return Err(io::Error::new(
207                io::ErrorKind::InvalidInput,
208                "SCRAM iteration count exceeds the maximum allowed",
209            ));
210        }
211
212        let salt = match STANDARD.decode(parsed.salt) {
213            Ok(salt) => salt,
214            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
215        };
216
217        let salted_password = hi(&password, &salt, parsed.iteration_count);
218
219        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
220            .expect("HMAC is able to accept all key sizes");
221        hmac.update(b"Client Key");
222        let client_key = hmac.finalize().into_bytes();
223
224        let mut hash = Sha256::default();
225        hash.update(client_key);
226        let stored_key = hash.finalize_fixed();
227
228        let mut cbind_input = vec![];
229        cbind_input.extend(channel_binding.gs2_header().as_bytes());
230        cbind_input.extend(channel_binding.cbind_data());
231        let cbind_input = STANDARD.encode(&cbind_input);
232
233        self.message.clear();
234        write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
235
236        let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
237
238        let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
239            .expect("HMAC is able to accept all key sizes");
240        hmac.update(auth_message.as_bytes());
241        let client_signature = hmac.finalize().into_bytes();
242
243        let mut client_proof = client_key;
244        for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
245            *proof ^= signature;
246        }
247
248        write!(
249            &mut self.message,
250            ",p={}",
251            Base64Display::new(&client_proof, &STANDARD)
252        )
253        .unwrap();
254
255        self.state = State::Finish {
256            salted_password,
257            auth_message,
258        };
259        Ok(())
260    }
261
262    /// Finalizes the authentication process.
263    ///
264    /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
265    /// Authentication has only succeeded if this method returns `Ok(())`.
266    pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
267        let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
268            State::Finish {
269                salted_password,
270                auth_message,
271            } => (salted_password, auth_message),
272            _ => return Err(io::Error::other("invalid SCRAM state")),
273        };
274
275        let message =
276            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
277
278        let parsed = Parser::new(message).server_final_message()?;
279
280        let verifier = match parsed {
281            ServerFinalMessage::Error(e) => {
282                return Err(io::Error::other(format!("SCRAM error: {e}")));
283            }
284            ServerFinalMessage::Verifier(verifier) => verifier,
285        };
286
287        let verifier = match STANDARD.decode(verifier) {
288            Ok(verifier) => verifier,
289            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
290        };
291
292        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
293            .expect("HMAC is able to accept all key sizes");
294        hmac.update(b"Server Key");
295        let server_key = hmac.finalize().into_bytes();
296
297        let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
298            .expect("HMAC is able to accept all key sizes");
299        hmac.update(auth_message.as_bytes());
300        hmac.verify_slice(&verifier)
301            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
302    }
303}
304
305struct Parser<'a> {
306    s: &'a str,
307    it: iter::Peekable<str::CharIndices<'a>>,
308}
309
310impl<'a> Parser<'a> {
311    fn new(s: &'a str) -> Parser<'a> {
312        Parser {
313            s,
314            it: s.char_indices().peekable(),
315        }
316    }
317
318    fn eat(&mut self, target: char) -> io::Result<()> {
319        match self.it.next() {
320            Some((_, c)) if c == target => Ok(()),
321            Some((i, c)) => {
322                let m =
323                    format!("unexpected character at byte {i}: expected `{target}` but got `{c}");
324                Err(io::Error::new(io::ErrorKind::InvalidInput, m))
325            }
326            None => Err(io::Error::new(
327                io::ErrorKind::UnexpectedEof,
328                "unexpected EOF",
329            )),
330        }
331    }
332
333    fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
334    where
335        F: Fn(char) -> bool,
336    {
337        let start = match self.it.peek() {
338            Some(&(i, _)) => i,
339            None => return Ok(""),
340        };
341
342        loop {
343            match self.it.peek() {
344                Some(&(_, c)) if f(c) => {
345                    self.it.next();
346                }
347                Some(&(i, _)) => return Ok(&self.s[start..i]),
348                None => return Ok(&self.s[start..]),
349            }
350        }
351    }
352
353    fn printable(&mut self) -> io::Result<&'a str> {
354        self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
355    }
356
357    fn nonce(&mut self) -> io::Result<&'a str> {
358        self.eat('r')?;
359        self.eat('=')?;
360        self.printable()
361    }
362
363    fn base64(&mut self) -> io::Result<&'a str> {
364        self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
365    }
366
367    fn salt(&mut self) -> io::Result<&'a str> {
368        self.eat('s')?;
369        self.eat('=')?;
370        self.base64()
371    }
372
373    fn posit_number(&mut self) -> io::Result<u32> {
374        let n = self.take_while(|c| c.is_ascii_digit())?;
375        n.parse()
376            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
377    }
378
379    fn iteration_count(&mut self) -> io::Result<u32> {
380        self.eat('i')?;
381        self.eat('=')?;
382        self.posit_number()
383    }
384
385    fn eof(&mut self) -> io::Result<()> {
386        match self.it.peek() {
387            Some(&(i, _)) => Err(io::Error::new(
388                io::ErrorKind::InvalidInput,
389                format!("unexpected trailing data at byte {i}"),
390            )),
391            None => Ok(()),
392        }
393    }
394
395    fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
396        let nonce = self.nonce()?;
397        self.eat(',')?;
398        let salt = self.salt()?;
399        self.eat(',')?;
400        let iteration_count = self.iteration_count()?;
401        self.eof()?;
402
403        Ok(ServerFirstMessage {
404            nonce,
405            salt,
406            iteration_count,
407        })
408    }
409
410    fn value(&mut self) -> io::Result<&'a str> {
411        self.take_while(|c| !matches!(c, '\0' | '=' | ','))
412    }
413
414    fn server_error(&mut self) -> io::Result<Option<&'a str>> {
415        match self.it.peek() {
416            Some(&(_, 'e')) => {}
417            _ => return Ok(None),
418        }
419
420        self.eat('e')?;
421        self.eat('=')?;
422        self.value().map(Some)
423    }
424
425    fn verifier(&mut self) -> io::Result<&'a str> {
426        self.eat('v')?;
427        self.eat('=')?;
428        self.base64()
429    }
430
431    fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
432        let message = match self.server_error()? {
433            Some(error) => ServerFinalMessage::Error(error),
434            None => ServerFinalMessage::Verifier(self.verifier()?),
435        };
436        self.eof()?;
437        Ok(message)
438    }
439}
440
441struct ServerFirstMessage<'a> {
442    nonce: &'a str,
443    salt: &'a str,
444    iteration_count: u32,
445}
446
447enum ServerFinalMessage<'a> {
448    Error(&'a str),
449    Verifier(&'a str),
450}
451
452#[cfg(test)]
453mod test {
454    use super::*;
455
456    #[test]
457    fn parse_server_first_message() {
458        let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
459        let message = Parser::new(message).server_first_message().unwrap();
460        assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
461        assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
462        assert_eq!(message.iteration_count, 4096);
463    }
464
465    #[test]
466    fn parse_server_error_message() {
467        let message = "e=invalid-proof";
468        match Parser::new(message).server_final_message().unwrap() {
469            ServerFinalMessage::Error(error) => assert_eq!(error, "invalid-proof"),
470            ServerFinalMessage::Verifier(_) => panic!("expected server error"),
471        }
472
473        // the error value ends at the first '\0', '=' or ','
474        for message in ["invalid-proof\0x", "invalid-proof=x", "invalid-proof,x"] {
475            assert_eq!(Parser::new(message).value().unwrap(), "invalid-proof");
476        }
477    }
478
479    // recorded auth exchange from psql
480    #[test]
481    fn exchange() {
482        let password = "foobar";
483        let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
484
485        let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
486        let server_first = "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
487             =4096";
488        let client_final = "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
489             1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
490        let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
491
492        let mut scram = ScramSha256::new_inner(
493            password.as_bytes(),
494            ChannelBinding::unsupported(),
495            nonce.to_string(),
496        );
497        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
498
499        scram.update(server_first.as_bytes()).unwrap();
500        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
501
502        scram.finish(server_final.as_bytes()).unwrap();
503    }
504
505    #[test]
506    fn excessive_iteration_count_is_rejected() {
507        // a malicious server cannot force an unbounded PBKDF2 loop; the iteration
508        // count is rejected before `hi()` runs.
509        let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
510        let server_first =
511            "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i=1000000";
512
513        let mut scram =
514            ScramSha256::new_inner(b"foobar", ChannelBinding::unsupported(), nonce.to_string());
515        assert!(scram.update(server_first.as_bytes()).is_err());
516    }
517}