use clap::{Parser, ValueHint};
use clap_complete::Shell;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use std::net::IpAddr;
use std::path::PathBuf;
use crate::auth;
use crate::errors::ContextualError;
use crate::renderer;
#[derive(clap::ArgEnum, Clone)]
pub enum MediaType {
Image,
Audio,
Video,
}
#[derive(Parser)]
#[clap(name = "miniserve", author, about, version)]
pub struct CliArgs {
#[clap(short = 'v', long = "verbose")]
pub verbose: bool,
#[clap(name = "PATH", parse(from_os_str), value_hint = ValueHint::AnyPath)]
pub path: Option<PathBuf>,
#[clap(long, parse(from_os_str), name = "index_file", value_hint = ValueHint::FilePath)]
pub index: Option<PathBuf>,
#[clap(long, requires = "index_file")]
pub spa: bool,
#[clap(short = 'p', long = "port", default_value = "8080")]
pub port: u16,
#[clap(
short = 'i',
long = "interfaces",
parse(try_from_str = parse_interface),
multiple_occurrences(true),
number_of_values = 1,
)]
pub interfaces: Vec<IpAddr>,
#[clap(
short = 'a',
long = "auth",
parse(try_from_str = parse_auth),
multiple_occurrences(true),
number_of_values = 1,
)]
pub auth: Vec<auth::RequiredAuth>,
#[clap(long = "route-prefix")]
pub route_prefix: Option<String>,
#[clap(long = "random-route", conflicts_with("route-prefix"))]
pub random_route: bool,
#[clap(short = 'P', long = "no-symlinks")]
pub no_symlinks: bool,
#[clap(short = 'H', long = "hidden")]
pub hidden: bool,
#[clap(
short = 'c',
long = "color-scheme",
default_value = "squirrel",
possible_values = &*renderer::THEME_SLUGS,
ignore_case = true,
)]
pub color_scheme: String,
#[clap(
short = 'd',
long = "color-scheme-dark",
default_value = "archlinux",
possible_values = &*renderer::THEME_SLUGS,
ignore_case = true,
)]
pub color_scheme_dark: String,
#[clap(short = 'q', long = "qrcode")]
pub qrcode: bool,
#[clap(short = 'u', long = "upload-files", value_hint = ValueHint::FilePath, min_values = 0)]
pub allowed_upload_dir: Option<Vec<PathBuf>>,
#[clap(short = 'U', long = "mkdir", requires = "allowed-upload-dir")]
pub mkdir_enabled: bool,
#[clap(
arg_enum,
short = 'm',
long = "media-type",
requires = "allowed-upload-dir"
)]
pub media_type: Option<Vec<MediaType>>,
#[clap(
short = 'M',
long = "raw-media-type",
requires = "allowed-upload-dir",
conflicts_with = "media-type"
)]
pub media_type_raw: Option<String>,
#[clap(short = 'o', long = "overwrite-files")]
pub overwrite_files: bool,
#[clap(short = 'r', long = "enable-tar")]
pub enable_tar: bool,
#[clap(short = 'g', long = "enable-tar-gz")]
pub enable_tar_gz: bool,
#[clap(short = 'z', long = "enable-zip")]
pub enable_zip: bool,
#[clap(short = 'D', long = "dirs-first")]
pub dirs_first: bool,
#[clap(short = 't', long = "title")]
pub title: Option<String>,
#[clap(
long = "header",
parse(try_from_str = parse_header),
multiple_occurrences(true),
number_of_values = 1
)]
pub header: Vec<HeaderMap>,
#[clap(short = 'l', long = "show-symlink-info")]
pub show_symlink_info: bool,
#[clap(short = 'F', long = "hide-version-footer")]
pub hide_version_footer: bool,
#[clap(long = "hide-theme-selector")]
pub hide_theme_selector: bool,
#[clap(short = 'W', long = "show-wget-footer")]
pub show_wget_footer: bool,
#[clap(long = "print-completions", value_name = "shell", arg_enum)]
pub print_completions: Option<Shell>,
#[clap(long = "print-manpage")]
pub print_manpage: bool,
#[cfg(feature = "tls")]
#[clap(long = "tls-cert", requires = "tls-key", value_hint = ValueHint::FilePath)]
pub tls_cert: Option<PathBuf>,
#[cfg(feature = "tls")]
#[clap(long = "tls-key", requires = "tls-cert", value_hint = ValueHint::FilePath)]
pub tls_key: Option<PathBuf>,
#[clap(long)]
pub readme: bool,
}
fn parse_interface(src: &str) -> Result<IpAddr, std::net::AddrParseError> {
src.parse::<IpAddr>()
}
fn parse_auth(src: &str) -> Result<auth::RequiredAuth, ContextualError> {
let mut split = src.splitn(3, ':');
let invalid_auth_format = Err(ContextualError::InvalidAuthFormat);
let username = match split.next() {
Some(username) => username,
None => return invalid_auth_format,
};
let second_part = match split.next() {
Some(password) => password,
None => return invalid_auth_format,
};
let password = if let Some(hash_hex) = split.next() {
let hash_bin = hex::decode(hash_hex).map_err(|_| ContextualError::InvalidPasswordHash)?;
match second_part {
"sha256" => auth::RequiredAuthPassword::Sha256(hash_bin),
"sha512" => auth::RequiredAuthPassword::Sha512(hash_bin),
_ => return Err(ContextualError::InvalidHashMethod(second_part.to_owned())),
}
} else {
if second_part.len() > 255 {
return Err(ContextualError::PasswordTooLongError);
}
auth::RequiredAuthPassword::Plain(second_part.to_owned())
};
Ok(auth::RequiredAuth {
username: username.to_owned(),
password,
})
}
pub fn parse_header(src: &str) -> Result<HeaderMap, httparse::Error> {
let mut headers = [httparse::EMPTY_HEADER; 1];
let header = format!("{}\n", src);
httparse::parse_headers(header.as_bytes(), &mut headers)?;
let mut header_map = HeaderMap::new();
if let Some(h) = headers.first() {
if h.name != httparse::EMPTY_HEADER.name {
header_map.insert(
HeaderName::from_bytes(h.name.as_bytes()).unwrap(),
HeaderValue::from_bytes(h.value).unwrap(),
);
}
}
Ok(header_map)
}
#[rustfmt::skip]
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use pretty_assertions::assert_eq;
fn create_required_auth(username: &str, password: &str, encrypt: &str) -> auth::RequiredAuth {
use auth::*;
use RequiredAuthPassword::*;
let password = match encrypt {
"plain" => Plain(password.to_owned()),
"sha256" => Sha256(hex::decode(password.to_owned()).unwrap()),
"sha512" => Sha512(hex::decode(password.to_owned()).unwrap()),
_ => panic!("Unknown encryption type"),
};
auth::RequiredAuth {
username: username.to_owned(),
password,
}
}
#[rstest(
auth_string, username, password, encrypt,
case("username:password", "username", "password", "plain"),
case("username:sha256:abcd", "username", "abcd", "sha256"),
case("username:sha512:abcd", "username", "abcd", "sha512")
)]
fn parse_auth_valid(auth_string: &str, username: &str, password: &str, encrypt: &str) {
assert_eq!(
parse_auth(auth_string).unwrap(),
create_required_auth(username, password, encrypt),
);
}
#[rstest(
auth_string, err_msg,
case(
"foo",
"Invalid format for credentials string. Expected username:password, username:sha256:hash or username:sha512:hash"
),
case(
"username:blahblah:abcd",
"blahblah is not a valid hashing method. Expected sha256 or sha512"
),
case(
"username:sha256:invalid",
"Invalid format for password hash. Expected hex code"
),
case(
"username:sha512:invalid",
"Invalid format for password hash. Expected hex code"
),
)]
fn parse_auth_invalid(auth_string: &str, err_msg: &str) {
let err = parse_auth(auth_string).unwrap_err();
assert_eq!(format!("{}", err), err_msg.to_owned());
}
}