1use std::sync::Arc;
4
5use async_trait::async_trait;
6use tracing::{debug, warn};
7
8use super::handler::{AuthContext, AuthHandler, AuthMethod, AuthResult};
9use crate::PublicKey;
10
11#[derive(Debug, Clone, Copy, Default)]
24pub struct AcceptAllAuth {
25 _private: (),
26}
27
28impl AcceptAllAuth {
29 pub fn new() -> Self {
31 warn!("AcceptAllAuth in use - NOT FOR PRODUCTION");
32 Self { _private: () }
33 }
34}
35
36#[async_trait]
37impl AuthHandler for AcceptAllAuth {
38 async fn auth_password(&self, ctx: &AuthContext, _password: &str) -> AuthResult {
39 warn!(
40 username = %ctx.username(),
41 remote_addr = %ctx.remote_addr(),
42 "AcceptAllAuth: accepting password auth"
43 );
44 AuthResult::Accept
45 }
46
47 async fn auth_publickey(&self, ctx: &AuthContext, _key: &PublicKey) -> AuthResult {
48 warn!(
49 username = %ctx.username(),
50 remote_addr = %ctx.remote_addr(),
51 "AcceptAllAuth: accepting public key auth"
52 );
53 AuthResult::Accept
54 }
55
56 async fn auth_keyboard_interactive(&self, ctx: &AuthContext, _response: &str) -> AuthResult {
57 warn!(
58 username = %ctx.username(),
59 remote_addr = %ctx.remote_addr(),
60 "AcceptAllAuth: accepting keyboard-interactive auth"
61 );
62 AuthResult::Accept
63 }
64
65 async fn auth_none(&self, ctx: &AuthContext) -> AuthResult {
66 warn!(
67 username = %ctx.username(),
68 remote_addr = %ctx.remote_addr(),
69 "AcceptAllAuth: accepting none auth"
70 );
71 AuthResult::Accept
72 }
73
74 fn supported_methods(&self) -> Vec<AuthMethod> {
75 vec![
76 AuthMethod::None,
77 AuthMethod::Password,
78 AuthMethod::PublicKey,
79 AuthMethod::KeyboardInteractive,
80 ]
81 }
82}
83
84pub struct CallbackAuth<F>
98where
99 F: Fn(&AuthContext, &str) -> bool + Send + Sync,
100{
101 callback: F,
102}
103
104impl<F> CallbackAuth<F>
105where
106 F: Fn(&AuthContext, &str) -> bool + Send + Sync,
107{
108 pub fn new(callback: F) -> Self {
110 Self { callback }
111 }
112}
113
114#[async_trait]
115impl<F> AuthHandler for CallbackAuth<F>
116where
117 F: Fn(&AuthContext, &str) -> bool + Send + Sync + 'static,
118{
119 async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
120 debug!(
121 username = %ctx.username(),
122 remote_addr = %ctx.remote_addr(),
123 "CallbackAuth: password auth attempt"
124 );
125
126 if (self.callback)(ctx, password) {
127 debug!(username = %ctx.username(), "CallbackAuth: password accepted");
128 AuthResult::Accept
129 } else {
130 debug!(username = %ctx.username(), "CallbackAuth: password rejected");
131 AuthResult::Reject
132 }
133 }
134
135 fn supported_methods(&self) -> Vec<AuthMethod> {
136 vec![AuthMethod::Password]
137 }
138}
139
140pub struct PasswordAuth {
154 users: std::collections::HashMap<String, String>,
155}
156
157impl PasswordAuth {
158 pub fn new() -> Self {
160 Self {
161 users: std::collections::HashMap::new(),
162 }
163 }
164
165 pub fn add_user(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
167 self.users.insert(username.into(), password.into());
168 self
169 }
170
171 pub fn add_users<I, U, P>(mut self, users: I) -> Self
173 where
174 I: IntoIterator<Item = (U, P)>,
175 U: Into<String>,
176 P: Into<String>,
177 {
178 for (username, password) in users {
179 self.users.insert(username.into(), password.into());
180 }
181 self
182 }
183
184 pub fn user_count(&self) -> usize {
186 self.users.len()
187 }
188
189 pub fn has_user(&self, username: &str) -> bool {
191 self.users.contains_key(username)
192 }
193}
194
195impl Default for PasswordAuth {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[async_trait]
202impl AuthHandler for PasswordAuth {
203 async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
204 debug!(
205 username = %ctx.username(),
206 remote_addr = %ctx.remote_addr(),
207 "PasswordAuth: auth attempt"
208 );
209
210 let stored = self.users.get(ctx.username());
211 let target = stored.map(String::as_str).unwrap_or("");
214
215 if constant_time_eq(target, password) && stored.is_some() {
216 debug!(username = %ctx.username(), "PasswordAuth: accepted");
217 AuthResult::Accept
218 } else {
219 debug!(username = %ctx.username(), "PasswordAuth: rejected");
220 AuthResult::Reject
221 }
222 }
223
224 fn supported_methods(&self) -> Vec<AuthMethod> {
225 vec![AuthMethod::Password]
226 }
227}
228
229fn constant_time_eq(a: &str, b: &str) -> bool {
235 let a_bytes = a.as_bytes();
236 let b_bytes = b.as_bytes();
237 let len = a_bytes.len().max(b_bytes.len());
238 let mut result: usize = a_bytes.len() ^ b_bytes.len();
240 for i in 0..len {
241 let x = a_bytes.get(i).copied().unwrap_or(0);
242 let y = b_bytes.get(i).copied().unwrap_or(0);
243 result |= (x ^ y) as usize;
244 }
245 result == 0
246}
247
248#[allow(dead_code)]
267pub struct AsyncCallbackAuth<F>
268where
269 F: Fn(&AuthContext, &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
270 + Send
271 + Sync,
272{
273 callback: Arc<F>,
274}
275
276impl<F> AsyncCallbackAuth<F>
277where
278 F: Fn(&AuthContext, &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
279 + Send
280 + Sync,
281{
282 #[allow(dead_code)]
284 pub fn new(callback: Arc<F>) -> Self {
285 Self { callback }
286 }
287}
288
289#[async_trait]
290impl<F> AuthHandler for AsyncCallbackAuth<F>
291where
292 F: Fn(&AuthContext, &str) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
293 + Send
294 + Sync
295 + 'static,
296{
297 async fn auth_password(&self, ctx: &AuthContext, password: &str) -> AuthResult {
298 debug!(
299 username = %ctx.username(),
300 remote_addr = %ctx.remote_addr(),
301 "AsyncCallbackAuth: password auth attempt"
302 );
303
304 if (self.callback)(ctx, password).await {
305 debug!(username = %ctx.username(), "AsyncCallbackAuth: password accepted");
306 AuthResult::Accept
307 } else {
308 debug!(username = %ctx.username(), "AsyncCallbackAuth: password rejected");
309 AuthResult::Reject
310 }
311 }
312
313 fn supported_methods(&self) -> Vec<AuthMethod> {
314 vec![AuthMethod::Password]
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::super::SessionId;
321 use super::*;
322 use std::net::SocketAddr;
323 use std::sync::Arc;
324
325 fn make_context(username: &str) -> AuthContext {
326 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
327 AuthContext::new(username, addr, SessionId(1))
328 }
329
330 #[tokio::test]
331 async fn test_accept_all_auth() {
332 let auth = AcceptAllAuth::new();
333 let ctx = make_context("anyone");
334
335 assert!(matches!(
336 auth.auth_password(&ctx, "anything").await,
337 AuthResult::Accept
338 ));
339 assert!(matches!(auth.auth_none(&ctx).await, AuthResult::Accept));
340 }
341
342 #[tokio::test]
343 async fn test_callback_auth() {
344 let auth =
345 CallbackAuth::new(|ctx, password| ctx.username() == "admin" && password == "secret");
346
347 let ctx = make_context("admin");
348 assert!(matches!(
349 auth.auth_password(&ctx, "secret").await,
350 AuthResult::Accept
351 ));
352 assert!(matches!(
353 auth.auth_password(&ctx, "wrong").await,
354 AuthResult::Reject
355 ));
356
357 let ctx = make_context("user");
358 assert!(matches!(
359 auth.auth_password(&ctx, "secret").await,
360 AuthResult::Reject
361 ));
362 }
363
364 #[tokio::test]
365 async fn test_password_auth() {
366 let auth = PasswordAuth::new()
367 .add_user("alice", "password123")
368 .add_user("bob", "secret456");
369
370 assert_eq!(auth.user_count(), 2);
371 assert!(auth.has_user("alice"));
372 assert!(!auth.has_user("charlie"));
373
374 let ctx = make_context("alice");
375 assert!(matches!(
376 auth.auth_password(&ctx, "password123").await,
377 AuthResult::Accept
378 ));
379 assert!(matches!(
380 auth.auth_password(&ctx, "wrong").await,
381 AuthResult::Reject
382 ));
383
384 let ctx = make_context("charlie");
385 assert!(matches!(
386 auth.auth_password(&ctx, "any").await,
387 AuthResult::Reject
388 ));
389 }
390
391 #[test]
392 fn test_password_auth_add_users() {
393 let users = vec![("user1", "pass1"), ("user2", "pass2")];
394 let auth = PasswordAuth::new().add_users(users);
395 assert_eq!(auth.user_count(), 2);
396 assert!(auth.has_user("user1"));
397 assert!(auth.has_user("user2"));
398 }
399
400 #[tokio::test]
401 async fn test_async_callback_auth() {
402 let auth = AsyncCallbackAuth::new(Arc::new(|ctx: &AuthContext, password: &str| {
403 let username = ctx.username().to_string();
404 let password = password.to_string();
405 let fut: std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> =
406 Box::pin(async move { username == "admin" && password == "secret" });
407 fut
408 }));
409
410 let ctx = make_context("admin");
411 assert!(matches!(
412 auth.auth_password(&ctx, "secret").await,
413 AuthResult::Accept
414 ));
415 assert!(matches!(
416 auth.auth_password(&ctx, "wrong").await,
417 AuthResult::Reject
418 ));
419 }
420
421 #[test]
422 fn test_constant_time_eq_basic() {
423 assert!(constant_time_eq("hello", "hello"));
424 assert!(!constant_time_eq("hello", "world"));
425 assert!(!constant_time_eq("hello", "hell"));
426 assert!(!constant_time_eq("", "a"));
427 assert!(constant_time_eq("", ""));
428 }
429
430 #[test]
431 fn test_constant_time_eq_length_differs_by_256() {
432 let short = "";
436 let long = "a".repeat(256);
437 assert!(!constant_time_eq(short, &long));
438 assert!(!constant_time_eq(&long, short));
439
440 let a = "x";
442 let b = "x".to_string() + &"y".repeat(256);
443 assert!(!constant_time_eq(a, &b));
444 }
445}