1use std::collections::HashMap;
16
17use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
18
19use crate::backend::auth::{hmac_sha256, pbkdf2_hmac_sha256, sha256};
20
21#[derive(Debug, Clone)]
24pub struct ScramVerifier {
25 pub salt: Vec<u8>,
26 pub iterations: u32,
27 pub stored_key: [u8; 32],
28 pub server_key: [u8; 32],
29}
30
31impl ScramVerifier {
32 pub fn from_password(password: &str, salt: Vec<u8>, iterations: u32) -> Self {
36 let salted = pbkdf2_hmac_sha256(password.as_bytes(), &salt, iterations);
37 let client_key = hmac_sha256(&salted, b"Client Key");
38 let stored_key = sha256(&client_key);
39 let server_key = hmac_sha256(&salted, b"Server Key");
40 Self { salt, iterations, stored_key, server_key }
41 }
42
43 pub fn parse(s: &str) -> Option<Self> {
47 let rest = s.strip_prefix("SCRAM-SHA-256$")?;
48 let (params, keys) = rest.split_once('$')?;
49 let (iter_str, salt_b64) = params.split_once(':')?;
50 let (stored_b64, server_b64) = keys.split_once(':')?;
51 let iterations: u32 = iter_str.parse().ok()?;
52 let salt = BASE64.decode(salt_b64.trim()).ok()?;
53 let stored = BASE64.decode(stored_b64.trim()).ok()?;
54 let server = BASE64.decode(server_b64.trim()).ok()?;
55 if stored.len() != 32 || server.len() != 32 {
56 return None;
57 }
58 let mut stored_key = [0u8; 32];
59 stored_key.copy_from_slice(&stored);
60 let mut server_key = [0u8; 32];
61 server_key.copy_from_slice(&server);
62 Some(Self { salt, iterations, stored_key, server_key })
63 }
64}
65
66#[derive(Debug, Clone, Default)]
73pub struct AuthFile {
74 users: HashMap<String, ScramVerifier>,
75}
76
77impl AuthFile {
78 pub fn load(path: &str) -> Result<Self, String> {
79 let data = std::fs::read_to_string(path)
80 .map_err(|e| format!("reading auth_file {}: {}", path, e))?;
81 Self::parse_str(&data, path)
82 }
83
84 pub fn parse_str(data: &str, path: &str) -> Result<Self, String> {
85 let mut users = HashMap::new();
86 for (lineno, raw) in data.lines().enumerate() {
87 let line = raw.trim();
88 if line.is_empty() || line.starts_with('#') {
89 continue;
90 }
91 let (user, secret) = line
92 .split_once(':')
93 .ok_or_else(|| format!("{}:{}: expected `user:secret`", path, lineno + 1))?;
94 let user = unquote(user.trim());
95 let secret = unquote(secret.trim());
96 let verifier = if secret.starts_with("SCRAM-SHA-256$") {
97 ScramVerifier::parse(&secret).ok_or_else(|| {
98 format!("{}:{}: malformed SCRAM verifier", path, lineno + 1)
99 })?
100 } else {
101 let salt = sha256(user.as_bytes())[..16].to_vec();
105 ScramVerifier::from_password(&secret, salt, 4096)
106 };
107 users.insert(user, verifier);
108 }
109 Ok(Self { users })
110 }
111
112 pub fn get(&self, user: &str) -> Option<&ScramVerifier> {
113 self.users.get(user)
114 }
115
116 pub fn is_empty(&self) -> bool {
117 self.users.is_empty()
118 }
119}
120
121fn unquote(s: &str) -> String {
122 let t = s.trim();
123 if t.len() >= 2 && t.starts_with('"') && t.ends_with('"') {
124 t[1..t.len() - 1].to_string()
125 } else {
126 t.to_string()
127 }
128}
129
130pub struct ScramServer {
132 verifier: ScramVerifier,
133 combined_nonce: String,
134 client_first_bare: String,
135 server_first: String,
136}
137
138impl ScramServer {
139 pub fn start(
144 verifier: ScramVerifier,
145 client_first: &str,
146 server_nonce: &str,
147 ) -> Result<(Self, String), String> {
148 let mut parts = client_first.splitn(3, ',');
151 let _gs2_cbind = parts.next();
152 let _gs2_authzid = parts.next();
153 let bare = parts
154 .next()
155 .ok_or_else(|| "malformed client-first (no bare part)".to_string())?;
156
157 let client_nonce = bare
158 .split(',')
159 .find_map(|f| f.strip_prefix("r="))
160 .ok_or_else(|| "client-first missing r=".to_string())?;
161 if client_nonce.is_empty() {
162 return Err("empty client nonce".to_string());
163 }
164
165 let combined_nonce = format!("{}{}", client_nonce, server_nonce);
166 let salt_b64 = BASE64.encode(&verifier.salt);
167 let server_first = format!("r={},s={},i={}", combined_nonce, salt_b64, verifier.iterations);
168
169 Ok((
170 Self {
171 verifier,
172 combined_nonce,
173 client_first_bare: bare.to_string(),
174 server_first: server_first.clone(),
175 },
176 server_first,
177 ))
178 }
179
180 pub fn finish(&self, client_final: &str) -> Result<String, String> {
184 let proof_pos = client_final
186 .rfind(",p=")
187 .ok_or_else(|| "client-final missing p=".to_string())?;
188 let without_proof = &client_final[..proof_pos];
189 let proof_b64 = &client_final[proof_pos + 3..];
190
191 let echoed_nonce = without_proof
193 .split(',')
194 .find_map(|f| f.strip_prefix("r="))
195 .ok_or_else(|| "client-final missing r=".to_string())?;
196 if echoed_nonce != self.combined_nonce {
197 return Err("nonce mismatch".to_string());
198 }
199
200 let proof = BASE64
201 .decode(proof_b64.trim())
202 .map_err(|e| format!("bad proof base64: {}", e))?;
203 if proof.len() != 32 {
204 return Err("proof wrong length".to_string());
205 }
206
207 let auth_message = format!(
208 "{},{},{}",
209 self.client_first_bare, self.server_first, without_proof
210 );
211
212 let client_signature = hmac_sha256(&self.verifier.stored_key, auth_message.as_bytes());
216 let mut client_key = [0u8; 32];
217 for i in 0..32 {
218 client_key[i] = proof[i] ^ client_signature[i];
219 }
220 let derived_stored = sha256(&client_key);
221 if !constant_time_eq(&derived_stored, &self.verifier.stored_key) {
222 return Err("authentication failed (proof mismatch)".to_string());
223 }
224
225 let server_signature = hmac_sha256(&self.verifier.server_key, auth_message.as_bytes());
227 Ok(format!("v={}", BASE64.encode(server_signature)))
228 }
229}
230
231fn constant_time_eq(a: &[u8; 32], b: &[u8; 32]) -> bool {
232 let mut diff = 0u8;
233 for i in 0..32 {
234 diff |= a[i] ^ b[i];
235 }
236 diff == 0
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use crate::backend::auth::Scram;
243
244 #[test]
245 fn parse_pg_verifier_roundtrips_from_password() {
246 let v = ScramVerifier::from_password("s3cret", b"0123456789abcdef".to_vec(), 4096);
248 let s = format!(
249 "SCRAM-SHA-256${}:{}${}:{}",
250 v.iterations,
251 BASE64.encode(&v.salt),
252 BASE64.encode(v.stored_key),
253 BASE64.encode(v.server_key),
254 );
255 let p = ScramVerifier::parse(&s).expect("parses");
256 assert_eq!(p.iterations, v.iterations);
257 assert_eq!(p.salt, v.salt);
258 assert_eq!(p.stored_key, v.stored_key);
259 assert_eq!(p.server_key, v.server_key);
260 }
261
262 #[test]
263 fn full_scram_handshake_client_vs_server() {
264 let password = "correct horse battery staple";
266 let verifier = ScramVerifier::from_password(password, b"saltsaltsaltsalt".to_vec(), 4096);
267
268 let (mut client, init) = Scram::client_first("clientNONCE123");
269 let data = &init.0;
272 let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
273 let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
274
275 let (server, server_first) =
276 ScramServer::start(verifier.clone(), client_first, "serverNONCE456").unwrap();
277
278 let client_final = client.client_final(server_first.as_bytes(), password).unwrap();
279 let server_final = server.finish(std::str::from_utf8(&client_final.0).unwrap()).unwrap();
280
281 client.verify_server(server_final.as_bytes()).unwrap();
283 }
284
285 #[test]
286 fn wrong_password_is_rejected() {
287 let verifier = ScramVerifier::from_password("rightpw", b"saltsaltsaltsalt".to_vec(), 4096);
288 let (mut client, init) = Scram::client_first("nonceAAA");
289 let data = &init.0;
290 let mech_end = data.iter().position(|&b| b == 0).unwrap() + 1;
291 let client_first = std::str::from_utf8(&data[mech_end + 4..]).unwrap();
292 let (server, server_first) =
293 ScramServer::start(verifier, client_first, "nonceBBB").unwrap();
294 let client_final = client.client_final(server_first.as_bytes(), "wrongpw").unwrap();
296 let res = server.finish(std::str::from_utf8(&client_final.0).unwrap());
297 assert!(res.is_err(), "wrong password must be rejected");
298 }
299
300 #[test]
301 fn auth_file_parses_plaintext_and_verifier() {
302 let v = ScramVerifier::from_password("pw", b"0123456789abcdef".to_vec(), 4096);
303 let verifier_line = format!(
304 "carol:SCRAM-SHA-256${}:{}${}:{}",
305 v.iterations,
306 BASE64.encode(&v.salt),
307 BASE64.encode(v.stored_key),
308 BASE64.encode(v.server_key),
309 );
310 let body = format!("# comment\nalice:secret\n\nbob:\"quoted\"\n{}\n", verifier_line);
311 let af = AuthFile::parse_str(&body, "test").unwrap();
312 assert!(af.get("alice").is_some());
313 assert!(af.get("bob").is_some());
314 assert!(af.get("carol").is_some());
315 assert!(af.get("dave").is_none());
316 }
317}