use axum_core::{
body,
response::{IntoResponse, Response},
BoxError,
};
use bytes::Bytes;
use futures_util::TryStream;
use http::{header, StatusCode};
use std::{io, path::Path};
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncSeekExt},
};
use tokio_util::io::ReaderStream;
#[must_use]
#[derive(Debug)]
pub struct FileStream<S> {
pub stream: S,
pub file_name: Option<String>,
pub content_size: Option<u64>,
}
impl<S> FileStream<S>
where
S: TryStream + Send + 'static,
S::Ok: Into<Bytes>,
S::Error: Into<BoxError>,
{
pub fn new(stream: S) -> Self {
Self {
stream,
file_name: None,
content_size: None,
}
}
pub fn file_name(mut self, file_name: impl Into<String>) -> Self {
self.file_name = Some(file_name.into());
self
}
pub fn content_size(mut self, len: u64) -> Self {
self.content_size = Some(len);
self
}
pub fn into_range_response(self, start: u64, end: u64, total_size: u64) -> Response {
let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
resp = resp.status(StatusCode::PARTIAL_CONTENT);
resp = resp.header(
header::CONTENT_RANGE,
format!("bytes {start}-{end}/{total_size}"),
);
resp.body(body::Body::from_stream(self.stream))
.unwrap_or_else(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("build FileStream response error: {e}"),
)
.into_response()
})
}
pub async fn try_range_response(
file_path: impl AsRef<Path>,
start: u64,
mut end: u64,
) -> io::Result<Response> {
let mut file = File::open(file_path).await?;
let metadata = file.metadata().await?;
let total_size = metadata.len();
if total_size == 0 {
return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
}
if end == 0 {
end = total_size - 1;
}
if start > total_size {
return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
}
if start > end {
return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
}
if end >= total_size {
return Ok((StatusCode::RANGE_NOT_SATISFIABLE, "Range Not Satisfiable").into_response());
}
file.seek(std::io::SeekFrom::Start(start)).await?;
let stream = ReaderStream::new(file.take(end - start + 1));
Ok(FileStream::new(stream).into_range_response(start, end, total_size))
}
}
impl FileStream<ReaderStream<File>> {
pub async fn from_path(path: impl AsRef<Path>) -> io::Result<Self> {
let file = File::open(&path).await?;
let mut content_size = None;
let mut file_name = None;
if let Ok(metadata) = file.metadata().await {
content_size = Some(metadata.len());
}
if let Some(file_name_os) = path.as_ref().file_name() {
if let Some(file_name_str) = file_name_os.to_str() {
file_name = Some(file_name_str.to_owned());
}
}
Ok(Self {
stream: ReaderStream::new(file),
file_name,
content_size,
})
}
}
impl<S> IntoResponse for FileStream<S>
where
S: TryStream + Send + 'static,
S::Ok: Into<Bytes>,
S::Error: Into<BoxError>,
{
fn into_response(self) -> Response {
let mut resp = Response::builder().header(header::CONTENT_TYPE, "application/octet-stream");
if let Some(file_name) = self.file_name {
resp = resp.header(
header::CONTENT_DISPOSITION,
format!(
"attachment; filename=\"{}\"",
super::content_disposition::EscapedFilename(&file_name)
),
);
}
if let Some(content_size) = self.content_size {
resp = resp.header(header::CONTENT_LENGTH, content_size);
}
resp.body(body::Body::from_stream(self.stream))
.unwrap_or_else(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("build FileStream responsec error: {e}"),
)
.into_response()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{extract::Request, routing::get, Router};
use body::Body;
use http::HeaderMap;
use http_body_util::BodyExt;
use std::io::Cursor;
use tokio_util::io::ReaderStream;
use tower::ServiceExt;
#[tokio::test]
async fn response() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route(
"/file",
get(|| async {
let file_content = b"Hello, this is the simulated file content!".to_vec();
let reader = Cursor::new(file_content);
let stream = ReaderStream::new(reader);
FileStream::new(stream).into_response()
}),
);
let response = app
.oneshot(Request::builder().uri("/file").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/octet-stream"
);
let body: &[u8] = &response.into_body().collect().await?.to_bytes();
assert_eq!(
std::str::from_utf8(body)?,
"Hello, this is the simulated file content!"
);
Ok(())
}
#[tokio::test]
async fn response_not_set_filename() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route(
"/file",
get(|| async {
let file_content = b"Hello, this is the simulated file content!".to_vec();
let size = file_content.len() as u64;
let reader = Cursor::new(file_content);
let stream = ReaderStream::new(reader);
FileStream::new(stream).content_size(size).into_response()
}),
);
let response = app
.oneshot(Request::builder().uri("/file").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/octet-stream"
);
assert_eq!(response.headers().get("content-length").unwrap(), "42");
let body: &[u8] = &response.into_body().collect().await?.to_bytes();
assert_eq!(
std::str::from_utf8(body)?,
"Hello, this is the simulated file content!"
);
Ok(())
}
#[tokio::test]
async fn response_not_set_content_size() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route(
"/file",
get(|| async {
let file_content = b"Hello, this is the simulated file content!".to_vec();
let reader = Cursor::new(file_content);
let stream = ReaderStream::new(reader);
FileStream::new(stream).file_name("test").into_response()
}),
);
let response = app
.oneshot(Request::builder().uri("/file").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/octet-stream"
);
assert_eq!(
response.headers().get("content-disposition").unwrap(),
"attachment; filename=\"test\""
);
let body: &[u8] = &response.into_body().collect().await?.to_bytes();
assert_eq!(
std::str::from_utf8(body)?,
"Hello, this is the simulated file content!"
);
Ok(())
}
#[tokio::test]
async fn response_with_content_size_and_filename() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route(
"/file",
get(|| async {
let file_content = b"Hello, this is the simulated file content!".to_vec();
let size = file_content.len() as u64;
let reader = Cursor::new(file_content);
let stream = ReaderStream::new(reader);
FileStream::new(stream)
.file_name("test")
.content_size(size)
.into_response()
}),
);
let response = app
.oneshot(Request::builder().uri("/file").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/octet-stream"
);
assert_eq!(
response.headers().get("content-disposition").unwrap(),
"attachment; filename=\"test\""
);
assert_eq!(response.headers().get("content-length").unwrap(), "42");
let body: &[u8] = &response.into_body().collect().await?.to_bytes();
assert_eq!(
std::str::from_utf8(body)?,
"Hello, this is the simulated file content!"
);
Ok(())
}
#[tokio::test]
async fn response_from_path() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route(
"/from_path",
get(move || async move {
FileStream::from_path(Path::new("CHANGELOG.md"))
.await
.unwrap()
.into_response()
}),
);
let response = app
.oneshot(
Request::builder()
.uri("/from_path")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/octet-stream"
);
assert_eq!(
response.headers().get("content-disposition").unwrap(),
"attachment; filename=\"CHANGELOG.md\""
);
let file = File::open("CHANGELOG.md").await.unwrap();
let content_length = file.metadata().await.unwrap().len();
assert_eq!(
response
.headers()
.get("content-length")
.unwrap()
.to_str()
.unwrap(),
content_length.to_string()
);
Ok(())
}
#[tokio::test]
async fn response_range_file() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route("/range_response", get(range_stream));
let response = app
.oneshot(
Request::builder()
.uri("/range_response")
.header(header::RANGE, "bytes=20-1000")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::PARTIAL_CONTENT);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/octet-stream"
);
let file = File::open("CHANGELOG.md").await.unwrap();
let content_length = file.metadata().await.unwrap().len();
assert_eq!(
response
.headers()
.get("content-range")
.unwrap()
.to_str()
.unwrap(),
format!("bytes 20-1000/{content_length}")
);
Ok(())
}
async fn range_stream(headers: HeaderMap) -> Response {
let range_header = headers
.get(header::RANGE)
.and_then(|value| value.to_str().ok());
let (start, end) = if let Some(range) = range_header {
if let Some(range) = parse_range_header(range) {
range
} else {
return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range").into_response();
}
} else {
(0, 0) };
FileStream::<ReaderStream<File>>::try_range_response(Path::new("CHANGELOG.md"), start, end)
.await
.unwrap()
}
fn parse_range_header(range: &str) -> Option<(u64, u64)> {
let range = range.strip_prefix("bytes=")?;
let mut parts = range.split('-');
let start = parts.next()?.parse::<u64>().ok()?;
let end = parts
.next()
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
if start > end {
return None;
}
Some((start, end))
}
#[tokio::test]
async fn filename_escapes_quotes() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route(
"/file",
get(|| async {
let file_content = b"data".to_vec();
let reader = Cursor::new(file_content);
let stream = ReaderStream::new(reader);
FileStream::new(stream)
.file_name("evil\"; filename*=UTF-8''pwned.txt; x=\"")
.into_response()
}),
);
let response = app
.oneshot(Request::builder().uri("/file").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-disposition").unwrap(),
"attachment; filename=\"evil\\\"; filename*=UTF-8''pwned.txt; x=\\\"\""
);
Ok(())
}
#[tokio::test]
async fn filename_escapes_backslashes() -> Result<(), Box<dyn std::error::Error>> {
let app = Router::new().route(
"/file",
get(|| async {
let file_content = b"data".to_vec();
let reader = Cursor::new(file_content);
let stream = ReaderStream::new(reader);
FileStream::new(stream)
.file_name("file\\name.txt")
.into_response()
}),
);
let response = app
.oneshot(Request::builder().uri("/file").body(Body::empty())?)
.await?;
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-disposition").unwrap(),
"attachment; filename=\"file\\\\name.txt\""
);
Ok(())
}
#[tokio::test]
async fn response_range_empty_file() -> Result<(), Box<dyn std::error::Error>> {
let file = tempfile::NamedTempFile::new()?;
file.as_file().set_len(0)?;
let path = file.path().to_owned();
let app = Router::new().route(
"/range_empty",
get(move |headers: HeaderMap| {
let path = path.clone();
async move {
let range_header = headers
.get(header::RANGE)
.and_then(|value| value.to_str().ok());
let (start, end) = if let Some(range) = range_header {
if let Some(range) = parse_range_header(range) {
range
} else {
return (StatusCode::RANGE_NOT_SATISFIABLE, "Invalid Range")
.into_response();
}
} else {
(0, 0)
};
FileStream::<ReaderStream<File>>::try_range_response(path, start, end)
.await
.unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response())
}
}),
);
let response = app
.oneshot(
Request::builder()
.uri("/range_empty")
.header(header::RANGE, "bytes=0-")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::RANGE_NOT_SATISFIABLE);
Ok(())
}
}