Skip to main content

haystack_server/auth/
mod.rs

1//! Server-side authentication manager using SCRAM SHA-256.
2//!
3//! Manages user records, in-flight SCRAM handshakes, and active bearer
4//! tokens.
5
6pub mod users;
7
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use parking_lot::RwLock;
15use sha2::Sha256;
16use uuid::Uuid;
17use zeroize::Zeroize;
18
19use haystack_core::auth::{
20    DEFAULT_ITERATIONS, ScramCredentials, ScramHandshake, derive_credentials, extract_client_nonce,
21    format_auth_info, format_www_authenticate, generate_nonce, server_first_message,
22    server_verify_final,
23};
24
25use users::{UserRecord, load_users_from_str, load_users_from_toml};
26
27/// An authenticated user with associated permissions.
28#[derive(Debug, Clone)]
29pub struct AuthUser {
30    pub username: String,
31    pub permissions: Vec<String>,
32}
33
34/// Time-to-live for in-flight SCRAM handshakes.
35const HANDSHAKE_TTL: Duration = Duration::from_secs(60);
36
37/// Server-side authentication manager.
38///
39/// Holds user credentials, in-flight SCRAM handshakes, and active
40/// bearer tokens.
41pub struct AuthManager {
42    /// Username -> pre-computed SCRAM credentials + permissions.
43    users: HashMap<String, UserRecord>,
44    /// In-flight SCRAM handshakes: handshake_token -> (ScramHandshake, created_at).
45    handshakes: RwLock<HashMap<String, (ScramHandshake, Instant)>>,
46    /// Active bearer tokens: auth_token -> (AuthUser, created_at).
47    tokens: RwLock<HashMap<String, (AuthUser, Instant)>>,
48    /// Time-to-live for bearer tokens.
49    token_ttl: Duration,
50    /// Secret used to derive fake SCRAM challenges for unknown users,
51    /// preventing username enumeration attacks.
52    server_secret: [u8; 32],
53}
54
55impl Drop for AuthManager {
56    fn drop(&mut self) {
57        self.server_secret.zeroize();
58    }
59}
60
61impl AuthManager {
62    /// Create a new AuthManager with the given user records and token TTL.
63    pub fn new(users: HashMap<String, UserRecord>, token_ttl: Duration) -> Self {
64        let mut server_secret = [0u8; 32];
65        rand::RngExt::fill(&mut rand::rng(), &mut server_secret);
66        Self {
67            users,
68            handshakes: RwLock::new(HashMap::new()),
69            tokens: RwLock::new(HashMap::new()),
70            token_ttl,
71            server_secret,
72        }
73    }
74
75    /// Create an AuthManager with no users (auth effectively disabled).
76    pub fn empty() -> Self {
77        Self::new(HashMap::new(), Duration::from_secs(3600))
78    }
79
80    /// Builder method to configure the token TTL.
81    pub fn with_token_ttl(mut self, duration: Duration) -> Self {
82        self.token_ttl = duration;
83        self
84    }
85
86    /// Create an AuthManager from a TOML file.
87    pub fn from_toml(path: &str) -> Result<Self, String> {
88        let users = load_users_from_toml(path)?;
89        Ok(Self::new(users, Duration::from_secs(3600)))
90    }
91
92    /// Create an AuthManager from TOML content string.
93    pub fn from_toml_str(content: &str) -> Result<Self, String> {
94        let users = load_users_from_str(content)?;
95        Ok(Self::new(users, Duration::from_secs(3600)))
96    }
97
98    /// Returns true if authentication is enabled (there are registered users).
99    pub fn is_enabled(&self) -> bool {
100        !self.users.is_empty()
101    }
102
103    /// Derive deterministic fake SCRAM credentials for an unknown username.
104    ///
105    /// Uses HMAC(server_secret, username) so the same unknown username always
106    /// produces the same salt, making the response indistinguishable from a
107    /// real user's challenge to an outside observer.
108    fn fake_credentials(&self, username: &str) -> ScramCredentials {
109        let mut mac = <Hmac<Sha256>>::new_from_slice(&self.server_secret)
110            .expect("HMAC accepts keys of any size");
111        mac.update(username.as_bytes());
112        let fake_salt = mac.finalize().into_bytes();
113
114        // Derive credentials using a throwaway password; the handshake will
115        // always fail at the `handle_scram` step because the attacker does
116        // not know a valid password, but the challenge itself looks normal.
117        derive_credentials("", &fake_salt, DEFAULT_ITERATIONS)
118    }
119
120    /// Handle a HELLO request: look up user, create SCRAM handshake.
121    ///
122    /// `client_first_b64` is the optional base64-encoded client-first-message
123    /// containing the client nonce. If absent, the server generates a nonce
124    /// (but the handshake will fail if the client expects its own nonce).
125    ///
126    /// Returns the `WWW-Authenticate` header value for the 401 response.
127    /// Unknown users receive a fake but plausible challenge to prevent
128    /// username enumeration.
129    pub fn handle_hello(
130        &self,
131        username: &str,
132        client_first_b64: Option<&str>,
133    ) -> Result<String, String> {
134        let credentials = match self.users.get(username) {
135            Some(user_record) => user_record.credentials.clone(),
136            None => self.fake_credentials(username),
137        };
138
139        // Extract client nonce from client-first-message, or generate one
140        let client_nonce = match client_first_b64 {
141            Some(data) => {
142                extract_client_nonce(data).map_err(|e| format!("invalid client-first data: {e}"))?
143            }
144            None => generate_nonce(),
145        };
146
147        // Create server-first-message
148        let (handshake, server_first_b64) =
149            server_first_message(username, &client_nonce, &credentials);
150
151        // Lazy cleanup: remove expired handshakes before inserting.
152        {
153            let now = Instant::now();
154            self.handshakes
155                .write()
156                .retain(|_, (_, created)| now.duration_since(*created) < HANDSHAKE_TTL);
157        }
158
159        // Store handshake with a unique token and timestamp.
160        let handshake_token = Uuid::new_v4().to_string();
161        self.handshakes
162            .write()
163            .insert(handshake_token.clone(), (handshake, Instant::now()));
164
165        // Format the WWW-Authenticate header
166        let www_auth = format_www_authenticate(&handshake_token, "SHA-256", &server_first_b64);
167        Ok(www_auth)
168    }
169
170    /// Handle a SCRAM request: verify client proof, issue auth token.
171    ///
172    /// Returns `(auth_token, authentication_info_header_value)`.
173    pub fn handle_scram(
174        &self,
175        handshake_token: &str,
176        data: &str,
177    ) -> Result<(String, String), String> {
178        // Remove the handshake (one-time use) and check expiry.
179        let (handshake, created_at) = self
180            .handshakes
181            .write()
182            .remove(handshake_token)
183            .ok_or_else(|| "invalid or expired handshake token".to_string())?;
184        if created_at.elapsed() > HANDSHAKE_TTL {
185            return Err("handshake token expired".to_string());
186        }
187
188        let username = handshake.username.clone();
189
190        // Verify client proof
191        let server_sig = server_verify_final(&handshake, data)
192            .map_err(|e| format!("SCRAM verification failed: {e}"))?;
193
194        // Issue auth token
195        let auth_token = Uuid::new_v4().to_string();
196
197        // Look up permissions
198        let permissions = self
199            .users
200            .get(&username)
201            .map(|r| r.permissions.clone())
202            .unwrap_or_default();
203
204        // Store token -> (user, created_at) mapping
205        self.tokens.write().insert(
206            auth_token.clone(),
207            (
208                AuthUser {
209                    username,
210                    permissions,
211                },
212                Instant::now(),
213            ),
214        );
215
216        // Format the server-final data (v=<server_signature>)
217        let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
218        let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
219        let auth_info = format_auth_info(&auth_token, &server_final_b64);
220
221        Ok((auth_token, auth_info))
222    }
223
224    /// Validate a bearer token and return the associated user.
225    ///
226    /// Returns `None` if the token is unknown or has expired. Expired
227    /// tokens are automatically removed under a single write lock to
228    /// avoid TOCTOU races.
229    pub fn validate_token(&self, token: &str) -> Option<AuthUser> {
230        let mut tokens = self.tokens.write();
231        match tokens.get(token) {
232            Some((user, created_at)) => {
233                if created_at.elapsed() <= self.token_ttl {
234                    Some(user.clone())
235                } else {
236                    // Token expired — remove immediately under the same lock.
237                    tokens.remove(token);
238                    None
239                }
240            }
241            None => None,
242        }
243    }
244
245    /// Remove a bearer token (logout / close).
246    pub fn revoke_token(&self, token: &str) -> bool {
247        self.tokens.write().remove(token).is_some()
248    }
249
250    /// Inject a token directly (for testing). The token is stamped with the
251    /// current instant so it will not be considered expired.
252    #[doc(hidden)]
253    pub fn inject_token(&self, token: String, user: AuthUser) {
254        self.tokens.write().insert(token, (user, Instant::now()));
255    }
256
257    /// Check whether a user has a required permission.
258    pub fn check_permission(user: &AuthUser, required: &str) -> bool {
259        // Admin has all permissions
260        if user.permissions.contains(&"admin".to_string()) {
261            return true;
262        }
263        user.permissions.contains(&required.to_string())
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::auth::users::hash_password;
271
272    fn make_test_manager() -> AuthManager {
273        let hash = hash_password("s3cret");
274        let toml_str = format!(
275            r#"
276[users.admin]
277password_hash = "{hash}"
278permissions = ["read", "write", "admin"]
279
280[users.viewer]
281password_hash = "{hash}"
282permissions = ["read"]
283"#
284        );
285        AuthManager::from_toml_str(&toml_str).unwrap()
286    }
287
288    #[test]
289    fn empty_manager_is_disabled() {
290        let mgr = AuthManager::empty();
291        assert!(!mgr.is_enabled());
292    }
293
294    #[test]
295    fn manager_with_users_is_enabled() {
296        let mgr = make_test_manager();
297        assert!(mgr.is_enabled());
298    }
299
300    #[test]
301    fn hello_unknown_user_returns_fake_challenge() {
302        let mgr = make_test_manager();
303        // Unknown users now get a plausible SCRAM challenge instead of an
304        // error, preventing username enumeration.
305        let result = mgr.handle_hello("nonexistent", None);
306        assert!(result.is_ok());
307        let www_auth = result.unwrap();
308        assert!(www_auth.contains("SCRAM"));
309        assert!(www_auth.contains("SHA-256"));
310    }
311
312    #[test]
313    fn hello_known_user_succeeds() {
314        let mgr = make_test_manager();
315        let result = mgr.handle_hello("admin", None);
316        assert!(result.is_ok());
317        let www_auth = result.unwrap();
318        assert!(www_auth.contains("SCRAM"));
319        assert!(www_auth.contains("SHA-256"));
320    }
321
322    #[test]
323    fn hello_known_and_unknown_users_look_similar() {
324        let mgr = make_test_manager();
325        let known = mgr.handle_hello("admin", None).unwrap();
326        let unknown = mgr.handle_hello("nonexistent", None).unwrap();
327
328        // Both responses must have the same structural format so that an
329        // attacker cannot distinguish real from fake users.
330        assert!(known.starts_with("SCRAM handshakeToken="));
331        assert!(unknown.starts_with("SCRAM handshakeToken="));
332        assert!(known.contains("hash=SHA-256"));
333        assert!(unknown.contains("hash=SHA-256"));
334        assert!(known.contains("data="));
335        assert!(unknown.contains("data="));
336    }
337
338    #[test]
339    fn fake_challenge_is_deterministic_per_username() {
340        let mgr = make_test_manager();
341        // The fake salt must be deterministic so that repeated HELLO requests
342        // for the same unknown username produce consistent parameters.
343        let creds1 = mgr.fake_credentials("ghost");
344        let creds2 = mgr.fake_credentials("ghost");
345        assert_eq!(creds1.salt, creds2.salt);
346        assert_eq!(creds1.stored_key, creds2.stored_key);
347        assert_eq!(creds1.server_key, creds2.server_key);
348
349        // Different usernames produce different fake salts.
350        let creds3 = mgr.fake_credentials("phantom");
351        assert_ne!(creds1.salt, creds3.salt);
352    }
353
354    #[test]
355    fn validate_token_returns_none_for_unknown() {
356        let mgr = make_test_manager();
357        assert!(mgr.validate_token("nonexistent-token").is_none());
358    }
359
360    #[test]
361    fn check_permission_admin_has_all() {
362        let user = AuthUser {
363            username: "admin".to_string(),
364            permissions: vec!["admin".to_string()],
365        };
366        assert!(AuthManager::check_permission(&user, "read"));
367        assert!(AuthManager::check_permission(&user, "write"));
368        assert!(AuthManager::check_permission(&user, "admin"));
369    }
370
371    #[test]
372    fn check_permission_viewer_limited() {
373        let user = AuthUser {
374            username: "viewer".to_string(),
375            permissions: vec!["read".to_string()],
376        };
377        assert!(AuthManager::check_permission(&user, "read"));
378        assert!(!AuthManager::check_permission(&user, "write"));
379        assert!(!AuthManager::check_permission(&user, "admin"));
380    }
381
382    #[test]
383    fn revoke_token_returns_false_for_unknown() {
384        let mgr = make_test_manager();
385        assert!(!mgr.revoke_token("nonexistent-token"));
386    }
387
388    #[test]
389    fn validate_token_succeeds_before_expiry() {
390        let mgr = make_test_manager();
391        // Manually insert a token with Instant::now() (fresh, not expired).
392        let user = AuthUser {
393            username: "admin".to_string(),
394            permissions: vec!["admin".to_string()],
395        };
396        mgr.tokens
397            .write()
398            .insert("good-token".to_string(), (user, Instant::now()));
399
400        assert!(mgr.validate_token("good-token").is_some());
401    }
402
403    #[test]
404    fn validate_token_fails_after_expiry() {
405        // Use a very short TTL so the token is already expired.
406        let mgr = make_test_manager().with_token_ttl(Duration::from_secs(0));
407
408        let user = AuthUser {
409            username: "admin".to_string(),
410            permissions: vec!["admin".to_string()],
411        };
412        // Insert a token that was created "now" -- with a 0s TTL it is
413        // immediately expired.
414        mgr.tokens
415            .write()
416            .insert("expired-token".to_string(), (user, Instant::now()));
417
418        // Even though the token exists, it should be reported as expired.
419        assert!(mgr.validate_token("expired-token").is_none());
420
421        // The expired token should have been removed from the map.
422        assert!(mgr.tokens.read().get("expired-token").is_none());
423    }
424
425    #[test]
426    fn with_token_ttl_sets_custom_duration() {
427        let mgr = AuthManager::empty().with_token_ttl(Duration::from_secs(120));
428        assert_eq!(mgr.token_ttl, Duration::from_secs(120));
429    }
430}