#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![forbid(unsafe_code)]
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use std::process::ExitCode;
use anyhow::Result;
use clap::{Parser, Subcommand, ValueHint};
use walkdir::WalkDir;
#[derive(Parser)]
#[command(author, about, version = malware_modeler::VERSION)]
pub struct Args {
#[clap(subcommand)]
cmd: Actions,
}
#[derive(Subcommand)]
enum Actions {
Ngram(Ngram),
Dataset(Dataset),
Train(Train),
Evaluate(Evaluate),
Model(ModelInfo),
}
#[derive(Parser)]
struct Ngram {
#[arg(long, value_hint = ValueHint::DirPath)]
path: PathBuf,
#[arg(default_value = "6")]
n: u16,
#[arg(default_value = "10000")]
k: usize,
#[arg(long, value_hint = ValueHint::FilePath, default_value = "ngrams.txt")]
output: PathBuf,
#[clap(long, short, action)]
counts: bool,
}
impl Ngram {
fn execute(&self) -> Result<ExitCode> {
if self.k < 10 {
eprintln!("NGrams must be at least 10, should be ~100k");
return Ok(ExitCode::FAILURE);
}
if self.n < 2 {
eprintln!("NGrams must be at least 2, should be >4");
return Ok(ExitCode::FAILURE);
}
let ngrammer = malware_modeler::Ngrammer::new(&self.path, self.n, self.k)?;
let ngrams = ngrammer.ngrams();
let mut output = File::create(&self.output)?;
for (n, count) in ngrams {
if self.counts {
writeln!(output, "{},{count}", hex::encode(n))?;
} else {
writeln!(output, "{}", hex::encode(n))?;
}
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Dataset {
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub malicious: PathBuf,
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub benign: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub ngrams: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub output: PathBuf,
}
impl Dataset {
fn execute(&self) -> Result<ExitCode> {
let dataset =
malware_modeler::dataset::Dataset::create_from_benign_malicious_files_and_ngrams(
&self.malicious,
&self.benign,
&self.ngrams,
)?;
dataset.save(&self.output)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Train {
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub model: PathBuf,
#[arg(short, long, default_value_t = 100)]
pub epochs: u32,
#[arg(long, default_value_t = 0.1)]
pub learning_rate: f32,
#[arg(long, default_value_t = 0.1)]
pub l1_pentalty: f32,
#[arg(long, default_value_t = 0.1)]
pub l2_penalty: f32,
}
impl Train {
fn execute(&self) -> Result<ExitCode> {
let dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let mut model = malware_modeler::model::LogisticRegression::new(
dataset.data[0].len(),
self.learning_rate,
self.l1_pentalty,
self.l2_penalty,
);
match model.train(self.epochs, &dataset) {
Ok(error) => {
println!("Training complete, final error {error}");
}
Err(error) => {
eprintln!("Training failed: {error}");
return Ok(ExitCode::FAILURE);
}
}
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
if let Err(e) = model.set_features(dataset.features) {
eprintln!("Failed to set model features: {e}");
} else {
let original = model.weights.len();
model.reduce();
let reduced = model.weights.len();
if reduced != original {
let diff = original - reduced;
println!("Features reduced by {diff}: {original} - {reduced}");
}
}
let model_json = serde_json::to_string_pretty(&model)?;
std::fs::write(&self.model, model_json)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Evaluate {
#[arg(short, long, value_hint = ValueHint::FilePath)]
dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
}
impl Evaluate {
fn execute(&self) -> Result<ExitCode> {
let dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let model_contents = std::fs::read_to_string(&self.model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct ModelInfo {
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
#[arg(short = 'n', long, value_hint = ValueHint::FilePath)]
extract_ngrams: Option<PathBuf>,
#[arg(short, long, value_hint = ValueHint::AnyPath)]
eval: Option<PathBuf>,
}
impl ModelInfo {
fn execute(&self) -> Result<ExitCode> {
let model_contents = std::fs::read_to_string(&self.model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
println!("Model has {} weights", model.weights.len());
if let Some(extract_ngrams) = &self.extract_ngrams {
let features = model
.features
.iter()
.map(hex::encode)
.collect::<Vec<String>>()
.join("\n");
let mut file = File::create(extract_ngrams)?;
file.write_all(features.as_bytes())?;
file.write_all(b"\n")?;
file.sync_all()?;
}
if let Some(eval) = &self.eval {
if eval.is_file() {
let (label, prediction, features) = model.evaluate_file(eval)?;
println!("{} is predicted to be {label} (raw: {prediction}) and had {features} of {} features", eval.display(), model.features.len());
} else if eval.is_dir() {
println!("Path, PredictionLabel, PredictionRaw, FoundFeatures/AvailableFeatures");
for entry in WalkDir::new(eval)
.max_depth(malware_modeler::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
let (label, prediction, features) = model.evaluate_file(entry.path())?;
println!(
"{}, {label}, {prediction}, {features}/{}",
entry.path().display(),
model.features.len()
);
}
}
} else {
eprintln!("{} is not a file or directory", eval.display());
}
}
Ok(ExitCode::SUCCESS)
}
}
fn main() -> Result<ExitCode> {
match Args::parse().cmd {
Actions::Ngram(ngram) => ngram.execute(),
Actions::Dataset(dataset) => dataset.execute(),
Actions::Train(training) => training.execute(),
Actions::Evaluate(evaluate) => evaluate.execute(),
Actions::Model(model) => model.execute(),
}
}
#[test]
fn verify_cli() {
use clap::CommandFactory;
Args::command().debug_assert();
}