Skip to main content

malware_modeler/
model.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::ftype::FileType;
4use crate::{Bytes, dataset::Dataset};
5
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::path::Path;
9
10use anyhow::{Result, anyhow, ensure};
11use rand::RngExt;
12use rayon::prelude::*;
13use serde::{Deserialize, Serialize};
14
15// Adapted from https://github.com/ibnz36/lrclassifier/blob/06d136445028653af5d2061e002111ef11b14277/src/lib.rs
16// Accessed 01 November 2025
17
18/// The sigmoid function.
19#[inline]
20#[must_use]
21fn sigmoid(x: f32) -> f32 {
22    1.0 / (1.0 + (-x).exp())
23}
24
25/// Machine learning model using logistic regression
26#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
27pub struct LogisticRegression {
28    /// Learning rate
29    pub learning_rate: f32,
30
31    /// Bias term
32    pub bias: f32,
33
34    /// Model's weights
35    pub weights: Vec<f32>,
36
37    /// L1 LASSO regularization
38    pub l1: f32,
39
40    /// L2 Ridge regularization
41    pub l2: f32,
42
43    /// N-grams used to train the model
44    #[serde(
45        serialize_with = "crate::serde::serialize_hex_map",
46        deserialize_with = "crate::serde::deserialize_hex_map"
47    )]
48    pub features: HashMap<Bytes, usize>,
49
50    /// N-gram size
51    n: usize,
52
53    /// If the model has been trained
54    pub trained: bool,
55
56    /// Amount of n-grams originally used
57    pub original_ngrams: u32,
58
59    /// The type of file this model is trained on
60    pub file_type: FileType,
61
62    /// How well the model performed on the training dataset
63    pub train_performance: PerformanceStats,
64
65    /// How well the model performed on the testing dataset
66    #[serde(default)]
67    pub test_performance: Option<PerformanceStats>,
68}
69
70impl LogisticRegression {
71    /// New logistic regression model with random initial weights from given parameters.
72    #[must_use]
73    #[allow(clippy::cast_possible_truncation)]
74    pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
75        let mut rng = rand::rng();
76
77        Self {
78            learning_rate,
79            weights: (0..input_size)
80                .map(|_| rng.random_range(-1.0..1.0))
81                .collect(),
82            l1,
83            l2,
84            n: 0,
85            features: HashMap::new(),
86            trained: false,
87            bias: rng.random(),
88            original_ngrams: input_size as u32,
89            file_type: FileType::NotSet,
90            train_performance: PerformanceStats::default(),
91            test_performance: None,
92        }
93    }
94
95    /// Returns a trained logistic regression model from given parameters and dataset. Returns the model
96    /// and the error value from the last epoch.
97    ///
98    /// # Panics
99    ///
100    /// This code won't panic, but [`LogisticRegression::train`] would panic if the data sizes were
101    /// different, but that can't happen in this case since the same object is used in both places.
102    #[must_use]
103    pub fn new_from_dataset_and_train(
104        dataset: &mut Dataset,
105        test: Option<&Dataset>,
106        epochs: u32,
107        learning_rate: f32,
108        l1: f32,
109        l2: f32,
110    ) -> (Self, f32) {
111        let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
112        model.n = dataset.features[0].len();
113        model.features = dataset
114            .features
115            .iter()
116            .map(|f| (f.clone(), 0))
117            .collect::<HashMap<_, _>>();
118        let result = model.train(epochs, dataset).unwrap();
119        let train_performance = model.evaluate_dataset(dataset).unwrap();
120        model.file_type = dataset.ftype;
121        model.train_performance = train_performance.into();
122
123        if let Some(test) = test
124            && dataset.ftype == test.ftype
125            && model.n == test.data[0].len()
126            && let Ok(test_performance) = model.evaluate_dataset(test)
127        {
128            model.test_performance = Some(test_performance.into());
129        }
130
131        (model, result)
132    }
133
134    /// Predicts the output for a given input vector.
135    /// This will fail if the input vector isn't the same length as the weights vector.
136    #[inline]
137    #[must_use]
138    pub fn predict(&self, input: &[f32]) -> f32 {
139        let linear_model = input
140            .iter()
141            .zip(&self.weights)
142            .map(|(x, w)| x * w)
143            .sum::<f32>()
144            + self.bias;
145        sigmoid(linear_model)
146    }
147
148    /// Trains the classifier once with the given inputs and outputs.
149    ///
150    /// A mutable reference is required because the dataset is shuffled in place each epoch.
151    ///
152    /// # Errors
153    ///
154    /// Returns an error if the data isn't the correct size or if labels are missing.
155    #[allow(clippy::cast_precision_loss)]
156    pub fn train(&mut self, epochs: u32, dataset: &mut Dataset) -> Result<f32, &'static str> {
157        if dataset.labels.is_empty() {
158            return Err("Dataset must have labels");
159        }
160
161        if !dataset.validate() {
162            return Err("Dataset didn't pass validity check!");
163        }
164
165        if dataset.data[0].len() != self.weights.len() {
166            return Err("Dataset feature length must equal the number of model weights");
167        }
168
169        let mut loss = 0.0;
170        #[allow(unused)]
171        for epoch in 0..epochs {
172            loss = 0.0;
173            dataset.shuffle();
174            for (input, output) in dataset.data.iter().zip(&dataset.labels) {
175                let output = f32::from(*output);
176                let prediction = self.predict(input);
177                let error = prediction - output;
178                let p = prediction.clamp(1e-8, 1.0 - 1e-8);
179                loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
180
181                self.weights
182                    .par_iter_mut()
183                    .enumerate()
184                    .for_each(|(i, weight)| {
185                        let l1r = self.l1 * weight.signum();
186                        let l2r = self.l2 * *weight;
187                        *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
188                    });
189                self.bias -= self.learning_rate * error;
190            }
191            loss /= self.weights.len() as f32;
192
193            #[cfg(debug_assertions)]
194            println!("Epoch: {epoch}, Log loss: {loss}");
195
196            if loss < 1e-6 {
197                break;
198            }
199        }
200
201        self.trained = true;
202        self.file_type = dataset.ftype;
203        self.n = dataset.features[0].len();
204        if let Ok(confusion_matrix) = self.evaluate_dataset(dataset) {
205            self.train_performance = confusion_matrix.into();
206        }
207        Ok(loss)
208    }
209
210    /// Evaluate a dataset
211    ///
212    /// # Errors
213    ///
214    /// If the dataset doesn't match the model weight length or doesn't have labels
215    pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
216        ensure!(!dataset.is_empty(), "Dataset is empty");
217        ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
218        ensure!(
219            dataset.data[0].len() == self.weights.len(),
220            "Dataset length must equal the number of model weights"
221        );
222        ensure!(
223            self.file_type == dataset.ftype,
224            "Dataset file type must match model file type"
225        );
226
227        let mut tp_ = 0;
228        let mut fp_ = 0;
229        let mut tn_ = 0;
230        let mut fn_ = 0;
231        let mut predictions = Vec::with_capacity(dataset.labels.len());
232
233        for index in 0..dataset.len() {
234            let prediction = self.predict(&dataset.data[index]);
235            if prediction >= 0.5 && dataset.labels[index] >= 1 {
236                tp_ += 1;
237            } else if prediction >= 0.5 && dataset.labels[index] < 1 {
238                fp_ += 1;
239            } else if prediction < 0.5 && dataset.labels[index] < 1 {
240                tn_ += 1;
241            } else {
242                fn_ += 1;
243            }
244            predictions.push(prediction);
245        }
246
247        Ok(ConfusionMatrix {
248            true_p: tp_,
249            true_n: tn_,
250            false_p: fp_,
251            false_n: fn_,
252            dataset,
253            predictions,
254        })
255    }
256
257    /// Evaluate a file
258    ///
259    /// # Errors
260    ///
261    /// Errors will result if the model doesn't have features or if the sample file can't be read
262    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
263    pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
264        ensure!(
265            !self.features.is_empty(),
266            "Features are required for file evaluation"
267        );
268
269        ensure!(
270            self.file_type.matches_path(&path)?,
271            "File type doesn't match model type"
272        );
273
274        let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
275        let result = self.predict(&vector);
276        let features = vector.iter().map(|v| *v as u32).sum();
277        if result > 0.5 {
278            Ok(("Malicious", result, features))
279        } else {
280            Ok(("Benign", result, features))
281        }
282    }
283
284    /// Remove small weights (from regularization), returns the number of weights removed
285    ///
286    /// # Errors
287    ///
288    /// There's an error if the model is not trained
289    #[allow(clippy::cast_precision_loss)]
290    pub fn reduce(&mut self, dataset: &Dataset) -> Result<usize> {
291        const THRESHOLD: f32 = 0.01;
292
293        ensure!(
294            self.trained,
295            "Model must be trained before reducing weights"
296        );
297        ensure!(
298            self.file_type == dataset.ftype,
299            "Dataset file type must match model file type"
300        );
301
302        let mut average_features = Vec::with_capacity(self.weights.len());
303        for (index, model_weight) in self.weights.iter().enumerate() {
304            let col_iter = dataset
305                .column_iter(index)
306                .ok_or_else(|| anyhow!("Column index out of bounds"))?;
307            average_features.push((
308                index,
309                (model_weight * col_iter.sum::<f32>() / dataset.len() as f32).abs(),
310            ));
311        }
312
313        average_features.sort_by(|(_, a), (_, b)| {
314            if a > b {
315                Ordering::Greater
316            } else if a < b {
317                Ordering::Less
318            } else {
319                Ordering::Equal
320            }
321        });
322
323        let mut removed = vec![];
324        let mut running_sum = 0.0;
325        for (index, weight) in average_features {
326            running_sum += weight;
327            if running_sum < THRESHOLD {
328                removed.push(index);
329            }
330        }
331
332        removed.sort_unstable();
333        removed.reverse();
334        let removed_len = removed.len();
335        for to_remove in &removed {
336            self.weights.remove(*to_remove);
337        }
338
339        if !self.features.is_empty() {
340            let mut removed_features = Vec::with_capacity(removed_len);
341            for index in &removed {
342                for (feat, feat_index) in &self.features {
343                    if index == feat_index {
344                        removed_features.push(feat.clone());
345                    }
346                }
347            }
348
349            for removed_feature in removed_features {
350                self.features.remove(&removed_feature);
351            }
352        }
353
354        Ok(removed_len)
355    }
356
357    /// Set the features by adding to the struct
358    ///
359    /// # Errors
360    ///
361    /// Returns an error if the number of weights and features doesn't match
362    pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
363        ensure!(
364            features.len() == self.weights.len(),
365            "Provided features length {} does not equal the number of model features length {}",
366            features.len(),
367            self.weights.len()
368        );
369        self.features = features
370            .into_iter()
371            .enumerate()
372            .map(|(f, i)| (i, f))
373            .collect::<HashMap<_, _>>();
374
375        Ok(())
376    }
377
378    /// Convenience function to get the features from the dataset then reduce based on the dataset.
379    ///
380    /// # Errors
381    ///
382    /// An error occurs if the conditions below aren't met:
383    /// * Ensures the number of model weights matches the dataset features
384    /// * Ensures the file types match
385    pub fn set_features_and_reduce(&mut self, dataset: &Dataset) -> Result<usize> {
386        ensure!(
387            self.file_type == dataset.ftype,
388            "Dataset file type must match model file type"
389        );
390        ensure!(
391            dataset.data[0].len() == self.weights.len(),
392            "Dataset length must equal the number of model weights"
393        );
394
395        self.features = dataset
396            .features
397            .iter()
398            .enumerate()
399            .map(|(f, i)| (i.clone(), f))
400            .collect::<HashMap<_, _>>();
401
402        self.reduce(dataset)
403    }
404
405    /// Set the features by creating a new struct which has the features
406    ///
407    /// # Errors
408    ///
409    /// Returns an error if the number of weights and features doesn't match
410    pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
411        ensure!(
412            features.len() == self.weights.len(),
413            "Provided features length {} does not equal the number of model features length {}",
414            features.len(),
415            self.weights.len()
416        );
417
418        Ok(Self {
419            learning_rate: self.learning_rate,
420            bias: self.bias,
421            weights: self.weights,
422            l1: self.l1,
423            l2: self.l2,
424            trained: self.trained,
425            original_ngrams: self.original_ngrams,
426            file_type: self.file_type,
427            n: self.n,
428            train_performance: PerformanceStats::default(),
429            test_performance: None,
430            features: features
431                .into_iter()
432                .enumerate()
433                .map(|(f, i)| (i, f))
434                .collect::<HashMap<_, _>>(),
435        })
436    }
437}
438
439/// Model performance statistics
440#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
441pub struct PerformanceStats {
442    /// True positives
443    pub true_positives: u32,
444
445    /// True negatives
446    pub true_negatives: u32,
447
448    /// False positives
449    pub false_positives: u32,
450
451    /// False negatives
452    pub false_negatives: u32,
453
454    /// Recall
455    pub recall: f32,
456
457    /// Precision
458    pub precision: f32,
459
460    /// F1
461    pub f1: f32,
462
463    /// AUC
464    pub auc: f32,
465}
466
467impl<'a> From<ConfusionMatrix<'a>> for PerformanceStats {
468    fn from(value: ConfusionMatrix<'a>) -> Self {
469        PerformanceStats {
470            true_positives: value.true_p,
471            true_negatives: value.true_n,
472            false_positives: value.false_p,
473            false_negatives: value.false_n,
474            recall: value.recall(),
475            precision: value.precision(),
476            f1: value.f1(),
477            auc: value.auc(),
478        }
479    }
480}
481
482/// Confusion Matrix
483#[derive(Debug, Clone, PartialEq)]
484pub struct ConfusionMatrix<'a> {
485    /// True positives
486    pub true_p: u32,
487
488    /// True negatives
489    pub true_n: u32,
490
491    /// False positives
492    pub false_p: u32,
493
494    /// False negatives
495    pub false_n: u32,
496
497    /// Original dataset reference
498    dataset: &'a Dataset,
499
500    /// Model's outputs
501    predictions: Vec<f32>,
502}
503
504impl ConfusionMatrix<'_> {
505    /// Accuracy as correct vs total
506    #[inline]
507    #[must_use]
508    #[allow(clippy::cast_precision_loss)]
509    pub fn accuracy(&self) -> f32 {
510        (self.true_p + self.true_n) as f32 / self.total() as f32
511    }
512
513    /// Precision
514    #[must_use]
515    #[allow(clippy::cast_precision_loss)]
516    pub fn precision(&self) -> f32 {
517        self.true_p as f32 / (self.true_p + self.false_p) as f32
518    }
519
520    /// Recall
521    #[must_use]
522    #[allow(clippy::cast_precision_loss)]
523    pub fn recall(&self) -> f32 {
524        self.true_p as f32 / (self.true_p + self.false_n) as f32
525    }
526
527    /// F1 score
528    #[must_use]
529    #[allow(clippy::cast_precision_loss)]
530    pub fn f1(&self) -> f32 {
531        2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
532    }
533
534    /// Total items evaluated
535    #[inline]
536    #[must_use]
537    pub fn total(&self) -> u32 {
538        self.true_p + self.true_n + self.false_p + self.false_n
539    }
540
541    /// Calculate the area under the curve (AUC)
542    #[must_use]
543    #[allow(clippy::float_cmp)]
544    pub fn auc(&self) -> f32 {
545        // Adapted from the AUC implementation by Maciej Kula <maciej.kula@gmail.com>
546        // https://github.com/maciejkula/rustlearn/blob/2ac052559b04860c62d7ba34c563b57a02912e4d/src/metrics/ranking.rs
547        // The last commit was June 2020 so used here directly instead of importing the crate.
548        // This code was also Apache 2.0 licensed.
549
550        // vector of pairs (score, label) - the order is switched with respect to function arguments
551        let (mut true_positive_count, mut false_positive_count) = {
552            let mut pairs: Vec<_> = self
553                .predictions
554                .iter()
555                .copied()
556                .zip(self.dataset.labels.iter().copied())
557                .collect();
558
559            // Sort by scores in descending order
560            pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
561
562            let mut score_prev = f32::NAN;
563            // tp .. true positives, fp .. false positives
564            let (mut tp, mut fp) = (0.0f32, 0.0f32);
565            let (mut tps, mut fps) = (vec![], vec![]);
566            for (score, label) in pairs {
567                let label = f32::from(label);
568                // `tp` and `fp` from the previous iteration are pushed onto the ROC curve only if
569                // the `score` changed. This avoids errors due to arbitrary classification of points with
570                // identical scores
571                if score != score_prev {
572                    tps.push(tp);
573                    fps.push(fp);
574                    score_prev = score;
575                }
576                tp += label;
577                fp += 1.0 - label;
578            }
579            // Push the final point corresponding to the (1,1) ROC coordinates
580            tps.push(tp);
581            fps.push(fp);
582            (tps, fps)
583        };
584
585        let true_positives = true_positive_count[true_positive_count.len() - 1];
586        let false_positives = false_positive_count[false_positive_count.len() - 1];
587
588        for (tp, fp) in true_positive_count
589            .iter_mut()
590            .zip(false_positive_count.iter_mut())
591        {
592            *tp /= true_positives;
593            *fp /= false_positives;
594        }
595
596        let mut prev_x = false_positive_count[0];
597        let mut prev_y = true_positive_count[0];
598        let mut integral = 0.0;
599
600        for (&x, &y) in false_positive_count
601            .iter()
602            .skip(1)
603            .zip(true_positive_count.iter().skip(1))
604        {
605            integral += (x - prev_x) * (prev_y + y) / 2.0;
606
607            prev_x = x;
608            prev_y = y;
609        }
610
611        integral
612    }
613}
614
615impl std::fmt::Display for ConfusionMatrix<'_> {
616    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617        const WIDTH: usize = 10;
618
619        writeln!(f, "Result \\ Actual  |    Malicious   |  Benign")?;
620        writeln!(
621            f,
622            "   Malicious     |     {:<WIDTH$} |    {:<WIDTH$}",
623            self.true_p, self.false_p
624        )?;
625        writeln!(
626            f,
627            "   Benign        |     {:<WIDTH$} |    {:<WIDTH$}",
628            self.false_n, self.true_n
629        )
630    }
631}
632
633#[cfg(test)]
634mod tests {
635    use super::*;
636    use crate::dataset::Dataset;
637
638    #[test]
639    fn xor() {
640        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
641        let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
642        lr.train(100, &mut dataset).unwrap();
643
644        let mut correct = 0u16;
645        let mut incorrect = 0u16;
646
647        for index in 0..dataset.data.len() {
648            println!(
649                "Predicted: {}, Expected: {}",
650                lr.predict(&dataset.data[index]),
651                dataset.labels[index]
652            );
653            if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 1)
654                || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 1)
655            {
656                correct += 1;
657            } else {
658                incorrect += 1;
659            }
660        }
661
662        println!("Correct: {correct}, Incorrect: {incorrect}");
663        assert!(correct > incorrect);
664
665        let result = lr.evaluate_dataset(&dataset).unwrap();
666        println!("{result}");
667        println!("Accuracy: {:.2}", result.accuracy());
668        println!("Precision: {:.2}", result.precision());
669        println!("Recall: {:.2}", result.recall());
670        println!("F1: {:.2}", result.f1());
671        println!("Auc: {:.2}", result.auc());
672    }
673
674    #[test]
675    fn reduction() {
676        const BOGUS_LEN: usize = 6;
677
678        let mut dataset =
679            Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
680        let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
681        lr.set_features(dataset.features.clone()).unwrap();
682        lr.train(20, &mut dataset).unwrap();
683        let cm = lr.evaluate_dataset(&dataset).unwrap();
684        println!("{cm}");
685        println!("Weights before reduction: {:?}", lr.weights);
686        println!("Features before reduction: {:?}", lr.features);
687        lr.reduce(&dataset).expect("Failed to reduce weights");
688        println!("Weights after reduction: {:?}", lr.weights);
689        println!("Features after reduction: {:?}", lr.features);
690        println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
691        assert!(
692            lr.weights.len() < BOGUS_LEN,
693            "** If this assertion fails, re-run the test once or twice. **"
694        );
695        lr.test_performance = Some(cm.into());
696        println!("Model: {lr:?}");
697    }
698
699    #[test]
700    fn auc() {
701        let y_true = vec![1, 1, 0, 0];
702        let y_hat = vec![0.5, 0.2, 0.3, -1.0];
703
704        let dataset = Dataset {
705            data: vec![],
706            labels: y_true,
707            features: vec![],
708            ftype: FileType::DOCFILE, // Doesn't matter for this test
709        };
710
711        let confusion_matrix = ConfusionMatrix {
712            true_p: 0,
713            true_n: 0,
714            false_p: 0,
715            false_n: 0,
716            dataset: &dataset,
717            predictions: y_hat,
718        };
719
720        let auc = confusion_matrix.auc();
721        println!("Auc: {auc:.2}, expected 0.75");
722        assert!((0.73..0.78).contains(&auc));
723    }
724}