1use std::{fmt, future::Future, future::poll_fn, pin::Pin, rc::Rc};
2
3use ntex_dispatcher::Dispatcher as IoDispatcher;
4use ntex_io::{Filter, Io, IoBoxed};
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory};
7use ntex_util::{channel::pool, time::timeout_checked};
8
9use crate::control::{Control, ControlAck};
10use crate::{codec::Codec, connection::Connection, default::DefaultControlService};
11use crate::{config::ServiceConfig, consts, dispatcher::Dispatcher, frame, message::Message};
12
13use super::ServerError;
14
15#[derive(Debug)]
16pub struct Server<Pub, Ctl>(ServerInner<Pub, Ctl>);
18
19#[derive(Debug)]
20struct ServerInner<Pub, Ctl> {
21 control: Rc<Ctl>,
22 publish: Rc<Pub>,
23 pool: pool::Pool<()>,
24}
25
26impl<Pub, Ctl> Clone for ServerInner<Pub, Ctl> {
27 fn clone(&self) -> Self {
28 Self {
29 control: self.control.clone(),
30 publish: self.publish.clone(),
31 pool: self.pool.clone(),
32 }
33 }
34}
35
36impl<Pub> Server<Pub, DefaultControlService>
37where
38 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
39 Pub::Error: fmt::Debug,
40 Pub::InitError: fmt::Debug,
41{
42 pub fn new(publish: Pub) -> Self {
44 Self(ServerInner {
45 publish: Rc::new(publish),
46 control: Rc::new(DefaultControlService),
47 pool: pool::new(),
48 })
49 }
50}
51
52impl<Pub, Ctl> Server<Pub, Ctl>
53where
54 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
55 Ctl::Error: fmt::Debug,
56 Ctl::InitError: fmt::Debug,
57 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
58 Pub::Error: fmt::Debug,
59 Pub::InitError: fmt::Debug,
60{
61 pub fn control<S, F>(&self, service: F) -> Server<Pub, S>
63 where
64 F: IntoServiceFactory<S, Control<Pub::Error>, SharedCfg>,
65 S: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
66 S::Error: fmt::Debug,
67 S::InitError: fmt::Debug,
68 {
69 Server(ServerInner {
70 control: Rc::new(service.into_factory()),
71 publish: self.0.publish.clone(),
72 pool: self.0.pool.clone(),
73 })
74 }
75
76 pub fn handler(&self, cfg: SharedCfg) -> ServerHandler<Pub, Ctl> {
78 ServerHandler::new(cfg, self.0.clone())
79 }
80}
81
82impl<Pub, Ctl> ServiceFactory<IoBoxed, SharedCfg> for Server<Pub, Ctl>
83where
84 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
85 Ctl::Error: fmt::Debug,
86 Ctl::InitError: fmt::Debug,
87 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
88 Pub::Error: fmt::Debug,
89 Pub::InitError: fmt::Debug,
90{
91 type Response = ();
92 type Error = ServerError<()>;
93 type Service = ServerHandler<Pub, Ctl>;
94 type InitError = ();
95
96 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
97 Ok(ServerHandler::new(cfg, self.0.clone()))
98 }
99}
100
101impl<F, Pub, Ctl> ServiceFactory<Io<F>, SharedCfg> for Server<Pub, Ctl>
102where
103 F: Filter,
104 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
105 Ctl::Error: fmt::Debug,
106 Ctl::InitError: fmt::Debug,
107 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
108 Pub::Error: fmt::Debug,
109 Pub::InitError: fmt::Debug,
110{
111 type Response = ();
112 type Error = ServerError<()>;
113 type Service = ServerHandler<Pub, Ctl>;
114 type InitError = ();
115
116 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
117 Ok(ServerHandler::new(cfg, self.0.clone()))
118 }
119}
120
121#[derive(Debug)]
122pub struct ServerHandler<Pub, Ctl> {
124 cfg: Cfg<ServiceConfig>,
125 inner: ServerInner<Pub, Ctl>,
126 shared: SharedCfg,
127}
128
129impl<Pub, Ctl> ServerHandler<Pub, Ctl> {
130 fn new(shared: SharedCfg, inner: ServerInner<Pub, Ctl>) -> Self {
131 let cfg = shared.get();
132 Self { cfg, inner, shared }
133 }
134}
135
136impl<Pub, Ctl> Clone for ServerHandler<Pub, Ctl> {
137 fn clone(&self) -> Self {
138 Self {
139 inner: self.inner.clone(),
140 cfg: self.cfg.clone(),
141 shared: self.shared.clone(),
142 }
143 }
144}
145
146impl<Pub, Ctl> ServerHandler<Pub, Ctl>
147where
148 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
149 Ctl::Error: fmt::Debug,
150 Ctl::InitError: fmt::Debug,
151 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
152 Pub::Error: fmt::Debug,
153 Pub::InitError: fmt::Debug,
154{
155 pub async fn run(&self, io: IoBoxed) -> Result<(), ServerError<()>> {
156 let inner = &self.inner;
157
158 let (ctl_srv, pub_srv) = timeout_checked(self.cfg.handshake_timeout, async {
159 read_preface(&io).await?;
160
161 let pub_srv = inner
163 .publish
164 .create(self.shared.clone())
165 .await
166 .map_err(|e| {
167 log::error!("Publish service init error: {e:?}");
168 ServerError::PublishServiceError
169 })?;
170
171 let ctl_srv = inner
173 .control
174 .create(self.shared.clone())
175 .await
176 .map_err(|e| {
177 log::error!("Control service init error: {e:?}");
178 ServerError::ControlServiceError
179 })?;
180
181 Ok::<_, ServerError<()>>((ctl_srv, pub_srv))
182 })
183 .await
184 .map_err(|()| ServerError::HandshakeTimeout)??;
185
186 let codec = Codec::default();
188 let con = Connection::new(
189 true,
190 io.get_ref(),
191 codec.clone(),
192 self.cfg.clone(),
193 true,
194 false,
195 self.inner.pool.clone(),
196 );
197 let con2 = con.clone();
198
199 let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_srv, pub_srv));
201 poll_fn(|cx| {
202 if con2.config().is_shutdown() {
203 con2.disconnect_when_ready();
204 }
205 Pin::new(&mut fut).poll(cx)
206 })
207 .await
208 .map_err(|()| ServerError::Dispatcher)
209 }
210}
211
212impl<Pub, Ctl> Service<IoBoxed> for ServerHandler<Pub, Ctl>
213where
214 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
215 Ctl::Error: fmt::Debug,
216 Ctl::InitError: fmt::Debug,
217 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
218 Pub::Error: fmt::Debug,
219 Pub::InitError: fmt::Debug,
220{
221 type Response = ();
222 type Error = ServerError<()>;
223
224 async fn call(
225 &self,
226 io: IoBoxed,
227 _: ServiceCtx<'_, Self>,
228 ) -> Result<Self::Response, Self::Error> {
229 self.run(io).await
230 }
231}
232
233impl<F, Pub, Ctl> Service<Io<F>> for ServerHandler<Pub, Ctl>
234where
235 F: Filter,
236 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
237 Ctl::Error: fmt::Debug,
238 Ctl::InitError: fmt::Debug,
239 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
240 Pub::Error: fmt::Debug,
241 Pub::InitError: fmt::Debug,
242{
243 type Response = ();
244 type Error = ServerError<()>;
245
246 async fn call(
247 &self,
248 req: Io<F>,
249 _: ServiceCtx<'_, Self>,
250 ) -> Result<Self::Response, Self::Error> {
251 self.run(req.into()).await
252 }
253}
254
255async fn read_preface(io: &IoBoxed) -> Result<(), ServerError<()>> {
256 loop {
257 let ready = io.with_read_buf(|buf| {
258 if buf.len() >= consts::PREFACE.len() {
259 if buf[..consts::PREFACE.len()] == consts::PREFACE {
260 buf.advance_to(consts::PREFACE.len());
261 Ok(true)
262 } else {
263 log::trace!("read_preface: invalid preface {buf:?}");
264 Err(ServerError::<()>::Frame(frame::FrameError::InvalidPreface))
265 }
266 } else {
267 Ok(false)
268 }
269 })?;
270
271 if ready {
272 log::debug!("Preface has been received");
273 return Ok::<_, ServerError<_>>(());
274 }
275 io.read_ready()
276 .await?
277 .ok_or(ServerError::Disconnected(None))?;
278 }
279}
280
281pub async fn handle_one<Pub, Ctl>(
283 io: IoBoxed,
284 pub_svc: Pub,
285 ctl_svc: Ctl,
286) -> Result<(), ServerError<()>>
287where
288 Ctl: Service<Control<Pub::Error>, Response = ControlAck> + 'static,
289 Ctl::Error: fmt::Debug,
290 Pub: Service<Message, Response = ()> + 'static,
291 Pub::Error: fmt::Debug,
292{
293 let config: Cfg<ServiceConfig> = io.shared().get();
294
295 timeout_checked(config.handshake_timeout, async { read_preface(&io).await })
297 .await
298 .map_err(|()| ServerError::HandshakeTimeout)??;
299
300 let codec = Codec::default();
302 let con = Connection::new(
303 true,
304 io.get_ref(),
305 codec.clone(),
306 config,
307 true,
308 false,
309 pool::new(),
310 );
311 let con2 = con.clone();
312
313 let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_svc, pub_svc));
315
316 poll_fn(|cx| {
317 if con2.config().is_shutdown() {
318 con2.disconnect_when_ready();
319 }
320 Pin::new(&mut fut).poll(cx)
321 })
322 .await
323 .map_err(|()| ServerError::Dispatcher)
324}