1use std::{fmt, future::Future, future::poll_fn, pin::Pin, rc::Rc};
2
3use ntex_io::{Dispatcher as IoDispatcher, Filter, Io, IoBoxed};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory};
6use ntex_util::{channel::pool, time::timeout_checked};
7
8use crate::control::{Control, ControlAck};
9use crate::{codec::Codec, connection::Connection, default::DefaultControlService};
10use crate::{config::ServiceConfig, consts, dispatcher::Dispatcher, frame, message::Message};
11
12use super::ServerError;
13
14#[derive(Debug)]
15pub struct Server<Pub, Ctl>(ServerInner<Pub, Ctl>);
17
18#[derive(Debug)]
19struct ServerInner<Pub, Ctl> {
20 control: Rc<Ctl>,
21 publish: Rc<Pub>,
22 pool: pool::Pool<()>,
23}
24
25impl<Pub, Ctl> Clone for ServerInner<Pub, Ctl> {
26 fn clone(&self) -> Self {
27 Self {
28 control: self.control.clone(),
29 publish: self.publish.clone(),
30 pool: self.pool.clone(),
31 }
32 }
33}
34
35impl<Pub> Server<Pub, DefaultControlService>
36where
37 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
38 Pub::Error: fmt::Debug,
39 Pub::InitError: fmt::Debug,
40{
41 pub fn new(publish: Pub) -> Self {
43 Self(ServerInner {
44 publish: Rc::new(publish),
45 control: Rc::new(DefaultControlService),
46 pool: pool::new(),
47 })
48 }
49}
50
51impl<Pub, Ctl> Server<Pub, Ctl>
52where
53 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
54 Ctl::Error: fmt::Debug,
55 Ctl::InitError: fmt::Debug,
56 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
57 Pub::Error: fmt::Debug,
58 Pub::InitError: fmt::Debug,
59{
60 pub fn control<S, F>(&self, service: F) -> Server<Pub, S>
62 where
63 F: IntoServiceFactory<S, Control<Pub::Error>, SharedCfg>,
64 S: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
65 S::Error: fmt::Debug,
66 S::InitError: fmt::Debug,
67 {
68 Server(ServerInner {
69 control: Rc::new(service.into_factory()),
70 publish: self.0.publish.clone(),
71 pool: self.0.pool.clone(),
72 })
73 }
74
75 pub fn handler(&self, cfg: SharedCfg) -> ServerHandler<Pub, Ctl> {
77 ServerHandler::new(cfg, self.0.clone())
78 }
79}
80
81impl<Pub, Ctl> ServiceFactory<IoBoxed, SharedCfg> for Server<Pub, Ctl>
82where
83 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
84 Ctl::Error: fmt::Debug,
85 Ctl::InitError: fmt::Debug,
86 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
87 Pub::Error: fmt::Debug,
88 Pub::InitError: fmt::Debug,
89{
90 type Response = ();
91 type Error = ServerError<()>;
92 type Service = ServerHandler<Pub, Ctl>;
93 type InitError = ();
94
95 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
96 Ok(ServerHandler::new(cfg, self.0.clone()))
97 }
98}
99
100impl<F, Pub, Ctl> ServiceFactory<Io<F>, SharedCfg> for Server<Pub, Ctl>
101where
102 F: Filter,
103 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
104 Ctl::Error: fmt::Debug,
105 Ctl::InitError: fmt::Debug,
106 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
107 Pub::Error: fmt::Debug,
108 Pub::InitError: fmt::Debug,
109{
110 type Response = ();
111 type Error = ServerError<()>;
112 type Service = ServerHandler<Pub, Ctl>;
113 type InitError = ();
114
115 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
116 Ok(ServerHandler::new(cfg, self.0.clone()))
117 }
118}
119
120#[derive(Debug)]
121pub struct ServerHandler<Pub, Ctl> {
123 inner: ServerInner<Pub, Ctl>,
124 cfg: Cfg<ServiceConfig>,
125 shared: SharedCfg,
126}
127
128impl<Pub, Ctl> ServerHandler<Pub, Ctl> {
129 fn new(shared: SharedCfg, inner: ServerInner<Pub, Ctl>) -> Self {
130 let cfg = shared.get();
131 Self { cfg, shared, inner }
132 }
133}
134
135impl<Pub, Ctl> Clone for ServerHandler<Pub, Ctl> {
136 fn clone(&self) -> Self {
137 Self {
138 inner: self.inner.clone(),
139 cfg: self.cfg,
140 shared: self.shared,
141 }
142 }
143}
144
145impl<Pub, Ctl> ServerHandler<Pub, Ctl>
146where
147 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
148 Ctl::Error: fmt::Debug,
149 Ctl::InitError: fmt::Debug,
150 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
151 Pub::Error: fmt::Debug,
152 Pub::InitError: fmt::Debug,
153{
154 pub async fn run(&self, io: IoBoxed) -> Result<(), ServerError<()>> {
155 let inner = &self.inner;
156
157 let (ctl_srv, pub_srv) = timeout_checked(self.cfg.handshake_timeout, async {
158 read_preface(&io).await?;
159
160 let pub_srv = inner.publish.create(self.shared).await.map_err(|e| {
162 log::error!("Publish service init error: {e:?}");
163 ServerError::PublishServiceError
164 })?;
165
166 let ctl_srv = inner.control.create(self.shared).await.map_err(|e| {
168 log::error!("Control service init error: {e:?}");
169 ServerError::ControlServiceError
170 })?;
171
172 Ok::<_, ServerError<()>>((ctl_srv, pub_srv))
173 })
174 .await
175 .map_err(|_| ServerError::HandshakeTimeout)??;
176
177 let codec = Codec::default();
179 let con = Connection::new(
180 true,
181 io.get_ref(),
182 codec.clone(),
183 self.cfg,
184 true,
185 false,
186 self.inner.pool.clone(),
187 );
188 let con2 = con.clone();
189
190 let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_srv, pub_srv));
192 poll_fn(|cx| {
193 if con2.config().is_shutdown() {
194 con2.disconnect_when_ready();
195 }
196 Pin::new(&mut fut).poll(cx)
197 })
198 .await
199 .map_err(|_| ServerError::Dispatcher)
200 }
201}
202
203impl<Pub, Ctl> Service<IoBoxed> for ServerHandler<Pub, Ctl>
204where
205 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
206 Ctl::Error: fmt::Debug,
207 Ctl::InitError: fmt::Debug,
208 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
209 Pub::Error: fmt::Debug,
210 Pub::InitError: fmt::Debug,
211{
212 type Response = ();
213 type Error = ServerError<()>;
214
215 async fn call(
216 &self,
217 io: IoBoxed,
218 _: ServiceCtx<'_, Self>,
219 ) -> Result<Self::Response, Self::Error> {
220 self.run(io).await
221 }
222}
223
224impl<F, Pub, Ctl> Service<Io<F>> for ServerHandler<Pub, Ctl>
225where
226 F: Filter,
227 Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
228 Ctl::Error: fmt::Debug,
229 Ctl::InitError: fmt::Debug,
230 Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
231 Pub::Error: fmt::Debug,
232 Pub::InitError: fmt::Debug,
233{
234 type Response = ();
235 type Error = ServerError<()>;
236
237 async fn call(
238 &self,
239 req: Io<F>,
240 _: ServiceCtx<'_, Self>,
241 ) -> Result<Self::Response, Self::Error> {
242 self.run(req.into()).await
243 }
244}
245
246async fn read_preface(io: &IoBoxed) -> Result<(), ServerError<()>> {
247 loop {
248 let ready = io.with_read_buf(|buf| {
249 if buf.len() >= consts::PREFACE.len() {
250 if buf[..consts::PREFACE.len()] == consts::PREFACE {
251 buf.split_to(consts::PREFACE.len());
252 Ok(true)
253 } else {
254 log::trace!("read_preface: invalid preface {buf:?}");
255 Err(ServerError::<()>::Frame(frame::FrameError::InvalidPreface))
256 }
257 } else {
258 Ok(false)
259 }
260 })?;
261
262 if ready {
263 log::debug!("Preface has been received");
264 return Ok::<_, ServerError<_>>(());
265 } else {
266 io.read_ready()
267 .await?
268 .ok_or(ServerError::Disconnected(None))?;
269 }
270 }
271}
272
273pub async fn handle_one<Pub, Ctl>(
275 io: IoBoxed,
276 pub_svc: Pub,
277 ctl_svc: Ctl,
278) -> Result<(), ServerError<()>>
279where
280 Ctl: Service<Control<Pub::Error>, Response = ControlAck> + 'static,
281 Ctl::Error: fmt::Debug,
282 Pub: Service<Message, Response = ()> + 'static,
283 Pub::Error: fmt::Debug,
284{
285 let config: Cfg<ServiceConfig> = io.shared().get();
286
287 timeout_checked(config.handshake_timeout, async { read_preface(&io).await })
289 .await
290 .map_err(|_| ServerError::HandshakeTimeout)??;
291
292 let codec = Codec::default();
294 let con = Connection::new(
295 true,
296 io.get_ref(),
297 codec.clone(),
298 config,
299 true,
300 false,
301 pool::new(),
302 );
303 let con2 = con.clone();
304
305 let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_svc, pub_svc));
307
308 poll_fn(|cx| {
309 if con2.config().is_shutdown() {
310 con2.disconnect_when_ready();
311 }
312 Pin::new(&mut fut).poll(cx)
313 })
314 .await
315 .map_err(|_| ServerError::Dispatcher)
316}