1use axum_core::{
2 body,
3 response::{IntoResponse, Response},
4 BoxError,
5};
6use bytes::Bytes;
7use futures_util::TryStream;
8use http::{header, StatusCode};
9use std::{io, path::Path};
10use tokio::{
11 fs::File,
12 io::{AsyncReadExt, AsyncSeekExt},
13};
14use tokio_util::io::ReaderStream;
15
16#[must_use]
48#[derive(Debug)]
49pub struct FileStream<S> {
50 pub stream: S,
52 pub file_name: Option<String>,
54 pub content_size: Option<u64>,
56}
57
58impl<S> FileStream<S>
59where
60 S: TryStream + Send + 'static,
61 S::Ok: Into<Bytes>,
62 S::Error: Into<BoxError>,
63{
64 pub fn new(stream: S) -> Self {
66 Self {
67 stream,
68 file_name: None,
69 content_size: None,
70 }
71 }
72
73 pub fn file_name(mut self, file_name: impl Into<String>) -> Self {
77 self.file_name = Some(file_name.into());
78 self
79 }
80
81 pub fn content_size(mut self, len: u64) -> Self {
83 self.content_size = Some(len);
84 self
85 }
86
87 pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
127 let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
128 resp = resp.status(StatusCode::PARTIAL_CONTENT);
129
130 resp = resp.header(
131 header::CONTENT_RANGE,
132 format!("bytes {start}-{end}/{total_size}"),
133 );
134
135 resp.body(body::Body::from_stream(self.stream))
136 .unwrap_or_else(|e| {
137 (
138 StatusCode::INTERNAL_SERVER_ERROR,
139 format!("build FileStream response error: {e}"),
140 )
141 .into_response()
142 })
143 }
144
145 pub async fn try_range_response(
185 file_path: impl AsRef<Path>,
186 start: u64,
187 mut end: u64,
188 ) -> io::Result<Response> {
189 let mut file = File::open(file_path).await?;
190
191 let metadata = file.metadata().await?;
192 let total_size = metadata.len();
193
194 if total_size == 0 {
195 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
196 }
197
198 if end == 0 {
199 end = total_size - 1;
200 }
201
202 if start > total_size {
203 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
204 }
205 if start > end {
206 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
207 }
208 if end >= total_size {
209 return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
210 }
211
212 file.seek(std::io::SeekFrom::Start(start)).await?;
213
214 let stream = ReaderStream::new(file.take(end - start + 1));
215
216 Ok(FileStream::new(stream).into_range_response(start, end, total_size))
217 }
218}
219
220impl FileStream<ReaderStream<File>> {
222 pub async fn from_path(path: impl AsRef<Path>) -> io::Result<Self> {
245 let file = File::open(&path).await?;
246 let mut content_size = None;
247 let mut file_name = None;
248
249 if let Ok(metadata) = file.metadata().await {
250 content_size = Some(metadata.len());
251 }
252
253 if let Some(file_name_os) = path.as_ref().file_name() {
254 if let Some(file_name_str) = file_name_os.to_str() {
255 file_name = Some(file_name_str.to_owned());
256 }
257 }
258
259 Ok(Self {
260 stream: ReaderStream::new(file),
261 file_name,
262 content_size,
263 })
264 }
265}
266
267impl<S> IntoResponse for FileStream<S>
268where
269 S: TryStream + Send + 'static,
270 S::Ok: Into<Bytes>,
271 S::Error: Into<BoxError>,
272{
273 fn into_response(self) -> Response {
274 let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
275
276 if let Some(file_name) = self.file_name {
277 resp = resp.header(
278 header::CONTENT_DISPOSITION,
279 format!("attachment; filename=\"{file_name}\""),
280 );
281 }
282
283 if let Some(content_size) = self.content_size {
284 resp = resp.header(header::CONTENT_LENGTH, content_size);
285 }
286
287 resp.body(body::Body::from_stream(self.stream))
288 .unwrap_or_else(|e| {
289 (
290 StatusCode::INTERNAL_SERVER_ERROR,
291 format!("build FileStream responsec error: {e}"),
292 )
293 .into_response()
294 })
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use axum::{extract::Request, routing::get, Router};
302 use body::Body;
303 use http::HeaderMap;
304 use http_body_util::BodyExt;
305 use std::io::Cursor;
306 use tokio_util::io::ReaderStream;
307 use tower::ServiceExt;
308
309 #[tokio::test]
310 async fn response() -> Result<(), Box<dyn std::error::Error>> {
311 let app = Router::new().route(
312 "/file",
313 get(|| async {
314 let file_content = b"Hello, this is the simulated file content!".to_vec();
316 let reader = Cursor::new(file_content);
317
318 let stream = ReaderStream::new(reader);
321 FileStream::new(stream).into_response()
322 }),
323 );
324
325 let response = app
327 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
328 .await?;
329
330 assert_eq!(response.status(), StatusCode::OK);
332
333 assert_eq!(
335 response.headers().get("content-type").unwrap(),
336 "application/octet-stream"
337 );
338
339 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
341 assert_eq!(
342 std::str::from_utf8(body)?,
343 "Hello, this is the simulated file content!"
344 );
345 Ok(())
346 }
347
348 #[tokio::test]
349 async fn response_not_set_filename() -> Result<(), Box<dyn std::error::Error>> {
350 let app = Router::new().route(
351 "/file",
352 get(|| async {
353 let file_content = b"Hello, this is the simulated file content!".to_vec();
355 let size = file_content.len() as u64;
356 let reader = Cursor::new(file_content);
357
358 let stream = ReaderStream::new(reader);
360 FileStream::new(stream).content_size(size).into_response()
361 }),
362 );
363
364 let response = app
366 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
367 .await?;
368
369 assert_eq!(response.status(), StatusCode::OK);
371
372 assert_eq!(
374 response.headers().get("content-type").unwrap(),
375 "application/octet-stream"
376 );
377 assert_eq!(response.headers().get("content-length").unwrap(), "42");
378
379 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
381 assert_eq!(
382 std::str::from_utf8(body)?,
383 "Hello, this is the simulated file content!"
384 );
385 Ok(())
386 }
387
388 #[tokio::test]
389 async fn response_not_set_content_size() -> Result<(), Box<dyn std::error::Error>> {
390 let app = Router::new().route(
391 "/file",
392 get(|| async {
393 let file_content = b"Hello, this is the simulated file content!".to_vec();
395 let reader = Cursor::new(file_content);
396
397 let stream = ReaderStream::new(reader);
399 FileStream::new(stream).file_name("test").into_response()
400 }),
401 );
402
403 let response = app
405 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
406 .await?;
407
408 assert_eq!(response.status(), StatusCode::OK);
410
411 assert_eq!(
413 response.headers().get("content-type").unwrap(),
414 "application/octet-stream"
415 );
416 assert_eq!(
417 response.headers().get("content-disposition").unwrap(),
418 "attachment; filename=\"test\""
419 );
420
421 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
423 assert_eq!(
424 std::str::from_utf8(body)?,
425 "Hello, this is the simulated file content!"
426 );
427 Ok(())
428 }
429
430 #[tokio::test]
431 async fn response_with_content_size_and_filename() -> Result<(), Box<dyn std::error::Error>> {
432 let app = Router::new().route(
433 "/file",
434 get(|| async {
435 let file_content = b"Hello, this is the simulated file content!".to_vec();
437 let size = file_content.len() as u64;
438 let reader = Cursor::new(file_content);
439
440 let stream = ReaderStream::new(reader);
442 FileStream::new(stream)
443 .file_name("test")
444 .content_size(size)
445 .into_response()
446 }),
447 );
448
449 let response = app
451 .oneshot(Request::builder().uri("/file").body(Body::empty())?)
452 .await?;
453
454 assert_eq!(response.status(), StatusCode::OK);
456
457 assert_eq!(
459 response.headers().get("content-type").unwrap(),
460 "application/octet-stream"
461 );
462 assert_eq!(
463 response.headers().get("content-disposition").unwrap(),
464 "attachment; filename=\"test\""
465 );
466 assert_eq!(response.headers().get("content-length").unwrap(), "42");
467
468 let body: &[u8] = &response.into_body().collect().await?.to_bytes();
470 assert_eq!(
471 std::str::from_utf8(body)?,
472 "Hello, this is the simulated file content!"
473 );
474 Ok(())
475 }
476
477 #[tokio::test]
478 async fn response_from_path() -> Result<(), Box<dyn std::error::Error>> {
479 let app = Router::new().route(
480 "/from_path",
481 get(move || async move {
482 FileStream::from_path(Path::new("CHANGELOG.md"))
483 .await
484 .unwrap()
485 .into_response()
486 }),
487 );
488
489 let response = app
491 .oneshot(
492 Request::builder()
493 .uri("/from_path")
494 .body(Body::empty())
495 .unwrap(),
496 )
497 .await
498 .unwrap();
499
500 assert_eq!(response.status(), StatusCode::OK);
502
503 assert_eq!(
505 response.headers().get("content-type").unwrap(),
506 "application/octet-stream"
507 );
508 assert_eq!(
509 response.headers().get("content-disposition").unwrap(),
510 "attachment; filename=\"CHANGELOG.md\""
511 );
512
513 let file = File::open("CHANGELOG.md").await.unwrap();
514 let content_length = file.metadata().await.unwrap().len();
516
517 assert_eq!(
518 response
519 .headers()
520 .get("content-length")
521 .unwrap()
522 .to_str()
523 .unwrap(),
524 content_length.to_string()
525 );
526 Ok(())
527 }
528
529 #[tokio::test]
530 async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
531 let app = Router::new().route("/range_response", get(range_stream));
532
533 let response = app
535 .oneshot(
536 Request::builder()
537 .uri("/range_response")
538 .header(header::RANGE, "bytes=20-1000")
539 .body(Body::empty())
540 .unwrap(),
541 )
542 .await
543 .unwrap();
544
545 assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
547
548 assert_eq!(
550 response.headers().get("content-type").unwrap(),
551 "application/octet-stream"
552 );
553
554 let file = File::open("CHANGELOG.md").await.unwrap();
555 let content_length = file.metadata().await.unwrap().len();
557
558 assert_eq!(
559 response
560 .headers()
561 .get("content-range")
562 .unwrap()
563 .to_str()
564 .unwrap(),
565 format!("bytes 20-1000/{content_length}")
566 );
567 Ok(())
568 }
569
570 async fn range_stream(headers: HeaderMap) -> Response {
571 let range_header = headers
572 .get(header::RANGE)
573 .and_then(|value| value.to_str().ok());
574
575 let (start, end) = if let Some(range) = range_header {
576 if let Some(range) = parse_range_header(range) {
577 range
578 } else {
579 return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
580 }
581 } else {
582 (0, 0) };
584
585 FileStream::<ReaderStream<File>>::try_range_response(Path::new("CHANGELOG.md"), start, end)
586 .await
587 .unwrap()
588 }
589
590 fn parse_range_header(range: &str) -> Option<(u64, u64)> {
591 let range = range.strip_prefix("bytes=")?;
592 let mut parts = range.split('-');
593 let start = parts.next()?.parse::<u64>().ok()?;
594 let end = parts
595 .next()
596 .and_then(|s| s.parse::<u64>().ok())
597 .unwrap_or(0);
598 if start > end {
599 return None;
600 }
601 Some((start, end))
602 }
603
604 #[tokio::test]
605 async fn response_range_empty_file() -> Result<(), Box<dyn std::error::Error>> {
606 let file = tempfile::NamedTempFile::new()?;
607 file.as_file().set_len(0)?;
608 let path = file.path().to_owned();
609
610 let app = Router::new().route(
611 "/range_empty",
612 get(move |headers: HeaderMap| {
613 let path = path.clone();
614 async move {
615 let range_header = headers
616 .get(header::RANGE)
617 .and_then(|value| value.to_str().ok());
618
619 let (start, end) = if let Some(range) = range_header {
620 if let Some(range) = parse_range_header(range) {
621 range
622 } else {
623 return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range")
624 .into_response();
625 }
626 } else {
627 (0, 0)
628 };
629
630 FileStream::<ReaderStream<File>>::try_range_response(path, start, end)
631 .await
632 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
633 }
634 }),
635 );
636
637 let response = app
638 .oneshot(
639 Request::builder()
640 .uri("/range_empty")
641 .header(header::RANGE, "bytes=0-")
642 .body(Body::empty())
643 .unwrap(),
644 )
645 .await
646 .unwrap();
647
648 assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
649 Ok(())
650 }
651}