allframe_core/auth/
tonic.rs1use std::{marker::PhantomData, sync::Arc};
24
25use tonic::{metadata::MetadataMap, Request, Status};
26
27use super::{extract_bearer_token, AuthContext, AuthError, Authenticator};
28
29pub struct AuthInterceptor<A, C> {
49 authenticator: Arc<A>,
50 required: bool,
51 _phantom: PhantomData<C>,
52}
53
54impl<A, C> Clone for AuthInterceptor<A, C> {
55 fn clone(&self) -> Self {
56 Self {
57 authenticator: self.authenticator.clone(),
58 required: self.required,
59 _phantom: PhantomData,
60 }
61 }
62}
63
64impl<A, C> AuthInterceptor<A, C>
65where
66 A: Authenticator<Claims = C>,
67 C: Clone + Send + Sync + 'static,
68{
69 pub fn new(authenticator: A) -> Self {
71 Self {
72 authenticator: Arc::new(authenticator),
73 required: true,
74 _phantom: PhantomData,
75 }
76 }
77
78 pub fn required(mut self) -> Self {
80 self.required = true;
81 self
82 }
83
84 pub fn optional(mut self) -> Self {
86 self.required = false;
87 self
88 }
89
90 fn authenticate_request<T>(&self, request: &Request<T>) -> Result<Option<C>, Status> {
92 let metadata = request.metadata();
93
94 let token = match extract_token_from_metadata(metadata) {
96 Some(token) => token,
97 None => {
98 if self.required {
99 return Err(auth_error_to_status(AuthError::MissingToken));
100 }
101 return Ok(None);
102 }
103 };
104
105 let authenticator = self.authenticator.clone();
108 let token_owned = token.to_string();
109
110 let result = tokio::task::block_in_place(|| {
112 tokio::runtime::Handle::current()
113 .block_on(async { authenticator.authenticate(&token_owned).await })
114 });
115
116 match result {
117 Ok(claims) => Ok(Some(claims)),
118 Err(e) => {
119 if self.required {
120 Err(auth_error_to_status(e))
121 } else {
122 Ok(None)
123 }
124 }
125 }
126 }
127}
128
129impl<A, C> tonic::service::Interceptor for AuthInterceptor<A, C>
130where
131 A: Authenticator<Claims = C> + 'static,
132 C: Clone + Send + Sync + 'static,
133{
134 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
135 match self.authenticate_request(&request)? {
136 Some(claims) => {
137 let token = extract_token_from_metadata(request.metadata())
139 .unwrap_or("")
140 .to_string();
141
142 let ctx = AuthContext::new(claims, token);
143 request.extensions_mut().insert(ctx);
144 }
145 None => {
146 }
148 }
149
150 Ok(request)
151 }
152}
153
154fn extract_token_from_metadata(metadata: &MetadataMap) -> Option<&str> {
156 metadata
157 .get("authorization")
158 .and_then(|v| v.to_str().ok())
159 .and_then(extract_bearer_token)
160}
161
162fn auth_error_to_status(error: AuthError) -> Status {
164 match error {
165 AuthError::MissingToken => Status::unauthenticated("missing authentication token"),
166 AuthError::InvalidToken(msg) => Status::unauthenticated(format!("invalid token: {}", msg)),
167 AuthError::TokenExpired => Status::unauthenticated("token has expired"),
168 AuthError::InvalidSignature => Status::unauthenticated("invalid token signature"),
169 AuthError::InvalidIssuer => Status::unauthenticated("invalid token issuer"),
170 AuthError::InvalidAudience => Status::unauthenticated("invalid token audience"),
171 AuthError::ValidationFailed(msg) => {
172 Status::permission_denied(format!("validation failed: {}", msg))
173 }
174 AuthError::Internal(msg) => Status::internal(format!("auth error: {}", msg)),
175 }
176}
177
178pub trait GrpcAuthExt<T> {
180 fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>>;
182
183 fn claims<C: Clone + Send + Sync + 'static>(&self) -> Option<&C> {
185 self.auth_context::<C>().map(|ctx| &ctx.claims)
186 }
187
188 fn require_auth<C: Clone + Send + Sync + 'static>(&self) -> Result<&C, Status> {
190 self.claims::<C>()
191 .ok_or_else(|| Status::unauthenticated("authentication required"))
192 }
193}
194
195impl<T> GrpcAuthExt<T> for Request<T> {
196 fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>> {
197 self.extensions().get::<AuthContext<C>>()
198 }
199}
200
201pub fn auth_interceptor<A, C>(
216 authenticator: A,
217 required: bool,
218) -> impl FnMut(Request<()>) -> Result<Request<()>, Status> + Clone
219where
220 A: Authenticator<Claims = C> + 'static,
221 C: Clone + Send + Sync + 'static,
222{
223 let mut interceptor = AuthInterceptor::new(authenticator);
224 if !required {
225 interceptor = interceptor.optional();
226 }
227
228 move |req| {
229 let mut i = interceptor.clone();
230 tonic::service::Interceptor::call(&mut i, req)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_extract_token_from_metadata() {
240 let mut metadata = MetadataMap::new();
241
242 assert!(extract_token_from_metadata(&metadata).is_none());
244
245 metadata.insert("authorization", "Bearer test123".parse().unwrap());
247 assert_eq!(extract_token_from_metadata(&metadata), Some("test123"));
248 }
249
250 #[test]
251 fn test_auth_error_to_status() {
252 let status = auth_error_to_status(AuthError::MissingToken);
253 assert_eq!(status.code(), tonic::Code::Unauthenticated);
254
255 let status = auth_error_to_status(AuthError::TokenExpired);
256 assert_eq!(status.code(), tonic::Code::Unauthenticated);
257
258 let status = auth_error_to_status(AuthError::ValidationFailed("test".into()));
259 assert_eq!(status.code(), tonic::Code::PermissionDenied);
260
261 let status = auth_error_to_status(AuthError::Internal("error".into()));
262 assert_eq!(status.code(), tonic::Code::Internal);
263 }
264
265 #[test]
266 fn test_grpc_auth_ext() {
267 #[derive(Clone)]
268 struct Claims {
269 sub: String,
270 }
271
272 let mut request = Request::new(());
273
274 assert!(request.auth_context::<Claims>().is_none());
276 assert!(request.claims::<Claims>().is_none());
277 assert!(request.require_auth::<Claims>().is_err());
278
279 request.extensions_mut().insert(AuthContext::new(
281 Claims {
282 sub: "user123".to_string(),
283 },
284 "token",
285 ));
286
287 assert!(request.auth_context::<Claims>().is_some());
289 assert_eq!(request.claims::<Claims>().unwrap().sub, "user123");
290 assert!(request.require_auth::<Claims>().is_ok());
291 }
292}