1use std::{fmt, future::poll_fn, future::Future, pin::Pin, rc::Rc};
2
3use ntex_io::{Dispatcher as IoDispatcher, Filter, Io, IoBoxed};
4use ntex_service::{Service, ServiceCtx, ServiceFactory};
5use ntex_util::time::timeout_checked;
6
7use crate::control::{Control, ControlAck};
8use crate::{codec::Codec, connection::Connection};
9use crate::{config::Config, consts, dispatcher::Dispatcher, frame, message::Message};
10
11use super::{ServerBuilder, ServerError};
12
13#[derive(Debug)]
14pub struct Server<Ctl, Pub>(Rc<ServerInner<Ctl, Pub>>);
16
17#[derive(Debug)]
18struct ServerInner<Ctl, Pub> {
19 control: Ctl,
20 publish: Pub,
21 config: Config,
22}
23
24impl Server<(), ()> {
25 pub fn build<E>() -> ServerBuilder<E> {
28 ServerBuilder::new()
29 }
30}
31
32impl<Ctl, Pub> Server<Ctl, Pub>
33where
34 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
35 Ctl::Error: fmt::Debug,
36 Ctl::InitError: fmt::Debug,
37 Pub: ServiceFactory<Message, Response = ()> + 'static,
38 Pub::Error: fmt::Debug,
39 Pub::InitError: fmt::Debug,
40{
41 pub fn new(config: Config, control: Ctl, publish: Pub) -> Server<Ctl, Pub> {
43 Self(Rc::new(ServerInner {
44 control,
45 publish,
46 config,
47 }))
48 }
49
50 pub fn handler(&self) -> ServerHandler<Ctl, Pub> {
52 ServerHandler(self.0.clone())
53 }
54}
55
56impl<Ctl, Pub> ServiceFactory<IoBoxed> for Server<Ctl, Pub>
57where
58 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
59 Ctl::Error: fmt::Debug,
60 Ctl::InitError: fmt::Debug,
61 Pub: ServiceFactory<Message, Response = ()> + 'static,
62 Pub::Error: fmt::Debug,
63 Pub::InitError: fmt::Debug,
64{
65 type Response = ();
66 type Error = ServerError<()>;
67 type Service = ServerHandler<Ctl, Pub>;
68 type InitError = ();
69
70 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
71 Ok(ServerHandler(self.0.clone()))
72 }
73}
74
75impl<F, Ctl, Pub> ServiceFactory<Io<F>> for Server<Ctl, Pub>
76where
77 F: Filter,
78 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
79 Ctl::Error: fmt::Debug,
80 Ctl::InitError: fmt::Debug,
81 Pub: ServiceFactory<Message, Response = ()> + 'static,
82 Pub::Error: fmt::Debug,
83 Pub::InitError: fmt::Debug,
84{
85 type Response = ();
86 type Error = ServerError<()>;
87 type Service = ServerHandler<Ctl, Pub>;
88 type InitError = ();
89
90 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
91 Ok(ServerHandler(self.0.clone()))
92 }
93}
94
95#[derive(Debug)]
96pub struct ServerHandler<Ctl, Pub>(Rc<ServerInner<Ctl, Pub>>);
98
99impl<Ctl, Pub> Clone for ServerHandler<Ctl, Pub> {
100 fn clone(&self) -> Self {
101 Self(self.0.clone())
102 }
103}
104
105impl<Ctl, Pub> ServerHandler<Ctl, Pub>
106where
107 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
108 Ctl::Error: fmt::Debug,
109 Ctl::InitError: fmt::Debug,
110 Pub: ServiceFactory<Message, Response = ()> + 'static,
111 Pub::Error: fmt::Debug,
112 Pub::InitError: fmt::Debug,
113{
114 pub async fn run(&self, io: IoBoxed) -> Result<(), ServerError<()>> {
115 let inner = &self.0;
116
117 let (ctl_srv, pub_srv) = timeout_checked(inner.config.0.handshake_timeout.get(), async {
118 read_preface(&io).await?;
119
120 let pub_srv = inner.publish.create(()).await.map_err(|e| {
122 log::error!("Publish service init error: {:?}", e);
123 ServerError::PublishServiceError
124 })?;
125
126 let ctl_srv = inner.control.create(()).await.map_err(|e| {
128 log::error!("Control service init error: {:?}", e);
129 ServerError::ControlServiceError
130 })?;
131
132 Ok::<_, ServerError<()>>((ctl_srv, pub_srv))
133 })
134 .await
135 .map_err(|_| ServerError::HandshakeTimeout)??;
136
137 let codec = Codec::default();
139 let con = Connection::new(io.get_ref(), codec.clone(), inner.config.clone(), true);
140 let con2 = con.clone();
141
142 let mut fut = IoDispatcher::new(
144 io,
145 codec,
146 Dispatcher::new(con, ctl_srv, pub_srv),
147 &inner.config.inner().dispatcher_config,
148 );
149 poll_fn(|cx| {
150 if con2.config().is_shutdown() {
151 con2.disconnect_when_ready();
152 }
153 Pin::new(&mut fut).poll(cx)
154 })
155 .await
156 .map_err(|_| ServerError::Dispatcher)
157 }
158}
159
160impl<Ctl, Pub> Service<IoBoxed> for ServerHandler<Ctl, Pub>
161where
162 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
163 Ctl::Error: fmt::Debug,
164 Ctl::InitError: fmt::Debug,
165 Pub: ServiceFactory<Message, Response = ()> + 'static,
166 Pub::Error: fmt::Debug,
167 Pub::InitError: fmt::Debug,
168{
169 type Response = ();
170 type Error = ServerError<()>;
171
172 async fn call(
173 &self,
174 io: IoBoxed,
175 _: ServiceCtx<'_, Self>,
176 ) -> Result<Self::Response, Self::Error> {
177 self.run(io).await
178 }
179}
180
181impl<F, Ctl, Pub> Service<Io<F>> for ServerHandler<Ctl, Pub>
182where
183 F: Filter,
184 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
185 Ctl::Error: fmt::Debug,
186 Ctl::InitError: fmt::Debug,
187 Pub: ServiceFactory<Message, Response = ()> + 'static,
188 Pub::Error: fmt::Debug,
189 Pub::InitError: fmt::Debug,
190{
191 type Response = ();
192 type Error = ServerError<()>;
193
194 async fn call(
195 &self,
196 req: Io<F>,
197 _: ServiceCtx<'_, Self>,
198 ) -> Result<Self::Response, Self::Error> {
199 self.run(req.into()).await
200 }
201}
202
203async fn read_preface(io: &IoBoxed) -> Result<(), ServerError<()>> {
204 loop {
205 let ready = io.with_read_buf(|buf| {
206 if buf.len() >= consts::PREFACE.len() {
207 if buf[..consts::PREFACE.len()] == consts::PREFACE {
208 buf.split_to(consts::PREFACE.len());
209 Ok(true)
210 } else {
211 log::trace!("read_preface: invalid preface {:?}", buf);
212 Err(ServerError::<()>::Frame(frame::FrameError::InvalidPreface))
213 }
214 } else {
215 Ok(false)
216 }
217 })?;
218
219 if ready {
220 log::debug!("Preface has been received");
221 return Ok::<_, ServerError<_>>(());
222 } else {
223 io.read_ready()
224 .await?
225 .ok_or(ServerError::Disconnected(None))?;
226 }
227 }
228}
229
230pub async fn handle_one<Ctl, Pub>(
232 io: IoBoxed,
233 config: Config,
234 ctl_svc: Ctl,
235 pub_svc: Pub,
236) -> Result<(), ServerError<()>>
237where
238 Ctl: Service<Control<Pub::Error>, Response = ControlAck> + 'static,
239 Ctl::Error: fmt::Debug,
240 Pub: Service<Message, Response = ()> + 'static,
241 Pub::Error: fmt::Debug,
242{
243 timeout_checked(config.0.handshake_timeout.get(), async {
245 read_preface(&io).await
246 })
247 .await
248 .map_err(|_| ServerError::HandshakeTimeout)??;
249
250 let codec = Codec::default();
252 let con = Connection::new(io.get_ref(), codec.clone(), config.clone(), true);
253 let con2 = con.clone();
254
255 let mut fut = IoDispatcher::new(
257 io,
258 codec,
259 Dispatcher::new(con, ctl_svc, pub_svc),
260 &config.inner().dispatcher_config,
261 );
262
263 poll_fn(|cx| {
264 if con2.config().is_shutdown() {
265 con2.disconnect_when_ready();
266 }
267 Pin::new(&mut fut).poll(cx)
268 })
269 .await
270 .map_err(|_| ServerError::Dispatcher)
271}