Skip to main content

ash_rpc/
auth.rs

1//! Authentication and authorization hooks
2//!
3//! This module provides minimal traits for implementing authentication
4//! and authorization. The library makes NO assumptions about:
5//! - How you authenticate (JWT, API keys, OAuth, certificates, etc.)
6//! - What your user/identity model looks like
7//! - How you authorize (RBAC, ABAC, ACL, custom logic, etc.)
8//!
9//! You implement the trait, we call your `can_access` method.
10//!
11//! # Example
12//! ```
13//! use ash_rpc::auth::{AuthPolicy, ConnectionContext};
14//!
15//! struct MyAuth;
16//!
17//! impl AuthPolicy for MyAuth {
18//!     fn can_access(&self, method: &str, params: Option<&serde_json::Value>, _ctx: &ConnectionContext) -> bool {
19//!         // Your logic here - check API keys, JWT tokens, whatever you need
20//!         let _ = (method, params);
21//!         true
22//!     }
23//! }
24//! ```
25
26use crate::Response;
27use std::any::Any;
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31/// Type alias for auth metadata storage
32type AuthMetadata = std::collections::HashMap<String, Arc<dyn Any + Send + Sync>>;
33
34/// Connection context for authentication
35///
36/// This struct holds metadata about a connection that can be used for
37/// authentication and authorization decisions. The library makes NO
38/// assumptions about what data you need - store anything in `metadata`.
39///
40/// # Examples
41/// - TLS client certificates
42/// - IP addresses for whitelisting
43/// - Custom connection-level tokens
44/// - Session identifiers
45#[derive(Default, Clone)]
46pub struct ConnectionContext {
47    /// Remote address of the connection
48    pub remote_addr: Option<SocketAddr>,
49
50    /// User-defined metadata
51    ///
52    /// Store any auth-related data here:
53    /// - TLS peer certificates
54    /// - Extracted user IDs
55    /// - Session tokens
56    /// - Rate limiting state
57    /// - Whatever you need for your auth logic
58    pub metadata: AuthMetadata,
59}
60
61impl ConnectionContext {
62    /// Create a new empty context
63    pub fn new() -> Self {
64        Self::default()
65    }
66
67    /// Create context with remote address
68    pub fn with_addr(remote_addr: SocketAddr) -> Self {
69        Self {
70            remote_addr: Some(remote_addr),
71            metadata: std::collections::HashMap::new(),
72        }
73    }
74
75    /// Insert typed metadata
76    pub fn insert<T: Any + Send + Sync>(&mut self, key: String, value: T) {
77        self.metadata.insert(key, Arc::new(value));
78    }
79
80    /// Get typed metadata
81    pub fn get<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
82        self.metadata.get(key).and_then(|v| v.downcast_ref::<T>())
83    }
84}
85
86/// Trait for extracting authentication context from connections
87///
88/// Implement this to extract auth data from your transport layer.
89/// The library will call this when a new connection is established.
90///
91/// # What you can extract:
92/// - TLS client certificates for mutual TLS authentication
93/// - IP addresses for whitelisting/geoblocking
94/// - Custom connection-level authentication tokens
95/// - Any connection metadata you need for auth decisions
96///
97/// # Example: TLS Certificate Extraction
98/// ```text
99/// use ash_rpc::auth::{ContextExtractor, ConnectionContext};
100///
101/// struct TlsContextExtractor;
102///
103/// #[async_trait::async_trait]
104/// impl ContextExtractor for TlsContextExtractor {
105///     async fn extract(&self, stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>) -> ConnectionContext {
106///         let mut ctx = ConnectionContext::new();
107///         
108///         // Extract TLS peer certificates
109///         if let Some(certs) = stream.get_ref().1.peer_certificates() {
110///             ctx.insert("peer_certs".to_string(), certs.clone());
111///         }
112///         
113///         // Extract client IP
114///         if let Ok(addr) = stream.get_ref().0.peer_addr() {
115///             ctx.remote_addr = Some(addr);
116///         }
117///         
118///         ctx
119///     }
120/// }
121/// ```
122#[async_trait::async_trait]
123pub trait ContextExtractor: Send + Sync {
124    /// Extract connection context for authentication
125    ///
126    /// This is called once when a connection is established.
127    /// The returned context is passed to the auth policy for each request.
128    ///
129    /// # Arguments
130    /// * `remote_addr` - Remote socket address of the connection
131    /// * `metadata` - Optional transport-specific metadata (e.g., TLS session data)
132    ///
133    /// # Returns
134    /// A `ConnectionContext` with whatever data you need for auth
135    async fn extract(
136        &self,
137        remote_addr: Option<SocketAddr>,
138        metadata: Option<Arc<dyn Any + Send + Sync>>,
139    ) -> ConnectionContext;
140}
141
142/// Default context extractor that only captures the remote address
143pub struct DefaultContextExtractor;
144
145#[async_trait::async_trait]
146impl ContextExtractor for DefaultContextExtractor {
147    async fn extract(
148        &self,
149        remote_addr: Option<SocketAddr>,
150        _metadata: Option<Arc<dyn Any + Send + Sync>>,
151    ) -> ConnectionContext {
152        ConnectionContext {
153            remote_addr,
154            metadata: std::collections::HashMap::new(),
155        }
156    }
157}
158
159/// Trait for implementing authentication/authorization checks
160///
161/// Implement this to control access to your JSON-RPC methods.
162/// The library will call `can_access` before executing methods.
163///
164/// You decide what "access" means - it could be:
165/// - Checking an API key in request params
166/// - Validating a JWT token
167/// - Verifying client certificates (via ConnectionContext)
168/// - Role-based checks
169/// - Rate limiting
170/// - IP whitelisting (via ConnectionContext)
171/// - Anything else you need
172pub trait AuthPolicy: Send + Sync {
173    /// Check if a request should be allowed to proceed
174    ///
175    /// # Arguments
176    /// * `method` - The JSON-RPC method being called
177    /// * `params` - Optional parameters from the request
178    /// * `ctx` - Connection context (IP, TLS certs, custom metadata)
179    ///
180    /// # Returns
181    /// `true` if the request should proceed, `false` to deny
182    ///
183    /// # Example: Using Connection Context
184    /// ```text
185    /// fn can_access(
186    ///     &self,
187    ///     method: &str,
188    ///     params: Option<&serde_json::Value>,
189    ///     ctx: &ConnectionContext,
190    /// ) -> bool {
191    ///     // Check IP whitelist
192    ///     if let Some(addr) = ctx.remote_addr {
193    ///         if !self.is_ip_allowed(&addr.ip()) {
194    ///             return false;
195    ///         }
196    ///     }
197    ///     
198    ///     // Check TLS client certificate
199    ///     if let Some(certs) = ctx.get::<Vec<Certificate>>("peer_certs") {
200    ///         return self.validate_client_cert(certs);
201    ///     }
202    ///     
203    ///     // Check token in params
204    ///     let token = params
205    ///         .and_then(|p| p.get("auth_token"))
206    ///         .and_then(|t| t.as_str());
207    ///     
208    ///     self.validate_token(token)
209    /// }
210    /// ```
211    fn can_access(
212        &self,
213        method: &str,
214        params: Option<&serde_json::Value>,
215        ctx: &ConnectionContext,
216    ) -> bool;
217
218    /// Optional: Get the unauthorized error response
219    ///
220    /// Override this if you want custom error messages for denied requests.
221    /// Default returns a generic "Unauthorized" error.
222    fn unauthorized_error(&self, method: &str) -> Response {
223        let _ = method;
224        crate::ResponseBuilder::new()
225            .error(
226                crate::ErrorBuilder::new(crate::error_codes::INTERNAL_ERROR, "Unauthorized")
227                    .build(),
228            )
229            .id(None)
230            .build()
231    }
232}
233
234/// Helper: Always allow all requests (no authentication)
235///
236/// Use this as a placeholder or for development/testing.
237pub struct AllowAll;
238
239impl AuthPolicy for AllowAll {
240    fn can_access(
241        &self,
242        _method: &str,
243        _params: Option<&serde_json::Value>,
244        _ctx: &ConnectionContext,
245    ) -> bool {
246        true
247    }
248}
249
250/// Helper: Deny all requests
251///
252/// Useful for maintenance mode or testing denial paths.
253pub struct DenyAll;
254
255impl AuthPolicy for DenyAll {
256    fn can_access(
257        &self,
258        _method: &str,
259        _params: Option<&serde_json::Value>,
260        _ctx: &ConnectionContext,
261    ) -> bool {
262        false
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_allow_all() {
272        let policy = AllowAll;
273        let ctx = ConnectionContext::new();
274        assert!(policy.can_access("any_method", None, &ctx));
275        assert!(policy.can_access(
276            "another_method",
277            Some(&serde_json::json!({"key": "value"})),
278            &ctx
279        ));
280    }
281
282    #[test]
283    fn test_deny_all() {
284        let policy = DenyAll;
285        let ctx = ConnectionContext::new();
286        assert!(!policy.can_access("any_method", None, &ctx));
287        assert!(!policy.can_access(
288            "another_method",
289            Some(&serde_json::json!({"key": "value"})),
290            &ctx
291        ));
292    }
293
294    #[test]
295    fn test_connection_context() {
296        let mut ctx = ConnectionContext::new();
297
298        // Insert and retrieve typed metadata
299        ctx.insert("user_id".to_string(), 42u64);
300        assert_eq!(ctx.get::<u64>("user_id"), Some(&42));
301
302        // Wrong type returns None
303        assert_eq!(ctx.get::<String>("user_id"), None);
304
305        // Non-existent key returns None
306        assert_eq!(ctx.get::<u64>("other"), None);
307    }
308
309    #[test]
310    fn test_connection_context_with_addr() {
311        use std::net::{IpAddr, Ipv4Addr};
312        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
313        let ctx = ConnectionContext::with_addr(addr);
314
315        assert_eq!(ctx.remote_addr, Some(addr));
316        assert_eq!(ctx.metadata.len(), 0);
317    }
318
319    #[test]
320    fn test_connection_context_default() {
321        let ctx = ConnectionContext::default();
322        assert!(ctx.remote_addr.is_none());
323        assert_eq!(ctx.metadata.len(), 0);
324    }
325
326    #[test]
327    fn test_connection_context_multiple_metadata() {
328        let mut ctx = ConnectionContext::new();
329
330        ctx.insert("user_id".to_string(), 123u64);
331        ctx.insert("username".to_string(), String::from("alice"));
332        ctx.insert("is_admin".to_string(), true);
333
334        assert_eq!(ctx.get::<u64>("user_id"), Some(&123));
335        assert_eq!(ctx.get::<String>("username"), Some(&String::from("alice")));
336        assert_eq!(ctx.get::<bool>("is_admin"), Some(&true));
337    }
338
339    #[test]
340    fn test_allow_all_unauthorized_error() {
341        let policy = AllowAll;
342        let response = policy.unauthorized_error("test_method");
343        assert!(response.error.is_some());
344        let error = response.error.unwrap();
345        assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
346        assert_eq!(error.message, "Unauthorized");
347    }
348
349    #[test]
350    fn test_deny_all_unauthorized_error() {
351        let policy = DenyAll;
352        let response = policy.unauthorized_error("blocked_method");
353        assert!(response.error.is_some());
354    }
355
356    #[tokio::test]
357    async fn test_default_context_extractor() {
358        use std::net::{IpAddr, Ipv4Addr};
359        let extractor = DefaultContextExtractor;
360        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000);
361
362        let ctx = extractor.extract(Some(addr), None).await;
363        assert_eq!(ctx.remote_addr, Some(addr));
364        assert_eq!(ctx.metadata.len(), 0);
365    }
366
367    #[tokio::test]
368    async fn test_default_context_extractor_no_addr() {
369        let extractor = DefaultContextExtractor;
370        let ctx = extractor.extract(None, None).await;
371        assert!(ctx.remote_addr.is_none());
372    }
373
374    #[tokio::test]
375    async fn test_default_context_extractor_with_metadata() {
376        use std::net::{IpAddr, Ipv4Addr};
377        let extractor = DefaultContextExtractor;
378        let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 3000);
379        let metadata: Arc<dyn Any + Send + Sync> = Arc::new(String::from("test"));
380
381        let ctx = extractor.extract(Some(addr), Some(metadata)).await;
382        assert_eq!(ctx.remote_addr, Some(addr));
383    }
384
385    #[test]
386    fn test_connection_context_clone() {
387        let mut ctx1 = ConnectionContext::new();
388        ctx1.insert("key".to_string(), 100u32);
389
390        let ctx2 = ctx1.clone();
391        assert_eq!(ctx2.get::<u32>("key"), Some(&100));
392    }
393}