ash_rpc/auth.rs
1//! Authentication and authorization hooks
2//!
3//! This module provides minimal traits for implementing authentication
4//! and authorization. The library makes NO assumptions about:
5//! - How you authenticate (JWT, API keys, OAuth, certificates, etc.)
6//! - What your user/identity model looks like
7//! - How you authorize (RBAC, ABAC, ACL, custom logic, etc.)
8//!
9//! You implement the trait, we call your `can_access` method.
10//!
11//! # Example
12//! ```
13//! use ash_rpc::auth::{AuthPolicy, ConnectionContext};
14//!
15//! struct MyAuth;
16//!
17//! impl AuthPolicy for MyAuth {
18//! fn can_access(&self, method: &str, params: Option<&serde_json::Value>, _ctx: &ConnectionContext) -> bool {
19//! // Your logic here - check API keys, JWT tokens, whatever you need
20//! let _ = (method, params);
21//! true
22//! }
23//! }
24//! ```
25
26use crate::Response;
27use std::any::Any;
28use std::net::SocketAddr;
29use std::sync::Arc;
30
31/// Type alias for auth metadata storage
32type AuthMetadata = std::collections::HashMap<String, Arc<dyn Any + Send + Sync>>;
33
34/// Connection context for authentication
35///
36/// This struct holds metadata about a connection that can be used for
37/// authentication and authorization decisions. The library makes NO
38/// assumptions about what data you need - store anything in `metadata`.
39///
40/// # Examples
41/// - TLS client certificates
42/// - IP addresses for whitelisting
43/// - Custom connection-level tokens
44/// - Session identifiers
45#[derive(Default, Clone)]
46pub struct ConnectionContext {
47 /// Remote address of the connection
48 pub remote_addr: Option<SocketAddr>,
49
50 /// User-defined metadata
51 ///
52 /// Store any auth-related data here:
53 /// - TLS peer certificates
54 /// - Extracted user IDs
55 /// - Session tokens
56 /// - Rate limiting state
57 /// - Whatever you need for your auth logic
58 pub metadata: AuthMetadata,
59}
60
61impl ConnectionContext {
62 /// Create a new empty context
63 #[must_use]
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 /// Create context with remote address
69 #[must_use]
70 pub fn with_addr(remote_addr: SocketAddr) -> Self {
71 Self {
72 remote_addr: Some(remote_addr),
73 metadata: std::collections::HashMap::new(),
74 }
75 }
76
77 /// Insert typed metadata
78 pub fn insert<T: Any + Send + Sync>(&mut self, key: String, value: T) {
79 self.metadata.insert(key, Arc::new(value));
80 }
81
82 /// Get typed metadata
83 #[must_use]
84 pub fn get<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
85 self.metadata.get(key).and_then(|v| v.downcast_ref::<T>())
86 }
87}
88
89/// Trait for extracting authentication context from connections
90///
91/// Implement this to extract auth data from your transport layer.
92/// The library will call this when a new connection is established.
93///
94/// # What you can extract:
95/// - TLS client certificates for mutual TLS authentication
96/// - IP addresses for whitelisting/geoblocking
97/// - Custom connection-level authentication tokens
98/// - Any connection metadata you need for auth decisions
99///
100/// # Example: TLS Certificate Extraction
101/// ```text
102/// use ash_rpc::auth::{ContextExtractor, ConnectionContext};
103///
104/// struct TlsContextExtractor;
105///
106/// #[async_trait::async_trait]
107/// impl ContextExtractor for TlsContextExtractor {
108/// async fn extract(&self, stream: &tokio_rustls::server::TlsStream<tokio::net::TcpStream>) -> ConnectionContext {
109/// let mut ctx = ConnectionContext::new();
110///
111/// // Extract TLS peer certificates
112/// if let Some(certs) = stream.get_ref().1.peer_certificates() {
113/// ctx.insert("peer_certs".to_string(), certs.clone());
114/// }
115///
116/// // Extract client IP
117/// if let Ok(addr) = stream.get_ref().0.peer_addr() {
118/// ctx.remote_addr = Some(addr);
119/// }
120///
121/// ctx
122/// }
123/// }
124/// ```
125#[async_trait::async_trait]
126pub trait ContextExtractor: Send + Sync {
127 /// Extract connection context for authentication
128 ///
129 /// This is called once when a connection is established.
130 /// The returned context is passed to the auth policy for each request.
131 ///
132 /// # Arguments
133 /// * `remote_addr` - Remote socket address of the connection
134 /// * `metadata` - Optional transport-specific metadata (e.g., TLS session data)
135 ///
136 /// # Returns
137 /// A `ConnectionContext` with whatever data you need for auth
138 async fn extract(
139 &self,
140 remote_addr: Option<SocketAddr>,
141 metadata: Option<Arc<dyn Any + Send + Sync>>,
142 ) -> ConnectionContext;
143}
144
145/// Default context extractor that only captures the remote address
146pub struct DefaultContextExtractor;
147
148#[async_trait::async_trait]
149impl ContextExtractor for DefaultContextExtractor {
150 async fn extract(
151 &self,
152 remote_addr: Option<SocketAddr>,
153 _metadata: Option<Arc<dyn Any + Send + Sync>>,
154 ) -> ConnectionContext {
155 ConnectionContext {
156 remote_addr,
157 metadata: std::collections::HashMap::new(),
158 }
159 }
160}
161
162/// Trait for implementing authentication/authorization checks
163///
164/// Implement this to control access to your JSON-RPC methods.
165/// The library will call `can_access` before executing methods.
166///
167/// You decide what "access" means - it could be:
168/// - Checking an API key in request params
169/// - Validating a JWT token
170/// - Verifying client certificates (via `ConnectionContext`)
171/// - Role-based checks
172/// - Rate limiting
173/// - IP whitelisting (via `ConnectionContext`)
174/// - Anything else you need
175pub trait AuthPolicy: Send + Sync {
176 /// Check if a request should be allowed to proceed
177 ///
178 /// # Arguments
179 /// * `method` - The JSON-RPC method being called
180 /// * `params` - Optional parameters from the request
181 /// * `ctx` - Connection context (IP, TLS certs, custom metadata)
182 ///
183 /// # Returns
184 /// `true` if the request should proceed, `false` to deny
185 ///
186 /// # Example: Using Connection Context
187 /// ```text
188 /// fn can_access(
189 /// &self,
190 /// method: &str,
191 /// params: Option<&serde_json::Value>,
192 /// ctx: &ConnectionContext,
193 /// ) -> bool {
194 /// // Check IP whitelist
195 /// if let Some(addr) = ctx.remote_addr {
196 /// if !self.is_ip_allowed(&addr.ip()) {
197 /// return false;
198 /// }
199 /// }
200 ///
201 /// // Check TLS client certificate
202 /// if let Some(certs) = ctx.get::<Vec<Certificate>>("peer_certs") {
203 /// return self.validate_client_cert(certs);
204 /// }
205 ///
206 /// // Check token in params
207 /// let token = params
208 /// .and_then(|p| p.get("auth_token"))
209 /// .and_then(|t| t.as_str());
210 ///
211 /// self.validate_token(token)
212 /// }
213 /// ```
214 fn can_access(
215 &self,
216 method: &str,
217 params: Option<&serde_json::Value>,
218 ctx: &ConnectionContext,
219 ) -> bool;
220
221 /// Optional: Get the unauthorized error response
222 ///
223 /// Override this if you want custom error messages for denied requests.
224 /// Default returns a generic "Unauthorized" error.
225 fn unauthorized_error(&self, method: &str) -> Response {
226 let _ = method;
227 crate::ResponseBuilder::new()
228 .error(
229 crate::ErrorBuilder::new(crate::error_codes::INTERNAL_ERROR, "Unauthorized")
230 .build(),
231 )
232 .id(None)
233 .build()
234 }
235}
236
237/// Helper: Always allow all requests (no authentication)
238///
239/// Use this as a placeholder or for development/testing.
240pub struct AllowAll;
241
242impl AuthPolicy for AllowAll {
243 fn can_access(
244 &self,
245 _method: &str,
246 _params: Option<&serde_json::Value>,
247 _ctx: &ConnectionContext,
248 ) -> bool {
249 true
250 }
251}
252
253/// Helper: Deny all requests
254///
255/// Useful for maintenance mode or testing denial paths.
256pub struct DenyAll;
257
258impl AuthPolicy for DenyAll {
259 fn can_access(
260 &self,
261 _method: &str,
262 _params: Option<&serde_json::Value>,
263 _ctx: &ConnectionContext,
264 ) -> bool {
265 false
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_allow_all() {
275 let policy = AllowAll;
276 let ctx = ConnectionContext::new();
277 assert!(policy.can_access("any_method", None, &ctx));
278 assert!(policy.can_access(
279 "another_method",
280 Some(&serde_json::json!({"key": "value"})),
281 &ctx
282 ));
283 }
284
285 #[test]
286 fn test_deny_all() {
287 let policy = DenyAll;
288 let ctx = ConnectionContext::new();
289 assert!(!policy.can_access("any_method", None, &ctx));
290 assert!(!policy.can_access(
291 "another_method",
292 Some(&serde_json::json!({"key": "value"})),
293 &ctx
294 ));
295 }
296
297 #[test]
298 fn test_connection_context() {
299 let mut ctx = ConnectionContext::new();
300
301 // Insert and retrieve typed metadata
302 ctx.insert("user_id".to_string(), 42u64);
303 assert_eq!(ctx.get::<u64>("user_id"), Some(&42));
304
305 // Wrong type returns None
306 assert_eq!(ctx.get::<String>("user_id"), None);
307
308 // Non-existent key returns None
309 assert_eq!(ctx.get::<u64>("other"), None);
310 }
311
312 #[test]
313 fn test_connection_context_with_addr() {
314 use std::net::{IpAddr, Ipv4Addr};
315 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
316 let ctx = ConnectionContext::with_addr(addr);
317
318 assert_eq!(ctx.remote_addr, Some(addr));
319 assert_eq!(ctx.metadata.len(), 0);
320 }
321
322 #[test]
323 fn test_connection_context_default() {
324 let ctx = ConnectionContext::default();
325 assert!(ctx.remote_addr.is_none());
326 assert_eq!(ctx.metadata.len(), 0);
327 }
328
329 #[test]
330 fn test_connection_context_multiple_metadata() {
331 let mut ctx = ConnectionContext::new();
332
333 ctx.insert("user_id".to_string(), 123u64);
334 ctx.insert("username".to_string(), String::from("alice"));
335 ctx.insert("is_admin".to_string(), true);
336
337 assert_eq!(ctx.get::<u64>("user_id"), Some(&123));
338 assert_eq!(ctx.get::<String>("username"), Some(&String::from("alice")));
339 assert_eq!(ctx.get::<bool>("is_admin"), Some(&true));
340 }
341
342 #[test]
343 fn test_allow_all_unauthorized_error() {
344 let policy = AllowAll;
345 let response = policy.unauthorized_error("test_method");
346 assert!(response.error.is_some());
347 let error = response.error.unwrap();
348 assert_eq!(error.code, crate::error_codes::INTERNAL_ERROR);
349 assert_eq!(error.message, "Unauthorized");
350 }
351
352 #[test]
353 fn test_deny_all_unauthorized_error() {
354 let policy = DenyAll;
355 let response = policy.unauthorized_error("blocked_method");
356 assert!(response.error.is_some());
357 }
358
359 #[tokio::test]
360 async fn test_default_context_extractor() {
361 use std::net::{IpAddr, Ipv4Addr};
362 let extractor = DefaultContextExtractor;
363 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 9000);
364
365 let ctx = extractor.extract(Some(addr), None).await;
366 assert_eq!(ctx.remote_addr, Some(addr));
367 assert_eq!(ctx.metadata.len(), 0);
368 }
369
370 #[tokio::test]
371 async fn test_default_context_extractor_no_addr() {
372 let extractor = DefaultContextExtractor;
373 let ctx = extractor.extract(None, None).await;
374 assert!(ctx.remote_addr.is_none());
375 }
376
377 #[tokio::test]
378 async fn test_default_context_extractor_with_metadata() {
379 use std::net::{IpAddr, Ipv4Addr};
380 let extractor = DefaultContextExtractor;
381 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 3000);
382 let metadata: Arc<dyn Any + Send + Sync> = Arc::new(String::from("test"));
383
384 let ctx = extractor.extract(Some(addr), Some(metadata)).await;
385 assert_eq!(ctx.remote_addr, Some(addr));
386 }
387
388 #[test]
389 fn test_connection_context_clone() {
390 let mut ctx1 = ConnectionContext::new();
391 ctx1.insert("key".to_string(), 100u32);
392
393 let ctx2 = ctx1.clone();
394 assert_eq!(ctx2.get::<u32>("key"), Some(&100));
395 }
396}