garde_actix_web/web/
either.rs1use super::form::Form;
2use super::json::Json;
3use actix_web::dev::Payload;
4use actix_web::web::Bytes;
5use actix_web::{Error, FromRequest, HttpRequest};
6use futures::ready;
7use pin_project_lite::pin_project;
8use std::future::Future;
9use std::mem;
10use std::pin::Pin;
11use std::task::Context;
12use std::task::Poll;
13
14#[derive(Debug, PartialEq, Eq)]
16pub enum Either<L, R> {
17 Left(L),
18 Right(R),
19}
20
21impl<T> Either<Form<T>, Json<T>> {
22 pub fn into_inner(self) -> T {
23 match self {
24 Either::Left(form) => form.into_inner(),
25 Either::Right(json) => json.into_inner(),
26 }
27 }
28}
29
30impl<T> Either<Json<T>, Form<T>> {
31 pub fn into_inner(self) -> T {
32 match self {
33 Either::Left(json) => json.into_inner(),
34 Either::Right(form) => form.into_inner(),
35 }
36 }
37}
38
39#[derive(Debug)]
40pub enum EitherExtractError<L, R> {
41 Bytes(Error),
42 Extract(L, R),
43}
44
45impl<L, R> From<EitherExtractError<L, R>> for Error
46where
47 L: Into<Error>,
48 R: Into<Error>,
49{
50 fn from(err: EitherExtractError<L, R>) -> Error {
51 match err {
52 EitherExtractError::Bytes(err) => err,
53 EitherExtractError::Extract(a_err, _b_err) => a_err.into(),
54 }
55 }
56}
57
58impl<L, R> FromRequest for Either<L, R>
59where
60 L: FromRequest + 'static,
61 R: FromRequest + 'static,
62{
63 type Error = EitherExtractError<L::Error, R::Error>;
64 type Future = EitherExtractFut<L, R>;
65
66 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
67 EitherExtractFut {
68 req: req.clone(),
69 state: EitherExtractState::Bytes {
70 bytes: Bytes::from_request(req, payload),
71 },
72 }
73 }
74}
75
76pin_project! {
77 pub struct EitherExtractFut<L, R>
78 where
79 R: FromRequest,
80 L: FromRequest,
81 {
82 req: HttpRequest,
83 #[pin]
84 state: EitherExtractState<L, R>,
85 }
86}
87
88pin_project! {
89 #[project = EitherExtractProj]
90 pub enum EitherExtractState<L, R>
91 where
92 L: FromRequest,
93 R: FromRequest,
94 {
95 Bytes {
96 #[pin]
97 bytes: <Bytes as FromRequest>::Future,
98 },
99 Left {
100 #[pin]
101 left: L::Future,
102 fallback: Bytes,
103 },
104 Right {
105 #[pin]
106 right: R::Future,
107 left_err: Option<L::Error>,
108 },
109 }
110}
111
112impl<R, RF, RE, L, LF, LE> Future for EitherExtractFut<L, R>
113where
114 L: FromRequest<Future = LF, Error = LE>,
115 R: FromRequest<Future = RF, Error = RE>,
116 LF: Future<Output = Result<L, LE>> + 'static,
117 RF: Future<Output = Result<R, RE>> + 'static,
118 LE: Into<Error>,
119 RE: Into<Error>,
120{
121 type Output = Result<Either<L, R>, EitherExtractError<LE, RE>>;
122
123 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
124 let mut this = self.project();
125 let ready = loop {
126 let next = match this.state.as_mut().project() {
127 EitherExtractProj::Bytes { bytes } => {
128 let res = ready!(bytes.poll(cx));
129 match res {
130 Ok(bytes) => {
131 let fallback = bytes.clone();
132 let left = L::from_request(this.req, &mut payload_from_bytes(bytes));
133 EitherExtractState::Left { left, fallback }
134 }
135 Err(err) => break Err(EitherExtractError::Bytes(err)),
136 }
137 }
138 EitherExtractProj::Left { left, fallback } => {
139 let res = ready!(left.poll(cx));
140 match res {
141 Ok(extracted) => break Ok(Either::Left(extracted)),
142 Err(left_err) => {
143 let right = R::from_request(this.req, &mut payload_from_bytes(mem::take(fallback)));
144 EitherExtractState::Right {
145 left_err: Some(left_err),
146 right,
147 }
148 }
149 }
150 }
151 EitherExtractProj::Right { right, left_err } => {
152 let res = ready!(right.poll(cx));
153 match res {
154 Ok(data) => break Ok(Either::Right(data)),
155 Err(err) => {
156 #[allow(clippy::unwrap_used)]
157 break Err(EitherExtractError::Extract(left_err.take().unwrap(), err));
158 }
159 }
160 }
161 };
162 this.state.set(next);
163 };
164
165 Poll::Ready(ready)
166 }
167}
168
169fn payload_from_bytes(bytes: Bytes) -> Payload {
170 let (_, mut h1_payload) = actix_http::h1::Payload::create(true);
171 h1_payload.unread_data(bytes);
172 Payload::from(h1_payload)
173}