malwaredb 0.3.2

Service for storing malicious, benign, or unknown files and related metadata and relationships.
// SPDX-License-Identifier: Apache-2.0

mod compression;
mod config;
mod groups;
mod keys;
mod labels;
mod load;
mod sources;
mod users;

#[cfg(feature = "vt")]
mod vt;

mod unknowns;

use super::config::Config;
use malwaredb_server::State;
use std::io::Write;

use std::process::ExitCode;

use clap::{Args, Subcommand, ValueHint};
use malwaredb_server::db::DatabaseType;

/// Administrative Options for Malware Database Server
#[derive(Clone, Debug, Args, PartialEq)]
pub struct Admin {
    /// Config file: specify or search; needed for the database connection information.
    /// Specifying the path is best to avoid mix-up between instances.
    #[arg(short = 'c', value_name = "FILE", value_hint = ValueHint::FilePath)]
    pub config_file: Option<std::path::PathBuf>,

    /// Administrative subcommands
    #[clap(subcommand)]
    pub action: AdminActions,
}

impl Admin {
    pub async fn execute(&self) -> anyhow::Result<ExitCode> {
        if let AdminActions::DefaultConfig(default_config) = &self.action {
            return default_config.execute();
        }

        let cfg = if let Some(path) = &self.config_file {
            Config::from_file(path)
        } else {
            Config::from_found_files()
        }?;

        if self.action == AdminActions::Migrate {
            eprintln!("Back up the database before proceeding. This migration may take a while if there is a lot of data.");
            let mut s = String::new();
            print!("Please enter \"CONFIRM\" to proceed: ");
            let _ = std::io::stdout().flush();
            std::io::stdin().read_line(&mut s)?;
            return if s.trim() == "CONFIRM" {
                match DatabaseType::migrate(&cfg.db, cfg.pg_cert).await {
                    Ok(_) => Ok(ExitCode::SUCCESS),
                    Err(e) => {
                        eprintln!("Migration failed: {e}");
                        Err(e)
                    }
                }
            } else {
                eprintln!("Migration aborted.");
                Ok(ExitCode::FAILURE)
            };
        }

        let state = cfg.into_state().await?;

        match &self.action {
            AdminActions::AddUserToGroup(cmd) => cmd.execute(state).await,
            AdminActions::ClearAPIKeys(cmd) => cmd.execute(state).await,
            AdminActions::ResetPassword(cmd) => cmd.execute(state).await,
            AdminActions::AddGroupToSource(cmd) => cmd.execute(state).await,
            AdminActions::BulkAdd(cmd) => cmd.execute(state).await,
            AdminActions::Create(sub) => sub.execute(state).await,
            AdminActions::List(sub) => sub.execute(state).await,
            AdminActions::Compression(comp) => comp.execute(state).await,
            #[cfg(feature = "vt")]
            AdminActions::VirusTotal(vt) => vt.execute(state).await,
            AdminActions::UnknownFiles(unk) => unk.execute(state).await,
            AdminActions::DefaultConfig(_) | AdminActions::Migrate => unreachable!(),
            AdminActions::RenameInstance(cmd) => cmd.execute(state).await,
            AdminActions::Stats => {
                let db_info = state.db_type.db_info().await?;
                println!("-- {} --", state.db_config.name);
                println!(
                    "Number of samples: {}\nNumber of users: {}\nNumber of groups: {}\nNumber of sources: {}\nDatabase size: {}\nDatabase version: {}",
                    db_info.num_files, db_info.num_users, db_info.num_groups, db_info.num_sources, db_info.size, db_info.version
                );

                if db_info.num_files > 0 {
                    println!("File counts by type:");
                    let file_types_counts = state.db_type.file_types_counts().await?;
                    for (name, count) in file_types_counts {
                        println!("{name}: {count}");
                    }
                }

                #[cfg(feature = "vt")]
                {
                    if state.db_config.send_samples_to_vt {
                        let vt_stats = state.db_type.get_vt_stats().await?;
                        println!("VT Stats:\nClean samples: {}\nSamples with AV hits: {}\nSamples without records: {}", vt_stats.clean_records, vt_stats.hits_records, vt_stats.files_without_records);
                    }
                }

                Ok(ExitCode::SUCCESS)
            }
        }
    }
}

