malware-modeler 0.0.2

Train logisitic regression models for benign vs. malicious files based on byte n-grams and publish research.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
// SPDX-License-Identifier: Apache-2.0

use crate::ftype::FileType;
use crate::{dataset::Dataset, Bytes};

use std::cmp::Ordering;
use std::collections::HashMap;
use std::path::Path;

use anyhow::{ensure, Result};
use rand::Rng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};

// Adapted from https://github.com/ibnz36/lrclassifier/blob/06d136445028653af5d2061e002111ef11b14277/src/lib.rs
// Accessed 01 November 2025

/// The sigmoid function.
#[inline]
#[must_use]
fn sigmoid(x: f32) -> f32 {
    1.0 / (1.0 + (-x).exp())
}

/// Machine learning model using logistic regression
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct LogisticRegression {
    /// Learning rate
    pub learning_rate: f32,

    /// Bias term
    pub bias: f32,

    /// Model's weights
    pub weights: Vec<f32>,

    /// L1 LASSO regularization
    pub l1: f32,

    /// L2 Ridge regularization
    pub l2: f32,

    /// N-grams used to train the model
    #[serde(
        serialize_with = "crate::serde::serialize_hex_map",
        deserialize_with = "crate::serde::deserialize_hex_map"
    )]
    pub features: HashMap<Bytes, usize>,

    /// N-gram size
    n: usize,

    /// If the model has been trained
    pub trained: bool,

    /// Amount of n-grams originally used
    pub original_ngrams: u32,

    /// The type of file this model is trained on
    pub file_type: FileType,
}

impl LogisticRegression {
    /// New logistic regression model with random initial weights from given parameters.
    #[must_use]
    #[allow(clippy::cast_possible_truncation)]
    pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
        let mut rng = rand::rng();

