1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::{rngs::OsRng, Rng};
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15#[derive(Debug, Clone)]
17pub enum ScramError {
18 InvalidServerProof(String),
20 InvalidServerMessage(String),
22 Utf8Error(String),
24 Base64Error(String),
26}
27
28impl fmt::Display for ScramError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
32 ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
33 ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
34 ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
35 }
36 }
37}
38
39impl std::error::Error for ScramError {}
40
41#[derive(Clone, Debug)]
43pub struct ScramState {
44 auth_message: Vec<u8>,
46 server_key: Vec<u8>,
48}
49
50pub struct ScramClient {
52 username: String,
53 password: String,
54 nonce: String,
55}
56
57impl ScramClient {
58 pub fn new(username: String, password: String) -> Self {
60 let mut rng = OsRng;
62 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
63 let nonce = BASE64.encode(&nonce_bytes);
64
65 Self {
66 username,
67 password,
68 nonce,
69 }
70 }
71
72 pub fn client_first(&self) -> String {
74 let escaped_username = self.username.replace('=', "=3D").replace(',', "=2C");
79 format!("n,,n={},r={}", escaped_username, self.nonce)
80 }
81
82 pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
92 let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
94
95 if !server_nonce.starts_with(&self.nonce) {
97 return Err(ScramError::InvalidServerMessage(
98 "server nonce doesn't contain client nonce".to_string(),
99 ));
100 }
101
102 let salt_bytes = BASE64
104 .decode(&salt)
105 .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
106 let iterations = iterations
107 .parse::<u32>()
108 .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
109
110 let channel_binding = BASE64.encode(b"n,,");
112
113 let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
115
116 let client_first_bare = format!("n={},r={}", self.username, self.nonce);
119 let auth_message = format!(
120 "{},{},{}",
121 client_first_bare, server_first, client_final_without_proof
122 );
123
124 let proof = calculate_client_proof(
126 &self.password,
127 &salt_bytes,
128 iterations,
129 auth_message.as_bytes(),
130 )?;
131
132 let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
134
135 let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
137
138 let state = ScramState {
139 auth_message: auth_message.into_bytes(),
140 server_key,
141 };
142
143 Ok((client_final, state))
144 }
145
146 pub fn verify_server_final(
154 &self,
155 server_final: &str,
156 state: &ScramState,
157 ) -> Result<(), ScramError> {
158 let server_sig_encoded = server_final
160 .strip_prefix("v=")
161 .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
162
163 let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
164 ScramError::Base64Error("invalid server signature encoding".to_string())
165 })?;
166
167 let expected_signature =
169 calculate_server_signature(&state.server_key, &state.auth_message)?;
170
171 if constant_time_compare(&server_signature, &expected_signature) {
173 Ok(())
174 } else {
175 Err(ScramError::InvalidServerProof(
176 "server signature verification failed".to_string(),
177 ))
178 }
179 }
180}
181
182fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
184 let mut nonce = String::new();
185 let mut salt = String::new();
186 let mut iterations = String::new();
187
188 for part in msg.split(',') {
189 if let Some(value) = part.strip_prefix("r=") {
190 nonce = value.to_string();
191 } else if let Some(value) = part.strip_prefix("s=") {
192 salt = value.to_string();
193 } else if let Some(value) = part.strip_prefix("i=") {
194 iterations = value.to_string();
195 }
196 }
197
198 if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
199 return Err(ScramError::InvalidServerMessage(
200 "missing required fields in server first message".to_string(),
201 ));
202 }
203
204 Ok((nonce, salt, iterations))
205}
206
207fn calculate_client_proof(
209 password: &str,
210 salt: &[u8],
211 iterations: u32,
212 auth_message: &[u8],
213) -> Result<Vec<u8>, ScramError> {
214 let password_bytes = password.as_bytes();
216 let mut salted_password = vec![0u8; 32]; let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
218
219 let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
221 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
222 client_key_hmac.update(b"Client Key");
223 let client_key = client_key_hmac.finalize().into_bytes();
224
225 let stored_key = Sha256::digest(client_key.to_vec().as_slice());
227
228 let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
230 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
231 client_sig_hmac.update(auth_message);
232 let client_signature = client_sig_hmac.finalize().into_bytes();
233
234 let mut proof = client_key.to_vec();
236 for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
237 *proof_byte ^= sig_byte;
238 }
239
240 Ok(proof.clone())
241}
242
243fn calculate_server_key(
245 password: &str,
246 salt: &[u8],
247 iterations: u32,
248) -> Result<Vec<u8>, ScramError> {
249 let password_bytes = password.as_bytes();
251 let mut salted_password = vec![0u8; 32];
252 let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
253
254 let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
256 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
257 server_key_hmac.update(b"Server Key");
258
259 Ok(server_key_hmac.finalize().into_bytes().to_vec())
260}
261
262fn calculate_server_signature(
264 server_key: &[u8],
265 auth_message: &[u8],
266) -> Result<Vec<u8>, ScramError> {
267 let mut hmac = HmacSha256::new_from_slice(server_key)
268 .map_err(|_| ScramError::Utf8Error("invalid HMAC key for server signature".to_string()))?;
269 hmac.update(auth_message);
270 Ok(hmac.finalize().into_bytes().to_vec())
271}
272
273fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
277 use subtle::ConstantTimeEq;
278 a.ct_eq(b).into()
279}
280
281#[cfg(test)]
282mod tests {
283 #![allow(clippy::unwrap_used)] use super::*;
285
286 #[test]
287 fn test_scram_client_creation() {
288 let client = ScramClient::new("user".to_string(), "password".to_string());
289 assert_eq!(client.username, "user");
290 assert_eq!(client.password, "password");
291 assert!(!client.nonce.is_empty());
292 }
293
294 #[test]
295 fn test_client_first_message_format() {
296 let client = ScramClient::new("alice".to_string(), "secret".to_string());
297 let first = client.client_first();
298
299 assert!(first.starts_with("n,,n=alice,r="));
301 assert!(first.len() > 20);
302 }
303
304 #[test]
305 fn test_parse_server_first_valid() {
306 let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
307 let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
308
309 assert_eq!(nonce, "client_nonce_server_nonce");
310 assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
311 assert_eq!(iterations, "4096");
312 }
313
314 #[test]
315 fn test_parse_server_first_invalid() {
316 let server_first = "r=nonce,s=salt"; assert!(parse_server_first(server_first).is_err());
318 }
319
320 #[test]
321 fn test_constant_time_compare_equal() {
322 let a = b"test_value";
323 let b_arr = b"test_value";
324 assert!(constant_time_compare(a, b_arr));
325 }
326
327 #[test]
328 fn test_constant_time_compare_different() {
329 let a = b"test_value";
330 let b_arr = b"test_wrong";
331 assert!(!constant_time_compare(a, b_arr));
332 }
333
334 #[test]
335 fn test_constant_time_compare_different_length() {
336 let a = b"test";
337 let b_arr = b"test_longer";
338 assert!(!constant_time_compare(a, b_arr));
339 }
340
341 #[test]
342 fn test_scram_client_final_flow() {
343 let mut client = ScramClient::new("user".to_string(), "password".to_string());
344 let _client_first = client.client_first();
345
346 let server_nonce = format!("{}server_nonce_part", client.nonce);
348 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
349
350 let result = client.client_final(&server_first);
352 assert!(result.is_ok());
353
354 let (client_final, state) = result.unwrap();
355 assert!(client_final.starts_with("c="));
356 assert!(!state.auth_message.is_empty());
357 }
358}