#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![forbid(unsafe_code)]
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use std::process::ExitCode;
use anyhow::Result;
use clap::{Parser, Subcommand, ValueHint};
use walkdir::WalkDir;
#[derive(Parser)]
#[command(author, about, version = malware_modeler::VERSION)]
pub struct Args {
#[clap(subcommand)]
cmd: Actions,
#[clap(long, default_value_t = num_cpus::get())]
threads: usize,
}
#[derive(Subcommand)]
enum Actions {
Ngram(Ngram),
Dataset(Dataset),
DatasetConvert(DatasetConvert),
Train(Train),
Evaluate(Evaluate),
Model(ModelInfo),
Types(Types),
}
#[derive(Parser)]
struct Ngram {
#[arg(long, value_hint = ValueHint::DirPath)]
path: PathBuf,
#[arg(short = 't', long = "type")]
ftype: Option<malware_modeler::ftype::FileType>,
#[arg(default_value = "6")]
n: u16,
#[arg(default_value = "10000")]
k: usize,
#[arg(long, value_hint = ValueHint::FilePath, default_value = "ngrams.txt")]
output: PathBuf,
#[clap(long, short, action)]
counts: bool,
}
impl Ngram {
fn execute(&self) -> Result<ExitCode> {
if self.k < 10 {
eprintln!("NGrams must be at least 10, should be ~100k");
return Ok(ExitCode::FAILURE);
}
if self.n < 2 {
eprintln!("NGrams must be at least 2, should be >4");
return Ok(ExitCode::FAILURE);
}
let mut ngrammer =
malware_modeler::ngram::Ngrammer::new(self.ftype, &self.path, self.n, self.k)?;
ngrammer.find();
ngrammer.save(&self.output, self.counts)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Dataset {
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub malicious: PathBuf,
#[arg(short, long, value_hint = ValueHint::DirPath)]
pub benign: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub ngrams: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub output: PathBuf,
}
impl Dataset {
fn execute(&self) -> Result<ExitCode> {
malware_modeler::dataset::Dataset::create_save_from_benign_malicious_files_and_ngrams(
&self.malicious,
&self.benign,
&self.ngrams,
&self.output,
)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct DatasetConvert {
#[arg(short, long = "in", value_hint = ValueHint::FilePath)]
pub input: PathBuf,
#[arg(short, long = "out", value_hint = ValueHint::FilePath)]
pub output: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub model: Option<PathBuf>,
}
impl DatasetConvert {
fn execute(&self) -> Result<ExitCode> {
if self.input.extension() == self.output.extension() && self.model.is_none() {
eprintln!("Refusing to attempt 'conversion' between files with the same extension.");
return Ok(ExitCode::FAILURE);
}
let mut dataset = malware_modeler::dataset::Dataset::load(&self.input)?;
if let Some(model) = &self.model {
let original = dataset.data[0].len();
let model_contents = std::fs::read_to_string(model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
let removed = dataset.reduce(&model)?.len();
println!("Removed {removed} features from the original {original}");
}
dataset.save(&self.output)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Train {
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
pub model: PathBuf,
#[arg(short, long, default_value_t = 100)]
pub epochs: u32,
#[arg(long, default_value_t = 0.1)]
pub learning_rate: f32,
#[arg(long, default_value_t = 0.1)]
pub l1_pentalty: f32,
#[arg(long, default_value_t = 0.1)]
pub l2_penalty: f32,
}
impl Train {
fn execute(&self) -> Result<ExitCode> {
let dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let mut model = malware_modeler::model::LogisticRegression::new(
dataset.data[0].len(),
self.learning_rate,
self.l1_pentalty,
self.l2_penalty,
);
match model.train(self.epochs, &dataset) {
Ok(error) => {
println!("Training complete, final error {error}");
}
Err(error) => {
eprintln!("Training failed: {error}");
return Ok(ExitCode::FAILURE);
}
}
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
if let Err(e) = model.set_features(dataset.features) {
eprintln!("Failed to set model features: {e}");
} else {
let original = model.weights.len();
model.reduce();
let reduced = model.weights.len();
if reduced != original {
let diff = original - reduced;
println!("Features reduced by {diff}: {original} - {reduced}");
}
}
let model_json = serde_json::to_string_pretty(&model)?;
std::fs::write(&self.model, model_json)?;
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Evaluate {
#[arg(short, long, value_hint = ValueHint::FilePath)]
dataset: PathBuf,
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
}
impl Evaluate {
fn execute(&self) -> Result<ExitCode> {
let dataset = malware_modeler::dataset::Dataset::load(&self.dataset)?;
let model_contents = std::fs::read_to_string(&self.model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
let result = model.evaluate_dataset(&dataset)?;
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct ModelInfo {
#[arg(short, long, value_hint = ValueHint::FilePath)]
model: PathBuf,
#[arg(short = 'n', long, value_hint = ValueHint::FilePath)]
extract_ngrams: Option<PathBuf>,
#[arg(short, long, value_hint = ValueHint::AnyPath)]
eval: Option<PathBuf>,
}
impl ModelInfo {
fn execute(&self) -> Result<ExitCode> {
let model_contents = std::fs::read_to_string(&self.model)?;
let model: malware_modeler::model::LogisticRegression =
serde_json::from_str(model_contents.as_str())?;
println!("Model has {} weights", model.weights.len());
if let Some(extract_ngrams) = &self.extract_ngrams {
let features = model
.features
.keys()
.map(hex::encode)
.collect::<Vec<String>>()
.join("\n");
let mut file = File::create(extract_ngrams)?;
file.write_all(features.as_bytes())?;
file.write_all(b"\n")?;
file.sync_all()?;
}
if let Some(eval) = &self.eval {
if eval.is_file() {
let (label, prediction, features) = model.evaluate_file(eval)?;
println!("{} is predicted to be {label} (raw: {prediction}) and had {features} of {} features", eval.display(), model.features.len());
} else if eval.is_dir() {
println!("Path, PredictionLabel, PredictionRaw, FoundFeatures/AvailableFeatures");
for entry in WalkDir::new(eval)
.max_depth(malware_modeler::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
let (label, prediction, features) = model.evaluate_file(entry.path())?;
println!(
"{}, {label}, {prediction}, {features}/{}",
entry.path().display(),
model.features.len()
);
}
}
} else {
eprintln!("{} is not a file or directory", eval.display());
}
}
Ok(ExitCode::SUCCESS)
}
}
#[derive(Parser)]
struct Types {
#[arg(value_hint = ValueHint::FilePath)]
path: PathBuf,
}
impl Types {
pub fn execute(&self) -> Result<ExitCode> {
if self.path.is_file() {
let ftype = malware_modeler::ftype::FileType::from_path(&self.path)?;
match ftype {
Some(ft) => println!("{ft:?}"),
None => println!("Unknown"),
}
} else if self.path.is_dir() {
for entry in WalkDir::new(&self.path)
.max_depth(malware_modeler::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
match malware_modeler::ftype::FileType::from_path(entry.path()) {
Ok(Some(ft)) => println!("{}: {ft:?}", entry.path().display()),
Ok(None) => println!("{}: Unknown", entry.path().display()),
Err(e) => eprintln!("Failed to open {}: {}", entry.path().display(), e),
}
}
}
}
Ok(ExitCode::SUCCESS)
}
}
fn main() -> Result<ExitCode> {
let args = Args::parse();
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()?;
match args.cmd {
Actions::Ngram(ngram) => ngram.execute(),
Actions::Dataset(dataset) => dataset.execute(),
Actions::DatasetConvert(dataset) => dataset.execute(),
Actions::Train(training) => training.execute(),
Actions::Evaluate(evaluate) => evaluate.execute(),
Actions::Model(model) => model.execute(),
Actions::Types(types) => types.execute(),
}
}
#[test]
fn verify_cli() {
use clap::CommandFactory;
Args::command().debug_assert();
}