        Self {
            learning_rate,
            weights: (0..input_size)
                .map(|_| rng.random_range(-1.0..1.0))
                .collect(),
            l1,
            l2,
            n: 0,
            features: HashMap::new(),
            trained: false,
            bias: rng.random(),
            original_ngrams: input_size as u32,
            file_type: FileType::NotSet,
        }
    }

    /// Returns a trained logistic regression model from given parameters and dataset. Returns the model
    /// and the error value from the last epoch.
    ///
    /// # Panics
    ///
    /// This code won't panic, but [`LogisticRegression::train`] would panic if the data sizes were
    /// different, but that can't happen in this case since the same object is used in both places.
    #[must_use]
    pub fn new_from_dataset_and_train(
        dataset: &Dataset,
        epochs: u32,
        learning_rate: f32,
        l1: f32,
        l2: f32,
    ) -> (Self, f32) {
        let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
        model.n = dataset.features[0].len();
        model.features = dataset
            .features
            .iter()
            .map(|f| (f.clone(), 0))
            .collect::<HashMap<_, _>>();
        let result = model.train(epochs, dataset).unwrap();
        model.file_type = dataset.ftype;
        (model, result)
    }

    /// Predicts the output for a given input vector.
    /// This will fail if the input vector isn't the same length as the weights vector.
    #[inline]
    #[must_use]
    pub fn predict(&self, input: &[f32]) -> f32 {
        let linear_model = input
            .iter()
            .zip(&self.weights)
            .map(|(x, w)| x * w)
            .sum::<f32>()
            + self.bias;
        sigmoid(linear_model)
    }

    /// Trains the classifier once with the given inputs and outputs.
    ///
    /// # Errors
    ///
    /// Returns an error if the data isn't the correct size or if labels are missing.
    #[allow(clippy::cast_precision_loss)]
    pub fn train(&mut self, epochs: u32, dataset: &Dataset) -> Result<f32, &'static str> {
        if dataset.labels.is_empty() {
            return Err("Dataset must have labels");
        }

        if !dataset.validate() {
            return Err("Dataset didn't pass validity check!");
        }

        if dataset.data[0].len() != self.weights.len() {
            return Err("Dataset feature length must equal the number of model weights");
        }

        let mut loss = 0.0;
        #[allow(unused)]
        for epoch in 0..epochs {
            loss = 0.0;
            for (input, output) in dataset.data.iter().zip(&dataset.labels) {
                let prediction = self.predict(input);
                let error = prediction - output;
                let p = prediction.clamp(1e-8, 1.0 - 1e-8);
                loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();

                self.weights
                    .par_iter_mut()
                    .enumerate()
                    .for_each(|(i, weight)| {
                        let l1r = self.l1 * (*weight / (weight.abs() + 1e-8));
                        let l2r = self.l2 * *weight;
                        *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
                    });
                self.bias -= self.learning_rate * error;
            }
            loss /= self.weights.len() as f32;

            #[cfg(debug_assertions)]
            println!("Epoch: {epoch}, Log loss: {loss}");

            if loss < 1e-6 {
                break;
            }
        }

        self.trained = true;
        self.file_type = dataset.ftype;
        self.n = dataset.features[0].len();
        Ok(loss)
    }

    /// Evaluate a dataset
    ///
    /// # Errors
    ///
    /// If the dataset doesn't match the model weight length or doesn't have labels
    pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
        ensure!(!dataset.is_empty(), "Dataset is empty");
        ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
        ensure!(
            dataset.data[0].len() == self.weights.len(),
            "Dataset length must equal the number of model weights"
        );

        let mut tp_ = 0;
        let mut fp_ = 0;
        let mut tn_ = 0;
        let mut fn_ = 0;
        let mut predictions = Vec::with_capacity(dataset.labels.len());

        for index in 0..dataset.len() {
            let prediction = self.predict(&dataset.data[index]);
            if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
                tp_ += 1;
            } else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
                fp_ += 1;
            } else if prediction < 0.5 && dataset.labels[index] < 0.5 {
                tn_ += 1;
            } else {
                fn_ += 1;
            }
            predictions.push(prediction);
        }

        Ok(ConfusionMatrix {
            true_p: tp_,
            true_n: tn_,
            false_p: fp_,
            false_n: fn_,
            dataset,
            predictions,
        })
    }

    /// Evaluate a file
    ///
    /// # Errors
    ///
    /// Errors will result if the model doesn't have features or if the sample file can't be read
    #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
    pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
        ensure!(
            !self.features.is_empty(),
            "Features are required for file evaluation"
        );

        ensure!(
            self.file_type.matches_path(&path)?,
            "File type doesn't match model type"
        );

        let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
        let result = self.predict(&vector);
        let features = vector.iter().map(|v| *v as u32).sum();
        if result > 0.5 {
            Ok(("Malicious", result, features))
        } else {
            Ok(("Benign", result, features))
        }
    }

    /// Remove small weights (from regularization)
    ///
    /// TODO: Ensure the filtered weight comparison value is sane
    pub fn reduce(&mut self) {
        if self.trained {
            let mut removed = vec![];
            self.weights = self
                .weights
                .iter()
                .enumerate()
                .filter_map(|(index, w)| {
                    if w.abs() > 0.01 {
                        Some(w)
                    } else {
                        removed.push(index);
                        None
                    }
                })
                .copied()
                .collect();

            if !self.features.is_empty() {
                removed.sort_unstable();
                removed.reverse();
                let mut removed_features = Vec::with_capacity(removed.len());
                for index in removed {
                    for (feat, feat_index) in &self.features {
                        if index == *feat_index {
                            removed_features.push(feat.clone());
                        }
                    }
                }

                for removed_feature in removed_features {
                    self.features.remove(&removed_feature);
                }
            }
        }
    }

    /// Set the features by adding to the struct
    ///
    /// # Errors
    ///
    /// Returns an error if the number of weights and features doesn't match
    pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
        ensure!(
            features.len() == self.weights.len(),
            "Provided features length {} does not equal the number of model features length {}",
            features.len(),
            self.weights.len()
        );
        self.features = features
            .into_iter()
            .enumerate()
            .map(|(f, i)| (i, f))
            .collect::<HashMap<_, _>>();

        Ok(())
    }

    /// Set the features by creating a new struct which has the features
    ///
    /// # Errors
    ///
    /// Returns an error if the number of weights and features doesn't match
    pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
        ensure!(
            features.len() == self.weights.len(),
            "Provided features length {} does not equal the number of model features length {}",
            features.len(),
            self.weights.len()
        );

        Ok(Self {
            learning_rate: self.learning_rate,
            bias: self.bias,
            weights: self.weights,
            l1: self.l1,
            l2: self.l2,
            trained: self.trained,
            original_ngrams: self.original_ngrams,
            file_type: self.file_type,
            n: self.n,
            features: features
                .into_iter()
                .enumerate()
                .map(|(f, i)| (i, f))
                .collect::<HashMap<_, _>>(),
        })
    }
}

/// Confusion Matrix
#[derive(Debug, Clone, PartialEq)]
pub struct ConfusionMatrix<'a> {
    /// True positives
    pub true_p: u32,

    /// True negatives
    pub true_n: u32,

    /// False positives
    pub false_p: u32,

    /// False negatives
    pub false_n: u32,

    /// Original dataset reference
    dataset: &'a Dataset,

    /// Model's outputs
    predictions: Vec<f32>,
}

