1use std::net::SocketAddr;
4
5use async_trait::async_trait;
6
7use super::SessionId;
8use crate::PublicKey;
9
10#[derive(Debug, Clone)]
15pub struct AuthContext {
16 pub username: String,
18 pub remote_addr: SocketAddr,
20 pub session_id: SessionId,
22 pub attempt_count: u32,
24}
25
26impl AuthContext {
27 pub fn new(
29 username: impl Into<String>,
30 remote_addr: SocketAddr,
31 session_id: SessionId,
32 ) -> Self {
33 Self {
34 username: username.into(),
35 remote_addr,
36 session_id,
37 attempt_count: 0,
38 }
39 }
40
41 pub fn with_attempt(mut self, count: u32) -> Self {
43 self.attempt_count = count;
44 self
45 }
46
47 pub fn username(&self) -> &str {
49 &self.username
50 }
51
52 pub fn remote_addr(&self) -> SocketAddr {
54 self.remote_addr
55 }
56
57 pub fn session_id(&self) -> SessionId {
59 self.session_id
60 }
61
62 pub fn attempt_count(&self) -> u32 {
64 self.attempt_count
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
70pub enum AuthMethod {
71 None,
73 Password,
75 PublicKey,
77 KeyboardInteractive,
79 HostBased,
81}
82
83impl std::fmt::Display for AuthMethod {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 match self {
86 AuthMethod::None => write!(f, "none"),
87 AuthMethod::Password => write!(f, "password"),
88 AuthMethod::PublicKey => write!(f, "publickey"),
89 AuthMethod::KeyboardInteractive => write!(f, "keyboard-interactive"),
90 AuthMethod::HostBased => write!(f, "hostbased"),
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub enum AuthResult {
98 Accept,
100 Reject,
102 Partial {
104 next_methods: Vec<AuthMethod>,
106 },
107}
108
109impl AuthResult {
110 pub fn is_accepted(&self) -> bool {
112 matches!(self, AuthResult::Accept)
113 }
114
115 pub fn is_rejected(&self) -> bool {
117 matches!(self, AuthResult::Reject)
118 }
119
120 pub fn is_partial(&self) -> bool {
122 matches!(self, AuthResult::Partial { .. })
123 }
124}
125
126#[async_trait]
151pub trait AuthHandler: Send + Sync {
152 async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
163 let _ = (ctx, password);
164 AuthResult::Reject
165 }
166
167 async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
178 let _ = (ctx, key);
179 AuthResult::Reject
180 }
181
182 async fn auth_keyboard_interactive(&self, ctx: &AuthContext, response: &str) -> AuthResult {
193 let _ = (ctx, response);
194 AuthResult::Reject
195 }
196
197 async fn auth_none(&self, ctx: &AuthContext) -> AuthResult {
201 let _ = ctx;
202 AuthResult::Reject
203 }
204
205 fn supported_methods(&self) -> Vec<AuthMethod> {
207 vec![AuthMethod::Password, AuthMethod::PublicKey]
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_auth_context() {
217 let addr: SocketAddr = "192.168.1.1:12345".parse().unwrap();
218 let ctx = AuthContext::new("testuser", addr, SessionId(42));
219
220 assert_eq!(ctx.username(), "testuser");
221 assert_eq!(ctx.remote_addr(), addr);
222 assert_eq!(ctx.session_id(), SessionId(42));
223 assert_eq!(ctx.attempt_count(), 0);
224
225 let ctx = ctx.with_attempt(3);
226 assert_eq!(ctx.attempt_count(), 3);
227 }
228
229 #[test]
230 fn test_auth_method_display() {
231 assert_eq!(format!("{}", AuthMethod::None), "none");
232 assert_eq!(format!("{}", AuthMethod::Password), "password");
233 assert_eq!(format!("{}", AuthMethod::PublicKey), "publickey");
234 assert_eq!(
235 format!("{}", AuthMethod::KeyboardInteractive),
236 "keyboard-interactive"
237 );
238 assert_eq!(format!("{}", AuthMethod::HostBased), "hostbased");
239 }
240
241 #[test]
242 fn test_auth_result_checks() {
243 let accept = AuthResult::Accept;
244 assert!(accept.is_accepted());
245 assert!(!accept.is_rejected());
246 assert!(!accept.is_partial());
247
248 let reject = AuthResult::Reject;
249 assert!(!reject.is_accepted());
250 assert!(reject.is_rejected());
251 assert!(!reject.is_partial());
252
253 let partial = AuthResult::Partial {
254 next_methods: vec![AuthMethod::Password],
255 };
256 assert!(!partial.is_accepted());
257 assert!(!partial.is_rejected());
258 assert!(partial.is_partial());
259 }
260
261 use super::super::SessionId;
262
263 struct RejectAllAuth;
264
265 #[async_trait]
266 impl AuthHandler for RejectAllAuth {}
267
268 #[tokio::test]
269 async fn test_default_auth_handler_rejects() {
270 let handler = RejectAllAuth;
271 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
272 let ctx = AuthContext::new("user", addr, SessionId(1));
273
274 assert!(matches!(
275 handler.auth_password(&ctx, "pass").await,
276 AuthResult::Reject
277 ));
278 assert!(matches!(handler.auth_none(&ctx).await, AuthResult::Reject));
279 }
280}