malware_modeler/
model.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::ftype::FileType;
4use crate::{dataset::Dataset, Bytes};
5
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::path::Path;
9
10use anyhow::{anyhow, ensure, Result};
11use rand::Rng;
12use rayon::prelude::*;
13use serde::{Deserialize, Serialize};
14
15// Adapted from https://github.com/ibnz36/lrclassifier/blob/06d136445028653af5d2061e002111ef11b14277/src/lib.rs
16// Accessed 01 November 2025
17
18/// The sigmoid function.
19#[inline]
20#[must_use]
21fn sigmoid(x: f32) -> f32 {
22    1.0 / (1.0 + (-x).exp())
23}
24
25/// Machine learning model using logistic regression
26#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
27pub struct LogisticRegression {
28    /// Learning rate
29    pub learning_rate: f32,
30
31    /// Bias term
32    pub bias: f32,
33
34    /// Model's weights
35    pub weights: Vec<f32>,
36
37    /// L1 LASSO regularization
38    pub l1: f32,
39
40    /// L2 Ridge regularization
41    pub l2: f32,
42
43    /// N-grams used to train the model
44    #[serde(
45        serialize_with = "crate::serde::serialize_hex_map",
46        deserialize_with = "crate::serde::deserialize_hex_map"
47    )]
48    pub features: HashMap<Bytes, usize>,
49
50    /// N-gram size
51    n: usize,
52
53    /// If the model has been trained
54    pub trained: bool,
55
56    /// Amount of n-grams originally used
57    pub original_ngrams: u32,
58
59    /// The type of file this model is trained on
60    pub file_type: FileType,
61}
62
63impl LogisticRegression {
64    /// New logistic regression model with random initial weights from given parameters.
65    #[must_use]
66    #[allow(clippy::cast_possible_truncation)]
67    pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
68        let mut rng = rand::rng();
69
70        Self {
71            learning_rate,
72            weights: (0..input_size)
73                .map(|_| rng.random_range(-1.0..1.0))
74                .collect(),
75            l1,
76            l2,
77            n: 0,
78            features: HashMap::new(),
79            trained: false,
80            bias: rng.random(),
81            original_ngrams: input_size as u32,
82            file_type: FileType::NotSet,
83        }
84    }
85
86    /// Returns a trained logistic regression model from given parameters and dataset. Returns the model
87    /// and the error value from the last epoch.
88    ///
89    /// # Panics
90    ///
91    /// This code won't panic, but [`LogisticRegression::train`] would panic if the data sizes were
92    /// different, but that can't happen in this case since the same object is used in both places.
93    #[must_use]
94    pub fn new_from_dataset_and_train(
95        dataset: &mut Dataset,
96        epochs: u32,
97        learning_rate: f32,
98        l1: f32,
99        l2: f32,
100    ) -> (Self, f32) {
101        let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
102        model.n = dataset.features[0].len();
103        model.features = dataset
104            .features
105            .iter()
106            .map(|f| (f.clone(), 0))
107            .collect::<HashMap<_, _>>();
108        let result = model.train(epochs, dataset).unwrap();
109        model.file_type = dataset.ftype;
110        (model, result)
111    }
112
113    /// Predicts the output for a given input vector.
114    /// This will fail if the input vector isn't the same length as the weights vector.
115    #[inline]
116    #[must_use]
117    pub fn predict(&self, input: &[f32]) -> f32 {
118        let linear_model = input
119            .iter()
120            .zip(&self.weights)
121            .map(|(x, w)| x * w)
122            .sum::<f32>()
123            + self.bias;
124        sigmoid(linear_model)
125    }
126
127    /// Trains the classifier once with the given inputs and outputs.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the data isn't the correct size or if labels are missing.
132    #[allow(clippy::cast_precision_loss)]
133    pub fn train(&mut self, epochs: u32, dataset: &mut Dataset) -> Result<f32, &'static str> {
134        if dataset.labels.is_empty() {
135            return Err("Dataset must have labels");
136        }
137
138        if !dataset.validate() {
139            return Err("Dataset didn't pass validity check!");
140        }
141
142        if dataset.data[0].len() != self.weights.len() {
143            return Err("Dataset feature length must equal the number of model weights");
144        }
145
146        let mut loss = 0.0;
147        #[allow(unused)]
148        for epoch in 0..epochs {
149            loss = 0.0;
150            dataset.shuffle();
151            for (input, output) in dataset.data.iter().zip(&dataset.labels) {
152                let prediction = self.predict(input);
153                let error = prediction - output;
154                let p = prediction.clamp(1e-8, 1.0 - 1e-8);
155                loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
156
157                self.weights
158                    .par_iter_mut()
159                    .enumerate()
160                    .for_each(|(i, weight)| {
161                        let l1r = self.l1 * weight.signum();
162                        let l2r = self.l2 * *weight;
163                        *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
164                    });
165                self.bias -= self.learning_rate * error;
166            }
167            loss /= self.weights.len() as f32;
168
169            #[cfg(debug_assertions)]
170            println!("Epoch: {epoch}, Log loss: {loss}");
171
172            if loss < 1e-6 {
173                break;
174            }
175        }
176
177        self.trained = true;
178        self.file_type = dataset.ftype;
179        self.n = dataset.features[0].len();
180        Ok(loss)
181    }
182
183    /// Evaluate a dataset
184    ///
185    /// # Errors
186    ///
187    /// If the dataset doesn't match the model weight length or doesn't have labels
188    pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
189        ensure!(!dataset.is_empty(), "Dataset is empty");
190        ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
191        ensure!(
192            dataset.data[0].len() == self.weights.len(),
193            "Dataset length must equal the number of model weights"
194        );
195        ensure!(
196            self.file_type == dataset.ftype,
197            "Dataset file type must match model file type"
198        );
199
200        let mut tp_ = 0;
201        let mut fp_ = 0;
202        let mut tn_ = 0;
203        let mut fn_ = 0;
204        let mut predictions = Vec::with_capacity(dataset.labels.len());
205
206        for index in 0..dataset.len() {
207            let prediction = self.predict(&dataset.data[index]);
208            if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
209                tp_ += 1;
210            } else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
211                fp_ += 1;
212            } else if prediction < 0.5 && dataset.labels[index] < 0.5 {
213                tn_ += 1;
214            } else {
215                fn_ += 1;
216            }
217            predictions.push(prediction);
218        }
219
220        Ok(ConfusionMatrix {
221            true_p: tp_,
222            true_n: tn_,
223            false_p: fp_,
224            false_n: fn_,
225            dataset,
226            predictions,
227        })
228    }
229
230    /// Evaluate a file
231    ///
232    /// # Errors
233    ///
234    /// Errors will result if the model doesn't have features or if the sample file can't be read
235    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
236    pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
237        ensure!(
238            !self.features.is_empty(),
239            "Features are required for file evaluation"
240        );
241
242        ensure!(
243            self.file_type.matches_path(&path)?,
244            "File type doesn't match model type"
245        );
246
247        let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
248        let result = self.predict(&vector);
249        let features = vector.iter().map(|v| *v as u32).sum();
250        if result > 0.5 {
251            Ok(("Malicious", result, features))
252        } else {
253            Ok(("Benign", result, features))
254        }
255    }
256
257    /// Remove small weights (from regularization), returns the number of weights removed
258    ///
259    /// # Errors
260    ///
261    /// There's an error if the model is not trained
262    #[allow(clippy::cast_precision_loss)]
263    pub fn reduce(&mut self, dataset: &Dataset) -> Result<usize> {
264        const THRESHOLD: f32 = 0.01;
265
266        ensure!(
267            self.trained,
268            "Model must be trained before reducing weights"
269        );
270        ensure!(
271            self.file_type == dataset.ftype,
272            "Dataset file type must match model file type"
273        );
274
275        let mut average_features = Vec::with_capacity(self.weights.len());
276        for (index, model_weight) in self.weights.iter().enumerate() {
277            let col_iter = dataset
278                .column_iter(index)
279                .ok_or_else(|| anyhow!("Column index out of bounds"))?;
280            average_features.push((
281                index,
282                (model_weight * col_iter.sum::<f32>() / dataset.len() as f32).abs(),
283            ));
284        }
285
286        average_features.sort_by(|(_, a), (_, b)| {
287            if a > b {
288                Ordering::Greater
289            } else if a < b {
290                Ordering::Less
291            } else {
292                Ordering::Equal
293            }
294        });
295
296        let mut removed = vec![];
297        let mut running_sum = 0.0;
298        for (index, weight) in average_features {
299            running_sum += weight;
300            if running_sum < THRESHOLD {
301                removed.push(index);
302            }
303        }
304
305        removed.sort_unstable();
306        removed.reverse();
307        let removed_len = removed.len();
308        for to_remove in &removed {
309            self.weights.remove(*to_remove);
310        }
311
312        if !self.features.is_empty() {
313            let mut removed_features = Vec::with_capacity(removed_len);
314            for index in &removed {
315                for (feat, feat_index) in &self.features {
316                    if index == feat_index {
317                        removed_features.push(feat.clone());
318                    }
319                }
320            }
321
322            for removed_feature in removed_features {
323                self.features.remove(&removed_feature);
324            }
325        }
326
327        Ok(removed_len)
328    }
329
330    /// Set the features by adding to the struct
331    ///
332    /// # Errors
333    ///
334    /// Returns an error if the number of weights and features doesn't match
335    pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
336        ensure!(
337            features.len() == self.weights.len(),
338            "Provided features length {} does not equal the number of model features length {}",
339            features.len(),
340            self.weights.len()
341        );
342        self.features = features
343            .into_iter()
344            .enumerate()
345            .map(|(f, i)| (i, f))
346            .collect::<HashMap<_, _>>();
347
348        Ok(())
349    }
350
351    /// Convenience function to get the features from the dataset then reduce based on the dataset.
352    ///
353    /// # Errors
354    ///
355    /// An error occurs if the conditions below aren't met:
356    /// * Ensures the number of model weights matches the dataset features
357    /// * Ensures the file types match
358    pub fn set_features_and_reduce(&mut self, dataset: &Dataset) -> Result<usize> {
359        ensure!(
360            self.file_type == dataset.ftype,
361            "Dataset file type must match model file type"
362        );
363        ensure!(
364            dataset.data[0].len() == self.weights.len(),
365            "Dataset length must equal the number of model weights"
366        );
367
368        self.features = dataset
369            .features
370            .iter()
371            .enumerate()
372            .map(|(f, i)| (i.clone(), f))
373            .collect::<HashMap<_, _>>();
374
375        self.reduce(dataset)
376    }
377
378    /// Set the features by creating a new struct which has the features
379    ///
380    /// # Errors
381    ///
382    /// Returns an error if the number of weights and features doesn't match
383    pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
384        ensure!(
385            features.len() == self.weights.len(),
386            "Provided features length {} does not equal the number of model features length {}",
387            features.len(),
388            self.weights.len()
389        );
390
391        Ok(Self {
392            learning_rate: self.learning_rate,
393            bias: self.bias,
394            weights: self.weights,
395            l1: self.l1,
396            l2: self.l2,
397            trained: self.trained,
398            original_ngrams: self.original_ngrams,
399            file_type: self.file_type,
400            n: self.n,
401            features: features
402                .into_iter()
403                .enumerate()
404                .map(|(f, i)| (i, f))
405                .collect::<HashMap<_, _>>(),
406        })
407    }
408}
409
410/// Confusion Matrix
411#[derive(Debug, Clone, PartialEq)]
412pub struct ConfusionMatrix<'a> {
413    /// True positives
414    pub true_p: u32,
415
416    /// True negatives
417    pub true_n: u32,
418
419    /// False positives
420    pub false_p: u32,
421
422    /// False negatives
423    pub false_n: u32,
424
425    /// Original dataset reference
426    dataset: &'a Dataset,
427
428    /// Model's outputs
429    predictions: Vec<f32>,
430}
431
432impl ConfusionMatrix<'_> {
433    /// Accuracy as correct vs total
434    #[inline]
435    #[must_use]
436    #[allow(clippy::cast_precision_loss)]
437    pub fn accuracy(&self) -> f32 {
438        (self.true_p + self.true_n) as f32 / self.total() as f32
439    }
440
441    /// Precision
442    #[must_use]
443    #[allow(clippy::cast_precision_loss)]
444    pub fn precision(&self) -> f32 {
445        self.true_p as f32 / (self.true_p + self.false_p) as f32
446    }
447
448    /// Recall
449    #[must_use]
450    #[allow(clippy::cast_precision_loss)]
451    pub fn recall(&self) -> f32 {
452        self.true_p as f32 / (self.true_p + self.false_n) as f32
453    }
454
455    /// F1 score
456    #[must_use]
457    #[allow(clippy::cast_precision_loss)]
458    pub fn f1(&self) -> f32 {
459        2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
460    }
461
462    /// Total items evaluated
463    #[inline]
464    #[must_use]
465    pub fn total(&self) -> u32 {
466        self.true_p + self.true_n + self.false_p + self.false_n
467    }
468
469    /// Calculate the area under the curve (AUC)
470    #[must_use]
471    #[allow(clippy::float_cmp)]
472    pub fn auc(&self) -> f32 {
473        // Adapted from the AUC implementation by Maciej Kula <maciej.kula@gmail.com>
474        // https://github.com/maciejkula/rustlearn/blob/2ac052559b04860c62d7ba34c563b57a02912e4d/src/metrics/ranking.rs
475        // The last commit was June 2020 so used here directly instead of importing the crate.
476        // This code was also Apache 2.0 licensed.
477
478        // vector of pairs (score, label) - the order is switched with respect to function arguments
479        let (mut true_positive_count, mut false_positive_count) = {
480            let mut pairs: Vec<_> = self
481                .predictions
482                .iter()
483                .copied()
484                .zip(self.dataset.labels.iter().copied())
485                .collect();
486
487            // Sort by scores in descending order
488            pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
489
490            let mut score_prev = f32::NAN;
491            // tp .. true positives, fp .. false positives
492            let (mut tp, mut fp) = (0.0f32, 0.0f32);
493            let (mut tps, mut fps) = (vec![], vec![]);
494            for (score, label) in pairs {
495                // `tp` and `fp` from the previous iteration are pushed onto the ROC curve only if
496                // the `score` changed. This avoids errors due to arbitrary classification of points with
497                // identical scores
498                if score != score_prev {
499                    tps.push(tp);
500                    fps.push(fp);
501                    score_prev = score;
502                }
503                tp += label;
504                fp += 1.0 - label;
505            }
506            // Push the final point corresponding to the (1,1) ROC coordinates
507            tps.push(tp);
508            fps.push(fp);
509            (tps, fps)
510        };
511
512        let true_positives = true_positive_count[true_positive_count.len() - 1];
513        let false_positives = false_positive_count[false_positive_count.len() - 1];
514
515        for (tp, fp) in true_positive_count
516            .iter_mut()
517            .zip(false_positive_count.iter_mut())
518        {
519            *tp /= true_positives;
520            *fp /= false_positives;
521        }
522
523        let mut prev_x = false_positive_count[0];
524        let mut prev_y = true_positive_count[0];
525        let mut integral = 0.0;
526
527        for (&x, &y) in false_positive_count
528            .iter()
529            .skip(1)
530            .zip(true_positive_count.iter().skip(1))
531        {
532            integral += (x - prev_x) * (prev_y + y) / 2.0;
533
534            prev_x = x;
535            prev_y = y;
536        }
537
538        integral
539    }
540}
541
542impl std::fmt::Display for ConfusionMatrix<'_> {
543    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
544        const WIDTH: usize = 10;
545
546        writeln!(f, "Result \\ Actual  |    Malicious   |  Benign")?;
547        writeln!(
548            f,
549            "   Malicious     |     {:<WIDTH$} |    {:<WIDTH$}",
550            self.true_p, self.false_p
551        )?;
552        writeln!(
553            f,
554            "   Benign        |     {:<WIDTH$} |    {:<WIDTH$}",
555            self.false_n, self.true_n
556        )
557    }
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use crate::dataset::Dataset;
564
565    #[test]
566    fn xor() {
567        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
568        let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
569        lr.train(100, &mut dataset).unwrap();
570
571        let mut correct = 0u16;
572        let mut incorrect = 0u16;
573
574        for index in 0..dataset.data.len() {
575            println!(
576                "Predicted: {}, Expected: {}",
577                lr.predict(&dataset.data[index]),
578                dataset.labels[index]
579            );
580            if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
581                || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
582            {
583                correct += 1;
584            } else {
585                incorrect += 1;
586            }
587        }
588
589        println!("Correct: {correct}, Incorrect: {incorrect}");
590        assert!(correct > incorrect);
591
592        let result = lr.evaluate_dataset(&dataset).unwrap();
593        println!("{result}");
594        println!("Accuracy: {:.2}", result.accuracy());
595        println!("Precision: {:.2}", result.precision());
596        println!("Recall: {:.2}", result.recall());
597        println!("F1: {:.2}", result.f1());
598        println!("Auc: {:.2}", result.auc());
599    }
600
601    #[test]
602    fn reduction() {
603        const BOGUS_LEN: usize = 6;
604
605        let mut dataset =
606            Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
607        let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
608        lr.set_features(dataset.features.clone()).unwrap();
609        lr.train(20, &mut dataset).unwrap();
610        let cm = lr.evaluate_dataset(&dataset).unwrap();
611        println!("{cm}");
612        println!("Weights before reduction: {:?}", lr.weights);
613        println!("Features before reduction: {:?}", lr.features);
614        lr.reduce(&dataset).expect("Failed to reduce weights");
615        println!("Weights after reduction: {:?}", lr.weights);
616        println!("Features after reduction: {:?}", lr.features);
617        println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
618        assert!(
619            lr.weights.len() < BOGUS_LEN,
620            "** If this assertion fails, re-run the test once or twice. **"
621        );
622    }
623
624    #[test]
625    fn auc() {
626        let y_true = vec![1.0, 1.0, 0.0, 0.0];
627        let y_hat = vec![0.5, 0.2, 0.3, -1.0];
628
629        let dataset = Dataset {
630            data: vec![],
631            labels: y_true,
632            features: vec![],
633            ftype: FileType::DOCFILE, // Doesn't matter for this test
634        };
635
636        let confusion_matrix = ConfusionMatrix {
637            true_p: 0,
638            true_n: 0,
639            false_p: 0,
640            false_n: 0,
641            dataset: &dataset,
642            predictions: y_hat,
643        };
644
645        let auc = confusion_matrix.auc();
646        println!("Auc: {auc:.2}, expected 0.75");
647        assert!((0.73..0.78).contains(&auc));
648    }
649}