1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::rc::Rc;
5use std::task::{Context, Poll};
6
7use actix_codec::{AsyncRead, AsyncWrite, Decoder, Encoder};
8use actix_service::{IntoService, IntoServiceFactory, Service, ServiceFactory};
9use actix_utils::mpsc;
10use either::Either;
11use futures::future::{FutureExt, LocalBoxFuture};
12use pin_project::project;
13
14use crate::connect::{Connect, ConnectResult};
15use crate::dispatcher::{Dispatcher, Message};
16use crate::error::ServiceError;
17use crate::item::Item;
18use crate::sink::Sink;
19
20type RequestItem<S, U> = Item<S, U>;
21type ResponseItem<U> = Option<<U as Encoder>::Item>;
22type ServiceResult<U, E> = Result<Message<<U as Encoder>::Item>, E>;
23
24pub struct Builder<St, Codec>(PhantomData<(St, Codec)>);
27
28impl<St: Clone, Codec> Default for Builder<St, Codec> {
29 fn default() -> Builder<St, Codec> {
30 Builder::new()
31 }
32}
33
34impl<St: Clone, Codec> Builder<St, Codec> {
35 pub fn new() -> Builder<St, Codec> {
36 Builder(PhantomData)
37 }
38
39 pub fn service<Io, C, F>(self, connect: F) -> ServiceBuilder<St, C, Io, Codec>
41 where
42 F: IntoService<C>,
43 Io: AsyncRead + AsyncWrite,
44 C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
45 Codec: Decoder + Encoder,
46 {
47 ServiceBuilder {
48 connect: connect.into_service(),
49 disconnect: None,
50 _t: PhantomData,
51 }
52 }
53
54 pub fn factory<Io, C, F>(self, connect: F) -> NewServiceBuilder<St, C, Io, Codec>
56 where
57 F: IntoServiceFactory<C>,
58 Io: AsyncRead + AsyncWrite,
59 C: ServiceFactory<
60 Config = (),
61 Request = Connect<Io, Codec>,
62 Response = ConnectResult<Io, St, Codec>,
63 >,
64 C::Error: 'static,
65 C::Future: 'static,
66 Codec: Decoder + Encoder,
67 {
68 NewServiceBuilder {
69 connect: connect.into_factory(),
70 disconnect: None,
71 _t: PhantomData,
72 }
73 }
74}
75
76pub struct ServiceBuilder<St, C, Io, Codec> {
77 connect: C,
78 disconnect: Option<Rc<dyn Fn(St, bool)>>,
79 _t: PhantomData<(St, Io, Codec)>,
80}
81
82impl<St, C, Io, Codec> ServiceBuilder<St, C, Io, Codec>
83where
84 St: Clone,
85 C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
86 Io: AsyncRead + AsyncWrite,
87 Codec: Decoder + Encoder,
88 <Codec as Encoder>::Item: 'static,
89 <Codec as Encoder>::Error: std::fmt::Debug,
90{
91 pub fn disconnect<F, Out>(mut self, disconnect: F) -> Self
95 where
96 F: Fn(St, bool) + 'static,
97 {
98 self.disconnect = Some(Rc::new(disconnect));
99 self
100 }
101
102 pub fn finish<F, T>(self, service: F) -> FramedServiceImpl<St, C, T, Io, Codec>
104 where
105 F: IntoServiceFactory<T>,
106 T: ServiceFactory<
107 Config = St,
108 Request = RequestItem<St, Codec>,
109 Response = ResponseItem<Codec>,
110 Error = C::Error,
111 InitError = C::Error,
112 >,
113 {
114 FramedServiceImpl {
115 connect: self.connect,
116 handler: Rc::new(service.into_factory()),
117 disconnect: self.disconnect.clone(),
118 _t: PhantomData,
119 }
120 }
121}
122
123pub struct NewServiceBuilder<St, C, Io, Codec> {
124 connect: C,
125 disconnect: Option<Rc<dyn Fn(St, bool)>>,
126 _t: PhantomData<(St, Io, Codec)>,
127}
128
129impl<St, C, Io, Codec> NewServiceBuilder<St, C, Io, Codec>
130where
131 St: Clone,
132 Io: AsyncRead + AsyncWrite,
133 C: ServiceFactory<
134 Config = (),
135 Request = Connect<Io, Codec>,
136 Response = ConnectResult<Io, St, Codec>,
137 >,
138 C::Error: 'static,
139 C::Future: 'static,
140 Codec: Decoder + Encoder,
141 <Codec as Encoder>::Item: 'static,
142 <Codec as Encoder>::Error: std::fmt::Debug,
143{
144 pub fn disconnect<F>(mut self, disconnect: F) -> Self
148 where
149 F: Fn(St, bool) + 'static,
150 {
151 self.disconnect = Some(Rc::new(disconnect));
152 self
153 }
154
155 pub fn finish<F, T, Cfg>(self, service: F) -> FramedService<St, C, T, Io, Codec, Cfg>
156 where
157 F: IntoServiceFactory<T>,
158 T: ServiceFactory<
159 Config = St,
160 Request = RequestItem<St, Codec>,
161 Response = ResponseItem<Codec>,
162 Error = C::Error,
163 InitError = C::Error,
164 > + 'static,
165 {
166 FramedService {
167 connect: self.connect,
168 handler: Rc::new(service.into_factory()),
169 disconnect: self.disconnect,
170 _t: PhantomData,
171 }
172 }
173}
174
175pub struct FramedService<St, C, T, Io, Codec, Cfg> {
176 connect: C,
177 handler: Rc<T>,
178 disconnect: Option<Rc<dyn Fn(St, bool)>>,
179 _t: PhantomData<(St, Io, Codec, Cfg)>,
180}
181
182impl<St, C, T, Io, Codec, Cfg> ServiceFactory for FramedService<St, C, T, Io, Codec, Cfg>
183where
184 St: Clone + 'static,
185 Io: AsyncRead + AsyncWrite,
186 C: ServiceFactory<
187 Config = (),
188 Request = Connect<Io, Codec>,
189 Response = ConnectResult<Io, St, Codec>,
190 >,
191 C::Error: 'static,
192 C::Future: 'static,
193 T: ServiceFactory<
194 Config = St,
195 Request = RequestItem<St, Codec>,
196 Response = ResponseItem<Codec>,
197 Error = C::Error,
198 InitError = C::Error,
199 > + 'static,
200 <T::Service as Service>::Future: 'static,
201 Codec: Decoder + Encoder,
202 <Codec as Encoder>::Item: 'static,
203 <Codec as Encoder>::Error: std::fmt::Debug,
204{
205 type Config = Cfg;
206 type Request = Io;
207 type Response = ();
208 type Error = ServiceError<C::Error, Codec>;
209 type InitError = C::InitError;
210 type Service = FramedServiceImpl<St, C::Service, T, Io, Codec>;
211 type Future = LocalBoxFuture<'static, Result<Self::Service, Self::InitError>>;
212
213 fn new_service(&self, _: Cfg) -> Self::Future {
214 let handler = self.handler.clone();
215 let disconnect = self.disconnect.clone();
216
217 self.connect
219 .new_service(())
220 .map(move |result| {
221 result.map(move |connect| FramedServiceImpl {
222 connect,
223 handler,
224 disconnect,
225 _t: PhantomData,
226 })
227 })
228 .boxed_local()
229 }
230}
231
232pub struct FramedServiceImpl<St, C, T, Io, Codec> {
233 connect: C,
234 handler: Rc<T>,
235 disconnect: Option<Rc<dyn Fn(St, bool)>>,
236 _t: PhantomData<(St, Io, Codec)>,
237}
238
239impl<St, C, T, Io, Codec> Service for FramedServiceImpl<St, C, T, Io, Codec>
240where
241 St: Clone,
242 Io: AsyncRead + AsyncWrite,
243 C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
244 C::Error: 'static,
245 T: ServiceFactory<
246 Config = St,
247 Request = RequestItem<St, Codec>,
248 Response = ResponseItem<Codec>,
249 Error = C::Error,
250 InitError = C::Error,
251 >,
252 <T::Service as Service>::Future: 'static,
253 Codec: Decoder + Encoder,
254 <Codec as Encoder>::Item: 'static,
255 <Codec as Encoder>::Error: std::fmt::Debug,
256{
257 type Request = Io;
258 type Response = ();
259 type Error = ServiceError<C::Error, Codec>;
260 type Future = FramedServiceImplResponse<St, Io, Codec, C, T>;
261
262 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
263 self.connect.poll_ready(cx).map_err(|e| e.into())
264 }
265
266 fn call(&mut self, req: Io) -> Self::Future {
267 let (tx, rx) = mpsc::channel();
268 let sink = Sink::new(Rc::new(move |msg| {
269 let _ = tx.send(Ok(msg));
270 }));
271 FramedServiceImplResponse {
272 inner: FramedServiceImplResponseInner::Connect(
273 self.connect.call(Connect::new(req, sink.clone())),
274 self.handler.clone(),
275 self.disconnect.clone(),
276 Some(rx),
277 ),
278 }
279 }
280}
281
282#[pin_project::pin_project]
283pub struct FramedServiceImplResponse<St, Io, Codec, C, T>
284where
285 St: Clone,
286 C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
287 C::Error: 'static,
288 T: ServiceFactory<
289 Config = St,
290 Request = RequestItem<St, Codec>,
291 Response = ResponseItem<Codec>,
292 Error = C::Error,
293 InitError = C::Error,
294 >,
295 <T::Service as Service>::Future: 'static,
296 Io: AsyncRead + AsyncWrite,
297 Codec: Encoder + Decoder,
298 <Codec as Encoder>::Item: 'static,
299 <Codec as Encoder>::Error: std::fmt::Debug,
300{
301 #[pin]
302 inner: FramedServiceImplResponseInner<St, Io, Codec, C, T>,
303}
304
305impl<St, Io, Codec, C, T> Future for FramedServiceImplResponse<St, Io, Codec, C, T>
306where
307 St: Clone,
308 C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
309 C::Error: 'static,
310 T: ServiceFactory<
311 Config = St,
312 Request = RequestItem<St, Codec>,
313 Response = ResponseItem<Codec>,
314 Error = C::Error,
315 InitError = C::Error,
316 >,
317 <T::Service as Service>::Future: 'static,
318 Io: AsyncRead + AsyncWrite,
319 Codec: Encoder + Decoder,
320 <Codec as Encoder>::Item: 'static,
321 <Codec as Encoder>::Error: std::fmt::Debug,
322{
323 type Output = Result<(), ServiceError<C::Error, Codec>>;
324
325 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
326 let mut this = self.as_mut().project();
327
328 loop {
329 match this.inner.poll(cx) {
330 Either::Left(new) => {
331 this = self.as_mut().project();
332 this.inner.set(new)
333 }
334 Either::Right(poll) => return poll,
335 };
336 }
337 }
338}
339
340#[pin_project::pin_project]
341enum FramedServiceImplResponseInner<St, Io, Codec, C, T>
342where
343 St: Clone,
344 C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
345 C::Error: 'static,
346 T: ServiceFactory<
347 Config = St,
348 Request = RequestItem<St, Codec>,
349 Response = ResponseItem<Codec>,
350 Error = C::Error,
351 InitError = C::Error,
352 >,
353 <T::Service as Service>::Future: 'static,
354 Io: AsyncRead + AsyncWrite,
355 Codec: Encoder + Decoder,
356 <Codec as Encoder>::Item: 'static,
357 <Codec as Encoder>::Error: std::fmt::Debug,
358{
359 Connect(
360 #[pin] C::Future,
361 Rc<T>,
362 Option<Rc<dyn Fn(St, bool)>>,
363 Option<mpsc::Receiver<ServiceResult<Codec, C::Error>>>,
364 ),
365 Handler(
366 #[pin] T::Future,
367 Option<ConnectResult<Io, St, Codec>>,
368 Option<Rc<dyn Fn(St, bool)>>,
369 Option<mpsc::Receiver<ServiceResult<Codec, C::Error>>>,
370 ),
371 Dispatcher(Dispatcher<St, T::Service, Io, Codec>),
372}
373
374impl<St, Io, Codec, C, T> FramedServiceImplResponseInner<St, Io, Codec, C, T>
375where
376 St: Clone,
377 C: Service<Request = Connect<Io, Codec>, Response = ConnectResult<Io, St, Codec>>,
378 C::Error: 'static,
379 T: ServiceFactory<
380 Config = St,
381 Request = RequestItem<St, Codec>,
382 Response = ResponseItem<Codec>,
383 Error = C::Error,
384 InitError = C::Error,
385 >,
386 <T::Service as Service>::Future: 'static,
387 Io: AsyncRead + AsyncWrite,
388 Codec: Encoder + Decoder,
389 <Codec as Encoder>::Item: 'static,
390 <Codec as Encoder>::Error: std::fmt::Debug,
391{
392 #[project]
393 fn poll(
394 self: Pin<&mut Self>,
395 cx: &mut Context<'_>,
396 ) -> Either<
397 FramedServiceImplResponseInner<St, Io, Codec, C, T>,
398 Poll<Result<(), ServiceError<C::Error, Codec>>>,
399 > {
400 #[project]
401 match self.project() {
402 FramedServiceImplResponseInner::Connect(fut, handler, disconnect, rx) => {
403 match fut.poll(cx) {
404 Poll::Ready(Ok(res)) => {
405 Either::Left(FramedServiceImplResponseInner::Handler(
406 handler.new_service(res.state.clone()),
407 Some(res),
408 disconnect.take(),
409 rx.take(),
410 ))
411 }
412 Poll::Pending => Either::Right(Poll::Pending),
413 Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))),
414 }
415 }
416 FramedServiceImplResponseInner::Handler(fut, res, disconnect, rx) => {
417 match fut.poll(cx) {
418 Poll::Ready(Ok(handler)) => {
419 let res = res.take().unwrap();
420 Either::Left(FramedServiceImplResponseInner::Dispatcher(
421 Dispatcher::new(
422 res.framed,
423 res.state,
424 handler,
425 res.sink,
426 rx.take().unwrap(),
427 disconnect.take(),
428 ),
429 ))
430 }
431 Poll::Pending => Either::Right(Poll::Pending),
432 Poll::Ready(Err(e)) => Either::Right(Poll::Ready(Err(e.into()))),
433 }
434 }
435 FramedServiceImplResponseInner::Dispatcher(ref mut fut) => {
436 Either::Right(fut.poll(cx))
437 }
438 }
439 }
440}