1use std::{fmt, marker, rc::Rc};
2
3use ntex_dispatcher::Dispatcher as IoDispatcher;
4use ntex_io::{Filter, Io, IoBoxed};
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{IntoServiceFactory, Pipeline, Service, ServiceCtx, ServiceFactory};
7use ntex_util::time::{Millis, timeout_checked};
8
9use crate::codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, ProtocolIdError, protocol::ProtocolId};
10use crate::{AmqpServiceConfig, Connection, ControlFrame, State, default::DefaultControlService};
11use crate::{dispatcher::Dispatcher, types::Message};
12
13use super::handshake::{Handshake, HandshakeAck};
14use super::{Error, HandshakeError, ServerError};
15
16pub struct Server<St, H, Ctl, Pb> {
18 handshake: H,
19 inner: Rc<ServerInner<St, Ctl, Pb>>,
20}
21
22pub struct ServerBuilder<St, H, Ctl> {
24 handshake: H,
25 control: Ctl,
26 _t: marker::PhantomData<St>,
27}
28
29pub(super) struct ServerInner<St, Ctl, Pb> {
30 control: Ctl,
31 publish: Pb,
32 _t: marker::PhantomData<St>,
33}
34
35impl<St> Server<St, (), (), ()>
36where
37 St: 'static,
38{
39 pub fn build<F, H>(handshake: F) -> ServerBuilder<St, H, DefaultControlService<St, H::Error>>
41 where
42 F: IntoServiceFactory<H, Handshake, SharedCfg>,
43 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>>,
44 {
45 ServerBuilder {
46 handshake: handshake.into_factory(),
47 control: DefaultControlService::default(),
48 _t: marker::PhantomData,
49 }
50 }
51}
52
53impl<St, H, Ctl> ServerBuilder<St, H, Ctl>
54where
55 St: 'static,
56 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
57 Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
58 Ctl::InitError: fmt::Debug,
59 Error: From<Ctl::Error>,
60{
61 pub fn control<F, S>(self, service: F) -> ServerBuilder<St, H, S>
63 where
64 F: IntoServiceFactory<S, ControlFrame, State<St>>,
65 S: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
66 S::InitError: fmt::Debug,
67 Error: From<S::Error>,
68 {
69 ServerBuilder {
70 handshake: self.handshake,
71 control: service.into_factory(),
72 _t: marker::PhantomData,
73 }
74 }
75
76 pub fn finish<S, Pb>(self, service: S) -> Server<St, H, Ctl, Pb>
78 where
79 S: IntoServiceFactory<Pb, Message, State<St>>,
80 Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
81 Pb::InitError: fmt::Debug,
82 Error: From<Pb::Error> + From<Ctl::Error>,
83 {
84 Server {
85 handshake: self.handshake,
86 inner: Rc::new(ServerInner {
87 publish: service.into_factory(),
88 control: self.control,
89 _t: marker::PhantomData,
90 }),
91 }
92 }
93}
94
95impl<F, St, H, Ctl, Pb> ServiceFactory<Io<F>, SharedCfg> for Server<St, H, Ctl, Pb>
96where
97 F: Filter,
98 St: 'static,
99 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
100 Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
101 Ctl::InitError: fmt::Debug,
102 Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
103 Pb::InitError: fmt::Debug,
104 Error: From<Pb::Error> + From<Ctl::Error>,
105{
106 type Response = ();
107 type Error = ServerError<H::Error>;
108 type Service = ServerHandler<St, H::Service, Ctl, Pb>;
109 type InitError = H::InitError;
110
111 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
112 self.handshake
113 .pipeline(cfg.clone())
114 .await
115 .map(move |handshake| ServerHandler {
116 handshake,
117 cfg: cfg.get(),
118 inner: self.inner.clone(),
119 })
120 }
121}
122
123impl<St, H, Ctl, Pb> ServiceFactory<IoBoxed, SharedCfg> for Server<St, H, Ctl, Pb>
124where
125 St: 'static,
126 H: ServiceFactory<Handshake, SharedCfg, Response = HandshakeAck<St>> + 'static,
127 Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
128 Ctl::InitError: fmt::Debug,
129 Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
130 Pb::InitError: fmt::Debug,
131 Error: From<Pb::Error> + From<Ctl::Error>,
132{
133 type Response = ();
134 type Error = ServerError<H::Error>;
135 type Service = ServerHandler<St, H::Service, Ctl, Pb>;
136 type InitError = H::InitError;
137
138 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
139 self.handshake
140 .pipeline(cfg.clone())
141 .await
142 .map(move |handshake| ServerHandler {
143 handshake,
144 cfg: cfg.get(),
145 inner: self.inner.clone(),
146 })
147 }
148}
149
150pub struct ServerHandler<St, H, Ctl, Pb> {
152 cfg: Cfg<AmqpServiceConfig>,
153 handshake: Pipeline<H>,
154 inner: Rc<ServerInner<St, Ctl, Pb>>,
155}
156
157impl<St, H, Ctl, Pb> ServerHandler<St, H, Ctl, Pb>
158where
159 St: 'static,
160 H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
161 Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
162 Ctl::InitError: fmt::Debug,
163 Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
164 Pb::InitError: fmt::Debug,
165 Error: From<Pb::Error> + From<Ctl::Error>,
166{
167 async fn create(&self, req: IoBoxed) -> Result<(), ServerError<H::Error>> {
168 let fut = handshake(req, &self.handshake, self.cfg.clone());
169 let inner = self.inner.clone();
170
171 let (state, codec, sink, st, idle_timeout) =
172 timeout_checked(self.cfg.handshake_timeout, fut)
173 .await
174 .map_err(|()| HandshakeError::Timeout)??;
175
176 let pb_srv = inner.publish.pipeline(st.clone()).await.map_err(|e| {
178 log::error!("Publish service init error: {e:?}");
179 ServerError::PublishServiceError
180 })?;
181
182 let ctl_srv = inner.control.pipeline(st.clone()).await.map_err(|e| {
184 log::error!("Control service init error: {e:?}");
185 ServerError::ControlServiceError
186 })?;
187
188 IoDispatcher::new(
189 state,
190 codec,
191 Dispatcher::new(sink, pb_srv, ctl_srv, idle_timeout),
192 )
193 .await
194 .map_err(ServerError::Dispatcher)
195 }
196}
197
198impl<F, St, H, Ctl, Pb> Service<Io<F>> for ServerHandler<St, H, Ctl, Pb>
199where
200 F: Filter,
201 St: 'static,
202 H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
203 Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
204 Ctl::InitError: fmt::Debug,
205 Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
206 Pb::InitError: fmt::Debug,
207 Error: From<Pb::Error> + From<Ctl::Error>,
208{
209 type Response = ();
210 type Error = ServerError<H::Error>;
211
212 #[inline]
213 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
214 self.handshake.ready().await.map_err(ServerError::Service)
215 }
216
217 #[inline]
218 async fn shutdown(&self) {
219 self.handshake.shutdown().await;
220 }
221
222 async fn call(
223 &self,
224 req: Io<F>,
225 _: ServiceCtx<'_, Self>,
226 ) -> Result<Self::Response, Self::Error> {
227 self.create(IoBoxed::from(req)).await
228 }
229}
230
231impl<St, H, Ctl, Pb> Service<IoBoxed> for ServerHandler<St, H, Ctl, Pb>
232where
233 St: 'static,
234 H: Service<Handshake, Response = HandshakeAck<St>> + 'static,
235 Ctl: ServiceFactory<ControlFrame, State<St>, Response = ()> + 'static,
236 Ctl::InitError: fmt::Debug,
237 Pb: ServiceFactory<Message, State<St>, Response = ()> + 'static,
238 Pb::InitError: fmt::Debug,
239 Error: From<Pb::Error> + From<Ctl::Error>,
240{
241 type Response = ();
242 type Error = ServerError<H::Error>;
243
244 #[inline]
245 async fn ready(&self, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
246 self.handshake.ready().await.map_err(ServerError::Service)
247 }
248
249 #[inline]
250 async fn shutdown(&self) {
251 self.handshake.shutdown().await;
252 }
253
254 #[inline]
255 async fn call(
256 &self,
257 req: IoBoxed,
258 _: ServiceCtx<'_, Self>,
259 ) -> Result<Self::Response, Self::Error> {
260 self.create(req).await
261 }
262}
263
264async fn handshake<St, H>(
265 io: IoBoxed,
266 handshake: &Pipeline<H>,
267 cfg: Cfg<AmqpServiceConfig>,
268) -> Result<(IoBoxed, AmqpCodec<AmqpFrame>, Connection, State<St>, Millis), ServerError<H::Error>>
269where
270 St: 'static,
271 H: Service<Handshake, Response = HandshakeAck<St>>,
272{
273 let protocol = io
274 .recv(&ProtocolIdCodec)
275 .await
276 .map_err(HandshakeError::from)?
277 .ok_or_else(|| {
278 log::trace!("{}: Server amqp is disconnected during handshake", io.tag());
279 HandshakeError::Disconnected(None)
280 })?;
281
282 match protocol {
283 ProtocolId::Amqp | ProtocolId::AmqpSasl => {
285 io.send(protocol, &ProtocolIdCodec)
287 .await
288 .map_err(HandshakeError::from)?;
289
290 let ack = handshake
292 .call(if protocol == ProtocolId::Amqp {
293 Handshake::new_plain(io, cfg.clone())
294 } else {
295 Handshake::new_sasl(io, cfg.clone())
296 })
297 .await
298 .map_err(ServerError::Service)?;
299
300 let (st, sink, idle_timeout, io) = ack.into_inner();
301
302 let codec = AmqpCodec::new().max_size(cfg.max_size);
303
304 let local = cfg.to_open();
306 io.send(AmqpFrame::new(0, local.into()), &codec)
307 .await
308 .map_err(HandshakeError::from)?;
309
310 Ok((io, codec, sink, State::new(st), Millis::from(idle_timeout)))
311 }
312 ProtocolId::AmqpTls => Err(ServerError::Handshake(HandshakeError::from(
313 ProtocolIdError::Unexpected {
314 exp: ProtocolId::Amqp,
315 got: ProtocolId::AmqpTls,
316 },
317 ))),
318 }
319}