1use axum_core::{
2 body,
3 response::{IntoResponse, Response},
4 BoxError,
5};
6use bytes::Bytes;
7use futures_util::TryStream;
8use http::{header, StatusCode};
9use std::{io, path::Path};
10use tokio::{
11 fs::File,
12 io::{AsyncReadExt, AsyncSeekExt},
13};
14use tokio_util::io::ReaderStream;
15
16#[must_use]
48#[derive(Debug)]
49pub struct FileStream<S> {
50 pub stream: S,
52 pub file_name: Option<String>,
54 pub content_size: Option<u64>,
56}
57
58impl<S> FileStream<S>
59where
60 S: TryStream + Send + 'static,
61 S::Ok: Into<Bytes>,
62 S::Error: Into<BoxError>,
63{
64 pub fn new(stream: S) -> Self {
66 Self {
67 stream,
68 file_name: None,
69 content_size: None,
70 }
71 }
72
73 pub fn file_name(mut self, file_name: impl Into<String>) -> Self {
77 self.file_name = Some(file_name.into());
78 self
79 }
80
81 pub fn content_size(mut self, len: u64) -> Self {
83 self.content_size = Some(len);
84 self
85 }
86
87 pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
127 let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
128 resp = resp.status(StatusCode::PARTIAL_CONTENT);
129
130 resp = resp.header(
131 header::CONTENT_RANGE,
132 format!("bytes {start}-{end}/{total_size}"),
133 );
134
135 resp.body(body::Body::from_stream(self.stream))
136 .unwrap_or_else(|e| {
137 (
138 StatusCode::INTERNAL_SERVER_ERROR,
139 format!("build FileStream response error: {e}"),
140 )
141 .into_response()
142 })
143 }
144
145 pub async fn try_range_response(
185 file_path: impl AsRef<Path>,
186 start: u64,
187 mut end: u64,
188 ) -> io::Result<Response> {
189 let mut file = File::open(file_path).await?;
190
191 let metadata = file.metadata().await?;
192 let total_size = metadata.len();
193
194 if total_size == 0 {
195 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
196 }
197
198 if end == 0 {
199 end = total_size - 1;
200 }
201
202 if start > total_size {
203 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
204 }
205 if start > end {
206 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
207 }
208 if end >= total_size {
209 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
210 }
211
212 file.seek(std::io::SeekFrom::Start(start)).await?;
213
214 let stream = ReaderStream::new(file.take(end - start + 1));
215
216 Ok(FileStream::new(stream).into_range_response(start, end, total_size))
217 }
218}
219
220impl FileStream<ReaderStream<File>> {
222 pub async fn from_path(path: impl AsRef<Path>) -> io::Result<Self> {
245 let file = File::open(&path).await?;
246 let mut content_size = None;
247 let mut file_name = None;
248
249 if let Ok(metadata) = file.metadata().await {
250 content_size = Some(metadata.len());
251 }
252
253 if let Some(file_name_os) = path.as_ref().file_name() {
254 if let Some(file_name_str) = file_name_os.to_str() {
255 file_name = Some(file_name_str.to_owned());
256 }
257 }
258
259 Ok(Self {
260 stream: ReaderStream::new(file),
261 file_name,
262 content_size,
263 })
264 }
265}
266
267impl<S> IntoResponse for FileStream<S>
268where
269 S: TryStream + Send + 'static,
270 S::Ok: Into<Bytes>,
271 S::Error: Into<BoxError>,
272{
273 fn into_response(self) -> Response {
274 let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
275
276 if let Some(file_name) = self.file_name {
277 resp = resp.header(
278 header::CONTENT_DISPOSITION,
279 format!(
280 "attachment; filename=\"{}\"",
281 super::content_disposition::EscapedFilename(&file_name)
282 ),
283 );
284 }
285
286 if let Some(content_size) = self.content_size {
287 resp = resp.header(header::CONTENT_LENGTH, content_size);
288 }
289
290 resp.body(body::Body::from_stream(self.stream))
291 .unwrap_or_else(|e| {
292 (
293 StatusCode::INTERNAL_SERVER_ERROR,
294 format!("build FileStream responsec error: {e}"),
295 )
296 .into_response()
297 })
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use axum::{extract::Request, routing::get, Router};
305 use body::Body;
306 use http::HeaderMap;
307 use http_body_util::BodyExt;
308 use std::io::Cursor;
309 use tokio_util::io::ReaderStream;
310 use tower::ServiceExt;
311
312 #[tokio::test]
313 async fn response() -> Result<(), Box<dyn std::error::Error>> {
314 let app = Router::new().route(
315 "/file",
316 get(|| async {
317 let file_content = b"Hello, this is the simulated file content!".to_vec();
319 let reader = Cursor::new(file_content);
320
321 let stream = ReaderStream::new(reader);
324 FileStream::new(stream).into_response()
325 }),
326 );
327
328 let response = app
330 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
331 .await?;
332
333 assert_eq!(response.status(), StatusCode::OK);
335
336 assert_eq!(
338 response.headers().get("content-type").unwrap(),
339 "application/octet-stream"
340 );
341
342 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
344 assert_eq!(
345 std::str::from_utf8(body)?,
346 "Hello, this is the simulated file content!"
347 );
348 Ok(())
349 }
350
351 #[tokio::test]
352 async fn response_not_set_filename() -> Result<(), Box<dyn std::error::Error>> {
353 let app = Router::new().route(
354 "/file",
355 get(|| async {
356 let file_content = b"Hello, this is the simulated file content!".to_vec();
358 let size = file_content.len() as u64;
359 let reader = Cursor::new(file_content);
360
361 let stream = ReaderStream::new(reader);
363 FileStream::new(stream).content_size(size).into_response()
364 }),
365 );
366
367 let response = app
369 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
370 .await?;
371
372 assert_eq!(response.status(), StatusCode::OK);
374
375 assert_eq!(
377 response.headers().get("content-type").unwrap(),
378 "application/octet-stream"
379 );
380 assert_eq!(response.headers().get("content-length").unwrap(), "42");
381
382 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
384 assert_eq!(
385 std::str::from_utf8(body)?,
386 "Hello, this is the simulated file content!"
387 );
388 Ok(())
389 }
390
391 #[tokio::test]
392 async fn response_not_set_content_size() -> Result<(), Box<dyn std::error::Error>> {
393 let app = Router::new().route(
394 "/file",
395 get(|| async {
396 let file_content = b"Hello, this is the simulated file content!".to_vec();
398 let reader = Cursor::new(file_content);
399
400 let stream = ReaderStream::new(reader);
402 FileStream::new(stream).file_name("test").into_response()
403 }),
404 );
405
406 let response = app
408 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
409 .await?;
410
411 assert_eq!(response.status(), StatusCode::OK);
413
414 assert_eq!(
416 response.headers().get("content-type").unwrap(),
417 "application/octet-stream"
418 );
419 assert_eq!(
420 response.headers().get("content-disposition").unwrap(),
421 "attachment; filename=\"test\""
422 );
423
424 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
426 assert_eq!(
427 std::str::from_utf8(body)?,
428 "Hello, this is the simulated file content!"
429 );
430 Ok(())
431 }
432
433 #[tokio::test]
434 async fn response_with_content_size_and_filename() -> Result<(), Box<dyn std::error::Error>> {
435 let app = Router::new().route(
436 "/file",
437 get(|| async {
438 let file_content = b"Hello, this is the simulated file content!".to_vec();
440 let size = file_content.len() as u64;
441 let reader = Cursor::new(file_content);
442
443 let stream = ReaderStream::new(reader);
445 FileStream::new(stream)
446 .file_name("test")
447 .content_size(size)
448 .into_response()
449 }),
450 );
451
452 let response = app
454 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
455 .await?;
456
457 assert_eq!(response.status(), StatusCode::OK);
459
460 assert_eq!(
462 response.headers().get("content-type").unwrap(),
463 "application/octet-stream"
464 );
465 assert_eq!(
466 response.headers().get("content-disposition").unwrap(),
467 "attachment; filename=\"test\""
468 );
469 assert_eq!(response.headers().get("content-length").unwrap(), "42");
470
471 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
473 assert_eq!(
474 std::str::from_utf8(body)?,
475 "Hello, this is the simulated file content!"
476 );
477 Ok(())
478 }
479
480 #[tokio::test]
481 async fn response_from_path() -> Result<(), Box<dyn std::error::Error>> {
482 let app = Router::new().route(
483 "/from_path",
484 get(move || async move {
485 FileStream::from_path(Path::new("CHANGELOG.md"))
486 .await
487 .unwrap()
488 .into_response()
489 }),
490 );
491
492 let response = app
494 .oneshot(
495 Request::builder()
496 .uri("/from_path")
497 .body(Body::empty())
498 .unwrap(),
499 )
500 .await
501 .unwrap();
502
503 assert_eq!(response.status(), StatusCode::OK);
505
506 assert_eq!(
508 response.headers().get("content-type").unwrap(),
509 "application/octet-stream"
510 );
511 assert_eq!(
512 response.headers().get("content-disposition").unwrap(),
513 "attachment; filename=\"CHANGELOG.md\""
514 );
515
516 let file = File::open("CHANGELOG.md").await.unwrap();
517 let content_length = file.metadata().await.unwrap().len();
519
520 assert_eq!(
521 response
522 .headers()
523 .get("content-length")
524 .unwrap()
525 .to_str()
526 .unwrap(),
527 content_length.to_string()
528 );
529 Ok(())
530 }
531
532 #[tokio::test]
533 async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
534 let app = Router::new().route("/range_response", get(range_stream));
535
536 let response = app
538 .oneshot(
539 Request::builder()
540 .uri("/range_response")
541 .header(header::RANGE, "bytes=20-1000")
542 .body(Body::empty())
543 .unwrap(),
544 )
545 .await
546 .unwrap();
547
548 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
550
551 assert_eq!(
553 response.headers().get("content-type").unwrap(),
554 "application/octet-stream"
555 );
556
557 let file = File::open("CHANGELOG.md").await.unwrap();
558 let content_length = file.metadata().await.unwrap().len();
560
561 assert_eq!(
562 response
563 .headers()
564 .get("content-range")
565 .unwrap()
566 .to_str()
567 .unwrap(),
568 format!("bytes 20-1000/{content_length}")
569 );
570 Ok(())
571 }
572
573 async fn range_stream(headers: HeaderMap) -> Response {
574 let range_header = headers
575 .get(header::RANGE)
576 .and_then(|value| value.to_str().ok());
577
578 let (start, end) = if let Some(range) = range_header {
579 if let Some(range) = parse_range_header(range) {
580 range
581 } else {
582 return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
583 }
584 } else {
585 (0, 0) };
587
588 FileStream::<ReaderStream<File>>::try_range_response(Path::new("CHANGELOG.md"), start, end)
589 .await
590 .unwrap()
591 }
592
593 fn parse_range_header(range: &str) -> Option<(u64, u64)> {
594 let range = range.strip_prefix("bytes=")?;
595 let mut parts = range.split('-');
596 let start = parts.next()?.parse::<u64>().ok()?;
597 let end = parts
598 .next()
599 .and_then(|s| s.parse::<u64>().ok())
600 .unwrap_or(0);
601 if start > end {
602 return None;
603 }
604 Some((start, end))
605 }
606
607 #[tokio::test]
608 async fn filename_escapes_quotes() -> Result<(), Box<dyn std::error::Error>> {
609 let app = Router::new().route(
610 "/file",
611 get(|| async {
612 let file_content = b"data".to_vec();
613 let reader = Cursor::new(file_content);
614 let stream = ReaderStream::new(reader);
615 FileStream::new(stream)
617 .file_name("evil\"; filename*=UTF-8''pwned.txt; x=\"")
618 .into_response()
619 }),
620 );
621
622 let response = app
623 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
624 .await?;
625
626 assert_eq!(response.status(), StatusCode::OK);
627 assert_eq!(
628 response.headers().get("content-disposition").unwrap(),
629 "attachment; filename=\"evil\\\"; filename*=UTF-8''pwned.txt; x=\\\"\""
630 );
631 Ok(())
632 }
633
634 #[tokio::test]
635 async fn filename_escapes_backslashes() -> Result<(), Box<dyn std::error::Error>> {
636 let app = Router::new().route(
637 "/file",
638 get(|| async {
639 let file_content = b"data".to_vec();
640 let reader = Cursor::new(file_content);
641 let stream = ReaderStream::new(reader);
642 FileStream::new(stream)
643 .file_name("file\\name.txt")
644 .into_response()
645 }),
646 );
647
648 let response = app
649 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
650 .await?;
651
652 assert_eq!(response.status(), StatusCode::OK);
653 assert_eq!(
654 response.headers().get("content-disposition").unwrap(),
655 "attachment; filename=\"file\\\\name.txt\""
656 );
657 Ok(())
658 }
659
660 #[tokio::test]
661 async fn response_range_empty_file() -> Result<(), Box<dyn std::error::Error>> {
662 let file = tempfile::NamedTempFile::new()?;
663 file.as_file().set_len(0)?;
664 let path = file.path().to_owned();
665
666 let app = Router::new().route(
667 "/range_empty",
668 get(move |headers: HeaderMap| {
669 let path = path.clone();
670 async move {
671 let range_header = headers
672 .get(header::RANGE)
673 .and_then(|value| value.to_str().ok());
674
675 let (start, end) = if let Some(range) = range_header {
676 if let Some(range) = parse_range_header(range) {
677 range
678 } else {
679 return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range")
680 .into_response();
681 }
682 } else {
683 (0, 0)
684 };
685
686 FileStream::<ReaderStream<File>>::try_range_response(path, start, end)
687 .await
688 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
689 }
690 }),
691 );
692
693 let response = app
694 .oneshot(
695 Request::builder()
696 .uri("/range_empty")
697 .header(header::RANGE, "bytes=0-")
698 .body(Body::empty())
699 .unwrap(),
700 )
701 .await
702 .unwrap();
703
704 assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
705 Ok(())
706 }
707}