Skip to main content

wish/auth/
handler.rs

1//! Core authentication handler trait and types.
2
3use std::net::SocketAddr;
4
5use async_trait::async_trait;
6
7use super::SessionId;
8use crate::PublicKey;
9
10/// Context provided to authentication handlers.
11///
12/// Contains information about the authentication attempt including
13/// the username, remote address, and session identifier.
14#[derive(Debug, Clone)]
15pub struct AuthContext {
16    /// The username attempting authentication.
17    pub username: String,
18    /// The remote address of the client.
19    pub remote_addr: SocketAddr,
20    /// The session ID for this connection.
21    pub session_id: SessionId,
22    /// Number of authentication attempts so far.
23    pub attempt_count: u32,
24}
25
26impl AuthContext {
27    /// Creates a new authentication context.
28    pub fn new(
29        username: impl Into<String>,
30        remote_addr: SocketAddr,
31        session_id: SessionId,
32    ) -> Self {
33        Self {
34            username: username.into(),
35            remote_addr,
36            session_id,
37            attempt_count: 0,
38        }
39    }
40
41    /// Creates a context with an incremented attempt count.
42    pub fn with_attempt(mut self, count: u32) -> Self {
43        self.attempt_count = count;
44        self
45    }
46
47    /// Returns the username.
48    pub fn username(&self) -> &str {
49        &self.username
50    }
51
52    /// Returns the remote address.
53    pub fn remote_addr(&self) -> SocketAddr {
54        self.remote_addr
55    }
56
57    /// Returns the session ID.
58    pub fn session_id(&self) -> SessionId {
59        self.session_id
60    }
61
62    /// Returns the current attempt count.
63    pub fn attempt_count(&self) -> u32 {
64        self.attempt_count
65    }
66}
67
68/// Authentication methods supported by SSH.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub enum AuthMethod {
71    /// No authentication (anonymous).
72    None,
73    /// Password authentication.
74    Password,
75    /// Public key authentication.
76    PublicKey,
77    /// Keyboard-interactive authentication.
78    KeyboardInteractive,
79    /// Host-based authentication.
80    HostBased,
81}
82
83impl std::fmt::Display for AuthMethod {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        match self {
86            AuthMethod::None => write!(f, "none"),
87            AuthMethod::Password => write!(f, "password"),
88            AuthMethod::PublicKey => write!(f, "publickey"),
89            AuthMethod::KeyboardInteractive => write!(f, "keyboard-interactive"),
90            AuthMethod::HostBased => write!(f, "hostbased"),
91        }
92    }
93}
94
95/// Result of an authentication attempt.
96#[derive(Debug, Clone)]
97pub enum AuthResult {
98    /// Authentication was successful.
99    Accept,
100    /// Authentication was rejected.
101    Reject,
102    /// Authentication partially succeeded, continue with additional methods.
103    Partial {
104        /// Methods to continue with.
105        next_methods: Vec<AuthMethod>,
106    },
107}
108
109impl AuthResult {
110    /// Returns true if the authentication was accepted.
111    pub fn is_accepted(&self) -> bool {
112        matches!(self, AuthResult::Accept)
113    }
114
115    /// Returns true if the authentication was rejected.
116    pub fn is_rejected(&self) -> bool {
117        matches!(self, AuthResult::Reject)
118    }
119
120    /// Returns true if partial authentication is required.
121    pub fn is_partial(&self) -> bool {
122        matches!(self, AuthResult::Partial { .. })
123    }
124}
125
126/// Trait for implementing authentication handlers.
127///
128/// Authentication handlers decide whether to accept or reject
129/// authentication attempts based on credentials provided.
130///
131/// # Example
132///
133/// ```rust,ignore
134/// use wish::auth::{AuthHandler, AuthContext, AuthResult};
135/// use async_trait::async_trait;
136///
137/// struct MyAuth;
138///
139/// #[async_trait]
140/// impl AuthHandler for MyAuth {
141///     async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
142///         if ctx.username() == "admin" && password == "secret" {
143///             AuthResult::Accept
144///         } else {
145///             AuthResult::Reject
146///         }
147///     }
148/// }
149/// ```
150#[async_trait]
151pub trait AuthHandler: Send + Sync {
152    /// Authenticate with password.
153    ///
154    /// # Arguments
155    ///
156    /// * `ctx` - The authentication context.
157    /// * `password` - The password provided by the client.
158    ///
159    /// # Returns
160    ///
161    /// The authentication result.
162    async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
163        let _ = (ctx, password);
164        AuthResult::Reject
165    }
166
167    /// Authenticate with public key.
168    ///
169    /// # Arguments
170    ///
171    /// * `ctx` - The authentication context.
172    /// * `key` - The public key provided by the client.
173    ///
174    /// # Returns
175    ///
176    /// The authentication result.
177    async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
178        let _ = (ctx, key);
179        AuthResult::Reject
180    }
181
182    /// Authenticate with keyboard-interactive.
183    ///
184    /// # Arguments
185    ///
186    /// * `ctx` - The authentication context.
187    /// * `response` - The response provided by the client.
188    ///
189    /// # Returns
190    ///
191    /// The authentication result.
192    async fn auth_keyboard_interactive(&self, ctx: &AuthContext, response: &str) -> AuthResult {
193        let _ = (ctx, response);
194        AuthResult::Reject
195    }
196
197    /// Check if "none" authentication is allowed.
198    ///
199    /// By default, returns `Reject`. Override to allow anonymous access.
200    async fn auth_none(&self, ctx: &AuthContext) -> AuthResult {
201        let _ = ctx;
202        AuthResult::Reject
203    }
204
205    /// Returns the authentication methods supported by this handler.
206    fn supported_methods(&self) -> Vec<AuthMethod> {
207        vec![AuthMethod::Password, AuthMethod::PublicKey]
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_auth_context() {
217        let addr: SocketAddr = "192.168.1.1:12345".parse().unwrap();
218        let ctx = AuthContext::new("testuser", addr, SessionId(42));
219
220        assert_eq!(ctx.username(), "testuser");
221        assert_eq!(ctx.remote_addr(), addr);
222        assert_eq!(ctx.session_id(), SessionId(42));
223        assert_eq!(ctx.attempt_count(), 0);
224
225        let ctx = ctx.with_attempt(3);
226        assert_eq!(ctx.attempt_count(), 3);
227    }
228
229    #[test]
230    fn test_auth_method_display() {
231        assert_eq!(format!("{}", AuthMethod::None), "none");
232        assert_eq!(format!("{}", AuthMethod::Password), "password");
233        assert_eq!(format!("{}", AuthMethod::PublicKey), "publickey");
234        assert_eq!(
235            format!("{}", AuthMethod::KeyboardInteractive),
236            "keyboard-interactive"
237        );
238        assert_eq!(format!("{}", AuthMethod::HostBased), "hostbased");
239    }
240
241    #[test]
242    fn test_auth_result_checks() {
243        let accept = AuthResult::Accept;
244        assert!(accept.is_accepted());
245        assert!(!accept.is_rejected());
246        assert!(!accept.is_partial());
247
248        let reject = AuthResult::Reject;
249        assert!(!reject.is_accepted());
250        assert!(reject.is_rejected());
251        assert!(!reject.is_partial());
252
253        let partial = AuthResult::Partial {
254            next_methods: vec![AuthMethod::Password],
255        };
256        assert!(!partial.is_accepted());
257        assert!(!partial.is_rejected());
258        assert!(partial.is_partial());
259    }
260
261    use super::super::SessionId;
262
263    struct RejectAllAuth;
264
265    #[async_trait]
266    impl AuthHandler for RejectAllAuth {}
267
268    #[tokio::test]
269    async fn test_default_auth_handler_rejects() {
270        let handler = RejectAllAuth;
271        let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
272        let ctx = AuthContext::new("user", addr, SessionId(1));
273
274        assert!(matches!(
275            handler.auth_password(&ctx, "pass").await,
276            AuthResult::Reject
277        ));
278        assert!(matches!(handler.auth_none(&ctx).await, AuthResult::Reject));
279    }
280}