malware_modeler/
model.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::ftype::FileType;
4use crate::{dataset::Dataset, Bytes};
5
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::path::Path;
9
10use anyhow::{ensure, Result};
11use rand::Rng;
12use rayon::prelude::*;
13use serde::{Deserialize, Serialize};
14
15// Adapted from https://github.com/ibnz36/lrclassifier/blob/06d136445028653af5d2061e002111ef11b14277/src/lib.rs
16// Accessed 01 November 2025
17
18/// The sigmoid function.
19#[inline]
20#[must_use]
21fn sigmoid(x: f32) -> f32 {
22    1.0 / (1.0 + (-x).exp())
23}
24
25/// Machine learning model using logistic regression
26#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
27pub struct LogisticRegression {
28    /// Learning rate
29    pub learning_rate: f32,
30
31    /// Bias term
32    pub bias: f32,
33
34    /// Model's weights
35    pub weights: Vec<f32>,
36
37    /// L1 LASSO regularization
38    pub l1: f32,
39
40    /// L2 Ridge regularization
41    pub l2: f32,
42
43    /// N-grams used to train the model
44    #[serde(
45        serialize_with = "crate::serde::serialize_hex_map",
46        deserialize_with = "crate::serde::deserialize_hex_map"
47    )]
48    pub features: HashMap<Bytes, usize>,
49
50    /// N-gram size
51    n: usize,
52
53    /// If the model has been trained
54    pub trained: bool,
55
56    /// Amount of n-grams originally used
57    pub original_ngrams: u32,
58
59    /// The type of file this model is trained on
60    pub file_type: FileType,
61}
62
63impl LogisticRegression {
64    /// New logistic regression model with random initial weights from given parameters.
65    #[must_use]
66    #[allow(clippy::cast_possible_truncation)]
67    pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
68        let mut rng = rand::rng();
69
70        Self {
71            learning_rate,
72            weights: (0..input_size)
73                .map(|_| rng.random_range(-1.0..1.0))
74                .collect(),
75            l1,
76            l2,
77            n: 0,
78            features: HashMap::new(),
79            trained: false,
80            bias: rng.random(),
81            original_ngrams: input_size as u32,
82            file_type: FileType::NotSet,
83        }
84    }
85
86    /// Returns a trained logistic regression model from given parameters and dataset. Returns the model
87    /// and the error value from the last epoch.
88    ///
89    /// # Panics
90    ///
91    /// This code won't panic, but [`LogisticRegression::train`] would panic if the data sizes were
92    /// different, but that can't happen in this case since the same object is used in both places.
93    #[must_use]
94    pub fn new_from_dataset_and_train(
95        dataset: &Dataset,
96        epochs: u32,
97        learning_rate: f32,
98        l1: f32,
99        l2: f32,
100    ) -> (Self, f32) {
101        let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
102        model.n = dataset.features[0].len();
103        model.features = dataset
104            .features
105            .iter()
106            .map(|f| (f.clone(), 0))
107            .collect::<HashMap<_, _>>();
108        let result = model.train(epochs, dataset).unwrap();
109        model.file_type = dataset.ftype;
110        (model, result)
111    }
112
113    /// Predicts the output for a given input vector.
114    /// This will fail if the input vector isn't the same length as the weights vector.
115    #[inline]
116    #[must_use]
117    pub fn predict(&self, input: &[f32]) -> f32 {
118        let linear_model = input
119            .iter()
120            .zip(&self.weights)
121            .map(|(x, w)| x * w)
122            .sum::<f32>()
123            + self.bias;
124        sigmoid(linear_model)
125    }
126
127    /// Trains the classifier once with the given inputs and outputs.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the data isn't the correct size or if labels are missing.
132    #[allow(clippy::cast_precision_loss)]
133    pub fn train(&mut self, epochs: u32, dataset: &Dataset) -> Result<f32, &'static str> {
134        if dataset.labels.is_empty() {
135            return Err("Dataset must have labels");
136        }
137
138        if !dataset.validate() {
139            return Err("Dataset didn't pass validity check!");
140        }
141
142        if dataset.data[0].len() != self.weights.len() {
143            return Err("Dataset feature length must equal the number of model weights");
144        }
145
146        let mut loss = 0.0;
147        #[allow(unused)]
148        for epoch in 0..epochs {
149            loss = 0.0;
150            for (input, output) in dataset.data.iter().zip(&dataset.labels) {
151                let prediction = self.predict(input);
152                let error = prediction - output;
153                let p = prediction.clamp(1e-8, 1.0 - 1e-8);
154                loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
155
156                self.weights
157                    .par_iter_mut()
158                    .enumerate()
159                    .for_each(|(i, weight)| {
160                        let l1r = self.l1 * (*weight / (weight.abs() + 1e-8));
161                        let l2r = self.l2 * *weight;
162                        *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
163                    });
164                self.bias -= self.learning_rate * error;
165            }
166            loss /= self.weights.len() as f32;
167
168            #[cfg(debug_assertions)]
169            println!("Epoch: {epoch}, Log loss: {loss}");
170
171            if loss < 1e-6 {
172                break;
173            }
174        }
175
176        self.trained = true;
177        self.file_type = dataset.ftype;
178        self.n = dataset.features[0].len();
179        Ok(loss)
180    }
181
182    /// Evaluate a dataset
183    ///
184    /// # Errors
185    ///
186    /// If the dataset doesn't match the model weight length or doesn't have labels
187    pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
188        ensure!(!dataset.is_empty(), "Dataset is empty");
189        ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
190        ensure!(
191            dataset.data[0].len() == self.weights.len(),
192            "Dataset length must equal the number of model weights"
193        );
194
195        let mut tp_ = 0;
196        let mut fp_ = 0;
197        let mut tn_ = 0;
198        let mut fn_ = 0;
199        let mut predictions = Vec::with_capacity(dataset.labels.len());
200
201        for index in 0..dataset.len() {
202            let prediction = self.predict(&dataset.data[index]);
203            if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
204                tp_ += 1;
205            } else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
206                fp_ += 1;
207            } else if prediction < 0.5 && dataset.labels[index] < 0.5 {
208                tn_ += 1;
209            } else {
210                fn_ += 1;
211            }
212            predictions.push(prediction);
213        }
214
215        Ok(ConfusionMatrix {
216            true_p: tp_,
217            true_n: tn_,
218            false_p: fp_,
219            false_n: fn_,
220            dataset,
221            predictions,
222        })
223    }
224
225    /// Evaluate a file
226    ///
227    /// # Errors
228    ///
229    /// Errors will result if the model doesn't have features or if the sample file can't be read
230    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
231    pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
232        ensure!(
233            !self.features.is_empty(),
234            "Features are required for file evaluation"
235        );
236
237        ensure!(
238            self.file_type.matches_path(&path)?,
239            "File type doesn't match model type"
240        );
241
242        let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
243        let result = self.predict(&vector);
244        let features = vector.iter().map(|v| *v as u32).sum();
245        if result > 0.5 {
246            Ok(("Malicious", result, features))
247        } else {
248            Ok(("Benign", result, features))
249        }
250    }
251
252    /// Remove small weights (from regularization)
253    ///
254    /// TODO: Ensure the filtered weight comparison value is sane
255    pub fn reduce(&mut self) {
256        if self.trained {
257            let mut removed = vec![];
258            self.weights = self
259                .weights
260                .iter()
261                .enumerate()
262                .filter_map(|(index, w)| {
263                    if w.abs() > 0.01 {
264                        Some(w)
265                    } else {
266                        removed.push(index);
267                        None
268                    }
269                })
270                .copied()
271                .collect();
272
273            if !self.features.is_empty() {
274                removed.sort_unstable();
275                removed.reverse();
276                let mut removed_features = Vec::with_capacity(removed.len());
277                for index in removed {
278                    for (feat, feat_index) in &self.features {
279                        if index == *feat_index {
280                            removed_features.push(feat.clone());
281                        }
282                    }
283                }
284
285                for removed_feature in removed_features {
286                    self.features.remove(&removed_feature);
287                }
288            }
289        }
290    }
291
292    /// Set the features by adding to the struct
293    ///
294    /// # Errors
295    ///
296    /// Returns an error if the number of weights and features doesn't match
297    pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
298        ensure!(
299            features.len() == self.weights.len(),
300            "Provided features length {} does not equal the number of model features length {}",
301            features.len(),
302            self.weights.len()
303        );
304        self.features = features
305            .into_iter()
306            .enumerate()
307            .map(|(f, i)| (i, f))
308            .collect::<HashMap<_, _>>();
309
310        Ok(())
311    }
312
313    /// Set the features by creating a new struct which has the features
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if the number of weights and features doesn't match
318    pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
319        ensure!(
320            features.len() == self.weights.len(),
321            "Provided features length {} does not equal the number of model features length {}",
322            features.len(),
323            self.weights.len()
324        );
325
326        Ok(Self {
327            learning_rate: self.learning_rate,
328            bias: self.bias,
329            weights: self.weights,
330            l1: self.l1,
331            l2: self.l2,
332            trained: self.trained,
333            original_ngrams: self.original_ngrams,
334            file_type: self.file_type,
335            n: self.n,
336            features: features
337                .into_iter()
338                .enumerate()
339                .map(|(f, i)| (i, f))
340                .collect::<HashMap<_, _>>(),
341        })
342    }
343}
344
345/// Confusion Matrix
346#[derive(Debug, Clone, PartialEq)]
347pub struct ConfusionMatrix<'a> {
348    /// True positives
349    pub true_p: u32,
350
351    /// True negatives
352    pub true_n: u32,
353
354    /// False positives
355    pub false_p: u32,
356
357    /// False negatives
358    pub false_n: u32,
359
360    /// Original dataset reference
361    dataset: &'a Dataset,
362
363    /// Model's outputs
364    predictions: Vec<f32>,
365}
366
367impl ConfusionMatrix<'_> {
368    /// Accuracy as correct vs total
369    #[inline]
370    #[must_use]
371    #[allow(clippy::cast_precision_loss)]
372    pub fn accuracy(&self) -> f32 {
373        (self.true_p + self.true_n) as f32 / self.total() as f32
374    }
375
376    /// Precision
377    #[must_use]
378    #[allow(clippy::cast_precision_loss)]
379    pub fn precision(&self) -> f32 {
380        self.true_p as f32 / (self.true_p + self.false_p) as f32
381    }
382
383    /// Recall
384    #[must_use]
385    #[allow(clippy::cast_precision_loss)]
386    pub fn recall(&self) -> f32 {
387        self.true_p as f32 / (self.true_p + self.false_n) as f32
388    }
389
390    /// F1 score
391    #[must_use]
392    #[allow(clippy::cast_precision_loss)]
393    pub fn f1(&self) -> f32 {
394        2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
395    }
396
397    /// Total items evaluated
398    #[inline]
399    #[must_use]
400    pub fn total(&self) -> u32 {
401        self.true_p + self.true_n + self.false_p + self.false_n
402    }
403
404    /// Calculate the area under the curve (AUC)
405    #[must_use]
406    #[allow(clippy::float_cmp)]
407    pub fn auc(&self) -> f32 {
408        // Adapted from the AUC implementation by Maciej Kula <maciej.kula@gmail.com>
409        // https://github.com/maciejkula/rustlearn/blob/2ac052559b04860c62d7ba34c563b57a02912e4d/src/metrics/ranking.rs
410        // The last commit was June 2020 so used here directly instead of importing the crate.
411        // This code was also Apache 2.0 licensed.
412
413        // vector of pairs (score, label) - the order is switched with respect to function arguments
414        let (mut true_positive_count, mut false_positive_count) = {
415            let mut pairs: Vec<_> = self
416                .predictions
417                .iter()
418                .copied()
419                .zip(self.dataset.labels.iter().copied())
420                .collect();
421
422            // Sort by scores in descending order
423            pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
424
425            let mut score_prev = f32::NAN;
426            // tp .. true positives, fp .. false positives
427            let (mut tp, mut fp) = (0.0f32, 0.0f32);
428            let (mut tps, mut fps) = (vec![], vec![]);
429            for (score, label) in pairs {
430                // `tp` and `fp` from the previous iteration are pushed onto the ROC curve only if
431                // the `score` changed. This avoids errors due to arbitrary classification of points with
432                // identical scores
433                if score != score_prev {
434                    tps.push(tp);
435                    fps.push(fp);
436                    score_prev = score;
437                }
438                tp += label;
439                fp += 1.0 - label;
440            }
441            // Push the final point corresponding to the (1,1) ROC coordinates
442            tps.push(tp);
443            fps.push(fp);
444            (tps, fps)
445        };
446
447        let true_positives = true_positive_count[true_positive_count.len() - 1];
448        let false_positives = false_positive_count[false_positive_count.len() - 1];
449
450        for (tp, fp) in true_positive_count
451            .iter_mut()
452            .zip(false_positive_count.iter_mut())
453        {
454            *tp /= true_positives;
455            *fp /= false_positives;
456        }
457
458        let mut prev_x = false_positive_count[0];
459        let mut prev_y = true_positive_count[0];
460        let mut integral = 0.0;
461
462        for (&x, &y) in false_positive_count
463            .iter()
464            .skip(1)
465            .zip(true_positive_count.iter().skip(1))
466        {
467            integral += (x - prev_x) * (prev_y + y) / 2.0;
468
469            prev_x = x;
470            prev_y = y;
471        }
472
473        integral
474    }
475}
476
477impl std::fmt::Display for ConfusionMatrix<'_> {
478    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479        writeln!(f, "Result \\ Actual | Malicious | Benign")?;
480        writeln!(
481            f,
482            " Malicious       |     {} |    {}",
483            self.true_p, self.false_p
484        )?;
485        writeln!(
486            f,
487            " Benign          |     {} |    {}",
488            self.false_n, self.true_n
489        )
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use crate::dataset::Dataset;
497
498    #[test]
499    fn xor() {
500        let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
501        let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
502        lr.train(100, &dataset).unwrap();
503
504        let mut correct = 0u16;
505        let mut incorrect = 0u16;
506
507        for index in 0..dataset.data.len() {
508            println!(
509                "Predicted: {}, Expected: {}",
510                lr.predict(&dataset.data[index]),
511                dataset.labels[index]
512            );
513            if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
514                || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
515            {
516                correct += 1;
517            } else {
518                incorrect += 1;
519            }
520        }
521
522        println!("Correct: {correct}, Incorrect: {incorrect}");
523        assert!(correct > incorrect);
524
525        let result = lr.evaluate_dataset(&dataset).unwrap();
526        println!("{result}");
527        println!("Accuracy: {:.2}", result.accuracy());
528        println!("Precision: {:.2}", result.precision());
529        println!("Recall: {:.2}", result.recall());
530        println!("F1: {:.2}", result.f1());
531        println!("Auc: {:.2}", result.auc());
532    }
533
534    #[test]
535    fn reduction() {
536        const BOGUS_LEN: usize = 6;
537
538        let dataset =
539            Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
540        let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
541        lr.set_features(dataset.features.clone()).unwrap();
542        lr.train(20, &dataset).unwrap();
543        println!("Weights before reduction: {:?}", lr.weights);
544        println!("Features before reduction: {:?}", lr.features);
545        lr.reduce();
546        println!("Weights after reduction: {:?}", lr.weights);
547        println!("Features after reduction: {:?}", lr.features);
548        println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
549        assert!(
550            lr.weights.len() < BOGUS_LEN,
551            "** If this assertion fails, re-run the test once or twice. **"
552        );
553    }
554
555    #[test]
556    fn auc() {
557        let y_true = vec![1.0, 1.0, 0.0, 0.0];
558        let y_hat = vec![0.5, 0.2, 0.3, -1.0];
559
560        let dataset = Dataset {
561            data: vec![],
562            labels: y_true,
563            features: vec![],
564            ftype: FileType::DOCFILE, // Doesn't matter for this test
565        };
566
567        let confusion_matrix = ConfusionMatrix {
568            true_p: 0,
569            true_n: 0,
570            false_p: 0,
571            false_n: 0,
572            dataset: &dataset,
573            predictions: y_hat,
574        };
575
576        let auc = confusion_matrix.auc();
577        println!("Auc: {auc:.2}, expected 0.75");
578        assert!((0.73..0.78).contains(&auc));
579    }
580}