gel_auth/scram/
mod.rs

1//! # SCRAM (Salted Challenge Response Authentication Mechanism)
2//!
3//! SCRAM is a protocol that securely authenticates users by hashing and salting
4//! their passwords before exchanging them, preventing exposure of the actual
5//! password during transmission. It is commonly used in applications and
6//! protocols like Postgres and SASL to enhance security against common attacks
7//! such as replay and man-in-the-middle attacks.
8//!
9//! <https://en.wikipedia.org/wiki/Salted_Challenge_Response_Authentication_Mechanism>
10//!
11//! ## Limitations of this implementation
12//!
13//! This code implements a sufficient form of SCRAM to authenticate to most
14//! PostgreSQL instances. It does not implement channel binding, and its
15//! implementation is likely not generic enough to work with any other
16//! implementation of SCRAM that isn't designed to be PostgreSQL-compatible.
17//!
18//! ## Transaction
19//!
20//! The transaction consists of four steps:
21//!
22//! 1. Client's Initial Response: The client sends its username and initial
23//!    nonce.
24//! 2. Server's Challenge: The server responds with a combined nonce, a
25//!    base64-encoded salt, and an iteration count for the PBKDF2 algorithm.
26//! 3. Client's Proof: The client sends its proof of possession of the password,
27//!    along with the combined nonce and base64-encoded channel binding data.
28//! 4. Server's Final Response: The server sends its verifier, proving
29//!    successful authentication.
30//!
31//! This transaction securely authenticates the client to the server without
32//! transmitting the actual password.
33//!
34//! ## Parameters
35//!
36//! The following parameters are used in the SCRAM authentication exchange:
37//!
38//! * `r=` (nonce): A random string generated by the client and server to ensure
39//!   the uniqueness of each authentication exchange. The client initially sends
40//!   its nonce, and the server responds with a combined nonce (client’s nonce +
41//!   server’s nonce).
42//!
43//! * `c=` (channel binding): A base64-encoded representation of the channel
44//!   binding data. This parameter is used to bind the authentication to the
45//!   specific channel over which it is occurring, ensuring the integrity of the
46//!   communication channel.
47//!
48//! * `s=` (salt): A base64-encoded salt provided by the server. The salt is
49//!   used in conjunction with the client’s password to generate a salted
50//!   password for enhanced security.
51//!
52//! * `i=` (iteration count): The number of iterations to apply in the PBKDF2
53//!   (Password-Based Key Derivation Function 2) algorithm. This parameter
54//!   defines the computational cost of generating the salted password.
55//!
56//! * `n=` (name): The username of the client. This parameter is included in the
57//!   client’s initial response. Note that the username is sent out-of-band when
58//!   using PostgreSQL and this parameter is set to an empty value.
59//!
60//! * `p=` (proof): The client’s proof of possession of the password. This is a
61//!   base64-encoded value calculated using the salted password and other SCRAM
62//!   parameters to prove that the client knows the password without sending it
63//!   directly.
64//!
65//! * `v=` (verifier): The server’s verifier, which is used to prove that the
66//!   server also knows the shared secret. This parameter is included in the
67//!   server’s final message to confirm successful authentication.
68#![allow(unused)]
69
70use base64::{prelude::BASE64_STANDARD, Engine};
71use hmac::{Hmac, Mac};
72use rand::Rng;
73use sha2::{digest::FixedOutput, Digest, Sha256};
74use std::borrow::Cow;
75use std::str::FromStr;
76
77pub mod stringprep;
78mod stringprep_table;
79
80use stringprep::sasl_normalize_password_bytes;
81
82const CHANNEL_BINDING_ENCODED: &str = "biws";
83const MINIMUM_NONCE_LENGTH: usize = 16;
84
85type HmacSha256 = Hmac<Sha256>;
86pub type Sha256Out = [u8; 32];
87
88#[derive(Debug, derive_more::Error, derive_more::Display)]
89pub enum SCRAMError {
90    #[display("Invalid encoding")]
91    ProtocolError,
92}
93
94pub trait ServerEnvironment {
95    fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize);
96    fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out);
97    fn generate_nonce(&self) -> String;
98}
99
100#[derive(Default, derive_more::Debug)]
101pub struct ServerTransaction {
102    #[debug(skip)]
103    state: ServerState,
104}
105
106impl ServerTransaction {
107    pub fn success(&self) -> bool {
108        matches!(self.state, ServerState::Success)
109    }
110
111    pub fn initial(&self) -> bool {
112        matches!(self.state, ServerState::Initial)
113    }
114
115    pub fn process_message(
116        &mut self,
117        message: &[u8],
118        env: &impl ServerEnvironment,
119    ) -> Result<Vec<u8>, SCRAMError> {
120        match &self.state {
121            ServerState::Success => Err(SCRAMError::ProtocolError),
122            ServerState::Initial => {
123                let message = ClientFirstMessage::decode(message)?;
124                if message.channel_binding != ChannelBinding::NotSupported("".into()) {
125                    return Err(SCRAMError::ProtocolError);
126                }
127                if message.nonce.len() < MINIMUM_NONCE_LENGTH {
128                    return Err(SCRAMError::ProtocolError);
129                }
130                let (salt, iterations) = env.get_password_parameters(&message.username);
131                let mut nonce = message.nonce.to_string();
132                nonce += &env.generate_nonce();
133                let response = ServerFirstResponse {
134                    combined_nonce: nonce.to_string().into(),
135                    salt: BASE64_STANDARD.encode(salt).into(),
136                    iterations,
137                };
138                self.state =
139                    ServerState::SentChallenge(message.to_owned_bare(), response.to_owned());
140                Ok(response.encode().into_bytes())
141            }
142            ServerState::SentChallenge(first_message, first_response) => {
143                let message = ClientFinalMessage::decode(message)?;
144                if !constant_time_eq::constant_time_eq(
145                    message.combined_nonce.as_bytes(),
146                    first_response.combined_nonce.as_bytes(),
147                ) {
148                    return Err(SCRAMError::ProtocolError);
149                }
150                if message.channel_binding != CHANNEL_BINDING_ENCODED {
151                    return Err(SCRAMError::ProtocolError);
152                }
153                let (stored_key, server_key) = env.get_stored_key(&first_message.username);
154
155                // Decode the provided client proof
156                let mut provided_proof = vec![];
157                BASE64_STANDARD
158                    .decode_vec(message.proof.as_bytes(), &mut provided_proof)
159                    .map_err(|_| SCRAMError::ProtocolError)?;
160
161                let (calculated_stored_key, server_signature) = generate_server_proof(
162                    first_message.encode().as_bytes(),
163                    first_response.encode().as_bytes(),
164                    message.channel_binding.as_bytes(),
165                    message.combined_nonce.as_bytes(),
166                    &provided_proof,
167                    &server_key,
168                    &stored_key,
169                );
170
171                if !constant_time_eq::constant_time_eq(
172                    calculated_stored_key.as_slice(),
173                    &stored_key,
174                ) {
175                    return Err(SCRAMError::ProtocolError);
176                }
177
178                self.state = ServerState::Success;
179                let verifier = BASE64_STANDARD.encode(server_signature).into();
180                Ok(ServerFinalResponse { verifier }.encode().into_bytes())
181            }
182        }
183    }
184}
185
186#[derive(Default)]
187enum ServerState {
188    #[default]
189    Initial,
190    SentChallenge(ClientFirstMessage<'static>, ServerFirstResponse<'static>),
191    Success,
192}
193
194pub trait ClientEnvironment {
195    fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out;
196    fn generate_nonce(&self) -> String;
197}
198
199#[derive(Debug)]
200pub struct ClientTransaction {
201    state: ClientState,
202}
203
204impl ClientTransaction {
205    pub fn new(username: Cow<'static, str>) -> Self {
206        Self {
207            state: ClientState::Initial(username),
208        }
209    }
210
211    pub fn success(&self) -> bool {
212        matches!(self.state, ClientState::Success)
213    }
214
215    pub fn process_message(
216        &mut self,
217        message: &[u8],
218        env: &impl ClientEnvironment,
219    ) -> Result<Option<Vec<u8>>, SCRAMError> {
220        match &self.state {
221            ClientState::Success => Err(SCRAMError::ProtocolError),
222            ClientState::Initial(username) => {
223                if !message.is_empty() {
224                    return Err(SCRAMError::ProtocolError);
225                }
226                let nonce = env.generate_nonce().into();
227                let message = ClientFirstMessage {
228                    channel_binding: ChannelBinding::NotSupported("".into()),
229                    username: username.clone(),
230                    nonce,
231                };
232                self.state = ClientState::SentFirst(message.to_owned_bare());
233                Ok(Some(message.encode().into_bytes()))
234            }
235            ClientState::SentFirst(first_message) => {
236                let message = ServerFirstResponse::decode(message)?;
237                // Ensure the client nonce was concatenated with the server's nonce
238                if !message
239                    .combined_nonce
240                    .starts_with(first_message.nonce.as_ref())
241                {
242                    return Err(SCRAMError::ProtocolError);
243                }
244                if message.combined_nonce.len() - first_message.nonce.len() < MINIMUM_NONCE_LENGTH {
245                    return Err(SCRAMError::ProtocolError);
246                }
247                let mut buffer = [0; 1024];
248                let salt = decode_salt(&message.salt, &mut buffer)?;
249                let salted_password = env.get_salted_password(&salt, message.iterations);
250                let (client_proof, server_verifier) = generate_client_proof(
251                    first_message.encode().as_bytes(),
252                    message.encode().as_bytes(),
253                    CHANNEL_BINDING_ENCODED.as_bytes(),
254                    message.combined_nonce.as_bytes(),
255                    &salted_password,
256                );
257                let message = ClientFinalMessage {
258                    channel_binding: CHANNEL_BINDING_ENCODED.into(),
259                    combined_nonce: message.combined_nonce.to_string().into(),
260                    proof: BASE64_STANDARD.encode(client_proof).into(),
261                };
262                self.state = ClientState::ExpectingVerifier(ServerFinalResponse {
263                    verifier: BASE64_STANDARD.encode(server_verifier).into(),
264                });
265                Ok(Some(message.encode().into_bytes()))
266            }
267            ClientState::ExpectingVerifier(server_final_response) => {
268                let message = ServerFinalResponse::decode(message)?;
269                if !constant_time_eq::constant_time_eq(
270                    message.verifier.as_bytes(),
271                    server_final_response.verifier.as_bytes(),
272                ) {
273                    return Err(SCRAMError::ProtocolError);
274                }
275                self.state = ClientState::Success;
276                Ok(None)
277            }
278        }
279    }
280}
281
282#[derive(Debug)]
283enum ClientState {
284    Initial(Cow<'static, str>),
285    SentFirst(ClientFirstMessage<'static>),
286    ExpectingVerifier(ServerFinalResponse<'static>),
287    Success,
288}
289
290trait Encode {
291    fn encode(&self) -> String;
292}
293
294trait Decode<'a> {
295    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError>
296    where
297        Self: Sized + 'a;
298}
299
300fn extract<'a>(input: &'a [u8], prefix: &'static str) -> Result<&'a str, SCRAMError> {
301    let bytes = input
302        .strip_prefix(prefix.as_bytes())
303        .ok_or(SCRAMError::ProtocolError)?;
304    std::str::from_utf8(bytes).map_err(|_| SCRAMError::ProtocolError)
305}
306
307fn inext<'a>(it: &mut impl Iterator<Item = &'a [u8]>) -> Result<&'a [u8], SCRAMError> {
308    it.next().ok_or(SCRAMError::ProtocolError)
309}
310
311fn hmac(s: &[u8]) -> HmacSha256 {
312    // This is effectively infallible
313    HmacSha256::new_from_slice(s).expect("HMAC can take key of any size")
314}
315
316#[derive(Debug, Clone, PartialEq, Eq)]
317/// `gs2-cbind-flag` from RFC5802.
318enum ChannelBinding<'a> {
319    /// No channel binding
320    NotSpecified,
321    /// "n" -> client doesn't support channel binding.
322    NotSupported(Cow<'a, str>),
323    /// "y" -> client does support channel binding but thinks the server does
324    /// not.
325    Supported(Cow<'a, str>),
326    /// "p" -> client requires channel binding. The selected channel binding
327    /// follows "p=".
328    Required(Cow<'a, str>, Cow<'a, str>),
329}
330
331#[derive(Debug)]
332pub struct ClientFirstMessage<'a> {
333    channel_binding: ChannelBinding<'a>,
334    username: Cow<'a, str>,
335    nonce: Cow<'a, str>,
336}
337
338impl ClientFirstMessage<'_> {
339    /// Get the bare first message
340    pub fn to_owned_bare(&self) -> ClientFirstMessage<'static> {
341        ClientFirstMessage {
342            channel_binding: ChannelBinding::NotSpecified,
343            username: self.username.to_string().into(),
344            nonce: self.nonce.to_string().into(),
345        }
346    }
347}
348
349impl Encode for ClientFirstMessage<'_> {
350    fn encode(&self) -> String {
351        let channel_binding = match self.channel_binding {
352            ChannelBinding::NotSpecified => "".to_string(),
353            ChannelBinding::NotSupported(ref s) => format!("n,{s},"),
354            ChannelBinding::Supported(ref s) => format!("y,{s},"),
355            ChannelBinding::Required(ref s, ref t) => format!("p={t},{s},"),
356        };
357        format!("{channel_binding}n={},r={}", self.username, self.nonce)
358    }
359}
360
361impl<'a> Decode<'a> for ClientFirstMessage<'a> {
362    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
363        let mut parts = buf.split(|&b| b == b',');
364
365        // Check for channel binding
366        let mut next = inext(&mut parts)?;
367        let mut channel_binding = ChannelBinding::NotSpecified;
368        match (next.len(), next.first()) {
369            (_, Some(b'p')) => {
370                // p=(cb-name),(authz-id),
371                let Some(cb_name) = next.strip_prefix(b"p=") else {
372                    return Err(SCRAMError::ProtocolError);
373                };
374                let cb_name =
375                    std::str::from_utf8(cb_name).map_err(|_| SCRAMError::ProtocolError)?;
376                let param = inext(&mut parts)?;
377                channel_binding = ChannelBinding::Required(
378                    Cow::Borrowed(
379                        std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?,
380                    ),
381                    cb_name.into(),
382                );
383                next = inext(&mut parts)?;
384            }
385            (1, Some(b'n')) => {
386                let param = inext(&mut parts)?;
387                channel_binding = ChannelBinding::NotSupported(Cow::Borrowed(
388                    std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?,
389                ));
390                next = inext(&mut parts)?;
391            }
392            (1, Some(b'y')) => {
393                let param = inext(&mut parts)?;
394                channel_binding = ChannelBinding::Supported(Cow::Borrowed(
395                    std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?,
396                ));
397                next = inext(&mut parts)?;
398            }
399            (_, None) => {
400                return Err(SCRAMError::ProtocolError);
401            }
402            _ => {
403                // No channel binding specified
404            }
405        }
406        let username = extract(next, "n=")?.into();
407        let nonce = extract(inext(&mut parts)?, "r=")?.into();
408        Ok(ClientFirstMessage {
409            channel_binding,
410            username,
411            nonce,
412        })
413    }
414}
415
416pub struct ServerFirstResponse<'a> {
417    combined_nonce: Cow<'a, str>,
418    salt: Cow<'a, str>,
419    iterations: usize,
420}
421
422impl ServerFirstResponse<'_> {
423    pub fn to_owned(&self) -> ServerFirstResponse<'static> {
424        ServerFirstResponse {
425            combined_nonce: self.combined_nonce.to_string().into(),
426            salt: self.salt.to_string().into(),
427            iterations: self.iterations,
428        }
429    }
430}
431
432impl Encode for ServerFirstResponse<'_> {
433    fn encode(&self) -> String {
434        format!(
435            "r={},s={},i={}",
436            self.combined_nonce, self.salt, self.iterations
437        )
438    }
439}
440
441impl<'a> Decode<'a> for ServerFirstResponse<'a> {
442    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
443        let mut parts = buf.split(|&b| b == b',');
444        let combined_nonce = extract(inext(&mut parts)?, "r=")?.into();
445        let salt = extract(inext(&mut parts)?, "s=")?.into();
446        let iterations = extract(inext(&mut parts)?, "i=")?;
447        Ok(ServerFirstResponse {
448            combined_nonce,
449            salt,
450            iterations: str::parse(iterations).map_err(|_| SCRAMError::ProtocolError)?,
451        })
452    }
453}
454
455pub struct ClientFinalMessage<'a> {
456    channel_binding: Cow<'a, str>,
457    combined_nonce: Cow<'a, str>,
458    proof: Cow<'a, str>,
459}
460
461impl Encode for ClientFinalMessage<'_> {
462    fn encode(&self) -> String {
463        format!(
464            "c={},r={},p={}",
465            self.channel_binding, self.combined_nonce, self.proof
466        )
467    }
468}
469
470impl<'a> Decode<'a> for ClientFinalMessage<'a> {
471    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
472        let mut parts = buf.split(|&b| b == b',');
473        let channel_binding = extract(inext(&mut parts)?, "c=")?.into();
474        let combined_nonce = extract(inext(&mut parts)?, "r=")?.into();
475        let proof = extract(inext(&mut parts)?, "p=")?.into();
476        Ok(ClientFinalMessage {
477            channel_binding,
478            combined_nonce,
479            proof,
480        })
481    }
482}
483
484#[derive(Debug)]
485pub struct ServerFinalResponse<'a> {
486    verifier: Cow<'a, str>,
487}
488
489impl Encode for ServerFinalResponse<'_> {
490    fn encode(&self) -> String {
491        format!("v={}", self.verifier)
492    }
493}
494
495impl<'a> Decode<'a> for ServerFinalResponse<'a> {
496    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
497        let mut parts = buf.split(|&b| b == b',');
498        let verifier = extract(inext(&mut parts)?, "v=")?.into();
499        Ok(ServerFinalResponse { verifier })
500    }
501}
502
503pub fn decode_salt<'a>(salt: &str, buffer: &'a mut [u8]) -> Result<Cow<'a, [u8]>, SCRAMError> {
504    // The salt needs to be base64 decoded -- full binary must be used
505    if let Ok(n) = BASE64_STANDARD.decode_slice(salt, buffer) {
506        Ok(Cow::Borrowed(&buffer[..n]))
507    } else {
508        // In the unlikely case the salt is large -- note that we also fall back to this
509        // path for invalid base64 strings!
510        let mut buffer = vec![];
511        BASE64_STANDARD
512            .decode_vec(salt, &mut buffer)
513            .map_err(|_| SCRAMError::ProtocolError)?;
514        Ok(Cow::Owned(buffer))
515    }
516}
517
518/// Given a password in byte form, generates the salted version of the password,
519/// applying SASLprep to it beforehand.
520pub fn generate_salted_password(password: &[u8], salt: &[u8], iterations: usize) -> Sha256Out {
521    // Save the pre-keyed hmac
522    let ui_p = hmac(&sasl_normalize_password_bytes(password));
523
524    // The initial signature is the salt with a terminator of a 32-bit string ending in 1
525    let mut ui = ui_p.clone();
526
527    ui.update(salt);
528    ui.update(&[0, 0, 0, 1]);
529
530    // Grab the initial digest
531    let mut last_hash = Default::default();
532    ui.finalize_into(&mut last_hash);
533    let mut u = last_hash;
534
535    // For X number of iterations, recompute the HMAC signature against the password and the latest iteration of the hash, and XOR it with the previous version
536    for _ in 0..(iterations - 1) {
537        let mut ui = ui_p.clone();
538        ui.update(&last_hash);
539        ui.finalize_into(&mut last_hash);
540        for i in 0..u.len() {
541            u[i] ^= last_hash[i];
542        }
543    }
544
545    u.as_slice().try_into().unwrap()
546}
547
548pub fn generate_nonce() -> String {
549    let bytes: [u8; 32] = rand::random();
550    BASE64_STANDARD.encode(bytes)
551}
552
553/// A stored SCRAM-SHA-256 key.
554///
555/// The SCRAM key format consists of several components separated by '$' and ':'
556/// characters:
557///
558/// `"SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>"`
559///
560/// Where:
561///  - `iterations`: Number of PBKDF2-HMAC-SHA256 iterations used for key
562///    derivation
563///  - `salt`: Base64-encoded cryptographically secure random salt used in key
564///    derivation
565///  - `stored_key`: Hash of the client key, where client key is derived as
566///    `SHA-256(HMAC-SHA-256(salted_password, "Client Key"))`
567///  - `server_key`: Server key derived as `HMAC-SHA-256(salted_password,
568///    "Server Key")`
569///
570/// The `stored_key` and `server_key` are pre-computed cryptographic values that
571/// prevent storing the raw password while maintaining secure authentication.
572/// The `stored_key` is a `hash(hmac(P, ...))` used to verify client
573/// authentication proofs, while the `server_key` is a `hmac(P, ...)` used to
574/// generate server authentication signatures.
575#[derive(Clone, Debug)]
576pub struct StoredKey {
577    pub iterations: usize,
578    pub salt: Vec<u8>,
579    pub stored_key: Sha256Out,
580    pub server_key: Sha256Out,
581}
582
583impl PartialEq for StoredKey {
584    fn eq(&self, other: &Self) -> bool {
585        // If the salt and stored_key match, the remainder must match.
586        // We only need to compare these two fields for equality because:
587        // 1. The salt is used to derive the stored_key, so if both match,
588        //    it implies the same password and iteration count were used.
589        // 2. The server_key is derived from the same process, so it will
590        //    automatically match if the salt and stored_key match.
591        // 3. The iterations count doesn't need to be compared explicitly
592        //    as it's factored into the stored_key calculation.
593        constant_time_eq::constant_time_eq(&self.salt, &other.salt)
594            && constant_time_eq::constant_time_eq(&self.stored_key, &other.stored_key)
595    }
596}
597
598impl Eq for StoredKey {}
599
600impl FromStr for StoredKey {
601    type Err = SCRAMError;
602
603    // "SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>"
604
605    fn from_str(s: &str) -> Result<Self, Self::Err> {
606        let parts: Vec<&str> = s.split('$').collect();
607        if parts.len() != 3 || parts[0] != "SCRAM-SHA-256" {
608            return Err(SCRAMError::ProtocolError);
609        }
610
611        let iterations = parts[1]
612            .split(':')
613            .next()
614            .ok_or(SCRAMError::ProtocolError)?
615            .parse()
616            .map_err(|_| SCRAMError::ProtocolError)?;
617
618        let salt = BASE64_STANDARD
619            .decode(
620                parts[1]
621                    .split(':')
622                    .nth(1)
623                    .ok_or(SCRAMError::ProtocolError)?,
624            )
625            .map_err(|_| SCRAMError::ProtocolError)?;
626
627        let key_parts: Vec<&str> = parts[2].split(':').collect();
628        if key_parts.len() != 2 {
629            return Err(SCRAMError::ProtocolError);
630        }
631
632        let stored_key = BASE64_STANDARD
633            .decode(key_parts[0])
634            .map_err(|_| SCRAMError::ProtocolError)?
635            .try_into()
636            .map_err(|_| SCRAMError::ProtocolError)?;
637
638        let server_key = BASE64_STANDARD
639            .decode(key_parts[1])
640            .map_err(|_| SCRAMError::ProtocolError)?
641            .try_into()
642            .map_err(|_| SCRAMError::ProtocolError)?;
643
644        Ok(StoredKey {
645            iterations,
646            salt,
647            stored_key,
648            server_key,
649        })
650    }
651}
652use std::fmt;
653
654impl fmt::Display for StoredKey {
655    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
656        write!(
657            f,
658            "SCRAM-SHA-256${}:{}${}:{}",
659            self.iterations,
660            BASE64_STANDARD.encode(&self.salt),
661            BASE64_STANDARD.encode(self.stored_key),
662            BASE64_STANDARD.encode(self.server_key)
663        )
664    }
665}
666
667impl ServerEnvironment for StoredKey {
668    fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) {
669        (Cow::Owned(self.salt.clone()), self.iterations)
670    }
671
672    fn generate_nonce(&self) -> String {
673        let nonce: [u8; 32] = rand::random();
674        base64::engine::general_purpose::STANDARD.encode(nonce)
675    }
676
677    fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) {
678        (self.stored_key, self.server_key)
679    }
680}
681
682impl StoredKey {
683    /// Generate a stored key compatible with PostgreSQL's encoding.
684    pub fn generate(password: &[u8], salt: &[u8], iterations: usize) -> Self {
685        let digest_key = generate_salted_password(password, salt, iterations);
686
687        let client_key = hmac(&digest_key)
688            .chain_update(b"Client Key")
689            .finalize()
690            .into_bytes();
691
692        let stored_key = Sha256::digest(client_key);
693
694        let server_key = hmac(&digest_key)
695            .chain_update(b"Server Key")
696            .finalize()
697            .into_bytes();
698
699        Self {
700            iterations,
701            salt: salt.to_owned(),
702            stored_key: stored_key.into(),
703            server_key: server_key.into(),
704        }
705    }
706}
707
708fn generate_client_proof(
709    first_message_bare: &[u8],
710    server_first_message: &[u8],
711    channel_binding: &[u8],
712    server_nonce: &[u8],
713    salted_password: &[u8],
714) -> (Sha256Out, Sha256Out) {
715    let ui_p = hmac(salted_password);
716
717    let mut ui = ui_p.clone();
718    ui.update(b"Server Key");
719    let server_key = ui.finalize_fixed();
720
721    let mut ui = ui_p.clone();
722    ui.update(b"Client Key");
723    let client_key = ui.finalize_fixed();
724
725    let mut hash = Sha256::new();
726    hash.update(client_key);
727    let stored_key = hash.finalize_fixed();
728
729    let auth_message = [
730        (first_message_bare),
731        (b","),
732        (server_first_message),
733        (b",c="),
734        (channel_binding),
735        (b",r="),
736        (server_nonce),
737    ];
738
739    let mut client_signature = hmac(&stored_key);
740    for chunk in auth_message {
741        client_signature.update(chunk);
742    }
743
744    let client_signature = client_signature.finalize_fixed();
745    let mut client_signature: Sha256Out = client_signature.as_slice().try_into().unwrap();
746
747    for i in 0..client_signature.len() {
748        client_signature[i] ^= client_key[i];
749    }
750
751    let mut server_proof = hmac(&server_key);
752    for chunk in auth_message {
753        server_proof.update(chunk);
754    }
755    let server_proof = server_proof.finalize_fixed().as_slice().try_into().unwrap();
756
757    (client_signature, server_proof)
758}
759
760fn generate_server_proof(
761    first_message_bare: &[u8],
762    server_first_message: &[u8],
763    channel_binding: &[u8],
764    server_nonce: &[u8],
765    provided_proof: &[u8],
766    server_key: &[u8],
767    stored_key: &[u8],
768) -> (Sha256Out, Sha256Out) {
769    let auth_message = [
770        (first_message_bare),
771        (b","),
772        (server_first_message),
773        (b",c="),
774        (channel_binding),
775        (b",r="),
776        (server_nonce),
777    ];
778
779    let mut client_signature = hmac(stored_key);
780    for chunk in &auth_message {
781        client_signature.update(chunk);
782    }
783    let client_signature = client_signature.finalize_fixed();
784
785    let mut calculated_stored_key = [0u8; 32];
786    for (i, (&p, &c)) in provided_proof
787        .iter()
788        .zip(client_signature.iter())
789        .enumerate()
790    {
791        calculated_stored_key[i] = p ^ c;
792    }
793
794    let calculated_stored_key = Sha256::digest(calculated_stored_key);
795
796    let mut server_signature = hmac(server_key);
797    for chunk in &auth_message {
798        server_signature.update(chunk);
799    }
800    let server_signature = server_signature.finalize_fixed();
801
802    (calculated_stored_key.into(), server_signature.into())
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808    use hex_literal::hex;
809    use pretty_assertions::{assert_eq, assert_ne};
810    use rstest::rstest;
811
812    // Define a set of test parameters
813    const CLIENT_NONCE: &str = "2XendqvQOa6cl0+Q7Y6UU0gw";
814    const SERVER_NONCE: &str = "xWn3mvDeVZwnUtT09vwXoItO";
815    const USERNAME: &str = "";
816    const PASSWORD: &[u8] = b"secret";
817    const SALT: &str = "t5YekvL6lgy4RyPnsiyqsg==";
818    const ITERATIONS: usize = 4096;
819    const CLIENT_PROOF: &[u8] = "p/HmDcOziQQnyF8fbVnJnlvwoLp1kZY4xsI9cCJhzCE=".as_bytes();
820    const SERVER_VERIFY: &[u8] = "g/X0codOryF0nCOWh7KkIab23ZFPX99iLzN5Ghn3nNc=".as_bytes();
821
822    #[rstest]
823    #[case(
824        b"1234",
825        "1234",
826        1,
827        hex!("EBE7E5BA4BF5A4D178D3BADAADD4C49A98C72FCFF4FB357DA7090D584990FCAA")
828    )]
829    #[case(
830        b"1234",
831        "1234",
832        2,
833        hex!("F9271C334EE6CD7FEE63BBC86FAF951A4ED9E293BDD72AC33663BAE662D31953")
834    )]
835    #[case(
836        b"1234",
837        "1234",
838        4096,
839        hex!("4FF8D6443278AB43209DF5A1327949AAC99A5AA23921E5C9199626524776F751")
840    )]
841    #[case(
842        b"password",
843        "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu",
844        4096,
845        hex!("E118A9AD43C87938659AD736E63F26BA2EBAF079AA351DB44AE29228FB4F7EF0")
846    )]
847    #[case(
848        b"secret",
849        "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu",
850        4096,
851        hex!("77DFD8E62A4379296C9769F9BA2F77D503C4647DE7919B47D6CF121986981BCC")
852    )]
853    #[case(
854        b"secret",
855        "t5YekvL6lgy4RyPnsiyqsg==",
856        4096,
857        hex!("9FB413FE9F1D0C8020400A3D49CFBC47FBFB1251CEA9297630BD025DB2B65171")
858    )]
859    #[case(
860        "😀".as_bytes(),
861        "t5YekvL6lgy4RyPnsiyqsg==",
862        4096,
863        hex!("AF490CE1BEA2DDB585DAF9C3842D1528AB091EF6FAB2A92489870523A98835EE")
864    )]
865    fn test_generate_salted_password(
866        #[case] password: &[u8],
867        #[case] salt: &str,
868        #[case] iterations: usize,
869        #[case] expected_hash: Sha256Out,
870    ) {
871        let mut buffer = [0; 128];
872        let salt = decode_salt(salt, &mut buffer).unwrap();
873        let hash = generate_salted_password(password, &salt, iterations);
874        assert_eq!(hash, expected_hash);
875    }
876
877    /// Tests that use real stored keys from postgres to match normalization
878    /// behaviour. This exercises the saslprep code indirectly to ensure it
879    /// matches the PostgreSQL implementation.
880    ///
881    /// Passwords in these tests were generated via `ALTER ROLE` DDL statements,
882    /// and the salted password was then extracted from the roles table.
883    ///
884    /// Note that a PostgreSQL user may have a password that is only value when
885    /// interpreted as a bag of bytes and cannot be set directly via `ALTER
886    /// ROLE` or the `initdb` command-line. This code _should_ support those
887    /// passwords, however given the complexity of testing this we do not
888    /// currently do so. Should we wish to test this in the future, we will need
889    /// to manually create the stored key, set this string as the password for
890    /// the role, and then validate that it can be used to authenticate via
891    /// SCRAM.
892    #[rstest]
893    // ASCII
894    #[case(b"password", "SCRAM-SHA-256$4096:jZLwuMbICV2L8i9SsfSEYQ==$Qhd2nOIlLW/dtVFERkVjVNdzzrVwPm2l+WHibmPesoc=:P1aH2cUHyPUbIdO06hEiXdwKxQyqBNUijLGFLkTXcHs=")]
895    // Unicode
896    #[case("schön".as_bytes(), "SCRAM-SHA-256$4096:uuH6VXsbbeId2AcdL0WmSA==$imMseND/Sg7tL5Tm1ltZJGa6PsdxwysUZ9s1lXPOPdo=:kMp6Rb9yN3zYpvwkuf0/xQZWhIGEa0ryjwnyDfpL3G0=")]
897    // Unicode normalization -> half-width to full-width
898    #[case("パスワード".as_bytes(), "SCRAM-SHA-256$4096:oCSGmW9Llo803DWp94yE0A==$TvNA2Hh1IqwCHlhxHhIaTeI7N/mFSx01D3/tb2VGQfw=:RBDsZImb7XoP6Md1j0zhjf7yBz0ocDoxqsPeFtJLyaI=")]
899    // Chars that normalize to space and nothing
900    #[case(b"pass\xc2\xa0\xe2\x80\x80word", "SCRAM-SHA-256$4096:ag3Z1WnqEn8dhTvSP7UtYA==$taWe9cZJYK5Y28V9Nw3zy6E9qQKbqKrMRS5DwlDXG04=:Y4n3uwZ4jQyG7nYCde3vtPxO1p0Oxz5ytJT1W+lqM+I=")]
901    // Invalid control chars
902    #[case(b"\x01\x02\x03", "SCRAM-SHA-256$4096:XGcYpEn2cwuS+BZXJBaqFg==$mG53wGoI6pAANoAZl7qxYiKPZ6u3CfhCVZK4et3l52A=:X5PUFkC5MVJWmuBTwWQHTFH81xjiyAHrJ9r0anOPXiI=")]
903    // Prohibited char (ffff)
904    #[case(b"\xef\xbf\xbf", "SCRAM-SHA-256$4096:Tdv5eCJIm+LU9QJBKO96gQ==$YXE4G3HKPwCmwo4FjiFKaiqVGCDTOpVETv+Fe6wWY9Q=:DK7MZ/OgGGgCDh6EfsmmcyFuaAD+T2Zh78sl+QDQFIo=")]
905    fn test_stored_key(#[case] password: &[u8], #[case] stored_key: &str) {
906        let parsed_key = StoredKey::from_str(stored_key).unwrap();
907        assert_eq!(4096, parsed_key.iterations);
908        let generated_key = StoredKey::generate(password, &parsed_key.salt, parsed_key.iterations);
909        assert_eq!(generated_key, parsed_key);
910        assert_eq!(generated_key.to_string(), stored_key);
911    }
912
913    #[test]
914    fn test_client_proof() {
915        let mut buffer = [0; 128];
916        let salt = decode_salt(SALT, &mut buffer).unwrap();
917        let salted_password = generate_salted_password(PASSWORD, &salt, ITERATIONS);
918        let (client, server) = generate_client_proof(
919            format!("n={USERNAME},r={CLIENT_NONCE}").as_bytes(),
920            format!("r={CLIENT_NONCE}{SERVER_NONCE},s={SALT},i={ITERATIONS}").as_bytes(),
921            CHANNEL_BINDING_ENCODED.as_bytes(),
922            format!("{CLIENT_NONCE}{SERVER_NONCE}").as_bytes(),
923            &salted_password,
924        );
925        assert_eq!(
926            &client,
927            BASE64_STANDARD.decode(CLIENT_PROOF).unwrap().as_slice()
928        );
929        assert_eq!(
930            &server,
931            BASE64_STANDARD.decode(SERVER_VERIFY).unwrap().as_slice()
932        );
933    }
934
935    #[test]
936    fn test_client_first_message() {
937        let message = ClientFirstMessage::decode(b"n,,n=,r=480I9uIaXEU9oB2RRcenOxN/").unwrap();
938        assert_eq!(
939            message.channel_binding,
940            ChannelBinding::NotSupported(Cow::Borrowed(""))
941        );
942        assert_eq!(message.username, "");
943        assert_eq!(message.nonce, "480I9uIaXEU9oB2RRcenOxN/");
944        assert_eq!(
945            message.encode(),
946            "n,,n=,r=480I9uIaXEU9oB2RRcenOxN/".to_owned()
947        );
948    }
949
950    #[test]
951    fn test_client_first_message_required() {
952        let message =
953            ClientFirstMessage::decode(b"p=cb-name,,n=,r=480I9uIaXEU9oB2RRcenOxN/").unwrap();
954        assert_eq!(
955            message.channel_binding,
956            ChannelBinding::Required(Cow::Borrowed(""), Cow::Borrowed("cb-name"))
957        );
958        assert_eq!(message.username, "");
959        assert_eq!(message.nonce, "480I9uIaXEU9oB2RRcenOxN/");
960        assert_eq!(
961            message.encode(),
962            "p=cb-name,,n=,r=480I9uIaXEU9oB2RRcenOxN/".to_owned()
963        );
964    }
965
966    #[test]
967    fn test_server_first_response() {
968        let message = ServerFirstResponse::decode(
969            b"r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,s=t5YekvL6lgy4RyPnsiyqsg==,i=4096",
970        )
971        .unwrap();
972        assert_eq!(
973            message.combined_nonce,
974            "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu"
975        );
976        assert_eq!(message.salt, "t5YekvL6lgy4RyPnsiyqsg==");
977        assert_eq!(message.iterations, 4096);
978        assert_eq!(
979            message.encode(),
980            "r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,s=t5YekvL6lgy4RyPnsiyqsg==,i=4096"
981                .to_owned()
982        );
983    }
984
985    #[test]
986    fn test_client_final_message() {
987        let message = b"c=biws,r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,p=7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4=";
988        let decoded = ClientFinalMessage::decode(message).unwrap();
989        assert_eq!(decoded.channel_binding, "biws");
990        assert_eq!(
991            decoded.combined_nonce,
992            "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu"
993        );
994        assert_eq!(
995            decoded.proof,
996            "7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4="
997        );
998        let encoded = decoded.encode();
999        assert_eq!(encoded, "c=biws,r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,p=7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4=");
1000    }
1001
1002    #[test]
1003    fn test_server_final_response() {
1004        let message = b"v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=";
1005        let decoded: ServerFinalResponse = ServerFinalResponse::decode(message).unwrap();
1006        assert_eq!(
1007            decoded.verifier,
1008            "6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="
1009        );
1010        let encoded = decoded.encode();
1011        assert_eq!(encoded, "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=");
1012    }
1013
1014    /// Run a SCRAM conversation with a fixed set of parameters
1015    #[test]
1016    fn test_transaction() {
1017        let mut server = ServerTransaction::default();
1018        let mut client = ClientTransaction::new("username".into());
1019
1020        struct Env {}
1021        impl ClientEnvironment for Env {
1022            fn generate_nonce(&self) -> String {
1023                "<<<client nonce>>>".into()
1024            }
1025            fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
1026                generate_salted_password(b"password", salt, iterations)
1027            }
1028        }
1029        impl ServerEnvironment for Env {
1030            fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) {
1031                assert_eq!(username, "username");
1032                let key = StoredKey::generate(b"password", b"hello", 4096);
1033                (key.stored_key, key.server_key)
1034            }
1035            fn generate_nonce(&self) -> String {
1036                "<<<server nonce>>>".into()
1037            }
1038            fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) {
1039                assert_eq!(username, "username");
1040                (Cow::Borrowed(b"hello"), 4096)
1041            }
1042        }
1043        let env = Env {};
1044        let message = client.process_message(&[], &env).unwrap().unwrap();
1045        assert_eq!(
1046            String::from_utf8(message.clone()).unwrap(),
1047            "n,,n=username,r=<<<client nonce>>>"
1048        );
1049        let message = server.process_message(&message, &env).unwrap();
1050        assert_eq!(
1051            String::from_utf8(message.clone()).unwrap(),
1052            "r=<<<client nonce>>><<<server nonce>>>,s=aGVsbG8=,i=4096"
1053        );
1054        let message = client.process_message(&message, &env).unwrap().unwrap();
1055        assert_eq!(String::from_utf8(message.clone()).unwrap(), "c=biws,r=<<<client nonce>>><<<server nonce>>>,p=621h6u6V3axb7mNYHNgTspTZ3SqILcxuJOsFu5wMjV8=");
1056        let message = server.process_message(&message, &env).unwrap();
1057        assert_eq!(
1058            String::from_utf8(message.clone()).unwrap(),
1059            "v=moj4kNnZKB3wjXZeQsKYI9luTTakwgH8r0NdGOjugRY="
1060        );
1061        assert!(client.process_message(&message, &env).unwrap().is_none());
1062        assert!(client.success());
1063        assert!(server.success());
1064    }
1065}