axum_extra/response/
file_stream.rs

1use axum_core::{
2    body,
3    response::{IntoResponse, Response},
4    BoxError,
5};
6use bytes::Bytes;
7use futures_util::TryStream;
8use http::{header, StatusCode};
9use std::{io, path::Path};
10use tokio::{
11    fs::File,
12    io::{AsyncReadExt, AsyncSeekExt},
13};
14use tokio_util::io::ReaderStream;
15
16/// Encapsulate the file stream.
17///
18/// The encapsulated file stream construct requires passing in a stream.
19///
20/// # Examples
21///
22/// ```
23/// use axum::{
24///     http::StatusCode,
25///     response::{IntoResponse, Response},
26///     routing::get,
27///     Router,
28/// };
29/// use axum_extra::response::file_stream::FileStream;
30/// use tokio::fs::File;
31/// use tokio_util::io::ReaderStream;
32///
33/// async fn file_stream() -> Result<Response, (StatusCode, String)> {
34///     let file = File::open("test.txt")
35///         .await
36///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?;
37///
38///     let stream = ReaderStream::new(file);
39///     let file_stream_resp = FileStream::new(stream).file_name("test.txt");
40///
41///     Ok(file_stream_resp.into_response())
42/// }
43///
44/// let app = Router::new().route("/file-stream", get(file_stream));
45/// # let _: Router = app;
46/// ```
47#[must_use]
48#[derive(Debug)]
49pub struct FileStream<S> {
50    /// stream.
51    pub stream: S,
52    /// The file name of the file.
53    pub file_name: Option<String>,
54    /// The size of the file.
55    pub content_size: Option<u64>,
56}
57
58impl<S> FileStream<S>
59where
60    S: TryStream + Send + 'static,
61    S::Ok: Into<Bytes>,
62    S::Error: Into<BoxError>,
63{
64    /// Create a new [`FileStream`]
65    pub fn new(stream: S) -> Self {
66        Self {
67            stream,
68            file_name: None,
69            content_size: None,
70        }
71    }
72
73    /// Set the file name of the [`FileStream`].
74    ///
75    /// This adds the attachment `Content-Disposition` header with the given `file_name`.
76    pub fn file_name(mut self, file_name: impl Into<String>) -> Self {
77        self.file_name = Some(file_name.into());
78        self
79    }
80
81    /// Set the size of the file.
82    pub fn content_size(mut self, len: u64) -> Self {
83        self.content_size = Some(len);
84        self
85    }
86
87    /// Return a range response.
88    ///
89    /// range: (start, end, total_size)
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use axum::{
95    ///     http::StatusCode,
96    ///     response::IntoResponse,
97    ///     routing::get,
98    ///     Router,
99    /// };
100    /// use axum_extra::response::file_stream::FileStream;
101    /// use tokio::fs::File;
102    /// use tokio::io::AsyncSeekExt;
103    /// use tokio_util::io::ReaderStream;
104    ///
105    /// async fn range_response() -> Result<impl IntoResponse, (StatusCode, String)> {
106    ///     let mut file = File::open("test.txt")
107    ///         .await
108    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?;
109    ///     let mut file_size = file
110    ///         .metadata()
111    ///         .await
112    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("Get file size: {e}")))?
113    ///         .len();
114    ///
115    ///     file.seek(std::io::SeekFrom::Start(10))
116    ///         .await
117    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File seek error: {e}")))?;
118    ///     let stream = ReaderStream::new(file);
119    ///
120    ///     Ok(FileStream::new(stream).into_range_response(10, file_size - 1, file_size))
121    /// }
122    ///
123    /// let app = Router::new().route("/file-stream", get(range_response));
124    /// # let _: Router = app;
125    /// ```
126    pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
127        let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
128        resp = resp.status(StatusCode::PARTIAL_CONTENT);
129
130        resp = resp.header(
131            header::CONTENT_RANGE,
132            format!("bytes {start}-{end}/{total_size}"),
133        );
134
135        resp.body(body::Body::from_stream(self.stream))
136            .unwrap_or_else(|e| {
137                (
138                    StatusCode::INTERNAL_SERVER_ERROR,
139                    format!("build FileStream response error: {e}"),
140                )
141                    .into_response()
142            })
143    }
144
145    /// Attempts to return RANGE requests directly from the file path.
146    ///
147    /// # Arguments
148    ///
149    /// * `file_path` - The path of the file to be streamed
150    /// * `start` - The start position of the range
151    /// * `end` - The end position of the range
152    ///
153    /// # Note
154    ///
155    /// * If `end` is 0, then it is used as `file_size - 1`
156    /// * If `start` > `file_size` or `start` > `end`, then `Range Not Satisfiable` is returned
157    ///
158    /// # Examples
159    ///
160    /// ```
161    /// use axum::{
162    ///     http::StatusCode,
163    ///     response::IntoResponse,
164    ///     Router,
165    ///     routing::get
166    /// };
167    /// use std::path::Path;
168    /// use axum_extra::response::file_stream::FileStream;
169    /// use tokio::fs::File;
170    /// use tokio_util::io::ReaderStream;
171    /// use tokio::io::AsyncSeekExt;
172    ///
173    /// async fn range_stream() -> impl IntoResponse {
174    ///     let range_start = 0;
175    ///     let range_end = 1024;
176    ///
177    ///     FileStream::<ReaderStream<File>>::try_range_response("CHANGELOG.md", range_start, range_end).await
178    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
179    /// }
180    ///
181    /// let app = Router::new().route("/file-stream", get(range_stream));
182    /// # let _: Router = app;
183    /// ```
184    pub async fn try_range_response(
185        file_path: impl AsRef<Path>,
186        start: u64,
187        mut end: u64,
188    ) -> io::Result<Response> {
189        let mut file = File::open(file_path).await?;
190
191        let metadata = file.metadata().await?;
192        let total_size = metadata.len();
193
194        if total_size == 0 {
195            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
196        }
197
198        if end == 0 {
199            end = total_size - 1;
200        }
201
202        if start > total_size {
203            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
204        }
205        if start > end {
206            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
207        }
208        if end >= total_size {
209            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
210        }
211
212        file.seek(std::io::SeekFrom::Start(start)).await?;
213
214        let stream = ReaderStream::new(file.take(end - start + 1));
215
216        Ok(FileStream::new(stream).into_range_response(start, end, total_size))
217    }
218}
219
220// Split because the general impl requires to specify `S` and this one does not.
221impl FileStream<ReaderStream<File>> {
222    /// Create a [`FileStream`] from a file path.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// use axum::{
228    ///     http::StatusCode,
229    ///     response::IntoResponse,
230    ///     Router,
231    ///     routing::get
232    /// };
233    /// use axum_extra::response::file_stream::FileStream;
234    ///
235    /// async fn file_stream() -> impl IntoResponse {
236    ///     FileStream::from_path("test.txt")
237    ///         .await
238    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
239    /// }
240    ///
241    /// let app = Router::new().route("/file-stream", get(file_stream));
242    /// # let _: Router = app;
243    /// ```
244    pub async fn from_path(path: impl AsRef<Path>) -> io::Result<Self> {
245        let file = File::open(&path).await?;
246        let mut content_size = None;
247        let mut file_name = None;
248
249        if let Ok(metadata) = file.metadata().await {
250            content_size = Some(metadata.len());
251        }
252
253        if let Some(file_name_os) = path.as_ref().file_name() {
254            if let Some(file_name_str) = file_name_os.to_str() {
255                file_name = Some(file_name_str.to_owned());
256            }
257        }
258
259        Ok(Self {
260            stream: ReaderStream::new(file),
261            file_name,
262            content_size,
263        })
264    }
265}
266
267impl<S> IntoResponse for FileStream<S>
268where
269    S: TryStream + Send + 'static,
270    S::Ok: Into<Bytes>,
271    S::Error: Into<BoxError>,
272{
273    fn into_response(self) -> Response {
274        let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
275
276        if let Some(file_name) = self.file_name {
277            resp = resp.header(
278                header::CONTENT_DISPOSITION,
279                format!("attachment; filename=\"{file_name}\""),
280            );
281        }
282
283        if let Some(content_size) = self.content_size {
284            resp = resp.header(header::CONTENT_LENGTH, content_size);
285        }
286
287        resp.body(body::Body::from_stream(self.stream))
288            .unwrap_or_else(|e| {
289                (
290                    StatusCode::INTERNAL_SERVER_ERROR,
291                    format!("build FileStream responsec error: {e}"),
292                )
293                    .into_response()
294            })
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use axum::{extract::Request, routing::get, Router};
302    use body::Body;
303    use http::HeaderMap;
304    use http_body_util::BodyExt;
305    use std::io::Cursor;
306    use tokio_util::io::ReaderStream;
307    use tower::ServiceExt;
308
309    #[tokio::test]
310    async fn response() -> Result<(), Box<dyn std::error::Error>> {
311        let app = Router::new().route(
312            "/file",
313            get(|| async {
314                // Simulating a file stream
315                let file_content = b"Hello, this is the simulated file content!".to_vec();
316                let reader = Cursor::new(file_content);
317
318                // Response file stream
319                // Content size and file name are not attached by default
320                let stream = ReaderStream::new(reader);
321                FileStream::new(stream).into_response()
322            }),
323        );
324
325        // Simulating a GET request
326        let response = app
327            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
328            .await?;
329
330        // Validate Response Status Code
331        assert_eq!(response.status(), StatusCode::OK);
332
333        // Validate Response Headers
334        assert_eq!(
335            response.headers().get("content-type").unwrap(),
336            "application/octet-stream"
337        );
338
339        // Validate Response Body
340        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
341        assert_eq!(
342            std::str::from_utf8(body)?,
343            "Hello, this is the simulated file content!"
344        );
345        Ok(())
346    }
347
348    #[tokio::test]
349    async fn response_not_set_filename() -> Result<(), Box<dyn std::error::Error>> {
350        let app = Router::new().route(
351            "/file",
352            get(|| async {
353                // Simulating a file stream
354                let file_content = b"Hello, this is the simulated file content!".to_vec();
355                let size = file_content.len() as u64;
356                let reader = Cursor::new(file_content);
357
358                // Response file stream
359                let stream = ReaderStream::new(reader);
360                FileStream::new(stream).content_size(size).into_response()
361            }),
362        );
363
364        // Simulating a GET request
365        let response = app
366            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
367            .await?;
368
369        // Validate Response Status Code
370        assert_eq!(response.status(), StatusCode::OK);
371
372        // Validate Response Headers
373        assert_eq!(
374            response.headers().get("content-type").unwrap(),
375            "application/octet-stream"
376        );
377        assert_eq!(response.headers().get("content-length").unwrap(), "42");
378
379        // Validate Response Body
380        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
381        assert_eq!(
382            std::str::from_utf8(body)?,
383            "Hello, this is the simulated file content!"
384        );
385        Ok(())
386    }
387
388    #[tokio::test]
389    async fn response_not_set_content_size() -> Result<(), Box<dyn std::error::Error>> {
390        let app = Router::new().route(
391            "/file",
392            get(|| async {
393                // Simulating a file stream
394                let file_content = b"Hello, this is the simulated file content!".to_vec();
395                let reader = Cursor::new(file_content);
396
397                // Response file stream
398                let stream = ReaderStream::new(reader);
399                FileStream::new(stream).file_name("test").into_response()
400            }),
401        );
402
403        // Simulating a GET request
404        let response = app
405            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
406            .await?;
407
408        // Validate Response Status Code
409        assert_eq!(response.status(), StatusCode::OK);
410
411        // Validate Response Headers
412        assert_eq!(
413            response.headers().get("content-type").unwrap(),
414            "application/octet-stream"
415        );
416        assert_eq!(
417            response.headers().get("content-disposition").unwrap(),
418            "attachment; filename=\"test\""
419        );
420
421        // Validate Response Body
422        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
423        assert_eq!(
424            std::str::from_utf8(body)?,
425            "Hello, this is the simulated file content!"
426        );
427        Ok(())
428    }
429
430    #[tokio::test]
431    async fn response_with_content_size_and_filename() -> Result<(), Box<dyn std::error::Error>> {
432        let app = Router::new().route(
433            "/file",
434            get(|| async {
435                // Simulating a file stream
436                let file_content = b"Hello, this is the simulated file content!".to_vec();
437                let size = file_content.len() as u64;
438                let reader = Cursor::new(file_content);
439
440                // Response file stream
441                let stream = ReaderStream::new(reader);
442                FileStream::new(stream)
443                    .file_name("test")
444                    .content_size(size)
445                    .into_response()
446            }),
447        );
448
449        // Simulating a GET request
450        let response = app
451            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
452            .await?;
453
454        // Validate Response Status Code
455        assert_eq!(response.status(), StatusCode::OK);
456
457        // Validate Response Headers
458        assert_eq!(
459            response.headers().get("content-type").unwrap(),
460            "application/octet-stream"
461        );
462        assert_eq!(
463            response.headers().get("content-disposition").unwrap(),
464            "attachment; filename=\"test\""
465        );
466        assert_eq!(response.headers().get("content-length").unwrap(), "42");
467
468        // Validate Response Body
469        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
470        assert_eq!(
471            std::str::from_utf8(body)?,
472            "Hello, this is the simulated file content!"
473        );
474        Ok(())
475    }
476
477    #[tokio::test]
478    async fn response_from_path() -> Result<(), Box<dyn std::error::Error>> {
479        let app = Router::new().route(
480            "/from_path",
481            get(move || async move {
482                FileStream::from_path(Path::new("CHANGELOG.md"))
483                    .await
484                    .unwrap()
485                    .into_response()
486            }),
487        );
488
489        // Simulating a GET request
490        let response = app
491            .oneshot(
492                Request::builder()
493                    .uri("/from_path")
494                    .body(Body::empty())
495                    .unwrap(),
496            )
497            .await
498            .unwrap();
499
500        // Validate Response Status Code
501        assert_eq!(response.status(), StatusCode::OK);
502
503        // Validate Response Headers
504        assert_eq!(
505            response.headers().get("content-type").unwrap(),
506            "application/octet-stream"
507        );
508        assert_eq!(
509            response.headers().get("content-disposition").unwrap(),
510            "attachment; filename=\"CHANGELOG.md\""
511        );
512
513        let file = File::open("CHANGELOG.md").await.unwrap();
514        // get file size
515        let content_length = file.metadata().await.unwrap().len();
516
517        assert_eq!(
518            response
519                .headers()
520                .get("content-length")
521                .unwrap()
522                .to_str()
523                .unwrap(),
524            content_length.to_string()
525        );
526        Ok(())
527    }
528
529    #[tokio::test]
530    async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
531        let app = Router::new().route("/range_response", get(range_stream));
532
533        // Simulating a GET request
534        let response = app
535            .oneshot(
536                Request::builder()
537                    .uri("/range_response")
538                    .header(header::RANGE, "bytes=20-1000")
539                    .body(Body::empty())
540                    .unwrap(),
541            )
542            .await
543            .unwrap();
544
545        // Validate Response Status Code
546        assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
547
548        // Validate Response Headers
549        assert_eq!(
550            response.headers().get("content-type").unwrap(),
551            "application/octet-stream"
552        );
553
554        let file = File::open("CHANGELOG.md").await.unwrap();
555        // get file size
556        let content_length = file.metadata().await.unwrap().len();
557
558        assert_eq!(
559            response
560                .headers()
561                .get("content-range")
562                .unwrap()
563                .to_str()
564                .unwrap(),
565            format!("bytes 20-1000/{content_length}")
566        );
567        Ok(())
568    }
569
570    async fn range_stream(headers: HeaderMap) -> Response {
571        let range_header = headers
572            .get(header::RANGE)
573            .and_then(|value| value.to_str().ok());
574
575        let (start, end) = if let Some(range) = range_header {
576            if let Some(range) = parse_range_header(range) {
577                range
578            } else {
579                return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
580            }
581        } else {
582            (0, 0) // default range end = 0, if end = 0 end == file size - 1
583        };
584
585        FileStream::<ReaderStream<File>>::try_range_response(Path::new("CHANGELOG.md"), start, end)
586            .await
587            .unwrap()
588    }
589
590    fn parse_range_header(range: &str) -> Option<(u64, u64)> {
591        let range = range.strip_prefix("bytes=")?;
592        let mut parts = range.split('-');
593        let start = parts.next()?.parse::<u64>().ok()?;
594        let end = parts
595            .next()
596            .and_then(|s| s.parse::<u64>().ok())
597            .unwrap_or(0);
598        if start > end {
599            return None;
600        }
601        Some((start, end))
602    }
603
604    #[tokio::test]
605    async fn response_range_empty_file() -> Result<(), Box<dyn std::error::Error>> {
606        let file = tempfile::NamedTempFile::new()?;
607        file.as_file().set_len(0)?;
608        let path = file.path().to_owned();
609
610        let app = Router::new().route(
611            "/range_empty",
612            get(move |headers: HeaderMap| {
613                let path = path.clone();
614                async move {
615                    let range_header = headers
616                        .get(header::RANGE)
617                        .and_then(|value| value.to_str().ok());
618
619                    let (start, end) = if let Some(range) = range_header {
620                        if let Some(range) = parse_range_header(range) {
621                            range
622                        } else {
623                            return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range")
624                                .into_response();
625                        }
626                    } else {
627                        (0, 0)
628                    };
629
630                    FileStream::<ReaderStream<File>>::try_range_response(path, start, end)
631                        .await
632                        .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
633                }
634            }),
635        );
636
637        let response = app
638            .oneshot(
639                Request::builder()
640                    .uri("/range_empty")
641                    .header(header::RANGE, "bytes=0-")
642                    .body(Body::empty())
643                    .unwrap(),
644            )
645            .await
646            .unwrap();
647
648        assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
649        Ok(())
650    }
651}