1use crate::codec::Codec;
2use crate::error::{BoxDynError, Error};
3use crate::request::Request;
4use crate::response::Response;
5use futures::channel::mpsc::{SendError, Sender};
6use futures::SinkExt;
7use futures::{future::BoxFuture, Future, FutureExt};
8use serde::Serialize;
9use std::fmt::Debug;
10use std::marker::PhantomData;
11use std::{fmt, sync::Arc};
12pub use tower::{
13 layer::layer_fn, layer::util::Identity, util::BoxCloneService, Layer, Service, ServiceBuilder,
14};
15
16pub struct CommonLayer<In, T, U, E> {
19 boxed: Arc<dyn Layer<In, Service = BoxCloneService<T, U, E>>>,
20}
21
22impl<In, T, U, E> CommonLayer<In, T, U, E> {
23 pub fn new<L>(inner_layer: L) -> Self
25 where
26 L: Layer<In> + 'static,
27 L::Service: Service<T, Response = U, Error = E> + Send + 'static + Clone,
28 <L::Service as Service<T>>::Future: Send + 'static,
29 E: std::error::Error,
30 {
31 let layer = layer_fn(move |inner: In| {
32 let out = inner_layer.layer(inner);
33 BoxCloneService::new(out)
34 });
35
36 Self {
37 boxed: Arc::new(layer),
38 }
39 }
40}
41
42impl<In, T, U, E> Layer<In> for CommonLayer<In, T, U, E> {
43 type Service = BoxCloneService<T, U, E>;
44
45 fn layer(&self, inner: In) -> Self::Service {
46 self.boxed.layer(inner)
47 }
48}
49
50impl<In, T, U, E> Clone for CommonLayer<In, T, U, E> {
51 fn clone(&self) -> Self {
52 Self {
53 boxed: Arc::clone(&self.boxed),
54 }
55 }
56}
57
58impl<In, T, U, E> fmt::Debug for CommonLayer<In, T, U, E> {
59 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
60 fmt.debug_struct("CommonLayer").finish()
61 }
62}
63
64pub mod extensions {
66 use std::{
67 ops::Deref,
68 task::{Context, Poll},
69 };
70 use tower::Service;
71
72 use crate::request::Request;
73
74 #[derive(Debug, Clone, Copy)]
103 pub struct Data<T>(T);
104 impl<T> Data<T> {
105 pub fn new(inner: T) -> Data<T> {
107 Data(inner)
108 }
109 }
110
111 impl<T> Deref for Data<T> {
112 type Target = T;
113 fn deref(&self) -> &Self::Target {
114 &self.0
115 }
116 }
117
118 impl<S, T> tower::Layer<S> for Data<T>
119 where
120 T: Clone + Send + Sync + 'static,
121 {
122 type Service = AddExtension<S, T>;
123
124 fn layer(&self, inner: S) -> Self::Service {
125 AddExtension {
126 inner,
127 value: self.0.clone(),
128 }
129 }
130 }
131
132 #[derive(Clone, Copy, Debug)]
134 pub struct AddExtension<S, T> {
135 inner: S,
136 value: T,
137 }
138
139 impl<S, T, Req, Ctx> Service<Request<Req, Ctx>> for AddExtension<S, T>
140 where
141 S: Service<Request<Req, Ctx>>,
142 T: Clone + Send + Sync + 'static,
143 {
144 type Response = S::Response;
145 type Error = S::Error;
146 type Future = S::Future;
147
148 #[inline]
149 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150 self.inner.poll_ready(cx)
151 }
152
153 fn call(&mut self, mut req: Request<Req, Ctx>) -> Self::Future {
154 req.parts.data.insert(self.value.clone());
155 self.inner.call(req)
156 }
157 }
158}
159
160pub trait Ack<Task, Res, Codec> {
164 type Context;
166 type AckError: std::error::Error;
168
169 fn ack(
171 &mut self,
172 ctx: &Self::Context,
173 response: &Response<Res>,
174 ) -> impl Future<Output = Result<(), Self::AckError>> + Send;
175}
176
177impl<T, Res: Clone + Send + Sync + Serialize, Ctx: Clone + Send + Sync, Cdc: Codec> Ack<T, Res, Cdc>
178 for Sender<(Ctx, Response<Cdc::Compact>)>
179where
180 Cdc::Error: Debug,
181 Cdc::Compact: Send,
182{
183 type AckError = SendError;
184 type Context = Ctx;
185 async fn ack(
186 &mut self,
187 ctx: &Self::Context,
188 result: &Response<Res>,
189 ) -> Result<(), Self::AckError> {
190 let ctx = ctx.clone();
191 let res = result.map(|res| Cdc::encode(res).unwrap());
192 self.send((ctx, res)).await.unwrap();
193 Ok(())
194 }
195}
196
197#[derive(Debug)]
199pub struct AckLayer<A, Req, Ctx, Cdc> {
200 ack: A,
201 job_type: PhantomData<Request<Req, Ctx>>,
202 codec: PhantomData<Cdc>,
203}
204
205impl<A, Req, Ctx, Cdc> AckLayer<A, Req, Ctx, Cdc> {
206 pub fn new(ack: A) -> Self {
208 Self {
209 ack,
210 job_type: PhantomData,
211 codec: PhantomData,
212 }
213 }
214}
215
216impl<A, Req, Ctx, S, Cdc> Layer<S> for AckLayer<A, Req, Ctx, Cdc>
217where
218 S: Service<Request<Req, Ctx>> + Send + 'static,
219 S::Error: std::error::Error + Send + Sync + 'static,
220 S::Future: Send + 'static,
221 A: Ack<Req, S::Response, Cdc> + Clone + Send + Sync + 'static,
222{
223 type Service = AckService<S, A, Req, Ctx, Cdc>;
224
225 fn layer(&self, service: S) -> Self::Service {
226 AckService {
227 service,
228 ack: self.ack.clone(),
229 job_type: PhantomData,
230 codec: PhantomData,
231 }
232 }
233}
234
235#[derive(Debug)]
237pub struct AckService<SV, A, Req, Ctx, Cdc> {
238 service: SV,
239 ack: A,
240 job_type: PhantomData<Request<Req, Ctx>>,
241 codec: PhantomData<Cdc>,
242}
243
244impl<Sv: Clone, A: Clone, Req, Ctx, Cdc> Clone for AckService<Sv, A, Req, Ctx, Cdc> {
245 fn clone(&self) -> Self {
246 Self {
247 ack: self.ack.clone(),
248 job_type: PhantomData,
249 service: self.service.clone(),
250 codec: PhantomData,
251 }
252 }
253}
254
255impl<SV, A, Req, Ctx, Cdc> Service<Request<Req, Ctx>> for AckService<SV, A, Req, Ctx, Cdc>
256where
257 SV: Service<Request<Req, Ctx>> + Send + 'static,
258 SV::Error: Into<BoxDynError> + Send + 'static,
259 SV::Future: Send + 'static,
260 A: Ack<Req, SV::Response, Cdc, Context = Ctx> + Send + 'static + Clone,
261 Req: 'static + Send,
262 SV::Response: std::marker::Send + Serialize,
263 <A as Ack<Req, SV::Response, Cdc>>::Context: Send + Clone,
264 <A as Ack<Req, SV::Response, Cdc>>::Context: 'static,
265 Ctx: Clone,
266{
267 type Response = SV::Response;
268 type Error = Error;
269 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
270
271 fn poll_ready(
272 &mut self,
273 cx: &mut std::task::Context<'_>,
274 ) -> std::task::Poll<Result<(), Self::Error>> {
275 self.service
276 .poll_ready(cx)
277 .map_err(|e| Error::Failed(Arc::new(e.into())))
278 }
279
280 fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
281 let mut ack = self.ack.clone();
282 let ctx = request.parts.context.clone();
283 let attempt = request.parts.attempt.clone();
284 let task_id = request.parts.task_id.clone();
285 let fut = self.service.call(request);
286 let fut_with_ack = async move {
287 let res = fut.await.map_err(|err| {
288 let e: BoxDynError = err.into();
289 if let Some(custom_error) = e.downcast_ref::<Error>() {
291 return custom_error.clone();
292 }
293 Error::Failed(Arc::new(e))
294 });
295 let response = Response {
296 attempt,
297 inner: res,
298 task_id,
299 _priv: (),
300 };
301 if let Err(_e) = ack.ack(&ctx, &response).await {
302 }
305 response.inner
306 };
307 fut_with_ack.boxed()
308 }
309}