fraiseql_wire/auth/scram/
mod.rs1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::Rng;
10use sha2::{Digest, Sha256};
11use std::fmt;
12use zeroize::Zeroizing;
13
14type HmacSha256 = Hmac<Sha256>;
15
16pub(crate) const MAX_SCRAM_ITERATIONS: u32 = 1_000_000;
23
24#[derive(Debug, Clone)]
26#[non_exhaustive]
27pub enum ScramError {
28 InvalidServerProof(String),
30 InvalidServerMessage(String),
32 Utf8Error(String),
34 Base64Error(String),
36}
37
38impl fmt::Display for ScramError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
42 ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
43 ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
44 ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
45 }
46 }
47}
48
49impl std::error::Error for ScramError {}
50
51#[derive(Clone, Debug)]
53pub struct ScramState {
54 auth_message: Vec<u8>,
56 server_key: Vec<u8>,
58}
59
60pub struct ScramClient {
62 username: String,
63 password: Zeroizing<String>,
66 nonce: String,
67}
68
69impl ScramClient {
70 #[must_use]
72 pub fn new(username: String, password: String) -> Self {
73 let mut rng = rand::rng();
75 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.random()).collect();
76 let nonce = BASE64.encode(&nonce_bytes);
77
78 Self {
79 username,
80 password: Zeroizing::new(password),
81 nonce,
82 }
83 }
84
85 #[must_use]
87 pub fn client_first(&self) -> String {
88 let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
93 format!("n,,n={},r={}", escaped_username, self.nonce)
94 }
95
96 pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
107 let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
109
110 if !server_nonce.starts_with(&self.nonce) {
112 return Err(ScramError::InvalidServerMessage(
113 "server nonce doesn't contain client nonce".to_string(),
114 ));
115 }
116
117 let salt_bytes = BASE64
119 .decode(&salt)
120 .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
121 let iterations = iterations
122 .parse::<u32>()
123 .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
124
125 if iterations > MAX_SCRAM_ITERATIONS {
128 return Err(ScramError::InvalidServerMessage(format!(
129 "server iteration count {iterations} exceeds maximum of {MAX_SCRAM_ITERATIONS}"
130 )));
131 }
132
133 let channel_binding = BASE64.encode(b"n,,");
135
136 let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
138
139 let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
145 let client_first_bare = format!("n={},r={}", escaped_username, self.nonce);
146 let auth_message = format!(
147 "{},{},{}",
148 client_first_bare, server_first, client_final_without_proof
149 );
150
151 let proof = calculate_client_proof(
153 &self.password,
154 &salt_bytes,
155 iterations,
156 auth_message.as_bytes(),
157 )?;
158
159 let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
161
162 let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
164
165 let state = ScramState {
166 auth_message: auth_message.into_bytes(),
167 server_key,
168 };
169
170 Ok((client_final, state))
171 }
172
173 pub fn verify_server_final(
181 &self,
182 server_final: &str,
183 state: &ScramState,
184 ) -> Result<(), ScramError> {
185 let server_sig_encoded = server_final
187 .strip_prefix("v=")
188 .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
189
190 let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
191 ScramError::Base64Error("invalid server signature encoding".to_string())
192 })?;
193
194 let expected_signature =
196 calculate_server_signature(&state.server_key, &state.auth_message)?;
197
198 if constant_time_compare(&server_signature, &expected_signature) {
200 Ok(())
201 } else {
202 Err(ScramError::InvalidServerProof(
203 "server signature verification failed".to_string(),
204 ))
205 }
206 }
207}
208
209pub(crate) fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
211 let mut nonce = String::new();
212 let mut salt = String::new();
213 let mut iterations = String::new();
214
215 for part in msg.split(',') {
216 if let Some(value) = part.strip_prefix("r=") {
217 nonce = value.to_string();
218 } else if let Some(value) = part.strip_prefix("s=") {
219 salt = value.to_string();
220 } else if let Some(value) = part.strip_prefix("i=") {
221 iterations = value.to_string();
222 }
223 }
224
225 if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
226 return Err(ScramError::InvalidServerMessage(
227 "missing required fields in server first message".to_string(),
228 ));
229 }
230
231 Ok((nonce, salt, iterations))
232}
233
234fn calculate_client_proof(
236 password: &str,
237 salt: &[u8],
238 iterations: u32,
239 auth_message: &[u8],
240) -> Result<Vec<u8>, ScramError> {
241 let password_bytes = password.as_bytes();
243 let mut salted_password = vec![0u8; 32]; let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
245
246 let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
248 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
249 client_key_hmac.update(b"Client Key");
250 let client_key = client_key_hmac.finalize().into_bytes();
251
252 let stored_key = Sha256::digest(client_key.to_vec().as_slice());
254
255 let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
257 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
258 client_sig_hmac.update(auth_message);
259 let client_signature = client_sig_hmac.finalize().into_bytes();
260
261 let mut proof = client_key.to_vec();
263 for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
264 *proof_byte ^= sig_byte;
265 }
266
267 Ok(proof.clone())
268}
269
270fn calculate_server_key(
272 password: &str,
273 salt: &[u8],
274 iterations: u32,
275) -> Result<Vec<u8>, ScramError> {
276 let password_bytes = password.as_bytes();
278 let mut salted_password = vec![0u8; 32];
279 let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
280
281 let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
283 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
284 server_key_hmac.update(b"Server Key");
285
286 Ok(server_key_hmac.finalize().into_bytes().to_vec())
287}
288
289fn calculate_server_signature(
291 server_key: &[u8],
292 auth_message: &[u8],
293) -> Result<Vec<u8>, ScramError> {
294 let mut hmac = HmacSha256::new_from_slice(server_key)
295 .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
296 hmac.update(auth_message);
297 Ok(hmac.finalize().into_bytes().to_vec())
298}
299
300pub(crate) fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
304 use subtle::ConstantTimeEq;
305 a.ct_eq(b).into()
306}
307
308#[cfg(test)]
309mod tests;