proto_tower_http_2/server/
layer.rs

1use crate::server::parser::{read_next_frame, Http2Frame, WriteOnto};
2use crate::ProtoHttp2Config;
3use proto_tower_util::{AsyncReadToBuf, ZeroReadBehaviour};
4use std::fmt::Debug;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::sync::mpsc::{Receiver, Sender};
10use tower::Service;
11
12#[derive(Debug)]
13pub enum ProtoHttp2Error<Error: Debug> {
14    InvalidPreface,
15    Timeout,
16    InnerServiceClosed,
17    ServiceError(Error),
18    OtherInternalError(&'static str),
19}
20
21/// A service to process HTTP/1.1 requests
22///
23/// This should not be constructed directly - it gets created by MakeService during invocation.
24pub struct ProtoHttp2Layer<Svc>
25where
26    Svc: Service<(Receiver<Http2Frame>, Sender<Http2Frame>), Response = ()> + Send + Clone,
27{
28    config: ProtoHttp2Config,
29    /// The inner service to process requests
30    inner: Svc,
31}
32
33impl<Svc> ProtoHttp2Layer<Svc>
34where
35    Svc: Service<(Receiver<Http2Frame>, Sender<Http2Frame>), Response = ()> + Send + Clone,
36{
37    /// Create a new instance of the service
38    pub fn new(config: ProtoHttp2Config, inner: Svc) -> Self {
39        ProtoHttp2Layer { config, inner }
40    }
41}
42
43impl<Reader, Writer, Svc, SvcError, SvcFut> Service<(Reader, Writer)> for ProtoHttp2Layer<Svc>
44where
45    Reader: AsyncReadExt + Send + Unpin + 'static,
46    Writer: AsyncWriteExt + Send + Unpin + 'static,
47    Svc: Service<(Receiver<Http2Frame>, Sender<Http2Frame>), Response = (), Error = SvcError, Future = SvcFut> + Send + Clone + 'static,
48    SvcFut: Future<Output = Result<(Receiver<Http2Frame>, Sender<Http2Frame>), SvcError>> + Send + 'static,
49    SvcError: Debug + Send + 'static,
50{
51    /// The response is handled by the protocol
52    type Response = ();
53    /// Errors would be failures in parsing the protocol - this should be handled by the protocol
54    type Error = ProtoHttp2Error<SvcError>;
55    /// The future is the protocol itself
56    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
57
58    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59        self.inner.poll_ready(cx).map_err(ProtoHttp2Error::ServiceError)
60    }
61
62    /// Indefinitely process the protocol
63    fn call(&mut self, (mut reader, mut writer): (Reader, Writer)) -> Self::Future {
64        let mut service = self.inner.clone();
65        let config = self.config.clone();
66        Box::pin(async move {
67            let async_read = AsyncReadToBuf::new_1024(ZeroReadBehaviour::TickAndYield);
68            let mut preface = async_read.read_with_timeout(&mut reader, config.timeout, Some(28)).await;
69            if preface != b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" {
70                return Err(ProtoHttp2Error::InvalidPreface);
71            }
72            let (layer_frame_sx, svc_frame_rx) = tokio::sync::mpsc::channel::<Http2Frame>(1);
73            let (svc_frame_sx, mut layer_frame_rx) = tokio::sync::mpsc::channel::<Http2Frame>(1);
74            let internal_task = tokio::spawn(service.call((svc_frame_rx, svc_frame_sx)));
75            // Validate request
76            loop {
77                tokio::select! {
78                    _ = tokio::time::sleep(config.timeout) => {
79                        return Err(ProtoHttp2Error::Timeout);
80                    }
81                    frame = layer_frame_rx.recv() => {
82                        match frame {
83                            Some(frame) => {
84                                frame.write_onto(&mut writer).await.unwrap();
85                            }
86                            None => {
87                                return Err(ProtoHttp2Error::InnerServiceClosed);
88                            }
89                        }
90                    }
91                    frame = read_next_frame(&mut reader, config.timeout) => {
92                        match frame {
93                            Ok(frame) => {
94                                layer_frame_sx.send(frame).await.unwrap();
95                            }
96                            Err(e) => {
97                                return Err(ProtoHttp2Error::OtherInternalError(e));
98                            }
99                        }
100                    }
101                }
102                if internal_task.is_finished() {
103                    return match internal_task.await {
104                        Ok(Err(e)) => Err(ProtoHttp2Error::ServiceError(e)),
105                        Ok(Ok(_)) | Err(_) => Err(ProtoHttp2Error::InnerServiceClosed),
106                    };
107                }
108            }
109        })
110    }
111}