1use std::{fmt, future::poll_fn, future::Future, pin::Pin, rc::Rc};
2
3use ntex_io::{Dispatcher as IoDispatcher, Filter, Io, IoBoxed};
4use ntex_service::{Service, ServiceCtx, ServiceFactory};
5use ntex_util::time::timeout_checked;
6
7use crate::control::{Control, ControlAck};
8use crate::{codec::Codec, connection::Connection};
9use crate::{config::Config, consts, dispatcher::Dispatcher, frame, message::Message};
10
11use super::{ServerBuilder, ServerError};
12
13#[derive(Debug)]
14pub struct Server<Ctl, Pub>(Rc<ServerInner<Ctl, Pub>>);
16
17#[derive(Debug)]
18struct ServerInner<Ctl, Pub> {
19 control: Ctl,
20 publish: Pub,
21 config: Config,
22}
23
24impl Server<(), ()> {
25 pub fn build<E>() -> ServerBuilder<E> {
28 ServerBuilder::new()
29 }
30}
31
32impl<Ctl, Pub> Server<Ctl, Pub>
33where
34 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
35 Ctl::Error: fmt::Debug,
36 Ctl::InitError: fmt::Debug,
37 Pub: ServiceFactory<Message, Response = ()> + 'static,
38 Pub::Error: fmt::Debug,
39 Pub::InitError: fmt::Debug,
40{
41 pub fn new(config: Config, control: Ctl, publish: Pub) -> Server<Ctl, Pub> {
43 Self(Rc::new(ServerInner {
44 control,
45 publish,
46 config,
47 }))
48 }
49
50 pub fn handler(&self) -> ServerHandler<Ctl, Pub> {
52 ServerHandler(self.0.clone())
53 }
54}
55
56impl<Ctl, Pub> ServiceFactory<IoBoxed> for Server<Ctl, Pub>
57where
58 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
59 Ctl::Error: fmt::Debug,
60 Ctl::InitError: fmt::Debug,
61 Pub: ServiceFactory<Message, Response = ()> + 'static,
62 Pub::Error: fmt::Debug,
63 Pub::InitError: fmt::Debug,
64{
65 type Response = ();
66 type Error = ServerError<()>;
67 type Service = ServerHandler<Ctl, Pub>;
68 type InitError = ();
69
70 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
71 Ok(ServerHandler(self.0.clone()))
72 }
73}
74
75impl<F, Ctl, Pub> ServiceFactory<Io<F>> for Server<Ctl, Pub>
76where
77 F: Filter,
78 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
79 Ctl::Error: fmt::Debug,
80 Ctl::InitError: fmt::Debug,
81 Pub: ServiceFactory<Message, Response = ()> + 'static,
82 Pub::Error: fmt::Debug,
83 Pub::InitError: fmt::Debug,
84{
85 type Response = ();
86 type Error = ServerError<()>;
87 type Service = ServerHandler<Ctl, Pub>;
88 type InitError = ();
89
90 async fn create(&self, _: ()) -> Result<Self::Service, Self::InitError> {
91 Ok(ServerHandler(self.0.clone()))
92 }
93}
94
95#[derive(Debug)]
96pub struct ServerHandler<Ctl, Pub>(Rc<ServerInner<Ctl, Pub>>);
98
99impl<Ctl, Pub> Clone for ServerHandler<Ctl, Pub> {
100 fn clone(&self) -> Self {
101 Self(self.0.clone())
102 }
103}
104
105impl<Ctl, Pub> ServerHandler<Ctl, Pub>
106where
107 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
108 Ctl::Error: fmt::Debug,
109 Ctl::InitError: fmt::Debug,
110 Pub: ServiceFactory<Message, Response = ()> + 'static,
111 Pub::Error: fmt::Debug,
112 Pub::InitError: fmt::Debug,
113{
114 pub async fn run(&self, io: IoBoxed) -> Result<(), ServerError<()>> {
115 let inner = &self.0;
116
117 let (ctl_srv, pub_srv) = timeout_checked(inner.config.0.handshake_timeout.get(), async {
118 read_preface(&io).await?;
119
120 let pub_srv = inner.publish.create(()).await.map_err(|e| {
122 log::error!("Publish service init error: {e:?}");
123 ServerError::PublishServiceError
124 })?;
125
126 let ctl_srv = inner.control.create(()).await.map_err(|e| {
128 log::error!("Control service init error: {e:?}");
129 ServerError::ControlServiceError
130 })?;
131
132 Ok::<_, ServerError<()>>((ctl_srv, pub_srv))
133 })
134 .await
135 .map_err(|_| ServerError::HandshakeTimeout)??;
136
137 let codec = Codec::default();
139 let con = Connection::new(
140 io.get_ref(),
141 codec.clone(),
142 inner.config.clone(),
143 true,
144 false,
145 );
146 let con2 = con.clone();
147
148 let mut fut = IoDispatcher::new(
150 io,
151 codec,
152 Dispatcher::new(con, ctl_srv, pub_srv),
153 &inner.config.inner().dispatcher_config,
154 );
155 poll_fn(|cx| {
156 if con2.config().is_shutdown() {
157 con2.disconnect_when_ready();
158 }
159 Pin::new(&mut fut).poll(cx)
160 })
161 .await
162 .map_err(|_| ServerError::Dispatcher)
163 }
164}
165
166impl<Ctl, Pub> Service<IoBoxed> for ServerHandler<Ctl, Pub>
167where
168 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
169 Ctl::Error: fmt::Debug,
170 Ctl::InitError: fmt::Debug,
171 Pub: ServiceFactory<Message, Response = ()> + 'static,
172 Pub::Error: fmt::Debug,
173 Pub::InitError: fmt::Debug,
174{
175 type Response = ();
176 type Error = ServerError<()>;
177
178 async fn call(
179 &self,
180 io: IoBoxed,
181 _: ServiceCtx<'_, Self>,
182 ) -> Result<Self::Response, Self::Error> {
183 self.run(io).await
184 }
185}
186
187impl<F, Ctl, Pub> Service<Io<F>> for ServerHandler<Ctl, Pub>
188where
189 F: Filter,
190 Ctl: ServiceFactory<Control<Pub::Error>, Response = ControlAck> + 'static,
191 Ctl::Error: fmt::Debug,
192 Ctl::InitError: fmt::Debug,
193 Pub: ServiceFactory<Message, Response = ()> + 'static,
194 Pub::Error: fmt::Debug,
195 Pub::InitError: fmt::Debug,
196{
197 type Response = ();
198 type Error = ServerError<()>;
199
200 async fn call(
201 &self,
202 req: Io<F>,
203 _: ServiceCtx<'_, Self>,
204 ) -> Result<Self::Response, Self::Error> {
205 self.run(req.into()).await
206 }
207}
208
209async fn read_preface(io: &IoBoxed) -> Result<(), ServerError<()>> {
210 loop {
211 let ready = io.with_read_buf(|buf| {
212 if buf.len() >= consts::PREFACE.len() {
213 if buf[..consts::PREFACE.len()] == consts::PREFACE {
214 buf.split_to(consts::PREFACE.len());
215 Ok(true)
216 } else {
217 log::trace!("read_preface: invalid preface {buf:?}");
218 Err(ServerError::<()>::Frame(frame::FrameError::InvalidPreface))
219 }
220 } else {
221 Ok(false)
222 }
223 })?;
224
225 if ready {
226 log::debug!("Preface has been received");
227 return Ok::<_, ServerError<_>>(());
228 } else {
229 io.read_ready()
230 .await?
231 .ok_or(ServerError::Disconnected(None))?;
232 }
233 }
234}
235
236pub async fn handle_one<Ctl, Pub>(
238 io: IoBoxed,
239 config: Config,
240 ctl_svc: Ctl,
241 pub_svc: Pub,
242) -> Result<(), ServerError<()>>
243where
244 Ctl: Service<Control<Pub::Error>, Response = ControlAck> + 'static,
245 Ctl::Error: fmt::Debug,
246 Pub: Service<Message, Response = ()> + 'static,
247 Pub::Error: fmt::Debug,
248{
249 timeout_checked(config.0.handshake_timeout.get(), async {
251 read_preface(&io).await
252 })
253 .await
254 .map_err(|_| ServerError::HandshakeTimeout)??;
255
256 let codec = Codec::default();
258 let con = Connection::new(io.get_ref(), codec.clone(), config.clone(), true, false);
259 let con2 = con.clone();
260
261 let mut fut = IoDispatcher::new(
263 io,
264 codec,
265 Dispatcher::new(con, ctl_svc, pub_svc),
266 &config.inner().dispatcher_config,
267 );
268
269 poll_fn(|cx| {
270 if con2.config().is_shutdown() {
271 con2.disconnect_when_ready();
272 }
273 Pin::new(&mut fut).poll(cx)
274 })
275 .await
276 .map_err(|_| ServerError::Dispatcher)
277}