malware_modeler/
model.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::{dataset::Dataset, Bytes};
4
5use std::cmp::Ordering;
6use std::path::Path;
7
8use anyhow::{ensure, Result};
9use rand::Rng;
10use rayon::prelude::*;
11use serde::ser::Error;
12use serde::{Deserialize, Deserializer, Serialize, Serializer};
13// Adapted from https://github.com/ibnz36/lrclassifier/blob/06d136445028653af5d2061e002111ef11b14277/src/lib.rs
14// Accessed 01 November 2025
15
16/// The sigmoid function.
17#[inline]
18#[must_use]
19fn sigmoid(x: f32) -> f32 {
20    1.0 / (1.0 + (-x).exp())
21}
22
23/// Machine learning model using logistic regression
24#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
25pub struct LogisticRegression {
26    /// Learning rate
27    pub learning_rate: f32,
28
29    /// Bias term
30    pub bias: f32,
31
32    /// Model's weights
33    pub weights: Vec<f32>,
34
35    /// L1 LASSO regularization
36    pub l1: f32,
37
38    /// L2 Ridge regularization
39    pub l2: f32,
40
41    /// N-grams used to train the model
42    #[serde(
43        serialize_with = "model_serialize_features",
44        deserialize_with = "model_deserialize_features"
45    )]
46    pub features: Vec<Bytes>,
47
48    /// If the model has been trained
49    pub trained: bool,
50
51    /// Amount of n-grams originally used
52    pub original_ngrams: u32,
53}
54
55impl LogisticRegression {
56    /// New logistic regression model with random initial weights from given parameters.
57    #[must_use]
58    #[allow(clippy::cast_possible_truncation)]
59    pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
60        let mut rng = rand::rng();
61
62        Self {
63            learning_rate,
64            weights: (0..input_size)
65                .map(|_| rng.random_range(-1.0..1.0))
66                .collect(),
67            l1,
68            l2,
69            features: vec![],
70            trained: false,
71            bias: rng.random(),
72            original_ngrams: input_size as u32,
73        }
74    }
75
76    /// New trained logistic regression model from given parameters and dataset. Returns the model
77    /// and the error value from the last epoch.
78    ///
79    /// # Panics
80    ///
81    /// This code won't panic, but [`LogisticRegression::train`] would panic if the data sizes weren't
82    /// the same but that can't happen in this case since the same object is used in both places.
83    #[must_use]
84    pub fn new_from_dataset_and_train(
85        dataset: &Dataset,
86        epochs: u32,
87        learning_rate: f32,
88        l1: f32,
89        l2: f32,
90    ) -> (Self, f32) {
91        let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
92        model.features.clone_from(&dataset.features);
93        let result = model.train(epochs, dataset).unwrap();
94        (model, result)
95    }
96
97    /// Predicts the output for a given input vector.
98    /// This will fail if the input vector isn't the same length as the weights vector.
99    #[inline]
100    #[must_use]
101    pub fn predict(&self, input: &[f32]) -> f32 {
102        let linear_model = input
103            .iter()
104            .zip(&self.weights)
105            .map(|(x, w)| x * w)
106            .sum::<f32>()
107            + self.bias;
108        sigmoid(linear_model)
109    }
110
111    /// Trains the classifier once with the given inputs and outputs.
112    ///
113    /// # Errors
114    ///
115    /// Returns an error if the data isn't the correct size or if labels are missing.
116    #[allow(clippy::cast_precision_loss)]
117    pub fn train(&mut self, epochs: u32, dataset: &Dataset) -> Result<f32, &'static str> {
118        if dataset.labels.is_empty() {
119            return Err("Dataset must have labels");
120        }
121
122        if !dataset.validate() {
123            return Err("Dataset didn't pass validity check!");
124        }
125
126        if dataset.data[0].len() != self.weights.len() {
127            return Err("Dataset feature length must equal the number of model weights");
128        }
129
130        let mut loss = 0.0;
131        #[allow(unused)]
132        for epoch in 0..epochs {
133            loss = 0.0;
134            for (input, output) in dataset.data.iter().zip(&dataset.labels) {
135                let prediction = self.predict(input);
136                let error = prediction - output;
137                let p = prediction.clamp(1e-8, 1.0 - 1e-8);
138                loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
139
140                self.weights
141                    .par_iter_mut()
142                    .enumerate()
143                    .for_each(|(i, weight)| {
144                        let l1r = self.l1 * (*weight / (weight.abs() + 1e-8));
145                        let l2r = self.l2 * *weight;
146                        *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
147                    });
148                self.bias -= self.learning_rate * error;
149            }
150            loss /= self.weights.len() as f32;
151
152            #[cfg(debug_assertions)]
153            println!("Epoch: {epoch}, Log loss: {loss}");
154
155            if loss < 1e-6 {
156                break;
157            }
158        }
159
160        self.trained = true;
161        Ok(loss)
162    }
163
164    /// Evaluate a dataset
165    ///
166    /// # Errors
167    ///
168    /// If the dataset doesn't match the model weight length or doesn't have labels
169    pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
170        ensure!(!dataset.is_empty(), "Dataset is empty");
171        ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
172        ensure!(
173            dataset.data[0].len() == self.weights.len(),
174            "Dataset length must equal the number of model weights"
175        );
176
177        let mut tp_ = 0;
178        let mut fp_ = 0;
179        let mut tn_ = 0;
180        let mut fn_ = 0;
181        let mut predictions = Vec::with_capacity(dataset.labels.len());
182
183        for index in 0..dataset.len() {
184            let prediction = self.predict(&dataset.data[index]);
185            if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
186                tp_ += 1;
187            } else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
188                fp_ += 1;
189            } else if prediction < 0.5 && dataset.labels[index] < 0.5 {
190                tn_ += 1;
191            } else {
192                fn_ += 1;
193            }
194            predictions.push(prediction);
195        }
196
197        Ok(ConfusionMatrix {
198            true_p: tp_,
199            true_n: tn_,
200            false_p: fp_,
201            false_n: fn_,
202            dataset,
203            predictions,
204        })
205    }
206
207    /// Evaluate a file
208    ///
209    /// # Errors
210    ///
211    /// Errors will result if the model doesn't have features or if the sample file can't be read
212    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
213    pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
214        ensure!(
215            !self.features.is_empty(),
216            "Features are required for file evaluation"
217        );
218
219        let n = self.features[0].len();
220        let vector = crate::dataset::featurize_file(path, n, &self.features)?;
221        let result = self.predict(&vector);
222        let features = vector.iter().map(|v| *v as u32).sum();
223        if result > 0.5 {
224            Ok(("Malicious", result, features))
225        } else {
226            Ok(("Benign", result, features))
227        }
228    }
229
230    /// Remove small weights (from regularization)
231    ///
232    /// TODO: Ensure the filtered weight comparison value is sane
233    pub fn reduce(&mut self) {
234        if self.trained {
235            let mut removed = vec![];
236            self.weights = self
237                .weights
238                .iter()
239                .enumerate()
240                .filter_map(|(index, w)| {
241                    if w.abs() > 0.01 {
242                        Some(w)
243                    } else {
244                        removed.push(index);
245                        None
246                    }
247                })
248                .copied()
249                .collect();
250
251            if !self.features.is_empty() {
252                removed.sort_unstable();
253                removed.reverse();
254                for index in removed {
255                    self.features.remove(index);
256                }
257            }
258        }
259    }
260
261    /// Set the features by adding to the struct
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if the number of weights and features doesn't match
266    pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
267        ensure!(
268            features.len() == self.weights.len(),
269            "Provided features length {} does not equal the number of model features length {}",
270            features.len(),
271            self.weights.len()
272        );
273        self.features = features;
274
275        Ok(())
276    }
277
278    /// Set the features by creating a new struct which has the features
279    ///
280    /// # Errors
281    ///
282    /// Returns an error if the number of weights and features doesn't match
283    pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
284        ensure!(
285            features.len() == self.weights.len(),
286            "Provided features length {} does not equal the number of model features length {}",
287            features.len(),
288            self.weights.len()
289        );
290
291        Ok(Self {
292            learning_rate: self.learning_rate,
293            bias: self.bias,
294            weights: self.weights,
295            l1: self.l1,
296            l2: self.l2,
297            trained: self.trained,
298            original_ngrams: self.original_ngrams,
299            features,
300        })
301    }
302}
303
304fn model_serialize_features<S>(x: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error>
305where
306    S: Serializer,
307{
308    if x.is_empty() {
309        return Err(Error::custom("N-gram features not set!"));
310    }
311
312    let features = x.iter().map(hex::encode).collect::<Vec<String>>();
313    s.collect_seq(features)
314}
315
316fn model_deserialize_features<'de, D>(deserializer: D) -> Result<Vec<Vec<u8>>, D::Error>
317where
318    D: Deserializer<'de>,
319{
320    use serde::de::Error;
321    let features = Vec::<String>::deserialize(deserializer)?;
322    if features.is_empty() {
323        return Err(Error::custom("N-gram features were empty!"));
324    }
325
326    features
327        .into_iter()
328        .map(hex::decode)
329        .collect::<Result<Vec<Vec<u8>>, _>>()
330        .map_err(Error::custom)
331}
332
333/// Confusion Matrix
334#[derive(Debug, Clone, PartialEq)]
335pub struct ConfusionMatrix<'a> {
336    /// True positives
337    pub true_p: u32,
338
339    /// True negatives
340    pub true_n: u32,
341
342    /// False positives
343    pub false_p: u32,
344
345    /// False negatives
346    pub false_n: u32,
347
348    /// Original dataset reference
349    dataset: &'a Dataset,
350
351    /// Model's outputs
352    predictions: Vec<f32>,
353}
354
355impl ConfusionMatrix<'_> {
356    /// Accuracy as correct vs total
357    #[inline]
358    #[must_use]
359    #[allow(clippy::cast_precision_loss)]
360    pub fn accuracy(&self) -> f32 {
361        (self.true_p + self.true_n) as f32 / self.total() as f32
362    }
363
364    /// Precision
365    #[must_use]
366    #[allow(clippy::cast_precision_loss)]
367    pub fn precision(&self) -> f32 {
368        self.true_p as f32 / (self.true_p + self.false_p) as f32
369    }
370
371    /// Recall
372    #[must_use]
373    #[allow(clippy::cast_precision_loss)]
374    pub fn recall(&self) -> f32 {
375        self.true_p as f32 / (self.true_p + self.false_n) as f32
376    }
377
378    /// F1 score
379    #[must_use]
380    #[allow(clippy::cast_precision_loss)]
381    pub fn f1(&self) -> f32 {
382        2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
383    }
384
385    /// Total items evaluated
386    #[inline]
387    #[must_use]
388    pub fn total(&self) -> u32 {
389        self.true_p + self.true_n + self.false_p + self.false_n
390    }
391
392    /// Calculate the area under the curve (AUC)
393    #[must_use]
394    #[allow(clippy::float_cmp)]
395    pub fn auc(&self) -> f32 {
396        // Adapted from the AUC implementation by Maciej Kula <maciej.kula@gmail.com>
397        // https://github.com/maciejkula/rustlearn/blob/2ac052559b04860c62d7ba34c563b57a02912e4d/src/metrics/ranking.rs
398        // The last commit was June 2020 so used here directly instead of importing the crate.
399        // This code was also Apache 2.0 licensed.
400
401        // vector of pairs (score, label) - the order is switched with respect to function arguments
402        let (mut true_positive_count, mut false_positive_count) = {
403            let mut pairs: Vec<_> = self
404                .predictions
405                .iter()
406                .copied()
407                .zip(self.dataset.labels.iter().copied())
408                .collect();
409
410            // Sort by scores in descending order
411            pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
412
413            let mut score_prev = f32::NAN;
414            // tp .. true positives, fp .. false positives
415            let (mut tp, mut fp) = (0.0f32, 0.0f32);
416            let (mut tps, mut fps) = (vec![], vec![]);
417            for (score, label) in pairs {
418                // `tp` and `fp` from the previous iteration are pushed onto the ROC curve only if
419                // the `score` changed. This avoids errors due to arbitrary classification of points with
420                // identical scores
421                if score != score_prev {
422                    tps.push(tp);
423                    fps.push(fp);
424                    score_prev = score;
425                }
426                tp += label;
427                fp += 1.0 - label;
428            }
429            // Push the final point corresponding to the (1,1) ROC coordinates
430            tps.push(tp);
431            fps.push(fp);
432            (tps, fps)
433        };
434
435        let true_positives = true_positive_count[true_positive_count.len() - 1];
436        let false_positives = false_positive_count[false_positive_count.len() - 1];
437
438        for (tp, fp) in true_positive_count
439            .iter_mut()
440            .zip(false_positive_count.iter_mut())
441        {
442            *tp /= true_positives;
443            *fp /= false_positives;
444        }
445
446        let mut prev_x = false_positive_count[0];
447        let mut prev_y = true_positive_count[0];
448        let mut integral = 0.0;
449
450        for (&x, &y) in false_positive_count
451            .iter()
452            .skip(1)
453            .zip(true_positive_count.iter().skip(1))
454        {
455            integral += (x - prev_x) * (prev_y + y) / 2.0;
456
457            prev_x = x;
458            prev_y = y;
459        }
460
461        integral
462    }
463}
464
465impl std::fmt::Display for ConfusionMatrix<'_> {
466    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467        writeln!(f, "Result \\ Actual | Malicious | Benign")?;
468        writeln!(
469            f,
470            " Malicious       |     {} |    {}",
471            self.true_p, self.false_p
472        )?;
473        writeln!(
474            f,
475            " Benign          |     {} |    {}",
476            self.false_n, self.true_n
477        )
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::dataset::Dataset;
485
486    #[test]
487    fn xor() {
488        let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
489        let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
490        lr.train(100, &dataset).unwrap();
491
492        let mut correct = 0u16;
493        let mut incorrect = 0u16;
494
495        for index in 0..dataset.data.len() {
496            println!(
497                "Predicted: {}, Expected: {}",
498                lr.predict(&dataset.data[index]),
499                dataset.labels[index]
500            );
501            if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
502                || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
503            {
504                correct += 1;
505            } else {
506                incorrect += 1;
507            }
508        }
509
510        println!("Correct: {correct}, Incorrect: {incorrect}");
511        assert!(correct > incorrect);
512
513        let result = lr.evaluate_dataset(&dataset).unwrap();
514        println!("{result}");
515        println!("Accuracy: {:.2}", result.accuracy());
516        println!("Precision: {:.2}", result.precision());
517        println!("Recall: {:.2}", result.recall());
518        println!("F1: {:.2}", result.f1());
519        println!("Auc: {:.2}", result.auc());
520    }
521
522    #[test]
523    fn reduction() {
524        const BOGUS_LEN: usize = 6;
525
526        let dataset =
527            Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
528        let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
529        lr.set_features(dataset.features.clone()).unwrap();
530        lr.train(20, &dataset).unwrap();
531        println!("Weights before reduction: {:?}", lr.weights);
532        println!("Features before reduction: {:?}", lr.features);
533        lr.reduce();
534        println!("Weights after reduction: {:?}", lr.weights);
535        println!("Features after reduction: {:?}", lr.features);
536        println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
537        assert!(
538            lr.weights.len() < BOGUS_LEN,
539            "** If this assertion fails, re-run the test once or twice. **"
540        );
541    }
542
543    #[test]
544    fn auc() {
545        let y_true = vec![1.0, 1.0, 0.0, 0.0];
546        let y_hat = vec![0.5, 0.2, 0.3, -1.0];
547
548        let dataset = Dataset {
549            data: vec![],
550            labels: y_true,
551            features: vec![],
552        };
553
554        let confusion_matrix = ConfusionMatrix {
555            true_p: 0,
556            true_n: 0,
557            false_p: 0,
558            false_n: 0,
559            dataset: &dataset,
560            predictions: y_hat,
561        };
562
563        let auc = confusion_matrix.auc();
564        println!("Auc: {auc:.2}, expected 0.75");
565        assert!((0.73..0.78).contains(&auc));
566    }
567}