1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
7use hmac::{Hmac, Mac};
8use pbkdf2::pbkdf2;
9use rand::Rng;
10use sha2::{Digest, Sha256};
11use std::fmt;
12
13type HmacSha256 = Hmac<Sha256>;
14
15#[derive(Debug, Clone)]
17pub enum ScramError {
18 InvalidServerProof(String),
20 InvalidServerMessage(String),
22 Utf8Error(String),
24 Base64Error(String),
26}
27
28impl fmt::Display for ScramError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 ScramError::InvalidServerProof(msg) => write!(f, "invalid server proof: {}", msg),
32 ScramError::InvalidServerMessage(msg) => write!(f, "invalid server message: {}", msg),
33 ScramError::Utf8Error(msg) => write!(f, "UTF-8 error: {}", msg),
34 ScramError::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
35 }
36 }
37}
38
39impl std::error::Error for ScramError {}
40
41#[derive(Clone, Debug)]
43pub struct ScramState {
44 auth_message: Vec<u8>,
46 server_key: Vec<u8>,
48}
49
50pub struct ScramClient {
52 username: String,
53 password: String,
54 nonce: String,
55}
56
57impl ScramClient {
58 pub fn new(username: String, password: String) -> Self {
60 let mut rng = rand::thread_rng();
62 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
63 let nonce = BASE64.encode(&nonce_bytes);
64
65 Self {
66 username,
67 password,
68 nonce,
69 }
70 }
71
72 pub fn client_first(&self) -> String {
74 format!("n,,n={},r={}", self.username, self.nonce)
79 }
80
81 pub fn client_final(&mut self, server_first: &str) -> Result<(String, ScramState), ScramError> {
85 let (server_nonce, salt, iterations) = parse_server_first(server_first)?;
87
88 if !server_nonce.starts_with(&self.nonce) {
90 return Err(ScramError::InvalidServerMessage(
91 "server nonce doesn't contain client nonce".to_string(),
92 ));
93 }
94
95 let salt_bytes = BASE64
97 .decode(&salt)
98 .map_err(|_| ScramError::Base64Error("invalid salt encoding".to_string()))?;
99 let iterations = iterations
100 .parse::<u32>()
101 .map_err(|_| ScramError::InvalidServerMessage("invalid iteration count".to_string()))?;
102
103 let channel_binding = BASE64.encode(b"n,,");
105
106 let client_final_without_proof = format!("c={},r={}", channel_binding, server_nonce);
108
109 let client_first_bare = format!("n={},r={}", self.username, self.nonce);
112 let auth_message = format!(
113 "{},{},{}",
114 client_first_bare, server_first, client_final_without_proof
115 );
116
117 let proof = calculate_client_proof(
119 &self.password,
120 &salt_bytes,
121 iterations,
122 auth_message.as_bytes(),
123 )?;
124
125 let server_key = calculate_server_key(&self.password, &salt_bytes, iterations)?;
127
128 let client_final = format!("{},p={}", client_final_without_proof, BASE64.encode(&proof));
130
131 let state = ScramState {
132 auth_message: auth_message.into_bytes(),
133 server_key,
134 };
135
136 Ok((client_final, state))
137 }
138
139 pub fn verify_server_final(
141 &self,
142 server_final: &str,
143 state: &ScramState,
144 ) -> Result<(), ScramError> {
145 let server_sig_encoded = server_final
147 .strip_prefix("v=")
148 .ok_or_else(|| ScramError::InvalidServerMessage("missing 'v=' prefix".to_string()))?;
149
150 let server_signature = BASE64.decode(server_sig_encoded).map_err(|_| {
151 ScramError::Base64Error("invalid server signature encoding".to_string())
152 })?;
153
154 let expected_signature = calculate_server_signature(&state.server_key, &state.auth_message);
156
157 if constant_time_compare(&server_signature, &expected_signature) {
159 Ok(())
160 } else {
161 Err(ScramError::InvalidServerProof(
162 "server signature verification failed".to_string(),
163 ))
164 }
165 }
166}
167
168fn parse_server_first(msg: &str) -> Result<(String, String, String), ScramError> {
170 let mut nonce = String::new();
171 let mut salt = String::new();
172 let mut iterations = String::new();
173
174 for part in msg.split(',') {
175 if let Some(value) = part.strip_prefix("r=") {
176 nonce = value.to_string();
177 } else if let Some(value) = part.strip_prefix("s=") {
178 salt = value.to_string();
179 } else if let Some(value) = part.strip_prefix("i=") {
180 iterations = value.to_string();
181 }
182 }
183
184 if nonce.is_empty() || salt.is_empty() || iterations.is_empty() {
185 return Err(ScramError::InvalidServerMessage(
186 "missing required fields in server first message".to_string(),
187 ));
188 }
189
190 Ok((nonce, salt, iterations))
191}
192
193fn calculate_client_proof(
195 password: &str,
196 salt: &[u8],
197 iterations: u32,
198 auth_message: &[u8],
199) -> Result<Vec<u8>, ScramError> {
200 let password_bytes = password.as_bytes();
202 let mut salted_password = vec![0u8; 32]; let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
204
205 let mut client_key_hmac = HmacSha256::new_from_slice(&salted_password)
207 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
208 client_key_hmac.update(b"Client Key");
209 let client_key = client_key_hmac.finalize().into_bytes();
210
211 let stored_key = Sha256::digest(client_key.to_vec().as_slice());
213
214 let mut client_sig_hmac = HmacSha256::new_from_slice(&stored_key)
216 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
217 client_sig_hmac.update(auth_message);
218 let client_signature = client_sig_hmac.finalize().into_bytes();
219
220 let mut proof = client_key.to_vec();
222 for (proof_byte, sig_byte) in proof.iter_mut().zip(client_signature.iter()) {
223 *proof_byte ^= sig_byte;
224 }
225
226 Ok(proof.to_vec())
227}
228
229fn calculate_server_key(
231 password: &str,
232 salt: &[u8],
233 iterations: u32,
234) -> Result<Vec<u8>, ScramError> {
235 let password_bytes = password.as_bytes();
237 let mut salted_password = vec![0u8; 32];
238 let _ = pbkdf2::<HmacSha256>(password_bytes, salt, iterations, &mut salted_password);
239
240 let mut server_key_hmac = HmacSha256::new_from_slice(&salted_password)
242 .map_err(|_| ScramError::Utf8Error("HMAC key error".to_string()))?;
243 server_key_hmac.update(b"Server Key");
244
245 Ok(server_key_hmac.finalize().into_bytes().to_vec())
246}
247
248fn calculate_server_signature(server_key: &[u8], auth_message: &[u8]) -> Vec<u8> {
250 let mut hmac = HmacSha256::new_from_slice(server_key).expect("HMAC key should be valid");
251 hmac.update(auth_message);
252 hmac.finalize().into_bytes().to_vec()
253}
254
255fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
257 if a.len() != b.len() {
258 return false;
259 }
260 let mut result = 0u8;
261 for (x, y) in a.iter().zip(b.iter()) {
262 result |= x ^ y;
263 }
264 result == 0
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_scram_client_creation() {
273 let client = ScramClient::new("user".to_string(), "password".to_string());
274 assert_eq!(client.username, "user");
275 assert_eq!(client.password, "password");
276 assert!(!client.nonce.is_empty());
277 }
278
279 #[test]
280 fn test_client_first_message_format() {
281 let client = ScramClient::new("alice".to_string(), "secret".to_string());
282 let first = client.client_first();
283
284 assert!(first.starts_with("n,,n=alice,r="));
286 assert!(first.len() > 20);
287 }
288
289 #[test]
290 fn test_parse_server_first_valid() {
291 let server_first = "r=client_nonce_server_nonce,s=aW1hZ2luYXJ5c2FsdA==,i=4096";
292 let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
293
294 assert_eq!(nonce, "client_nonce_server_nonce");
295 assert_eq!(salt, "aW1hZ2luYXJ5c2FsdA==");
296 assert_eq!(iterations, "4096");
297 }
298
299 #[test]
300 fn test_parse_server_first_invalid() {
301 let server_first = "r=nonce,s=salt"; assert!(parse_server_first(server_first).is_err());
303 }
304
305 #[test]
306 fn test_constant_time_compare_equal() {
307 let a = b"test_value";
308 let b_arr = b"test_value";
309 assert!(constant_time_compare(a, b_arr));
310 }
311
312 #[test]
313 fn test_constant_time_compare_different() {
314 let a = b"test_value";
315 let b_arr = b"test_wrong";
316 assert!(!constant_time_compare(a, b_arr));
317 }
318
319 #[test]
320 fn test_constant_time_compare_different_length() {
321 let a = b"test";
322 let b_arr = b"test_longer";
323 assert!(!constant_time_compare(a, b_arr));
324 }
325
326 #[test]
327 fn test_scram_client_final_flow() {
328 let mut client = ScramClient::new("user".to_string(), "password".to_string());
329 let _client_first = client.client_first();
330
331 let server_nonce = format!("{}server_nonce_part", client.nonce);
333 let server_first = format!("r={},s={},i=4096", server_nonce, BASE64.encode(b"salty"));
334
335 let result = client.client_final(&server_first);
337 assert!(result.is_ok());
338
339 let (client_final, state) = result.unwrap();
340 assert!(client_final.starts_with("c="));
341 assert!(!state.auth_message.is_empty());
342 }
343}