use super::content_disposition::EscapedFilename;
use axum_core::response::IntoResponse;
use http::{header, HeaderMap, HeaderValue};
use tracing::error;
#[derive(Debug)]
#[must_use]
pub struct Attachment<T> {
inner: T,
filename: Option<HeaderValue>,
content_type: Option<HeaderValue>,
}
impl<T: IntoResponse> Attachment<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
filename: None,
content_type: None,
}
}
pub fn filename<H: TryInto<HeaderValue>>(mut self, value: H) -> Self {
self.filename = if let Ok(filename) = value.try_into() {
Some(filename)
} else {
error!("Attachment filename contains invalid characters");
None
};
self
}
pub fn content_type<H: TryInto<HeaderValue>>(mut self, value: H) -> Self {
if let Ok(content_type) = value.try_into() {
self.content_type = Some(content_type);
} else {
error!("Attachment content-type contains invalid characters");
}
self
}
}
impl<T> IntoResponse for Attachment<T>
where
T: IntoResponse,
{
fn into_response(self) -> axum_core::response::Response {
let mut headers = HeaderMap::new();
if let Some(content_type) = self.content_type {
headers.append(header::CONTENT_TYPE, content_type);
}
let content_disposition = if let Some(filename) = self.filename {
let filename_str = filename
.to_str()
.expect("This was a HeaderValue so this can not fail");
let value = format!("attachment; filename=\"{}\"", EscapedFilename(filename_str));
HeaderValue::try_from(value).expect("This was a HeaderValue so this can not fail")
} else {
HeaderValue::from_static("attachment")
};
headers.append(header::CONTENT_DISPOSITION, content_disposition);
(headers, self.inner).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum_core::response::IntoResponse;
use http::header::CONTENT_DISPOSITION;
#[test]
fn attachment_without_filename() {
let attachment = Attachment::new("data").into_response();
let value = attachment.headers().get(CONTENT_DISPOSITION).unwrap();
assert_eq!(value, "attachment");
}
#[test]
fn attachment_with_normal_filename() {
let attachment = Attachment::new("data")
.filename("report.pdf")
.into_response();
let value = attachment.headers().get(CONTENT_DISPOSITION).unwrap();
assert_eq!(value, "attachment; filename=\"report.pdf\"");
}
#[test]
fn attachment_filename_escapes_quotes() {
let attachment = Attachment::new("data")
.filename("evil\"; filename*=UTF-8''pwned.txt; x=\"")
.into_response();
let value = attachment.headers().get(CONTENT_DISPOSITION).unwrap();
assert_eq!(
value,
"attachment; filename=\"evil\\\"; filename*=UTF-8''pwned.txt; x=\\\"\""
);
}
#[test]
fn attachment_filename_escapes_backslashes() {
let attachment = Attachment::new("data")
.filename("file\\name.txt")
.into_response();
let value = attachment.headers().get(CONTENT_DISPOSITION).unwrap();
assert_eq!(value, "attachment; filename=\"file\\\\name.txt\"");
}
}