malwaredb 0.3.2

Service for storing malicious, benign, or unknown files and related metadata and relationships.
// SPDX-License-Identifier: Apache-2.0

use malwaredb_server::db::types::{FileMetadata, FileType};
use malwaredb_server::{State, GZIP_MAGIC};
use malwaredb_types::doc::{is_zip_file_doc, PK_HEADER};
use malwaredb_types::KnownType;

use std::io::{Cursor, Read};
use std::path::{Path, PathBuf};
use std::process::ExitCode;

use anyhow::{bail, ensure, Result};
use clap::Parser;
use flate2::read::GzDecoder;
use tracing::{info, trace};
use walkdir::WalkDir;

/// Bulk load files into Malware DB
#[derive(Clone, Debug, Parser, PartialEq)]
pub struct Load {
    /// The id number for the source to be used as the origin of the files added
    #[arg(short, long)]
    pub source_id: u32,

    /// The user id to use as the uploading person for the files added
    #[arg(short, long)]
    pub user_id: u32,

    /// Maximum depth for directory recursion
    #[arg(short, long, default_value = "100")]
    pub max_depth: usize,

    /// Whether or not symbolic links should be followed
    #[arg(short = 'l', long, default_value = "false")]
    pub follow_links: bool,

    /// Password(s) needed to open protected Zip files, should any be encountered
    #[arg(short, long)]
    pub password: Option<String>,

    /// The files and/or directories of files to be added
    pub paths: Vec<PathBuf>,
}

impl Load {
    async fn add_bytes(
        &self,
        state: &State,
        db_file_types: &Vec<FileType>,
        fname: &str,
        contents: &[u8],
    ) -> Result<bool> {
        let db_file_type = {
            let mut id = None;
            for db_file_type in db_file_types {
                for magic in &db_file_type.magic {
                    if contents.starts_with(magic)
                        && magic.is_empty() == state.db_config.keep_unknown_files
                    {
                        info!("File type for {fname}: {}", db_file_type.name);
                        id = Some(db_file_type.id);
                    }
                }
            }
            id
        };

        if let Some(type_id) = db_file_type {
            let known_type = KnownType::new(contents)?;
            let meta_data = FileMetadata::new(contents, Some(fname));

            if state
                .db_type
                .add_file(
                    &meta_data,
                    known_type,
                    self.user_id,
                    self.source_id,
                    type_id,
                    None,
                )
                .await?
                .is_new
            {
                state.store_bytes(contents).await.map_err(|e| panic!("Unable to write data to disk: {e}.\nAre you running the admin command as the correct user?")).expect("unable to write data to disk");
                return Ok(true);
            }
        }

        Ok(false)
    }

    async fn load_file_path(
        &self,
        state: &State,
        db_file_types: &Vec<FileType>,
        path: &Path,
    ) -> Result<u32> {
        let contents = std::fs::read(path)?;
        let fname = path.file_name().unwrap().to_str().unwrap().to_string();

        let mut file = std::fs::File::open(path)?;
        let mut header = [0u8; 2];
        if file.read(&mut header)? != 2 {
            bail!("Failed to read header for {}", path.display());
        }

        if header == GZIP_MAGIC {
            let buff = Cursor::new(contents);
            let mut decompressor = GzDecoder::new(buff);
            let mut decompressed: Vec<u8> = vec![];
            decompressor.read_to_end(&mut decompressed)?;
            return self
                .add_bytes(state, db_file_types, &fname, &decompressed)
                .await
                .and(Ok(1));
        } else if header == PK_HEADER {
            if let Ok(false) = is_zip_file_doc(path) {
                trace!("{fname:?} is a Zip!");
                return self.add_from_zip(state, db_file_types, path).await;
            }
        }

        self.add_bytes(state, db_file_types, &fname, &contents)
            .await
            .and(Ok(1))
    }

    pub async fn execute(&self, state: State) -> Result<ExitCode> {
        let sources = state.db_type.list_sources().await?;
        let users = state.db_type.list_users().await?;

        ensure!(
            sources.iter().any(|s| s.id == self.source_id),
            "Source ID {} is not valid.",
            self.source_id
        );
        ensure!(
            users.iter().any(|u| u.id == self.user_id),
            "User ID {} is not valid.",
            self.user_id
        );

        if !state
            .db_type
            .allowed_user_source(self.user_id, self.source_id)
            .await?
        {
            panic!(
                "User ID {} is not in a group which is allowed to use source {}.",
                self.user_id, self.source_id
            );
        }

        let db_file_types = state.db_type.get_known_data_types().await?;
        let mut counter = 0u32;

        for path in &self.paths {
            if path.is_dir() {
                for entry in WalkDir::new(path)
                    .follow_links(self.follow_links)
                    .max_depth(self.max_depth)
                    .into_iter()
                    .flatten()
                {
                    if entry.file_type().is_file() {
                        counter += self
                            .load_file_path(&state, &db_file_types, entry.path())
                            .await?;
                    }
                }
            } else {
                counter += self
                    .load_file_path(&state, &db_file_types, path.as_path())
                    .await?;
            }
        }

        println!("Inserted records for {counter} files.");

        Ok(ExitCode::SUCCESS)
    }

    async fn add_from_zip(
        &self,
        state: &State,
        db_file_types: &Vec<FileType>,
        path: &Path,
    ) -> Result<u32> {
        let file = std::fs::File::open(path)?;

        let mut counter = 0;
        let mut archive = zip::ZipArchive::new(file)?;
        for i in 0..archive.len() {
            let mut file = if let Some(password) = &self.password {
                archive
                    .by_index_decrypt(i, password.as_bytes())
                    .expect("unable to get Zip object with password")
            } else {
                match archive.by_index(i) {
                    Ok(f) => f,
                    Err(e) => {
                        bail!("ZipError: {e}");
                    }
                }
            };
            if (*file.name()).ends_with('/') {
                continue;
            }

            let fname = (*file.name()).to_string();
            let mut contents = Vec::new();

            std::io::copy(&mut file, &mut contents).unwrap();
            match self
                .add_bytes(state, db_file_types, file.name(), &contents)
                .await
            {
                Ok(false) => {}
                Ok(true) => {
                    counter += 1;
                }
                Err(e) => {
                    eprintln!("Error parsing sample {fname}: {e}");
                }
            }
        }

        Ok(counter)
    }
}