Skip to main content

wish/auth/
mod.rs

1//! Authentication module for Wish SSH server.
2//!
3//! This module provides flexible authentication handlers supporting
4//! password, public key, and keyboard-interactive authentication methods.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use wish::auth::{AuthHandler, AcceptAllAuth, AuthorizedKeysAuth};
10//!
11//! // Development: accept all connections
12//! let dev_auth = AcceptAllAuth::new();
13//!
14//! // Production: use authorized_keys file
15//! let prod_auth = AuthorizedKeysAuth::new("~/.ssh/authorized_keys")
16//!     .expect("Failed to load authorized_keys");
17//! ```
18
19mod authorized_keys;
20mod handler;
21mod password;
22mod publickey;
23
24pub use authorized_keys::{AuthorizedKey, AuthorizedKeysAuth, parse_authorized_keys};
25pub use handler::{AuthContext, AuthHandler, AuthMethod, AuthResult};
26pub use password::{AcceptAllAuth, AsyncCallbackAuth, CallbackAuth, PasswordAuth};
27pub use publickey::{AsyncPublicKeyAuth, PublicKeyAuth, PublicKeyCallbackAuth};
28
29use std::sync::Arc;
30
31use crate::PublicKey;
32
33/// Session ID type for tracking authentication attempts.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub struct SessionId(pub u64);
36
37impl std::fmt::Display for SessionId {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{}", self.0)
40    }
41}
42
43/// Default authentication rejection delay to mitigate timing attacks.
44pub const DEFAULT_AUTH_REJECTION_DELAY_MS: u64 = 100;
45
46/// Default maximum authentication attempts before disconnection.
47pub const DEFAULT_MAX_AUTH_ATTEMPTS: u32 = 6;
48
49/// Composite authentication handler that tries multiple handlers in order.
50///
51/// Returns `Accept` if any handler accepts, `Reject` if all reject.
52pub struct CompositeAuth {
53    handlers: Vec<Arc<dyn AuthHandler>>,
54}
55
56impl CompositeAuth {
57    /// Creates a new composite auth handler.
58    pub fn new() -> Self {
59        Self {
60            handlers: Vec::new(),
61        }
62    }
63
64    /// Adds an authentication handler.
65    #[allow(clippy::should_implement_trait)]
66    pub fn add<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
67        self.handlers.push(Arc::new(handler));
68        self
69    }
70}
71
72impl Default for CompositeAuth {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78#[async_trait::async_trait]
79impl AuthHandler for CompositeAuth {
80    async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
81        for handler in &self.handlers {
82            match handler.auth_password(ctx, password).await {
83                AuthResult::Accept => return AuthResult::Accept,
84                AuthResult::Partial { next_methods } => {
85                    return AuthResult::Partial { next_methods };
86                }
87                AuthResult::Reject => continue,
88            }
89        }
90        AuthResult::Reject
91    }
92
93    async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
94        for handler in &self.handlers {
95            match handler.auth_publickey(ctx, key).await {
96                AuthResult::Accept => return AuthResult::Accept,
97                AuthResult::Partial { next_methods } => {
98                    return AuthResult::Partial { next_methods };
99                }
100                AuthResult::Reject => continue,
101            }
102        }
103        AuthResult::Reject
104    }
105
106    async fn auth_keyboard_interactive(&self, ctx: &AuthContext, response: &str) -> AuthResult {
107        for handler in &self.handlers {
108            match handler.auth_keyboard_interactive(ctx, response).await {
109                AuthResult::Accept => return AuthResult::Accept,
110                AuthResult::Partial { next_methods } => {
111                    return AuthResult::Partial { next_methods };
112                }
113                AuthResult::Reject => continue,
114            }
115        }
116        AuthResult::Reject
117    }
118}
119
120/// Rate-limited authentication wrapper.
121///
122/// Adds a delay after failed authentication attempts to mitigate
123/// brute-force attacks and timing attacks.
124pub struct RateLimitedAuth<H> {
125    inner: H,
126    rejection_delay_ms: u64,
127    max_attempts: u32,
128}
129
130impl<H: AuthHandler> RateLimitedAuth<H> {
131    /// Creates a new rate-limited auth wrapper with default settings.
132    pub fn new(inner: H) -> Self {
133        Self {
134            inner,
135            rejection_delay_ms: DEFAULT_AUTH_REJECTION_DELAY_MS,
136            max_attempts: DEFAULT_MAX_AUTH_ATTEMPTS,
137        }
138    }
139
140    /// Sets the rejection delay in milliseconds.
141    pub fn with_rejection_delay(mut self, delay_ms: u64) -> Self {
142        self.rejection_delay_ms = delay_ms;
143        self
144    }
145
146    /// Sets the maximum authentication attempts.
147    pub fn with_max_attempts(mut self, max: u32) -> Self {
148        self.max_attempts = max;
149        self
150    }
151
152    /// Returns the maximum authentication attempts.
153    pub fn max_attempts(&self) -> u32 {
154        self.max_attempts
155    }
156
157    async fn apply_rejection_delay(&self) {
158        if self.rejection_delay_ms > 0 {
159            tokio::time::sleep(std::time::Duration::from_millis(self.rejection_delay_ms)).await;
160        }
161    }
162}
163
164#[async_trait::async_trait]
165impl<H: AuthHandler + Send + Sync> AuthHandler for RateLimitedAuth<H> {
166    async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
167        let result = self.inner.auth_password(ctx, password).await;
168        if matches!(result, AuthResult::Reject) {
169            self.apply_rejection_delay().await;
170        }
171        result
172    }
173
174    async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
175        let result = self.inner.auth_publickey(ctx, key).await;
176        if matches!(result, AuthResult::Reject) {
177            self.apply_rejection_delay().await;
178        }
179        result
180    }
181
182    async fn auth_keyboard_interactive(&self, ctx: &AuthContext, response: &str) -> AuthResult {
183        let result = self.inner.auth_keyboard_interactive(ctx, response).await;
184        if matches!(result, AuthResult::Reject) {
185            self.apply_rejection_delay().await;
186        }
187        result
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use std::net::SocketAddr;
195    use std::time::Duration;
196
197    struct RejectAuth;
198
199    #[async_trait::async_trait]
200    impl AuthHandler for RejectAuth {}
201
202    struct PartialAuth;
203
204    #[async_trait::async_trait]
205    impl AuthHandler for PartialAuth {
206        async fn auth_password(&self, _ctx: &AuthContext, _password: &str) -> AuthResult {
207            AuthResult::Partial {
208                next_methods: vec![AuthMethod::PublicKey],
209            }
210        }
211    }
212
213    #[test]
214    fn test_session_id() {
215        let id = SessionId(42);
216        assert_eq!(id.0, 42);
217        assert_eq!(format!("{}", id), "42");
218    }
219
220    #[test]
221    fn test_composite_auth_empty() {
222        let auth = CompositeAuth::new();
223        assert!(auth.handlers.is_empty());
224    }
225
226    #[tokio::test]
227    async fn test_composite_auth_accepts_first() {
228        let auth = CompositeAuth::new().add(AcceptAllAuth::new());
229
230        let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
231        let ctx = AuthContext::new("testuser", addr, SessionId(1));
232
233        let result = auth.auth_password(&ctx, "password").await;
234        assert!(matches!(result, AuthResult::Accept));
235    }
236
237    #[tokio::test]
238    async fn test_composite_auth_rejects_all() {
239        let auth = CompositeAuth::new().add(RejectAuth);
240
241        let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
242        let ctx = AuthContext::new("testuser", addr, SessionId(1));
243
244        let result = auth.auth_password(&ctx, "password").await;
245        assert!(matches!(result, AuthResult::Reject));
246    }
247
248    #[tokio::test]
249    async fn test_composite_auth_partial() {
250        let auth = CompositeAuth::new()
251            .add(PartialAuth)
252            .add(AcceptAllAuth::new());
253
254        let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
255        let ctx = AuthContext::new("testuser", addr, SessionId(1));
256
257        let result = auth.auth_password(&ctx, "password").await;
258        match result {
259            AuthResult::Partial { next_methods } => {
260                assert_eq!(next_methods, vec![AuthMethod::PublicKey]);
261            }
262            _ => panic!("Expected partial auth result"),
263        }
264    }
265
266    #[test]
267    fn test_rate_limited_auth_settings() {
268        let inner = AcceptAllAuth::new();
269        let auth = RateLimitedAuth::new(inner)
270            .with_rejection_delay(200)
271            .with_max_attempts(3);
272
273        assert_eq!(auth.rejection_delay_ms, 200);
274        assert_eq!(auth.max_attempts(), 3);
275    }
276
277    #[tokio::test]
278    async fn test_rate_limited_auth_delay_on_reject() {
279        let inner = RejectAuth;
280        let auth = RateLimitedAuth::new(inner).with_rejection_delay(20);
281
282        let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
283        let ctx = AuthContext::new("testuser", addr, SessionId(1));
284
285        let start = tokio::time::Instant::now();
286        let result = auth.auth_password(&ctx, "password").await;
287        let elapsed = start.elapsed();
288
289        assert!(matches!(result, AuthResult::Reject));
290        assert!(elapsed >= Duration::from_millis(15));
291    }
292}