#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![forbid(unsafe_code)]
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use std::process::ExitCode;
use anyhow::{Result, bail};
use clap::{Command, CommandFactory, Parser, Subcommand, ValueHint};
use clap_complete::{Shell, generate};
use walkdir::WalkDir;
#[derive(Parser)]
#[command(author, about, version = malware_modeler::VERSION)]
pub struct Args {
#[clap(subcommand)]
cmd: Actions,
#[clap(long, default_value_t = num_cpus::get())]
threads: usize,
}
#[derive(Subcommand)]
enum Actions {
Ngram(Ngram),
Dataset(Dataset),
DatasetConvert(DatasetConvert),
Train(Train),
Evaluate(Evaluate),
Model(ModelInfo),
Types(Types),
Similarity(Similarity),
SelectiveUnzip(SelectiveUnzip),
ZipSummary(ZipSummary),
Generate(Generator),
}
#[derive(Parser)]
struct Ngram {
#[arg(long, value_hint = ValueHint::DirPath)]
path: PathBuf,
#[arg(short = 't', long = "type", ignore_case = true)]
ftype: Option<malware_modeler::ftype::FileType>,
#[arg(default_value = "6")]
n: u16,
#[arg(default_value = "10000")]
k: usize,
#[arg(long, value_hint = ValueHint::FilePath, default_value = "ngrams.txt")]
output: PathBuf,
#[clap(long, short, action)]
counts: bool,
#[clap(long, default_value = "false", action)]
verbose: bool,
}
impl Ngram {
fn execute(&self) -> Result<ExitCode> {
if self.k < 10 {
eprintln!("NGrams must be at least 10, should be ~100k");
return Ok(ExitCode::FAILURE);
}
if self.n < 2 {
eprintln!("NGrams must be at least 2, should be >4");
return Ok(ExitCode::FAILURE);
}
let mut ngrammer = malware_modeler::ngram::Ngrammer::new(
self.ftype,
&self.path,
self.n,
self.k,
self.verbose,
)?;
ngrammer.find();
ngrammer.save(&self.output, self.counts)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Dataset {
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub malicious: PathBuf,
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub benign: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub ngrams: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub output: PathBuf,
}
impl Dataset {
fn execute(&self) -> Result<ExitCode> {
malware_modeler::dataset::Dataset::create_save_from_benign_malicious_files_and_ngrams(
&self.malicious,
&self.benign,
&self.ngrams,
&self.output,
)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct DatasetConvert {
#[arg(short, long = "in", value_hint = ValueHint::FilePath)]
pub input: PathBuf,
#[arg(short, long = "out", value_hint = ValueHint::FilePath)]
pub output: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub model: Option<PathBuf>,
}
impl DatasetConvert {
fn execute(&self) -> Result<ExitCode> {
if self.input.extension() == self.output.extension() && self.model.is_none() {
eprintln!("Refusing to attempt 'conversion' between files with the same extension.");
return Ok(ExitCode::FAILURE);
}
let mut dataset = malware_modeler::dataset::Dataset::load(&self.input)?;
if let Some(model) = &self.model {
let original = dataset.data[0].len();
let model_contents = std::fs::read_to_string(model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
let removed = dataset.reduce(&model)?.len();
println!("Removed {removed} features from the original {original}");
}
dataset.save(&self.output)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Train {
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub test: Option<PathBuf>,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub model: PathBuf,
#[arg(short, long, default_value_t = 100)]
pub epochs: u32,
#[arg(long, default_value_t = 0.1)]
pub learning_rate: f32,
#[arg(long, default_value_t = 0.1)]
pub l1_pentalty: f32,
#[arg(long, default_value_t = 0.1)]
pub l2_penalty: f32,
}
impl Train {
fn execute(&self) -> Result<ExitCode> {
let mut dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let mut model = malware_modeler::model::LogisticRegression::new(
dataset.data[0].len(),
self.learning_rate,
self.l1_pentalty,
self.l2_penalty,
);
match model.train(self.epochs, &mut dataset) {
Ok(error) => {
println!("Training complete, final error {error}");
}
Err(error) => {
eprintln!("Training failed: {error}");
return Ok(ExitCode::FAILURE);
}
}
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
if let Some(test_dataset) = &self.test {
let dataset = malware_modeler::dataset::Dataset::load(test_dataset)?;
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
model.test_performance = Some(result.into());
}
let original = model.weights.len();
model.set_features_and_reduce(&dataset)?;
let reduced = model.weights.len();
let diff = original - reduced;
println!("Features reduced by {diff}: {original} - {reduced}");
let model_json = serde_json::to_string_pretty(&model)?;
std::fs::write(&self.model, model_json)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Evaluate {
#[arg(short, long, value_hint = ValueHint::FilePath)]
dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
#[arg(long, default_value_t = false)]
update_model: bool,
}
impl Evaluate {
fn execute(&self) -> Result<ExitCode> {
let dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let model_contents = std::fs::read_to_string(&self.model)?;
let mut model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
if self.update_model {
model.test_performance = Some(result.into());
let model_json = serde_json::to_string_pretty(&model)?;
std::fs::write(&self.model, model_json)?;
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct ModelInfo {
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
#[arg(short = 'n', long, value_hint = ValueHint::FilePath)]
extract_ngrams: Option<PathBuf>,
#[arg(short, long, value_hint = ValueHint::AnyPath)]
eval: Option<PathBuf>,
}
impl ModelInfo {
fn execute(&self) -> Result<ExitCode> {
let model_contents = std::fs::read_to_string(&self.model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
println!("Model has {} weights", model.weights.len());
if let Some(extract_ngrams) = &self.extract_ngrams {
let features = model
.features
.keys()
.map(hex::encode)
.collect::<Vec<String>>()
.join("\n");
let mut file = File::create(extract_ngrams)?;
file.write_all(features.as_bytes())?;
file.write_all(b"\n")?;
file.sync_all()?;
}
if let Some(eval) = &self.eval {
if eval.is_file() {
let (label, prediction, features) = model.evaluate_file(eval)?;
println!(
"{} is predicted to be {label} (raw: {prediction}) and had {features} of {} features",
eval.display(),
model.features.len()
);
} else if eval.is_dir() {
println!("Path, PredictionLabel, PredictionRaw, FoundFeatures/AvailableFeatures");
for entry in WalkDir::new(eval)
.max_depth(malware_modeler::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
let (label, prediction, features) = model.evaluate_file(entry.path())?;
println!(
"{}, {label}, {prediction}, {features}/{}",
entry.path().display(),
model.features.len()
);
}
}
} else {
eprintln!("{} is not a file or directory", eval.display());
}
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Types {
#[arg(value_hint = ValueHint::AnyPath)]
path: PathBuf,
#[arg(value_hint = ValueHint::DirPath)]
destination: Option<PathBuf>,
#[arg(long, default_value_t = 3)]
depth: u8,
}
impl Types {
pub fn execute(&self) -> Result<ExitCode> {
if self.path.is_file() {
let contents = std::fs::read(&self.path)?;
let ftype = malware_modeler::sorting::FileTypeUnion::from_bytes(&contents);
println!("{ftype}");
} else if self.path.is_dir() && self.destination.is_none() {
for entry in WalkDir::new(&self.path)
.max_depth(malware_modeler::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
let contents = std::fs::read(entry.path())?;
let ftype = malware_modeler::sorting::FileTypeUnion::from_bytes(&contents);
println!("{}: {ftype}", entry.path().display());
}
}
} else if self.path.is_dir()
&& let Some(dest) = &self.destination
{
let destination = dest.to_owned();
if destination.is_file() {
eprintln!("Destination must be a directory, not a file");
return Ok(ExitCode::FAILURE);
}
let result =
malware_modeler::sorting::file_sorting(&self.path, &destination, self.depth)?;
println!(
"Countered {} files, {} of which were duplicates.",
result.total_files, result.files_removed
);
if result.errors > 0 {
eprintln!("Encountered {} errors while sorting files.", result.errors);
}
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser, Debug, Clone, PartialEq)]
struct Similarity {
one: PathBuf,
two: Option<PathBuf>,
#[arg(long, default_value_t = 0.8)]
threshold: f32,
#[arg(long, default_value_t = false)]
delete_file_if_too_similar_yes_i_really_mean_it: bool,
}
impl Similarity {
pub fn execute(&self) -> Result<ExitCode> {
if self.one.is_file() {
if let Some(two) = &self.two {
if two.is_file() {
let sim = malware_modeler::similarity::lzjd_compare_paths(&self.one, two)?;
println!("{sim:.4}");
}
} else {
let lzjd = malware_modeler::similarity::lzjd_from_path(&self.one)?;
println!("{lzjd}");
}
} else {
if self.two.is_some() {
bail!("Similarity does not a directory argument with another path");
}
let similarity_check = malware_modeler::similarity::Similarity {
path: &self.one,
threshold: self.threshold,
};
let similar_count = similarity_check.find(|a, b, sim| {
eprintln!("{} & {} too similar: {sim}", a.display(), b.display());
if self.delete_file_if_too_similar_yes_i_really_mean_it
&& std::fs::exists(a).unwrap_or_default()
&& let Err(e) = std::fs::remove_file(a)
{
eprintln!("Failed to remove {}: {e}", a.display());
}
})?;
println!("Total similar file pairs: {similar_count}");
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser, Debug, Clone, PartialEq)]
struct SelectiveUnzip {
#[arg(value_hint = ValueHint::FilePath)]
source: PathBuf,
#[arg(value_hint = ValueHint::DirPath)]
destination: PathBuf,
#[arg(short, long)]
password: Option<String>,
#[arg(long, default_value_t = 3)]
depth: u8,
#[arg(short = 't', long = "type", value_parser = malware_modeler::sorting::parse_file_type_union)]
file_type: Option<malware_modeler::sorting::FileTypeUnion>,
#[cfg(feature = "libmagic")]
#[arg(short = 'm', long = "magic")]
magic: Option<String>,
#[arg(long, default_value_t = false)]
keep_unknowns: bool,
}
impl SelectiveUnzip {
pub fn execute(self) -> Result<ExitCode> {
let extracted_count = malware_modeler::unzip::unzip_files_by_type(
self.source,
self.destination,
self.password.as_ref(),
self.depth,
self.file_type,
#[cfg(feature = "libmagic")]
self.magic.as_deref(),
self.keep_unknowns,
)?;
println!("Extracted {extracted_count} files");
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser, Debug, Clone, PartialEq)]
struct ZipSummary {
#[arg(value_hint = ValueHint::FilePath)]
source: PathBuf,
#[arg(short, long)]
password: Option<String>,
#[arg(short, long, default_value_t = 0)]
unknown_magic: usize,
#[arg(long, default_value_t = 2)]
unknown_minimum: usize,
}
impl ZipSummary {
pub fn execute(self) -> Result<ExitCode> {
const WIDTH: usize = 14;
let zip_summary = malware_modeler::unzip::zip_file_type_counts(
self.source,
self.password.as_ref(),
self.unknown_magic,
)?;
println!("{0:<WIDTH$} {1:<WIDTH$}", "Type", "Count");
let mut ordered_summary: Vec<_> = zip_summary.file_type_counts.into_iter().collect();
ordered_summary.sort_by(|a, b| b.1.cmp(&a.1));
for (ftype, count) in ordered_summary {
println!("{:<WIDTH$} {count:<WIDTH$}", format!("{ftype}"));
}
println!("{:^WIDTH$} {:<WIDTH$}", "Total", zip_summary.total_files);
if !zip_summary.unknown_magic_counts.is_empty() {
#[cfg(feature = "libmagic")]
println!("\n{0:<WIDTH$} {1:<WIDTH$} Magic", "Unknown Header", "Count");
#[cfg(not(feature = "libmagic"))]
println!("\n{0:<WIDTH$} {1:<WIDTH$}", "Unknown Header", "Count");
let mut ordered_summary: Vec<_> =
zip_summary.unknown_magic_counts.into_iter().collect();
ordered_summary.sort_by(|a, b| b.1.cmp(&a.1));
#[cfg(feature = "libmagic")]
for (unknown_bytes, (count, magic)) in ordered_summary {
if count >= self.unknown_minimum {
println!(
"{:<WIDTH$} {count:<WIDTH$} {magic}",
hex::encode(unknown_bytes)
);
}
}
#[cfg(not(feature = "libmagic"))]
for (unknown_bytes, count) in ordered_summary {
if count >= self.unknown_minimum {
println!("{:<WIDTH$} {count:<WIDTH$}", hex::encode(unknown_bytes));
}
}
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct Generator {
#[arg(value_enum)]
shell: Shell,
}
impl Generator {
#[must_use]
pub fn execute(&self) -> ExitCode {
let mut cmd = Args::command();
self.print_completions(&mut cmd);
ExitCode::SUCCESS
}
fn print_completions(&self, cmd: &mut Command) {
generate(
self.shell,
cmd,
cmd.get_name().to_string(),
&mut std::io::stdout(),
);
}
}
fn main() -> Result<ExitCode> {
let args = Args::parse();
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()?;
let ret = match args.cmd {
Actions::Ngram(ngram) => ngram.execute(),
Actions::Dataset(dataset) => dataset.execute(),
Actions::DatasetConvert(dataset) => dataset.execute(),
Actions::Train(training) => training.execute(),
Actions::Evaluate(evaluate) => evaluate.execute(),
Actions::Model(model) => model.execute(),
Actions::Types(types) => types.execute(),
Actions::Similarity(sim) => sim.execute(),
Actions::SelectiveUnzip(unzip) => unzip.execute(),
Actions::ZipSummary(summary) => summary.execute(),
Actions::Generate(generator) => Ok(generator.execute()),
};
if let Some(mem) = app_memory_usage_fetcher::get_memory_usage_string() {
println!("Memory used: {mem}");
}
ret
}
#[test]
fn verify_cli() {
use clap::CommandFactory;
Args::command().debug_assert();
}