Skip to main content

haystack_server/auth/
mod.rs

1//! Server-side authentication manager using SCRAM SHA-256.
2//!
3//! Manages user records, in-flight SCRAM handshakes, and active bearer
4//! tokens.
5
6pub mod users;
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use base64::Engine;
13use base64::engine::general_purpose::STANDARD as BASE64;
14use hmac::{Hmac, Mac};
15use parking_lot::RwLock;
16use sha2::Sha256;
17use uuid::Uuid;
18use zeroize::Zeroize;
19
20use haystack_core::auth::{
21    DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials, extract_client_nonce,
22    format_auth_info, format_www_authenticate, generate_nonce, server_first_message,
23    server_verify_final,
24};
25
26use users::{UserRecord, load_users_from_str, load_users_from_toml};
27
28/// An authenticated user with associated permissions.
29#[derive(Debug, Clone)]
30pub struct AuthUser {
31    pub username: String,
32    pub permissions: Vec<String>,
33}
34
35/// Time-to-live for in-flight SCRAM handshakes.
36const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
37
38/// Server-side authentication manager.
39///
40/// Holds user credentials, in-flight SCRAM handshakes, and active
41/// bearer tokens.
42pub struct AuthManager {
43    /// Username -> pre-computed SCRAM credentials + permissions.
44    users: HashMap<String, UserRecord>,
45    /// In-flight SCRAM handshakes: handshake_token -> (ScramHandshake, created_at).
46    handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
47    /// Active bearer tokens: auth_token -> (AuthUser, created_at).
48    tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
49    /// Time-to-live for bearer tokens.
50    token_ttl: Duration,
51    /// Secret used to derive fake SCRAM challenges for unknown users,
52    /// preventing username enumeration attacks.
53    server_secret: [u8; 32],
54    /// Counter used to periodically trigger bulk cleanup of expired tokens.
55    cleanup_counter: AtomicU64,
56}
57
58impl Drop for AuthManager {
59    fn drop(&mut self) {
60        self.server_secret.zeroize();
61    }
62}
63
64impl AuthManager {
65    /// Create a new AuthManager with the given user records and token TTL.
66    pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
67        let mut server_secret = [0u8; 32];
68        rand::RngExt::fill(&mut rand::rng(), &mut server_secret);
69        Self {
70            users,
71            handshakes: RwLock::new(HashMap::new()),
72            tokens: RwLock::new(HashMap::new()),
73            token_ttl,
74            server_secret,
75            cleanup_counter: AtomicU64::new(0),
76        }
77    }
78
79    /// Create an AuthManager with no users (auth effectively disabled).
80    pub fn empty() -> Self {
81        Self::new(HashMap::new(), Duration::from_secs(3600))
82    }
83
84    /// Builder method to configure the token TTL.
85    pub fn with_token_ttl(mut self, duration: Duration) -> Self {
86        self.token_ttl = duration;
87        self
88    }
89
90    /// Create an AuthManager from a TOML file.
91    pub fn from_toml(path: &str) -> Result<Self, String> {
92        let users = load_users_from_toml(path)?;
93        Ok(Self::new(users, Duration::from_secs(3600)))
94    }
95
96    /// Create an AuthManager from TOML content string.
97    pub fn from_toml_str(content: &str) -> Result<Self, String> {
98        let users = load_users_from_str(content)?;
99        Ok(Self::new(users, Duration::from_secs(3600)))
100    }
101
102    /// Returns true if authentication is enabled (there are registered users).
103    pub fn is_enabled(&self) -> bool {
104        !self.users.is_empty()
105    }
106
107    /// Derive deterministic fake SCRAM credentials for an unknown username.
108    ///
109    /// Uses HMAC(server_secret, username) so the same unknown username always
110    /// produces the same salt, making the response indistinguishable from a
111    /// real user's challenge to an outside observer.
112    fn fake_credentials(&self, username: &str) -> ScramCredentials {
113        let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
114            .expect("HMAC accepts keys of any size");
115        mac.update(username.as_bytes());
116        let fake_salt = mac.finalize().into_bytes();
117
118        // Derive credentials using a throwaway password; the handshake will
119        // always fail at the `handle_scram` step because the attacker does
120        // not know a valid password, but the challenge itself looks normal.
121        derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
122    }
123
124    /// Handle a HELLO request: look up user, create SCRAM handshake.
125    ///
126    /// `client_first_b64` is the optional base64-encoded client-first-message
127    /// containing the client nonce. If absent, the server generates a nonce
128    /// (but the handshake will fail if the client expects its own nonce).
129    ///
130    /// Returns the `WWW-Authenticate` header value for the 401 response.
131    /// Unknown users receive a fake but plausible challenge to prevent
132    /// username enumeration.
133    pub fn handle_hello(
134        &self,
135        username: &str,
136        client_first_b64: Option<&str>,
137    ) -> Result<String, String> {
138        let owned_fake;
139        let credentials: &ScramCredentials = match self.users.get(username) {
140            Some(user_record) => &user_record.credentials,
141            None => {
142                owned_fake = self.fake_credentials(username);
143                &owned_fake
144            }
145        };
146
147        // Extract client nonce from client-first-message, or generate one
148        let client_nonce = match client_first_b64 {
149            Some(data) => {
150                extract_client_nonce(data).map_err(|e| format!("invalid client-first data: {e}"))?
151            }
152            None => generate_nonce(),
153        };
154
155        // Create server-first-message
156        let (handshake, server_first_b64) =
157            server_first_message(username, &client_nonce, credentials);
158
159        // Lazy cleanup: remove expired handshakes before inserting.
160        {
161            let now = Instant::now();
162            self.handshakes
163                .write()
164                .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
165        }
166
167        // Store handshake with a unique token and timestamp.
168        let handshake_token = Uuid::new_v4().to_string();
169        self.handshakes
170            .write()
171            .insert(handshake_token.clone(), (handshake, Instant::now()));
172
173        // Format the WWW-Authenticate header
174        let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
175        Ok(www_auth)
176    }
177
178    /// Handle a SCRAM request: verify client proof, issue auth token.
179    ///
180    /// Returns `(auth_token, authentication_info_header_value)`.
181    pub fn handle_scram(
182        &self,
183        handshake_token: &str,
184        data: &str,
185    ) -> Result<(String, String), String> {
186        // Remove the handshake (one-time use) and check expiry.
187        let (handshake, created_at) = self
188            .handshakes
189            .write()
190            .remove(handshake_token)
191            .ok_or_else(|| "invalid or expired handshake token".to_string())?;
192        if created_at.elapsed() > HANDSHAKE_TTL {
193            return Err("handshake token expired".to_string());
194        }
195
196        let username = handshake.username.clone();
197
198        // Verify client proof
199        let server_sig = server_verify_final(&handshake, data)
200            .map_err(|e| format!("SCRAM verification failed: {e}"))?;
201
202        // Issue auth token
203        let auth_token = Uuid::new_v4().to_string();
204
205        // Look up permissions
206        let permissions = self
207            .users
208            .get(&username)
209            .map(|r| r.permissions.clone())
210            .unwrap_or_default();
211
212        // Store token -> (user, created_at) mapping
213        self.tokens.write().insert(
214            auth_token.clone(),
215            (
216                AuthUser {
217                    username,
218                    permissions,
219                },
220                Instant::now(),
221            ),
222        );
223
224        // Format the server-final data (v=<server_signature>)
225        let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
226        let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
227        let auth_info = format_auth_info(&auth_token, &server_final_b64);
228
229        Ok((auth_token, auth_info))
230    }
231
232    /// Validate a bearer token and return the associated user.
233    ///
234    /// Returns `None` if the token is unknown or has expired. Expired
235    /// tokens are automatically removed under a single write lock to
236    /// avoid TOCTOU races.
237    pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
238        // Periodically sweep all expired tokens to prevent unbounded growth.
239        let count = self.cleanup_counter.fetch_add(1, Ordering::Relaxed);
240        if count.is_multiple_of(100) {
241            let mut tokens = self.tokens.write();
242            let ttl = self.token_ttl;
243            tokens.retain(|_, (_, created)| created.elapsed() < ttl);
244        }
245
246        let mut tokens = self.tokens.write();
247        match tokens.get(token) {
248            Some((user, created_at)) => {
249                if created_at.elapsed() <= self.token_ttl {
250                    Some(user.clone())
251                } else {
252                    // Token expired — remove immediately under the same lock.
253                    tokens.remove(token);
254                    None
255                }
256            }
257            None => None,
258        }
259    }
260
261    /// Remove a bearer token (logout / close).
262    pub fn revoke_token(&self, token: &str) -> bool {
263        self.tokens.write().remove(token).is_some()
264    }
265
266    /// Inject a token directly (for testing). The token is stamped with the
267    /// current instant so it will not be considered expired.
268    #[doc(hidden)]
269    pub fn inject_token(&self, token: String, user: AuthUser) {
270        self.tokens.write().insert(token, (user, Instant::now()));
271    }
272
273    /// Check whether a user has a required permission.
274    pub fn check_permission(user: &AuthUser, required: &str) -> bool {
275        // Admin has all permissions
276        if user.permissions.contains(&"admin".to_string()) {
277            return true;
278        }
279        user.permissions.contains(&required.to_string())
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use crate::auth::users::hash_password;
287
288    fn make_test_manager() -> AuthManager {
289        let hash = hash_password("s3cret");
290        let toml_str = format!(
291            r#"
292[users.admin]
293password_hash = "{hash}"
294permissions = ["read", "write", "admin"]
295
296[users.viewer]
297password_hash = "{hash}"
298permissions = ["read"]
299"#
300        );
301        AuthManager::from_toml_str(&toml_str).unwrap()
302    }
303
304    #[test]
305    fn empty_manager_is_disabled() {
306        let mgr = AuthManager::empty();
307        assert!(!mgr.is_enabled());
308    }
309
310    #[test]
311    fn manager_with_users_is_enabled() {
312        let mgr = make_test_manager();
313        assert!(mgr.is_enabled());
314    }
315
316    #[test]
317    fn hello_unknown_user_returns_fake_challenge() {
318        let mgr = make_test_manager();
319        // Unknown users now get a plausible SCRAM challenge instead of an
320        // error, preventing username enumeration.
321        let result = mgr.handle_hello("nonexistent", None);
322        assert!(result.is_ok());
323        let www_auth = result.unwrap();
324        assert!(www_auth.contains("SCRAM"));
325        assert!(www_auth.contains("SHA-256"));
326    }
327
328    #[test]
329    fn hello_known_user_succeeds() {
330        let mgr = make_test_manager();
331        let result = mgr.handle_hello("admin", None);
332        assert!(result.is_ok());
333        let www_auth = result.unwrap();
334        assert!(www_auth.contains("SCRAM"));
335        assert!(www_auth.contains("SHA-256"));
336    }
337
338    #[test]
339    fn hello_known_and_unknown_users_look_similar() {
340        let mgr = make_test_manager();
341        let known = mgr.handle_hello("admin", None).unwrap();
342        let unknown = mgr.handle_hello("nonexistent", None).unwrap();
343
344        // Both responses must have the same structural format so that an
345        // attacker cannot distinguish real from fake users.
346        assert!(known.starts_with("SCRAM handshakeToken="));
347        assert!(unknown.starts_with("SCRAM handshakeToken="));
348        assert!(known.contains("hash=SHA-256"));
349        assert!(unknown.contains("hash=SHA-256"));
350        assert!(known.contains("data="));
351        assert!(unknown.contains("data="));
352    }
353
354    #[test]
355    fn fake_challenge_is_deterministic_per_username() {
356        let mgr = make_test_manager();
357        // The fake salt must be deterministic so that repeated HELLO requests
358        // for the same unknown username produce consistent parameters.
359        let creds1 = mgr.fake_credentials("ghost");
360        let creds2 = mgr.fake_credentials("ghost");
361        assert_eq!(creds1.salt, creds2.salt);
362        assert_eq!(creds1.stored_key, creds2.stored_key);
363        assert_eq!(creds1.server_key, creds2.server_key);
364
365        // Different usernames produce different fake salts.
366        let creds3 = mgr.fake_credentials("phantom");
367        assert_ne!(creds1.salt, creds3.salt);
368    }
369
370    #[test]
371    fn validate_token_returns_none_for_unknown() {
372        let mgr = make_test_manager();
373        assert!(mgr.validate_token("nonexistent-token").is_none());
374    }
375
376    #[test]
377    fn check_permission_admin_has_all() {
378        let user = AuthUser {
379            username: "admin".to_string(),
380            permissions: vec!["admin".to_string()],
381        };
382        assert!(AuthManager::check_permission(&user, "read"));
383        assert!(AuthManager::check_permission(&user, "write"));
384        assert!(AuthManager::check_permission(&user, "admin"));
385    }
386
387    #[test]
388    fn check_permission_viewer_limited() {
389        let user = AuthUser {
390            username: "viewer".to_string(),
391            permissions: vec!["read".to_string()],
392        };
393        assert!(AuthManager::check_permission(&user, "read"));
394        assert!(!AuthManager::check_permission(&user, "write"));
395        assert!(!AuthManager::check_permission(&user, "admin"));
396    }
397
398    #[test]
399    fn revoke_token_returns_false_for_unknown() {
400        let mgr = make_test_manager();
401        assert!(!mgr.revoke_token("nonexistent-token"));
402    }
403
404    #[test]
405    fn validate_token_succeeds_before_expiry() {
406        let mgr = make_test_manager();
407        // Manually insert a token with Instant::now() (fresh, not expired).
408        let user = AuthUser {
409            username: "admin".to_string(),
410            permissions: vec!["admin".to_string()],
411        };
412        mgr.tokens
413            .write()
414            .insert("good-token".to_string(), (user, Instant::now()));
415
416        assert!(mgr.validate_token("good-token").is_some());
417    }
418
419    #[test]
420    fn validate_token_fails_after_expiry() {
421        // Use a very short TTL so the token is already expired.
422        let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
423
424        let user = AuthUser {
425            username: "admin".to_string(),
426            permissions: vec!["admin".to_string()],
427        };
428        // Insert a token that was created "now" -- with a 0s TTL it is
429        // immediately expired.
430        mgr.tokens
431            .write()
432            .insert("expired-token".to_string(), (user, Instant::now()));
433
434        // Even though the token exists, it should be reported as expired.
435        assert!(mgr.validate_token("expired-token").is_none());
436
437        // The expired token should have been removed from the map.
438        assert!(mgr.tokens.read().get("expired-token").is_none());
439    }
440
441    #[test]
442    fn with_token_ttl_sets_custom_duration() {
443        let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
444        assert_eq!(mgr.token_ttl, Duration::from_secs(120));
445    }
446}