Skip to main content

wish/auth/
publickey.rs

1//! Public key authentication handlers.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tracing::debug;
7
8use super::handler::{AuthContext, AuthHandler, AuthMethod, AuthResult};
9use crate::PublicKey;
10
11/// Callback-based public key authentication handler.
12///
13/// Uses a user-provided callback function to validate public keys.
14///
15/// # Example
16///
17/// ```rust,ignore
18/// use wish::auth::PublicKeyCallbackAuth;
19///
20/// let auth = PublicKeyCallbackAuth::new(|ctx, key| {
21///     // Check if key is in allowed list
22///     allowed_keys.contains(key)
23/// });
24/// ```
25pub struct PublicKeyCallbackAuth<F>
26where
27    F: Fn(&AuthContext, &PublicKey) -> bool + Send + Sync,
28{
29    callback: F,
30}
31
32impl<F> PublicKeyCallbackAuth<F>
33where
34    F: Fn(&AuthContext, &PublicKey) -> bool + Send + Sync,
35{
36    /// Creates a new callback-based public key auth handler.
37    pub fn new(callback: F) -> Self {
38        Self { callback }
39    }
40}
41
42#[async_trait]
43impl<F> AuthHandler for PublicKeyCallbackAuth<F>
44where
45    F: Fn(&AuthContext, &PublicKey) -> bool + Send + Sync + 'static,
46{
47    async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
48        debug!(
49            username = %ctx.username(),
50            remote_addr = %ctx.remote_addr(),
51            key_type = %key.key_type,
52            "PublicKeyCallbackAuth: auth attempt"
53        );
54
55        if (self.callback)(ctx, key) {
56            debug!(username = %ctx.username(), "PublicKeyCallbackAuth: accepted");
57            AuthResult::Accept
58        } else {
59            debug!(username = %ctx.username(), "PublicKeyCallbackAuth: rejected");
60            AuthResult::Reject
61        }
62    }
63
64    fn supported_methods(&self) -> Vec<AuthMethod> {
65        vec![AuthMethod::PublicKey]
66    }
67}
68
69/// Simple public key authentication against a static set of keys.
70///
71/// Stores public keys and validates against them.
72///
73/// # Example
74///
75/// ```rust,ignore
76/// use wish::auth::PublicKeyAuth;
77/// use wish::PublicKey;
78///
79/// let key = PublicKey::new("ssh-ed25519", key_data);
80/// let auth = PublicKeyAuth::new().add_key(key);
81/// ```
82#[derive(Default)]
83pub struct PublicKeyAuth {
84    /// All allowed keys (regardless of user).
85    global_keys: Vec<PublicKey>,
86    /// Per-user allowed keys.
87    user_keys: std::collections::HashMap<String, Vec<PublicKey>>,
88}
89
90impl PublicKeyAuth {
91    /// Creates a new empty public key auth handler.
92    pub fn new() -> Self {
93        Self::default()
94    }
95
96    /// Adds a global key that can authenticate any user.
97    pub fn add_key(mut self, key: PublicKey) -> Self {
98        self.global_keys.push(key);
99        self
100    }
101
102    /// Adds multiple global keys.
103    pub fn add_keys<I>(mut self, keys: I) -> Self
104    where
105        I: IntoIterator<Item = PublicKey>,
106    {
107        self.global_keys.extend(keys);
108        self
109    }
110
111    /// Adds a key that can only authenticate a specific user.
112    pub fn add_user_key(mut self, username: impl Into<String>, key: PublicKey) -> Self {
113        self.user_keys.entry(username.into()).or_default().push(key);
114        self
115    }
116
117    /// Returns the number of global keys.
118    pub fn global_key_count(&self) -> usize {
119        self.global_keys.len()
120    }
121
122    /// Returns the number of keys for a specific user.
123    pub fn user_key_count(&self, username: &str) -> usize {
124        self.user_keys.get(username).map(|v| v.len()).unwrap_or(0)
125    }
126
127    /// Checks if a key is allowed for authentication.
128    fn is_key_allowed(&self, username: &str, key: &PublicKey) -> bool {
129        // Check global keys
130        if self.global_keys.iter().any(|k| k == key) {
131            return true;
132        }
133
134        // Check user-specific keys
135        if let Some(user_keys) = self.user_keys.get(username)
136            && user_keys.iter().any(|k| k == key)
137        {
138            return true;
139        }
140
141        false
142    }
143}
144
145#[async_trait]
146impl AuthHandler for PublicKeyAuth {
147    async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
148        debug!(
149            username = %ctx.username(),
150            remote_addr = %ctx.remote_addr(),
151            key_type = %key.key_type,
152            "PublicKeyAuth: auth attempt"
153        );
154
155        if self.is_key_allowed(ctx.username(), key) {
156            debug!(username = %ctx.username(), "PublicKeyAuth: accepted");
157            AuthResult::Accept
158        } else {
159            debug!(username = %ctx.username(), "PublicKeyAuth: rejected");
160            AuthResult::Reject
161        }
162    }
163
164    fn supported_methods(&self) -> Vec<AuthMethod> {
165        vec![AuthMethod::PublicKey]
166    }
167}
168
169/// Async callback-based public key authentication handler.
170///
171/// Uses a user-provided async callback function to validate public keys.
172/// Useful for database lookups or remote authentication services.
173#[allow(dead_code)]
174pub struct AsyncPublicKeyAuth<F>
175where
176    F: Fn(
177            &AuthContext,
178            &PublicKey,
179        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
180        + Send
181        + Sync,
182{
183    callback: Arc<F>,
184}
185
186impl<F> AsyncPublicKeyAuth<F>
187where
188    F: Fn(
189            &AuthContext,
190            &PublicKey,
191        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
192        + Send
193        + Sync,
194{
195    /// Creates a new async callback-based public key auth handler.
196    #[allow(dead_code)]
197    pub fn new(callback: Arc<F>) -> Self {
198        Self { callback }
199    }
200}
201
202#[async_trait]
203impl<F> AuthHandler for AsyncPublicKeyAuth<F>
204where
205    F: Fn(
206            &AuthContext,
207            &PublicKey,
208        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>>
209        + Send
210        + Sync
211        + 'static,
212{
213    async fn auth_publickey(&self, ctx: &AuthContext, key: &PublicKey) -> AuthResult {
214        debug!(
215            username = %ctx.username(),
216            remote_addr = %ctx.remote_addr(),
217            key_type = %key.key_type,
218            "AsyncPublicKeyAuth: auth attempt"
219        );
220
221        if (self.callback)(ctx, key).await {
222            debug!(username = %ctx.username(), "AsyncPublicKeyAuth: accepted");
223            AuthResult::Accept
224        } else {
225            debug!(username = %ctx.username(), "AsyncPublicKeyAuth: rejected");
226            AuthResult::Reject
227        }
228    }
229
230    fn supported_methods(&self) -> Vec<AuthMethod> {
231        vec![AuthMethod::PublicKey]
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::super::SessionId;
238    use super::*;
239    use std::net::SocketAddr;
240    use std::sync::Arc;
241
242    fn make_context(username: &str) -> AuthContext {
243        let addr: SocketAddr = "127.0.0.1:22".parse().unwrap();
244        AuthContext::new(username, addr, SessionId(1))
245    }
246
247    fn make_key(key_type: &str, data: &[u8]) -> PublicKey {
248        PublicKey::new(key_type, data.to_vec())
249    }
250
251    #[tokio::test]
252    async fn test_publickey_callback_auth() {
253        let auth = PublicKeyCallbackAuth::new(|ctx, key| {
254            ctx.username() == "alice" && key.key_type == "ssh-ed25519"
255        });
256
257        let ctx = make_context("alice");
258        let key = make_key("ssh-ed25519", b"keydata");
259        assert!(matches!(
260            auth.auth_publickey(&ctx, &key).await,
261            AuthResult::Accept
262        ));
263
264        let key = make_key("ssh-rsa", b"keydata");
265        assert!(matches!(
266            auth.auth_publickey(&ctx, &key).await,
267            AuthResult::Reject
268        ));
269
270        let ctx = make_context("bob");
271        let key = make_key("ssh-ed25519", b"keydata");
272        assert!(matches!(
273            auth.auth_publickey(&ctx, &key).await,
274            AuthResult::Reject
275        ));
276    }
277
278    #[tokio::test]
279    async fn test_publickey_auth_global() {
280        let key1 = make_key("ssh-ed25519", b"key1");
281        let key2 = make_key("ssh-ed25519", b"key2");
282        let key3 = make_key("ssh-ed25519", b"key3");
283
284        let auth = PublicKeyAuth::new()
285            .add_key(key1.clone())
286            .add_key(key2.clone());
287
288        assert_eq!(auth.global_key_count(), 2);
289
290        let ctx = make_context("anyone");
291        assert!(matches!(
292            auth.auth_publickey(&ctx, &key1).await,
293            AuthResult::Accept
294        ));
295        assert!(matches!(
296            auth.auth_publickey(&ctx, &key2).await,
297            AuthResult::Accept
298        ));
299        assert!(matches!(
300            auth.auth_publickey(&ctx, &key3).await,
301            AuthResult::Reject
302        ));
303    }
304
305    #[tokio::test]
306    async fn test_publickey_auth_per_user() {
307        let alice_key = make_key("ssh-ed25519", b"alice_key");
308        let bob_key = make_key("ssh-ed25519", b"bob_key");
309
310        let auth = PublicKeyAuth::new()
311            .add_user_key("alice", alice_key.clone())
312            .add_user_key("bob", bob_key.clone());
313
314        assert_eq!(auth.user_key_count("alice"), 1);
315        assert_eq!(auth.user_key_count("bob"), 1);
316        assert_eq!(auth.user_key_count("charlie"), 0);
317
318        let ctx = make_context("alice");
319        assert!(matches!(
320            auth.auth_publickey(&ctx, &alice_key).await,
321            AuthResult::Accept
322        ));
323        assert!(matches!(
324            auth.auth_publickey(&ctx, &bob_key).await,
325            AuthResult::Reject
326        ));
327
328        let ctx = make_context("bob");
329        assert!(matches!(
330            auth.auth_publickey(&ctx, &bob_key).await,
331            AuthResult::Accept
332        ));
333        assert!(matches!(
334            auth.auth_publickey(&ctx, &alice_key).await,
335            AuthResult::Reject
336        ));
337    }
338
339    #[tokio::test]
340    async fn test_publickey_auth_add_keys() {
341        let keys = vec![
342            make_key("ssh-ed25519", b"key1"),
343            make_key("ssh-ed25519", b"key2"),
344        ];
345        let auth = PublicKeyAuth::new().add_keys(keys);
346        assert_eq!(auth.global_key_count(), 2);
347    }
348
349    #[tokio::test]
350    async fn test_async_publickey_auth() {
351        let auth = AsyncPublicKeyAuth::new(Arc::new(|ctx: &AuthContext, key: &PublicKey| {
352            let username = ctx.username().to_string();
353            let key_type = key.key_type.clone();
354            let fut: std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> =
355                Box::pin(async move { username == "alice" && key_type == "ssh-ed25519" });
356            fut
357        }));
358
359        let ctx = make_context("alice");
360        let key = make_key("ssh-ed25519", b"keydata");
361        assert!(matches!(
362            auth.auth_publickey(&ctx, &key).await,
363            AuthResult::Accept
364        ));
365
366        let key = make_key("ssh-rsa", b"keydata");
367        assert!(matches!(
368            auth.auth_publickey(&ctx, &key).await,
369            AuthResult::Reject
370        ));
371    }
372}