#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![forbid(unsafe_code)]
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use std::process::ExitCode;
use anyhow::{bail, Result};
use clap::{Command, CommandFactory, Parser, Subcommand, ValueHint};
use clap_complete::{generate, Shell};
use walkdir::WalkDir;
#[derive(Parser)]
#[command(author, about, version = malware_modeler::VERSION)]
pub struct Args {
#[clap(subcommand)]
cmd: Actions,
#[clap(long, default_value_t = num_cpus::get())]
threads: usize,
}
#[derive(Subcommand)]
enum Actions {
Ngram(Ngram),
Dataset(Dataset),
DatasetConvert(DatasetConvert),
Train(Train),
Evaluate(Evaluate),
Model(ModelInfo),
Types(Types),
Similarity(Similarity),
Generate(Generator),
}
#[derive(Parser)]
struct Ngram {
#[arg(long, value_hint = ValueHint::DirPath)]
path: PathBuf,
#[arg(short = 't', long = "type", ignore_case = true)]
ftype: Option<malware_modeler::ftype::FileType>,
#[arg(default_value = "6")]
n: u16,
#[arg(default_value = "10000")]
k: usize,
#[arg(long, value_hint = ValueHint::FilePath, default_value = "ngrams.txt")]
output: PathBuf,
#[clap(long, short, action)]
counts: bool,
#[clap(long, default_value = "false", action)]
verbose: bool,
}
impl Ngram {
fn execute(&self) -> Result<ExitCode> {
if self.k < 10 {
eprintln!("NGrams must be at least 10, should be ~100k");
return Ok(ExitCode::FAILURE);
}
if self.n < 2 {
eprintln!("NGrams must be at least 2, should be >4");
return Ok(ExitCode::FAILURE);
}
let mut ngrammer = malware_modeler::ngram::Ngrammer::new(
self.ftype,
&self.path,
self.n,
self.k,
self.verbose,
)?;
ngrammer.find();
ngrammer.save(&self.output, self.counts)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Dataset {
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub malicious: PathBuf,
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub benign: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub ngrams: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub output: PathBuf,
}
impl Dataset {
fn execute(&self) -> Result<ExitCode> {
malware_modeler::dataset::Dataset::create_save_from_benign_malicious_files_and_ngrams(
&self.malicious,
&self.benign,
&self.ngrams,
&self.output,
)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct DatasetConvert {
#[arg(short, long = "in", value_hint = ValueHint::FilePath)]
pub input: PathBuf,
#[arg(short, long = "out", value_hint = ValueHint::FilePath)]
pub output: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub model: Option<PathBuf>,
}
impl DatasetConvert {
fn execute(&self) -> Result<ExitCode> {
if self.input.extension() == self.output.extension() && self.model.is_none() {
eprintln!("Refusing to attempt 'conversion' between files with the same extension.");
return Ok(ExitCode::FAILURE);
}
let mut dataset = malware_modeler::dataset::Dataset::load(&self.input)?;
if let Some(model) = &self.model {
let original = dataset.data[0].len();
let model_contents = std::fs::read_to_string(model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
let removed = dataset.reduce(&model)?.len();
println!("Removed {removed} features from the original {original}");
}
dataset.save(&self.output)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Train {
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub model: PathBuf,
#[arg(short, long, default_value_t = 100)]
pub epochs: u32,
#[arg(long, default_value_t = 0.1)]
pub learning_rate: f32,
#[arg(long, default_value_t = 0.1)]
pub l1_pentalty: f32,
#[arg(long, default_value_t = 0.1)]
pub l2_penalty: f32,
}
impl Train {
fn execute(&self) -> Result<ExitCode> {
let dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let mut model = malware_modeler::model::LogisticRegression::new(
dataset.data[0].len(),
self.learning_rate,
self.l1_pentalty,
self.l2_penalty,
);
match model.train(self.epochs, &dataset) {
Ok(error) => {
println!("Training complete, final error {error}");
}
Err(error) => {
eprintln!("Training failed: {error}");
return Ok(ExitCode::FAILURE);
}
}
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
if let Err(e) = model.set_features(dataset.features) {
eprintln!("Failed to set model features: {e}");
} else {
let original = model.weights.len();
model.reduce();
let reduced = model.weights.len();
if reduced != original {
let diff = original - reduced;
println!("Features reduced by {diff}: {original} - {reduced}");
}
}
let model_json = serde_json::to_string_pretty(&model)?;
std::fs::write(&self.model, model_json)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Evaluate {
#[arg(short, long, value_hint = ValueHint::FilePath)]
dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
}
impl Evaluate {
fn execute(&self) -> Result<ExitCode> {
let dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let model_contents = std::fs::read_to_string(&self.model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct ModelInfo {
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
#[arg(short = 'n', long, value_hint = ValueHint::FilePath)]
extract_ngrams: Option<PathBuf>,
#[arg(short, long, value_hint = ValueHint::AnyPath)]
eval: Option<PathBuf>,
}
impl ModelInfo {
fn execute(&self) -> Result<ExitCode> {
let model_contents = std::fs::read_to_string(&self.model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
println!("Model has {} weights", model.weights.len());
if let Some(extract_ngrams) = &self.extract_ngrams {
let features = model
.features
.keys()
.map(hex::encode)
.collect::<Vec<String>>()
.join("\n");
let mut file = File::create(extract_ngrams)?;
file.write_all(features.as_bytes())?;
file.write_all(b"\n")?;
file.sync_all()?;
}
if let Some(eval) = &self.eval {
if eval.is_file() {
let (label, prediction, features) = model.evaluate_file(eval)?;
println!("{} is predicted to be {label} (raw: {prediction}) and had {features} of {} features", eval.display(), model.features.len());
} else if eval.is_dir() {
println!("Path, PredictionLabel, PredictionRaw, FoundFeatures/AvailableFeatures");
for entry in WalkDir::new(eval)
.max_depth(malware_modeler::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
let (label, prediction, features) = model.evaluate_file(entry.path())?;
println!(
"{}, {label}, {prediction}, {features}/{}",
entry.path().display(),
model.features.len()
);
}
}
} else {
eprintln!("{} is not a file or directory", eval.display());
}
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Types {
#[arg(value_hint = ValueHint::FilePath)]
path: PathBuf,
}
impl Types {
pub fn execute(&self) -> Result<ExitCode> {
if self.path.is_file() {
let ftype = malware_modeler::ftype::FileType::from_path(&self.path)?;
match ftype {
Some(ft) => println!("{ft:?}"),
None => println!("Unknown"),
}
} else if self.path.is_dir() {
for entry in WalkDir::new(&self.path)
.max_depth(malware_modeler::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
match malware_modeler::ftype::FileType::from_path(entry.path()) {
Ok(Some(ft)) => println!("{}: {ft:?}", entry.path().display()),
Ok(None) => println!("{}: Unknown", entry.path().display()),
Err(e) => eprintln!("Failed to open {}: {}", entry.path().display(), e),
}
}
}
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser, Debug, Clone, PartialEq)]
struct Similarity {
one: PathBuf,
two: Option<PathBuf>,
#[arg(long, default_value_t = 0.8)]
threshold: f32,
#[arg(long, default_value_t = false)]
delete_file_if_too_similar_yes_i_really_mean_it: bool,
}
impl Similarity {
pub fn execute(&self) -> Result<ExitCode> {
if self.one.is_file() {
if let Some(two) = &self.two {
if two.is_file() {
let sim = malware_modeler::similarity::lzjd_compare_paths(&self.one, two)?;
println!("{sim:.4}");
}
} else {
let lzjd = malware_modeler::similarity::lzjd_from_path(&self.one)?;
println!("{lzjd}");
}
} else {
if self.two.is_some() {
bail!("Similarity does not a directory argument with another path");
}
let similarity_check = malware_modeler::similarity::Similarity {
path: &self.one,
threshold: self.threshold,
};
let similar_count = similarity_check.find(|a, b, sim| {
eprintln!("{} & {} too similar: {sim}", a.display(), b.display());
if self.delete_file_if_too_similar_yes_i_really_mean_it
&& std::fs::exists(a).unwrap_or_default()
{
if let Err(e) = std::fs::remove_file(a) {
eprintln!("Failed to remove {}: {e}", a.display());
}
}
})?;
println!("Total similar file pairs: {similar_count}");
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser, Debug, Clone, PartialEq)]
pub struct Generator {
#[arg(value_enum)]
shell: Shell,
}
impl Generator {
#[must_use]
pub fn execute(&self) -> ExitCode {
let mut cmd = Args::command();
eprintln!("Generating completion file for {:?}...", self.shell);
self.print_completions(&mut cmd);
ExitCode::SUCCESS
}
fn print_completions(&self, cmd: &mut Command) {
generate(
self.shell,
cmd,
cmd.get_name().to_string(),
&mut std::io::stdout(),
);
}
}
fn main() -> Result<ExitCode> {
let args = Args::parse();
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()?;
match args.cmd {
Actions::Ngram(ngram) => ngram.execute(),
Actions::Dataset(dataset) => dataset.execute(),
Actions::DatasetConvert(dataset) => dataset.execute(),
Actions::Train(training) => training.execute(),
Actions::Evaluate(evaluate) => evaluate.execute(),
Actions::Model(model) => model.execute(),
Actions::Types(types) => types.execute(),
Actions::Similarity(sim) => sim.execute(),
Actions::Generate(generator) => Ok(generator.execute()),
}
}
#[test]
fn verify_cli() {
use clap::CommandFactory;
Args::command().debug_assert();
}