allframe_core/auth/
tonic.rs

1//! Tonic (gRPC) integration for authentication.
2//!
3//! Provides interceptors for using AllFrame auth with gRPC services.
4//!
5//! # Example
6//!
7//! ```rust,ignore
8//! use allframe_core::auth::{AuthInterceptor, JwtValidator, JwtConfig};
9//! use tonic::transport::Server;
10//!
11//! #[derive(Clone, serde::Deserialize)]
12//! struct Claims {
13//!     sub: String,
14//! }
15//!
16//! let validator = JwtValidator::<Claims>::new(JwtConfig::hs256("secret"));
17//! let interceptor = AuthInterceptor::new(validator);
18//!
19//! // Use with a service
20//! let service = MyServiceServer::with_interceptor(impl, interceptor);
21//! ```
22
23use std::{marker::PhantomData, sync::Arc};
24
25use tonic::{metadata::MetadataMap, Request, Status};
26
27use super::{extract_bearer_token, AuthContext, AuthError, Authenticator};
28
29/// gRPC interceptor for authentication.
30///
31/// Validates the bearer token from the `authorization` metadata and
32/// adds the auth context to request extensions.
33///
34/// # Example
35///
36/// ```rust,ignore
37/// use allframe_core::auth::{AuthInterceptor, JwtValidator, JwtConfig};
38///
39/// let validator = JwtValidator::<Claims>::new(JwtConfig::hs256("secret"));
40/// let interceptor = AuthInterceptor::new(validator);
41///
42/// // Required auth - rejects unauthenticated requests
43/// let service = MyServiceServer::with_interceptor(impl, interceptor.required());
44///
45/// // Optional auth - allows unauthenticated requests
46/// let service = MyServiceServer::with_interceptor(impl, interceptor.optional());
47/// ```
48pub struct AuthInterceptor<A, C> {
49    authenticator: Arc<A>,
50    required: bool,
51    _phantom: PhantomData<C>,
52}
53
54impl<A, C> Clone for AuthInterceptor<A, C> {
55    fn clone(&self) -> Self {
56        Self {
57            authenticator: self.authenticator.clone(),
58            required: self.required,
59            _phantom: PhantomData,
60        }
61    }
62}
63
64impl<A, C> AuthInterceptor<A, C>
65where
66    A: Authenticator<Claims = C>,
67    C: Clone + Send + Sync + 'static,
68{
69    /// Create a new auth interceptor (required by default).
70    pub fn new(authenticator: A) -> Self {
71        Self {
72            authenticator: Arc::new(authenticator),
73            required: true,
74            _phantom: PhantomData,
75        }
76    }
77
78    /// Make authentication required (rejects unauthenticated requests).
79    pub fn required(mut self) -> Self {
80        self.required = true;
81        self
82    }
83
84    /// Make authentication optional (allows unauthenticated requests).
85    pub fn optional(mut self) -> Self {
86        self.required = false;
87        self
88    }
89
90    /// Extract and validate the token from request metadata.
91    fn authenticate_request<T>(&self, request: &Request<T>) -> Result<Option<C>, Status> {
92        let metadata = request.metadata();
93
94        // Try to get the authorization header
95        let token = match extract_token_from_metadata(metadata) {
96            Some(token) => token,
97            None => {
98                if self.required {
99                    return Err(auth_error_to_status(AuthError::MissingToken));
100                }
101                return Ok(None);
102            }
103        };
104
105        // Validate the token synchronously
106        // Note: We use block_in_place since tonic interceptors are sync
107        let authenticator = self.authenticator.clone();
108        let token_owned = token.to_string();
109
110        // For sync validation (like JWT), we can use blocking
111        let result = tokio::task::block_in_place(|| {
112            tokio::runtime::Handle::current()
113                .block_on(async { authenticator.authenticate(&token_owned).await })
114        });
115
116        match result {
117            Ok(claims) => Ok(Some(claims)),
118            Err(e) => {
119                if self.required {
120                    Err(auth_error_to_status(e))
121                } else {
122                    Ok(None)
123                }
124            }
125        }
126    }
127}
128
129impl<A, C> tonic::service::Interceptor for AuthInterceptor<A, C>
130where
131    A: Authenticator<Claims = C> + 'static,
132    C: Clone + Send + Sync + 'static,
133{
134    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
135        match self.authenticate_request(&request)? {
136            Some(claims) => {
137                // Extract token for the context
138                let token = extract_token_from_metadata(request.metadata())
139                    .unwrap_or("")
140                    .to_string();
141
142                let ctx = AuthContext::new(claims, token);
143                request.extensions_mut().insert(ctx);
144            }
145            None => {
146                // Optional auth, no token present
147            }
148        }
149
150        Ok(request)
151    }
152}
153
154/// Extract bearer token from gRPC metadata.
155fn extract_token_from_metadata(metadata: &MetadataMap) -> Option<&str> {
156    metadata
157        .get("authorization")
158        .and_then(|v| v.to_str().ok())
159        .and_then(extract_bearer_token)
160}
161
162/// Convert an auth error to a tonic Status.
163fn auth_error_to_status(error: AuthError) -> Status {
164    match error {
165        AuthError::MissingToken => Status::unauthenticated("missing authentication token"),
166        AuthError::InvalidToken(msg) => Status::unauthenticated(format!("invalid token: {}", msg)),
167        AuthError::TokenExpired => Status::unauthenticated("token has expired"),
168        AuthError::InvalidSignature => Status::unauthenticated("invalid token signature"),
169        AuthError::InvalidIssuer => Status::unauthenticated("invalid token issuer"),
170        AuthError::InvalidAudience => Status::unauthenticated("invalid token audience"),
171        AuthError::ValidationFailed(msg) => {
172            Status::permission_denied(format!("validation failed: {}", msg))
173        }
174        AuthError::Internal(msg) => Status::internal(format!("auth error: {}", msg)),
175    }
176}
177
178/// Extension trait for extracting auth context from gRPC requests.
179pub trait GrpcAuthExt<T> {
180    /// Get the auth context if present.
181    fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>>;
182
183    /// Get the claims if authenticated.
184    fn claims<C: Clone + Send + Sync + 'static>(&self) -> Option<&C> {
185        self.auth_context::<C>().map(|ctx| &ctx.claims)
186    }
187
188    /// Get the claims, returning an error if not authenticated.
189    fn require_auth<C: Clone + Send + Sync + 'static>(&self) -> Result<&C, Status> {
190        self.claims::<C>()
191            .ok_or_else(|| Status::unauthenticated("authentication required"))
192    }
193}
194
195impl<T> GrpcAuthExt<T> for Request<T> {
196    fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>> {
197        self.extensions().get::<AuthContext<C>>()
198    }
199}
200
201/// Simple function-based interceptor for quick auth setup.
202///
203/// # Example
204///
205/// ```rust,ignore
206/// use allframe_core::auth::tonic::auth_interceptor;
207///
208/// let validator = JwtValidator::<Claims>::new(config);
209///
210/// // Create a simple interceptor function
211/// let intercept = auth_interceptor(validator, true);
212///
213/// let service = MyServiceServer::with_interceptor(impl, intercept);
214/// ```
215pub fn auth_interceptor<A, C>(
216    authenticator: A,
217    required: bool,
218) -> impl FnMut(Request<()>) -> Result<Request<()>, Status> + Clone
219where
220    A: Authenticator<Claims = C> + 'static,
221    C: Clone + Send + Sync + 'static,
222{
223    let mut interceptor = AuthInterceptor::new(authenticator);
224    if !required {
225        interceptor = interceptor.optional();
226    }
227
228    move |req| {
229        let mut i = interceptor.clone();
230        tonic::service::Interceptor::call(&mut i, req)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_extract_token_from_metadata() {
240        let mut metadata = MetadataMap::new();
241
242        // No token
243        assert!(extract_token_from_metadata(&metadata).is_none());
244
245        // Add bearer token
246        metadata.insert("authorization", "Bearer test123".parse().unwrap());
247        assert_eq!(extract_token_from_metadata(&metadata), Some("test123"));
248    }
249
250    #[test]
251    fn test_auth_error_to_status() {
252        let status = auth_error_to_status(AuthError::MissingToken);
253        assert_eq!(status.code(), tonic::Code::Unauthenticated);
254
255        let status = auth_error_to_status(AuthError::TokenExpired);
256        assert_eq!(status.code(), tonic::Code::Unauthenticated);
257
258        let status = auth_error_to_status(AuthError::ValidationFailed("test".into()));
259        assert_eq!(status.code(), tonic::Code::PermissionDenied);
260
261        let status = auth_error_to_status(AuthError::Internal("error".into()));
262        assert_eq!(status.code(), tonic::Code::Internal);
263    }
264
265    #[test]
266    fn test_grpc_auth_ext() {
267        #[derive(Clone)]
268        struct Claims {
269            sub: String,
270        }
271
272        let mut request = Request::new(());
273
274        // No auth initially
275        assert!(request.auth_context::<Claims>().is_none());
276        assert!(request.claims::<Claims>().is_none());
277        assert!(request.require_auth::<Claims>().is_err());
278
279        // Add auth context
280        request.extensions_mut().insert(AuthContext::new(
281            Claims {
282                sub: "user123".to_string(),
283            },
284            "token",
285        ));
286
287        // Now available
288        assert!(request.auth_context::<Claims>().is_some());
289        assert_eq!(request.claims::<Claims>().unwrap().sub, "user123");
290        assert!(request.require_auth::<Claims>().is_ok());
291    }
292}