Skip to main content

ntex_h2/server/
service.rs

1use std::{fmt, future::Future, future::poll_fn, pin::Pin, rc::Rc};
2
3use ntex_dispatcher::Dispatcher as IoDispatcher;
4use ntex_io::{Filter, Io, IoBoxed};
5use ntex_service::cfg::{Cfg, SharedCfg};
6use ntex_service::{IntoServiceFactory, Service, ServiceCtx, ServiceFactory};
7use ntex_util::{channel::pool, time::timeout_checked};
8
9use crate::control::{Control, ControlAck};
10use crate::{codec::Codec, connection::Connection, default::DefaultControlService};
11use crate::{config::ServiceConfig, consts, dispatcher::Dispatcher, frame, message::Message};
12
13use super::ServerError;
14
15#[derive(Debug)]
16/// Http/2 server factory
17pub struct Server<Pub, Ctl>(ServerInner<Pub, Ctl>);
18
19#[derive(Debug)]
20struct ServerInner<Pub, Ctl> {
21    control: Rc<Ctl>,
22    publish: Rc<Pub>,
23    pool: pool::Pool<()>,
24}
25
26impl<Pub, Ctl> Clone for ServerInner<Pub, Ctl> {
27    fn clone(&self) -> Self {
28        Self {
29            control: self.control.clone(),
30            publish: self.publish.clone(),
31            pool: self.pool.clone(),
32        }
33    }
34}
35
36impl<Pub> Server<Pub, DefaultControlService>
37where
38    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
39    Pub::Error: fmt::Debug,
40    Pub::InitError: fmt::Debug,
41{
42    /// Create new instance of Server factory
43    pub fn new(publish: Pub) -> Self {
44        Self(ServerInner {
45            publish: Rc::new(publish),
46            control: Rc::new(DefaultControlService),
47            pool: pool::new(),
48        })
49    }
50}
51
52impl<Pub, Ctl> Server<Pub, Ctl>
53where
54    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
55    Ctl::Error: fmt::Debug,
56    Ctl::InitError: fmt::Debug,
57    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
58    Pub::Error: fmt::Debug,
59    Pub::InitError: fmt::Debug,
60{
61    /// Service to handle control frames
62    pub fn control<S, F>(&self, service: F) -> Server<Pub, S>
63    where
64        F: IntoServiceFactory<S, Control<Pub::Error>, SharedCfg>,
65        S: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
66        S::Error: fmt::Debug,
67        S::InitError: fmt::Debug,
68    {
69        Server(ServerInner {
70            control: Rc::new(service.into_factory()),
71            publish: self.0.publish.clone(),
72            pool: self.0.pool.clone(),
73        })
74    }
75
76    /// Construct service handler
77    pub fn handler(&self, cfg: SharedCfg) -> ServerHandler<Pub, Ctl> {
78        ServerHandler::new(cfg, self.0.clone())
79    }
80}
81
82impl<Pub, Ctl> ServiceFactory<IoBoxed, SharedCfg> for Server<Pub, Ctl>
83where
84    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
85    Ctl::Error: fmt::Debug,
86    Ctl::InitError: fmt::Debug,
87    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
88    Pub::Error: fmt::Debug,
89    Pub::InitError: fmt::Debug,
90{
91    type Response = ();
92    type Error = ServerError<()>;
93    type Service = ServerHandler<Pub, Ctl>;
94    type InitError = ();
95
96    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
97        Ok(ServerHandler::new(cfg, self.0.clone()))
98    }
99}
100
101impl<F, Pub, Ctl> ServiceFactory<Io<F>, SharedCfg> for Server<Pub, Ctl>
102where
103    F: Filter,
104    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
105    Ctl::Error: fmt::Debug,
106    Ctl::InitError: fmt::Debug,
107    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
108    Pub::Error: fmt::Debug,
109    Pub::InitError: fmt::Debug,
110{
111    type Response = ();
112    type Error = ServerError<()>;
113    type Service = ServerHandler<Pub, Ctl>;
114    type InitError = ();
115
116    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
117        Ok(ServerHandler::new(cfg, self.0.clone()))
118    }
119}
120
121#[derive(Debug)]
122/// Http2 connections handler
123pub struct ServerHandler<Pub, Ctl> {
124    cfg: Cfg<ServiceConfig>,
125    inner: ServerInner<Pub, Ctl>,
126    shared: SharedCfg,
127}
128
129impl<Pub, Ctl> ServerHandler<Pub, Ctl> {
130    fn new(shared: SharedCfg, inner: ServerInner<Pub, Ctl>) -> Self {
131        let cfg = shared.get();
132        Self { cfg, inner, shared }
133    }
134}
135
136impl<Pub, Ctl> Clone for ServerHandler<Pub, Ctl> {
137    fn clone(&self) -> Self {
138        Self {
139            inner: self.inner.clone(),
140            cfg: self.cfg.clone(),
141            shared: self.shared.clone(),
142        }
143    }
144}
145
146impl<Pub, Ctl> ServerHandler<Pub, Ctl>
147where
148    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
149    Ctl::Error: fmt::Debug,
150    Ctl::InitError: fmt::Debug,
151    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
152    Pub::Error: fmt::Debug,
153    Pub::InitError: fmt::Debug,
154{
155    pub async fn run(&self, io: IoBoxed) -> Result<(), ServerError<()>> {
156        let inner = &self.inner;
157
158        let (ctl_srv, pub_srv) = timeout_checked(self.cfg.handshake_timeout, async {
159            read_preface(&io).await?;
160
161            // create publish service
162            let pub_srv = inner
163                .publish
164                .create(self.shared.clone())
165                .await
166                .map_err(|e| {
167                    log::error!("Publish service init error: {e:?}");
168                    ServerError::PublishServiceError
169                })?;
170
171            // create control service
172            let ctl_srv = inner
173                .control
174                .create(self.shared.clone())
175                .await
176                .map_err(|e| {
177                    log::error!("Control service init error: {e:?}");
178                    ServerError::ControlServiceError
179                })?;
180
181            Ok::<_, ServerError<()>>((ctl_srv, pub_srv))
182        })
183        .await
184        .map_err(|()| ServerError::HandshakeTimeout)??;
185
186        // create h2 codec
187        let codec = Codec::default();
188        let con = Connection::new(
189            true,
190            io.get_ref(),
191            codec.clone(),
192            self.cfg.clone(),
193            true,
194            false,
195            self.inner.pool.clone(),
196        );
197        let con2 = con.clone();
198
199        // start protocol dispatcher
200        let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_srv, pub_srv));
201        poll_fn(|cx| {
202            if con2.config().is_shutdown() {
203                con2.disconnect_when_ready();
204            }
205            Pin::new(&mut fut).poll(cx)
206        })
207        .await
208        .map_err(|()| ServerError::Dispatcher)
209    }
210}
211
212impl<Pub, Ctl> Service<IoBoxed> for ServerHandler<Pub, Ctl>
213where
214    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
215    Ctl::Error: fmt::Debug,
216    Ctl::InitError: fmt::Debug,
217    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
218    Pub::Error: fmt::Debug,
219    Pub::InitError: fmt::Debug,
220{
221    type Response = ();
222    type Error = ServerError<()>;
223
224    async fn call(
225        &self,
226        io: IoBoxed,
227        _: ServiceCtx<'_, Self>,
228    ) -> Result<Self::Response, Self::Error> {
229        self.run(io).await
230    }
231}
232
233impl<F, Pub, Ctl> Service<Io<F>> for ServerHandler<Pub, Ctl>
234where
235    F: Filter,
236    Ctl: ServiceFactory<Control<Pub::Error>, SharedCfg, Response = ControlAck> + 'static,
237    Ctl::Error: fmt::Debug,
238    Ctl::InitError: fmt::Debug,
239    Pub: ServiceFactory<Message, SharedCfg, Response = ()> + 'static,
240    Pub::Error: fmt::Debug,
241    Pub::InitError: fmt::Debug,
242{
243    type Response = ();
244    type Error = ServerError<()>;
245
246    async fn call(
247        &self,
248        req: Io<F>,
249        _: ServiceCtx<'_, Self>,
250    ) -> Result<Self::Response, Self::Error> {
251        self.run(req.into()).await
252    }
253}
254
255async fn read_preface(io: &IoBoxed) -> Result<(), ServerError<()>> {
256    loop {
257        let ready = io.with_read_buf(|buf| {
258            if buf.len() >= consts::PREFACE.len() {
259                if buf[..consts::PREFACE.len()] == consts::PREFACE {
260                    buf.advance_to(consts::PREFACE.len());
261                    Ok(true)
262                } else {
263                    log::trace!("read_preface: invalid preface {buf:?}");
264                    Err(ServerError::<()>::Frame(frame::FrameError::InvalidPreface))
265                }
266            } else {
267                Ok(false)
268            }
269        })?;
270
271        if ready {
272            log::debug!("Preface has been received");
273            return Ok::<_, ServerError<_>>(());
274        }
275        io.read_ready()
276            .await?
277            .ok_or(ServerError::Disconnected(None))?;
278    }
279}
280
281/// Handle io object.
282pub async fn handle_one<Pub, Ctl>(
283    io: IoBoxed,
284    pub_svc: Pub,
285    ctl_svc: Ctl,
286) -> Result<(), ServerError<()>>
287where
288    Ctl: Service<Control<Pub::Error>, Response = ControlAck> + 'static,
289    Ctl::Error: fmt::Debug,
290    Pub: Service<Message, Response = ()> + 'static,
291    Pub::Error: fmt::Debug,
292{
293    let config: Cfg<ServiceConfig> = io.shared().get();
294
295    // read preface
296    timeout_checked(config.handshake_timeout, async { read_preface(&io).await })
297        .await
298        .map_err(|()| ServerError::HandshakeTimeout)??;
299
300    // create h2 codec
301    let codec = Codec::default();
302    let con = Connection::new(
303        true,
304        io.get_ref(),
305        codec.clone(),
306        config,
307        true,
308        false,
309        pool::new(),
310    );
311    let con2 = con.clone();
312
313    // start protocol dispatcher
314    let mut fut = IoDispatcher::new(io, codec, Dispatcher::new(con, ctl_svc, pub_svc));
315
316    poll_fn(|cx| {
317        if con2.config().is_shutdown() {
318            con2.disconnect_when_ready();
319        }
320        Pin::new(&mut fut).poll(cx)
321    })
322    .await
323    .map_err(|()| ServerError::Dispatcher)
324}