1use axum::{
2 extract::{FromRequest, FromRequestParts, Request},
3 handler::Handler,
4 response::{IntoResponse, Response},
5 routing::MethodRouter,
6};
7use std::{any::Any, future::Future, pin::Pin};
8
9use crate::{
10 context::{
11 ConnectContext, RequestProtocol, validate_streaming_content_type,
12 validate_unary_content_type,
13 },
14 error::ConnectError,
15 message::{ConnectRequest, ConnectResponse, StreamBody, Streaming},
16};
17use futures::Stream;
18use prost::Message;
19use serde::de::DeserializeOwned;
20
21pub(crate) fn handle_extractor_rejection<R>(rejection: R, protocol: RequestProtocol) -> Response
27where
28 R: IntoResponse + Any,
29{
30 let rejection_any: Box<dyn Any> = Box::new(rejection);
31
32 match rejection_any.downcast::<ConnectError>() {
33 Ok(connect_err) => connect_err.into_response_with_protocol(protocol),
34 Err(any_box) => {
35 tracing::warn!(
36 "Extractor rejection is not ConnectError, returning as-is. \
37 If this is unintentional, consider using an extractor that returns ConnectError."
38 );
39 any_box
41 .downcast::<R>()
42 .map(|r| r.into_response())
43 .unwrap_or_else(|_| {
44 ConnectError::new_internal("extractor rejection")
46 .into_response_with_protocol(protocol)
47 })
48 }
49 }
50}
51
52fn validate_unary_protocol(ctx: &ConnectContext) -> Option<Response> {
57 validate_unary_content_type(ctx.protocol)
58 .map(|err| err.into_response_with_protocol(ctx.protocol))
59}
60
61pub(crate) fn validate_streaming_protocol(ctx: &ConnectContext) -> Option<Response> {
67 validate_streaming_content_type(ctx.protocol).map(|err| {
68 let use_proto = ctx.protocol.is_proto();
69 err.into_streaming_response(use_proto)
70 })
71}
72
73#[derive(Clone)]
75pub struct ConnectHandlerWrapper<F>(pub F);
76
77pub type ConnectHandler<F> = ConnectHandlerWrapper<F>;
79
80macro_rules! all_tuples_nonempty {
82 ($m:ident) => {
83 $m!([A1]);
84 $m!([A1, A2]);
85 $m!([A1, A2, A3]);
86 $m!([A1, A2, A3, A4]);
87 $m!([A1, A2, A3, A4, A5]);
88 $m!([A1, A2, A3, A4, A5, A6]);
89 $m!([A1, A2, A3, A4, A5, A6, A7]);
90 $m!([A1, A2, A3, A4, A5, A6, A7, A8]);
91 $m!([A1, A2, A3, A4, A5, A6, A7, A8, A9]);
92 $m!([A1, A2, A3, A4, A5, A6, A7, A8, A9, A10]);
93 $m!([A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11]);
94 $m!([A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12]);
95 $m!([A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13]);
96 $m!([A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14]);
97 $m!([
98 A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15
99 ]);
100 $m!([
101 A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16
102 ]);
103 };
104}
105
106impl<F, Fut, Req, Resp> Handler<(ConnectRequest<Req>,), ()> for ConnectHandlerWrapper<F>
110where
111 F: Fn(ConnectRequest<Req>) -> Fut + Clone + Send + Sync + 'static,
112 Fut: Future<Output = Result<ConnectResponse<Resp>, ConnectError>> + Send + 'static,
113 ConnectRequest<Req>: FromRequest<()>,
114 Req: Send + Sync + 'static,
115 Resp: prost::Message + serde::Serialize + Send + Clone + Sync + 'static,
116{
117 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
118
119 fn call(self, req: Request, _state: ()) -> Self::Future {
120 Box::pin(async move {
121 let ctx = req
123 .extensions()
124 .get::<ConnectContext>()
125 .cloned()
126 .unwrap_or_default();
127
128 if let Some(err_response) = validate_unary_protocol(&ctx) {
130 return err_response;
131 }
132
133 let connect_req = match ConnectRequest::<Req>::from_request(req, &()).await {
135 Ok(value) => value,
136 Err(err) => return err.into_response(),
137 };
138
139 let result = (self.0)(connect_req).await;
142
143 match result {
145 Ok(response) => response.into_response_with_context(&ctx),
146 Err(err) => err.into_response_with_protocol(ctx.protocol),
147 }
148 })
149 }
150}
151
152macro_rules! impl_handler_for_connect_handler_wrapper {
155 ([$($A:ident),*]) => {
156 impl<F, Fut, S, Req, Resp, $($A,)*> Handler<($($A,)* ConnectRequest<Req>,), S>
158 for ConnectHandlerWrapper<F>
159 where
160 F: Fn($($A,)* ConnectRequest<Req>) -> Fut + Clone + Send + Sync + 'static,
161 Fut: Future<Output = Result<ConnectResponse<Resp>, ConnectError>> + Send + 'static,
162 S:Clone+Send+Sync+'static,
163
164 $( $A: FromRequestParts<S> + Send + Sync + 'static,
166 <$A as FromRequestParts<S>>::Rejection: 'static, )*
167 ConnectRequest<Req>: FromRequest<S>,
168 Req: Send + Sync + 'static,
169 S: Send + Sync + 'static,
170
171 Resp: prost::Message + serde::Serialize + Send + Clone + Sync + 'static,
173 {
174 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
175
176 #[allow(unused_mut)]
177 fn call(self, req: Request, state: S) -> Self::Future {
178 Box::pin(async move {
179 let ctx = req
181 .extensions()
182 .get::<ConnectContext>()
183 .cloned()
184 .unwrap_or_default();
185
186 if let Some(err_response) = validate_unary_protocol(&ctx) {
188 return err_response;
189 }
190
191 let (mut parts, body) = req.into_parts();
193
194 $(
196 let $A = match $A::from_request_parts(&mut parts, &state).await {
197 Ok(value) => value,
198 Err(rejection) => return handle_extractor_rejection(rejection, ctx.protocol),
199 };
200 )*
201
202 let req = Request::from_parts(parts, body);
204
205 let connect_req = match ConnectRequest::<Req>::from_request(req, &state).await {
207 Ok(value) => value,
208 Err(err) => return err.into_response(),
209 };
210
211 let result = (self.0)($($A,)* connect_req).await;
214
215 match result {
217 Ok(response) => response.into_response_with_context(&ctx),
218 Err(err) => err.into_response_with_protocol(ctx.protocol),
219 }
220 })
221 }
222 }
223
224 };
225}
226
227#[allow(non_snake_case)]
228mod generated_handler_impls {
229 use super::*;
230 all_tuples_nonempty!(impl_handler_for_connect_handler_wrapper);
232}
233
234macro_rules! impl_server_stream_handler_for_connect_handler_wrapper {
238 ([$($A:ident),*]) => {
239 impl<F, Fut, S, Req, Resp, St, $($A,)*> Handler<($($A,)* ConnectRequest<Req>, StreamBody<St>), S>
240 for ConnectHandlerWrapper<F>
241 where
242 F: Fn($($A,)* ConnectRequest<Req>) -> Fut + Clone + Send + Sync + 'static,
243 Fut: Future<Output = Result<ConnectResponse<StreamBody<St>>, ConnectError>> + Send + 'static,
244 St: Stream<Item = Result<Resp, ConnectError>> + Send + 'static,
245 S: Clone + Send + Sync + 'static,
246
247 $( $A: FromRequestParts<S> + Send + Sync + 'static,
249 <$A as FromRequestParts<S>>::Rejection: 'static, )*
250 Req: Message + DeserializeOwned + Default + Send + Sync + 'static,
251 S: Send + Sync + 'static,
252
253 Resp: Message + serde::Serialize + Send + Sync + 'static,
255 {
256 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
257
258 #[allow(unused_mut)]
259 fn call(self, req: Request, state: S) -> Self::Future {
260 Box::pin(async move {
261 let ctx = req
262 .extensions()
263 .get::<ConnectContext>()
264 .cloned()
265 .unwrap_or_default();
266
267 if let Some(err_response) = validate_streaming_protocol(&ctx) {
268 return err_response;
269 }
270
271 let (mut parts, body) = req.into_parts();
272
273 $(
274 let $A = match $A::from_request_parts(&mut parts, &state).await {
275 Ok(value) => value,
276 Err(rejection) => return handle_extractor_rejection(rejection, ctx.protocol),
277 };
278 )*
279
280 let req = Request::from_parts(parts, body);
281
282 let connect_req = match ConnectRequest::<Req>::from_request(req, &state).await {
283 Ok(value) => value,
284 Err(err) => return err.into_response(),
285 };
286
287 let result = (self.0)($($A,)* connect_req).await;
288
289 match result {
290 Ok(response) => response.into_response_with_context(&ctx),
291 Err(err) => {
292 let use_proto = ctx.protocol.is_proto();
293 err.into_streaming_response(use_proto)
294 }
295 }
296 })
297 }
298 }
299 };
300}
301
302#[allow(non_snake_case)]
303mod generated_server_stream_handler_impls {
304 use super::*;
305 all_tuples_nonempty!(impl_server_stream_handler_for_connect_handler_wrapper);
306}
307
308macro_rules! impl_client_stream_handler_for_connect_handler_wrapper {
312 ([$($A:ident),*]) => {
313 impl<F, Fut, S, Req, Resp, $($A,)*> Handler<($($A,)* ConnectRequest<Streaming<Req>>, Resp), S>
314 for ConnectHandlerWrapper<F>
315 where
316 F: Fn($($A,)* ConnectRequest<Streaming<Req>>) -> Fut + Clone + Send + Sync + 'static,
317 Fut: Future<Output = Result<ConnectResponse<Resp>, ConnectError>> + Send + 'static,
318 S: Clone + Send + Sync + 'static,
319
320 $( $A: FromRequestParts<S> + Send + Sync + 'static,
322 <$A as FromRequestParts<S>>::Rejection: 'static, )*
323 Req: Message + DeserializeOwned + Default + Send + 'static,
324 S: Send + Sync + 'static,
325
326 Resp: Message + serde::Serialize + Send + Clone + Sync + 'static,
328 {
329 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
330
331 #[allow(unused_mut)]
332 fn call(self, req: Request, state: S) -> Self::Future {
333 Box::pin(async move {
334 let ctx = req
335 .extensions()
336 .get::<ConnectContext>()
337 .cloned()
338 .unwrap_or_default();
339
340 if let Some(err_response) = validate_streaming_protocol(&ctx) {
341 return err_response;
342 }
343
344 let (mut parts, body) = req.into_parts();
345
346 $(
347 let $A = match $A::from_request_parts(&mut parts, &state).await {
348 Ok(value) => value,
349 Err(rejection) => return handle_extractor_rejection(rejection, ctx.protocol),
350 };
351 )*
352
353 let req = Request::from_parts(parts, body);
354
355 let streaming_req =
356 match ConnectRequest::<Streaming<Req>>::from_request(req, &state).await {
357 Ok(value) => value,
358 Err(err) => return err.into_response(),
359 };
360
361 let result = (self.0)($($A,)* streaming_req).await;
362
363 match result {
364 Ok(response) => response.into_streaming_response_with_context(&ctx),
365 Err(err) => {
366 let use_proto = ctx.protocol.is_proto();
367 err.into_streaming_response(use_proto)
368 }
369 }
370 })
371 }
372 }
373 };
374}
375
376#[allow(non_snake_case)]
377mod generated_client_stream_handler_impls {
378 use super::*;
379 all_tuples_nonempty!(impl_client_stream_handler_for_connect_handler_wrapper);
380}
381
382macro_rules! impl_bidi_stream_handler_for_connect_handler_wrapper {
386 ([$($A:ident),*]) => {
387 impl<F, Fut, S, Req, Resp, St, $($A,)*> Handler<($($A,)* ConnectRequest<Streaming<Req>>, StreamBody<St>), S>
388 for ConnectHandlerWrapper<F>
389 where
390 F: Fn($($A,)* ConnectRequest<Streaming<Req>>) -> Fut + Clone + Send + Sync + 'static,
391 Fut: Future<Output = Result<ConnectResponse<StreamBody<St>>, ConnectError>> + Send + 'static,
392 St: Stream<Item = Result<Resp, ConnectError>> + Send + 'static,
393 S: Clone + Send + Sync + 'static,
394
395 $( $A: FromRequestParts<S> + Send + Sync + 'static,
397 <$A as FromRequestParts<S>>::Rejection: 'static, )*
398 Req: Message + DeserializeOwned + Default + Send + 'static,
399 S: Send + Sync + 'static,
400
401 Resp: Message + serde::Serialize + Send + Sync + 'static,
403 {
404 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
405
406 #[allow(unused_mut)]
407 fn call(self, req: Request, state: S) -> Self::Future {
408 Box::pin(async move {
409 let ctx = req
410 .extensions()
411 .get::<ConnectContext>()
412 .cloned()
413 .unwrap_or_default();
414
415 if let Some(err_response) = validate_streaming_protocol(&ctx) {
416 return err_response;
417 }
418
419 let (mut parts, body) = req.into_parts();
420
421 $(
422 let $A = match $A::from_request_parts(&mut parts, &state).await {
423 Ok(value) => value,
424 Err(rejection) => return handle_extractor_rejection(rejection, ctx.protocol),
425 };
426 )*
427
428 let req = Request::from_parts(parts, body);
429
430 let streaming_req =
431 match ConnectRequest::<Streaming<Req>>::from_request(req, &state).await {
432 Ok(value) => value,
433 Err(err) => return err.into_response(),
434 };
435
436 let result = (self.0)($($A,)* streaming_req).await;
437
438 match result {
439 Ok(response) => response.into_response_with_context(&ctx),
440 Err(err) => {
441 let use_proto = ctx.protocol.is_proto();
442 err.into_streaming_response(use_proto)
443 }
444 }
445 })
446 }
447 }
448 };
449}
450
451#[allow(non_snake_case)]
452mod generated_bidi_stream_handler_impls {
453 use super::*;
454 all_tuples_nonempty!(impl_bidi_stream_handler_for_connect_handler_wrapper);
455}
456
457impl<F, Fut, Req, Resp, St> Handler<(ConnectRequest<Req>, StreamBody<St>), ()>
463 for ConnectHandlerWrapper<F>
464where
465 F: Fn(ConnectRequest<Req>) -> Fut + Clone + Send + Sync + 'static,
466 Fut: Future<Output = Result<ConnectResponse<StreamBody<St>>, ConnectError>> + Send + 'static,
467 St: Stream<Item = Result<Resp, ConnectError>> + Send + 'static,
468 Req: Message + DeserializeOwned + Default + Send + Sync + 'static,
470 Resp: Message + serde::Serialize + Send + Sync + 'static,
471{
472 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
473
474 fn call(self, req: Request, _state: ()) -> Self::Future {
475 Box::pin(async move {
476 let ctx = req
477 .extensions()
478 .get::<ConnectContext>()
479 .cloned()
480 .unwrap_or_default();
481
482 if let Some(err_response) = validate_streaming_protocol(&ctx) {
484 return err_response;
485 }
486
487 let connect_req = match ConnectRequest::<Req>::from_request(req, &()).await {
488 Ok(value) => value,
489 Err(err) => return err.into_response(),
490 };
491
492 let result = (self.0)(connect_req).await;
493
494 match result {
495 Ok(response) => response.into_response_with_context(&ctx),
496 Err(err) => {
497 let use_proto = ctx.protocol.is_proto();
498 err.into_streaming_response(use_proto)
499 }
500 }
501 })
502 }
503}
504
505impl<F, Fut, Req, Resp> Handler<(ConnectRequest<Streaming<Req>>, Resp), ()>
511 for ConnectHandlerWrapper<F>
512where
513 F: Fn(ConnectRequest<Streaming<Req>>) -> Fut + Clone + Send + Sync + 'static,
514 Fut: Future<Output = Result<ConnectResponse<Resp>, ConnectError>> + Send + 'static,
515 Req: Message + DeserializeOwned + Default + Send + 'static,
516 Resp: Message + serde::Serialize + Send + Clone + Sync + 'static,
517{
518 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
519
520 fn call(self, req: Request, _state: ()) -> Self::Future {
521 Box::pin(async move {
522 let ctx = req
523 .extensions()
524 .get::<ConnectContext>()
525 .cloned()
526 .unwrap_or_default();
527
528 if let Some(err_response) = validate_streaming_protocol(&ctx) {
530 return err_response;
531 }
532
533 let streaming_req = match ConnectRequest::<Streaming<Req>>::from_request(req, &()).await
534 {
535 Ok(value) => value,
536 Err(err) => return err.into_response(),
537 };
538
539 let result = (self.0)(streaming_req).await;
540
541 match result {
543 Ok(response) => response.into_streaming_response_with_context(&ctx),
544 Err(err) => {
545 let use_proto = ctx.protocol.is_proto();
546 err.into_streaming_response(use_proto)
547 }
548 }
549 })
550 }
551}
552
553impl<F, Fut, Req, Resp, St> Handler<(ConnectRequest<Streaming<Req>>, StreamBody<St>), ()>
560 for ConnectHandlerWrapper<F>
561where
562 F: Fn(ConnectRequest<Streaming<Req>>) -> Fut + Clone + Send + Sync + 'static,
563 Fut: Future<Output = Result<ConnectResponse<StreamBody<St>>, ConnectError>> + Send + 'static,
564 St: Stream<Item = Result<Resp, ConnectError>> + Send + 'static,
565 Req: Message + DeserializeOwned + Default + Send + 'static,
566 Resp: Message + serde::Serialize + Send + Sync + 'static,
567{
568 type Future = Pin<Box<dyn Future<Output = Response> + Send>>;
569
570 fn call(self, req: Request, _state: ()) -> Self::Future {
571 Box::pin(async move {
572 let ctx = req
573 .extensions()
574 .get::<ConnectContext>()
575 .cloned()
576 .unwrap_or_default();
577
578 if let Some(err_response) = validate_streaming_protocol(&ctx) {
580 return err_response;
581 }
582
583 let streaming_req = match ConnectRequest::<Streaming<Req>>::from_request(req, &()).await
584 {
585 Ok(value) => value,
586 Err(err) => return err.into_response(),
587 };
588
589 let result = (self.0)(streaming_req).await;
590
591 match result {
592 Ok(response) => response.into_response_with_context(&ctx),
593 Err(err) => {
594 let use_proto = ctx.protocol.is_proto();
595 err.into_streaming_response(use_proto)
596 }
597 }
598 })
599 }
600}
601
602pub fn post_connect<F, T, S>(f: F) -> MethodRouter<S>
633where
634 S: Clone + Send + Sync + 'static,
635 ConnectHandlerWrapper<F>: Handler<T, S>,
636 T: 'static,
637{
638 axum::routing::post(ConnectHandlerWrapper(f))
639}
640
641pub fn get_connect<F, T, S>(f: F) -> MethodRouter<S>
658where
659 S: Clone + Send + Sync + 'static,
660 ConnectHandlerWrapper<F>: Handler<T, S>,
661 T: 'static,
662{
663 axum::routing::get(ConnectHandlerWrapper(f))
664}