use std::env;
use std::path::PathBuf;
use std::sync::{Arc};
use std::time::Duration;
use anyhow::bail;
use clap::{Parser, Subcommand};
use dav_server::{memls::MemLs, DavHandler};
#[cfg(unix)]
use futures_util::stream::StreamExt;
use tracing::{debug, info};
use tracing_subscriber::EnvFilter;
#[cfg(unix)]
use {signal_hook::consts::signal::*, signal_hook_tokio::Signals};
use cache::Cache;
use drive::*;
use vfs::QuarkDriveFileSystem;
use webdav::WebDavServer;
mod cache;
mod drive;
mod vfs;
mod webdav;
use tokio::time::interval;
#[derive(Parser, Debug)]
#[command(name = "quarkdrive-webdav", about, version, author)]
#[command(args_conflicts_with_subcommands = true)]
struct Opt {
#[arg(long, env = "HOST", default_value = "0.0.0.0")]
host: String,
#[arg(short, env = "PORT", default_value = "8080")]
port: u16,
#[arg(long, env = "QUARK_COOKIE")]
quark_cookie: Option<String>,
#[arg(short = 'U', long, env = "WEBDAV_AUTH_USER")]
auth_user: Option<String>,
#[arg(short = 'W', long, env = "WEBDAV_AUTH_PASSWORD")]
auth_password: Option<String>,
#[arg(short = 'I', long)]
auto_index: bool,
#[arg(short = 'S', long, default_value = "10485760")]
read_buffer_size: usize,
#[arg(long, default_value = "16777216")]
upload_buffer_size: usize,
#[arg(long, default_value = "1000")]
cache_size: u64,
#[arg(long, default_value = "600")]
cache_ttl: u64,
#[arg(long, env = "WEBDAV_ROOT", default_value = "/")]
root: String,
#[arg(long)]
no_trash: bool,
#[arg(long)]
read_only: bool,
#[arg(long, env = "TLS_CERT")]
tls_cert: Option<PathBuf>,
#[arg(long, env = "TLS_KEY")]
tls_key: Option<PathBuf>,
#[arg(long, env = "WEBDAV_STRIP_PREFIX")]
strip_prefix: Option<String>,
#[arg(long)]
debug: bool,
#[arg(long)]
no_self_upgrade: bool,
#[arg(long)]
skip_upload_same_size: bool,
#[arg(long)]
prefer_http_download: bool,
#[arg(long)]
redirect: bool,
#[command(subcommand)]
subcommands: Option<Commands>,
#[arg(long, env = "REFRESH_CACHE_SECS_INTERVAL", default_value = "3600")]
refresh_cache_secs_interval: u64,
}
#[derive(Subcommand, Debug)]
enum Commands {
#[command(subcommand)]
Qr(QrCommand),
}
#[derive(Subcommand, Debug)]
enum QrCommand {
Login,
Generate,
#[command(arg_required_else_help = true)]
Query {
#[arg(long)]
sid: String,
},
}
pub fn start_periodic_invalidate(cache: Arc<Cache>, secs: u64) {
tokio::spawn(async move {
let mut ticker = interval(Duration::from_secs(secs));
loop {
ticker.tick().await;
cache.invalidate_all();
}
});
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> anyhow::Result<()> {
#[cfg(feature = "native-tls-vendored")]
openssl_probe::init_openssl_env_vars();
let opt = Opt::parse();
if env::var("RUST_LOG").is_err() {
if opt.debug {
unsafe { env::set_var("RUST_LOG", "quarkdrive_webdav=debug,reqwest=debug"); }
} else {
unsafe { env::set_var("RUST_LOG", "quarkdrive_webdav=info,reqwest=warn"); }
}
}
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_timer(tracing_subscriber::fmt::time::time())
.init();
let drive_config = DriveConfig {
api_base_url: "https://drive.quark.cn".to_string(),
cookie: opt.quark_cookie.clone(),
};
let auth_user = opt.auth_user;
let auth_password = opt.auth_password;
if (auth_user.is_some() && auth_password.is_none())
|| (auth_user.is_none() && auth_password.is_some())
{
bail!("auth-user and auth-password must be specified together.");
}
let tls_config = match (opt.tls_cert, opt.tls_key) {
(Some(cert), Some(key)) => Some((cert, key)),
(None, None) => None,
_ => bail!("tls-cert and tls-key must be specified together."),
};
let drive = QuarkDrive::new(drive_config)?;
let mut fs = QuarkDriveFileSystem::new(drive, opt.root, opt.cache_size, opt.cache_ttl)?;
fs.set_no_trash(opt.no_trash)
.set_read_only(opt.read_only)
.set_upload_buffer_size(opt.upload_buffer_size)
.set_skip_upload_same_size(opt.skip_upload_same_size)
.set_prefer_http_download(opt.prefer_http_download);
let cache = Arc::new(fs.dir_cache.clone());
start_periodic_invalidate(cache.clone(), opt.refresh_cache_secs_interval);
#[cfg(unix)]
let mut dav_server_builder = DavHandler::builder()
.filesystem(Box::new(fs))
.locksystem(MemLs::new())
.read_buf_size(opt.read_buffer_size)
.autoindex(opt.auto_index)
.redirect(opt.redirect);
if let Some(prefix) = opt.strip_prefix {
dav_server_builder = dav_server_builder.strip_prefix(prefix);
}
let dav_server = dav_server_builder.build_handler();
debug!(
read_buffer_size = opt.read_buffer_size,
auto_index = opt.auto_index,
"webdav handler initialized"
);
let server = WebDavServer {
host: opt.host,
port: opt.port,
auth_user,
auth_password,
tls_config,
handler: dav_server,
};
#[cfg(not(unix))]
server.serve().await?;
#[cfg(unix)]
{
let signals = Signals::new([SIGHUP])?;
let handle = signals.handle();
let signals_task = tokio::spawn(handle_signals(signals, cache));
server.serve().await?;
handle.close();
signals_task.await?;
}
Ok(())
}
#[cfg(unix)]
async fn handle_signals(mut signals: Signals, dir_cache: Arc<Cache>) {
while let Some(signal) = signals.next().await {
match signal {
SIGHUP => {
dir_cache.invalidate_all();
info!("directory cache invalidated by SIGHUP");
}
_ => unreachable!(),
}
}
}