gaussdb_protocol/authentication/
sasl.rs

1//! SASL-based authentication support.
2
3use base64::display::Base64Display;
4use base64::engine::general_purpose::STANDARD;
5use base64::Engine;
6use hmac::{Hmac, Mac};
7use rand::{self, Rng};
8use sha2::digest::FixedOutput;
9use sha2::{Digest, Sha256};
10use std::fmt::Write;
11use std::io;
12use std::iter;
13use std::mem;
14use std::str;
15
16const NONCE_LENGTH: usize = 24;
17
18/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
19pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
20/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
21pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
22
23// since postgres passwords are not required to exclude saslprep-prohibited
24// characters or even be valid UTF8, we run saslprep if possible and otherwise
25// return the raw password.
26fn normalize(pass: &[u8]) -> Vec<u8> {
27    let pass = match str::from_utf8(pass) {
28        Ok(pass) => pass,
29        Err(_) => return pass.to_vec(),
30    };
31
32    match stringprep::saslprep(pass) {
33        Ok(pass) => pass.into_owned().into_bytes(),
34        Err(_) => pass.as_bytes().to_vec(),
35    }
36}
37
38pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
39    let mut hmac =
40        Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
41    hmac.update(salt);
42    hmac.update(&[0, 0, 0, 1]);
43    let mut prev = hmac.finalize().into_bytes();
44
45    let mut hi = prev;
46
47    for _ in 1..i {
48        let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
49        hmac.update(&prev);
50        prev = hmac.finalize().into_bytes();
51
52        for (hi, prev) in hi.iter_mut().zip(prev) {
53            *hi ^= prev;
54        }
55    }
56
57    hi.into()
58}
59
60enum ChannelBindingInner {
61    Unrequested,
62    Unsupported,
63    TlsServerEndPoint(Vec<u8>),
64}
65
66/// The channel binding configuration for a SCRAM authentication exchange.
67pub struct ChannelBinding(ChannelBindingInner);
68
69impl ChannelBinding {
70    /// The server did not request channel binding.
71    pub fn unrequested() -> ChannelBinding {
72        ChannelBinding(ChannelBindingInner::Unrequested)
73    }
74
75    /// The server requested channel binding but the client is unable to provide it.
76    pub fn unsupported() -> ChannelBinding {
77        ChannelBinding(ChannelBindingInner::Unsupported)
78    }
79
80    /// The server requested channel binding and the client will use the `tls-server-end-point`
81    /// method.
82    pub fn tls_server_end_point(signature: Vec<u8>) -> ChannelBinding {
83        ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature))
84    }
85
86    fn gs2_header(&self) -> &'static str {
87        match self.0 {
88            ChannelBindingInner::Unrequested => "y,,",
89            ChannelBindingInner::Unsupported => "n,,",
90            ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,",
91        }
92    }
93
94    fn cbind_data(&self) -> &[u8] {
95        match self.0 {
96            ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[],
97            ChannelBindingInner::TlsServerEndPoint(ref buf) => buf,
98        }
99    }
100}
101
102enum State {
103    Update {
104        nonce: String,
105        password: Vec<u8>,
106        channel_binding: ChannelBinding,
107    },
108    Finish {
109        salted_password: [u8; 32],
110        auth_message: String,
111    },
112    Done,
113}
114
115/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication
116/// process.
117///
118/// During the authentication process, if the backend sends an `AuthenticationSASL` message which
119/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used.
120///
121/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be
122/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name.
123///
124/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be
125/// passed to the `update()` method, after which the buffer returned by the `message()` method
126/// should be sent to the backend in a `SASLResponse` message.
127///
128/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed
129/// to the `finish()` method, after which the authentication process is complete.
130pub struct ScramSha256 {
131    message: String,
132    state: State,
133}
134
135impl ScramSha256 {
136    /// Constructs a new instance which will use the provided password for authentication.
137    pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 {
138        // rand 0.5's ThreadRng is cryptographically secure
139        let mut rng = rand::rng();
140        let nonce = (0..NONCE_LENGTH)
141            .map(|_| {
142                let mut v = rng.random_range(0x21u8..0x7e);
143                if v == 0x2c {
144                    v = 0x7e
145                }
146                v as char
147            })
148            .collect::<String>();
149
150        ScramSha256::new_inner(password, channel_binding, nonce)
151    }
152
153    fn new_inner(password: &[u8], channel_binding: ChannelBinding, nonce: String) -> ScramSha256 {
154        ScramSha256 {
155            message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce),
156            state: State::Update {
157                nonce,
158                password: normalize(password),
159                channel_binding,
160            },
161        }
162    }
163
164    /// Returns the message which should be sent to the backend in an `SASLResponse` message.
165    pub fn message(&self) -> &[u8] {
166        if let State::Done = self.state {
167            panic!("invalid SCRAM state");
168        }
169        self.message.as_bytes()
170    }
171
172    /// Updates the state machine with the response from the backend.
173    ///
174    /// This should be called when an `AuthenticationSASLContinue` message is received.
175    pub fn update(&mut self, message: &[u8]) -> io::Result<()> {
176        let (client_nonce, password, channel_binding) =
177            match mem::replace(&mut self.state, State::Done) {
178                State::Update {
179                    nonce,
180                    password,
181                    channel_binding,
182                } => (nonce, password, channel_binding),
183                _ => return Err(io::Error::other("invalid SCRAM state")),
184            };
185
186        let message =
187            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
188
189        let parsed = Parser::new(message).server_first_message()?;
190
191        if !parsed.nonce.starts_with(&client_nonce) {
192            return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce"));
193        }
194
195        let salt = match STANDARD.decode(parsed.salt) {
196            Ok(salt) => salt,
197            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
198        };
199
200        let salted_password = hi(&password, &salt, parsed.iteration_count);
201
202        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
203            .expect("HMAC is able to accept all key sizes");
204        hmac.update(b"Client Key");
205        let client_key = hmac.finalize().into_bytes();
206
207        let mut hash = Sha256::default();
208        hash.update(client_key.as_slice());
209        let stored_key = hash.finalize_fixed();
210
211        let mut cbind_input = vec![];
212        cbind_input.extend(channel_binding.gs2_header().as_bytes());
213        cbind_input.extend(channel_binding.cbind_data());
214        let cbind_input = STANDARD.encode(&cbind_input);
215
216        self.message.clear();
217        write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap();
218
219        let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message);
220
221        let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
222            .expect("HMAC is able to accept all key sizes");
223        hmac.update(auth_message.as_bytes());
224        let client_signature = hmac.finalize().into_bytes();
225
226        let mut client_proof = client_key;
227        for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
228            *proof ^= signature;
229        }
230
231        write!(
232            &mut self.message,
233            ",p={}",
234            Base64Display::new(&client_proof, &STANDARD)
235        )
236        .unwrap();
237
238        self.state = State::Finish {
239            salted_password,
240            auth_message,
241        };
242        Ok(())
243    }
244
245    /// Finalizes the authentication process.
246    ///
247    /// This should be called when the backend sends an `AuthenticationSASLFinal` message.
248    /// Authentication has only succeeded if this method returns `Ok(())`.
249    pub fn finish(&mut self, message: &[u8]) -> io::Result<()> {
250        let (salted_password, auth_message) = match mem::replace(&mut self.state, State::Done) {
251            State::Finish {
252                salted_password,
253                auth_message,
254            } => (salted_password, auth_message),
255            _ => return Err(io::Error::other("invalid SCRAM state")),
256        };
257
258        let message =
259            str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
260
261        let parsed = Parser::new(message).server_final_message()?;
262
263        let verifier = match parsed {
264            ServerFinalMessage::Error(e) => {
265                return Err(io::Error::other(format!("SCRAM error: {}", e)));
266            }
267            ServerFinalMessage::Verifier(verifier) => verifier,
268        };
269
270        let verifier = match STANDARD.decode(verifier) {
271            Ok(verifier) => verifier,
272            Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)),
273        };
274
275        let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
276            .expect("HMAC is able to accept all key sizes");
277        hmac.update(b"Server Key");
278        let server_key = hmac.finalize().into_bytes();
279
280        let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
281            .expect("HMAC is able to accept all key sizes");
282        hmac.update(auth_message.as_bytes());
283        hmac.verify_slice(&verifier)
284            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error"))
285    }
286}
287
288struct Parser<'a> {
289    s: &'a str,
290    it: iter::Peekable<str::CharIndices<'a>>,
291}
292
293impl<'a> Parser<'a> {
294    fn new(s: &'a str) -> Parser<'a> {
295        Parser {
296            s,
297            it: s.char_indices().peekable(),
298        }
299    }
300
301    fn eat(&mut self, target: char) -> io::Result<()> {
302        match self.it.next() {
303            Some((_, c)) if c == target => Ok(()),
304            Some((i, c)) => {
305                let m = format!(
306                    "unexpected character at byte {}: expected `{}` but got `{}",
307                    i, target, c
308                );
309                Err(io::Error::new(io::ErrorKind::InvalidInput, m))
310            }
311            None => Err(io::Error::new(
312                io::ErrorKind::UnexpectedEof,
313                "unexpected EOF",
314            )),
315        }
316    }
317
318    fn take_while<F>(&mut self, f: F) -> io::Result<&'a str>
319    where
320        F: Fn(char) -> bool,
321    {
322        let start = match self.it.peek() {
323            Some(&(i, _)) => i,
324            None => return Ok(""),
325        };
326
327        loop {
328            match self.it.peek() {
329                Some(&(_, c)) if f(c) => {
330                    self.it.next();
331                }
332                Some(&(i, _)) => return Ok(&self.s[start..i]),
333                None => return Ok(&self.s[start..]),
334            }
335        }
336    }
337
338    fn printable(&mut self) -> io::Result<&'a str> {
339        self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e'))
340    }
341
342    fn nonce(&mut self) -> io::Result<&'a str> {
343        self.eat('r')?;
344        self.eat('=')?;
345        self.printable()
346    }
347
348    fn base64(&mut self) -> io::Result<&'a str> {
349        self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '='))
350    }
351
352    fn salt(&mut self) -> io::Result<&'a str> {
353        self.eat('s')?;
354        self.eat('=')?;
355        self.base64()
356    }
357
358    fn posit_number(&mut self) -> io::Result<u32> {
359        let n = self.take_while(|c| c.is_ascii_digit())?;
360        n.parse()
361            .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))
362    }
363
364    fn iteration_count(&mut self) -> io::Result<u32> {
365        self.eat('i')?;
366        self.eat('=')?;
367        self.posit_number()
368    }
369
370    fn eof(&mut self) -> io::Result<()> {
371        match self.it.peek() {
372            Some(&(i, _)) => Err(io::Error::new(
373                io::ErrorKind::InvalidInput,
374                format!("unexpected trailing data at byte {}", i),
375            )),
376            None => Ok(()),
377        }
378    }
379
380    fn server_first_message(&mut self) -> io::Result<ServerFirstMessage<'a>> {
381        let nonce = self.nonce()?;
382        self.eat(',')?;
383        let salt = self.salt()?;
384        self.eat(',')?;
385        let iteration_count = self.iteration_count()?;
386        self.eof()?;
387
388        Ok(ServerFirstMessage {
389            nonce,
390            salt,
391            iteration_count,
392        })
393    }
394
395    fn value(&mut self) -> io::Result<&'a str> {
396        self.take_while(|c| matches!(c, '\0' | '=' | ','))
397    }
398
399    fn server_error(&mut self) -> io::Result<Option<&'a str>> {
400        match self.it.peek() {
401            Some(&(_, 'e')) => {}
402            _ => return Ok(None),
403        }
404
405        self.eat('e')?;
406        self.eat('=')?;
407        self.value().map(Some)
408    }
409
410    fn verifier(&mut self) -> io::Result<&'a str> {
411        self.eat('v')?;
412        self.eat('=')?;
413        self.base64()
414    }
415
416    fn server_final_message(&mut self) -> io::Result<ServerFinalMessage<'a>> {
417        let message = match self.server_error()? {
418            Some(error) => ServerFinalMessage::Error(error),
419            None => ServerFinalMessage::Verifier(self.verifier()?),
420        };
421        self.eof()?;
422        Ok(message)
423    }
424}
425
426struct ServerFirstMessage<'a> {
427    nonce: &'a str,
428    salt: &'a str,
429    iteration_count: u32,
430}
431
432enum ServerFinalMessage<'a> {
433    Error(&'a str),
434    Verifier(&'a str),
435}
436
437#[cfg(test)]
438mod test {
439    use super::*;
440
441    #[test]
442    fn parse_server_first_message() {
443        let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096";
444        let message = Parser::new(message).server_first_message().unwrap();
445        assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
446        assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
447        assert_eq!(message.iteration_count, 4096);
448    }
449
450    // recorded auth exchange from psql
451    #[test]
452    fn exchange() {
453        let password = "foobar";
454        let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
455
456        let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
457        let server_first =
458            "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
459             =4096";
460        let client_final =
461            "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
462             1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
463        let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
464
465        let mut scram = ScramSha256::new_inner(
466            password.as_bytes(),
467            ChannelBinding::unsupported(),
468            nonce.to_string(),
469        );
470        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first);
471
472        scram.update(server_first.as_bytes()).unwrap();
473        assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final);
474
475        scram.finish(server_final.as_bytes()).unwrap();
476    }
477}