proto_tower_http_2/server/
layer.rs1use crate::server::parser::{read_next_frame, Http2Frame, WriteOnto};
2use crate::ProtoHttp2Config;
3use proto_tower_util::{AsyncReadToBuf, ZeroReadBehaviour};
4use std::fmt::Debug;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::sync::mpsc::{Receiver, Sender};
10use tower::Service;
11
12#[derive(Debug)]
13pub enum ProtoHttp2Error<Error: Debug> {
14 InvalidPreface,
15 Timeout,
16 InnerServiceClosed,
17 ServiceError(Error),
18 OtherInternalError(&'static str),
19}
20
21pub struct ProtoHttp2Layer<Svc>
25where
26 Svc: Service<(Receiver<Http2Frame>, Sender<Http2Frame>), Response = ()> + Send + Clone,
27{
28 config: ProtoHttp2Config,
29 inner: Svc,
31}
32
33impl<Svc> ProtoHttp2Layer<Svc>
34where
35 Svc: Service<(Receiver<Http2Frame>, Sender<Http2Frame>), Response = ()> + Send + Clone,
36{
37 pub fn new(config: ProtoHttp2Config, inner: Svc) -> Self {
39 ProtoHttp2Layer { config, inner }
40 }
41}
42
43impl<Reader, Writer, Svc, SvcError, SvcFut> Service<(Reader, Writer)> for ProtoHttp2Layer<Svc>
44where
45 Reader: AsyncReadExt + Send + Unpin + 'static,
46 Writer: AsyncWriteExt + Send + Unpin + 'static,
47 Svc: Service<(Receiver<Http2Frame>, Sender<Http2Frame>), Response = (), Error = SvcError, Future = SvcFut> + Send + Clone + 'static,
48 SvcFut: Future<Output = Result<(Receiver<Http2Frame>, Sender<Http2Frame>), SvcError>> + Send + 'static,
49 SvcError: Debug + Send + 'static,
50{
51 type Response = ();
53 type Error = ProtoHttp2Error<SvcError>;
55 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
57
58 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
59 self.inner.poll_ready(cx).map_err(ProtoHttp2Error::ServiceError)
60 }
61
62 fn call(&mut self, (mut reader, mut writer): (Reader, Writer)) -> Self::Future {
64 let mut service = self.inner.clone();
65 let config = self.config.clone();
66 Box::pin(async move {
67 let async_read = AsyncReadToBuf::new_1024(ZeroReadBehaviour::TickAndYield);
68 let mut preface = async_read.read_with_timeout(&mut reader, config.timeout, Some(28)).await;
69 if preface != b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" {
70 return Err(ProtoHttp2Error::InvalidPreface);
71 }
72 let (layer_frame_sx, svc_frame_rx) = tokio::sync::mpsc::channel::<Http2Frame>(1);
73 let (svc_frame_sx, mut layer_frame_rx) = tokio::sync::mpsc::channel::<Http2Frame>(1);
74 let internal_task = tokio::spawn(service.call((svc_frame_rx, svc_frame_sx)));
75 loop {
77 tokio::select! {
78 _ = tokio::time::sleep(config.timeout) => {
79 return Err(ProtoHttp2Error::Timeout);
80 }
81 frame = layer_frame_rx.recv() => {
82 match frame {
83 Some(frame) => {
84 frame.write_onto(&mut writer).await.unwrap();
85 }
86 None => {
87 return Err(ProtoHttp2Error::InnerServiceClosed);
88 }
89 }
90 }
91 frame = read_next_frame(&mut reader, config.timeout) => {
92 match frame {
93 Ok(frame) => {
94 layer_frame_sx.send(frame).await.unwrap();
95 }
96 Err(e) => {
97 return Err(ProtoHttp2Error::OtherInternalError(e));
98 }
99 }
100 }
101 }
102 if internal_task.is_finished() {
103 return match internal_task.await {
104 Ok(Err(e)) => Err(ProtoHttp2Error::ServiceError(e)),
105 Ok(Ok(_)) | Err(_) => Err(ProtoHttp2Error::InnerServiceClosed),
106 };
107 }
108 }
109 })
110 }
111}