use async_compression::tokio::write::GzipEncoder;
use async_tar::Builder;
use bytes::BytesMut;
use clap::ValueEnum;
use headers::{ContentType, HeaderMapExt};
use http::{HeaderValue, Method, Response};
use hyper::{Body, body::Sender};
use mime_guess::Mime;
use std::fmt::Display;
use std::path::Path;
use std::path::PathBuf;
use std::str::FromStr;
use std::task::Poll::{Pending, Ready};
use tokio::fs;
use tokio::io;
use tokio::io::AsyncWriteExt;
use tokio_util::compat::TokioAsyncWriteCompatExt;
use crate::Result;
use crate::handler::RequestHandlerOpts;
use crate::http_ext::MethodExt;
pub const DOWNLOAD_PARAM_KEY: &str = "download";
#[derive(Debug, Serialize, Deserialize, Clone, ValueEnum, Eq, Hash, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DirDownloadFmt {
Targz,
}
impl Display for DirDownloadFmt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(self, f)
}
}
pub struct DirDownloadOpts<'a> {
pub method: &'a Method,
pub disable_symlinks: bool,
pub ignore_hidden_files: bool,
}
pub fn init(formats: &Vec<DirDownloadFmt>, handler_opts: &mut RequestHandlerOpts) {
for fmt in formats {
if !handler_opts.dir_listing_download.contains(fmt) {
tracing::info!("directory listing download: enabled format {}", &fmt);
handler_opts.dir_listing_download.push(fmt.to_owned());
}
}
tracing::info!(
"directory listing download: enabled={}",
!handler_opts.dir_listing_download.is_empty()
);
}
pub struct ChannelBuffer {
s: Sender,
}
impl tokio::io::AsyncWrite for ChannelBuffer {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
let b = BytesMut::from(buf);
match this.s.poll_ready(cx) {
Ready(r) => match r {
Ok(()) => match this.s.try_send_data(b.freeze()) {
Ok(_) => Ready(Ok(buf.len())),
Err(_) => Pending,
},
Err(e) => Ready(Err(io::Error::new(io::ErrorKind::BrokenPipe, e))),
},
Pending => Pending,
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
async fn archive(
path: PathBuf,
src_path: PathBuf,
cb: ChannelBuffer,
follow_symlinks: bool,
ignore_hidden: bool,
) -> Result {
let gz = GzipEncoder::with_quality(cb, async_compression::Level::Default);
let mut a = Builder::new(gz.compat_write());
a.follow_symlinks(follow_symlinks);
let mut stack = vec![(src_path.to_path_buf(), true, false)];
while let Some((src, is_dir, is_symlink)) = stack.pop() {
let dest = path.join(src.strip_prefix(&src_path)?);
if is_dir || (is_symlink && follow_symlinks && src.is_dir()) {
let mut entries = fs::read_dir(&src).await?;
while let Some(entry) = entries.next_entry().await? {
let name = entry.file_name();
if ignore_hidden && name.as_encoded_bytes().first().is_some_and(|c| *c == b'.') {
continue;
}
let file_type = entry.file_type().await?;
stack.push((entry.path(), file_type.is_dir(), file_type.is_symlink()));
}
if dest != Path::new("") {
a.append_dir(&dest, &src).await?;
}
} else {
a.append_path_with_name(src, &dest).await?;
}
}
a.finish().await?;
a.into_inner().await?.into_inner().shutdown().await?;
Ok(())
}
pub fn archive_reply<P, Q>(path: P, src_path: Q, opts: DirDownloadOpts<'_>) -> Response<Body>
where
P: AsRef<Path>,
Q: AsRef<Path>,
{
let archive_name = path.as_ref().with_extension("tar.gz");
let mut resp = Response::new(Body::empty());
resp.headers_mut().typed_insert(ContentType::from(
Mime::from_str("application/gzip").unwrap(),
));
let hvals = format!(
"attachment; filename=\"{}\"",
archive_name.to_string_lossy()
);
match HeaderValue::from_str(hvals.as_str()) {
Ok(hval) => {
resp.headers_mut()
.insert(hyper::header::CONTENT_DISPOSITION, hval);
}
Err(err) => {
tracing::error!("can't make content disposition from {}: {:?}", hvals, err);
}
}
if opts.method.is_head() {
return resp;
}
let (tx, body) = Body::channel();
tokio::task::spawn(archive(
path.as_ref().into(),
src_path.as_ref().into(),
ChannelBuffer { s: tx },
!opts.disable_symlinks,
opts.ignore_hidden_files,
));
*resp.body_mut() = body;
resp
}