Skip to main content

ash_rpc/
auth.rs

1//! Authentication and authorization hooks
2//!
3//! This module provides minimal traits for implementing authentication
4//! and authorization. The library makes NO assumptions about:
5//! - How you authenticate (JWT, API keys, OAuth, certificates, etc.)
6//! - What your user/identity model looks like
7//! - How you authorize (RBAC, ABAC, ACL, custom logic, etc.)
8//!
9//! You implement the trait, we call your `can_access` method.
10//!
11//! # Example
12//! ```
13//! use ash_rpc::auth::{AuthPolicy, ConnectionContext};
14//!
15//! struct MyAuth;
16//!
17//! impl AuthPolicy for MyAuth {
18//!     fn can_access(&self, method: &str, params: Option<&serde_json::Value>, _ctx: &ConnectionContext) -> bool {
19//!         // Your logic here - check API keys, JWT tokens, whatever you need
20//!         let _ = (method, params);
21//!         true
22//!     }
23//! }
24//! ```
25
26use crate::Response;
27use std::any::Any;
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31/// Type alias for auth metadata storage
32type AuthMetadata = std::collections::HashMap<String, Arc<dyn Any + Send + Sync>>;
33
34/// Connection context for authentication
35///
36/// This struct holds metadata about a connection that can be used for
37/// authentication and authorization decisions. The library makes NO
38/// assumptions about what data you need - store anything in `metadata`.
39///
40/// # Examples
41/// - TLS client certificates
42/// - IP addresses for whitelisting
43/// - Custom connection-level tokens
44/// - Session identifiers
45#[derive(Default, Clone)]
46pub struct ConnectionContext {
47    /// Remote address of the connection
48    pub remote_addr: Option<SocketAddr>,
49
50    /// User-defined metadata
51    ///
52    /// Store any auth-related data here:
53    /// - TLS peer certificates
54    /// - Extracted user IDs
55    /// - Session tokens
56    /// - Rate limiting state
57    /// - Whatever you need for your auth logic
58    pub metadata: AuthMetadata,
59}
60
61impl ConnectionContext {
62    /// Create a new empty context
63    #[must_use]
64    pub fn new() -> Self {
65        Self::default()
66    }
67
68    /// Create context with remote address
69    #[must_use]
70    pub fn with_addr(remote_addr: SocketAddr) -> Self {
71        Self {
72            remote_addr: Some(remote_addr),
73            metadata: std::collections::HashMap::new(),
74        }
75    }
76
77    /// Insert typed metadata
78    pub fn insert<T: Any + Send + Sync>(&mut self, key: String, value: T) {
79        self.metadata.insert(key, Arc::new(value));
80    }
81
82    /// Get typed metadata
83    #[must_use]
84    pub fn get<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
85        self.metadata.get(key).and_then(|v| v.downcast_ref::<T>())
86    }
87}
88
89/// Trait for extracting authentication context from connections
90///
91/// Implement this to extract auth data from your transport layer.
92/// The library will call this when a new connection is established.
93///
94/// # What you can extract:
95/// - TLS client certificates for mutual TLS authentication
96/// - IP addresses for whitelisting/geoblocking
97/// - Custom connection-level authentication tokens
98/// - Any connection metadata you need for auth decisions
99///
100/// # Example: TLS Certificate Extraction
101/// ```text
102/// use ash_rpc::auth::{ContextExtractor, ConnectionContext};
103///
104/// struct TlsContextExtractor;
105///
106/// #[async_trait::async_trait]
107/// impl ContextExtractor for TlsContextExtractor {
108///     async fn extract(&self, stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>) -> ConnectionContext {
109///         let mut ctx = ConnectionContext::new();
110///         
111///         // Extract TLS peer certificates
112///         if let Some(certs) = stream.get_ref().1.peer_certificates() {
113///             ctx.insert("peer_certs".to_string(), certs.clone());
114///         }
115///         
116///         // Extract client IP
117///         if let Ok(addr) = stream.get_ref().0.peer_addr() {
118///             ctx.remote_addr = Some(addr);
119///         }
120///         
121///         ctx
122///     }
123/// }
124/// ```
125#[async_trait::async_trait]
126pub trait ContextExtractor: Send + Sync {
127    /// Extract connection context for authentication
128    ///
129    /// This is called once when a connection is established.
130    /// The returned context is passed to the auth policy for each request.
131    ///
132    /// # Arguments
133    /// * `remote_addr` - Remote socket address of the connection
134    /// * `metadata` - Optional transport-specific metadata (e.g., TLS session data)
135    ///
136    /// # Returns
137    /// A `ConnectionContext` with whatever data you need for auth
138    async fn extract(
139        &self,
140        remote_addr: Option<SocketAddr>,
141        metadata: Option<Arc<dyn Any + Send + Sync>>,
142    ) -> ConnectionContext;
143}
144
145/// Default context extractor that only captures the remote address
146pub struct DefaultContextExtractor;
147
148#[async_trait::async_trait]
149impl ContextExtractor for DefaultContextExtractor {
150    async fn extract(
151        &self,
152        remote_addr: Option<SocketAddr>,
153        _metadata: Option<Arc<dyn Any + Send + Sync>>,
154    ) -> ConnectionContext {
155        ConnectionContext {
156            remote_addr,
157            metadata: std::collections::HashMap::new(),
158        }
159    }
160}
161
162/// Trait for implementing authentication/authorization checks
163///
164/// Implement this to control access to your JSON-RPC methods.
165/// The library will call `can_access` before executing methods.
166///
167/// You decide what "access" means - it could be:
168/// - Checking an API key in request params
169/// - Validating a JWT token
170/// - Verifying client certificates (via `ConnectionContext`)
171/// - Role-based checks
172/// - Rate limiting
173/// - IP whitelisting (via `ConnectionContext`)
174/// - Anything else you need
175pub trait AuthPolicy: Send + Sync {
176    /// Check if a request should be allowed to proceed
177    ///
178    /// # Arguments
179    /// * `method` - The JSON-RPC method being called
180    /// * `params` - Optional parameters from the request
181    /// * `ctx` - Connection context (IP, TLS certs, custom metadata)
182    ///
183    /// # Returns
184    /// `true` if the request should proceed, `false` to deny
185    ///
186    /// # Example: Using Connection Context
187    /// ```text
188    /// fn can_access(
189    ///     &self,
190    ///     method: &str,
191    ///     params: Option<&serde_json::Value>,
192    ///     ctx: &ConnectionContext,
193    /// ) -> bool {
194    ///     // Check IP whitelist
195    ///     if let Some(addr) = ctx.remote_addr {
196    ///         if !self.is_ip_allowed(&addr.ip()) {
197    ///             return false;
198    ///         }
199    ///     }
200    ///     
201    ///     // Check TLS client certificate
202    ///     if let Some(certs) = ctx.get::<Vec<Certificate>>("peer_certs") {
203    ///         return self.validate_client_cert(certs);
204    ///     }
205    ///     
206    ///     // Check token in params
207    ///     let token = params
208    ///         .and_then(|p| p.get("auth_token"))
209    ///         .and_then(|t| t.as_str());
210    ///     
211    ///     self.validate_token(token)
212    /// }
213    /// ```
214    fn can_access(
215        &self,
216        method: &str,
217        params: Option<&serde_json::Value>,
218        ctx: &ConnectionContext,
219    ) -> bool;
220
221    /// Optional: Get the unauthorized error response
222    ///
223    /// Override this if you want custom error messages for denied requests.
224    /// Default returns a generic "Unauthorized" error.
225    fn unauthorized_error(&self, method: &str) -> Response {
226        let _ = method;
227        crate::ResponseBuilder::new()
228            .error(
229                crate::ErrorBuilder::new(crate::error_codes::INTERNAL_ERROR, "Unauthorized")
230                    .build(),
231            )
232            .id(None)
233            .build()
234    }
235}
236
237/// Helper: Always allow all requests (no authentication)
238///
239/// Use this as a placeholder or for development/testing.
240pub struct AllowAll;
241
242impl AuthPolicy for AllowAll {
243    fn can_access(
244        &self,
245        _method: &str,
246        _params: Option<&serde_json::Value>,
247        _ctx: &ConnectionContext,
248    ) -> bool {
249        true
250    }
251}
252
253/// Helper: Deny all requests
254///
255/// Useful for maintenance mode or testing denial paths.
256pub struct DenyAll;
257
258impl AuthPolicy for DenyAll {
259    fn can_access(
260        &self,
261        _method: &str,
262        _params: Option<&serde_json::Value>,
263        _ctx: &ConnectionContext,
264    ) -> bool {
265        false
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_allow_all() {
275        let policy = AllowAll;
276        let ctx = ConnectionContext::new();
277        assert!(policy.can_access("any_method", None, &ctx));
278        assert!(policy.can_access(
279            "another_method",
280            Some(&serde_json::json!({"key": "value"})),
281            &ctx
282        ));
283    }
284
285    #[test]
286    fn test_deny_all() {
287        let policy = DenyAll;
288        let ctx = ConnectionContext::new();
289        assert!(!policy.can_access("any_method", None, &ctx));
290        assert!(!policy.can_access(
291            "another_method",
292            Some(&serde_json::json!({"key": "value"})),
293            &ctx
294        ));
295    }
296
297    #[test]
298    fn test_connection_context() {
299        let mut ctx = ConnectionContext::new();
300
301        // Insert and retrieve typed metadata
302        ctx.insert("user_id".to_string(), 42u64);
303        assert_eq!(ctx.get::<u64>("user_id"), Some(&42));
304
305        // Wrong type returns None
306        assert_eq!(ctx.get::<String>("user_id"), None);
307
308        // Non-existent key returns None
309        assert_eq!(ctx.get::<u64>("other"), None);
310    }
311
312    #[test]
313    fn test_connection_context_with_addr() {
314        use std::net::{IpAddr, Ipv4Addr};
315        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
316        let ctx = ConnectionContext::with_addr(addr);
317
318        assert_eq!(ctx.remote_addr, Some(addr));
319        assert_eq!(ctx.metadata.len(), 0);
320    }
321
322    #[test]
323    fn test_connection_context_default() {
324        let ctx = ConnectionContext::default();
325        assert!(ctx.remote_addr.is_none());
326        assert_eq!(ctx.metadata.len(), 0);
327    }
328
329    #[test]
330    fn test_connection_context_multiple_metadata() {
331        let mut ctx = ConnectionContext::new();
332
333        ctx.insert("user_id".to_string(), 123u64);
334        ctx.insert("username".to_string(), String::from("alice"));
335        ctx.insert("is_admin".to_string(), true);
336
337        assert_eq!(ctx.get::<u64>("user_id"), Some(&123));
338        assert_eq!(ctx.get::<String>("username"), Some(&String::from("alice")));
339        assert_eq!(ctx.get::<bool>("is_admin"), Some(&true));
340    }
341
342    #[test]
343    fn test_allow_all_unauthorized_error() {
344        let policy = AllowAll;
345        let response = policy.unauthorized_error("test_method");
346        assert!(response.error.is_some());
347        let error = response.error.unwrap();
348        assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
349        assert_eq!(error.message, "Unauthorized");
350    }
351
352    #[test]
353    fn test_deny_all_unauthorized_error() {
354        let policy = DenyAll;
355        let response = policy.unauthorized_error("blocked_method");
356        assert!(response.error.is_some());
357    }
358
359    #[tokio::test]
360    async fn test_default_context_extractor() {
361        use std::net::{IpAddr, Ipv4Addr};
362        let extractor = DefaultContextExtractor;
363        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000);
364
365        let ctx = extractor.extract(Some(addr), None).await;
366        assert_eq!(ctx.remote_addr, Some(addr));
367        assert_eq!(ctx.metadata.len(), 0);
368    }
369
370    #[tokio::test]
371    async fn test_default_context_extractor_no_addr() {
372        let extractor = DefaultContextExtractor;
373        let ctx = extractor.extract(None, None).await;
374        assert!(ctx.remote_addr.is_none());
375    }
376
377    #[tokio::test]
378    async fn test_default_context_extractor_with_metadata() {
379        use std::net::{IpAddr, Ipv4Addr};
380        let extractor = DefaultContextExtractor;
381        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 3000);
382        let metadata: Arc<dyn Any + Send + Sync> = Arc::new(String::from("test"));
383
384        let ctx = extractor.extract(Some(addr), Some(metadata)).await;
385        assert_eq!(ctx.remote_addr, Some(addr));
386    }
387
388    #[test]
389    fn test_connection_context_clone() {
390        let mut ctx1 = ConnectionContext::new();
391        ctx1.insert("key".to_string(), 100u32);
392
393        let ctx2 = ctx1.clone();
394        assert_eq!(ctx2.get::<u32>("key"), Some(&100));
395    }
396}