axum_range/
lib.rs

1//! # axum-range
2//!
3//! HTTP range responses for [`axum`][1].
4//!
5//! Fully generic, supports any body implementing the [`RangeBody`] trait.
6//!
7//! Any type implementing both [`AsyncRead`] and [`AsyncSeekStart`] can be
8//! used the [`KnownSize`] adapter struct. There is also special cased support
9//! for [`tokio::fs::File`], see the [`KnownSize::file`] method.
10//!
11//! [`AsyncSeekStart`] is a trait defined by this crate which only allows
12//! seeking from the start of a file. It is automatically implemented for any
13//! type implementing [`AsyncSeek`].
14//!
15//! ```
16//! use axum::Router;
17//! use axum::routing::get;
18//! use axum_extra::TypedHeader;
19//! use axum_extra::headers::Range;
20//!
21//! use tokio::fs::File;
22//!
23//! use axum_range::Ranged;
24//! use axum_range::KnownSize;
25//!
26//! async fn file(range: Option<TypedHeader<Range>>) -> Ranged<KnownSize<File>> {
27//!     let file = File::open("The Sims 1 - The Complete Collection.rar").await.unwrap();
28//!     let body = KnownSize::file(file).await.unwrap();
29//!     let range = range.map(|TypedHeader(range)| range);
30//!     Ranged::new(range, body)
31//! }
32//!
33//! #[tokio::main]
34//! async fn main() {
35//!     // build our application with a single route
36//!     let _app = Router::<()>::new().route("/", get(file));
37//!
38//!     // run it with hyper on localhost:3000
39//!     #[cfg(feature = "run_server_in_example")]
40//!     axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
41//!        .serve(_app.into_make_service())
42//!        .await
43//!        .unwrap();
44//! }
45//! ```
46//!
47//! [1]: https://docs.rs/axum
48
49mod file;
50mod stream;
51
52use std::io;
53use std::ops::Bound;
54use std::pin::Pin;
55use std::task::{Context, Poll};
56
57use axum::http::StatusCode;
58use axum::response::{IntoResponse, Response};
59use axum_extra::TypedHeader;
60use axum_extra::headers::{Range, ContentRange, ContentLength, AcceptRanges};
61use tokio::io::{AsyncRead, AsyncSeek};
62
63pub use file::KnownSize;
64pub use stream::RangedStream;
65
66/// [`AsyncSeek`] narrowed to only allow seeking from start.
67pub trait AsyncSeekStart {
68    /// Same semantics as [`AsyncSeek::start_seek`], always passing position as the `SeekFrom::Start` variant.
69    fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()>;
70
71    /// Same semantics as [`AsyncSeek::poll_complete`], returning `()` instead of the new stream position.
72    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
73}
74
75impl<T: AsyncSeek> AsyncSeekStart for T {
76    fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()> {
77        AsyncSeek::start_seek(self, io::SeekFrom::Start(position))
78    }
79
80    fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81        AsyncSeek::poll_complete(self, cx).map_ok(|_| ())
82    }
83}
84
85/// An [`AsyncRead`] and [`AsyncSeekStart`] with a fixed known byte size.
86pub trait RangeBody: AsyncRead + AsyncSeekStart {
87    /// The total size of the underlying file.
88    ///
89    /// This should not change for the lifetime of the object once queried.
90    /// Behaviour is not guaranteed if it does change.
91    fn byte_size(&self) -> u64;
92}
93
94/// The main responder type. Implements [`IntoResponse`].
95pub struct Ranged<B: RangeBody + Send + 'static> {
96    range: Option<Range>,
97    body: B,
98}
99
100impl<B: RangeBody + Send + 'static> Ranged<B> {
101    /// Construct a ranged response over any type implementing [`RangeBody`]
102    /// and an optional [`Range`] header.
103    pub fn new(range: Option<Range>, body: B) -> Self {
104        Ranged { range, body }
105    }
106
107    /// Responds to the request, returning headers and body as
108    /// [`RangedResponse`]. Returns [`RangeNotSatisfiable`] error if requested
109    /// range in header was not satisfiable.
110    pub fn try_respond(self) -> Result<RangedResponse<B>, RangeNotSatisfiable> {
111        let total_bytes = self.body.byte_size();
112
113        // we don't support multiple byte ranges, only none or one
114        // fortunately, only responding with one of the requested ranges and
115        // no more seems to be compliant with the HTTP spec.
116        let range = self.range.and_then(|range| {
117            range.satisfiable_ranges(total_bytes).nth(0)
118        });
119
120        // pull seek positions out of range header
121        let seek_start = match range {
122            Some((Bound::Included(seek_start), _)) => seek_start,
123            _ => 0,
124        };
125
126        let seek_end_excl = match range {
127            // HTTP byte ranges are inclusive, so we translate to exclusive by adding 1:
128            Some((_, Bound::Included(end))) => {
129                if end >= total_bytes {
130                    total_bytes
131                } else {
132                    end + 1
133                }
134            },
135            _ => total_bytes,
136        };
137
138        // check seek positions and return with 416 Range Not Satisfiable if invalid
139        let seek_start_beyond_seek_end = seek_start > seek_end_excl;
140        // we could use >= above but I think this reads more clearly:
141        let zero_length_range = seek_start == seek_end_excl;
142
143        if seek_start_beyond_seek_end || zero_length_range {
144            let content_range = ContentRange::unsatisfied_bytes(total_bytes);
145            return Err(RangeNotSatisfiable(content_range));
146        }
147
148        // if we're good, build the response
149        let content_range = range.map(|_| {
150            ContentRange::bytes(seek_start..seek_end_excl, total_bytes)
151                .expect("ContentRange::bytes cannot panic in this usage")
152        });
153
154        let content_length = ContentLength(seek_end_excl - seek_start);
155
156        let stream = RangedStream::new(self.body, seek_start, content_length.0);
157
158        Ok(RangedResponse {
159            content_range,
160            content_length,
161            stream,
162        })
163    }
164}
165
166impl<B: RangeBody + Send + 'static> IntoResponse for Ranged<B> {
167    fn into_response(self) -> Response {
168        self.try_respond().into_response()
169    }
170}
171
172/// Error type indicating that the requested range was not satisfiable. Implements [`IntoResponse`].
173#[derive(Debug, Clone)]
174pub struct RangeNotSatisfiable(pub ContentRange);
175
176impl IntoResponse for RangeNotSatisfiable {
177    fn into_response(self) -> Response {
178        let status = StatusCode::RANGE_NOT_SATISFIABLE;
179        let header = TypedHeader(self.0);
180        (status, header, ()).into_response()
181    }
182}
183
184/// Data type containing computed headers and body for a range response. Implements [`IntoResponse`].
185pub struct RangedResponse<B> {
186    pub content_range: Option<ContentRange>,
187    pub content_length: ContentLength,
188    pub stream: RangedStream<B>,
189}
190
191impl<B: RangeBody + Send + 'static> IntoResponse for RangedResponse<B> {
192    fn into_response(self) -> Response {
193        let content_range = self.content_range.map(TypedHeader);
194        let content_length = TypedHeader(self.content_length);
195        let accept_ranges = TypedHeader(AcceptRanges::bytes());
196        let stream = self.stream;
197
198        let status = match content_range {
199            Some(_) => StatusCode::PARTIAL_CONTENT,
200            None => StatusCode::OK,
201        };
202
203        (status, content_range, content_length, accept_ranges, stream).into_response()
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use std::io;
210
211    use axum::http::HeaderValue;
212    use axum_extra::headers::{ContentRange, Header, Range};
213    use bytes::Bytes;
214    use futures::{pin_mut, Stream, StreamExt};
215    use tokio::fs::File;
216
217    use crate::Ranged;
218    use crate::KnownSize;
219
220    async fn collect_stream(stream: impl Stream<Item = io::Result<Bytes>>) -> String {
221        let mut string = String::new();
222        pin_mut!(stream);
223        while let Some(chunk) = stream.next().await.transpose().unwrap() {
224            string += std::str::from_utf8(&chunk).unwrap();
225        }
226        string
227    }
228
229    fn range(header: &str) -> Option<Range> {
230        let val = HeaderValue::from_str(header).unwrap();
231        Some(Range::decode(&mut [val].iter()).unwrap())
232    }
233
234    async fn body() -> KnownSize<File> {
235        let file = File::open("test/fixture.txt").await.unwrap();
236        KnownSize::file(file).await.unwrap()
237    }
238
239    #[tokio::test]
240    async fn test_full_response() {
241        let ranged = Ranged::new(None, body().await);
242
243        let response = ranged.try_respond().expect("try_respond should return Ok");
244
245        assert_eq!(54, response.content_length.0);
246        assert!(response.content_range.is_none());
247        assert_eq!("Hello world this is a file to test range requests on!\n",
248            &collect_stream(response.stream).await);
249    }
250
251    #[tokio::test]
252    async fn test_partial_response_1() {
253        let ranged = Ranged::new(range("bytes=0-29"), body().await);
254
255        let response = ranged.try_respond().expect("try_respond should return Ok");
256
257        assert_eq!(30, response.content_length.0);
258
259        let expected_content_range = ContentRange::bytes(0..30, 54).unwrap();
260        assert_eq!(Some(expected_content_range), response.content_range);
261
262        assert_eq!("Hello world this is a file to ",
263            &collect_stream(response.stream).await);
264    }
265
266    #[tokio::test]
267    async fn test_partial_response_2() {
268        let ranged = Ranged::new(range("bytes=30-53"), body().await);
269
270        let response = ranged.try_respond().expect("try_respond should return Ok");
271
272        assert_eq!(24, response.content_length.0);
273
274        let expected_content_range = ContentRange::bytes(30..54, 54).unwrap();
275        assert_eq!(Some(expected_content_range), response.content_range);
276
277        assert_eq!("test range requests on!\n",
278            &collect_stream(response.stream).await);
279    }
280
281    #[tokio::test]
282    async fn test_unbounded_start_response() {
283        // unbounded ranges in HTTP are actually a suffix
284
285        let ranged = Ranged::new(range("bytes=-20"), body().await);
286
287        let response = ranged.try_respond().expect("try_respond should return Ok");
288
289        assert_eq!(20, response.content_length.0);
290
291        let expected_content_range = ContentRange::bytes(34..54, 54).unwrap();
292        assert_eq!(Some(expected_content_range), response.content_range);
293
294        assert_eq!(" range requests on!\n",
295            &collect_stream(response.stream).await);
296    }
297
298    #[tokio::test]
299    async fn test_unbounded_end_response() {
300        let ranged = Ranged::new(range("bytes=40-"), body().await);
301
302        let response = ranged.try_respond().expect("try_respond should return Ok");
303
304        assert_eq!(14, response.content_length.0);
305
306        let expected_content_range = ContentRange::bytes(40..54, 54).unwrap();
307        assert_eq!(Some(expected_content_range), response.content_range);
308
309        assert_eq!(" requests on!\n",
310            &collect_stream(response.stream).await);
311    }
312
313    #[tokio::test]
314    async fn test_one_byte_response() {
315        let ranged = Ranged::new(range("bytes=30-30"), body().await);
316
317        let response = ranged.try_respond().expect("try_respond should return Ok");
318
319        assert_eq!(1, response.content_length.0);
320
321        let expected_content_range = ContentRange::bytes(30..31, 54).unwrap();
322        assert_eq!(Some(expected_content_range), response.content_range);
323
324        assert_eq!("t",
325            &collect_stream(response.stream).await);
326    }
327
328    #[tokio::test]
329    async fn test_invalid_range() {
330        let ranged = Ranged::new(range("bytes=30-29"), body().await);
331
332        let err = ranged.try_respond().err().expect("try_respond should return Err");
333
334        let expected_content_range = ContentRange::unsatisfied_bytes(54);
335        assert_eq!(expected_content_range, err.0)
336    }
337
338    #[tokio::test]
339    async fn test_range_end_exceed_length() {
340        let ranged = Ranged::new(range("bytes=30-99"), body().await);
341
342        let response = ranged.try_respond().expect("try_respond should return Ok");
343
344        let expected_content_range = ContentRange::bytes(30..54, 54).unwrap();
345        assert_eq!(Some(expected_content_range), response.content_range);
346
347        assert_eq!("test range requests on!\n",
348            &collect_stream(response.stream).await);
349    }
350
351    #[tokio::test]
352    async fn test_range_start_exceed_length() {
353        let ranged = Ranged::new(range("bytes=99-"), body().await);
354
355        let err = ranged.try_respond().err().expect("try_respond should return Err");
356
357        let expected_content_range = ContentRange::unsatisfied_bytes(54);
358        assert_eq!(expected_content_range, err.0)
359    }
360}