1pub mod users;
7
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use parking_lot::RwLock;
15use sha2::Sha256;
16use uuid::Uuid;
17use zeroize::Zeroize;
18
19use haystack_core::auth::{
20 DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials, extract_client_nonce,
21 format_auth_info, format_www_authenticate, generate_nonce, server_first_message,
22 server_verify_final,
23};
24
25use users::{UserRecord, load_users_from_str, load_users_from_toml};
26
27#[derive(Debug, Clone)]
29pub struct AuthUser {
30 pub username: String,
31 pub permissions: Vec<String>,
32}
33
34const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
36
37pub struct AuthManager {
42 users: HashMap<String, UserRecord>,
44 handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
46 tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
48 token_ttl: Duration,
50 server_secret: [u8; 32],
53}
54
55impl Drop for AuthManager {
56 fn drop(&mut self) {
57 self.server_secret.zeroize();
58 }
59}
60
61impl AuthManager {
62 pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
64 let mut server_secret = [0u8; 32];
65 rand::RngExt::fill(&mut rand::rng(), &mut server_secret);
66 Self {
67 users,
68 handshakes: RwLock::new(HashMap::new()),
69 tokens: RwLock::new(HashMap::new()),
70 token_ttl,
71 server_secret,
72 }
73 }
74
75 pub fn empty() -> Self {
77 Self::new(HashMap::new(), Duration::from_secs(3600))
78 }
79
80 pub fn with_token_ttl(mut self, duration: Duration) -> Self {
82 self.token_ttl = duration;
83 self
84 }
85
86 pub fn from_toml(path: &str) -> Result<Self, String> {
88 let users = load_users_from_toml(path)?;
89 Ok(Self::new(users, Duration::from_secs(3600)))
90 }
91
92 pub fn from_toml_str(content: &str) -> Result<Self, String> {
94 let users = load_users_from_str(content)?;
95 Ok(Self::new(users, Duration::from_secs(3600)))
96 }
97
98 pub fn is_enabled(&self) -> bool {
100 !self.users.is_empty()
101 }
102
103 fn fake_credentials(&self, username: &str) -> ScramCredentials {
109 let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
110 .expect("HMAC accepts keys of any size");
111 mac.update(username.as_bytes());
112 let fake_salt = mac.finalize().into_bytes();
113
114 derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
118 }
119
120 pub fn handle_hello(
130 &self,
131 username: &str,
132 client_first_b64: Option<&str>,
133 ) -> Result<String, String> {
134 let credentials = match self.users.get(username) {
135 Some(user_record) => user_record.credentials.clone(),
136 None => self.fake_credentials(username),
137 };
138
139 let client_nonce = match client_first_b64 {
141 Some(data) => {
142 extract_client_nonce(data).map_err(|e| format!("invalid client-first data: {e}"))?
143 }
144 None => generate_nonce(),
145 };
146
147 let (handshake, server_first_b64) =
149 server_first_message(username, &client_nonce, &credentials);
150
151 {
153 let now = Instant::now();
154 self.handshakes
155 .write()
156 .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
157 }
158
159 let handshake_token = Uuid::new_v4().to_string();
161 self.handshakes
162 .write()
163 .insert(handshake_token.clone(), (handshake, Instant::now()));
164
165 let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
167 Ok(www_auth)
168 }
169
170 pub fn handle_scram(
174 &self,
175 handshake_token: &str,
176 data: &str,
177 ) -> Result<(String, String), String> {
178 let (handshake, created_at) = self
180 .handshakes
181 .write()
182 .remove(handshake_token)
183 .ok_or_else(|| "invalid or expired handshake token".to_string())?;
184 if created_at.elapsed() > HANDSHAKE_TTL {
185 return Err("handshake token expired".to_string());
186 }
187
188 let username = handshake.username.clone();
189
190 let server_sig = server_verify_final(&handshake, data)
192 .map_err(|e| format!("SCRAM verification failed: {e}"))?;
193
194 let auth_token = Uuid::new_v4().to_string();
196
197 let permissions = self
199 .users
200 .get(&username)
201 .map(|r| r.permissions.clone())
202 .unwrap_or_default();
203
204 self.tokens.write().insert(
206 auth_token.clone(),
207 (
208 AuthUser {
209 username,
210 permissions,
211 },
212 Instant::now(),
213 ),
214 );
215
216 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
218 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
219 let auth_info = format_auth_info(&auth_token, &server_final_b64);
220
221 Ok((auth_token, auth_info))
222 }
223
224 pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
230 let mut tokens = self.tokens.write();
231 match tokens.get(token) {
232 Some((user, created_at)) => {
233 if created_at.elapsed() <= self.token_ttl {
234 Some(user.clone())
235 } else {
236 tokens.remove(token);
238 None
239 }
240 }
241 None => None,
242 }
243 }
244
245 pub fn revoke_token(&self, token: &str) -> bool {
247 self.tokens.write().remove(token).is_some()
248 }
249
250 #[doc(hidden)]
253 pub fn inject_token(&self, token: String, user: AuthUser) {
254 self.tokens.write().insert(token, (user, Instant::now()));
255 }
256
257 pub fn check_permission(user: &AuthUser, required: &str) -> bool {
259 if user.permissions.contains(&"admin".to_string()) {
261 return true;
262 }
263 user.permissions.contains(&required.to_string())
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::auth::users::hash_password;
271
272 fn make_test_manager() -> AuthManager {
273 let hash = hash_password("s3cret");
274 let toml_str = format!(
275 r#"
276[users.admin]
277password_hash = "{hash}"
278permissions = ["read", "write", "admin"]
279
280[users.viewer]
281password_hash = "{hash}"
282permissions = ["read"]
283"#
284 );
285 AuthManager::from_toml_str(&toml_str).unwrap()
286 }
287
288 #[test]
289 fn empty_manager_is_disabled() {
290 let mgr = AuthManager::empty();
291 assert!(!mgr.is_enabled());
292 }
293
294 #[test]
295 fn manager_with_users_is_enabled() {
296 let mgr = make_test_manager();
297 assert!(mgr.is_enabled());
298 }
299
300 #[test]
301 fn hello_unknown_user_returns_fake_challenge() {
302 let mgr = make_test_manager();
303 let result = mgr.handle_hello("nonexistent", None);
306 assert!(result.is_ok());
307 let www_auth = result.unwrap();
308 assert!(www_auth.contains("SCRAM"));
309 assert!(www_auth.contains("SHA-256"));
310 }
311
312 #[test]
313 fn hello_known_user_succeeds() {
314 let mgr = make_test_manager();
315 let result = mgr.handle_hello("admin", None);
316 assert!(result.is_ok());
317 let www_auth = result.unwrap();
318 assert!(www_auth.contains("SCRAM"));
319 assert!(www_auth.contains("SHA-256"));
320 }
321
322 #[test]
323 fn hello_known_and_unknown_users_look_similar() {
324 let mgr = make_test_manager();
325 let known = mgr.handle_hello("admin", None).unwrap();
326 let unknown = mgr.handle_hello("nonexistent", None).unwrap();
327
328 assert!(known.starts_with("SCRAM handshakeToken="));
331 assert!(unknown.starts_with("SCRAM handshakeToken="));
332 assert!(known.contains("hash=SHA-256"));
333 assert!(unknown.contains("hash=SHA-256"));
334 assert!(known.contains("data="));
335 assert!(unknown.contains("data="));
336 }
337
338 #[test]
339 fn fake_challenge_is_deterministic_per_username() {
340 let mgr = make_test_manager();
341 let creds1 = mgr.fake_credentials("ghost");
344 let creds2 = mgr.fake_credentials("ghost");
345 assert_eq!(creds1.salt, creds2.salt);
346 assert_eq!(creds1.stored_key, creds2.stored_key);
347 assert_eq!(creds1.server_key, creds2.server_key);
348
349 let creds3 = mgr.fake_credentials("phantom");
351 assert_ne!(creds1.salt, creds3.salt);
352 }
353
354 #[test]
355 fn validate_token_returns_none_for_unknown() {
356 let mgr = make_test_manager();
357 assert!(mgr.validate_token("nonexistent-token").is_none());
358 }
359
360 #[test]
361 fn check_permission_admin_has_all() {
362 let user = AuthUser {
363 username: "admin".to_string(),
364 permissions: vec!["admin".to_string()],
365 };
366 assert!(AuthManager::check_permission(&user, "read"));
367 assert!(AuthManager::check_permission(&user, "write"));
368 assert!(AuthManager::check_permission(&user, "admin"));
369 }
370
371 #[test]
372 fn check_permission_viewer_limited() {
373 let user = AuthUser {
374 username: "viewer".to_string(),
375 permissions: vec!["read".to_string()],
376 };
377 assert!(AuthManager::check_permission(&user, "read"));
378 assert!(!AuthManager::check_permission(&user, "write"));
379 assert!(!AuthManager::check_permission(&user, "admin"));
380 }
381
382 #[test]
383 fn revoke_token_returns_false_for_unknown() {
384 let mgr = make_test_manager();
385 assert!(!mgr.revoke_token("nonexistent-token"));
386 }
387
388 #[test]
389 fn validate_token_succeeds_before_expiry() {
390 let mgr = make_test_manager();
391 let user = AuthUser {
393 username: "admin".to_string(),
394 permissions: vec!["admin".to_string()],
395 };
396 mgr.tokens
397 .write()
398 .insert("good-token".to_string(), (user, Instant::now()));
399
400 assert!(mgr.validate_token("good-token").is_some());
401 }
402
403 #[test]
404 fn validate_token_fails_after_expiry() {
405 let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
407
408 let user = AuthUser {
409 username: "admin".to_string(),
410 permissions: vec!["admin".to_string()],
411 };
412 mgr.tokens
415 .write()
416 .insert("expired-token".to_string(), (user, Instant::now()));
417
418 assert!(mgr.validate_token("expired-token").is_none());
420
421 assert!(mgr.tokens.read().get("expired-token").is_none());
423 }
424
425 #[test]
426 fn with_token_ttl_sets_custom_duration() {
427 let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
428 assert_eq!(mgr.token_ttl, Duration::from_secs(120));
429 }
430}