use async_compression::tokio::write::GzipEncoder;
use async_tar::Builder;
use bytes::BytesMut;
use clap::ValueEnum;
use headers::{ContentType, HeaderMapExt};
use http::{HeaderValue, Method, Response};
use hyper::{Body, body::Sender};
use mime_guess::Mime;
use std::fmt::Display;
use std::path::Path;
use std::path::PathBuf;
use std::str::FromStr;
use std::task::Poll::{Pending, Ready};
use tokio::fs;
use tokio::io;
use tokio::io::AsyncWriteExt;
use tokio_util::compat::TokioAsyncWriteCompatExt;
use crate::Result;
use crate::handler::RequestHandlerOpts;
use crate::http_ext::MethodExt;
pub const DOWNLOAD_PARAM_KEY: &str = "download";
#[derive(Debug, Serialize, Deserialize, Clone, ValueEnum, Eq, Hash, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DirDownloadFmt {
Targz,
}
impl Display for DirDownloadFmt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(self, f)
}
}
pub struct DirDownloadOpts<'a> {
pub method: &'a Method,
pub disable_symlinks: bool,
pub ignore_hidden_files: bool,
}
pub fn init(formats: &Vec<DirDownloadFmt>, handler_opts: &mut RequestHandlerOpts) {
for fmt in formats {
if !handler_opts.dir_listing_download.contains(fmt) {
tracing::info!("directory listing download: enabled format {}", &fmt);
handler_opts.dir_listing_download.push(fmt.to_owned());
}
}
tracing::info!(
"directory listing download: enabled={}",
!handler_opts.dir_listing_download.is_empty()
);
}
pub struct ChannelBuffer {
s: Sender,
}
impl tokio::io::AsyncWrite for ChannelBuffer {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
let b = BytesMut::from(buf);
match this.s.poll_ready(cx) {
Ready(r) => match r {
Ok(()) => match this.s.try_send_data(b.freeze()) {
Ok(_) => Ready(Ok(buf.len())),
Err(_) => Pending,
},
Err(e) => Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, e))),
},
Pending => Pending,
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
async fn archive(
path: PathBuf,
src_path: PathBuf,
cb: ChannelBuffer,
follow_symlinks: bool,
ignore_hidden: bool,
) -> Result {
let gz = GzipEncoder::with_quality(cb, async_compression::Level::Default);
let mut a = Builder::new(gz.compat_write());
a.follow_symlinks(follow_symlinks);
let mut stack = vec![(src_path.to_path_buf(), true, false)];
while let Some((src, is_dir, is_symlink)) = stack.pop() {
let dest = path.join(src.strip_prefix(&src_path)?);
if is_dir || (is_symlink && follow_symlinks && src.is_dir()) {
let mut entries = fs::read_dir(&src).await?;
while let Some(entry) = entries.next_entry().await? {
let name = entry.file_name();
if ignore_hidden && name.as_encoded_bytes().first().is_some_and(|c| *c == b'.') {
continue;
}
let file_type = entry.file_type().await?;
stack.push((entry.path(), file_type.is_dir(), file_type.is_symlink()));
}
if dest != Path::new("") {
a.append_dir(&dest, &src).await?;
}
} else {
a.append_path_with_name(src, &dest).await?;
}
}
a.finish().await?;
a.into_inner().await?.into_inner().shutdown().await?;
Ok(())
}
pub fn archive_reply<P, Q>(path: P, src_path: Q, opts: DirDownloadOpts<'_>) -> Response<Body>
where
P: AsRef<Path>,
Q: AsRef<Path>,
{
let archive_name = path.as_ref().with_extension("tar.gz");
let mut resp = Response::new(Body::empty());
resp.headers_mut().typed_insert(ContentType::from(
Mime::from_str("application/gzip").unwrap_or(mime_guess::mime::APPLICATION_OCTET_STREAM),
));
let archive_name_str = archive_name.to_string_lossy();
let ascii_safe = sanitize_filename_for_quoted_string(&archive_name_str);
let percent_encoded = rfc5987_encode_filename(&archive_name_str);
let hvals =
format!("attachment; filename=\"{ascii_safe}\"; filename*=UTF-8''{percent_encoded}");
match HeaderValue::from_str(hvals.as_str()) {
Ok(hval) => {
resp.headers_mut()
.insert(hyper::header::CONTENT_DISPOSITION, hval);
}
Err(err) => {
tracing::error!("can't make content disposition from {}: {:?}", hvals, err);
}
}
if opts.method.is_head() {
return resp;
}
let (tx, body) = Body::channel();
tokio::task::spawn(archive(
path.as_ref().into(),
src_path.as_ref().into(),
ChannelBuffer { s: tx },
!opts.disable_symlinks,
opts.ignore_hidden_files,
));
*resp.body_mut() = body;
resp
}
#[doc(hidden)]
pub fn sanitize_filename_for_quoted_string(name: &str) -> String {
let mut out = String::with_capacity(name.len());
for ch in name.chars() {
match ch {
'"' | '\\' => out.push('_'),
c if (c as u32) < 0x20 || c == '\x7f' => out.push('_'),
c if c.is_ascii() => out.push(c),
_ => out.push('_'),
}
}
if out.is_empty() {
out.push_str("download");
}
out
}
#[doc(hidden)]
pub fn rfc5987_encode_filename(name: &str) -> String {
fn is_attr_char(b: u8) -> bool {
b.is_ascii_alphanumeric()
|| matches!(
b,
b'!' | b'#' | b'$' | b'&' | b'+' | b'-' | b'.' | b'^' | b'_' | b'`' | b'|' | b'~'
)
}
let mut out = String::with_capacity(name.len());
for &b in name.as_bytes() {
if is_attr_char(b) {
out.push(b as char);
} else {
use std::fmt::Write;
let _ = write!(out, "%{b:02X}");
}
}
out
}
#[cfg(test)]
mod tests {
use super::{rfc5987_encode_filename, sanitize_filename_for_quoted_string};
#[test]
fn sanitize_strips_quote_and_backslash() {
let out = sanitize_filename_for_quoted_string("evil\".tar.gz");
assert!(!out.contains('"'));
let out2 = sanitize_filename_for_quoted_string("a\\b.tar.gz");
assert!(!out2.contains('\\'));
}
#[test]
fn sanitize_strips_control_bytes() {
let out = sanitize_filename_for_quoted_string("a\r\nb\tc\x00d");
for ch in out.chars() {
assert!(
ch as u32 >= 0x20 && ch != '\x7f',
"control byte leaked: {:?}",
ch
);
}
}
#[test]
fn sanitize_replaces_non_ascii() {
let out = sanitize_filename_for_quoted_string("rep\u{00f6}rt.tar.gz");
assert!(out.is_ascii());
assert!(out.starts_with("rep_rt") || out.starts_with("rep__rt"));
}
#[test]
fn sanitize_never_empty() {
assert_eq!(sanitize_filename_for_quoted_string(""), "download");
}
#[test]
fn rfc5987_preserves_attr_char_alphabet() {
let input = "abcXYZ0189!#$&+-.^_`|~";
assert_eq!(rfc5987_encode_filename(input), input);
}
#[test]
fn rfc5987_encodes_unsafe_bytes() {
assert_eq!(rfc5987_encode_filename("a b"), "a%20b");
assert_eq!(rfc5987_encode_filename("a\"b"), "a%22b");
assert_eq!(rfc5987_encode_filename("a\\b"), "a%5Cb");
assert_eq!(rfc5987_encode_filename("a\r\nb"), "a%0D%0Ab");
assert_eq!(rfc5987_encode_filename("\u{00f6}"), "%C3%B6");
}
}