1mod file;
50mod stream;
51
52use std::io;
53use std::ops::Bound;
54use std::pin::Pin;
55use std::task::{Context, Poll};
56
57use axum::http::StatusCode;
58use axum::response::{IntoResponse, Response};
59use axum_extra::TypedHeader;
60use axum_extra::headers::{Range, ContentRange, ContentLength, AcceptRanges};
61use tokio::io::{AsyncRead, AsyncSeek};
62
63pub use file::KnownSize;
64pub use stream::RangedStream;
65
66pub trait AsyncSeekStart {
68 fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()>;
70
71 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>;
73}
74
75impl<T: AsyncSeek> AsyncSeekStart for T {
76 fn start_seek(self: Pin<&mut Self>, position: u64) -> io::Result<()> {
77 AsyncSeek::start_seek(self, io::SeekFrom::Start(position))
78 }
79
80 fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81 AsyncSeek::poll_complete(self, cx).map_ok(|_| ())
82 }
83}
84
85pub trait RangeBody: AsyncRead + AsyncSeekStart {
87 fn byte_size(&self) -> u64;
92}
93
94pub struct Ranged<B: RangeBody + Send + 'static> {
96 range: Option<Range>,
97 body: B,
98}
99
100impl<B: RangeBody + Send + 'static> Ranged<B> {
101 pub fn new(range: Option<Range>, body: B) -> Self {
104 Ranged { range, body }
105 }
106
107 pub fn try_respond(self) -> Result<RangedResponse<B>, RangeNotSatisfiable> {
111 let total_bytes = self.body.byte_size();
112
113 let range = self.range.and_then(|range| {
117 range.satisfiable_ranges(total_bytes).nth(0)
118 });
119
120 let seek_start = match range {
122 Some((Bound::Included(seek_start), _)) => seek_start,
123 _ => 0,
124 };
125
126 let seek_end_excl = match range {
127 Some((_, Bound::Included(end))) => {
129 if end >= total_bytes {
130 total_bytes
131 } else {
132 end + 1
133 }
134 },
135 _ => total_bytes,
136 };
137
138 let seek_start_beyond_seek_end = seek_start > seek_end_excl;
140 let zero_length_range = seek_start == seek_end_excl;
142
143 if seek_start_beyond_seek_end || zero_length_range {
144 let content_range = ContentRange::unsatisfied_bytes(total_bytes);
145 return Err(RangeNotSatisfiable(content_range));
146 }
147
148 let content_range = range.map(|_| {
150 ContentRange::bytes(seek_start..seek_end_excl, total_bytes)
151 .expect("ContentRange::bytes cannot panic in this usage")
152 });
153
154 let content_length = ContentLength(seek_end_excl - seek_start);
155
156 let stream = RangedStream::new(self.body, seek_start, content_length.0);
157
158 Ok(RangedResponse {
159 content_range,
160 content_length,
161 stream,
162 })
163 }
164}
165
166impl<B: RangeBody + Send + 'static> IntoResponse for Ranged<B> {
167 fn into_response(self) -> Response {
168 self.try_respond().into_response()
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct RangeNotSatisfiable(pub ContentRange);
175
176impl IntoResponse for RangeNotSatisfiable {
177 fn into_response(self) -> Response {
178 let status = StatusCode::RANGE_NOT_SATISFIABLE;
179 let header = TypedHeader(self.0);
180 (status, header, ()).into_response()
181 }
182}
183
184pub struct RangedResponse<B> {
186 pub content_range: Option<ContentRange>,
187 pub content_length: ContentLength,
188 pub stream: RangedStream<B>,
189}
190
191impl<B: RangeBody + Send + 'static> IntoResponse for RangedResponse<B> {
192 fn into_response(self) -> Response {
193 let content_range = self.content_range.map(TypedHeader);
194 let content_length = TypedHeader(self.content_length);
195 let accept_ranges = TypedHeader(AcceptRanges::bytes());
196 let stream = self.stream;
197
198 let status = match content_range {
199 Some(_) => StatusCode::PARTIAL_CONTENT,
200 None => StatusCode::OK,
201 };
202
203 (status, content_range, content_length, accept_ranges, stream).into_response()
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use std::io;
210
211 use axum::http::HeaderValue;
212 use axum_extra::headers::{ContentRange, Header, Range};
213 use bytes::Bytes;
214 use futures::{pin_mut, Stream, StreamExt};
215 use tokio::fs::File;
216
217 use crate::Ranged;
218 use crate::KnownSize;
219
220 async fn collect_stream(stream: impl Stream<Item = io::Result<Bytes>>) -> String {
221 let mut string = String::new();
222 pin_mut!(stream);
223 while let Some(chunk) = stream.next().await.transpose().unwrap() {
224 string += std::str::from_utf8(&chunk).unwrap();
225 }
226 string
227 }
228
229 fn range(header: &str) -> Option<Range> {
230 let val = HeaderValue::from_str(header).unwrap();
231 Some(Range::decode(&mut [val].iter()).unwrap())
232 }
233
234 async fn body() -> KnownSize<File> {
235 let file = File::open("test/fixture.txt").await.unwrap();
236 KnownSize::file(file).await.unwrap()
237 }
238
239 #[tokio::test]
240 async fn test_full_response() {
241 let ranged = Ranged::new(None, body().await);
242
243 let response = ranged.try_respond().expect("try_respond should return Ok");
244
245 assert_eq!(54, response.content_length.0);
246 assert!(response.content_range.is_none());
247 assert_eq!("Hello world this is a file to test range requests on!\n",
248 &collect_stream(response.stream).await);
249 }
250
251 #[tokio::test]
252 async fn test_partial_response_1() {
253 let ranged = Ranged::new(range("bytes=0-29"), body().await);
254
255 let response = ranged.try_respond().expect("try_respond should return Ok");
256
257 assert_eq!(30, response.content_length.0);
258
259 let expected_content_range = ContentRange::bytes(0..30, 54).unwrap();
260 assert_eq!(Some(expected_content_range), response.content_range);
261
262 assert_eq!("Hello world this is a file to ",
263 &collect_stream(response.stream).await);
264 }
265
266 #[tokio::test]
267 async fn test_partial_response_2() {
268 let ranged = Ranged::new(range("bytes=30-53"), body().await);
269
270 let response = ranged.try_respond().expect("try_respond should return Ok");
271
272 assert_eq!(24, response.content_length.0);
273
274 let expected_content_range = ContentRange::bytes(30..54, 54).unwrap();
275 assert_eq!(Some(expected_content_range), response.content_range);
276
277 assert_eq!("test range requests on!\n",
278 &collect_stream(response.stream).await);
279 }
280
281 #[tokio::test]
282 async fn test_unbounded_start_response() {
283 let ranged = Ranged::new(range("bytes=-20"), body().await);
286
287 let response = ranged.try_respond().expect("try_respond should return Ok");
288
289 assert_eq!(20, response.content_length.0);
290
291 let expected_content_range = ContentRange::bytes(34..54, 54).unwrap();
292 assert_eq!(Some(expected_content_range), response.content_range);
293
294 assert_eq!(" range requests on!\n",
295 &collect_stream(response.stream).await);
296 }
297
298 #[tokio::test]
299 async fn test_unbounded_end_response() {
300 let ranged = Ranged::new(range("bytes=40-"), body().await);
301
302 let response = ranged.try_respond().expect("try_respond should return Ok");
303
304 assert_eq!(14, response.content_length.0);
305
306 let expected_content_range = ContentRange::bytes(40..54, 54).unwrap();
307 assert_eq!(Some(expected_content_range), response.content_range);
308
309 assert_eq!(" requests on!\n",
310 &collect_stream(response.stream).await);
311 }
312
313 #[tokio::test]
314 async fn test_one_byte_response() {
315 let ranged = Ranged::new(range("bytes=30-30"), body().await);
316
317 let response = ranged.try_respond().expect("try_respond should return Ok");
318
319 assert_eq!(1, response.content_length.0);
320
321 let expected_content_range = ContentRange::bytes(30..31, 54).unwrap();
322 assert_eq!(Some(expected_content_range), response.content_range);
323
324 assert_eq!("t",
325 &collect_stream(response.stream).await);
326 }
327
328 #[tokio::test]
329 async fn test_invalid_range() {
330 let ranged = Ranged::new(range("bytes=30-29"), body().await);
331
332 let err = ranged.try_respond().err().expect("try_respond should return Err");
333
334 let expected_content_range = ContentRange::unsatisfied_bytes(54);
335 assert_eq!(expected_content_range, err.0)
336 }
337
338 #[tokio::test]
339 async fn test_range_end_exceed_length() {
340 let ranged = Ranged::new(range("bytes=30-99"), body().await);
341
342 let response = ranged.try_respond().expect("try_respond should return Ok");
343
344 let expected_content_range = ContentRange::bytes(30..54, 54).unwrap();
345 assert_eq!(Some(expected_content_range), response.content_range);
346
347 assert_eq!("test range requests on!\n",
348 &collect_stream(response.stream).await);
349 }
350
351 #[tokio::test]
352 async fn test_range_start_exceed_length() {
353 let ranged = Ranged::new(range("bytes=99-"), body().await);
354
355 let err = ranged.try_respond().err().expect("try_respond should return Err");
356
357 let expected_content_range = ContentRange::unsatisfied_bytes(54);
358 assert_eq!(expected_content_range, err.0)
359 }
360}