ash_rpc/auth.rs
1//! Authentication and authorization hooks
2//!
3//! This module provides minimal traits for implementing authentication
4//! and authorization. The library makes NO assumptions about:
5//! - How you authenticate (JWT, API keys, OAuth, certificates, etc.)
6//! - What your user/identity model looks like
7//! - How you authorize (RBAC, ABAC, ACL, custom logic, etc.)
8//!
9//! You implement the trait, we call your `can_access` method.
10//!
11//! # Example
12//! ```
13//! use ash_rpc::auth::{AuthPolicy, ConnectionContext};
14//!
15//! struct MyAuth;
16//!
17//! impl AuthPolicy for MyAuth {
18//! fn can_access(&self, method: &str, params: Option<&serde_json::Value>, _ctx: &ConnectionContext) -> bool {
19//! // Your logic here - check API keys, JWT tokens, whatever you need
20//! let _ = (method, params);
21//! true
22//! }
23//! }
24//! ```
25
26use crate::Response;
27use std::any::Any;
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31/// Type alias for auth metadata storage
32type AuthMetadata = std::collections::HashMap<String, Arc<dyn Any + Send + Sync>>;
33
34/// Connection context for authentication
35///
36/// This struct holds metadata about a connection that can be used for
37/// authentication and authorization decisions. The library makes NO
38/// assumptions about what data you need - store anything in `metadata`.
39///
40/// # Examples
41/// - TLS client certificates
42/// - IP addresses for whitelisting
43/// - Custom connection-level tokens
44/// - Session identifiers
45#[derive(Default, Clone)]
46pub struct ConnectionContext {
47 /// Remote address of the connection
48 pub remote_addr: Option<SocketAddr>,
49
50 /// User-defined metadata
51 ///
52 /// Store any auth-related data here:
53 /// - TLS peer certificates
54 /// - Extracted user IDs
55 /// - Session tokens
56 /// - Rate limiting state
57 /// - Whatever you need for your auth logic
58 pub metadata: AuthMetadata,
59}
60
61impl ConnectionContext {
62 /// Create a new empty context
63 pub fn new() -> Self {
64 Self::default()
65 }
66
67 /// Create context with remote address
68 pub fn with_addr(remote_addr: SocketAddr) -> Self {
69 Self {
70 remote_addr: Some(remote_addr),
71 metadata: std::collections::HashMap::new(),
72 }
73 }
74
75 /// Insert typed metadata
76 pub fn insert<T: Any + Send + Sync>(&mut self, key: String, value: T) {
77 self.metadata.insert(key, Arc::new(value));
78 }
79
80 /// Get typed metadata
81 pub fn get<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
82 self.metadata.get(key).and_then(|v| v.downcast_ref::<T>())
83 }
84}
85
86/// Trait for extracting authentication context from connections
87///
88/// Implement this to extract auth data from your transport layer.
89/// The library will call this when a new connection is established.
90///
91/// # What you can extract:
92/// - TLS client certificates for mutual TLS authentication
93/// - IP addresses for whitelisting/geoblocking
94/// - Custom connection-level authentication tokens
95/// - Any connection metadata you need for auth decisions
96///
97/// # Example: TLS Certificate Extraction
98/// ```text
99/// use ash_rpc::auth::{ContextExtractor, ConnectionContext};
100///
101/// struct TlsContextExtractor;
102///
103/// #[async_trait::async_trait]
104/// impl ContextExtractor for TlsContextExtractor {
105/// async fn extract(&self, stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>) -> ConnectionContext {
106/// let mut ctx = ConnectionContext::new();
107///
108/// // Extract TLS peer certificates
109/// if let Some(certs) = stream.get_ref().1.peer_certificates() {
110/// ctx.insert("peer_certs".to_string(), certs.clone());
111/// }
112///
113/// // Extract client IP
114/// if let Ok(addr) = stream.get_ref().0.peer_addr() {
115/// ctx.remote_addr = Some(addr);
116/// }
117///
118/// ctx
119/// }
120/// }
121/// ```
122#[async_trait::async_trait]
123pub trait ContextExtractor: Send + Sync {
124 /// Extract connection context for authentication
125 ///
126 /// This is called once when a connection is established.
127 /// The returned context is passed to the auth policy for each request.
128 ///
129 /// # Arguments
130 /// * `remote_addr` - Remote socket address of the connection
131 /// * `metadata` - Optional transport-specific metadata (e.g., TLS session data)
132 ///
133 /// # Returns
134 /// A `ConnectionContext` with whatever data you need for auth
135 async fn extract(
136 &self,
137 remote_addr: Option<SocketAddr>,
138 metadata: Option<Arc<dyn Any + Send + Sync>>,
139 ) -> ConnectionContext;
140}
141
142/// Default context extractor that only captures the remote address
143pub struct DefaultContextExtractor;
144
145#[async_trait::async_trait]
146impl ContextExtractor for DefaultContextExtractor {
147 async fn extract(
148 &self,
149 remote_addr: Option<SocketAddr>,
150 _metadata: Option<Arc<dyn Any + Send + Sync>>,
151 ) -> ConnectionContext {
152 ConnectionContext {
153 remote_addr,
154 metadata: std::collections::HashMap::new(),
155 }
156 }
157}
158
159/// Trait for implementing authentication/authorization checks
160///
161/// Implement this to control access to your JSON-RPC methods.
162/// The library will call `can_access` before executing methods.
163///
164/// You decide what "access" means - it could be:
165/// - Checking an API key in request params
166/// - Validating a JWT token
167/// - Verifying client certificates (via ConnectionContext)
168/// - Role-based checks
169/// - Rate limiting
170/// - IP whitelisting (via ConnectionContext)
171/// - Anything else you need
172pub trait AuthPolicy: Send + Sync {
173 /// Check if a request should be allowed to proceed
174 ///
175 /// # Arguments
176 /// * `method` - The JSON-RPC method being called
177 /// * `params` - Optional parameters from the request
178 /// * `ctx` - Connection context (IP, TLS certs, custom metadata)
179 ///
180 /// # Returns
181 /// `true` if the request should proceed, `false` to deny
182 ///
183 /// # Example: Using Connection Context
184 /// ```text
185 /// fn can_access(
186 /// &self,
187 /// method: &str,
188 /// params: Option<&serde_json::Value>,
189 /// ctx: &ConnectionContext,
190 /// ) -> bool {
191 /// // Check IP whitelist
192 /// if let Some(addr) = ctx.remote_addr {
193 /// if !self.is_ip_allowed(&addr.ip()) {
194 /// return false;
195 /// }
196 /// }
197 ///
198 /// // Check TLS client certificate
199 /// if let Some(certs) = ctx.get::<Vec<Certificate>>("peer_certs") {
200 /// return self.validate_client_cert(certs);
201 /// }
202 ///
203 /// // Check token in params
204 /// let token = params
205 /// .and_then(|p| p.get("auth_token"))
206 /// .and_then(|t| t.as_str());
207 ///
208 /// self.validate_token(token)
209 /// }
210 /// ```
211 fn can_access(
212 &self,
213 method: &str,
214 params: Option<&serde_json::Value>,
215 ctx: &ConnectionContext,
216 ) -> bool;
217
218 /// Optional: Get the unauthorized error response
219 ///
220 /// Override this if you want custom error messages for denied requests.
221 /// Default returns a generic "Unauthorized" error.
222 fn unauthorized_error(&self, method: &str) -> Response {
223 let _ = method;
224 crate::ResponseBuilder::new()
225 .error(
226 crate::ErrorBuilder::new(crate::error_codes::INTERNAL_ERROR, "Unauthorized")
227 .build(),
228 )
229 .id(None)
230 .build()
231 }
232}
233
234/// Helper: Always allow all requests (no authentication)
235///
236/// Use this as a placeholder or for development/testing.
237pub struct AllowAll;
238
239impl AuthPolicy for AllowAll {
240 fn can_access(
241 &self,
242 _method: &str,
243 _params: Option<&serde_json::Value>,
244 _ctx: &ConnectionContext,
245 ) -> bool {
246 true
247 }
248}
249
250/// Helper: Deny all requests
251///
252/// Useful for maintenance mode or testing denial paths.
253pub struct DenyAll;
254
255impl AuthPolicy for DenyAll {
256 fn can_access(
257 &self,
258 _method: &str,
259 _params: Option<&serde_json::Value>,
260 _ctx: &ConnectionContext,
261 ) -> bool {
262 false
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[test]
271 fn test_allow_all() {
272 let policy = AllowAll;
273 let ctx = ConnectionContext::new();
274 assert!(policy.can_access("any_method", None, &ctx));
275 assert!(policy.can_access(
276 "another_method",
277 Some(&serde_json::json!({"key": "value"})),
278 &ctx
279 ));
280 }
281
282 #[test]
283 fn test_deny_all() {
284 let policy = DenyAll;
285 let ctx = ConnectionContext::new();
286 assert!(!policy.can_access("any_method", None, &ctx));
287 assert!(!policy.can_access(
288 "another_method",
289 Some(&serde_json::json!({"key": "value"})),
290 &ctx
291 ));
292 }
293
294 #[test]
295 fn test_connection_context() {
296 let mut ctx = ConnectionContext::new();
297
298 // Insert and retrieve typed metadata
299 ctx.insert("user_id".to_string(), 42u64);
300 assert_eq!(ctx.get::<u64>("user_id"), Some(&42));
301
302 // Wrong type returns None
303 assert_eq!(ctx.get::<String>("user_id"), None);
304
305 // Non-existent key returns None
306 assert_eq!(ctx.get::<u64>("other"), None);
307 }
308
309 #[test]
310 fn test_connection_context_with_addr() {
311 use std::net::{IpAddr, Ipv4Addr};
312 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
313 let ctx = ConnectionContext::with_addr(addr);
314
315 assert_eq!(ctx.remote_addr, Some(addr));
316 assert_eq!(ctx.metadata.len(), 0);
317 }
318
319 #[test]
320 fn test_connection_context_default() {
321 let ctx = ConnectionContext::default();
322 assert!(ctx.remote_addr.is_none());
323 assert_eq!(ctx.metadata.len(), 0);
324 }
325
326 #[test]
327 fn test_connection_context_multiple_metadata() {
328 let mut ctx = ConnectionContext::new();
329
330 ctx.insert("user_id".to_string(), 123u64);
331 ctx.insert("username".to_string(), String::from("alice"));
332 ctx.insert("is_admin".to_string(), true);
333
334 assert_eq!(ctx.get::<u64>("user_id"), Some(&123));
335 assert_eq!(ctx.get::<String>("username"), Some(&String::from("alice")));
336 assert_eq!(ctx.get::<bool>("is_admin"), Some(&true));
337 }
338
339 #[test]
340 fn test_allow_all_unauthorized_error() {
341 let policy = AllowAll;
342 let response = policy.unauthorized_error("test_method");
343 assert!(response.error.is_some());
344 let error = response.error.unwrap();
345 assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
346 assert_eq!(error.message, "Unauthorized");
347 }
348
349 #[test]
350 fn test_deny_all_unauthorized_error() {
351 let policy = DenyAll;
352 let response = policy.unauthorized_error("blocked_method");
353 assert!(response.error.is_some());
354 }
355
356 #[tokio::test]
357 async fn test_default_context_extractor() {
358 use std::net::{IpAddr, Ipv4Addr};
359 let extractor = DefaultContextExtractor;
360 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000);
361
362 let ctx = extractor.extract(Some(addr), None).await;
363 assert_eq!(ctx.remote_addr, Some(addr));
364 assert_eq!(ctx.metadata.len(), 0);
365 }
366
367 #[tokio::test]
368 async fn test_default_context_extractor_no_addr() {
369 let extractor = DefaultContextExtractor;
370 let ctx = extractor.extract(None, None).await;
371 assert!(ctx.remote_addr.is_none());
372 }
373
374 #[tokio::test]
375 async fn test_default_context_extractor_with_metadata() {
376 use std::net::{IpAddr, Ipv4Addr};
377 let extractor = DefaultContextExtractor;
378 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 3000);
379 let metadata: Arc<dyn Any + Send + Sync> = Arc::new(String::from("test"));
380
381 let ctx = extractor.extract(Some(addr), Some(metadata)).await;
382 assert_eq!(ctx.remote_addr, Some(addr));
383 }
384
385 #[test]
386 fn test_connection_context_clone() {
387 let mut ctx1 = ConnectionContext::new();
388 ctx1.insert("key".to_string(), 100u32);
389
390 let ctx2 = ctx1.clone();
391 assert_eq!(ctx2.get::<u32>("key"), Some(&100));
392 }
393}