1mod authorized_keys;
20mod handler;
21mod password;
22mod publickey;
23
24pub use authorized_keys::{AuthorizedKey, AuthorizedKeysAuth, parse_authorized_keys};
25pub use handler::{AuthContext, AuthHandler, AuthMethod, AuthResult};
26pub use password::{AcceptAllAuth, AsyncCallbackAuth, CallbackAuth, PasswordAuth};
27pub use publickey::{AsyncPublicKeyAuth, PublicKeyAuth, PublicKeyCallbackAuth};
28
29use std::sync::Arc;
30
31use crate::PublicKey;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
35pub struct SessionId(pub u64);
36
37impl std::fmt::Display for SessionId {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "{}", self.0)
40 }
41}
42
43pub const DEFAULT_AUTH_REJECTION_DELAY_MS: u64 = 100;
45
46pub const DEFAULT_MAX_AUTH_ATTEMPTS: u32 = 6;
48
49pub struct CompositeAuth {
53 handlers: Vec<Arc<dyn AuthHandler>>,
54}
55
56impl CompositeAuth {
57 pub fn new() -> Self {
59 Self {
60 handlers: Vec::new(),
61 }
62 }
63
64 #[allow(clippy::should_implement_trait)]
66 pub fn add<H: AuthHandler + 'static>(mut self, handler: H) -> Self {
67 self.handlers.push(Arc::new(handler));
68 self
69 }
70}
71
72impl Default for CompositeAuth {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78#[async_trait::async_trait]
79impl AuthHandler for CompositeAuth {
80 async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
81 for handler in &self.handlers {
82 match handler.auth_password(ctx, password).await {
83 AuthResult::Accept => return AuthResult::Accept,
84 AuthResult::Partial { next_methods } => {
85 return AuthResult::Partial { next_methods };
86 }
87 AuthResult::Reject => continue,
88 }
89 }
90 AuthResult::Reject
91 }
92
93 async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
94 for handler in &self.handlers {
95 match handler.auth_publickey(ctx, key).await {
96 AuthResult::Accept => return AuthResult::Accept,
97 AuthResult::Partial { next_methods } => {
98 return AuthResult::Partial { next_methods };
99 }
100 AuthResult::Reject => continue,
101 }
102 }
103 AuthResult::Reject
104 }
105
106 async fn auth_keyboard_interactive(&self, ctx: &AuthContext, response: &str) -> AuthResult {
107 for handler in &self.handlers {
108 match handler.auth_keyboard_interactive(ctx, response).await {
109 AuthResult::Accept => return AuthResult::Accept,
110 AuthResult::Partial { next_methods } => {
111 return AuthResult::Partial { next_methods };
112 }
113 AuthResult::Reject => continue,
114 }
115 }
116 AuthResult::Reject
117 }
118}
119
120pub struct RateLimitedAuth<H> {
125 inner: H,
126 rejection_delay_ms: u64,
127 max_attempts: u32,
128}
129
130impl<H: AuthHandler> RateLimitedAuth<H> {
131 pub fn new(inner: H) -> Self {
133 Self {
134 inner,
135 rejection_delay_ms: DEFAULT_AUTH_REJECTION_DELAY_MS,
136 max_attempts: DEFAULT_MAX_AUTH_ATTEMPTS,
137 }
138 }
139
140 pub fn with_rejection_delay(mut self, delay_ms: u64) -> Self {
142 self.rejection_delay_ms = delay_ms;
143 self
144 }
145
146 pub fn with_max_attempts(mut self, max: u32) -> Self {
148 self.max_attempts = max;
149 self
150 }
151
152 pub fn max_attempts(&self) -> u32 {
154 self.max_attempts
155 }
156
157 async fn apply_rejection_delay(&self) {
158 if self.rejection_delay_ms > 0 {
159 tokio::time::sleep(std::time::Duration::from_millis(self.rejection_delay_ms)).await;
160 }
161 }
162}
163
164#[async_trait::async_trait]
165impl<H: AuthHandler + Send + Sync> AuthHandler for RateLimitedAuth<H> {
166 async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
167 let result = self.inner.auth_password(ctx, password).await;
168 if matches!(result, AuthResult::Reject) {
169 self.apply_rejection_delay().await;
170 }
171 result
172 }
173
174 async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
175 let result = self.inner.auth_publickey(ctx, key).await;
176 if matches!(result, AuthResult::Reject) {
177 self.apply_rejection_delay().await;
178 }
179 result
180 }
181
182 async fn auth_keyboard_interactive(&self, ctx: &AuthContext, response: &str) -> AuthResult {
183 let result = self.inner.auth_keyboard_interactive(ctx, response).await;
184 if matches!(result, AuthResult::Reject) {
185 self.apply_rejection_delay().await;
186 }
187 result
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use std::net::SocketAddr;
195 use std::time::Duration;
196
197 struct RejectAuth;
198
199 #[async_trait::async_trait]
200 impl AuthHandler for RejectAuth {}
201
202 struct PartialAuth;
203
204 #[async_trait::async_trait]
205 impl AuthHandler for PartialAuth {
206 async fn auth_password(&self, _ctx: &AuthContext, _password: &str) -> AuthResult {
207 AuthResult::Partial {
208 next_methods: vec![AuthMethod::PublicKey],
209 }
210 }
211 }
212
213 #[test]
214 fn test_session_id() {
215 let id = SessionId(42);
216 assert_eq!(id.0, 42);
217 assert_eq!(format!("{}", id), "42");
218 }
219
220 #[test]
221 fn test_composite_auth_empty() {
222 let auth = CompositeAuth::new();
223 assert!(auth.handlers.is_empty());
224 }
225
226 #[tokio::test]
227 async fn test_composite_auth_accepts_first() {
228 let auth = CompositeAuth::new().add(AcceptAllAuth::new());
229
230 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
231 let ctx = AuthContext::new("testuser", addr, SessionId(1));
232
233 let result = auth.auth_password(&ctx, "password").await;
234 assert!(matches!(result, AuthResult::Accept));
235 }
236
237 #[tokio::test]
238 async fn test_composite_auth_rejects_all() {
239 let auth = CompositeAuth::new().add(RejectAuth);
240
241 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
242 let ctx = AuthContext::new("testuser", addr, SessionId(1));
243
244 let result = auth.auth_password(&ctx, "password").await;
245 assert!(matches!(result, AuthResult::Reject));
246 }
247
248 #[tokio::test]
249 async fn test_composite_auth_partial() {
250 let auth = CompositeAuth::new()
251 .add(PartialAuth)
252 .add(AcceptAllAuth::new());
253
254 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
255 let ctx = AuthContext::new("testuser", addr, SessionId(1));
256
257 let result = auth.auth_password(&ctx, "password").await;
258 match result {
259 AuthResult::Partial { next_methods } => {
260 assert_eq!(next_methods, vec![AuthMethod::PublicKey]);
261 }
262 _ => panic!("Expected partial auth result"),
263 }
264 }
265
266 #[test]
267 fn test_rate_limited_auth_settings() {
268 let inner = AcceptAllAuth::new();
269 let auth = RateLimitedAuth::new(inner)
270 .with_rejection_delay(200)
271 .with_max_attempts(3);
272
273 assert_eq!(auth.rejection_delay_ms, 200);
274 assert_eq!(auth.max_attempts(), 3);
275 }
276
277 #[tokio::test]
278 async fn test_rate_limited_auth_delay_on_reject() {
279 let inner = RejectAuth;
280 let auth = RateLimitedAuth::new(inner).with_rejection_delay(20);
281
282 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
283 let ctx = AuthContext::new("testuser", addr, SessionId(1));
284
285 let start = tokio::time::Instant::now();
286 let result = auth.auth_password(&ctx, "password").await;
287 let elapsed = start.elapsed();
288
289 assert!(matches!(result, AuthResult::Reject));
290 assert!(elapsed >= Duration::from_millis(15));
291 }
292}