1pub mod users;
7
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use parking_lot::RwLock;
15use sha2::Sha256;
16use uuid::Uuid;
17
18use haystack_core::auth::{
19 DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials, format_auth_info,
20 format_www_authenticate, generate_nonce, server_first_message, server_verify_final,
21};
22
23use users::{UserRecord, load_users_from_str, load_users_from_toml};
24
25#[derive(Debug, Clone)]
27pub struct AuthUser {
28 pub username: String,
29 pub permissions: Vec<String>,
30}
31
32const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
34
35pub struct AuthManager {
40 users: HashMap<String, UserRecord>,
42 handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
44 tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
46 token_ttl: Duration,
48 server_secret: [u8; 32],
51}
52
53impl AuthManager {
54 pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
56 let mut server_secret = [0u8; 32];
57 rand::Rng::fill(&mut rand::rng(), &mut server_secret);
58 Self {
59 users,
60 handshakes: RwLock::new(HashMap::new()),
61 tokens: RwLock::new(HashMap::new()),
62 token_ttl,
63 server_secret,
64 }
65 }
66
67 pub fn empty() -> Self {
69 Self::new(HashMap::new(), Duration::from_secs(3600))
70 }
71
72 pub fn with_token_ttl(mut self, duration: Duration) -> Self {
74 self.token_ttl = duration;
75 self
76 }
77
78 pub fn from_toml(path: &str) -> Result<Self, String> {
80 let users = load_users_from_toml(path)?;
81 Ok(Self::new(users, Duration::from_secs(3600)))
82 }
83
84 pub fn from_toml_str(content: &str) -> Result<Self, String> {
86 let users = load_users_from_str(content)?;
87 Ok(Self::new(users, Duration::from_secs(3600)))
88 }
89
90 pub fn is_enabled(&self) -> bool {
92 !self.users.is_empty()
93 }
94
95 fn fake_credentials(&self, username: &str) -> ScramCredentials {
101 let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
102 .expect("HMAC accepts keys of any size");
103 mac.update(username.as_bytes());
104 let fake_salt = mac.finalize().into_bytes();
105
106 derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
110 }
111
112 pub fn handle_hello(&self, username: &str) -> Result<String, String> {
118 let credentials = match self.users.get(username) {
119 Some(user_record) => user_record.credentials.clone(),
120 None => self.fake_credentials(username),
121 };
122
123 let client_nonce = generate_nonce();
125
126 let (handshake, server_first_b64) =
128 server_first_message(username, &client_nonce, &credentials);
129
130 {
132 let now = Instant::now();
133 self.handshakes
134 .write()
135 .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
136 }
137
138 let handshake_token = Uuid::new_v4().to_string();
140 self.handshakes
141 .write()
142 .insert(handshake_token.clone(), (handshake, Instant::now()));
143
144 let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
146 Ok(www_auth)
147 }
148
149 pub fn handle_scram(
153 &self,
154 handshake_token: &str,
155 data: &str,
156 ) -> Result<(String, String), String> {
157 let (handshake, created_at) = self
159 .handshakes
160 .write()
161 .remove(handshake_token)
162 .ok_or_else(|| "invalid or expired handshake token".to_string())?;
163 if created_at.elapsed() > HANDSHAKE_TTL {
164 return Err("handshake token expired".to_string());
165 }
166
167 let username = handshake.username.clone();
168
169 let server_sig = server_verify_final(&handshake, data)
171 .map_err(|e| format!("SCRAM verification failed: {e}"))?;
172
173 let auth_token = Uuid::new_v4().to_string();
175
176 let permissions = self
178 .users
179 .get(&username)
180 .map(|r| r.permissions.clone())
181 .unwrap_or_default();
182
183 self.tokens.write().insert(
185 auth_token.clone(),
186 (
187 AuthUser {
188 username,
189 permissions,
190 },
191 Instant::now(),
192 ),
193 );
194
195 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
197 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
198 let auth_info = format_auth_info(&auth_token, &server_final_b64);
199
200 Ok((auth_token, auth_info))
201 }
202
203 pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
208 {
210 let tokens = self.tokens.read();
211 match tokens.get(token) {
212 Some((user, created_at)) => {
213 if created_at.elapsed() <= self.token_ttl {
214 return Some(user.clone());
215 }
216 }
218 None => return None,
219 }
220 }
221 self.tokens.write().remove(token);
223 None
224 }
225
226 pub fn revoke_token(&self, token: &str) -> bool {
228 self.tokens.write().remove(token).is_some()
229 }
230
231 #[doc(hidden)]
234 pub fn inject_token(&self, token: String, user: AuthUser) {
235 self.tokens.write().insert(token, (user, Instant::now()));
236 }
237
238 pub fn check_permission(user: &AuthUser, required: &str) -> bool {
240 if user.permissions.contains(&"admin".to_string()) {
242 return true;
243 }
244 user.permissions.contains(&required.to_string())
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::auth::users::hash_password;
252
253 fn make_test_manager() -> AuthManager {
254 let hash = hash_password("s3cret");
255 let toml_str = format!(
256 r#"
257[users.admin]
258password_hash = "{hash}"
259permissions = ["read", "write", "admin"]
260
261[users.viewer]
262password_hash = "{hash}"
263permissions = ["read"]
264"#
265 );
266 AuthManager::from_toml_str(&toml_str).unwrap()
267 }
268
269 #[test]
270 fn empty_manager_is_disabled() {
271 let mgr = AuthManager::empty();
272 assert!(!mgr.is_enabled());
273 }
274
275 #[test]
276 fn manager_with_users_is_enabled() {
277 let mgr = make_test_manager();
278 assert!(mgr.is_enabled());
279 }
280
281 #[test]
282 fn hello_unknown_user_returns_fake_challenge() {
283 let mgr = make_test_manager();
284 let result = mgr.handle_hello("nonexistent");
287 assert!(result.is_ok());
288 let www_auth = result.unwrap();
289 assert!(www_auth.contains("SCRAM"));
290 assert!(www_auth.contains("SHA-256"));
291 }
292
293 #[test]
294 fn hello_known_user_succeeds() {
295 let mgr = make_test_manager();
296 let result = mgr.handle_hello("admin");
297 assert!(result.is_ok());
298 let www_auth = result.unwrap();
299 assert!(www_auth.contains("SCRAM"));
300 assert!(www_auth.contains("SHA-256"));
301 }
302
303 #[test]
304 fn hello_known_and_unknown_users_look_similar() {
305 let mgr = make_test_manager();
306 let known = mgr.handle_hello("admin").unwrap();
307 let unknown = mgr.handle_hello("nonexistent").unwrap();
308
309 assert!(known.starts_with("SCRAM handshakeToken="));
312 assert!(unknown.starts_with("SCRAM handshakeToken="));
313 assert!(known.contains("hash=SHA-256"));
314 assert!(unknown.contains("hash=SHA-256"));
315 assert!(known.contains("data="));
316 assert!(unknown.contains("data="));
317 }
318
319 #[test]
320 fn fake_challenge_is_deterministic_per_username() {
321 let mgr = make_test_manager();
322 let creds1 = mgr.fake_credentials("ghost");
325 let creds2 = mgr.fake_credentials("ghost");
326 assert_eq!(creds1.salt, creds2.salt);
327 assert_eq!(creds1.stored_key, creds2.stored_key);
328 assert_eq!(creds1.server_key, creds2.server_key);
329
330 let creds3 = mgr.fake_credentials("phantom");
332 assert_ne!(creds1.salt, creds3.salt);
333 }
334
335 #[test]
336 fn validate_token_returns_none_for_unknown() {
337 let mgr = make_test_manager();
338 assert!(mgr.validate_token("nonexistent-token").is_none());
339 }
340
341 #[test]
342 fn check_permission_admin_has_all() {
343 let user = AuthUser {
344 username: "admin".to_string(),
345 permissions: vec!["admin".to_string()],
346 };
347 assert!(AuthManager::check_permission(&user, "read"));
348 assert!(AuthManager::check_permission(&user, "write"));
349 assert!(AuthManager::check_permission(&user, "admin"));
350 }
351
352 #[test]
353 fn check_permission_viewer_limited() {
354 let user = AuthUser {
355 username: "viewer".to_string(),
356 permissions: vec!["read".to_string()],
357 };
358 assert!(AuthManager::check_permission(&user, "read"));
359 assert!(!AuthManager::check_permission(&user, "write"));
360 assert!(!AuthManager::check_permission(&user, "admin"));
361 }
362
363 #[test]
364 fn revoke_token_returns_false_for_unknown() {
365 let mgr = make_test_manager();
366 assert!(!mgr.revoke_token("nonexistent-token"));
367 }
368
369 #[test]
370 fn validate_token_succeeds_before_expiry() {
371 let mgr = make_test_manager();
372 let user = AuthUser {
374 username: "admin".to_string(),
375 permissions: vec!["admin".to_string()],
376 };
377 mgr.tokens
378 .write()
379 .insert("good-token".to_string(), (user, Instant::now()));
380
381 assert!(mgr.validate_token("good-token").is_some());
382 }
383
384 #[test]
385 fn validate_token_fails_after_expiry() {
386 let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
388
389 let user = AuthUser {
390 username: "admin".to_string(),
391 permissions: vec!["admin".to_string()],
392 };
393 mgr.tokens
396 .write()
397 .insert("expired-token".to_string(), (user, Instant::now()));
398
399 assert!(mgr.validate_token("expired-token").is_none());
401
402 assert!(mgr.tokens.read().get("expired-token").is_none());
404 }
405
406 #[test]
407 fn with_token_ttl_sets_custom_duration() {
408 let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
409 assert_eq!(mgr.token_ttl, Duration::from_secs(120));
410 }
411}