1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::Rng;
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15#[derive(Debug, Clone)]
17pub enum ScramError {
18 InvalidServerProof(String),
20 InvalidServerMessage(String),
22 Utf8Error(String),
24 Base64Error(String),
26}
27
28impl fmt::Display for ScramError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
32 ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
33 ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
34 ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
35 }
36 }
37}
38
39impl std::error::Error for ScramError {}
40
41#[derive(Clone, Debug)]
43pub struct ScramState {
44 auth_message: Vec<u8>,
46 server_key: Vec<u8>,
48}
49
50pub struct ScramClient {
52 username: String,
53 password: String,
54 nonce: String,
55}
56
57impl ScramClient {
58 pub fn new(username: String, password: String) -> Self {
60 let mut rng = rand::thread_rng();
62 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
63 let nonce = BASE64.encode(&nonce_bytes);
64
65 Self {
66 username,
67 password,
68 nonce,
69 }
70 }
71
72 pub fn client_first(&self) -> String {
74 format!("n,a={},r={}", self.username, self.nonce)
77 }
78
79 pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
83 let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
85
86 if !server_nonce.starts_with(&self.nonce) {
88 return Err(ScramError::InvalidServerMessage(
89 "server nonce doesn't contain client nonce".to_string(),
90 ));
91 }
92
93 let salt_bytes = BASE64
95 .decode(&salt)
96 .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
97 let iterations = iterations
98 .parse::<u32>()
99 .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
100
101 let channel_binding = BASE64.encode(b"n,,");
103
104 let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
106
107 let client_first_bare = format!("a={},r={}", self.username, self.nonce);
109 let auth_message = format!(
110 "{},{},{}",
111 client_first_bare, server_first, client_final_without_proof
112 );
113
114 let proof = calculate_client_proof(
116 &self.password,
117 &salt_bytes,
118 iterations,
119 auth_message.as_bytes(),
120 )?;
121
122 let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
124
125 let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
127
128 let state = ScramState {
129 auth_message: auth_message.into_bytes(),
130 server_key,
131 };
132
133 Ok((client_final, state))
134 }
135
136 pub fn verify_server_final(
138 &self,
139 server_final: &str,
140 state: &ScramState,
141 ) -> Result<(), ScramError> {
142 let server_sig_encoded = server_final
144 .strip_prefix("v=")
145 .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
146
147 let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
148 ScramError::Base64Error("invalid server signature encoding".to_string())
149 })?;
150
151 let expected_signature = calculate_server_signature(&state.server_key, &state.auth_message);
153
154 if constant_time_compare(&server_signature, &expected_signature) {
156 Ok(())
157 } else {
158 Err(ScramError::InvalidServerProof(
159 "server signature verification failed".to_string(),
160 ))
161 }
162 }
163}
164
165fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
167 let mut nonce = String::new();
168 let mut salt = String::new();
169 let mut iterations = String::new();
170
171 for part in msg.split(',') {
172 if let Some(value) = part.strip_prefix("r=") {
173 nonce = value.to_string();
174 } else if let Some(value) = part.strip_prefix("s=") {
175 salt = value.to_string();
176 } else if let Some(value) = part.strip_prefix("i=") {
177 iterations = value.to_string();
178 }
179 }
180
181 if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
182 return Err(ScramError::InvalidServerMessage(
183 "missing required fields in server first message".to_string(),
184 ));
185 }
186
187 Ok((nonce, salt, iterations))
188}
189
190fn calculate_client_proof(
192 password: &str,
193 salt: &[u8],
194 iterations: u32,
195 auth_message: &[u8],
196) -> Result<Vec<u8>, ScramError> {
197 let password_bytes = password.as_bytes();
199 let mut salted_password = vec![0u8; 32]; let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
201
202 let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
204 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
205 client_key_hmac.update(b"Client Key");
206 let client_key = client_key_hmac.finalize().into_bytes();
207
208 let stored_key = Sha256::digest(client_key.to_vec().as_slice());
210
211 let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
213 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
214 client_sig_hmac.update(auth_message);
215 let client_signature = client_sig_hmac.finalize().into_bytes();
216
217 let mut proof = client_key.to_vec();
219 for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
220 *proof_byte ^= sig_byte;
221 }
222
223 Ok(proof.to_vec())
224}
225
226fn calculate_server_key(
228 password: &str,
229 salt: &[u8],
230 iterations: u32,
231) -> Result<Vec<u8>, ScramError> {
232 let password_bytes = password.as_bytes();
234 let mut salted_password = vec![0u8; 32];
235 let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
236
237 let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
239 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
240 server_key_hmac.update(b"Server Key");
241
242 Ok(server_key_hmac.finalize().into_bytes().to_vec())
243}
244
245fn calculate_server_signature(server_key: &[u8], auth_message: &[u8]) -> Vec<u8> {
247 let mut hmac = HmacSha256::new_from_slice(server_key).expect("HMAC key should be valid");
248 hmac.update(auth_message);
249 hmac.finalize().into_bytes().to_vec()
250}
251
252fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
254 if a.len() != b.len() {
255 return false;
256 }
257 let mut result = 0u8;
258 for (x, y) in a.iter().zip(b.iter()) {
259 result |= x ^ y;
260 }
261 result == 0
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_scram_client_creation() {
270 let client = ScramClient::new("user".to_string(), "password".to_string());
271 assert_eq!(client.username, "user");
272 assert_eq!(client.password, "password");
273 assert!(!client.nonce.is_empty());
274 }
275
276 #[test]
277 fn test_client_first_message_format() {
278 let client = ScramClient::new("alice".to_string(), "secret".to_string());
279 let first = client.client_first();
280
281 assert!(first.starts_with("n,a=alice,r="));
282 assert!(first.len() > 20);
283 }
284
285 #[test]
286 fn test_parse_server_first_valid() {
287 let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
288 let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
289
290 assert_eq!(nonce, "client_nonce_server_nonce");
291 assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
292 assert_eq!(iterations, "4096");
293 }
294
295 #[test]
296 fn test_parse_server_first_invalid() {
297 let server_first = "r=nonce,s=salt"; assert!(parse_server_first(server_first).is_err());
299 }
300
301 #[test]
302 fn test_constant_time_compare_equal() {
303 let a = b"test_value";
304 let b_arr = b"test_value";
305 assert!(constant_time_compare(a, b_arr));
306 }
307
308 #[test]
309 fn test_constant_time_compare_different() {
310 let a = b"test_value";
311 let b_arr = b"test_wrong";
312 assert!(!constant_time_compare(a, b_arr));
313 }
314
315 #[test]
316 fn test_constant_time_compare_different_length() {
317 let a = b"test";
318 let b_arr = b"test_longer";
319 assert!(!constant_time_compare(a, b_arr));
320 }
321
322 #[test]
323 fn test_scram_client_final_flow() {
324 let mut client = ScramClient::new("user".to_string(), "password".to_string());
325 let _client_first = client.client_first();
326
327 let server_nonce = format!("{}server_nonce_part", client.nonce);
329 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
330
331 let result = client.client_final(&server_first);
333 assert!(result.is_ok());
334
335 let (client_final, state) = result.unwrap();
336 assert!(client_final.starts_with("c="));
337 assert!(!state.auth_message.is_empty());
338 }
339}