Skip to main content

haystack_server/auth/
mod.rs

1//! Server-side authentication manager using SCRAM SHA-256.
2//!
3//! Manages user records, in-flight SCRAM handshakes, and active bearer
4//! tokens.
5
6pub mod users;
7
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use parking_lot::RwLock;
15use sha2::Sha256;
16use uuid::Uuid;
17
18use haystack_core::auth::{
19    DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials,
20    extract_client_nonce, format_auth_info, format_www_authenticate, generate_nonce,
21    server_first_message, server_verify_final,
22};
23
24use users::{UserRecord, load_users_from_str, load_users_from_toml};
25
26/// An authenticated user with associated permissions.
27#[derive(Debug, Clone)]
28pub struct AuthUser {
29    pub username: String,
30    pub permissions: Vec<String>,
31}
32
33/// Time-to-live for in-flight SCRAM handshakes.
34const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
35
36/// Server-side authentication manager.
37///
38/// Holds user credentials, in-flight SCRAM handshakes, and active
39/// bearer tokens.
40pub struct AuthManager {
41    /// Username -> pre-computed SCRAM credentials + permissions.
42    users: HashMap<String, UserRecord>,
43    /// In-flight SCRAM handshakes: handshake_token -> (ScramHandshake, created_at).
44    handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
45    /// Active bearer tokens: auth_token -> (AuthUser, created_at).
46    tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
47    /// Time-to-live for bearer tokens.
48    token_ttl: Duration,
49    /// Secret used to derive fake SCRAM challenges for unknown users,
50    /// preventing username enumeration attacks.
51    server_secret: [u8; 32],
52}
53
54impl AuthManager {
55    /// Create a new AuthManager with the given user records and token TTL.
56    pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
57        let mut server_secret = [0u8; 32];
58        rand::Rng::fill(&mut rand::rng(), &mut server_secret);
59        Self {
60            users,
61            handshakes: RwLock::new(HashMap::new()),
62            tokens: RwLock::new(HashMap::new()),
63            token_ttl,
64            server_secret,
65        }
66    }
67
68    /// Create an AuthManager with no users (auth effectively disabled).
69    pub fn empty() -> Self {
70        Self::new(HashMap::new(), Duration::from_secs(3600))
71    }
72
73    /// Builder method to configure the token TTL.
74    pub fn with_token_ttl(mut self, duration: Duration) -> Self {
75        self.token_ttl = duration;
76        self
77    }
78
79    /// Create an AuthManager from a TOML file.
80    pub fn from_toml(path: &str) -> Result<Self, String> {
81        let users = load_users_from_toml(path)?;
82        Ok(Self::new(users, Duration::from_secs(3600)))
83    }
84
85    /// Create an AuthManager from TOML content string.
86    pub fn from_toml_str(content: &str) -> Result<Self, String> {
87        let users = load_users_from_str(content)?;
88        Ok(Self::new(users, Duration::from_secs(3600)))
89    }
90
91    /// Returns true if authentication is enabled (there are registered users).
92    pub fn is_enabled(&self) -> bool {
93        !self.users.is_empty()
94    }
95
96    /// Derive deterministic fake SCRAM credentials for an unknown username.
97    ///
98    /// Uses HMAC(server_secret, username) so the same unknown username always
99    /// produces the same salt, making the response indistinguishable from a
100    /// real user's challenge to an outside observer.
101    fn fake_credentials(&self, username: &str) -> ScramCredentials {
102        let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
103            .expect("HMAC accepts keys of any size");
104        mac.update(username.as_bytes());
105        let fake_salt = mac.finalize().into_bytes();
106
107        // Derive credentials using a throwaway password; the handshake will
108        // always fail at the `handle_scram` step because the attacker does
109        // not know a valid password, but the challenge itself looks normal.
110        derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
111    }
112
113    /// Handle a HELLO request: look up user, create SCRAM handshake.
114    ///
115    /// `client_first_b64` is the optional base64-encoded client-first-message
116    /// containing the client nonce. If absent, the server generates a nonce
117    /// (but the handshake will fail if the client expects its own nonce).
118    ///
119    /// Returns the `WWW-Authenticate` header value for the 401 response.
120    /// Unknown users receive a fake but plausible challenge to prevent
121    /// username enumeration.
122    pub fn handle_hello(
123        &self,
124        username: &str,
125        client_first_b64: Option<&str>,
126    ) -> Result<String, String> {
127        let credentials = match self.users.get(username) {
128            Some(user_record) => user_record.credentials.clone(),
129            None => self.fake_credentials(username),
130        };
131
132        // Extract client nonce from client-first-message, or generate one
133        let client_nonce = match client_first_b64 {
134            Some(data) => extract_client_nonce(data)
135                .map_err(|e| format!("invalid client-first data: {e}"))?,
136            None => generate_nonce(),
137        };
138
139        // Create server-first-message
140        let (handshake, server_first_b64) =
141            server_first_message(username, &client_nonce, &credentials);
142
143        // Lazy cleanup: remove expired handshakes before inserting.
144        {
145            let now = Instant::now();
146            self.handshakes
147                .write()
148                .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
149        }
150
151        // Store handshake with a unique token and timestamp.
152        let handshake_token = Uuid::new_v4().to_string();
153        self.handshakes
154            .write()
155            .insert(handshake_token.clone(), (handshake, Instant::now()));
156
157        // Format the WWW-Authenticate header
158        let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
159        Ok(www_auth)
160    }
161
162    /// Handle a SCRAM request: verify client proof, issue auth token.
163    ///
164    /// Returns `(auth_token, authentication_info_header_value)`.
165    pub fn handle_scram(
166        &self,
167        handshake_token: &str,
168        data: &str,
169    ) -> Result<(String, String), String> {
170        // Remove the handshake (one-time use) and check expiry.
171        let (handshake, created_at) = self
172            .handshakes
173            .write()
174            .remove(handshake_token)
175            .ok_or_else(|| "invalid or expired handshake token".to_string())?;
176        if created_at.elapsed() > HANDSHAKE_TTL {
177            return Err("handshake token expired".to_string());
178        }
179
180        let username = handshake.username.clone();
181
182        // Verify client proof
183        let server_sig = server_verify_final(&handshake, data)
184            .map_err(|e| format!("SCRAM verification failed: {e}"))?;
185
186        // Issue auth token
187        let auth_token = Uuid::new_v4().to_string();
188
189        // Look up permissions
190        let permissions = self
191            .users
192            .get(&username)
193            .map(|r| r.permissions.clone())
194            .unwrap_or_default();
195
196        // Store token -> (user, created_at) mapping
197        self.tokens.write().insert(
198            auth_token.clone(),
199            (
200                AuthUser {
201                    username,
202                    permissions,
203                },
204                Instant::now(),
205            ),
206        );
207
208        // Format the server-final data (v=<server_signature>)
209        let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
210        let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
211        let auth_info = format_auth_info(&auth_token, &server_final_b64);
212
213        Ok((auth_token, auth_info))
214    }
215
216    /// Validate a bearer token and return the associated user.
217    ///
218    /// Returns `None` if the token is unknown or has expired. Expired
219    /// tokens are automatically removed.
220    pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
221        // First, check with a read lock.
222        {
223            let tokens = self.tokens.read();
224            match tokens.get(token) {
225                Some((user, created_at)) => {
226                    if created_at.elapsed() <= self.token_ttl {
227                        return Some(user.clone());
228                    }
229                    // Token expired -- fall through to remove it.
230                }
231                None => return None,
232            }
233        }
234        // Expired: remove under a write lock.
235        self.tokens.write().remove(token);
236        None
237    }
238
239    /// Remove a bearer token (logout / close).
240    pub fn revoke_token(&self, token: &str) -> bool {
241        self.tokens.write().remove(token).is_some()
242    }
243
244    /// Inject a token directly (for testing). The token is stamped with the
245    /// current instant so it will not be considered expired.
246    #[doc(hidden)]
247    pub fn inject_token(&self, token: String, user: AuthUser) {
248        self.tokens.write().insert(token, (user, Instant::now()));
249    }
250
251    /// Check whether a user has a required permission.
252    pub fn check_permission(user: &AuthUser, required: &str) -> bool {
253        // Admin has all permissions
254        if user.permissions.contains(&"admin".to_string()) {
255            return true;
256        }
257        user.permissions.contains(&required.to_string())
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::auth::users::hash_password;
265
266    fn make_test_manager() -> AuthManager {
267        let hash = hash_password("s3cret");
268        let toml_str = format!(
269            r#"
270[users.admin]
271password_hash = "{hash}"
272permissions = ["read", "write", "admin"]
273
274[users.viewer]
275password_hash = "{hash}"
276permissions = ["read"]
277"#
278        );
279        AuthManager::from_toml_str(&toml_str).unwrap()
280    }
281
282    #[test]
283    fn empty_manager_is_disabled() {
284        let mgr = AuthManager::empty();
285        assert!(!mgr.is_enabled());
286    }
287
288    #[test]
289    fn manager_with_users_is_enabled() {
290        let mgr = make_test_manager();
291        assert!(mgr.is_enabled());
292    }
293
294    #[test]
295    fn hello_unknown_user_returns_fake_challenge() {
296        let mgr = make_test_manager();
297        // Unknown users now get a plausible SCRAM challenge instead of an
298        // error, preventing username enumeration.
299        let result = mgr.handle_hello("nonexistent", None);
300        assert!(result.is_ok());
301        let www_auth = result.unwrap();
302        assert!(www_auth.contains("SCRAM"));
303        assert!(www_auth.contains("SHA-256"));
304    }
305
306    #[test]
307    fn hello_known_user_succeeds() {
308        let mgr = make_test_manager();
309        let result = mgr.handle_hello("admin", None);
310        assert!(result.is_ok());
311        let www_auth = result.unwrap();
312        assert!(www_auth.contains("SCRAM"));
313        assert!(www_auth.contains("SHA-256"));
314    }
315
316    #[test]
317    fn hello_known_and_unknown_users_look_similar() {
318        let mgr = make_test_manager();
319        let known = mgr.handle_hello("admin", None).unwrap();
320        let unknown = mgr.handle_hello("nonexistent", None).unwrap();
321
322        // Both responses must have the same structural format so that an
323        // attacker cannot distinguish real from fake users.
324        assert!(known.starts_with("SCRAM handshakeToken="));
325        assert!(unknown.starts_with("SCRAM handshakeToken="));
326        assert!(known.contains("hash=SHA-256"));
327        assert!(unknown.contains("hash=SHA-256"));
328        assert!(known.contains("data="));
329        assert!(unknown.contains("data="));
330    }
331
332    #[test]
333    fn fake_challenge_is_deterministic_per_username() {
334        let mgr = make_test_manager();
335        // The fake salt must be deterministic so that repeated HELLO requests
336        // for the same unknown username produce consistent parameters.
337        let creds1 = mgr.fake_credentials("ghost");
338        let creds2 = mgr.fake_credentials("ghost");
339        assert_eq!(creds1.salt, creds2.salt);
340        assert_eq!(creds1.stored_key, creds2.stored_key);
341        assert_eq!(creds1.server_key, creds2.server_key);
342
343        // Different usernames produce different fake salts.
344        let creds3 = mgr.fake_credentials("phantom");
345        assert_ne!(creds1.salt, creds3.salt);
346    }
347
348    #[test]
349    fn validate_token_returns_none_for_unknown() {
350        let mgr = make_test_manager();
351        assert!(mgr.validate_token("nonexistent-token").is_none());
352    }
353
354    #[test]
355    fn check_permission_admin_has_all() {
356        let user = AuthUser {
357            username: "admin".to_string(),
358            permissions: vec!["admin".to_string()],
359        };
360        assert!(AuthManager::check_permission(&user, "read"));
361        assert!(AuthManager::check_permission(&user, "write"));
362        assert!(AuthManager::check_permission(&user, "admin"));
363    }
364
365    #[test]
366    fn check_permission_viewer_limited() {
367        let user = AuthUser {
368            username: "viewer".to_string(),
369            permissions: vec!["read".to_string()],
370        };
371        assert!(AuthManager::check_permission(&user, "read"));
372        assert!(!AuthManager::check_permission(&user, "write"));
373        assert!(!AuthManager::check_permission(&user, "admin"));
374    }
375
376    #[test]
377    fn revoke_token_returns_false_for_unknown() {
378        let mgr = make_test_manager();
379        assert!(!mgr.revoke_token("nonexistent-token"));
380    }
381
382    #[test]
383    fn validate_token_succeeds_before_expiry() {
384        let mgr = make_test_manager();
385        // Manually insert a token with Instant::now() (fresh, not expired).
386        let user = AuthUser {
387            username: "admin".to_string(),
388            permissions: vec!["admin".to_string()],
389        };
390        mgr.tokens
391            .write()
392            .insert("good-token".to_string(), (user, Instant::now()));
393
394        assert!(mgr.validate_token("good-token").is_some());
395    }
396
397    #[test]
398    fn validate_token_fails_after_expiry() {
399        // Use a very short TTL so the token is already expired.
400        let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
401
402        let user = AuthUser {
403            username: "admin".to_string(),
404            permissions: vec!["admin".to_string()],
405        };
406        // Insert a token that was created "now" -- with a 0s TTL it is
407        // immediately expired.
408        mgr.tokens
409            .write()
410            .insert("expired-token".to_string(), (user, Instant::now()));
411
412        // Even though the token exists, it should be reported as expired.
413        assert!(mgr.validate_token("expired-token").is_none());
414
415        // The expired token should have been removed from the map.
416        assert!(mgr.tokens.read().get("expired-token").is_none());
417    }
418
419    #[test]
420    fn with_token_ttl_sets_custom_duration() {
421        let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
422        assert_eq!(mgr.token_ttl, Duration::from_secs(120));
423    }
424}