Skip to main content

crabka_security/scram/
client.rs

1//! `ScramClientExchange` — RFC 5802 SCRAM client state machine.
2//! Supports SCRAM-SHA-256 and SCRAM-SHA-512; the mechanism is fixed at
3//! construction.
4
5use base64::Engine;
6use base64::engine::general_purpose::STANDARD as B64;
7use hmac::{Hmac, KeyInit, Mac};
8use ring::rand::{SecureRandom, SystemRandom};
9use sha2::{Digest, Sha256, Sha512};
10use subtle::ConstantTimeEq;
11
12use crate::{AuthError, SaslMechanism};
13
14#[derive(Debug)]
15enum State {
16    Initial,
17    AwaitingServerFirst {
18        client_first_bare: String,
19        client_nonce: String,
20    },
21    AwaitingServerFinal {
22        auth_message: String,
23        server_key: Vec<u8>,
24    },
25    Finished,
26}
27
28#[derive(Debug)]
29pub struct ScramClientExchange {
30    username: String,
31    password: Vec<u8>,
32    mechanism: SaslMechanism,
33    state: State,
34}
35
36impl ScramClientExchange {
37    #[must_use]
38    pub fn new(username: String, password: Vec<u8>, mechanism: SaslMechanism) -> Self {
39        assert!(
40            mechanism.is_scram(),
41            "ScramClientExchange::new called with non-SCRAM mechanism {mechanism:?}"
42        );
43        Self {
44            username,
45            password,
46            mechanism,
47            state: State::Initial,
48        }
49    }
50
51    pub fn client_first(&mut self) -> Result<Vec<u8>, AuthError> {
52        if !matches!(self.state, State::Initial) {
53            return Err(AuthError::MalformedMessage);
54        }
55        let mut nonce_bytes = [0u8; 18];
56        SystemRandom::new()
57            .fill(&mut nonce_bytes)
58            .map_err(|_| AuthError::MalformedMessage)?;
59        let client_nonce = B64.encode(nonce_bytes);
60        let bare = format!("n={},r={}", self.username, client_nonce);
61        let msg = format!("n,,{bare}");
62        self.state = State::AwaitingServerFirst {
63            client_first_bare: bare,
64            client_nonce,
65        };
66        Ok(msg.into_bytes())
67    }
68
69    pub fn step(&mut self, server_bytes: &[u8]) -> Result<Vec<u8>, AuthError> {
70        let State::AwaitingServerFirst {
71            client_first_bare,
72            client_nonce,
73        } = std::mem::replace(&mut self.state, State::Finished)
74        else {
75            return Err(AuthError::MalformedMessage);
76        };
77        let s = std::str::from_utf8(server_bytes).map_err(|_| AuthError::MalformedMessage)?;
78        let mut nonce = None;
79        let mut salt = None;
80        let mut iterations = None;
81        for attr in s.split(',') {
82            if let Some(v) = attr.strip_prefix("r=") {
83                nonce = Some(v.to_string());
84            } else if let Some(v) = attr.strip_prefix("s=") {
85                salt = Some(B64.decode(v).map_err(|_| AuthError::MalformedMessage)?);
86            } else if let Some(v) = attr.strip_prefix("i=") {
87                iterations = Some(v.parse::<u32>().map_err(|_| AuthError::MalformedMessage)?);
88            }
89        }
90        let (Some(combined_nonce), Some(salt), Some(iters)) = (nonce, salt, iterations) else {
91            return Err(AuthError::MalformedMessage);
92        };
93        if !combined_nonce.starts_with(&client_nonce) {
94            return Err(AuthError::BadProof);
95        }
96
97        let channel_binding = B64.encode(b"n,,");
98        let client_final_no_proof = format!("c={channel_binding},r={combined_nonce}");
99        let auth_message = format!("{client_first_bare},{s},{client_final_no_proof}");
100
101        let (proof, server_key) = match self.mechanism {
102            SaslMechanism::ScramSha512 => {
103                compute_proof_sha512(&self.password, &salt, iters, auth_message.as_bytes())?
104            }
105            SaslMechanism::ScramSha256 => {
106                compute_proof_sha256(&self.password, &salt, iters, auth_message.as_bytes())?
107            }
108            SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
109                return Err(AuthError::MalformedMessage);
110            }
111        };
112
113        let client_final = format!("{client_final_no_proof},p={}", B64.encode(&proof));
114        self.state = State::AwaitingServerFinal {
115            auth_message,
116            server_key,
117        };
118        Ok(client_final.into_bytes())
119    }
120
121    pub fn verify_server_final(&mut self, server_bytes: &[u8]) -> Result<(), AuthError> {
122        let State::AwaitingServerFinal {
123            auth_message,
124            server_key,
125        } = std::mem::replace(&mut self.state, State::Finished)
126        else {
127            return Err(AuthError::MalformedMessage);
128        };
129        let s = std::str::from_utf8(server_bytes).map_err(|_| AuthError::MalformedMessage)?;
130        let v_b64 = s.strip_prefix("v=").ok_or(AuthError::MalformedMessage)?;
131        let v = B64.decode(v_b64).map_err(|_| AuthError::MalformedMessage)?;
132        let expected: Vec<u8> = match self.mechanism {
133            SaslMechanism::ScramSha512 => {
134                let mut mac = <Hmac<Sha512>>::new_from_slice(&server_key)
135                    .map_err(|_| AuthError::MalformedMessage)?;
136                mac.update(auth_message.as_bytes());
137                mac.finalize().into_bytes().to_vec()
138            }
139            SaslMechanism::ScramSha256 => {
140                let mut mac = <Hmac<Sha256>>::new_from_slice(&server_key)
141                    .map_err(|_| AuthError::MalformedMessage)?;
142                mac.update(auth_message.as_bytes());
143                mac.finalize().into_bytes().to_vec()
144            }
145            SaslMechanism::Plain | SaslMechanism::OAuthBearer | SaslMechanism::Gssapi => {
146                return Err(AuthError::MalformedMessage);
147            }
148        };
149        if expected.ct_eq(&v).unwrap_u8() != 1 {
150            return Err(AuthError::BadProof);
151        }
152        Ok(())
153    }
154}
155
156fn compute_proof_sha512(
157    password: &[u8],
158    salt: &[u8],
159    iters: u32,
160    auth_message: &[u8],
161) -> Result<(Vec<u8>, Vec<u8>), AuthError> {
162    let salted: [u8; 64] = pbkdf2::pbkdf2_hmac_array::<Sha512, 64>(password, salt, iters);
163    let mut client_key_mac =
164        <Hmac<Sha512>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
165    client_key_mac.update(b"Client Key");
166    let client_key = client_key_mac.finalize().into_bytes();
167    let stored_key = Sha512::digest(client_key);
168    let mut server_key_mac =
169        <Hmac<Sha512>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
170    server_key_mac.update(b"Server Key");
171    let server_key = server_key_mac.finalize().into_bytes().to_vec();
172
173    let mut client_sig_mac =
174        <Hmac<Sha512>>::new_from_slice(&stored_key).map_err(|_| AuthError::MalformedMessage)?;
175    client_sig_mac.update(auth_message);
176    let client_signature = client_sig_mac.finalize().into_bytes();
177    let proof: Vec<u8> = client_key
178        .iter()
179        .zip(client_signature.iter())
180        .map(|(a, b)| a ^ b)
181        .collect();
182    Ok((proof, server_key))
183}
184
185fn compute_proof_sha256(
186    password: &[u8],
187    salt: &[u8],
188    iters: u32,
189    auth_message: &[u8],
190) -> Result<(Vec<u8>, Vec<u8>), AuthError> {
191    let salted: [u8; 32] = pbkdf2::pbkdf2_hmac_array::<Sha256, 32>(password, salt, iters);
192    let mut client_key_mac =
193        <Hmac<Sha256>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
194    client_key_mac.update(b"Client Key");
195    let client_key = client_key_mac.finalize().into_bytes();
196    let stored_key = Sha256::digest(client_key);
197    let mut server_key_mac =
198        <Hmac<Sha256>>::new_from_slice(&salted).map_err(|_| AuthError::MalformedMessage)?;
199    server_key_mac.update(b"Server Key");
200    let server_key = server_key_mac.finalize().into_bytes().to_vec();
201
202    let mut client_sig_mac =
203        <Hmac<Sha256>>::new_from_slice(&stored_key).map_err(|_| AuthError::MalformedMessage)?;
204    client_sig_mac.update(auth_message);
205    let client_signature = client_sig_mac.finalize().into_bytes();
206    let proof: Vec<u8> = client_key
207        .iter()
208        .zip(client_signature.iter())
209        .map(|(a, b)| a ^ b)
210        .collect();
211    Ok((proof, server_key))
212}