1use std::sync::Arc;
4
5use async_trait::async_trait;
6use tracing::debug;
7
8use super::handler::{AuthContext, AuthHandler, AuthMethod, AuthResult};
9use crate::PublicKey;
10
11pub struct PublicKeyCallbackAuth<F>
26where
27 F: Fn(&AuthContext, &PublicKey) -> bool + Send + Sync,
28{
29 callback: F,
30}
31
32impl<F> PublicKeyCallbackAuth<F>
33where
34 F: Fn(&AuthContext, &PublicKey) -> bool + Send + Sync,
35{
36 pub fn new(callback: F) -> Self {
38 Self { callback }
39 }
40}
41
42#[async_trait]
43impl<F> AuthHandler for PublicKeyCallbackAuth<F>
44where
45 F: Fn(&AuthContext, &PublicKey) -> bool + Send + Sync + 'static,
46{
47 async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
48 debug!(
49 username = %ctx.username(),
50 remote_addr = %ctx.remote_addr(),
51 key_type = %key.key_type,
52 "PublicKeyCallbackAuth: auth attempt"
53 );
54
55 if (self.callback)(ctx, key) {
56 debug!(username = %ctx.username(), "PublicKeyCallbackAuth: accepted");
57 AuthResult::Accept
58 } else {
59 debug!(username = %ctx.username(), "PublicKeyCallbackAuth: rejected");
60 AuthResult::Reject
61 }
62 }
63
64 fn supported_methods(&self) -> Vec<AuthMethod> {
65 vec![AuthMethod::PublicKey]
66 }
67}
68
69#[derive(Default)]
83pub struct PublicKeyAuth {
84 global_keys: Vec<PublicKey>,
86 user_keys: std::collections::HashMap<String, Vec<PublicKey>>,
88}
89
90impl PublicKeyAuth {
91 pub fn new() -> Self {
93 Self::default()
94 }
95
96 pub fn add_key(mut self, key: PublicKey) -> Self {
98 self.global_keys.push(key);
99 self
100 }
101
102 pub fn add_keys<I>(mut self, keys: I) -> Self
104 where
105 I: IntoIterator<Item = PublicKey>,
106 {
107 self.global_keys.extend(keys);
108 self
109 }
110
111 pub fn add_user_key(mut self, username: impl Into<String>, key: PublicKey) -> Self {
113 self.user_keys.entry(username.into()).or_default().push(key);
114 self
115 }
116
117 pub fn global_key_count(&self) -> usize {
119 self.global_keys.len()
120 }
121
122 pub fn user_key_count(&self, username: &str) -> usize {
124 self.user_keys.get(username).map(|v| v.len()).unwrap_or(0)
125 }
126
127 fn is_key_allowed(&self, username: &str, key: &PublicKey) -> bool {
129 if self.global_keys.iter().any(|k| k == key) {
131 return true;
132 }
133
134 if let Some(user_keys) = self.user_keys.get(username)
136 && user_keys.iter().any(|k| k == key)
137 {
138 return true;
139 }
140
141 false
142 }
143}
144
145#[async_trait]
146impl AuthHandler for PublicKeyAuth {
147 async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
148 debug!(
149 username = %ctx.username(),
150 remote_addr = %ctx.remote_addr(),
151 key_type = %key.key_type,
152 "PublicKeyAuth: auth attempt"
153 );
154
155 if self.is_key_allowed(ctx.username(), key) {
156 debug!(username = %ctx.username(), "PublicKeyAuth: accepted");
157 AuthResult::Accept
158 } else {
159 debug!(username = %ctx.username(), "PublicKeyAuth: rejected");
160 AuthResult::Reject
161 }
162 }
163
164 fn supported_methods(&self) -> Vec<AuthMethod> {
165 vec![AuthMethod::PublicKey]
166 }
167}
168
169#[allow(dead_code)]
174pub struct AsyncPublicKeyAuth<F>
175where
176 F: Fn(
177 &AuthContext,
178 &PublicKey,
179 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
180 + Send
181 + Sync,
182{
183 callback: Arc<F>,
184}
185
186impl<F> AsyncPublicKeyAuth<F>
187where
188 F: Fn(
189 &AuthContext,
190 &PublicKey,
191 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
192 + Send
193 + Sync,
194{
195 #[allow(dead_code)]
197 pub fn new(callback: Arc<F>) -> Self {
198 Self { callback }
199 }
200}
201
202#[async_trait]
203impl<F> AuthHandler for AsyncPublicKeyAuth<F>
204where
205 F: Fn(
206 &AuthContext,
207 &PublicKey,
208 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
209 + Send
210 + Sync
211 + 'static,
212{
213 async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
214 debug!(
215 username = %ctx.username(),
216 remote_addr = %ctx.remote_addr(),
217 key_type = %key.key_type,
218 "AsyncPublicKeyAuth: auth attempt"
219 );
220
221 if (self.callback)(ctx, key).await {
222 debug!(username = %ctx.username(), "AsyncPublicKeyAuth: accepted");
223 AuthResult::Accept
224 } else {
225 debug!(username = %ctx.username(), "AsyncPublicKeyAuth: rejected");
226 AuthResult::Reject
227 }
228 }
229
230 fn supported_methods(&self) -> Vec<AuthMethod> {
231 vec![AuthMethod::PublicKey]
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::super::SessionId;
238 use super::*;
239 use std::net::SocketAddr;
240 use std::sync::Arc;
241
242 fn make_context(username: &str) -> AuthContext {
243 let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
244 AuthContext::new(username, addr, SessionId(1))
245 }
246
247 fn make_key(key_type: &str, data: &[u8]) -> PublicKey {
248 PublicKey::new(key_type, data.to_vec())
249 }
250
251 #[tokio::test]
252 async fn test_publickey_callback_auth() {
253 let auth = PublicKeyCallbackAuth::new(|ctx, key| {
254 ctx.username() == "alice" && key.key_type == "ssh-ed25519"
255 });
256
257 let ctx = make_context("alice");
258 let key = make_key("ssh-ed25519", b"keydata");
259 assert!(matches!(
260 auth.auth_publickey(&ctx, &key).await,
261 AuthResult::Accept
262 ));
263
264 let key = make_key("ssh-rsa", b"keydata");
265 assert!(matches!(
266 auth.auth_publickey(&ctx, &key).await,
267 AuthResult::Reject
268 ));
269
270 let ctx = make_context("bob");
271 let key = make_key("ssh-ed25519", b"keydata");
272 assert!(matches!(
273 auth.auth_publickey(&ctx, &key).await,
274 AuthResult::Reject
275 ));
276 }
277
278 #[tokio::test]
279 async fn test_publickey_auth_global() {
280 let key1 = make_key("ssh-ed25519", b"key1");
281 let key2 = make_key("ssh-ed25519", b"key2");
282 let key3 = make_key("ssh-ed25519", b"key3");
283
284 let auth = PublicKeyAuth::new()
285 .add_key(key1.clone())
286 .add_key(key2.clone());
287
288 assert_eq!(auth.global_key_count(), 2);
289
290 let ctx = make_context("anyone");
291 assert!(matches!(
292 auth.auth_publickey(&ctx, &key1).await,
293 AuthResult::Accept
294 ));
295 assert!(matches!(
296 auth.auth_publickey(&ctx, &key2).await,
297 AuthResult::Accept
298 ));
299 assert!(matches!(
300 auth.auth_publickey(&ctx, &key3).await,
301 AuthResult::Reject
302 ));
303 }
304
305 #[tokio::test]
306 async fn test_publickey_auth_per_user() {
307 let alice_key = make_key("ssh-ed25519", b"alice_key");
308 let bob_key = make_key("ssh-ed25519", b"bob_key");
309
310 let auth = PublicKeyAuth::new()
311 .add_user_key("alice", alice_key.clone())
312 .add_user_key("bob", bob_key.clone());
313
314 assert_eq!(auth.user_key_count("alice"), 1);
315 assert_eq!(auth.user_key_count("bob"), 1);
316 assert_eq!(auth.user_key_count("charlie"), 0);
317
318 let ctx = make_context("alice");
319 assert!(matches!(
320 auth.auth_publickey(&ctx, &alice_key).await,
321 AuthResult::Accept
322 ));
323 assert!(matches!(
324 auth.auth_publickey(&ctx, &bob_key).await,
325 AuthResult::Reject
326 ));
327
328 let ctx = make_context("bob");
329 assert!(matches!(
330 auth.auth_publickey(&ctx, &bob_key).await,
331 AuthResult::Accept
332 ));
333 assert!(matches!(
334 auth.auth_publickey(&ctx, &alice_key).await,
335 AuthResult::Reject
336 ));
337 }
338
339 #[tokio::test]
340 async fn test_publickey_auth_add_keys() {
341 let keys = vec![
342 make_key("ssh-ed25519", b"key1"),
343 make_key("ssh-ed25519", b"key2"),
344 ];
345 let auth = PublicKeyAuth::new().add_keys(keys);
346 assert_eq!(auth.global_key_count(), 2);
347 }
348
349 #[tokio::test]
350 async fn test_async_publickey_auth() {
351 let auth = AsyncPublicKeyAuth::new(Arc::new(|ctx: &AuthContext, key: &PublicKey| {
352 let username = ctx.username().to_string();
353 let key_type = key.key_type.clone();
354 let fut: std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> =
355 Box::pin(async move { username == "alice" && key_type == "ssh-ed25519" });
356 fut
357 }));
358
359 let ctx = make_context("alice");
360 let key = make_key("ssh-ed25519", b"keydata");
361 assert!(matches!(
362 auth.auth_publickey(&ctx, &key).await,
363 AuthResult::Accept
364 ));
365
366 let key = make_key("ssh-rsa", b"keydata");
367 assert!(matches!(
368 auth.auth_publickey(&ctx, &key).await,
369 AuthResult::Reject
370 ));
371 }
372}