#[derive(Clone, Subcommand, Debug, PartialEq)]
pub enum AdminActions {
    AddUserToGroup(groups::AddUser),
    ClearAPIKeys(users::ResetAPIKeys),
    ResetPassword(users::ResetPassword),
    AddGroupToSource(sources::AddGroup),
    /// Some information about the contents of the Database
    Stats,

    /// Rename the Malware DB instance
    RenameInstance(config::Rename),

    /// Show (or toggle) the state of Malware DB storing incoming files with zstd compression
    Compression(compression::Compression),

    /// Show (or toggle) the state of Malware DB submitting samples to Virus Total
    #[cfg(feature = "vt")]
    VirusTotal(vt::VirusTotal),

    /// Show (or toggle) the state of Malware DB storing unknown files
    UnknownFiles(unknowns::UnknownFiles),

    /// Bulk load files
    BulkAdd(load::Load),

    /// Create users, groups, sources
    #[clap(subcommand)]
    Create(CreateActions),

    /// List users, groups, sources
    #[clap(subcommand)]
    List(ListActions),

    /// Create an example config file optionally in the operating system-specific location
    DefaultConfig(config::CreateDefaultConfig),

    /// Perform a database migration if needed (updating database schema).
    Migrate,
}

#[derive(Clone, Subcommand, Debug, PartialEq)]
pub enum CreateActions {
    /// Create a user account
    User(users::Create),

    /// Create a group
    Group(groups::Create),

    /// Create a key to be used for file encryption
    Key(keys::Create),

    /// Create a sample source
    Source(sources::Create),

    /// Create a label
    Label(labels::Create),
}

impl CreateActions {
    pub async fn execute(&self, state: State) -> anyhow::Result<ExitCode> {
        match self {
            CreateActions::User(cmd) => cmd.execute(state).await,
            CreateActions::Group(cmd) => cmd.execute(state).await,
            CreateActions::Key(cmd) => cmd.execute(state).await,
            CreateActions::Source(cmd) => cmd.execute(state).await,
            CreateActions::Label(cmd) => cmd.execute(state).await,
        }
    }
}

#[derive(Clone, Subcommand, Debug, PartialEq)]
pub enum ListActions {
    /// List users, accounts which may access Malware DB
    Users(users::List),
    /// List file encryption key information
    Keys(keys::List),
    /// List groups, collections of users which may access sample(s)
    Groups(groups::List),
    /// List sources, the origins of file samples
    Sources(sources::List),
    /// List file types known to and supported by Malware DB
    Types,
    /// List labels, a hierarchical taxonomy for file samples
    Labels(labels::List),
}

impl ListActions {
    pub async fn execute(&self, state: State) -> anyhow::Result<ExitCode> {
        match self {
            ListActions::Users(cmd) => cmd.execute(state).await,
            ListActions::Keys(cmd) => cmd.execute(state).await,
            ListActions::Groups(cmd) => cmd.execute(state).await,
            ListActions::Sources(cmd) => cmd.execute(state).await,
            ListActions::Labels(cmd) => cmd.execute(state).await,
            ListActions::Types => {
                for data_type in state.db_type.get_known_data_types().await? {
                    print!("{}", data_type.name);
                    if let Some(desc) = data_type.description {
                        print!(" {desc}");
                    }
                    if data_type.executable {
                        print!(" -- is executable");
                    }
                    println!();
                    for magic in data_type.magic {
                        println!("\t{}", hex::encode(magic));
                    }
                }
                Ok(ExitCode::SUCCESS)
            }
        }
    }
}