malwaredb-client 0.3.4

Client application and library for connecting to MalwareDB.
Documentation
// SPDX-License-Identifier: Apache-2.0

use crate::MdbClient;
use malwaredb_types::doc::{is_zip_file_doc, PK_HEADER};

use std::io::{Read, Seek, SeekFrom};
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;

use anyhow::{bail, ensure, Context, Result};
use clap::{Parser, ValueHint};
use flate2::read::GzDecoder;
use walkdir::WalkDir;

pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];

/// Submit one or more samples to the server by source ID
#[derive(Parser, Clone, Debug, PartialEq)]
pub struct SubmitSamples {
    /// The source name which is the origin for this file
    #[arg(short = 'n', long, name = "name")]
    pub source_name: Option<String>,

    /// The source ID which is the origin for this file
    #[arg(short = 'i', long, name = "id")]
    pub source_id: Option<u32>,

    /// Max depth, useful if there might be a recursive symlink
    #[arg(short, long, default_value = "100")]
    pub max_depth: usize,

    /// Password(s) needed to open protected Zip files, should any be encountered
    #[arg(short, long)]
    pub password: Option<String>,

    /// Turn debugging information on
    #[arg(short, long, action = clap::ArgAction::Count)]
    pub debug: u8,

    /// The file(s) to send, walking directories and following symlinks
    #[arg(value_name = "FILE", value_hint = ValueHint::FilePath)]
    pub files: Vec<PathBuf>,
}

impl SubmitSamples {
    pub async fn exec(&self, config: &MdbClient) -> Result<ExitCode> {
        if self.source_name.is_none() == self.source_id.is_none() {
            bail!("Pick a source ID or name! Not both, not neither.");
        }

        let sources = config.sources().await?;

        let source_id = if let Some(name) = &self.source_name {
            sources
                .sources
                .iter()
                .find_map(|s| if s.name == *name { Some(s.id) } else { None })
                .ok_or_else(|| panic!("Source {name} not found."))
                .unwrap()
        } else {
            let source_id = self.source_id.unwrap();
            ensure!(
                sources.sources.iter().any(|s| s.id == source_id),
                "Source ID {source_id} isn't valid."
            );

            source_id
        };

        let counter = Arc::new(AtomicU32::default());

        let counter_copy = counter.clone();
        ctrlc::set_handler(move || {
            println!("Uploaded {} files.", counter_copy.load(Ordering::Relaxed));
            std::process::exit(1)
        })?;

        for path in &self.files {
            if path.is_file() {
                if let Err(e) = self.submit_file(config, path, source_id, &counter).await {
                    eprintln!("Error submitting {}: {e}", path.display());
                }
            } else if path.is_dir() {
                for entry in WalkDir::new(path)
                    .follow_links(true)
                    .max_depth(self.max_depth)
                    .into_iter()
                    .flatten()
                {
                    let entry = entry.path().to_path_buf();
                    if entry.is_file() {
                        if let Err(e) = self.submit_file(config, &entry, source_id, &counter).await
                        {
                            eprintln!("Error submitting {}: {e}", path.display());
                        }
                    }
                }
            }
        }

        println!("Uploaded {} files.", counter.load(Ordering::Relaxed));

        Ok(ExitCode::SUCCESS)
    }

    async fn submit_file(
        &self,
        config: &MdbClient,
        path: &PathBuf,
        source_id: u32,
        counter: &Arc<AtomicU32>,
    ) -> Result<()> {
        let mut header = [0u8; 4];
        let mut file = std::fs::File::open(path)?;
        if file.read(&mut header)? != 4 {
            bail!("Failed to read header for {}", path.display());
        }

        if header[..2] == PK_HEADER && !is_zip_file_doc(path).unwrap_or(false) {
            if self.debug > 1 {
                println!(
                    "Assuming {} is an archive of files to be submitted.",
                    path.display()
                );
            }

            file.seek(SeekFrom::Start(0))?;
            let mut archive = zip::ZipArchive::new(file)?;
            if self.debug > 1 {
                println!("{} has {} items", path.display(), archive.len());
            }

            for i in 0..archive.len() {
                let mut file = if let Some(password) = &self.password {
                    archive
                        .by_index_decrypt(i, password.as_bytes())
                        .expect("unable to get Zip object with password")
                } else {
                    match archive.by_index(i) {
                        Ok(f) => f,
                        Err(e) => {
                            bail!("ZipError: {e}");
                        }
                    }
                };

                let mut contents = Vec::new();
                std::io::copy(&mut file, &mut contents).context(format!(
                    "failed to read file #{i} {} from zip data",
                    file.name()
                ))?;
                if self.debug > 2 {
                    println!("{} len: {}", file.name(), contents.len());
                }
                let file_name_clone = file.name().to_owned();
                if let Ok(success) = config
                    .submit(contents, file_name_clone.clone(), source_id)
                    .await
                {
                    if success {
                        if self.debug > 0 {
                            println!("Sent {file_name_clone} successfully");
                        }
                        counter.fetch_add(1, Ordering::Relaxed);
                    }
                }
            }

            Ok(())
        } else {
            let buffer = if header[..2] == GZIP_MAGIC {
                let mut decompressor = GzDecoder::new(std::fs::File::open(path)?);
                let mut decompressed: Vec<u8> = vec![];
                decompressor.read_to_end(&mut decompressed)?;
                decompressed
            } else if header == ZSTD_MAGIC {
                let file = std::fs::File::open(path)?;
                let mut decompressed: Vec<u8> = vec![];
                zstd::stream::copy_decode(file, &mut decompressed)?;
                decompressed
            } else {
                std::fs::read(path)?
            };

            config
                .submit(
                    buffer,
                    path.file_name()
                        .expect("failed to get file name")
                        .to_string_lossy(),
                    source_id,
                )
                .await
                .map(|s| {
                    if s {
                        counter.fetch_add(1, Ordering::Relaxed);
                    }
                })
        }
    }
}

#[test]
fn verify_cli() {
    use clap::CommandFactory;

    SubmitSamples::command().debug_assert();
}