1use std::{fmt, future::Future, marker::PhantomData};
2
3use futures::TryFuture;
4use pin_project_lite::pin_project;
5use tower::{Layer, Service};
6
7use crate::{ExcService, ExchangeError, Request};
8
9pub trait Adaptor<R: Request>: Request {
11 fn from_request(req: R) -> Result<Self, ExchangeError>;
13
14 fn into_response(resp: Self::Response) -> Result<R::Response, ExchangeError>;
16}
17
18impl<T, R, E> Adaptor<R> for T
19where
20 T: Request,
21 R: Request,
22 T: TryFrom<R, Error = E>,
23 T::Response: TryInto<R::Response, Error = E>,
24 ExchangeError: From<E>,
25{
26 fn from_request(req: R) -> Result<Self, ExchangeError>
27 where
28 Self: Sized,
29 {
30 Ok(Self::try_from(req)?)
31 }
32
33 fn into_response(resp: Self::Response) -> Result<<R as Request>::Response, ExchangeError> {
34 Ok(resp.try_into()?)
35 }
36}
37
38#[derive(Debug)]
40pub struct AdaptLayer<Req, R>(PhantomData<fn() -> (Req, R)>);
41
42impl<Req, R> Default for AdaptLayer<Req, R> {
43 fn default() -> Self {
44 Self(PhantomData)
45 }
46}
47
48impl<S, Req, R> Layer<S> for AdaptLayer<Req, R> {
49 type Service = Adapt<S, Req, R>;
50
51 fn layer(&self, inner: S) -> Self::Service {
52 Adapt(inner, PhantomData)
53 }
54}
55
56pub trait AdaptService<Req, R>: ExcService<Req>
58where
59 Req: Request,
60 R: Request,
61{
62 type AdaptedResponse: Future<Output = Result<R::Response, ExchangeError>>;
64
65 fn adapt_from_request(&mut self, req: R) -> Result<Req, ExchangeError>;
67
68 fn adapt_into_response(&mut self, res: Self::Future) -> Self::AdaptedResponse;
70}
71
72pin_project! {
73 #[derive(Debug)]
75 pub struct AndThen<Fut, F> {
76 #[pin]
77 fut: Fut,
78 f: Option<F>,
79 }
80}
81
82impl<Fut, F> AndThen<Fut, F>
83where
84 Fut: TryFuture<Error = ExchangeError>,
85{
86 pub fn new(fut: Fut, f: F) -> Self {
88 Self { fut, f: Some(f) }
89 }
90}
91
92impl<Fut, F, T> Future for AndThen<Fut, F>
93where
94 Fut: TryFuture<Error = ExchangeError>,
95 F: FnOnce(Fut::Ok) -> Result<T, ExchangeError>,
96{
97 type Output = Result<T, ExchangeError>;
98
99 fn poll(
100 self: std::pin::Pin<&mut Self>,
101 cx: &mut std::task::Context<'_>,
102 ) -> std::task::Poll<Self::Output> {
103 let this = self.project();
104 match this.fut.try_poll(cx) {
105 std::task::Poll::Ready(Ok(ok)) => match this.f.take() {
106 Some(f) => std::task::Poll::Ready((f)(ok)),
107 None => std::task::Poll::Pending,
108 },
109 std::task::Poll::Ready(Err(err)) => std::task::Poll::Ready(Err(err)),
110 std::task::Poll::Pending => std::task::Poll::Pending,
111 }
112 }
113}
114
115impl<C, Req, R> AdaptService<Req, R> for C
116where
117 Req: Request,
118 R: Request,
119 Req: Adaptor<R>,
120 C: ExcService<Req>,
121{
122 type AdaptedResponse =
123 AndThen<Self::Future, fn(Req::Response) -> Result<R::Response, ExchangeError>>;
124
125 fn adapt_from_request(&mut self, req: R) -> Result<Req, ExchangeError> {
126 Req::from_request(req)
127 }
128
129 fn adapt_into_response(&mut self, res: Self::Future) -> Self::AdaptedResponse {
130 AndThen::new(res, Req::into_response)
131 }
132}
133
134pub struct Adapt<C, Req, R>(C, PhantomData<fn() -> (Req, R)>);
136
137impl<C, Req, R> fmt::Debug for Adapt<C, Req, R>
138where
139 C: fmt::Debug,
140{
141 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142 f.debug_tuple("Adapt")
143 .field(&self.0)
144 .field(&self.1)
145 .finish()
146 }
147}
148
149impl<C, Req, R> Clone for Adapt<C, Req, R>
150where
151 C: Clone,
152{
153 fn clone(&self) -> Self {
154 Self(self.0.clone(), PhantomData)
155 }
156}
157
158impl<C, Req, R> Copy for Adapt<C, Req, R> where C: Copy {}
159
160pin_project! {
161 #[allow(missing_docs)]
163 #[project = AdaptProj]
164 #[derive(Debug)]
165 pub enum AdaptFuture<Fut> {
166 FromRequestError {
168 err: Option<ExchangeError>,
169 },
170 IntoResponse {
172 #[pin]
173 fut: Fut,
174 }
175 }
176}
177
178impl<Fut> AdaptFuture<Fut> {
179 pub fn from_request_error(err: ExchangeError) -> Self {
181 Self::FromRequestError { err: Some(err) }
182 }
183
184 pub fn into_response(fut: Fut) -> Self {
186 Self::IntoResponse { fut }
187 }
188}
189
190impl<Fut> Future for AdaptFuture<Fut>
191where
192 Fut: TryFuture<Error = ExchangeError>,
193{
194 type Output = Result<Fut::Ok, ExchangeError>;
195
196 fn poll(
197 mut self: std::pin::Pin<&mut Self>,
198 cx: &mut std::task::Context<'_>,
199 ) -> std::task::Poll<Self::Output> {
200 match self.as_mut().project() {
201 AdaptProj::FromRequestError { err } => match err.take() {
202 Some(err) => std::task::Poll::Ready(Err(err)),
203 None => std::task::Poll::Pending,
204 },
205 AdaptProj::IntoResponse { fut, .. } => fut.try_poll(cx),
206 }
207 }
208}
209
210impl<C, Req, R> Service<R> for Adapt<C, Req, R>
211where
212 C: AdaptService<Req, R>,
213 Req: Request,
214 R: Request,
215{
216 type Response = R::Response;
217
218 type Error = ExchangeError;
219
220 type Future = AdaptFuture<C::AdaptedResponse>;
221
222 fn poll_ready(
223 &mut self,
224 cx: &mut std::task::Context<'_>,
225 ) -> std::task::Poll<Result<(), Self::Error>> {
226 self.0.poll_ready(cx)
227 }
228
229 fn call(&mut self, req: R) -> Self::Future {
230 let req = match self.0.adapt_from_request(req) {
231 Ok(req) => req,
232 Err(err) => return AdaptFuture::from_request_error(err),
233 };
234 let res = self.0.call(req);
235 AdaptFuture::into_response(self.0.adapt_into_response(res))
236 }
237}