use crate::MdbClient;
use malwaredb_types::doc::{is_zip_file_doc, PK_HEADER};
use std::io::{Read, Seek, SeekFrom};
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use anyhow::{bail, ensure, Context, Result};
use clap::{Parser, ValueHint};
use flate2::read::GzDecoder;
use walkdir::WalkDir;
pub const GZIP_MAGIC: [u8; 2] = [0x1fu8, 0x8bu8];
pub const ZSTD_MAGIC: [u8; 4] = [0x28u8, 0xb5u8, 0x2fu8, 0xfdu8];
#[derive(Parser, Clone, Debug, PartialEq)]
pub struct SubmitSamples {
#[arg(short = 'n', long, name = "name")]
pub source_name: Option<String>,
#[arg(short = 'i', long, name = "id")]
pub source_id: Option<u32>,
#[arg(short, long, default_value = "100")]
pub max_depth: usize,
#[arg(short, long)]
pub password: Option<String>,
#[arg(short, long, action = clap::ArgAction::Count)]
pub debug: u8,
#[arg(value_name = "FILE", value_hint = ValueHint::FilePath)]
pub files: Vec<PathBuf>,
}
impl SubmitSamples {
pub async fn exec(&self, config: &MdbClient) -> Result<ExitCode> {
if self.source_name.is_none() == self.source_id.is_none() {
bail!("Pick a source ID or name! Not both, not neither.");
}
let sources = config.sources().await?;
let source_id = if let Some(name) = &self.source_name {
sources
.sources
.iter()
.find_map(|s| if s.name == *name { Some(s.id) } else { None })
.ok_or_else(|| panic!("Source {name} not found."))
.unwrap()
} else {
let source_id = self.source_id.unwrap();
ensure!(
sources.sources.iter().any(|s| s.id == source_id),
"Source ID {source_id} isn't valid."
);
source_id
};
let counter = Arc::new(AtomicU32::default());
let counter_copy = counter.clone();
ctrlc::set_handler(move || {
println!("Uploaded {} files.", counter_copy.load(Ordering::Relaxed));
std::process::exit(1)
})?;
for path in &self.files {
if path.is_file() {
if let Err(e) = self.submit_file(config, path, source_id, &counter).await {
eprintln!("Error submitting {}: {e}", path.display());
}
} else if path.is_dir() {
for entry in WalkDir::new(path)
.follow_links(true)
.max_depth(self.max_depth)
.into_iter()
.flatten()
{
let entry = entry.path().to_path_buf();
if entry.is_file() {
if let Err(e) = self.submit_file(config, &entry, source_id, &counter).await
{
eprintln!("Error submitting {}: {e}", path.display());
}
}
}
}
}
println!("Uploaded {} files.", counter.load(Ordering::Relaxed));
Ok(ExitCode::SUCCESS)
}
async fn submit_file(
&self,
config: &MdbClient,
path: &PathBuf,
source_id: u32,
counter: &Arc<AtomicU32>,
) -> Result<()> {
let mut header = [0u8; 4];
let mut file = std::fs::File::open(path)?;
if file.read(&mut header)? != 4 {
bail!("Failed to read header for {}", path.display());
}
if header[..2] == PK_HEADER && !is_zip_file_doc(path).unwrap_or(false) {
if self.debug > 1 {
println!(
"Assuming {} is an archive of files to be submitted.",
path.display()
);
}
file.seek(SeekFrom::Start(0))?;
let mut archive = zip::ZipArchive::new(file)?;
if self.debug > 1 {
println!("{} has {} items", path.display(), archive.len());
}
for i in 0..archive.len() {
let mut file = if let Some(password) = &self.password {
archive
.by_index_decrypt(i, password.as_bytes())
.expect("unable to get Zip object with password")
} else {
match archive.by_index(i) {
Ok(f) => f,
Err(e) => {
bail!("ZipError: {e}");
}
}
};
let mut contents = Vec::new();
std::io::copy(&mut file, &mut contents).context(format!(
"failed to read file #{i} {} from zip data",
file.name()
))?;
if self.debug > 2 {
println!("{} len: {}", file.name(), contents.len());
}
let file_name_clone = file.name().to_owned();
if let Ok(success) = config
.submit(contents, file_name_clone.clone(), source_id)
.await
{
if success {
if self.debug > 0 {
println!("Sent {file_name_clone} successfully");
}
counter.fetch_add(1, Ordering::Relaxed);
}
}
}
Ok(())
} else {
let buffer = if header[..2] == GZIP_MAGIC {
let mut decompressor = GzDecoder::new(std::fs::File::open(path)?);
let mut decompressed: Vec<u8> = vec![];
decompressor.read_to_end(&mut decompressed)?;
decompressed
} else if header == ZSTD_MAGIC {
let file = std::fs::File::open(path)?;
let mut decompressed: Vec<u8> = vec![];
zstd::stream::copy_decode(file, &mut decompressed)?;
decompressed
} else {
std::fs::read(path)?
};
config
.submit(
buffer,
path.file_name()
.expect("failed to get file name")
.to_string_lossy(),
source_id,
)
.await
.map(|s| {
if s {
counter.fetch_add(1, Ordering::Relaxed);
}
})
}
}
}
#[test]
fn verify_cli() {
use clap::CommandFactory;
SubmitSamples::command().debug_assert();
}