use malwaredb_server::{State, StateBuilder};
use std::fmt::{Debug, Formatter};
use std::net::IpAddr;
use std::path::PathBuf;
use std::str::FromStr;
use anyhow::{bail, Context};
use bytesize::ByteSize;
use clap::{Parser, ValueHint};
use home::home_dir;
use serde::Deserialize;
#[cfg(feature = "admin")]
use serde::Serialize;
use zeroize::Zeroizing;
#[cfg(all(not(target_os = "macos"), target_family = "unix"))]
pub(crate) const CONFIG_PATHS: [&str; 2] = [
"/etc/mdb/mdb_config.toml",
"/usr/local/etc/mdb/mdb_config.toml",
];
#[cfg(target_os = "macos")]
pub(crate) const CONFIG_PATHS: [&str; 1] = ["/Library/Preferences/MalwareDB/mdb_config.toml"];
#[cfg(target_family = "windows")]
pub(crate) const CONFIG_PATHS: [&str; 1] = ["C:\\Program Files\\MalwareDB\\mdb_config.toml"];
#[cfg(all(not(target_family = "windows"), not(target_family = "unix")))]
pub(crate) const CONFIG_PATHS: [&str; 1] = ["mdb_config.toml"];
#[derive(Parser, Deserialize)]
#[cfg_attr(feature = "admin", derive(Serialize))]
pub struct Config {
#[arg(short, long, default_value_t = default_port())]
#[serde(default = "default_port")]
pub port: u16,
#[arg(long, value_hint = ValueHint::DirPath)]
pub dir: Option<PathBuf>,
#[arg(short, long, default_value_t = default_ip_addr())]
#[serde(default = "default_ip_addr")]
pub ip: IpAddr,
#[arg(short, long, default_value_t = default_max_upload_size())]
#[serde(default = "default_max_upload_size")]
pub max_upload_size: ByteSize,
#[arg(long)]
pub db: Zeroizing<String>,
#[arg(long, value_hint = ValueHint::FilePath)]
#[serde(default)]
pub cert: Option<PathBuf>,
#[arg(long, value_hint = ValueHint::FilePath)]
#[serde(default)]
pub key: Option<PathBuf>,
#[arg(long, value_hint = ValueHint::FilePath)]
#[serde(default)]
pub pg_cert: Option<PathBuf>,
#[cfg(feature = "vt")]
#[serde(default, flatten)]
pub vt_client: Option<malwaredb_virustotal::VirusTotalClient>,
#[arg(long, action, default_value_t = false)]
#[serde(default)]
pub mdns: bool,
}
const fn default_port() -> u16 {
8080
}
const fn default_ip_addr() -> IpAddr {
IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
}
const fn default_max_upload_size() -> ByteSize {
ByteSize::mib(100)
}
impl Config {
pub async fn into_state(self) -> anyhow::Result<State> {
let builder = StateBuilder::new(&self.db, self.pg_cert).await?;
let mut builder = builder
.ip(self.ip)
.port(self.port)
.max_upload(usize::try_from(self.max_upload_size.0)?);
#[cfg(feature = "vt")]
if let Some(vt_client) = self.vt_client.clone() {
builder = builder.vt_client(vt_client);
}
if let Some(dir) = self.dir {
builder = builder.directory(dir);
}
if self.mdns {
builder = builder.enable_mdns();
}
if self.cert.is_some() && self.key.is_some() {
let Some(cert) = self.cert else {
return Err(anyhow::anyhow!("Certificate file not found!"));
};
let Some(key) = self.key else {
return Err(anyhow::anyhow!("Key file not found!"));
};
builder = builder.tls(cert, key).await?;
}
builder.into_state().await
}
pub fn from_found_files() -> anyhow::Result<Self> {
let current_config = if let Ok(mut current_dir) = std::env::current_dir() {
current_dir.push("mdb_config.toml");
if current_dir.exists() {
return Self::from_file(¤t_dir);
}
Some(current_dir)
} else {
None
};
let home_config = if let Some(mut home_config) = home_dir() {
home_config.push(".mdb_server");
home_config.push("mdb_config.toml");
if home_config.exists() {
return Self::from_file(&home_config);
}
Some(home_config)
} else {
None
};
for system_path in CONFIG_PATHS {
let system_path = PathBuf::from_str(system_path)?;
if system_path.exists() {
return Self::from_file(&system_path);
}
}
let mut dirs = Vec::with_capacity(3);
if let Some(current_dir) = current_config {
dirs.push(current_dir);
}
if let Some(home_config) = home_config {
dirs.push(home_config);
}
for path in CONFIG_PATHS {
dirs.push(PathBuf::from(path));
}
bail!(
"MalwareDB could not automatically find a configuration file, checked: {:?}",
dirs.iter()
.map(|p| format!("{}", p.display()))
.collect::<Vec<String>>()
.join(",")
)
}
pub fn from_file(path: &PathBuf) -> anyhow::Result<Self> {
let config = std::fs::read_to_string(path)
.context(format!("failed to read config file {}", path.display()))?;
let cfg: Config = toml::from_str(&config)
.context(format!("failed to parse config file {}", path.display()))?;
Ok(cfg)
}
}
impl Debug for Config {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "MalwareDB listening on {}:{}", self.ip, self.port)?;
match &self.dir {
Some(path) => write!(f, " saving files to {}", path.display()),
None => write!(f, " not saving files"),
}
}
}
#[cfg(feature = "admin")]
impl Default for Config {
fn default() -> Self {
Self {
port: default_port(),
dir: Some("/path/to/samples/".into()),
ip: default_ip_addr(),
max_upload_size: default_max_upload_size(),
#[cfg(feature = "sqlite")]
db: Zeroizing::new(
"sqlite: file:/path/to/sqlite.db, postgres: postgres host=localhost dbname=malwaredb user=malwaredb password=malwaredb".to_string(),
),
#[cfg(not(feature = "sqlite"))]
db: Zeroizing::new("postgres host=localhost dbname=malwaredb user=malwaredb password=malwaredb".to_string()),
#[cfg(feature = "vt")]
vt_client: None,
cert: None,
key: None,
pg_cert: None,
mdns: false,
}
}
}