Skip to main content

tower_embed/
response.rs

1use std::{
2    borrow::Cow,
3    pin::Pin,
4    task::{Context, Poll, ready},
5};
6
7use bytes::Bytes;
8use futures_core::{Stream, future::BoxFuture};
9use http_body::{Body, Frame};
10use http_body_util::BodyExt;
11use tower_embed_core::{BoxError, Embed, Embedded, headers};
12
13use crate::core::headers::HeaderMapExt;
14
15type BoxBody = http_body_util::combinators::UnsyncBoxBody<Bytes, BoxError>;
16
17/// The body used in crate responses.
18#[derive(Debug)]
19pub struct ResponseBody(BoxBody);
20
21impl ResponseBody {
22    /// Create an empty response body.
23    pub fn empty() -> Self {
24        ResponseBody::new(http_body_util::Empty::new())
25    }
26
27    /// Create a new response body that contains a single chunk
28    pub fn full(data: Bytes) -> Self {
29        ResponseBody::new(http_body_util::Full::new(data))
30    }
31
32    /// Create a response body from a stream of bytes.
33    pub fn stream<S, E>(stream: S) -> Self
34    where
35        S: Stream<Item = Result<Frame<Bytes>, E>> + Send + 'static,
36        E: Into<BoxError>,
37    {
38        ResponseBody::new(http_body_util::StreamBody::new(stream))
39    }
40
41    fn new<B>(body: B) -> Self
42    where
43        B: Body<Data = Bytes> + Send + 'static,
44        B::Error: Into<BoxError>,
45    {
46        ResponseBody(body.map_err(|err| err.into()).boxed_unsync())
47    }
48}
49
50impl http_body::Body for ResponseBody {
51    type Data = Bytes;
52    type Error = BoxError;
53
54    fn poll_frame(
55        mut self: Pin<&mut Self>,
56        cx: &mut Context<'_>,
57    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
58        Pin::new(&mut self.0).poll_frame(cx)
59    }
60
61    fn is_end_stream(&self) -> bool {
62        self.0.is_end_stream()
63    }
64
65    fn size_hint(&self) -> http_body::SizeHint {
66        self.0.size_hint()
67    }
68}
69
70/// Response future of [`ServeEmbed`]
71///
72/// [`ServeEmbed`]: crate::ServeEmbed
73pub struct ResponseFuture {
74    inner: ResponseFutureInner,
75}
76
77enum ResponseFutureInner {
78    Ready(Option<http::Response<ResponseBody>>),
79    WaitingEmbedded {
80        fut: BoxFuture<'static, std::io::Result<Embedded>>,
81        if_none_match: Option<headers::IfNoneMatch>,
82        if_modified_since: Option<headers::IfModifiedSince>,
83    },
84}
85
86impl ResponseFuture {
87    pub(crate) fn new<E, B>(req: &http::Request<B>) -> Self
88    where
89        E: Embed,
90    {
91        if req.method() != http::Method::GET && req.method() != http::Method::HEAD {
92            return Self::method_not_allowed();
93        }
94
95        let path = get_file_path_from_uri(req.uri());
96        let embedded = E::get(path.as_ref());
97
98        let if_none_match = req.headers().typed_get::<headers::IfNoneMatch>();
99        let if_modified_since = req.headers().typed_get::<headers::IfModifiedSince>();
100
101        let inner = ResponseFutureInner::WaitingEmbedded {
102            fut: Box::pin(embedded),
103            if_none_match,
104            if_modified_since,
105        };
106        Self { inner }
107    }
108
109    pub(crate) fn method_not_allowed() -> Self {
110        let response = http::Response::builder()
111            .header(
112                http::header::ALLOW,
113                http::HeaderValue::from_static("GET, HEAD"),
114            )
115            .status(http::StatusCode::METHOD_NOT_ALLOWED)
116            .body(ResponseBody::empty())
117            .unwrap();
118
119        Self {
120            inner: ResponseFutureInner::Ready(Some(response)),
121        }
122    }
123}
124
125impl Future for ResponseFuture {
126    type Output = Result<http::Response<ResponseBody>, std::convert::Infallible>;
127
128    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
129        let inner = &mut self.get_mut().inner;
130
131        let response = match inner {
132            ResponseFutureInner::Ready(response) => response
133                .take()
134                .expect("ResponseFuture polled after completion"),
135            ResponseFutureInner::WaitingEmbedded {
136                fut,
137                if_none_match,
138                if_modified_since,
139            } => match ready!(Pin::new(fut).poll(cx)) {
140                Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
141                    *inner = ResponseFutureInner::Ready(None);
142                    http::Response::builder()
143                        .status(http::StatusCode::NOT_FOUND)
144                        .body(ResponseBody::empty())
145                        .unwrap()
146                }
147                Err(_) => {
148                    *inner = ResponseFutureInner::Ready(None);
149                    http::Response::builder()
150                        .status(http::StatusCode::INTERNAL_SERVER_ERROR)
151                        .body(ResponseBody::empty())
152                        .unwrap()
153                }
154                Ok(embedded) => {
155                    // Make the request conditional if an If-None-Match header is present
156                    if let Some(if_none_match) = if_none_match
157                        && let Some(etag) = &embedded.metadata.etag
158                        && !if_none_match.condition_passes(etag)
159                    {
160                        return Poll::Ready(Ok(http::Response::builder()
161                            .status(http::StatusCode::NOT_MODIFIED)
162                            .body(ResponseBody::empty())
163                            .unwrap()));
164                    }
165
166                    // Make the request conditional if an If-Modified-Since header is present
167                    if let Some(if_modified_since) = if_modified_since
168                        && let Some(last_modified) = embedded.metadata.last_modified
169                        && !if_modified_since.condition_passes(&last_modified)
170                    {
171                        return Poll::Ready(Ok(http::Response::builder()
172                            .status(http::StatusCode::NOT_MODIFIED)
173                            .body(ResponseBody::empty())
174                            .unwrap()));
175                    }
176
177                    let Embedded { content, metadata } = embedded;
178                    let mut response = http::Response::builder()
179                        .status(http::StatusCode::OK)
180                        .body(ResponseBody::stream(content))
181                        .unwrap();
182
183                    response.headers_mut().typed_insert(metadata.content_type);
184                    if let Some(etag) = metadata.etag {
185                        response.headers_mut().typed_insert(etag);
186                    }
187                    if let Some(last_modified) = metadata.last_modified {
188                        response.headers_mut().typed_insert(last_modified);
189                    }
190
191                    response
192                }
193            },
194        };
195        Poll::Ready(Ok(response))
196    }
197}
198
199fn get_file_path_from_uri(uri: &http::Uri) -> Cow<'_, str> {
200    let path = uri.path();
201    if path.ends_with("/") {
202        Cow::Owned(format!("{}index.html", path.trim_start_matches('/')))
203    } else {
204        Cow::Borrowed(path.trim_start_matches('/'))
205    }
206}