ntex_h2/server/
service.rs

1use std::{fmt, future::Future, future::poll_fn, pin::Pin, rc::Rc};
2
3use ntex_io::{Dispatcher as IoDispatcher, Filter, Io, IoBoxed};
4use ntex_service::cfg::{Cfg, SharedCfg};
5use ntex_service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory};
6use ntex_util::{channel::pool, time::timeout_checked};
7
8use crate::control::{Control, ControlAck};
9use crate::{codec::Codec, connection::Connection, default::DefaultControlService};
10use crate::{config::ServiceConfig, consts, dispatcher::Dispatcher, frame, message::Message};
11
12use super::ServerError;
13
14#[derive(Debug)]
15/// Http/2 server factory
16pub struct Server<Pub, Ctl>(ServerInner<Pub, Ctl>);
17
18#[derive(Debug)]
19struct ServerInner<Pub, Ctl> {
20    control: Rc<Ctl>,
21    publish: Rc<Pub>,
22    pool: pool::Pool<()>,
23}
24
25impl<Pub, Ctl> Clone for ServerInner<Pub, Ctl> {
26    fn clone(&self) -> Self {
27        Self {
28            control: self.control.clone(),
29            publish: self.publish.clone(),
30            pool: self.pool.clone(),
31        }
32    }
33}
34
35impl<Pub> Server<Pub, DefaultControlService>
36where
37    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
38    Pub::Error: fmt::Debug,
39    Pub::InitError: fmt::Debug,
40{
41    /// Create new instance of Server factory
42    pub fn new(publish: Pub) -> Self {
43        Self(ServerInner {
44            publish: Rc::new(publish),
45            control: Rc::new(DefaultControlService),
46            pool: pool::new(),
47        })
48    }
49}
50
51impl<Pub, Ctl> Server<Pub, Ctl>
52where
53    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
54    Ctl::Error: fmt::Debug,
55    Ctl::InitError: fmt::Debug,
56    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
57    Pub::Error: fmt::Debug,
58    Pub::InitError: fmt::Debug,
59{
60    /// Service to handle control frames
61    pub fn control<S, F>(&self, service: F) -> Server<Pub, S>
62    where
63        F: IntoServiceFactory<S, Control<Pub::Error>, SharedCfg>,
64        S: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
65        S::Error: fmt::Debug,
66        S::InitError: fmt::Debug,
67    {
68        Server(ServerInner {
69            control: Rc::new(service.into_factory()),
70            publish: self.0.publish.clone(),
71            pool: self.0.pool.clone(),
72        })
73    }
74
75    /// Construct service handler
76    pub fn handler(&self, cfg: SharedCfg) -> ServerHandler<Pub, Ctl> {
77        ServerHandler::new(cfg, self.0.clone())
78    }
79}
80
81impl<Pub, Ctl> ServiceFactory<IoBoxed, SharedCfg> for Server<Pub, Ctl>
82where
83    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
84    Ctl::Error: fmt::Debug,
85    Ctl::InitError: fmt::Debug,
86    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
87    Pub::Error: fmt::Debug,
88    Pub::InitError: fmt::Debug,
89{
90    type Response = ();
91    type Error = ServerError<()>;
92    type Service = ServerHandler<Pub, Ctl>;
93    type InitError = ();
94
95    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
96        Ok(ServerHandler::new(cfg, self.0.clone()))
97    }
98}
99
100impl<F, Pub, Ctl> ServiceFactory<Io<F>, SharedCfg> for Server<Pub, Ctl>
101where
102    F: Filter,
103    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
104    Ctl::Error: fmt::Debug,
105    Ctl::InitError: fmt::Debug,
106    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
107    Pub::Error: fmt::Debug,
108    Pub::InitError: fmt::Debug,
109{
110    type Response = ();
111    type Error = ServerError<()>;
112    type Service = ServerHandler<Pub, Ctl>;
113    type InitError = ();
114
115    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
116        Ok(ServerHandler::new(cfg, self.0.clone()))
117    }
118}
119
120#[derive(Debug)]
121/// Http2 connections handler
122pub struct ServerHandler<Pub, Ctl> {
123    inner: ServerInner<Pub, Ctl>,
124    cfg: Cfg<ServiceConfig>,
125    shared: SharedCfg,
126}
127
128impl<Pub, Ctl> ServerHandler<Pub, Ctl> {
129    fn new(shared: SharedCfg, inner: ServerInner<Pub, Ctl>) -> Self {
130        let cfg = shared.get();
131        Self { cfg, shared, inner }
132    }
133}
134
135impl<Pub, Ctl> Clone for ServerHandler<Pub, Ctl> {
136    fn clone(&self) -> Self {
137        Self {
138            inner: self.inner.clone(),
139            cfg: self.cfg,
140            shared: self.shared,
141        }
142    }
143}
144
145impl<Pub, Ctl> ServerHandler<Pub, Ctl>
146where
147    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
148    Ctl::Error: fmt::Debug,
149    Ctl::InitError: fmt::Debug,
150    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
151    Pub::Error: fmt::Debug,
152    Pub::InitError: fmt::Debug,
153{
154    pub async fn run(&self, io: IoBoxed) -> Result<(), ServerError<()>> {
155        let inner = &self.inner;
156
157        let (ctl_srv, pub_srv) = timeout_checked(self.cfg.handshake_timeout, async {
158            read_preface(&io).await?;
159
160            // create publish service
161            let pub_srv = inner.publish.create(self.shared).await.map_err(|e| {
162                log::error!("Publish service init error: {e:?}");
163                ServerError::PublishServiceError
164            })?;
165
166            // create control service
167            let ctl_srv = inner.control.create(self.shared).await.map_err(|e| {
168                log::error!("Control service init error: {e:?}");
169                ServerError::ControlServiceError
170            })?;
171
172            Ok::<_, ServerError<()>>((ctl_srv, pub_srv))
173        })
174        .await
175        .map_err(|_| ServerError::HandshakeTimeout)??;
176
177        // create h2 codec
178        let codec = Codec::default();
179        let con = Connection::new(
180            true,
181            io.get_ref(),
182            codec.clone(),
183            self.cfg,
184            true,
185            false,
186            self.inner.pool.clone(),
187        );
188        let con2 = con.clone();
189
190        // start protocol dispatcher
191        let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_srv, pub_srv));
192        poll_fn(|cx| {
193            if con2.config().is_shutdown() {
194                con2.disconnect_when_ready();
195            }
196            Pin::new(&mut fut).poll(cx)
197        })
198        .await
199        .map_err(|_| ServerError::Dispatcher)
200    }
201}
202
203impl<Pub, Ctl> Service<IoBoxed> for ServerHandler<Pub, Ctl>
204where
205    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
206    Ctl::Error: fmt::Debug,
207    Ctl::InitError: fmt::Debug,
208    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
209    Pub::Error: fmt::Debug,
210    Pub::InitError: fmt::Debug,
211{
212    type Response = ();
213    type Error = ServerError<()>;
214
215    async fn call(
216        &self,
217        io: IoBoxed,
218        _: ServiceCtx<'_, Self>,
219    ) -> Result<Self::Response, Self::Error> {
220        self.run(io).await
221    }
222}
223
224impl<F, Pub, Ctl> Service<Io<F>> for ServerHandler<Pub, Ctl>
225where
226    F: Filter,
227    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
228    Ctl::Error: fmt::Debug,
229    Ctl::InitError: fmt::Debug,
230    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
231    Pub::Error: fmt::Debug,
232    Pub::InitError: fmt::Debug,
233{
234    type Response = ();
235    type Error = ServerError<()>;
236
237    async fn call(
238        &self,
239        req: Io<F>,
240        _: ServiceCtx<'_, Self>,
241    ) -> Result<Self::Response, Self::Error> {
242        self.run(req.into()).await
243    }
244}
245
246async fn read_preface(io: &IoBoxed) -> Result<(), ServerError<()>> {
247    loop {
248        let ready = io.with_read_buf(|buf| {
249            if buf.len() >= consts::PREFACE.len() {
250                if buf[..consts::PREFACE.len()] == consts::PREFACE {
251                    buf.split_to(consts::PREFACE.len());
252                    Ok(true)
253                } else {
254                    log::trace!("read_preface: invalid preface {buf:?}");
255                    Err(ServerError::<()>::Frame(frame::FrameError::InvalidPreface))
256                }
257            } else {
258                Ok(false)
259            }
260        })?;
261
262        if ready {
263            log::debug!("Preface has been received");
264            return Ok::<_, ServerError<_>>(());
265        } else {
266            io.read_ready()
267                .await?
268                .ok_or(ServerError::Disconnected(None))?;
269        }
270    }
271}
272
273/// Handle io object.
274pub async fn handle_one<Pub, Ctl>(
275    io: IoBoxed,
276    pub_svc: Pub,
277    ctl_svc: Ctl,
278) -> Result<(), ServerError<()>>
279where
280    Ctl: Service<Control<Pub::Error>, Response = ControlAck> + 'static,
281    Ctl::Error: fmt::Debug,
282    Pub: Service<Message, Response = ()> + 'static,
283    Pub::Error: fmt::Debug,
284{
285    let config: Cfg<ServiceConfig> = io.shared().get();
286
287    // read preface
288    timeout_checked(config.handshake_timeout, async { read_preface(&io).await })
289        .await
290        .map_err(|_| ServerError::HandshakeTimeout)??;
291
292    // create h2 codec
293    let codec = Codec::default();
294    let con = Connection::new(
295        true,
296        io.get_ref(),
297        codec.clone(),
298        config,
299        true,
300        false,
301        pool::new(),
302    );
303    let con2 = con.clone();
304
305    // start protocol dispatcher
306    let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_svc, pub_svc));
307
308    poll_fn(|cx| {
309        if con2.config().is_shutdown() {
310            con2.disconnect_when_ready();
311        }
312        Pin::new(&mut fut).poll(cx)
313    })
314    .await
315    .map_err(|_| ServerError::Dispatcher)
316}