1use crate::error::{BoxDynError, Error};
2use crate::request::Request;
3use crate::response::Response;
4use futures::channel::mpsc::{SendError, Sender};
5use futures::SinkExt;
6use futures::{future::BoxFuture, Future, FutureExt};
7use serde::Serialize;
8use std::marker::PhantomData;
9use std::{fmt, sync::Arc};
10pub use tower::{
11 layer::layer_fn, layer::util::Identity, util::BoxCloneService, Layer, Service, ServiceBuilder,
12};
13
14pub struct CommonLayer<In, T, U, E> {
17 boxed: Arc<dyn Layer<In, Service = BoxCloneService<T, U, E>>>,
18}
19
20impl<In, T, U, E> CommonLayer<In, T, U, E> {
21 pub fn new<L>(inner_layer: L) -> Self
23 where
24 L: Layer<In> + 'static,
25 L::Service: Service<T, Response = U, Error = E> + Send + 'static + Clone,
26 <L::Service as Service<T>>::Future: Send + 'static,
27 E: std::error::Error,
28 {
29 let layer = layer_fn(move |inner: In| {
30 let out = inner_layer.layer(inner);
31 BoxCloneService::new(out)
32 });
33
34 Self {
35 boxed: Arc::new(layer),
36 }
37 }
38}
39
40impl<In, T, U, E> Layer<In> for CommonLayer<In, T, U, E> {
41 type Service = BoxCloneService<T, U, E>;
42
43 fn layer(&self, inner: In) -> Self::Service {
44 self.boxed.layer(inner)
45 }
46}
47
48impl<In, T, U, E> Clone for CommonLayer<In, T, U, E> {
49 fn clone(&self) -> Self {
50 Self {
51 boxed: Arc::clone(&self.boxed),
52 }
53 }
54}
55
56impl<In, T, U, E> fmt::Debug for CommonLayer<In, T, U, E> {
57 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
58 fmt.debug_struct("CommonLayer").finish()
59 }
60}
61
62pub mod extensions {
64 use std::{
65 ops::Deref,
66 task::{Context, Poll},
67 };
68 use tower::Service;
69
70 use crate::request::Request;
71
72 #[derive(Debug, Clone, Copy)]
101 pub struct Data<T>(T);
102 impl<T> Data<T> {
103 pub fn new(inner: T) -> Data<T> {
105 Data(inner)
106 }
107 }
108
109 impl<T> Deref for Data<T> {
110 type Target = T;
111 fn deref(&self) -> &Self::Target {
112 &self.0
113 }
114 }
115
116 impl<S, T> tower::Layer<S> for Data<T>
117 where
118 T: Clone + Send + Sync + 'static,
119 {
120 type Service = AddExtension<S, T>;
121
122 fn layer(&self, inner: S) -> Self::Service {
123 AddExtension {
124 inner,
125 value: self.0.clone(),
126 }
127 }
128 }
129
130 #[derive(Clone, Copy, Debug)]
132 pub struct AddExtension<S, T> {
133 inner: S,
134 value: T,
135 }
136
137 impl<S, T, Req, Ctx> Service<Request<Req, Ctx>> for AddExtension<S, T>
138 where
139 S: Service<Request<Req, Ctx>>,
140 T: Clone + Send + Sync + 'static,
141 {
142 type Response = S::Response;
143 type Error = S::Error;
144 type Future = S::Future;
145
146 #[inline]
147 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148 self.inner.poll_ready(cx)
149 }
150
151 fn call(&mut self, mut req: Request<Req, Ctx>) -> Self::Future {
152 req.parts.data.insert(self.value.clone());
153 self.inner.call(req)
154 }
155 }
156}
157
158pub trait Ack<Task, Res> {
162 type Context;
164 type AckError: std::error::Error;
166
167 fn ack(
169 &mut self,
170 ctx: &Self::Context,
171 response: &Response<Res>,
172 ) -> impl Future<Output = Result<(), Self::AckError>> + Send;
173}
174
175impl<T, Res: Clone + Send + Sync, Ctx: Clone + Send + Sync> Ack<T, Res>
176 for Sender<(Ctx, Response<Res>)>
177{
178 type AckError = SendError;
179 type Context = Ctx;
180 async fn ack(
181 &mut self,
182 ctx: &Self::Context,
183 result: &Response<Res>,
184 ) -> Result<(), Self::AckError> {
185 let ctx = ctx.clone();
186 self.send((ctx, result.clone())).await.unwrap();
187 Ok(())
188 }
189}
190
191#[derive(Debug)]
193pub struct AckLayer<A, Req, Ctx, Res> {
194 ack: A,
195 job_type: PhantomData<Request<Req, Ctx>>,
196 res: PhantomData<Res>,
197}
198
199impl<A, Req, Ctx, Res> AckLayer<A, Req, Ctx, Res> {
200 pub fn new(ack: A) -> Self {
202 Self {
203 ack,
204 job_type: PhantomData,
205 res: PhantomData,
206 }
207 }
208}
209
210impl<A, Req, Ctx, S, Res> Layer<S> for AckLayer<A, Req, Ctx, Res>
211where
212 S: Service<Request<Req, Ctx>> + Send + 'static,
213 S::Error: std::error::Error + Send + Sync + 'static,
214 S::Future: Send + 'static,
215 A: Ack<Req, S::Response> + Clone + Send + Sync + 'static,
216{
217 type Service = AckService<S, A, Req, Ctx, S::Response>;
218
219 fn layer(&self, service: S) -> Self::Service {
220 AckService {
221 service,
222 ack: self.ack.clone(),
223 job_type: PhantomData,
224 res: PhantomData,
225 }
226 }
227}
228
229#[derive(Debug)]
231pub struct AckService<SV, A, Req, Ctx, Res> {
232 service: SV,
233 ack: A,
234 job_type: PhantomData<Request<Req, Ctx>>,
235 res: PhantomData<Res>,
236}
237
238impl<Sv: Clone, A: Clone, Req, Ctx, Res> Clone for AckService<Sv, A, Req, Ctx, Res> {
239 fn clone(&self) -> Self {
240 Self {
241 ack: self.ack.clone(),
242 job_type: PhantomData,
243 service: self.service.clone(),
244 res: PhantomData,
245 }
246 }
247}
248
249impl<SV, A, Req, Res, Ctx> Service<Request<Req, Ctx>> for AckService<SV, A, Req, Ctx, Res>
250where
251 SV: Service<Request<Req, Ctx>> + Send + Sync + 'static,
252 <SV as Service<Request<Req, Ctx>>>::Error: Into<BoxDynError> + Send + Sync + 'static,
253 <SV as Service<Request<Req, Ctx>>>::Future: std::marker::Send + 'static,
254 A: Ack<Req, <SV as Service<Request<Req, Ctx>>>::Response, Context = Ctx>
255 + Send
256 + 'static
257 + Clone
258 + Send
259 + Sync,
260 Req: 'static + Send,
261 <SV as Service<Request<Req, Ctx>>>::Response: std::marker::Send + fmt::Debug + Sync + Serialize,
262 <A as Ack<Req, SV::Response>>::Context: Sync + Send + Clone,
263 <A as Ack<Req, <SV as Service<Request<Req, Ctx>>>::Response>>::Context: 'static,
264 Ctx: Clone,
265{
266 type Response = SV::Response;
267 type Error = Error;
268 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
269
270 fn poll_ready(
271 &mut self,
272 cx: &mut std::task::Context<'_>,
273 ) -> std::task::Poll<Result<(), Self::Error>> {
274 self.service
275 .poll_ready(cx)
276 .map_err(|e| Error::Failed(Arc::new(e.into())))
277 }
278
279 fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
280 let mut ack = self.ack.clone();
281 let ctx = request.parts.context.clone();
282 let attempt = request.parts.attempt.clone();
283 let task_id = request.parts.task_id.clone();
284 let fut = self.service.call(request);
285 let fut_with_ack = async move {
286 let res = fut.await.map_err(|err| {
287 let e: BoxDynError = err.into();
288 if let Some(custom_error) = e.downcast_ref::<Error>() {
290 return custom_error.clone();
291 }
292 Error::Failed(Arc::new(e))
293 });
294 let response = Response {
295 attempt,
296 inner: res,
297 task_id,
298 _priv: (),
299 };
300 if let Err(_e) = ack.ack(&ctx, &response).await {
301 }
304 response.inner
305 };
306 fut_with_ack.boxed()
307 }
308}