Skip to main content

axum_extra/response/
file_stream.rs

1use axum_core::{
2    body,
3    response::{IntoResponse, Response},
4    BoxError,
5};
6use bytes::Bytes;
7use futures_util::TryStream;
8use http::{header, StatusCode};
9use std::{io, path::Path};
10use tokio::{
11    fs::File,
12    io::{AsyncReadExt, AsyncSeekExt},
13};
14use tokio_util::io::ReaderStream;
15
16/// Encapsulate the file stream.
17///
18/// The encapsulated file stream construct requires passing in a stream.
19///
20/// # Examples
21///
22/// ```
23/// use axum::{
24///     http::StatusCode,
25///     response::{IntoResponse, Response},
26///     routing::get,
27///     Router,
28/// };
29/// use axum_extra::response::file_stream::FileStream;
30/// use tokio::fs::File;
31/// use tokio_util::io::ReaderStream;
32///
33/// async fn file_stream() -> Result<Response, (StatusCode, String)> {
34///     let file = File::open("test.txt")
35///         .await
36///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?;
37///
38///     let stream = ReaderStream::new(file);
39///     let file_stream_resp = FileStream::new(stream).file_name("test.txt");
40///
41///     Ok(file_stream_resp.into_response())
42/// }
43///
44/// let app = Router::new().route("/file-stream", get(file_stream));
45/// # let _: Router = app;
46/// ```
47#[must_use]
48#[derive(Debug)]
49pub struct FileStream<S> {
50    /// stream.
51    pub stream: S,
52    /// The file name of the file.
53    pub file_name: Option<String>,
54    /// The size of the file.
55    pub content_size: Option<u64>,
56}
57
58impl<S> FileStream<S>
59where
60    S: TryStream + Send + 'static,
61    S::Ok: Into<Bytes>,
62    S::Error: Into<BoxError>,
63{
64    /// Create a new [`FileStream`]
65    pub fn new(stream: S) -> Self {
66        Self {
67            stream,
68            file_name: None,
69            content_size: None,
70        }
71    }
72
73    /// Set the file name of the [`FileStream`].
74    ///
75    /// This adds the attachment `Content-Disposition` header with the given `file_name`.
76    pub fn file_name(mut self, file_name: impl Into<String>) -> Self {
77        self.file_name = Some(file_name.into());
78        self
79    }
80
81    /// Set the size of the file.
82    pub fn content_size(mut self, len: u64) -> Self {
83        self.content_size = Some(len);
84        self
85    }
86
87    /// Return a range response.
88    ///
89    /// range: (start, end, total_size)
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use axum::{
95    ///     http::StatusCode,
96    ///     response::IntoResponse,
97    ///     routing::get,
98    ///     Router,
99    /// };
100    /// use axum_extra::response::file_stream::FileStream;
101    /// use tokio::fs::File;
102    /// use tokio::io::AsyncSeekExt;
103    /// use tokio_util::io::ReaderStream;
104    ///
105    /// async fn range_response() -> Result<impl IntoResponse, (StatusCode, String)> {
106    ///     let mut file = File::open("test.txt")
107    ///         .await
108    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))?;
109    ///     let mut file_size = file
110    ///         .metadata()
111    ///         .await
112    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("Get file size: {e}")))?
113    ///         .len();
114    ///
115    ///     file.seek(std::io::SeekFrom::Start(10))
116    ///         .await
117    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File seek error: {e}")))?;
118    ///     let stream = ReaderStream::new(file);
119    ///
120    ///     Ok(FileStream::new(stream).into_range_response(10, file_size - 1, file_size))
121    /// }
122    ///
123    /// let app = Router::new().route("/file-stream", get(range_response));
124    /// # let _: Router = app;
125    /// ```
126    pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
127        let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
128        resp = resp.status(StatusCode::PARTIAL_CONTENT);
129
130        resp = resp.header(
131            header::CONTENT_RANGE,
132            format!("bytes {start}-{end}/{total_size}"),
133        );
134
135        resp.body(body::Body::from_stream(self.stream))
136            .unwrap_or_else(|e| {
137                (
138                    StatusCode::INTERNAL_SERVER_ERROR,
139                    format!("build FileStream response error: {e}"),
140                )
141                    .into_response()
142            })
143    }
144
145    /// Attempts to return RANGE requests directly from the file path.
146    ///
147    /// # Arguments
148    ///
149    /// * `file_path` - The path of the file to be streamed
150    /// * `start` - The start position of the range
151    /// * `end` - The end position of the range
152    ///
153    /// # Note
154    ///
155    /// * If `end` is 0, then it is used as `file_size - 1`
156    /// * If `start` > `file_size` or `start` > `end`, then `Range Not Satisfiable` is returned
157    ///
158    /// # Examples
159    ///
160    /// ```
161    /// use axum::{
162    ///     http::StatusCode,
163    ///     response::IntoResponse,
164    ///     Router,
165    ///     routing::get
166    /// };
167    /// use std::path::Path;
168    /// use axum_extra::response::file_stream::FileStream;
169    /// use tokio::fs::File;
170    /// use tokio_util::io::ReaderStream;
171    /// use tokio::io::AsyncSeekExt;
172    ///
173    /// async fn range_stream() -> impl IntoResponse {
174    ///     let range_start = 0;
175    ///     let range_end = 1024;
176    ///
177    ///     FileStream::<ReaderStream<File>>::try_range_response("CHANGELOG.md", range_start, range_end).await
178    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
179    /// }
180    ///
181    /// let app = Router::new().route("/file-stream", get(range_stream));
182    /// # let _: Router = app;
183    /// ```
184    pub async fn try_range_response(
185        file_path: impl AsRef<Path>,
186        start: u64,
187        mut end: u64,
188    ) -> io::Result<Response> {
189        let mut file = File::open(file_path).await?;
190
191        let metadata = file.metadata().await?;
192        let total_size = metadata.len();
193
194        if total_size == 0 {
195            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
196        }
197
198        if end == 0 {
199            end = total_size - 1;
200        }
201
202        if start > total_size {
203            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
204        }
205        if start > end {
206            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
207        }
208        if end >= total_size {
209            return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
210        }
211
212        file.seek(std::io::SeekFrom::Start(start)).await?;
213
214        let stream = ReaderStream::new(file.take(end - start + 1));
215
216        Ok(FileStream::new(stream).into_range_response(start, end, total_size))
217    }
218}
219
220// Split because the general impl requires to specify `S` and this one does not.
221impl FileStream<ReaderStream<File>> {
222    /// Create a [`FileStream`] from a file path.
223    ///
224    /// # Examples
225    ///
226    /// ```
227    /// use axum::{
228    ///     http::StatusCode,
229    ///     response::IntoResponse,
230    ///     Router,
231    ///     routing::get
232    /// };
233    /// use axum_extra::response::file_stream::FileStream;
234    ///
235    /// async fn file_stream() -> impl IntoResponse {
236    ///     FileStream::from_path("test.txt")
237    ///         .await
238    ///         .map_err(|e| (StatusCode::NOT_FOUND, format!("File not found: {e}")))
239    /// }
240    ///
241    /// let app = Router::new().route("/file-stream", get(file_stream));
242    /// # let _: Router = app;
243    /// ```
244    pub async fn from_path(path: impl AsRef<Path>) -> io::Result<Self> {
245        let file = File::open(&path).await?;
246        let mut content_size = None;
247        let mut file_name = None;
248
249        if let Ok(metadata) = file.metadata().await {
250            content_size = Some(metadata.len());
251        }
252
253        if let Some(file_name_os) = path.as_ref().file_name() {
254            if let Some(file_name_str) = file_name_os.to_str() {
255                file_name = Some(file_name_str.to_owned());
256            }
257        }
258
259        Ok(Self {
260            stream: ReaderStream::new(file),
261            file_name,
262            content_size,
263        })
264    }
265}
266
267impl<S> IntoResponse for FileStream<S>
268where
269    S: TryStream + Send + 'static,
270    S::Ok: Into<Bytes>,
271    S::Error: Into<BoxError>,
272{
273    fn into_response(self) -> Response {
274        let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
275
276        if let Some(file_name) = self.file_name {
277            resp = resp.header(
278                header::CONTENT_DISPOSITION,
279                format!(
280                    "attachment; filename=\"{}\"",
281                    super::content_disposition::EscapedFilename(&file_name)
282                ),
283            );
284        }
285
286        if let Some(content_size) = self.content_size {
287            resp = resp.header(header::CONTENT_LENGTH, content_size);
288        }
289
290        resp.body(body::Body::from_stream(self.stream))
291            .unwrap_or_else(|e| {
292                (
293                    StatusCode::INTERNAL_SERVER_ERROR,
294                    format!("build FileStream responsec error: {e}"),
295                )
296                    .into_response()
297            })
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use axum::{extract::Request, routing::get, Router};
305    use body::Body;
306    use http::HeaderMap;
307    use http_body_util::BodyExt;
308    use std::io::Cursor;
309    use tokio_util::io::ReaderStream;
310    use tower::ServiceExt;
311
312    #[tokio::test]
313    async fn response() -> Result<(), Box<dyn std::error::Error>> {
314        let app = Router::new().route(
315            "/file",
316            get(|| async {
317                // Simulating a file stream
318                let file_content = b"Hello, this is the simulated file content!".to_vec();
319                let reader = Cursor::new(file_content);
320
321                // Response file stream
322                // Content size and file name are not attached by default
323                let stream = ReaderStream::new(reader);
324                FileStream::new(stream).into_response()
325            }),
326        );
327
328        // Simulating a GET request
329        let response = app
330            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
331            .await?;
332
333        // Validate Response Status Code
334        assert_eq!(response.status(), StatusCode::OK);
335
336        // Validate Response Headers
337        assert_eq!(
338            response.headers().get("content-type").unwrap(),
339            "application/octet-stream"
340        );
341
342        // Validate Response Body
343        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
344        assert_eq!(
345            std::str::from_utf8(body)?,
346            "Hello, this is the simulated file content!"
347        );
348        Ok(())
349    }
350
351    #[tokio::test]
352    async fn response_not_set_filename() -> Result<(), Box<dyn std::error::Error>> {
353        let app = Router::new().route(
354            "/file",
355            get(|| async {
356                // Simulating a file stream
357                let file_content = b"Hello, this is the simulated file content!".to_vec();
358                let size = file_content.len() as u64;
359                let reader = Cursor::new(file_content);
360
361                // Response file stream
362                let stream = ReaderStream::new(reader);
363                FileStream::new(stream).content_size(size).into_response()
364            }),
365        );
366
367        // Simulating a GET request
368        let response = app
369            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
370            .await?;
371
372        // Validate Response Status Code
373        assert_eq!(response.status(), StatusCode::OK);
374
375        // Validate Response Headers
376        assert_eq!(
377            response.headers().get("content-type").unwrap(),
378            "application/octet-stream"
379        );
380        assert_eq!(response.headers().get("content-length").unwrap(), "42");
381
382        // Validate Response Body
383        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
384        assert_eq!(
385            std::str::from_utf8(body)?,
386            "Hello, this is the simulated file content!"
387        );
388        Ok(())
389    }
390
391    #[tokio::test]
392    async fn response_not_set_content_size() -> Result<(), Box<dyn std::error::Error>> {
393        let app = Router::new().route(
394            "/file",
395            get(|| async {
396                // Simulating a file stream
397                let file_content = b"Hello, this is the simulated file content!".to_vec();
398                let reader = Cursor::new(file_content);
399
400                // Response file stream
401                let stream = ReaderStream::new(reader);
402                FileStream::new(stream).file_name("test").into_response()
403            }),
404        );
405
406        // Simulating a GET request
407        let response = app
408            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
409            .await?;
410
411        // Validate Response Status Code
412        assert_eq!(response.status(), StatusCode::OK);
413
414        // Validate Response Headers
415        assert_eq!(
416            response.headers().get("content-type").unwrap(),
417            "application/octet-stream"
418        );
419        assert_eq!(
420            response.headers().get("content-disposition").unwrap(),
421            "attachment; filename=\"test\""
422        );
423
424        // Validate Response Body
425        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
426        assert_eq!(
427            std::str::from_utf8(body)?,
428            "Hello, this is the simulated file content!"
429        );
430        Ok(())
431    }
432
433    #[tokio::test]
434    async fn response_with_content_size_and_filename() -> Result<(), Box<dyn std::error::Error>> {
435        let app = Router::new().route(
436            "/file",
437            get(|| async {
438                // Simulating a file stream
439                let file_content = b"Hello, this is the simulated file content!".to_vec();
440                let size = file_content.len() as u64;
441                let reader = Cursor::new(file_content);
442
443                // Response file stream
444                let stream = ReaderStream::new(reader);
445                FileStream::new(stream)
446                    .file_name("test")
447                    .content_size(size)
448                    .into_response()
449            }),
450        );
451
452        // Simulating a GET request
453        let response = app
454            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
455            .await?;
456
457        // Validate Response Status Code
458        assert_eq!(response.status(), StatusCode::OK);
459
460        // Validate Response Headers
461        assert_eq!(
462            response.headers().get("content-type").unwrap(),
463            "application/octet-stream"
464        );
465        assert_eq!(
466            response.headers().get("content-disposition").unwrap(),
467            "attachment; filename=\"test\""
468        );
469        assert_eq!(response.headers().get("content-length").unwrap(), "42");
470
471        // Validate Response Body
472        let body: &[u8] = &response.into_body().collect().await?.to_bytes();
473        assert_eq!(
474            std::str::from_utf8(body)?,
475            "Hello, this is the simulated file content!"
476        );
477        Ok(())
478    }
479
480    #[tokio::test]
481    async fn response_from_path() -> Result<(), Box<dyn std::error::Error>> {
482        let app = Router::new().route(
483            "/from_path",
484            get(move || async move {
485                FileStream::from_path(Path::new("CHANGELOG.md"))
486                    .await
487                    .unwrap()
488                    .into_response()
489            }),
490        );
491
492        // Simulating a GET request
493        let response = app
494            .oneshot(
495                Request::builder()
496                    .uri("/from_path")
497                    .body(Body::empty())
498                    .unwrap(),
499            )
500            .await
501            .unwrap();
502
503        // Validate Response Status Code
504        assert_eq!(response.status(), StatusCode::OK);
505
506        // Validate Response Headers
507        assert_eq!(
508            response.headers().get("content-type").unwrap(),
509            "application/octet-stream"
510        );
511        assert_eq!(
512            response.headers().get("content-disposition").unwrap(),
513            "attachment; filename=\"CHANGELOG.md\""
514        );
515
516        let file = File::open("CHANGELOG.md").await.unwrap();
517        // get file size
518        let content_length = file.metadata().await.unwrap().len();
519
520        assert_eq!(
521            response
522                .headers()
523                .get("content-length")
524                .unwrap()
525                .to_str()
526                .unwrap(),
527            content_length.to_string()
528        );
529        Ok(())
530    }
531
532    #[tokio::test]
533    async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
534        let app = Router::new().route("/range_response", get(range_stream));
535
536        // Simulating a GET request
537        let response = app
538            .oneshot(
539                Request::builder()
540                    .uri("/range_response")
541                    .header(header::RANGE, "bytes=20-1000")
542                    .body(Body::empty())
543                    .unwrap(),
544            )
545            .await
546            .unwrap();
547
548        // Validate Response Status Code
549        assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
550
551        // Validate Response Headers
552        assert_eq!(
553            response.headers().get("content-type").unwrap(),
554            "application/octet-stream"
555        );
556
557        let file = File::open("CHANGELOG.md").await.unwrap();
558        // get file size
559        let content_length = file.metadata().await.unwrap().len();
560
561        assert_eq!(
562            response
563                .headers()
564                .get("content-range")
565                .unwrap()
566                .to_str()
567                .unwrap(),
568            format!("bytes 20-1000/{content_length}")
569        );
570        Ok(())
571    }
572
573    async fn range_stream(headers: HeaderMap) -> Response {
574        let range_header = headers
575            .get(header::RANGE)
576            .and_then(|value| value.to_str().ok());
577
578        let (start, end) = if let Some(range) = range_header {
579            if let Some(range) = parse_range_header(range) {
580                range
581            } else {
582                return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
583            }
584        } else {
585            (0, 0) // default range end = 0, if end = 0 end == file size - 1
586        };
587
588        FileStream::<ReaderStream<File>>::try_range_response(Path::new("CHANGELOG.md"), start, end)
589            .await
590            .unwrap()
591    }
592
593    fn parse_range_header(range: &str) -> Option<(u64, u64)> {
594        let range = range.strip_prefix("bytes=")?;
595        let mut parts = range.split('-');
596        let start = parts.next()?.parse::<u64>().ok()?;
597        let end = parts
598            .next()
599            .and_then(|s| s.parse::<u64>().ok())
600            .unwrap_or(0);
601        if start > end {
602            return None;
603        }
604        Some((start, end))
605    }
606
607    #[tokio::test]
608    async fn filename_escapes_quotes() -> Result<(), Box<dyn std::error::Error>> {
609        let app = Router::new().route(
610            "/file",
611            get(|| async {
612                let file_content = b"data".to_vec();
613                let reader = Cursor::new(file_content);
614                let stream = ReaderStream::new(reader);
615                // Filename containing double quotes that could cause parameter injection
616                FileStream::new(stream)
617                    .file_name("evil\"; filename*=UTF-8''pwned.txt; x=\"")
618                    .into_response()
619            }),
620        );
621
622        let response = app
623            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
624            .await?;
625
626        assert_eq!(response.status(), StatusCode::OK);
627        assert_eq!(
628            response.headers().get("content-disposition").unwrap(),
629            "attachment; filename=\"evil\\\"; filename*=UTF-8''pwned.txt; x=\\\"\""
630        );
631        Ok(())
632    }
633
634    #[tokio::test]
635    async fn filename_escapes_backslashes() -> Result<(), Box<dyn std::error::Error>> {
636        let app = Router::new().route(
637            "/file",
638            get(|| async {
639                let file_content = b"data".to_vec();
640                let reader = Cursor::new(file_content);
641                let stream = ReaderStream::new(reader);
642                FileStream::new(stream)
643                    .file_name("file\\name.txt")
644                    .into_response()
645            }),
646        );
647
648        let response = app
649            .oneshot(Request::builder().uri("/file").body(Body::empty())?)
650            .await?;
651
652        assert_eq!(response.status(), StatusCode::OK);
653        assert_eq!(
654            response.headers().get("content-disposition").unwrap(),
655            "attachment; filename=\"file\\\\name.txt\""
656        );
657        Ok(())
658    }
659
660    #[tokio::test]
661    async fn response_range_empty_file() -> Result<(), Box<dyn std::error::Error>> {
662        let file = tempfile::NamedTempFile::new()?;
663        file.as_file().set_len(0)?;
664        let path = file.path().to_owned();
665
666        let app = Router::new().route(
667            "/range_empty",
668            get(move |headers: HeaderMap| {
669                let path = path.clone();
670                async move {
671                    let range_header = headers
672                        .get(header::RANGE)
673                        .and_then(|value| value.to_str().ok());
674
675                    let (start, end) = if let Some(range) = range_header {
676                        if let Some(range) = parse_range_header(range) {
677                            range
678                        } else {
679                            return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range")
680                                .into_response();
681                        }
682                    } else {
683                        (0, 0)
684                    };
685
686                    FileStream::<ReaderStream<File>>::try_range_response(path, start, end)
687                        .await
688                        .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
689                }
690            }),
691        );
692
693        let response = app
694            .oneshot(
695                Request::builder()
696                    .uri("/range_empty")
697                    .header(header::RANGE, "bytes=0-")
698                    .body(Body::empty())
699                    .unwrap(),
700            )
701            .await
702            .unwrap();
703
704        assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
705        Ok(())
706    }
707}