use crate::ftype::FileType;
use crate::{dataset::Dataset, Bytes};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::path::Path;
use anyhow::{ensure, Result};
use rand::Rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[inline]
#[must_use]
fn sigmoid(x: f32) -> f32 {
1.0 / (1.0 + (-x).exp())
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct LogisticRegression {
pub learning_rate: f32,
pub bias: f32,
pub weights: Vec<f32>,
pub l1: f32,
pub l2: f32,
#[serde(
serialize_with = "crate::serde::serialize_hex_map",
deserialize_with = "crate::serde::deserialize_hex_map"
)]
pub features: HashMap<Bytes, usize>,
n: usize,
pub trained: bool,
pub original_ngrams: u32,
pub file_type: FileType,
}
impl LogisticRegression {
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
let mut rng = rand::rng();
Self {
learning_rate,
weights: (0..input_size)
.map(|_| rng.random_range(-1.0..1.0))
.collect(),
l1,
l2,
n: 0,
features: HashMap::new(),
trained: false,
bias: rng.random(),
original_ngrams: input_size as u32,
file_type: FileType::NotSet,
}
}
#[must_use]
pub fn new_from_dataset_and_train(
dataset: &Dataset,
epochs: u32,
learning_rate: f32,
l1: f32,
l2: f32,
) -> (Self, f32) {
let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
model.n = dataset.features[0].len();
model.features = dataset
.features
.iter()
.map(|f| (f.clone(), 0))
.collect::<HashMap<_, _>>();
let result = model.train(epochs, dataset).unwrap();
model.file_type = dataset.ftype;
(model, result)
}
#[inline]
#[must_use]
pub fn predict(&self, input: &[f32]) -> f32 {
let linear_model = input
.iter()
.zip(&self.weights)
.map(|(x, w)| x * w)
.sum::<f32>()
+ self.bias;
sigmoid(linear_model)
}
#[allow(clippy::cast_precision_loss)]
pub fn train(&mut self, epochs: u32, dataset: &Dataset) -> Result<f32, &'static str> {
if dataset.labels.is_empty() {
return Err("Dataset must have labels");
}
if !dataset.validate() {
return Err("Dataset didn't pass validity check!");
}
if dataset.data[0].len() != self.weights.len() {
return Err("Dataset feature length must equal the number of model weights");
}
let mut loss = 0.0;
#[allow(unused)]
for epoch in 0..epochs {
loss = 0.0;
for (input, output) in dataset.data.iter().zip(&dataset.labels) {
let prediction = self.predict(input);
let error = prediction - output;
let p = prediction.clamp(1e-8, 1.0 - 1e-8);
loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
self.weights
.par_iter_mut()
.enumerate()
.for_each(|(i, weight)| {
let l1r = self.l1 * (*weight / (weight.abs() + 1e-8));
let l2r = self.l2 * *weight;
*weight -= self.learning_rate * (error * input[i] + l1r + l2r);
});
self.bias -= self.learning_rate * error;
}
loss /= self.weights.len() as f32;
#[cfg(debug_assertions)]
println!("Epoch: {epoch}, Log loss: {loss}");
if loss < 1e-6 {
break;
}
}
self.trained = true;
self.file_type = dataset.ftype;
self.n = dataset.features[0].len();
Ok(loss)
}
pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
ensure!(!dataset.is_empty(), "Dataset is empty");
ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
ensure!(
dataset.data[0].len() == self.weights.len(),
"Dataset length must equal the number of model weights"
);
let mut tp_ = 0;
let mut fp_ = 0;
let mut tn_ = 0;
let mut fn_ = 0;
let mut predictions = Vec::with_capacity(dataset.labels.len());
for index in 0..dataset.len() {
let prediction = self.predict(&dataset.data[index]);
if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
tp_ += 1;
} else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
fp_ += 1;
} else if prediction < 0.5 && dataset.labels[index] < 0.5 {
tn_ += 1;
} else {
fn_ += 1;
}
predictions.push(prediction);
}
Ok(ConfusionMatrix {
true_p: tp_,
true_n: tn_,
false_p: fp_,
false_n: fn_,
dataset,
predictions,
})
}
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
ensure!(
!self.features.is_empty(),
"Features are required for file evaluation"
);
ensure!(
self.file_type.matches_path(&path)?,
"File type doesn't match model type"
);
let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
let result = self.predict(&vector);
let features = vector.iter().map(|v| *v as u32).sum();
if result > 0.5 {
Ok(("Malicious", result, features))
} else {
Ok(("Benign", result, features))
}
}
pub fn reduce(&mut self) {
if self.trained {
let mut removed = vec![];
self.weights = self
.weights
.iter()
.enumerate()
.filter_map(|(index, w)| {
if w.abs() > 0.01 {
Some(w)
} else {
removed.push(index);
None
}
})
.copied()
.collect();
if !self.features.is_empty() {
removed.sort_unstable();
removed.reverse();
let mut removed_features = Vec::with_capacity(removed.len());
for index in removed {
for (feat, feat_index) in &self.features {
if index == *feat_index {
removed_features.push(feat.clone());
}
}
}
for removed_feature in removed_features {
self.features.remove(&removed_feature);
}
}
}
}
pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
ensure!(
features.len() == self.weights.len(),
"Provided features length {} does not equal the number of model features length {}",
features.len(),
self.weights.len()
);
self.features = features
.into_iter()
.enumerate()
.map(|(f, i)| (i, f))
.collect::<HashMap<_, _>>();
Ok(())
}
pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
ensure!(
features.len() == self.weights.len(),
"Provided features length {} does not equal the number of model features length {}",
features.len(),
self.weights.len()
);
Ok(Self {
learning_rate: self.learning_rate,
bias: self.bias,
weights: self.weights,
l1: self.l1,
l2: self.l2,
trained: self.trained,
original_ngrams: self.original_ngrams,
file_type: self.file_type,
n: self.n,
features: features
.into_iter()
.enumerate()
.map(|(f, i)| (i, f))
.collect::<HashMap<_, _>>(),
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ConfusionMatrix<'a> {
pub true_p: u32,
pub true_n: u32,
pub false_p: u32,
pub false_n: u32,
dataset: &'a Dataset,
predictions: Vec<f32>,
}
impl ConfusionMatrix<'_> {
#[inline]
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn accuracy(&self) -> f32 {
(self.true_p + self.true_n) as f32 / self.total() as f32
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn precision(&self) -> f32 {
self.true_p as f32 / (self.true_p + self.false_p) as f32
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn recall(&self) -> f32 {
self.true_p as f32 / (self.true_p + self.false_n) as f32
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn f1(&self) -> f32 {
2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
}
#[inline]
#[must_use]
pub fn total(&self) -> u32 {
self.true_p + self.true_n + self.false_p + self.false_n
}
#[must_use]
#[allow(clippy::float_cmp)]
pub fn auc(&self) -> f32 {
let (mut true_positive_count, mut false_positive_count) = {
let mut pairs: Vec<_> = self
.predictions
.iter()
.copied()
.zip(self.dataset.labels.iter().copied())
.collect();
pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
let mut score_prev = f32::NAN;
let (mut tp, mut fp) = (0.0f32, 0.0f32);
let (mut tps, mut fps) = (vec![], vec![]);
for (score, label) in pairs {
if score != score_prev {
tps.push(tp);
fps.push(fp);
score_prev = score;
}
tp += label;
fp += 1.0 - label;
}
tps.push(tp);
fps.push(fp);
(tps, fps)
};
let true_positives = true_positive_count[true_positive_count.len() - 1];
let false_positives = false_positive_count[false_positive_count.len() - 1];
for (tp, fp) in true_positive_count
.iter_mut()
.zip(false_positive_count.iter_mut())
{
*tp /= true_positives;
*fp /= false_positives;
}
let mut prev_x = false_positive_count[0];
let mut prev_y = true_positive_count[0];
let mut integral = 0.0;
for (&x, &y) in false_positive_count
.iter()
.skip(1)
.zip(true_positive_count.iter().skip(1))
{
integral += (x - prev_x) * (prev_y + y) / 2.0;
prev_x = x;
prev_y = y;
}
integral
}
}
impl std::fmt::Display for ConfusionMatrix<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Result \\ Actual | Malicious | Benign")?;
writeln!(
f,
" Malicious | {} | {}",
self.true_p, self.false_p
)?;
writeln!(
f,
" Benign | {} | {}",
self.false_n, self.true_n
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::Dataset;
#[test]
fn xor() {
let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
lr.train(100, &dataset).unwrap();
let mut correct = 0u16;
let mut incorrect = 0u16;
for index in 0..dataset.data.len() {
println!(
"Predicted: {}, Expected: {}",
lr.predict(&dataset.data[index]),
dataset.labels[index]
);
if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
|| (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
{
correct += 1;
} else {
incorrect += 1;
}
}
println!("Correct: {correct}, Incorrect: {incorrect}");
assert!(correct > incorrect);
let result = lr.evaluate_dataset(&dataset).unwrap();
println!("{result}");
println!("Accuracy: {:.2}", result.accuracy());
println!("Precision: {:.2}", result.precision());
println!("Recall: {:.2}", result.recall());
println!("F1: {:.2}", result.f1());
println!("Auc: {:.2}", result.auc());
}
#[test]
fn reduction() {
const BOGUS_LEN: usize = 6;
let dataset =
Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
lr.set_features(dataset.features.clone()).unwrap();
lr.train(20, &dataset).unwrap();
println!("Weights before reduction: {:?}", lr.weights);
println!("Features before reduction: {:?}", lr.features);
lr.reduce();
println!("Weights after reduction: {:?}", lr.weights);
println!("Features after reduction: {:?}", lr.features);
println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
assert!(
lr.weights.len() < BOGUS_LEN,
"** If this assertion fails, re-run the test once or twice. **"
);
}
#[test]
fn auc() {
let y_true = vec![1.0, 1.0, 0.0, 0.0];
let y_hat = vec![0.5, 0.2, 0.3, -1.0];
let dataset = Dataset {
data: vec![],
labels: y_true,
features: vec![],
ftype: FileType::DOCFILE, };
let confusion_matrix = ConfusionMatrix {
true_p: 0,
true_n: 0,
false_p: 0,
false_n: 0,
dataset: &dataset,
predictions: y_hat,
};
let auc = confusion_matrix.auc();
println!("Auc: {auc:.2}, expected 0.75");
assert!((0.73..0.78).contains(&auc));
}
}