malwaredb 0.3.2

Service for storing malicious, benign, or unknown files and related metadata and relationships.
// SPDX-License-Identifier: Apache-2.0

use malwaredb_server::{State, StateBuilder};

use std::fmt::{Debug, Formatter};
use std::net::IpAddr;
use std::path::PathBuf;
use std::str::FromStr;

use anyhow::{bail, Context};
use bytesize::ByteSize;
use clap::{Parser, ValueHint};
use home::home_dir;
use serde::Deserialize;
#[cfg(feature = "admin")]
use serde::Serialize;
use zeroize::Zeroizing;

#[cfg(all(not(target_os = "macos"), target_family = "unix"))]
pub(crate) const CONFIG_PATHS: [&str; 2] = [
    "/etc/mdb/mdb_config.toml",
    "/usr/local/etc/mdb/mdb_config.toml",
];

#[cfg(target_os = "macos")]
pub(crate) const CONFIG_PATHS: [&str; 1] = ["/Library/Preferences/MalwareDB/mdb_config.toml"];

#[cfg(target_family = "windows")]
pub(crate) const CONFIG_PATHS: [&str; 1] = ["C:\\Program Files\\MalwareDB\\mdb_config.toml"];

// Catch-all, config file in the current working directory
#[cfg(all(not(target_family = "windows"), not(target_family = "unix")))]
pub(crate) const CONFIG_PATHS: [&str; 1] = ["mdb_config.toml"];

/// Malware DB configuration parameters
#[derive(Parser, Deserialize)]
#[cfg_attr(feature = "admin", derive(Serialize))]
pub struct Config {
    /// The port which will be used to listen for connections.
    #[arg(short, long, default_value_t = default_port())]
    #[serde(default = "default_port")]
    pub port: u16,

    /// The directory to store malware samples, if we're keeping them.
    #[arg(long, value_hint = ValueHint::DirPath)]
    pub dir: Option<PathBuf>,

    /// IP address to use for listening for connections
    #[arg(short, long, default_value_t = default_ip_addr())]
    #[serde(default = "default_ip_addr")]
    pub ip: IpAddr,

    /// Maximum size for a file upload, in bytes
    #[arg(short, long, default_value_t = default_max_upload_size())]
    #[serde(default = "default_max_upload_size")]
    pub max_upload_size: ByteSize,

    /// Database connection string
    #[arg(long)]
    pub db: Zeroizing<String>,

    /// PEM-encoded Https certificate file, optionally with the CA certificate in the same file.
    /// (Server and CA certificates concatenated).
    #[arg(long, value_hint = ValueHint::FilePath)]
    #[serde(default)]
    pub cert: Option<PathBuf>,

    /// PEM-encoded Https private key file
    #[arg(long, value_hint = ValueHint::FilePath)]
    #[serde(default)]
    pub key: Option<PathBuf>,

    /// PEM-encoded certificate file for ssl mode with Postgres
    #[arg(long, value_hint = ValueHint::FilePath)]
    #[serde(default)]
    pub pg_cert: Option<PathBuf>,

    /// VT API Key, if Malware DB is to query VT for A/V data or submit unknown samples
    #[cfg(feature = "vt")]
    #[serde(default, flatten)]
    pub vt_client: Option<malwaredb_virustotal::VirusTotalClient>,

    /// If Malware DB should advertise itself via Multicast DNS (also known as Bonjour or Zeroconf)
    #[arg(long, action, default_value_t = false)]
    #[serde(default)]
    pub mdns: bool,
}

const fn default_port() -> u16 {
    8080
}

const fn default_ip_addr() -> IpAddr {
    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
}

const fn default_max_upload_size() -> ByteSize {
    ByteSize::mib(100)
}

impl Config {
    pub async fn into_state(self) -> anyhow::Result<State> {
        let builder = StateBuilder::new(&self.db, self.pg_cert).await?;
        let mut builder = builder
            .ip(self.ip)
            .port(self.port)
            .max_upload(usize::try_from(self.max_upload_size.0)?);

        #[cfg(feature = "vt")]
        if let Some(vt_client) = self.vt_client.clone() {
            builder = builder.vt_client(vt_client);
        }

        if let Some(dir) = self.dir {
            builder = builder.directory(dir);
        }

        if self.mdns {
            builder = builder.enable_mdns();
        }

        if self.cert.is_some() && self.key.is_some() {
            let Some(cert) = self.cert else {
                return Err(anyhow::anyhow!("Certificate file not found!"));
            };

            let Some(key) = self.key else {
                return Err(anyhow::anyhow!("Key file not found!"));
            };

            builder = builder.tls(cert, key).await?;
        }

        builder.into_state().await
    }

    /// Check for a config file in the following locations:
    /// * The directory containing the server binary.
    /// * The user's home directory
    /// * OS-specific directory(ies)
    ///
    /// Malware DB will use the first available configuration file in this order!
    pub fn from_found_files() -> anyhow::Result<Self> {
        let current_config = if let Ok(mut current_dir) = std::env::current_dir() {
            current_dir.push("mdb_config.toml");
            if current_dir.exists() {
                return Self::from_file(&current_dir);
            }
            Some(current_dir)
        } else {
            None
        };

        let home_config = if let Some(mut home_config) = home_dir() {
            home_config.push(".mdb_server");
            home_config.push("mdb_config.toml");
            if home_config.exists() {
                return Self::from_file(&home_config);
            }
            Some(home_config)
        } else {
            None
        };

        for system_path in CONFIG_PATHS {
            let system_path = PathBuf::from_str(system_path)?;
            if system_path.exists() {
                return Self::from_file(&system_path);
            }
        }

        let mut dirs = Vec::with_capacity(3);
        if let Some(current_dir) = current_config {
            dirs.push(current_dir);
        }
        if let Some(home_config) = home_config {
            dirs.push(home_config);
        }
        for path in CONFIG_PATHS {
            dirs.push(PathBuf::from(path));
        }
        bail!(
            "MalwareDB could not automatically find a configuration file, checked: {:?}",
            dirs.iter()
                .map(|p| format!("{}", p.display()))
                .collect::<Vec<String>>()
                .join(",")
        )
    }

    /// Read configuration from a Toml file
    pub fn from_file(path: &PathBuf) -> anyhow::Result<Self> {
        let config = std::fs::read_to_string(path)
            .context(format!("failed to read config file {}", path.display()))?;
        let cfg: Config = toml::from_str(&config)
            .context(format!("failed to parse config file {}", path.display()))?;
        Ok(cfg)
    }
}

impl Debug for Config {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "MalwareDB listening on {}:{}", self.ip, self.port)?;
        match &self.dir {
            Some(path) => write!(f, " saving files to {}", path.display()),
            None => write!(f, " not saving files"),
        }
    }
}

#[cfg(feature = "admin")]
impl Default for Config {
    fn default() -> Self {
        Self {
            port: default_port(),
            dir: Some("/path/to/samples/".into()),
            ip: default_ip_addr(),
            max_upload_size: default_max_upload_size(),
            #[cfg(feature = "sqlite")]
            db: Zeroizing::new(
                "sqlite: file:/path/to/sqlite.db, postgres: postgres host=localhost dbname=malwaredb user=malwaredb password=malwaredb".to_string(),
            ),
            #[cfg(not(feature = "sqlite"))]
            db: Zeroizing::new("postgres host=localhost dbname=malwaredb user=malwaredb password=malwaredb".to_string()),
            #[cfg(feature = "vt")]
            vt_client: None,
            cert: None,
            key: None,
            pg_cert: None,
            mdns: false,
        }
    }
}