Skip to main content

wish/auth/
password.rs

1//! Password authentication handlers.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tracing::{debug, warn};
7
8use super::handler::{AuthContext, AuthHandler, AuthMethod, AuthResult};
9use crate::PublicKey;
10
11/// Authentication handler that accepts all authentication attempts.
12///
13/// **WARNING**: This should only be used for development and testing.
14/// Using this in production is a serious security risk.
15///
16/// # Example
17///
18/// ```rust,ignore
19/// use wish::auth::AcceptAllAuth;
20///
21/// let auth = AcceptAllAuth::new();
22/// ```
23#[derive(Debug, Clone, Copy, Default)]
24pub struct AcceptAllAuth {
25    _private: (),
26}
27
28impl AcceptAllAuth {
29    /// Creates a new AcceptAllAuth handler.
30    pub fn new() -> Self {
31        warn!("AcceptAllAuth in use - NOT FOR PRODUCTION");
32        Self { _private: () }
33    }
34}
35
36#[async_trait]
37impl AuthHandler for AcceptAllAuth {
38    async fn auth_password(&self, ctx: &AuthContext, _password: &str) -> AuthResult {
39        warn!(
40            username = %ctx.username(),
41            remote_addr = %ctx.remote_addr(),
42            "AcceptAllAuth: accepting password auth"
43        );
44        AuthResult::Accept
45    }
46
47    async fn auth_publickey(&self, ctx: &AuthContext, _key: &PublicKey) -> AuthResult {
48        warn!(
49            username = %ctx.username(),
50            remote_addr = %ctx.remote_addr(),
51            "AcceptAllAuth: accepting public key auth"
52        );
53        AuthResult::Accept
54    }
55
56    async fn auth_keyboard_interactive(&self, ctx: &AuthContext, _response: &str) -> AuthResult {
57        warn!(
58            username = %ctx.username(),
59            remote_addr = %ctx.remote_addr(),
60            "AcceptAllAuth: accepting keyboard-interactive auth"
61        );
62        AuthResult::Accept
63    }
64
65    async fn auth_none(&self, ctx: &AuthContext) -> AuthResult {
66        warn!(
67            username = %ctx.username(),
68            remote_addr = %ctx.remote_addr(),
69            "AcceptAllAuth: accepting none auth"
70        );
71        AuthResult::Accept
72    }
73
74    fn supported_methods(&self) -> Vec<AuthMethod> {
75        vec![
76            AuthMethod::None,
77            AuthMethod::Password,
78            AuthMethod::PublicKey,
79            AuthMethod::KeyboardInteractive,
80        ]
81    }
82}
83
84/// Callback-based password authentication handler.
85///
86/// Uses a user-provided callback function to validate passwords.
87///
88/// # Example
89///
90/// ```rust,ignore
91/// use wish::auth::CallbackAuth;
92///
93/// let auth = CallbackAuth::new(|ctx, password| {
94///     ctx.username() == "admin" && password == "secret"
95/// });
96/// ```
97pub struct CallbackAuth<F>
98where
99    F: Fn(&AuthContext, &str) -> bool + Send + Sync,
100{
101    callback: F,
102}
103
104impl<F> CallbackAuth<F>
105where
106    F: Fn(&AuthContext, &str) -> bool + Send + Sync,
107{
108    /// Creates a new callback-based auth handler.
109    pub fn new(callback: F) -> Self {
110        Self { callback }
111    }
112}
113
114#[async_trait]
115impl<F> AuthHandler for CallbackAuth<F>
116where
117    F: Fn(&AuthContext, &str) -> bool + Send + Sync + 'static,
118{
119    async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
120        debug!(
121            username = %ctx.username(),
122            remote_addr = %ctx.remote_addr(),
123            "CallbackAuth: password auth attempt"
124        );
125
126        if (self.callback)(ctx, password) {
127            debug!(username = %ctx.username(), "CallbackAuth: password accepted");
128            AuthResult::Accept
129        } else {
130            debug!(username = %ctx.username(), "CallbackAuth: password rejected");
131            AuthResult::Reject
132        }
133    }
134
135    fn supported_methods(&self) -> Vec<AuthMethod> {
136        vec![AuthMethod::Password]
137    }
138}
139
140/// Simple password authentication against a static map.
141///
142/// Stores username/password pairs and validates against them.
143///
144/// # Example
145///
146/// ```rust,ignore
147/// use wish::auth::PasswordAuth;
148///
149/// let auth = PasswordAuth::new()
150///     .add_user("alice", "password123")
151///     .add_user("bob", "secret456");
152/// ```
153pub struct PasswordAuth {
154    users: std::collections::HashMap<String, String>,
155}
156
157impl PasswordAuth {
158    /// Creates a new empty password auth handler.
159    pub fn new() -> Self {
160        Self {
161            users: std::collections::HashMap::new(),
162        }
163    }
164
165    /// Adds a user with the given password.
166    pub fn add_user(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
167        self.users.insert(username.into(), password.into());
168        self
169    }
170
171    /// Adds multiple users from an iterator.
172    pub fn add_users<I, U, P>(mut self, users: I) -> Self
173    where
174        I: IntoIterator<Item = (U, P)>,
175        U: Into<String>,
176        P: Into<String>,
177    {
178        for (username, password) in users {
179            self.users.insert(username.into(), password.into());
180        }
181        self
182    }
183
184    /// Returns the number of registered users.
185    pub fn user_count(&self) -> usize {
186        self.users.len()
187    }
188
189    /// Checks if a user exists.
190    pub fn has_user(&self, username: &str) -> bool {
191        self.users.contains_key(username)
192    }
193}
194
195impl Default for PasswordAuth {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201#[async_trait]
202impl AuthHandler for PasswordAuth {
203    async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
204        debug!(
205            username = %ctx.username(),
206            remote_addr = %ctx.remote_addr(),
207            "PasswordAuth: auth attempt"
208        );
209
210        let stored = self.users.get(ctx.username());
211        // Use a dummy string for comparison if user is not found to mitigate timing attacks
212        // against username enumeration (though checking the map itself might still leak timing)
213        let target = stored.map(String::as_str).unwrap_or("");
214
215        if constant_time_eq(target, password) && stored.is_some() {
216            debug!(username = %ctx.username(), "PasswordAuth: accepted");
217            AuthResult::Accept
218        } else {
219            debug!(username = %ctx.username(), "PasswordAuth: rejected");
220            AuthResult::Reject
221        }
222    }
223
224    fn supported_methods(&self) -> Vec<AuthMethod> {
225        vec![AuthMethod::Password]
226    }
227}
228
229/// Fixed-time string comparison.
230///
231/// Always iterates over the longer of the two inputs to avoid leaking
232/// length information through timing. XORs each byte pair and accumulates
233/// differences; also marks unequal if lengths differ.
234fn constant_time_eq(a: &str, b: &str) -> bool {
235    let a_bytes = a.as_bytes();
236    let b_bytes = b.as_bytes();
237    let len = a_bytes.len().max(b_bytes.len());
238    // Use usize for result to avoid truncation when lengths differ by ≥256
239    let mut result: usize = a_bytes.len() ^ b_bytes.len();
240    for i in 0..len {
241        let x = a_bytes.get(i).copied().unwrap_or(0);
242        let y = b_bytes.get(i).copied().unwrap_or(0);
243        result |= (x ^ y) as usize;
244    }
245    result == 0
246}
247
248/// Async callback-based password authentication handler.
249///
250/// Uses a user-provided async callback function to validate passwords.
251/// Useful for database lookups or remote authentication services.
252///
253/// # Example
254///
255/// ```rust,ignore
256/// use wish::auth::AsyncCallbackAuth;
257/// use std::sync::Arc;
258///
259/// let auth = AsyncCallbackAuth::new(Arc::new(|ctx, password| {
260///     Box::pin(async move {
261///         // Async database lookup
262///         database_check(ctx.username(), password).await
263///     })
264/// }));
265/// ```
266#[allow(dead_code)]
267pub struct AsyncCallbackAuth<F>
268where
269    F: Fn(&AuthContext, &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
270        + Send
271        + Sync,
272{
273    callback: Arc<F>,
274}
275
276impl<F> AsyncCallbackAuth<F>
277where
278    F: Fn(&AuthContext, &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
279        + Send
280        + Sync,
281{
282    /// Creates a new async callback-based auth handler.
283    #[allow(dead_code)]
284    pub fn new(callback: Arc<F>) -> Self {
285        Self { callback }
286    }
287}
288
289#[async_trait]
290impl<F> AuthHandler for AsyncCallbackAuth<F>
291where
292    F: Fn(&AuthContext, &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
293        + Send
294        + Sync
295        + 'static,
296{
297    async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
298        debug!(
299            username = %ctx.username(),
300            remote_addr = %ctx.remote_addr(),
301            "AsyncCallbackAuth: password auth attempt"
302        );
303
304        if (self.callback)(ctx, password).await {
305            debug!(username = %ctx.username(), "AsyncCallbackAuth: password accepted");
306            AuthResult::Accept
307        } else {
308            debug!(username = %ctx.username(), "AsyncCallbackAuth: password rejected");
309            AuthResult::Reject
310        }
311    }
312
313    fn supported_methods(&self) -> Vec<AuthMethod> {
314        vec![AuthMethod::Password]
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::super::SessionId;
321    use super::*;
322    use std::net::SocketAddr;
323    use std::sync::Arc;
324
325    fn make_context(username: &str) -> AuthContext {
326        let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
327        AuthContext::new(username, addr, SessionId(1))
328    }
329
330    #[tokio::test]
331    async fn test_accept_all_auth() {
332        let auth = AcceptAllAuth::new();
333        let ctx = make_context("anyone");
334
335        assert!(matches!(
336            auth.auth_password(&ctx, "anything").await,
337            AuthResult::Accept
338        ));
339        assert!(matches!(auth.auth_none(&ctx).await, AuthResult::Accept));
340    }
341
342    #[tokio::test]
343    async fn test_callback_auth() {
344        let auth =
345            CallbackAuth::new(|ctx, password| ctx.username() == "admin" && password == "secret");
346
347        let ctx = make_context("admin");
348        assert!(matches!(
349            auth.auth_password(&ctx, "secret").await,
350            AuthResult::Accept
351        ));
352        assert!(matches!(
353            auth.auth_password(&ctx, "wrong").await,
354            AuthResult::Reject
355        ));
356
357        let ctx = make_context("user");
358        assert!(matches!(
359            auth.auth_password(&ctx, "secret").await,
360            AuthResult::Reject
361        ));
362    }
363
364    #[tokio::test]
365    async fn test_password_auth() {
366        let auth = PasswordAuth::new()
367            .add_user("alice", "password123")
368            .add_user("bob", "secret456");
369
370        assert_eq!(auth.user_count(), 2);
371        assert!(auth.has_user("alice"));
372        assert!(!auth.has_user("charlie"));
373
374        let ctx = make_context("alice");
375        assert!(matches!(
376            auth.auth_password(&ctx, "password123").await,
377            AuthResult::Accept
378        ));
379        assert!(matches!(
380            auth.auth_password(&ctx, "wrong").await,
381            AuthResult::Reject
382        ));
383
384        let ctx = make_context("charlie");
385        assert!(matches!(
386            auth.auth_password(&ctx, "any").await,
387            AuthResult::Reject
388        ));
389    }
390
391    #[test]
392    fn test_password_auth_add_users() {
393        let users = vec![("user1", "pass1"), ("user2", "pass2")];
394        let auth = PasswordAuth::new().add_users(users);
395        assert_eq!(auth.user_count(), 2);
396        assert!(auth.has_user("user1"));
397        assert!(auth.has_user("user2"));
398    }
399
400    #[tokio::test]
401    async fn test_async_callback_auth() {
402        let auth = AsyncCallbackAuth::new(Arc::new(|ctx: &AuthContext, password: &str| {
403            let username = ctx.username().to_string();
404            let password = password.to_string();
405            let fut: std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> =
406                Box::pin(async move { username == "admin" && password == "secret" });
407            fut
408        }));
409
410        let ctx = make_context("admin");
411        assert!(matches!(
412            auth.auth_password(&ctx, "secret").await,
413            AuthResult::Accept
414        ));
415        assert!(matches!(
416            auth.auth_password(&ctx, "wrong").await,
417            AuthResult::Reject
418        ));
419    }
420
421    #[test]
422    fn test_constant_time_eq_basic() {
423        assert!(constant_time_eq("hello", "hello"));
424        assert!(!constant_time_eq("hello", "world"));
425        assert!(!constant_time_eq("hello", "hell"));
426        assert!(!constant_time_eq("", "a"));
427        assert!(constant_time_eq("", ""));
428    }
429
430    #[test]
431    fn test_constant_time_eq_length_differs_by_256() {
432        // This test verifies the fix for the u8 truncation bug.
433        // With the old code, (0 ^ 256) as u8 == 0, which would incorrectly
434        // seed the result as "equal" for length comparison.
435        let short = "";
436        let long = "a".repeat(256);
437        assert!(!constant_time_eq(short, &long));
438        assert!(!constant_time_eq(&long, short));
439
440        // Also test non-empty strings differing by 256
441        let a = "x";
442        let b = "x".to_string() + &"y".repeat(256);
443        assert!(!constant_time_eq(a, &b));
444    }
445}