1use super::Entity;
10use crate::etag;
11use crate::range;
12use bytes::Buf;
13use futures_core::Stream;
14use futures_util::stream::{self, StreamExt};
15use http::header::{self, HeaderMap, HeaderValue};
16use http::{self, Method, Request, Response, StatusCode};
17use http_body::Body;
18use httpdate::{fmt_http_date, parse_http_date};
19use pin_project::pin_project;
20use smallvec::SmallVec;
21use std::future::Future;
22use std::io::Write;
23use std::ops::Range;
24use std::pin::Pin;
25use std::time::SystemTime;
26
27const MAX_DECIMAL_U64_BYTES: usize = 20; fn parse_modified_hdrs(
30 etag: &Option<HeaderValue>,
31 req_hdrs: &HeaderMap,
32 last_modified: Option<SystemTime>,
33) -> Result<(bool, bool), &'static str> {
34 let precondition_failed = if !etag::any_match(etag, req_hdrs)? {
35 true
36 } else if let (Some(ref m), Some(since)) =
37 (last_modified, req_hdrs.get(header::IF_UNMODIFIED_SINCE))
38 {
39 const ERR: &str = "Unparseable If-Unmodified-Since";
40 *m > parse_http_date(since.to_str().map_err(|_| ERR)?).map_err(|_| ERR)?
41 } else {
42 false
43 };
44
45 let not_modified = match etag::none_match(etag, req_hdrs) {
46 Some(true) => false,
53
54 Some(false) => true,
55
56 None => {
57 if let (Some(ref m), Some(since)) =
58 (last_modified, req_hdrs.get(header::IF_MODIFIED_SINCE))
59 {
60 const ERR: &str = "Unparseable If-Modified-Since";
61 *m <= parse_http_date(since.to_str().map_err(|_| ERR)?).map_err(|_| ERR)?
62 } else {
63 false
64 }
65 }
66 };
67 Ok((precondition_failed, not_modified))
68}
69
70fn static_body<D, E>(s: &'static str) -> Box<dyn Stream<Item = Result<D, E>> + Send>
71where
72 D: 'static + Send + Buf + From<Vec<u8>> + From<&'static [u8]>,
73 E: 'static + Send,
74{
75 Box::new(stream::once(futures_util::future::ok(s.as_bytes().into())))
76}
77
78fn empty_body<D, E>() -> Box<dyn Stream<Item = Result<D, E>> + Send>
79where
80 D: 'static + Send + Buf + From<Vec<u8>> + From<&'static [u8]>,
81 E: 'static + Send,
82{
83 Box::new(stream::empty())
84}
85
86pub fn serve<
91 Ent: Entity,
92 B: Body + From<Box<dyn Stream<Item = Result<Ent::Data, Ent::Error>> + Send>>,
93 BI,
94>(
95 entity: Ent,
96 req: &Request<BI>,
97) -> Response<B> {
98 match serve_inner(&entity, req) {
102 ServeInner::Simple(res) => res,
103 ServeInner::Multipart {
104 res,
105 mut part_headers,
106 ranges,
107 } => {
108 let bodies = stream::unfold(0, move |state| {
109 next_multipart_body_chunk(state, &entity, &ranges[..], &mut part_headers[..])
110 });
111 let body = bodies.flatten();
112 let body: Box<dyn Stream<Item = Result<Ent::Data, Ent::Error>> + Send> = Box::new(body);
113 res.body(body.into()).unwrap()
114 }
115 }
116}
117
118enum ServeInner<B> {
120 Simple(Response<B>),
121 Multipart {
122 res: http::response::Builder,
123 part_headers: Vec<Vec<u8>>,
124 ranges: SmallVec<[Range<u64>; 1]>,
125 },
126}
127
128fn serve_inner<
130 D: 'static + Send + Sync + Buf + From<Vec<u8>> + From<&'static [u8]>,
131 E: 'static + Send + Sync,
132 B: Body + From<Box<dyn Stream<Item = Result<D, E>> + Send>>,
133 BI,
134>(
135 ent: &dyn Entity<Error = E, Data = D>,
136 req: &Request<BI>,
137) -> ServeInner<B> {
138 if *req.method() != Method::GET && *req.method() != Method::HEAD {
139 return ServeInner::Simple(
140 Response::builder()
141 .status(StatusCode::METHOD_NOT_ALLOWED)
142 .header(header::ALLOW, HeaderValue::from_static("get, head"))
143 .body(static_body::<D, E>("This resource only supports GET and HEAD.").into())
144 .unwrap(),
145 );
146 }
147
148 let last_modified = ent.last_modified();
149 let etag = ent.etag();
150
151 let (precondition_failed, not_modified) =
152 match parse_modified_hdrs(&etag, req.headers(), last_modified) {
153 Err(s) => {
154 return ServeInner::Simple(
155 Response::builder()
156 .status(StatusCode::BAD_REQUEST)
157 .body(static_body::<D, E>(s).into())
158 .unwrap(),
159 )
160 }
161 Ok(p) => p,
162 };
163
164 let mut range_hdr = req.headers().get(header::RANGE);
168 let include_entity_headers_on_range = match req.headers().get(header::IF_RANGE) {
169 Some(if_range) => {
170 let if_range = if_range.as_bytes();
171 if if_range.starts_with(b"W/\"") || if_range.starts_with(b"\"") {
172 if let Some(ref some_etag) = etag {
174 if etag::strong_eq(if_range, some_etag.as_bytes()) {
175 false
176 } else {
177 range_hdr = None;
178 true
179 }
180 } else {
181 range_hdr = None;
182 true
183 }
184 } else {
185 range_hdr = None;
190 true
191 }
192 }
193 None => true,
194 };
195
196 let mut res =
197 Response::builder().header(header::ACCEPT_RANGES, HeaderValue::from_static("bytes"));
198 if let Some(m) = last_modified {
199 let d = SystemTime::now();
203 res = res.header(header::DATE, fmt_http_date(d));
204 let clamped_m = std::cmp::min(m, d);
205 res = res.header(header::LAST_MODIFIED, fmt_http_date(clamped_m));
206 }
207 if let Some(e) = etag {
208 res = res.header(http::header::ETAG, e);
209 }
210
211 if precondition_failed {
212 res = res.status(StatusCode::PRECONDITION_FAILED);
213 return ServeInner::Simple(
214 res.body(static_body::<D, E>("Precondition failed").into())
215 .unwrap(),
216 );
217 }
218
219 if not_modified {
220 res = res.status(StatusCode::NOT_MODIFIED);
221 return ServeInner::Simple(res.body(empty_body::<D, E>().into()).unwrap());
222 }
223
224 let len = ent.len();
225 let (range, include_entity_headers) = match range::parse(range_hdr, len) {
226 range::ResolvedRanges::None => (0..len, true),
227 range::ResolvedRanges::Satisfiable(ranges) => {
228 if ranges.len() == 1 {
229 res = res.header(
230 header::CONTENT_RANGE,
231 unsafe_fmt_ascii_val!(
232 MAX_DECIMAL_U64_BYTES * 3 + "bytes -/".len(),
233 "bytes {}-{}/{}",
234 ranges[0].start,
235 ranges[0].end - 1,
236 len
237 ),
238 );
239 res = res.status(StatusCode::PARTIAL_CONTENT);
240 (ranges[0].clone(), include_entity_headers_on_range)
241 } else {
242 let est_len: u64 = ranges.iter().map(|r| 80 + r.end - r.start).sum();
246 if est_len < len {
247 let (res, part_headers) = prepare_multipart(
248 ent,
249 res,
250 &ranges[..],
251 len,
252 include_entity_headers_on_range,
253 );
254 if *req.method() == Method::HEAD {
255 return ServeInner::Simple(res.body(empty_body::<D, E>().into()).unwrap());
256 }
257 return ServeInner::Multipart {
258 res,
259 part_headers,
260 ranges,
261 };
262 }
263
264 (0..len, true)
265 }
266 }
267 range::ResolvedRanges::NotSatisfiable => {
268 res = res.header(
269 http::header::CONTENT_RANGE,
270 unsafe_fmt_ascii_val!(MAX_DECIMAL_U64_BYTES + "bytes */".len(), "bytes */{}", len),
271 );
272 res = res.status(StatusCode::RANGE_NOT_SATISFIABLE);
273 return ServeInner::Simple(res.body(empty_body::<D, E>().into()).unwrap());
274 }
275 };
276 res = res.header(
277 header::CONTENT_LENGTH,
278 unsafe_fmt_ascii_val!(MAX_DECIMAL_U64_BYTES, "{}", range.end - range.start),
279 );
280 let body = match *req.method() {
281 Method::HEAD => empty_body::<D, E>(),
282 _ => ent.get_range(range),
283 };
284 let mut res = res.body(body.into()).unwrap();
285 if include_entity_headers {
286 ent.add_headers(res.headers_mut());
287 }
288 ServeInner::Simple(res)
289}
290
291#[pin_project(project=InnerBodyProj)]
294enum InnerBody<D, E> {
295 Once(Option<D>),
296
297 B(Pin<Box<dyn Stream<Item = Result<D, E>> + Sync + Send>>),
299}
300
301impl<D, E> Stream for InnerBody<D, E> {
302 type Item = Result<D, E>;
303 fn poll_next(
304 self: Pin<&mut Self>,
305 ctx: &mut std::task::Context,
306 ) -> std::task::Poll<Option<Result<D, E>>> {
307 let mut this = self.project();
308 match this {
309 InnerBodyProj::Once(ref mut o) => std::task::Poll::Ready(o.take().map(Ok)),
310 InnerBodyProj::B(b) => b.as_mut().poll_next(ctx),
311 }
312 }
313}
314
315fn prepare_multipart<D, E>(
318 ent: &dyn Entity<Data = D, Error = E>,
319 mut res: http::response::Builder,
320 ranges: &[Range<u64>],
321 len: u64,
322 include_entity_headers: bool,
323) -> (http::response::Builder, Vec<Vec<u8>>)
324where
325 D: 'static + Send + Sync + Buf + From<Vec<u8>> + From<&'static [u8]>,
326 E: 'static + Send + Sync,
327{
328 let mut each_part_headers = Vec::new();
329 if include_entity_headers {
330 let mut h = http::header::HeaderMap::new();
331 ent.add_headers(&mut h);
332 each_part_headers.reserve(
333 h.iter()
334 .map(|(k, v)| k.as_str().len() + v.as_bytes().len() + 4)
335 .sum::<usize>()
336 + 2,
337 );
338 for (k, v) in &h {
339 each_part_headers.extend_from_slice(k.as_str().as_bytes());
340 each_part_headers.extend_from_slice(b": ");
341 each_part_headers.extend_from_slice(v.as_bytes());
342 each_part_headers.extend_from_slice(b"\r\n");
343 }
344 }
345 each_part_headers.extend_from_slice(b"\r\n");
346
347 let mut body_len = 0;
348 let mut part_headers: Vec<Vec<u8>> = Vec::with_capacity(2 * ranges.len() + 1);
349 for r in ranges {
350 let mut buf = Vec::with_capacity(64 + each_part_headers.len());
351 write!(
352 &mut buf,
353 "\r\n--B\r\nContent-Range: bytes {}-{}/{}\r\n",
354 r.start,
355 r.end - 1,
356 len
357 )
358 .unwrap();
359 buf.extend_from_slice(&each_part_headers);
360 body_len += buf.len() as u64 + r.end - r.start;
361 part_headers.push(buf);
362 }
363 body_len += PART_TRAILER.len() as u64;
364
365 res = res.header(
366 header::CONTENT_LENGTH,
367 unsafe_fmt_ascii_val!(MAX_DECIMAL_U64_BYTES, "{}", body_len),
368 );
369 res = res.header(
370 header::CONTENT_TYPE,
371 HeaderValue::from_static("multipart/byteranges; boundary=B"),
372 );
373 res = res.status(StatusCode::PARTIAL_CONTENT);
374
375 (res, part_headers)
376}
377
378const PART_TRAILER: &[u8] = b"\r\n--B--\r\n";
380
381fn next_multipart_body_chunk<D, E>(
386 state: usize,
387 ent: &dyn Entity<Data = D, Error = E>,
388 ranges: &[Range<u64>],
389 part_headers: &mut [Vec<u8>],
390) -> impl Future<Output = Option<(InnerBody<D, E>, usize)>>
391where
392 D: 'static + Send + Sync + Buf + From<Vec<u8>> + From<&'static [u8]>,
393 E: 'static + Send + Sync,
394{
395 let i = state >> 1;
396 let odd = (state & 1) == 1;
397 let body = if i == ranges.len() && odd {
398 return futures_util::future::ready(None);
399 } else if i == ranges.len() {
400 InnerBody::Once(Some(PART_TRAILER.into()))
401 } else if odd {
402 InnerBody::B(Pin::from(ent.get_range(ranges[i].clone())))
403 } else {
404 let v = std::mem::take(&mut part_headers[i]);
405 InnerBody::Once(Some(v.into()))
406 };
407 futures_util::future::ready(Some((body, state + 1)))
408}