use malwaredb_server::db::types::{FileMetadata, FileType};
use malwaredb_server::{State, GZIP_MAGIC};
use malwaredb_types::doc::{is_zip_file_doc, PK_HEADER};
use malwaredb_types::KnownType;
use std::io::{Cursor, Read};
use std::path::{Path, PathBuf};
use std::process::ExitCode;
use anyhow::{bail, ensure, Result};
use clap::Parser;
use flate2::read::GzDecoder;
use tracing::{info, trace};
use walkdir::WalkDir;
#[derive(Clone, Debug, Parser, PartialEq)]
pub struct Load {
#[arg(short, long)]
pub source_id: u32,
#[arg(short, long)]
pub user_id: u32,
#[arg(short, long, default_value = "100")]
pub max_depth: usize,
#[arg(short = 'l', long, default_value = "false")]
pub follow_links: bool,
#[arg(short, long)]
pub password: Option<String>,
pub paths: Vec<PathBuf>,
}
impl Load {
async fn add_bytes(
&self,
state: &State,
db_file_types: &Vec<FileType>,
fname: &str,
contents: &[u8],
) -> Result<bool> {
let db_file_type = {
let mut id = None;
for db_file_type in db_file_types {
for magic in &db_file_type.magic {
if contents.starts_with(magic)
&& magic.is_empty() == state.db_config.keep_unknown_files
{
info!("File type for {fname}: {}", db_file_type.name);
id = Some(db_file_type.id);
}
}
}
id
};
if let Some(type_id) = db_file_type {
let known_type = KnownType::new(contents)?;
let meta_data = FileMetadata::new(contents, Some(fname));
if state
.db_type
.add_file(
&meta_data,
known_type,
self.user_id,
self.source_id,
type_id,
None,
)
.await?
.is_new
{
state.store_bytes(contents).await.map_err(|e| panic!("Unable to write data to disk: {e}.\nAre you running the admin command as the correct user?")).expect("unable to write data to disk");
return Ok(true);
}
}
Ok(false)
}
async fn load_file_path(
&self,
state: &State,
db_file_types: &Vec<FileType>,
path: &Path,
) -> Result<u32> {
let contents = std::fs::read(path)?;
let fname = path.file_name().unwrap().to_str().unwrap().to_string();
let mut file = std::fs::File::open(path)?;
let mut header = [0u8; 2];
if file.read(&mut header)? != 2 {
bail!("Failed to read header for {}", path.display());
}
if header == GZIP_MAGIC {
let buff = Cursor::new(contents);
let mut decompressor = GzDecoder::new(buff);
let mut decompressed: Vec<u8> = vec![];
decompressor.read_to_end(&mut decompressed)?;
return self
.add_bytes(state, db_file_types, &fname, &decompressed)
.await
.and(Ok(1));
} else if header == PK_HEADER {
if let Ok(false) = is_zip_file_doc(path) {
trace!("{fname:?} is a Zip!");
return self.add_from_zip(state, db_file_types, path).await;
}
}
self.add_bytes(state, db_file_types, &fname, &contents)
.await
.and(Ok(1))
}
pub async fn execute(&self, state: State) -> Result<ExitCode> {
let sources = state.db_type.list_sources().await?;
let users = state.db_type.list_users().await?;
ensure!(
sources.iter().any(|s| s.id == self.source_id),
"Source ID {} is not valid.",
self.source_id
);
ensure!(
users.iter().any(|u| u.id == self.user_id),
"User ID {} is not valid.",
self.user_id
);
if !state
.db_type
.allowed_user_source(self.user_id, self.source_id)
.await?
{
panic!(
"User ID {} is not in a group which is allowed to use source {}.",
self.user_id, self.source_id
);
}
let db_file_types = state.db_type.get_known_data_types().await?;
let mut counter = 0u32;
for path in &self.paths {
if path.is_dir() {
for entry in WalkDir::new(path)
.follow_links(self.follow_links)
.max_depth(self.max_depth)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
counter += self
.load_file_path(&state, &db_file_types, entry.path())
.await?;
}
}
} else {
counter += self
.load_file_path(&state, &db_file_types, path.as_path())
.await?;
}
}
println!("Inserted records for {counter} files.");
Ok(ExitCode::SUCCESS)
}
async fn add_from_zip(
&self,
state: &State,
db_file_types: &Vec<FileType>,
path: &Path,
) -> Result<u32> {
let file = std::fs::File::open(path)?;
let mut counter = 0;
let mut archive = zip::ZipArchive::new(file)?;
for i in 0..archive.len() {
let mut file = if let Some(password) = &self.password {
archive
.by_index_decrypt(i, password.as_bytes())
.expect("unable to get Zip object with password")
} else {
match archive.by_index(i) {
Ok(f) => f,
Err(e) => {
bail!("ZipError: {e}");
}
}
};
if (*file.name()).ends_with('/') {
continue;
}
let fname = (*file.name()).to_string();
let mut contents = Vec::new();
std::io::copy(&mut file, &mut contents).unwrap();
match self
.add_bytes(state, db_file_types, file.name(), &contents)
.await
{
Ok(false) => {}
Ok(true) => {
counter += 1;
}
Err(e) => {
eprintln!("Error parsing sample {fname}: {e}");
}
}
}
Ok(counter)
}
}