gel_auth/scram/
mod.rs

1//! # SCRAM (Salted Challenge Response Authentication Mechanism)
2//!
3//! SCRAM is a protocol that securely authenticates users by hashing and salting
4//! their passwords before exchanging them, preventing exposure of the actual
5//! password during transmission. It is commonly used in applications and
6//! protocols like Postgres and SASL to enhance security against common attacks
7//! such as replay and man-in-the-middle attacks.
8//!
9//! <https://en.wikipedia.org/wiki/Salted_Challenge_Response_Authentication_Mechanism>
10//!
11//! ## Limitations of this implementation
12//!
13//! This code implements a sufficient form of SCRAM to authenticate to most
14//! PostgreSQL instances. It does not implement channel binding, and its
15//! implementation is likely not generic enough to work with any other
16//! implementation of SCRAM that isn't designed to be PostgreSQL-compatible.
17//!
18//! ## Transaction
19//!
20//! The transaction consists of four steps:
21//!
22//! 1. Client's Initial Response: The client sends its username and initial
23//!    nonce.
24//! 2. Server's Challenge: The server responds with a combined nonce, a
25//!    base64-encoded salt, and an iteration count for the PBKDF2 algorithm.
26//! 3. Client's Proof: The client sends its proof of possession of the password,
27//!    along with the combined nonce and base64-encoded channel binding data.
28//! 4. Server's Final Response: The server sends its verifier, proving
29//!    successful authentication.
30//!
31//! This transaction securely authenticates the client to the server without
32//! transmitting the actual password.
33//!
34//! ## Parameters
35//!
36//! The following parameters are used in the SCRAM authentication exchange:
37//!
38//! * `r=` (nonce): A random string generated by the client and server to ensure
39//!   the uniqueness of each authentication exchange. The client initially sends
40//!   its nonce, and the server responds with a combined nonce (client’s nonce +
41//!   server’s nonce).
42//!
43//! * `c=` (channel binding): A base64-encoded representation of the channel
44//!   binding data. This parameter is used to bind the authentication to the
45//!   specific channel over which it is occurring, ensuring the integrity of the
46//!   communication channel.
47//!
48//! * `s=` (salt): A base64-encoded salt provided by the server. The salt is
49//!   used in conjunction with the client’s password to generate a salted
50//!   password for enhanced security.
51//!
52//! * `i=` (iteration count): The number of iterations to apply in the PBKDF2
53//!   (Password-Based Key Derivation Function 2) algorithm. This parameter
54//!   defines the computational cost of generating the salted password.
55//!
56//! * `n=` (name): The username of the client. This parameter is included in the
57//!   client’s initial response. Note that the username is sent out-of-band when
58//!   using PostgreSQL and this parameter is set to an empty value.
59//!
60//! * `p=` (proof): The client’s proof of possession of the password. This is a
61//!   base64-encoded value calculated using the salted password and other SCRAM
62//!   parameters to prove that the client knows the password without sending it
63//!   directly.
64//!
65//! * `v=` (verifier): The server’s verifier, which is used to prove that the
66//!   server also knows the shared secret. This parameter is included in the
67//!   server’s final message to confirm successful authentication.
68#![allow(unused)]
69
70use base64::{prelude::BASE64_STANDARD, Engine};
71use hmac::{Hmac, Mac};
72use rand::Rng;
73use sha2::{digest::FixedOutput, Digest, Sha256};
74use std::borrow::Cow;
75use std::str::FromStr;
76
77pub mod stringprep;
78mod stringprep_table;
79
80use stringprep::sasl_normalize_password_bytes;
81
82const CHANNEL_BINDING_ENCODED: &str = "biws";
83const MINIMUM_NONCE_LENGTH: usize = 16;
84
85type HmacSha256 = Hmac<Sha256>;
86pub type Sha256Out = [u8; 32];
87
88#[derive(Debug, thiserror::Error)]
89pub enum SCRAMError {
90    #[error("Invalid encoding")]
91    ProtocolError,
92}
93
94pub trait ServerEnvironment {
95    fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize);
96    fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out);
97    fn generate_nonce(&self) -> String;
98}
99
100#[derive(Default, derive_more::Debug)]
101pub struct ServerTransaction {
102    #[debug(skip)]
103    state: ServerState,
104}
105
106impl ServerTransaction {
107    pub fn success(&self) -> bool {
108        matches!(self.state, ServerState::Success)
109    }
110
111    pub fn initial(&self) -> bool {
112        matches!(self.state, ServerState::Initial)
113    }
114
115    pub fn process_message(
116        &mut self,
117        message: &[u8],
118        env: &impl ServerEnvironment,
119    ) -> Result<Vec<u8>, SCRAMError> {
120        match &self.state {
121            ServerState::Success => Err(SCRAMError::ProtocolError),
122            ServerState::Initial => {
123                let message = ClientFirstMessage::decode(message)?;
124                if message.channel_binding != ChannelBinding::NotSupported("".into()) {
125                    return Err(SCRAMError::ProtocolError);
126                }
127                if message.nonce.len() < MINIMUM_NONCE_LENGTH {
128                    return Err(SCRAMError::ProtocolError);
129                }
130                let (salt, iterations) = env.get_password_parameters(&message.username);
131                let mut nonce = message.nonce.to_string();
132                nonce += &env.generate_nonce();
133                let response = ServerFirstResponse {
134                    combined_nonce: nonce.to_string().into(),
135                    salt: BASE64_STANDARD.encode(salt).into(),
136                    iterations,
137                };
138                self.state =
139                    ServerState::SentChallenge(message.to_owned_bare(), response.to_owned());
140                Ok(response.encode().into_bytes())
141            }
142            ServerState::SentChallenge(first_message, first_response) => {
143                let message = ClientFinalMessage::decode(message)?;
144                if !constant_time_eq::constant_time_eq(
145                    message.combined_nonce.as_bytes(),
146                    first_response.combined_nonce.as_bytes(),
147                ) {
148                    return Err(SCRAMError::ProtocolError);
149                }
150                if message.channel_binding != CHANNEL_BINDING_ENCODED {
151                    return Err(SCRAMError::ProtocolError);
152                }
153                let (stored_key, server_key) = env.get_stored_key(&first_message.username);
154
155                // Decode the provided client proof
156                let mut provided_proof = vec![];
157                BASE64_STANDARD
158                    .decode_vec(message.proof.as_bytes(), &mut provided_proof)
159                    .map_err(|_| SCRAMError::ProtocolError)?;
160
161                let (calculated_stored_key, server_signature) = generate_server_proof(
162                    first_message.encode().as_bytes(),
163                    first_response.encode().as_bytes(),
164                    message.channel_binding.as_bytes(),
165                    message.combined_nonce.as_bytes(),
166                    &provided_proof,
167                    &server_key,
168                    &stored_key,
169                );
170
171                if !constant_time_eq::constant_time_eq(
172                    calculated_stored_key.as_slice(),
173                    &stored_key,
174                ) {
175                    return Err(SCRAMError::ProtocolError);
176                }
177
178                self.state = ServerState::Success;
179                let verifier = BASE64_STANDARD.encode(server_signature).into();
180                Ok(ServerFinalResponse { verifier }.encode().into_bytes())
181            }
182        }
183    }
184}
185
186#[derive(Default)]
187enum ServerState {
188    #[default]
189    Initial,
190    SentChallenge(ClientFirstMessage<'static>, ServerFirstResponse<'static>),
191    Success,
192}
193
194pub trait ClientEnvironment {
195    fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out;
196    fn generate_nonce(&self) -> String;
197}
198
199#[derive(Debug)]
200pub struct ClientTransaction {
201    state: ClientState,
202}
203
204impl ClientTransaction {
205    pub fn new(username: Cow<'static, str>) -> Self {
206        Self {
207            state: ClientState::Initial(username),
208        }
209    }
210
211    pub fn success(&self) -> bool {
212        matches!(self.state, ClientState::Success)
213    }
214
215    pub fn process_message(
216        &mut self,
217        message: &[u8],
218        env: &impl ClientEnvironment,
219    ) -> Result<Option<Vec<u8>>, SCRAMError> {
220        match &self.state {
221            ClientState::Success => Err(SCRAMError::ProtocolError),
222            ClientState::Initial(username) => {
223                if !message.is_empty() {
224                    return Err(SCRAMError::ProtocolError);
225                }
226                let nonce = env.generate_nonce().into();
227                let message = ClientFirstMessage {
228                    channel_binding: ChannelBinding::NotSupported("".into()),
229                    username: username.clone(),
230                    nonce,
231                };
232                self.state = ClientState::SentFirst(message.to_owned_bare());
233                Ok(Some(message.encode().into_bytes()))
234            }
235            ClientState::SentFirst(first_message) => {
236                let message = ServerFirstResponse::decode(message)?;
237                // Ensure the client nonce was concatenated with the server's nonce
238                if !message
239                    .combined_nonce
240                    .starts_with(first_message.nonce.as_ref())
241                {
242                    return Err(SCRAMError::ProtocolError);
243                }
244                if message.combined_nonce.len() - first_message.nonce.len() < MINIMUM_NONCE_LENGTH {
245                    return Err(SCRAMError::ProtocolError);
246                }
247                let mut buffer = [0; 1024];
248                let salt = decode_salt(&message.salt, &mut buffer)?;
249                let salted_password = env.get_salted_password(&salt, message.iterations);
250                let (client_proof, server_verifier) = generate_client_proof(
251                    first_message.encode().as_bytes(),
252                    message.encode().as_bytes(),
253                    CHANNEL_BINDING_ENCODED.as_bytes(),
254                    message.combined_nonce.as_bytes(),
255                    &salted_password,
256                );
257                let message = ClientFinalMessage {
258                    channel_binding: CHANNEL_BINDING_ENCODED.into(),
259                    combined_nonce: message.combined_nonce.to_string().into(),
260                    proof: BASE64_STANDARD.encode(client_proof).into(),
261                };
262                self.state = ClientState::ExpectingVerifier(ServerFinalResponse {
263                    verifier: BASE64_STANDARD.encode(server_verifier).into(),
264                });
265                Ok(Some(message.encode().into_bytes()))
266            }
267            ClientState::ExpectingVerifier(server_final_response) => {
268                let message = ServerFinalResponse::decode(message)?;
269                if !constant_time_eq::constant_time_eq(
270                    message.verifier.as_bytes(),
271                    server_final_response.verifier.as_bytes(),
272                ) {
273                    return Err(SCRAMError::ProtocolError);
274                }
275                self.state = ClientState::Success;
276                Ok(None)
277            }
278        }
279    }
280}
281
282#[derive(Debug)]
283enum ClientState {
284    Initial(Cow<'static, str>),
285    SentFirst(ClientFirstMessage<'static>),
286    ExpectingVerifier(ServerFinalResponse<'static>),
287    Success,
288}
289
290trait Encode {
291    fn encode(&self) -> String;
292}
293
294trait Decode<'a> {
295    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError>
296    where
297        Self: Sized + 'a;
298}
299
300fn extract<'a>(input: &'a [u8], prefix: &'static str) -> Result<&'a str, SCRAMError> {
301    let bytes = input
302        .strip_prefix(prefix.as_bytes())
303        .ok_or(SCRAMError::ProtocolError)?;
304    std::str::from_utf8(bytes).map_err(|_| SCRAMError::ProtocolError)
305}
306
307fn inext<'a>(it: &mut impl Iterator<Item = &'a [u8]>) -> Result<&'a [u8], SCRAMError> {
308    it.next().ok_or(SCRAMError::ProtocolError)
309}
310
311fn hmac(s: &[u8]) -> HmacSha256 {
312    // This is effectively infallible
313    HmacSha256::new_from_slice(s).expect("HMAC can take key of any size")
314}
315
316#[derive(Debug, Clone, PartialEq, Eq)]
317/// `gs2-cbind-flag` from RFC5802.
318enum ChannelBinding<'a> {
319    /// No channel binding
320    NotSpecified,
321    /// "n" -> client doesn't support channel binding.
322    NotSupported(Cow<'a, str>),
323    /// "y" -> client does support channel binding but thinks the server does
324    /// not.
325    Supported(Cow<'a, str>),
326    /// "p" -> client requires channel binding. The selected channel binding
327    /// follows "p=".
328    Required(Cow<'a, str>, Cow<'a, str>),
329}
330
331#[derive(Debug)]
332pub struct ClientFirstMessage<'a> {
333    channel_binding: ChannelBinding<'a>,
334    username: Cow<'a, str>,
335    nonce: Cow<'a, str>,
336}
337
338impl ClientFirstMessage<'_> {
339    /// Get the bare first message
340    pub fn to_owned_bare(&self) -> ClientFirstMessage<'static> {
341        ClientFirstMessage {
342            channel_binding: ChannelBinding::NotSpecified,
343            username: self.username.to_string().into(),
344            nonce: self.nonce.to_string().into(),
345        }
346    }
347}
348
349impl Encode for ClientFirstMessage<'_> {
350    fn encode(&self) -> String {
351        let channel_binding = match self.channel_binding {
352            ChannelBinding::NotSpecified => "".to_string(),
353            ChannelBinding::NotSupported(ref s) => format!("n,{},", s),
354            ChannelBinding::Supported(ref s) => format!("y,{},", s),
355            ChannelBinding::Required(ref s, ref t) => format!("p={},{},", t, s),
356        };
357        format!("{channel_binding}n={},r={}", self.username, self.nonce)
358    }
359}
360
361impl<'a> Decode<'a> for ClientFirstMessage<'a> {
362    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
363        let mut parts = buf.split(|&b| b == b',');
364
365        // Check for channel binding
366        let mut next = inext(&mut parts)?;
367        let mut channel_binding = ChannelBinding::NotSpecified;
368        match (next.len(), next.first()) {
369            (_, Some(b'p')) => {
370                // p=(cb-name),(authz-id),
371                let Some(cb_name) = next.strip_prefix(b"p=") else {
372                    return Err(SCRAMError::ProtocolError);
373                };
374                let cb_name =
375                    std::str::from_utf8(cb_name).map_err(|_| SCRAMError::ProtocolError)?;
376                let param = inext(&mut parts)?;
377                channel_binding = ChannelBinding::Required(
378                    Cow::Borrowed(
379                        std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?,
380                    ),
381                    cb_name.into(),
382                );
383                next = inext(&mut parts)?;
384            }
385            (1, Some(b'n')) => {
386                let param = inext(&mut parts)?;
387                channel_binding = ChannelBinding::NotSupported(Cow::Borrowed(
388                    std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?,
389                ));
390                next = inext(&mut parts)?;
391            }
392            (1, Some(b'y')) => {
393                let param = inext(&mut parts)?;
394                channel_binding = ChannelBinding::Supported(Cow::Borrowed(
395                    std::str::from_utf8(param).map_err(|_| SCRAMError::ProtocolError)?,
396                ));
397                next = inext(&mut parts)?;
398            }
399            (_, None) => {
400                return Err(SCRAMError::ProtocolError);
401            }
402            _ => {
403                // No channel binding specified
404            }
405        }
406        let username = extract(next, "n=")?.into();
407        let nonce = extract(inext(&mut parts)?, "r=")?.into();
408        Ok(ClientFirstMessage {
409            channel_binding,
410            username,
411            nonce,
412        })
413    }
414}
415
416pub struct ServerFirstResponse<'a> {
417    combined_nonce: Cow<'a, str>,
418    salt: Cow<'a, str>,
419    iterations: usize,
420}
421
422impl ServerFirstResponse<'_> {
423    pub fn to_owned(&self) -> ServerFirstResponse<'static> {
424        ServerFirstResponse {
425            combined_nonce: self.combined_nonce.to_string().into(),
426            salt: self.salt.to_string().into(),
427            iterations: self.iterations,
428        }
429    }
430}
431
432impl Encode for ServerFirstResponse<'_> {
433    fn encode(&self) -> String {
434        format!(
435            "r={},s={},i={}",
436            self.combined_nonce, self.salt, self.iterations
437        )
438    }
439}
440
441impl<'a> Decode<'a> for ServerFirstResponse<'a> {
442    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
443        let mut parts = buf.split(|&b| b == b',');
444        let combined_nonce = extract(inext(&mut parts)?, "r=")?.into();
445        let salt = extract(inext(&mut parts)?, "s=")?.into();
446        let iterations = extract(inext(&mut parts)?, "i=")?;
447        Ok(ServerFirstResponse {
448            combined_nonce,
449            salt,
450            iterations: str::parse(iterations).map_err(|_| SCRAMError::ProtocolError)?,
451        })
452    }
453}
454
455pub struct ClientFinalMessage<'a> {
456    channel_binding: Cow<'a, str>,
457    combined_nonce: Cow<'a, str>,
458    proof: Cow<'a, str>,
459}
460
461impl Encode for ClientFinalMessage<'_> {
462    fn encode(&self) -> String {
463        format!(
464            "c={},r={},p={}",
465            self.channel_binding, self.combined_nonce, self.proof
466        )
467    }
468}
469
470impl<'a> Decode<'a> for ClientFinalMessage<'a> {
471    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
472        let mut parts = buf.split(|&b| b == b',');
473        let channel_binding = extract(inext(&mut parts)?, "c=")?.into();
474        let combined_nonce = extract(inext(&mut parts)?, "r=")?.into();
475        let proof = extract(inext(&mut parts)?, "p=")?.into();
476        Ok(ClientFinalMessage {
477            channel_binding,
478            combined_nonce,
479            proof,
480        })
481    }
482}
483
484#[derive(Debug)]
485pub struct ServerFinalResponse<'a> {
486    verifier: Cow<'a, str>,
487}
488
489impl<'a> Encode for ServerFinalResponse<'a> {
490    fn encode(&self) -> String {
491        format!("v={}", self.verifier)
492    }
493}
494
495impl<'a> Decode<'a> for ServerFinalResponse<'a> {
496    fn decode(buf: &'a [u8]) -> Result<Self, SCRAMError> {
497        let mut parts = buf.split(|&b| b == b',');
498        let verifier = extract(inext(&mut parts)?, "v=")?.into();
499        Ok(ServerFinalResponse { verifier })
500    }
501}
502
503pub fn decode_salt<'a>(salt: &str, buffer: &'a mut [u8]) -> Result<Cow<'a, [u8]>, SCRAMError> {
504    // The salt needs to be base64 decoded -- full binary must be used
505    if let Ok(n) = BASE64_STANDARD.decode_slice(salt, buffer) {
506        Ok(Cow::Borrowed(&buffer[..n]))
507    } else {
508        // In the unlikely case the salt is large -- note that we also fall back to this
509        // path for invalid base64 strings!
510        let mut buffer = vec![];
511        BASE64_STANDARD
512            .decode_vec(salt, &mut buffer)
513            .map_err(|_| SCRAMError::ProtocolError)?;
514        Ok(Cow::Owned(buffer))
515    }
516}
517
518/// Given a password in byte form, generates the salted version of the password,
519/// applying SASLprep to it beforehand.
520pub fn generate_salted_password(password: &[u8], salt: &[u8], iterations: usize) -> Sha256Out {
521    // Save the pre-keyed hmac
522    let ui_p = hmac(&sasl_normalize_password_bytes(password));
523
524    // The initial signature is the salt with a terminator of a 32-bit string ending in 1
525    let mut ui = ui_p.clone();
526
527    ui.update(salt);
528    ui.update(&[0, 0, 0, 1]);
529
530    // Grab the initial digest
531    let mut last_hash = Default::default();
532    ui.finalize_into(&mut last_hash);
533    let mut u = last_hash;
534
535    // For X number of iterations, recompute the HMAC signature against the password and the latest iteration of the hash, and XOR it with the previous version
536    for _ in 0..(iterations - 1) {
537        let mut ui = ui_p.clone();
538        ui.update(&last_hash);
539        ui.finalize_into(&mut last_hash);
540        for i in 0..u.len() {
541            u[i] ^= last_hash[i];
542        }
543    }
544
545    u.as_slice().try_into().unwrap()
546}
547
548pub fn generate_nonce() -> String {
549    let mut rng = rand::thread_rng();
550    let bytes: [u8; 32] = rng.gen();
551    BASE64_STANDARD.encode(bytes)
552}
553
554#[derive(Clone, Debug)]
555pub struct StoredKey {
556    pub iterations: usize,
557    pub salt: Vec<u8>,
558    pub stored_key: Sha256Out,
559    pub server_key: Sha256Out,
560}
561
562impl PartialEq for StoredKey {
563    fn eq(&self, other: &Self) -> bool {
564        // If the salt and stored_key match, the remainder must match.
565        // We only need to compare these two fields for equality because:
566        // 1. The salt is used to derive the stored_key, so if both match,
567        //    it implies the same password and iteration count were used.
568        // 2. The server_key is derived from the same process, so it will
569        //    automatically match if the salt and stored_key match.
570        // 3. The iterations count doesn't need to be compared explicitly
571        //    as it's factored into the stored_key calculation.
572        constant_time_eq::constant_time_eq(&self.salt, &other.salt)
573            && constant_time_eq::constant_time_eq(&self.stored_key, &other.stored_key)
574    }
575}
576
577impl Eq for StoredKey {}
578
579impl FromStr for StoredKey {
580    type Err = SCRAMError;
581
582    // "SCRAM-SHA-256$<iterations>:<salt>$<stored_key>:<server_key>"
583
584    fn from_str(s: &str) -> Result<Self, Self::Err> {
585        let parts: Vec<&str> = s.split('$').collect();
586        if parts.len() != 3 || parts[0] != "SCRAM-SHA-256" {
587            return Err(SCRAMError::ProtocolError);
588        }
589
590        let iterations = parts[1]
591            .split(':')
592            .next()
593            .ok_or(SCRAMError::ProtocolError)?
594            .parse()
595            .map_err(|_| SCRAMError::ProtocolError)?;
596
597        let salt = BASE64_STANDARD
598            .decode(
599                parts[1]
600                    .split(':')
601                    .nth(1)
602                    .ok_or(SCRAMError::ProtocolError)?,
603            )
604            .map_err(|_| SCRAMError::ProtocolError)?;
605
606        let key_parts: Vec<&str> = parts[2].split(':').collect();
607        if key_parts.len() != 2 {
608            return Err(SCRAMError::ProtocolError);
609        }
610
611        let stored_key = BASE64_STANDARD
612            .decode(key_parts[0])
613            .map_err(|_| SCRAMError::ProtocolError)?
614            .try_into()
615            .map_err(|_| SCRAMError::ProtocolError)?;
616
617        let server_key = BASE64_STANDARD
618            .decode(key_parts[1])
619            .map_err(|_| SCRAMError::ProtocolError)?
620            .try_into()
621            .map_err(|_| SCRAMError::ProtocolError)?;
622
623        Ok(StoredKey {
624            iterations,
625            salt,
626            stored_key,
627            server_key,
628        })
629    }
630}
631use std::fmt;
632
633impl fmt::Display for StoredKey {
634    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
635        write!(
636            f,
637            "SCRAM-SHA-256${}:{}${}:{}",
638            self.iterations,
639            BASE64_STANDARD.encode(&self.salt),
640            BASE64_STANDARD.encode(self.stored_key),
641            BASE64_STANDARD.encode(self.server_key)
642        )
643    }
644}
645
646impl ServerEnvironment for StoredKey {
647    fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) {
648        (Cow::Owned(self.salt.clone()), self.iterations)
649    }
650
651    fn generate_nonce(&self) -> String {
652        let nonce: [u8; 32] = rand::thread_rng().gen();
653        base64::engine::general_purpose::STANDARD.encode(nonce)
654    }
655
656    fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) {
657        (self.stored_key, self.server_key)
658    }
659}
660
661impl StoredKey {
662    /// Generate a stored key compatible with PostgreSQL's encoding.
663    pub fn generate(password: &[u8], salt: &[u8], iterations: usize) -> Self {
664        let digest_key = generate_salted_password(password, salt, iterations);
665
666        let client_key = hmac(&digest_key)
667            .chain_update(b"Client Key")
668            .finalize()
669            .into_bytes();
670
671        let stored_key = Sha256::digest(client_key);
672
673        let server_key = hmac(&digest_key)
674            .chain_update(b"Server Key")
675            .finalize()
676            .into_bytes();
677
678        Self {
679            iterations,
680            salt: salt.to_owned(),
681            stored_key: stored_key.into(),
682            server_key: server_key.into(),
683        }
684    }
685}
686
687fn generate_client_proof(
688    first_message_bare: &[u8],
689    server_first_message: &[u8],
690    channel_binding: &[u8],
691    server_nonce: &[u8],
692    salted_password: &[u8],
693) -> (Sha256Out, Sha256Out) {
694    let ui_p = hmac(salted_password);
695
696    let mut ui = ui_p.clone();
697    ui.update(b"Server Key");
698    let server_key = ui.finalize_fixed();
699
700    let mut ui = ui_p.clone();
701    ui.update(b"Client Key");
702    let client_key = ui.finalize_fixed();
703
704    let mut hash = Sha256::new();
705    hash.update(client_key);
706    let stored_key = hash.finalize_fixed();
707
708    let auth_message = [
709        (first_message_bare),
710        (b","),
711        (server_first_message),
712        (b",c="),
713        (channel_binding),
714        (b",r="),
715        (server_nonce),
716    ];
717
718    let mut client_signature = hmac(&stored_key);
719    for chunk in auth_message {
720        client_signature.update(chunk);
721    }
722
723    let client_signature = client_signature.finalize_fixed();
724    let mut client_signature: Sha256Out = client_signature.as_slice().try_into().unwrap();
725
726    for i in 0..client_signature.len() {
727        client_signature[i] ^= client_key[i];
728    }
729
730    let mut server_proof = hmac(&server_key);
731    for chunk in auth_message {
732        server_proof.update(chunk);
733    }
734    let server_proof = server_proof.finalize_fixed().as_slice().try_into().unwrap();
735
736    (client_signature, server_proof)
737}
738
739fn generate_server_proof(
740    first_message_bare: &[u8],
741    server_first_message: &[u8],
742    channel_binding: &[u8],
743    server_nonce: &[u8],
744    provided_proof: &[u8],
745    server_key: &[u8],
746    stored_key: &[u8],
747) -> (Sha256Out, Sha256Out) {
748    let auth_message = [
749        (first_message_bare),
750        (b","),
751        (server_first_message),
752        (b",c="),
753        (channel_binding),
754        (b",r="),
755        (server_nonce),
756    ];
757
758    let mut client_signature = hmac(stored_key);
759    for chunk in &auth_message {
760        client_signature.update(chunk);
761    }
762    let client_signature = client_signature.finalize_fixed();
763
764    let mut calculated_stored_key = [0u8; 32];
765    for (i, (&p, &c)) in provided_proof
766        .iter()
767        .zip(client_signature.iter())
768        .enumerate()
769    {
770        calculated_stored_key[i] = p ^ c;
771    }
772
773    let calculated_stored_key = Sha256::digest(calculated_stored_key);
774
775    let mut server_signature = hmac(server_key);
776    for chunk in &auth_message {
777        server_signature.update(chunk);
778    }
779    let server_signature = server_signature.finalize_fixed();
780
781    (calculated_stored_key.into(), server_signature.into())
782}
783
784#[cfg(test)]
785mod tests {
786    use super::*;
787    use hex_literal::hex;
788    use pretty_assertions::{assert_eq, assert_ne};
789    use rstest::rstest;
790
791    // Define a set of test parameters
792    const CLIENT_NONCE: &str = "2XendqvQOa6cl0+Q7Y6UU0gw";
793    const SERVER_NONCE: &str = "xWn3mvDeVZwnUtT09vwXoItO";
794    const USERNAME: &str = "";
795    const PASSWORD: &[u8] = b"secret";
796    const SALT: &str = "t5YekvL6lgy4RyPnsiyqsg==";
797    const ITERATIONS: usize = 4096;
798    const CLIENT_PROOF: &[u8] = "p/HmDcOziQQnyF8fbVnJnlvwoLp1kZY4xsI9cCJhzCE=".as_bytes();
799    const SERVER_VERIFY: &[u8] = "g/X0codOryF0nCOWh7KkIab23ZFPX99iLzN5Ghn3nNc=".as_bytes();
800
801    #[rstest]
802    #[case(
803        b"1234",
804        "1234",
805        1,
806        hex!("EBE7E5BA4BF5A4D178D3BADAADD4C49A98C72FCFF4FB357DA7090D584990FCAA")
807    )]
808    #[case(
809        b"1234",
810        "1234",
811        2,
812        hex!("F9271C334EE6CD7FEE63BBC86FAF951A4ED9E293BDD72AC33663BAE662D31953")
813    )]
814    #[case(
815        b"1234",
816        "1234",
817        4096,
818        hex!("4FF8D6443278AB43209DF5A1327949AAC99A5AA23921E5C9199626524776F751")
819    )]
820    #[case(
821        b"password",
822        "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu",
823        4096,
824        hex!("E118A9AD43C87938659AD736E63F26BA2EBAF079AA351DB44AE29228FB4F7EF0")
825    )]
826    #[case(
827        b"secret",
828        "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu",
829        4096,
830        hex!("77DFD8E62A4379296C9769F9BA2F77D503C4647DE7919B47D6CF121986981BCC")
831    )]
832    #[case(
833        b"secret",
834        "t5YekvL6lgy4RyPnsiyqsg==",
835        4096,
836        hex!("9FB413FE9F1D0C8020400A3D49CFBC47FBFB1251CEA9297630BD025DB2B65171")
837    )]
838    #[case(
839        "😀".as_bytes(),
840        "t5YekvL6lgy4RyPnsiyqsg==",
841        4096,
842        hex!("AF490CE1BEA2DDB585DAF9C3842D1528AB091EF6FAB2A92489870523A98835EE")
843    )]
844    fn test_generate_salted_password(
845        #[case] password: &[u8],
846        #[case] salt: &str,
847        #[case] iterations: usize,
848        #[case] expected_hash: Sha256Out,
849    ) {
850        let mut buffer = [0; 128];
851        let salt = decode_salt(salt, &mut buffer).unwrap();
852        let hash = generate_salted_password(password, &salt, iterations);
853        assert_eq!(hash, expected_hash);
854    }
855
856    /// Tests that use real stored keys from postgres to match normalization
857    /// behaviour. This exercises the saslprep code indirectly to ensure it
858    /// matches the PostgreSQL implementation.
859    ///
860    /// Passwords in these tests were generated via `ALTER ROLE` DDL statements,
861    /// and the salted password was then extracted from the roles table.
862    ///
863    /// Note that a PostgreSQL user may have a password that is only value when
864    /// interpreted as a bag of bytes and cannot be set directly via `ALTER
865    /// ROLE` or the `initdb` command-line. This code _should_ support those
866    /// passwords, however given the complexity of testing this we do not
867    /// currently do so. Should we wish to test this in the future, we will need
868    /// to manually create the stored key, set this string as the password for
869    /// the role, and then validate that it can be used to authenticate via
870    /// SCRAM.
871    #[rstest]
872    // ASCII
873    #[case(b"password", "SCRAM-SHA-256$4096:jZLwuMbICV2L8i9SsfSEYQ==$Qhd2nOIlLW/dtVFERkVjVNdzzrVwPm2l+WHibmPesoc=:P1aH2cUHyPUbIdO06hEiXdwKxQyqBNUijLGFLkTXcHs=")]
874    // Unicode
875    #[case("schön".as_bytes(), "SCRAM-SHA-256$4096:uuH6VXsbbeId2AcdL0WmSA==$imMseND/Sg7tL5Tm1ltZJGa6PsdxwysUZ9s1lXPOPdo=:kMp6Rb9yN3zYpvwkuf0/xQZWhIGEa0ryjwnyDfpL3G0=")]
876    // Unicode normalization -> half-width to full-width
877    #[case("パスワード".as_bytes(), "SCRAM-SHA-256$4096:oCSGmW9Llo803DWp94yE0A==$TvNA2Hh1IqwCHlhxHhIaTeI7N/mFSx01D3/tb2VGQfw=:RBDsZImb7XoP6Md1j0zhjf7yBz0ocDoxqsPeFtJLyaI=")]
878    // Chars that normalize to space and nothing
879    #[case(b"pass\xc2\xa0\xe2\x80\x80word", "SCRAM-SHA-256$4096:ag3Z1WnqEn8dhTvSP7UtYA==$taWe9cZJYK5Y28V9Nw3zy6E9qQKbqKrMRS5DwlDXG04=:Y4n3uwZ4jQyG7nYCde3vtPxO1p0Oxz5ytJT1W+lqM+I=")]
880    // Invalid control chars
881    #[case(b"\x01\x02\x03", "SCRAM-SHA-256$4096:XGcYpEn2cwuS+BZXJBaqFg==$mG53wGoI6pAANoAZl7qxYiKPZ6u3CfhCVZK4et3l52A=:X5PUFkC5MVJWmuBTwWQHTFH81xjiyAHrJ9r0anOPXiI=")]
882    // Prohibited char (ffff)
883    #[case(b"\xef\xbf\xbf", "SCRAM-SHA-256$4096:Tdv5eCJIm+LU9QJBKO96gQ==$YXE4G3HKPwCmwo4FjiFKaiqVGCDTOpVETv+Fe6wWY9Q=:DK7MZ/OgGGgCDh6EfsmmcyFuaAD+T2Zh78sl+QDQFIo=")]
884    fn test_stored_key(#[case] password: &[u8], #[case] stored_key: &str) {
885        let parsed_key = StoredKey::from_str(stored_key).unwrap();
886        assert_eq!(4096, parsed_key.iterations);
887        let generated_key = StoredKey::generate(password, &parsed_key.salt, parsed_key.iterations);
888        assert_eq!(generated_key, parsed_key);
889        assert_eq!(generated_key.to_string(), stored_key);
890    }
891
892    #[test]
893    fn test_client_proof() {
894        let mut buffer = [0; 128];
895        let salt = decode_salt(SALT, &mut buffer).unwrap();
896        let salted_password = generate_salted_password(PASSWORD, &salt, ITERATIONS);
897        let (client, server) = generate_client_proof(
898            format!("n={USERNAME},r={CLIENT_NONCE}").as_bytes(),
899            format!("r={CLIENT_NONCE}{SERVER_NONCE},s={SALT},i={ITERATIONS}").as_bytes(),
900            CHANNEL_BINDING_ENCODED.as_bytes(),
901            format!("{CLIENT_NONCE}{SERVER_NONCE}").as_bytes(),
902            &salted_password,
903        );
904        assert_eq!(
905            &client,
906            BASE64_STANDARD.decode(CLIENT_PROOF).unwrap().as_slice()
907        );
908        assert_eq!(
909            &server,
910            BASE64_STANDARD.decode(SERVER_VERIFY).unwrap().as_slice()
911        );
912    }
913
914    #[test]
915    fn test_client_first_message() {
916        let message = ClientFirstMessage::decode(b"n,,n=,r=480I9uIaXEU9oB2RRcenOxN/").unwrap();
917        assert_eq!(
918            message.channel_binding,
919            ChannelBinding::NotSupported(Cow::Borrowed(""))
920        );
921        assert_eq!(message.username, "");
922        assert_eq!(message.nonce, "480I9uIaXEU9oB2RRcenOxN/");
923        assert_eq!(
924            message.encode(),
925            "n,,n=,r=480I9uIaXEU9oB2RRcenOxN/".to_owned()
926        );
927    }
928
929    #[test]
930    fn test_client_first_message_required() {
931        let message =
932            ClientFirstMessage::decode(b"p=cb-name,,n=,r=480I9uIaXEU9oB2RRcenOxN/").unwrap();
933        assert_eq!(
934            message.channel_binding,
935            ChannelBinding::Required(Cow::Borrowed(""), Cow::Borrowed("cb-name"))
936        );
937        assert_eq!(message.username, "");
938        assert_eq!(message.nonce, "480I9uIaXEU9oB2RRcenOxN/");
939        assert_eq!(
940            message.encode(),
941            "p=cb-name,,n=,r=480I9uIaXEU9oB2RRcenOxN/".to_owned()
942        );
943    }
944
945    #[test]
946    fn test_server_first_response() {
947        let message = ServerFirstResponse::decode(
948            b"r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,s=t5YekvL6lgy4RyPnsiyqsg==,i=4096",
949        )
950        .unwrap();
951        assert_eq!(
952            message.combined_nonce,
953            "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu"
954        );
955        assert_eq!(message.salt, "t5YekvL6lgy4RyPnsiyqsg==");
956        assert_eq!(message.iterations, 4096);
957        assert_eq!(
958            message.encode(),
959            "r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,s=t5YekvL6lgy4RyPnsiyqsg==,i=4096"
960                .to_owned()
961        );
962    }
963
964    #[test]
965    fn test_client_final_message() {
966        let message = b"c=biws,r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,p=7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4=";
967        let decoded = ClientFinalMessage::decode(message).unwrap();
968        assert_eq!(decoded.channel_binding, "biws");
969        assert_eq!(
970            decoded.combined_nonce,
971            "480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu"
972        );
973        assert_eq!(
974            decoded.proof,
975            "7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4="
976        );
977        let encoded = decoded.encode();
978        assert_eq!(encoded, "c=biws,r=480I9uIaXEU9oB2RRcenOxN/RsOCy0BKJvyRSeuOtQ0cF0hu,p=7Vkz4SfWTNhB3hNdhTucC+3MaGmg3+PrAG3xfuepjP4=");
979    }
980
981    #[test]
982    fn test_server_final_response() {
983        let message = b"v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=";
984        let decoded: ServerFinalResponse = ServerFinalResponse::decode(message).unwrap();
985        assert_eq!(
986            decoded.verifier,
987            "6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="
988        );
989        let encoded = decoded.encode();
990        assert_eq!(encoded, "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=");
991    }
992
993    /// Run a SCRAM conversation with a fixed set of parameters
994    #[test]
995    fn test_transaction() {
996        let mut server = ServerTransaction::default();
997        let mut client = ClientTransaction::new("username".into());
998
999        struct Env {}
1000        impl ClientEnvironment for Env {
1001            fn generate_nonce(&self) -> String {
1002                "<<<client nonce>>>".into()
1003            }
1004            fn get_salted_password(&self, salt: &[u8], iterations: usize) -> Sha256Out {
1005                generate_salted_password(b"password", salt, iterations)
1006            }
1007        }
1008        impl ServerEnvironment for Env {
1009            fn get_stored_key(&self, username: &str) -> (Sha256Out, Sha256Out) {
1010                assert_eq!(username, "username");
1011                let key = StoredKey::generate(b"password", b"hello", 4096);
1012                (key.stored_key, key.server_key)
1013            }
1014            fn generate_nonce(&self) -> String {
1015                "<<<server nonce>>>".into()
1016            }
1017            fn get_password_parameters(&self, username: &str) -> (Cow<'static, [u8]>, usize) {
1018                assert_eq!(username, "username");
1019                (Cow::Borrowed(b"hello"), 4096)
1020            }
1021        }
1022        let env = Env {};
1023        let message = client.process_message(&[], &env).unwrap().unwrap();
1024        assert_eq!(
1025            String::from_utf8(message.clone()).unwrap(),
1026            "n,,n=username,r=<<<client nonce>>>"
1027        );
1028        let message = server.process_message(&message, &env).unwrap();
1029        assert_eq!(
1030            String::from_utf8(message.clone()).unwrap(),
1031            "r=<<<client nonce>>><<<server nonce>>>,s=aGVsbG8=,i=4096"
1032        );
1033        let message = client.process_message(&message, &env).unwrap().unwrap();
1034        assert_eq!(String::from_utf8(message.clone()).unwrap(), "c=biws,r=<<<client nonce>>><<<server nonce>>>,p=621h6u6V3axb7mNYHNgTspTZ3SqILcxuJOsFu5wMjV8=");
1035        let message = server.process_message(&message, &env).unwrap();
1036        assert_eq!(
1037            String::from_utf8(message.clone()).unwrap(),
1038            "v=moj4kNnZKB3wjXZeQsKYI9luTTakwgH8r0NdGOjugRY="
1039        );
1040        assert!(client.process_message(&message, &env).unwrap().is_none());
1041        assert!(client.success());
1042        assert!(server.success());
1043    }
1044}