impl ConfusionMatrix<'_> {
    /// Accuracy as correct vs total
    #[inline]
    #[must_use]
    #[allow(clippy::cast_precision_loss)]
    pub fn accuracy(&self) -> f32 {
        (self.true_p + self.true_n) as f32 / self.total() as f32
    }

    /// Precision
    #[must_use]
    #[allow(clippy::cast_precision_loss)]
    pub fn precision(&self) -> f32 {
        self.true_p as f32 / (self.true_p + self.false_p) as f32
    }

    /// Recall
    #[must_use]
    #[allow(clippy::cast_precision_loss)]
    pub fn recall(&self) -> f32 {
        self.true_p as f32 / (self.true_p + self.false_n) as f32
    }

    /// F1 score
    #[must_use]
    #[allow(clippy::cast_precision_loss)]
    pub fn f1(&self) -> f32 {
        2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
    }

    /// Total items evaluated
    #[inline]
    #[must_use]
    pub fn total(&self) -> u32 {
        self.true_p + self.true_n + self.false_p + self.false_n
    }

    /// Calculate the area under the curve (AUC)
    #[must_use]
    #[allow(clippy::float_cmp)]
    pub fn auc(&self) -> f32 {
        // Adapted from the AUC implementation by Maciej Kula <maciej.kula@gmail.com>
        // https://github.com/maciejkula/rustlearn/blob/2ac052559b04860c62d7ba34c563b57a02912e4d/src/metrics/ranking.rs
        // The last commit was June 2020 so used here directly instead of importing the crate.
        // This code was also Apache 2.0 licensed.

        // vector of pairs (score, label) - the order is switched with respect to function arguments
        let (mut true_positive_count, mut false_positive_count) = {
            let mut pairs: Vec<_> = self
                .predictions
                .iter()
                .copied()
                .zip(self.dataset.labels.iter().copied())
                .collect();

            // Sort by scores in descending order
            pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));

            let mut score_prev = f32::NAN;
            // tp .. true positives, fp .. false positives
            let (mut tp, mut fp) = (0.0f32, 0.0f32);
            let (mut tps, mut fps) = (vec![], vec![]);
            for (score, label) in pairs {
                // `tp` and `fp` from the previous iteration are pushed onto the ROC curve only if
                // the `score` changed. This avoids errors due to arbitrary classification of points with
                // identical scores
                if score != score_prev {
                    tps.push(tp);
                    fps.push(fp);
                    score_prev = score;
                }
                tp += label;
                fp += 1.0 - label;
            }
            // Push the final point corresponding to the (1,1) ROC coordinates
            tps.push(tp);
            fps.push(fp);
            (tps, fps)
        };

        let true_positives = true_positive_count[true_positive_count.len() - 1];
        let false_positives = false_positive_count[false_positive_count.len() - 1];

        for (tp, fp) in true_positive_count
            .iter_mut()
            .zip(false_positive_count.iter_mut())
        {
            *tp /= true_positives;
            *fp /= false_positives;
        }

        let mut prev_x = false_positive_count[0];
        let mut prev_y = true_positive_count[0];
        let mut integral = 0.0;

        for (&x, &y) in false_positive_count
            .iter()
            .skip(1)
            .zip(true_positive_count.iter().skip(1))
        {
            integral += (x - prev_x) * (prev_y + y) / 2.0;

            prev_x = x;
            prev_y = y;
        }

        integral
    }
}

impl std::fmt::Display for ConfusionMatrix<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(f, "Result \\ Actual | Malicious | Benign")?;
        writeln!(
            f,
            " Malicious       |     {} |    {}",
            self.true_p, self.false_p
        )?;
        writeln!(
            f,
            " Benign          |     {} |    {}",
            self.false_n, self.true_n
        )
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::dataset::Dataset;

    #[test]
    fn xor() {
        let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
        let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
        lr.train(100, &dataset).unwrap();

        let mut correct = 0u16;
        let mut incorrect = 0u16;

        for index in 0..dataset.data.len() {
            println!(
                "Predicted: {}, Expected: {}",
                lr.predict(&dataset.data[index]),
                dataset.labels[index]
            );
            if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
                || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
            {
                correct += 1;
            } else {
                incorrect += 1;
            }
        }

        println!("Correct: {correct}, Incorrect: {incorrect}");
        assert!(correct > incorrect);

        let result = lr.evaluate_dataset(&dataset).unwrap();
        println!("{result}");
        println!("Accuracy: {:.2}", result.accuracy());
        println!("Precision: {:.2}", result.precision());
        println!("Recall: {:.2}", result.recall());
        println!("F1: {:.2}", result.f1());
        println!("Auc: {:.2}", result.auc());
    }

    #[test]
    fn reduction() {
        const BOGUS_LEN: usize = 6;

        let dataset =
            Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
        let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
        lr.set_features(dataset.features.clone()).unwrap();
        lr.train(20, &dataset).unwrap();
        println!("Weights before reduction: {:?}", lr.weights);
        println!("Features before reduction: {:?}", lr.features);
        lr.reduce();
        println!("Weights after reduction: {:?}", lr.weights);
        println!("Features after reduction: {:?}", lr.features);
        println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
        assert!(
            lr.weights.len() < BOGUS_LEN,
            "** If this assertion fails, re-run the test once or twice. **"
        );
    }

    #[test]
    fn auc() {
        let y_true = vec![1.0, 1.0, 0.0, 0.0];
        let y_hat = vec![0.5, 0.2, 0.3, -1.0];

        let dataset = Dataset {
            data: vec![],
            labels: y_true,
            features: vec![],
            ftype: FileType::DOCFILE, // Doesn't matter for this test
        };

        let confusion_matrix = ConfusionMatrix {
            true_p: 0,
            true_n: 0,
            false_p: 0,
            false_n: 0,
            dataset: &dataset,
            predictions: y_hat,
        };

        let auc = confusion_matrix.auc();
        println!("Auc: {auc:.2}, expected 0.75");
        assert!((0.73..0.78).contains(&auc));
    }
}