Skip to main content

ntex_amqp/server/
service.rs

1use std::{fmt, marker, rc::Rc};
2
3use ntex_dispatcher::Dispatcher as IoDispatcher;
4use ntex_io::{Filter, Io, IoBoxed};
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{IntoServiceFactory, Pipeline, Service, ServiceCtx, ServiceFactory};
7use ntex_util::time::{Millis, timeout_checked};
8
9use crate::codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, ProtocolIdError, protocol::ProtocolId};
10use crate::{AmqpServiceConfig, Connection, ControlFrame, State, default::DefaultControlService};
11use crate::{dispatcher::Dispatcher, types::Message};
12
13use super::handshake::{Handshake, HandshakeAck};
14use super::{Error, HandshakeError, ServerError};
15
16/// Amqp server factory
17pub struct Server<St, H, Ctl, Pb> {
18    handshake: H,
19    inner: Rc<ServerInner<St, Ctl, Pb>>,
20}
21
22/// Amqp server builder
23pub struct ServerBuilder<St, H, Ctl> {
24    handshake: H,
25    control: Ctl,
26    _t: marker::PhantomData<St>,
27}
28
29pub(super) struct ServerInner<St, Ctl, Pb> {
30    control: Ctl,
31    publish: Pb,
32    _t: marker::PhantomData<St>,
33}
34
35impl<St> Server<St, (), (), ()>
36where
37    St: 'static,
38{
39    /// Start server building process with provided handshake service
40    pub fn build<F, H>(handshake: F) -> ServerBuilder<St, H, DefaultControlService<St, H::Error>>
41    where
42        F: IntoServiceFactory<H, Handshake, SharedCfg>,
43        H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>>,
44    {
45        ServerBuilder {
46            handshake: handshake.into_factory(),
47            control: DefaultControlService::default(),
48            _t: marker::PhantomData,
49        }
50    }
51}
52
53impl<St, H, Ctl> ServerBuilder<St, H, Ctl>
54where
55    St: 'static,
56    H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
57    Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
58    Ctl::InitError: fmt::Debug,
59    Error: From<Ctl::Error>,
60{
61    /// Service to call with control frames
62    pub fn control<F, S>(self, service: F) -> ServerBuilder<St, H, S>
63    where
64        F: IntoServiceFactory<S, ControlFrame, State<St>>,
65        S: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
66        S::InitError: fmt::Debug,
67        Error: From<S::Error>,
68    {
69        ServerBuilder {
70            handshake: self.handshake,
71            control: service.into_factory(),
72            _t: marker::PhantomData,
73        }
74    }
75
76    /// Set service to execute for incoming links and create service factory
77    pub fn finish<S, Pb>(self, service: S) -> Server<St, H, Ctl, Pb>
78    where
79        S: IntoServiceFactory<Pb, Message, State<St>>,
80        Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
81        Pb::InitError: fmt::Debug,
82        Error: From<Pb::Error> + From<Ctl::Error>,
83    {
84        Server {
85            handshake: self.handshake,
86            inner: Rc::new(ServerInner {
87                publish: service.into_factory(),
88                control: self.control,
89                _t: marker::PhantomData,
90            }),
91        }
92    }
93}
94
95impl<F, St, H, Ctl, Pb> ServiceFactory<Io<F>, SharedCfg> for Server<St, H, Ctl, Pb>
96where
97    F: Filter,
98    St: 'static,
99    H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
100    Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
101    Ctl::InitError: fmt::Debug,
102    Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
103    Pb::InitError: fmt::Debug,
104    Error: From<Pb::Error> + From<Ctl::Error>,
105{
106    type Response = ();
107    type Error = ServerError<H::Error>;
108    type Service = ServerHandler<St, H::Service, Ctl, Pb>;
109    type InitError = H::InitError;
110
111    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
112        self.handshake
113            .pipeline(cfg.clone())
114            .await
115            .map(move |handshake| ServerHandler {
116                handshake,
117                cfg: cfg.get(),
118                inner: self.inner.clone(),
119            })
120    }
121}
122
123impl<St, H, Ctl, Pb> ServiceFactory<IoBoxed, SharedCfg> for Server<St, H, Ctl, Pb>
124where
125    St: 'static,
126    H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
127    Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
128    Ctl::InitError: fmt::Debug,
129    Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
130    Pb::InitError: fmt::Debug,
131    Error: From<Pb::Error> + From<Ctl::Error>,
132{
133    type Response = ();
134    type Error = ServerError<H::Error>;
135    type Service = ServerHandler<St, H::Service, Ctl, Pb>;
136    type InitError = H::InitError;
137
138    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
139        self.handshake
140            .pipeline(cfg.clone())
141            .await
142            .map(move |handshake| ServerHandler {
143                handshake,
144                cfg: cfg.get(),
145                inner: self.inner.clone(),
146            })
147    }
148}
149
150/// Amqp connections handler
151pub struct ServerHandler<St, H, Ctl, Pb> {
152    cfg: Cfg<AmqpServiceConfig>,
153    handshake: Pipeline<H>,
154    inner: Rc<ServerInner<St, Ctl, Pb>>,
155}
156
157impl<St, H, Ctl, Pb> ServerHandler<St, H, Ctl, Pb>
158where
159    St: 'static,
160    H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
161    Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
162    Ctl::InitError: fmt::Debug,
163    Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
164    Pb::InitError: fmt::Debug,
165    Error: From<Pb::Error> + From<Ctl::Error>,
166{
167    async fn create(&self, req: IoBoxed) -> Result<(), ServerError<H::Error>> {
168        let fut = handshake(req, &self.handshake, self.cfg.clone());
169        let inner = self.inner.clone();
170
171        let (state, codec, sink, st, idle_timeout) =
172            timeout_checked(self.cfg.handshake_timeout, fut)
173                .await
174                .map_err(|()| HandshakeError::Timeout)??;
175
176        // create publish service
177        let pb_srv = inner.publish.pipeline(st.clone()).await.map_err(|e| {
178            log::error!("Publish service init error: {e:?}");
179            ServerError::PublishServiceError
180        })?;
181
182        // create control service
183        let ctl_srv = inner.control.pipeline(st.clone()).await.map_err(|e| {
184            log::error!("Control service init error: {e:?}");
185            ServerError::ControlServiceError
186        })?;
187
188        IoDispatcher::new(
189            state,
190            codec,
191            Dispatcher::new(sink, pb_srv, ctl_srv, idle_timeout),
192        )
193        .await
194        .map_err(ServerError::Dispatcher)
195    }
196}
197
198impl<F, St, H, Ctl, Pb> Service<Io<F>> for ServerHandler<St, H, Ctl, Pb>
199where
200    F: Filter,
201    St: 'static,
202    H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
203    Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
204    Ctl::InitError: fmt::Debug,
205    Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
206    Pb::InitError: fmt::Debug,
207    Error: From<Pb::Error> + From<Ctl::Error>,
208{
209    type Response = ();
210    type Error = ServerError<H::Error>;
211
212    #[inline]
213    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
214        self.handshake.ready().await.map_err(ServerError::Service)
215    }
216
217    #[inline]
218    async fn shutdown(&self) {
219        self.handshake.shutdown().await;
220    }
221
222    async fn call(
223        &self,
224        req: Io<F>,
225        _: ServiceCtx<'_, Self>,
226    ) -> Result<Self::Response, Self::Error> {
227        self.create(IoBoxed::from(req)).await
228    }
229}
230
231impl<St, H, Ctl, Pb> Service<IoBoxed> for ServerHandler<St, H, Ctl, Pb>
232where
233    St: 'static,
234    H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
235    Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
236    Ctl::InitError: fmt::Debug,
237    Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
238    Pb::InitError: fmt::Debug,
239    Error: From<Pb::Error> + From<Ctl::Error>,
240{
241    type Response = ();
242    type Error = ServerError<H::Error>;
243
244    #[inline]
245    async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
246        self.handshake.ready().await.map_err(ServerError::Service)
247    }
248
249    #[inline]
250    async fn shutdown(&self) {
251        self.handshake.shutdown().await;
252    }
253
254    #[inline]
255    async fn call(
256        &self,
257        req: IoBoxed,
258        _: ServiceCtx<'_, Self>,
259    ) -> Result<Self::Response, Self::Error> {
260        self.create(req).await
261    }
262}
263
264async fn handshake<St, H>(
265    io: IoBoxed,
266    handshake: &Pipeline<H>,
267    cfg: Cfg<AmqpServiceConfig>,
268) -> Result<(IoBoxed, AmqpCodec<AmqpFrame>, Connection, State<St>, Millis), ServerError<H::Error>>
269where
270    St: 'static,
271    H: Service<Handshake, Response = HandshakeAck<St>>,
272{
273    let protocol = io
274        .recv(&ProtocolIdCodec)
275        .await
276        .map_err(HandshakeError::from)?
277        .ok_or_else(|| {
278            log::trace!("{}: Server amqp is disconnected during handshake", io.tag());
279            HandshakeError::Disconnected(None)
280        })?;
281
282    match protocol {
283        // start amqp processing
284        ProtocolId::Amqp | ProtocolId::AmqpSasl => {
285            // confirm protocol
286            io.send(protocol, &ProtocolIdCodec)
287                .await
288                .map_err(HandshakeError::from)?;
289
290            // handshake protocol
291            let ack = handshake
292                .call(if protocol == ProtocolId::Amqp {
293                    Handshake::new_plain(io, cfg.clone())
294                } else {
295                    Handshake::new_sasl(io, cfg.clone())
296                })
297                .await
298                .map_err(ServerError::Service)?;
299
300            let (st, sink, idle_timeout, io) = ack.into_inner();
301
302            let codec = AmqpCodec::new().max_size(cfg.max_size);
303
304            // confirm Open
305            let local = cfg.to_open();
306            io.send(AmqpFrame::new(0, local.into()), &codec)
307                .await
308                .map_err(HandshakeError::from)?;
309
310            Ok((io, codec, sink, State::new(st), Millis::from(idle_timeout)))
311        }
312        ProtocolId::AmqpTls => Err(ServerError::Handshake(HandshakeError::from(
313            ProtocolIdError::Unexpected {
314                exp: ProtocolId::Amqp,
315                got: ProtocolId::AmqpTls,
316            },
317        ))),
318    }
319}