1use std::collections::HashMap;
16
17use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
18
19use crate::backend::auth::{hmac_sha256, pbkdf2_hmac_sha256, sha256};
20
21#[derive(Debug, Clone)]
24pub struct ScramVerifier {
25 pub salt: Vec<u8>,
26 pub iterations: u32,
27 pub stored_key: [u8; 32],
28 pub server_key: [u8; 32],
29}
30
31impl ScramVerifier {
32 pub fn from_password(password: &str, salt: Vec<u8>, iterations: u32) -> Self {
36 let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
37 let client_key = hmac_sha256(&salted, b"Client Key");
38 let stored_key = sha256(&client_key);
39 let server_key = hmac_sha256(&salted, b"Server Key");
40 Self {
41 salt,
42 iterations,
43 stored_key,
44 server_key,
45 }
46 }
47
48 pub fn parse(s: &str) -> Option<Self> {
52 let rest = s.strip_prefix("SCRAM-SHA-256$")?;
53 let (params, keys) = rest.split_once('$')?;
54 let (iter_str, salt_b64) = params.split_once(':')?;
55 let (stored_b64, server_b64) = keys.split_once(':')?;
56 let iterations: u32 = iter_str.parse().ok()?;
57 let salt = BASE64.decode(salt_b64.trim()).ok()?;
58 let stored = BASE64.decode(stored_b64.trim()).ok()?;
59 let server = BASE64.decode(server_b64.trim()).ok()?;
60 if stored.len() != 32 || server.len() != 32 {
61 return None;
62 }
63 let mut stored_key = [0u8; 32];
64 stored_key.copy_from_slice(&stored);
65 let mut server_key = [0u8; 32];
66 server_key.copy_from_slice(&server);
67 Some(Self {
68 salt,
69 iterations,
70 stored_key,
71 server_key,
72 })
73 }
74}
75
76#[derive(Debug, Clone, Default)]
83pub struct AuthFile {
84 users: HashMap<String, ScramVerifier>,
85}
86
87impl AuthFile {
88 pub fn load(path: &str) -> Result<Self, String> {
89 let data = std::fs::read_to_string(path)
90 .map_err(|e| format!("reading auth_file {}: {}", path, e))?;
91 Self::parse_str(&data, path)
92 }
93
94 pub fn parse_str(data: &str, path: &str) -> Result<Self, String> {
95 let mut users = HashMap::new();
96 for (lineno, raw) in data.lines().enumerate() {
97 let line = raw.trim();
98 if line.is_empty() || line.starts_with('#') {
99 continue;
100 }
101 let (user, secret) = line
102 .split_once(':')
103 .ok_or_else(|| format!("{}:{}: expected `user:secret`", path, lineno + 1))?;
104 let user = unquote(user.trim());
105 let secret = unquote(secret.trim());
106 let verifier = if secret.starts_with("SCRAM-SHA-256$") {
107 ScramVerifier::parse(&secret)
108 .ok_or_else(|| format!("{}:{}: malformed SCRAM verifier", path, lineno + 1))?
109 } else {
110 let salt = sha256(user.as_bytes())[..16].to_vec();
114 ScramVerifier::from_password(&secret, salt, 4096)
115 };
116 users.insert(user, verifier);
117 }
118 Ok(Self { users })
119 }
120
121 pub fn get(&self, user: &str) -> Option<&ScramVerifier> {
122 self.users.get(user)
123 }
124
125 pub fn is_empty(&self) -> bool {
126 self.users.is_empty()
127 }
128}
129
130fn unquote(s: &str) -> String {
131 let t = s.trim();
132 if t.len() >= 2 && t.starts_with('"') && t.ends_with('"') {
133 t[1..t.len() - 1].to_string()
134 } else {
135 t.to_string()
136 }
137}
138
139pub struct ScramServer {
141 verifier: ScramVerifier,
142 combined_nonce: String,
143 client_first_bare: String,
144 server_first: String,
145}
146
147impl ScramServer {
148 pub fn start(
153 verifier: ScramVerifier,
154 client_first: &str,
155 server_nonce: &str,
156 ) -> Result<(Self, String), String> {
157 let mut parts = client_first.splitn(3, ',');
160 let _gs2_cbind = parts.next();
161 let _gs2_authzid = parts.next();
162 let bare = parts
163 .next()
164 .ok_or_else(|| "malformed client-first (no bare part)".to_string())?;
165
166 let client_nonce = bare
167 .split(',')
168 .find_map(|f| f.strip_prefix("r="))
169 .ok_or_else(|| "client-first missing r=".to_string())?;
170 if client_nonce.is_empty() {
171 return Err("empty client nonce".to_string());
172 }
173
174 let combined_nonce = format!("{}{}", client_nonce, server_nonce);
175 let salt_b64 = BASE64.encode(&verifier.salt);
176 let server_first = format!(
177 "r={},s={},i={}",
178 combined_nonce, salt_b64, verifier.iterations
179 );
180
181 Ok((
182 Self {
183 verifier,
184 combined_nonce,
185 client_first_bare: bare.to_string(),
186 server_first: server_first.clone(),
187 },
188 server_first,
189 ))
190 }
191
192 pub fn finish(&self, client_final: &str) -> Result<String, String> {
196 let proof_pos = client_final
198 .rfind(",p=")
199 .ok_or_else(|| "client-final missing p=".to_string())?;
200 let without_proof = &client_final[..proof_pos];
201 let proof_b64 = &client_final[proof_pos + 3..];
202
203 let echoed_nonce = without_proof
205 .split(',')
206 .find_map(|f| f.strip_prefix("r="))
207 .ok_or_else(|| "client-final missing r=".to_string())?;
208 if echoed_nonce != self.combined_nonce {
209 return Err("nonce mismatch".to_string());
210 }
211
212 let proof = BASE64
213 .decode(proof_b64.trim())
214 .map_err(|e| format!("bad proof base64: {}", e))?;
215 if proof.len() != 32 {
216 return Err("proof wrong length".to_string());
217 }
218
219 let auth_message = format!(
220 "{},{},{}",
221 self.client_first_bare, self.server_first, without_proof
222 );
223
224 let client_signature = hmac_sha256(&self.verifier.stored_key, auth_message.as_bytes());
228 let mut client_key = [0u8; 32];
229 for i in 0..32 {
230 client_key[i] = proof[i] ^ client_signature[i];
231 }
232 let derived_stored = sha256(&client_key);
233 if !constant_time_eq(&derived_stored, &self.verifier.stored_key) {
234 return Err("authentication failed (proof mismatch)".to_string());
235 }
236
237 let server_signature = hmac_sha256(&self.verifier.server_key, auth_message.as_bytes());
239 Ok(format!("v={}", BASE64.encode(server_signature)))
240 }
241}
242
243fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
244 let mut diff = 0u8;
245 for i in 0..32 {
246 diff |= a[i] ^ b[i];
247 }
248 diff == 0
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::backend::auth::Scram;
255
256 #[test]
257 fn parse_pg_verifier_roundtrips_from_password() {
258 let v = ScramVerifier::from_password("s3cret", b"0123456789abcdef".to_vec(), 4096);
260 let s = format!(
261 "SCRAM-SHA-256${}:{}${}:{}",
262 v.iterations,
263 BASE64.encode(&v.salt),
264 BASE64.encode(v.stored_key),
265 BASE64.encode(v.server_key),
266 );
267 let p = ScramVerifier::parse(&s).expect("parses");
268 assert_eq!(p.iterations, v.iterations);
269 assert_eq!(p.salt, v.salt);
270 assert_eq!(p.stored_key, v.stored_key);
271 assert_eq!(p.server_key, v.server_key);
272 }
273
274 #[test]
275 fn full_scram_handshake_client_vs_server() {
276 let password = "correct horse battery staple";
278 let verifier = ScramVerifier::from_password(password, b"saltsaltsaltsalt".to_vec(), 4096);
279
280 let (mut client, init) = Scram::client_first("clientNONCE123");
281 let data = &init.0;
284 let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
285 let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
286
287 let (server, server_first) =
288 ScramServer::start(verifier.clone(), client_first, "serverNONCE456").unwrap();
289
290 let client_final = client
291 .client_final(server_first.as_bytes(), password)
292 .unwrap();
293 let server_final = server
294 .finish(std::str::from_utf8(&client_final.0).unwrap())
295 .unwrap();
296
297 client.verify_server(server_final.as_bytes()).unwrap();
299 }
300
301 #[test]
302 fn wrong_password_is_rejected() {
303 let verifier = ScramVerifier::from_password("rightpw", b"saltsaltsaltsalt".to_vec(), 4096);
304 let (mut client, init) = Scram::client_first("nonceAAA");
305 let data = &init.0;
306 let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
307 let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
308 let (server, server_first) =
309 ScramServer::start(verifier, client_first, "nonceBBB").unwrap();
310 let client_final = client
312 .client_final(server_first.as_bytes(), "wrongpw")
313 .unwrap();
314 let res = server.finish(std::str::from_utf8(&client_final.0).unwrap());
315 assert!(res.is_err(), "wrong password must be rejected");
316 }
317
318 #[test]
319 fn auth_file_parses_plaintext_and_verifier() {
320 let v = ScramVerifier::from_password("pw", b"0123456789abcdef".to_vec(), 4096);
321 let verifier_line = format!(
322 "carol:SCRAM-SHA-256${}:{}${}:{}",
323 v.iterations,
324 BASE64.encode(&v.salt),
325 BASE64.encode(v.stored_key),
326 BASE64.encode(v.server_key),
327 );
328 let body = format!(
329 "# comment\nalice:secret\n\nbob:\"quoted\"\n{}\n",
330 verifier_line
331 );
332 let af = AuthFile::parse_str(&body, "test").unwrap();
333 assert!(af.get("alice").is_some());
334 assert!(af.get("bob").is_some());
335 assert!(af.get("carol").is_some());
336 assert!(af.get("dave").is_none());
337 }
338}