1use crate::ftype::FileType;
4use crate::{dataset::Dataset, Bytes};
5
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::path::Path;
9
10use anyhow::{ensure, Result};
11use rand::Rng;
12use rayon::prelude::*;
13use serde::{Deserialize, Serialize};
14
15#[inline]
20#[must_use]
21fn sigmoid(x: f32) -> f32 {
22 1.0 / (1.0 + (-x).exp())
23}
24
25#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
27pub struct LogisticRegression {
28 pub learning_rate: f32,
30
31 pub bias: f32,
33
34 pub weights: Vec<f32>,
36
37 pub l1: f32,
39
40 pub l2: f32,
42
43 #[serde(
45 serialize_with = "crate::serde::serialize_hex_map",
46 deserialize_with = "crate::serde::deserialize_hex_map"
47 )]
48 pub features: HashMap<Bytes, usize>,
49
50 n: usize,
52
53 pub trained: bool,
55
56 pub original_ngrams: u32,
58
59 pub file_type: FileType,
61}
62
63impl LogisticRegression {
64 #[must_use]
66 #[allow(clippy::cast_possible_truncation)]
67 pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
68 let mut rng = rand::rng();
69
70 Self {
71 learning_rate,
72 weights: (0..input_size)
73 .map(|_| rng.random_range(-1.0..1.0))
74 .collect(),
75 l1,
76 l2,
77 n: 0,
78 features: HashMap::new(),
79 trained: false,
80 bias: rng.random(),
81 original_ngrams: input_size as u32,
82 file_type: FileType::NotSet,
83 }
84 }
85
86 #[must_use]
94 pub fn new_from_dataset_and_train(
95 dataset: &Dataset,
96 epochs: u32,
97 learning_rate: f32,
98 l1: f32,
99 l2: f32,
100 ) -> (Self, f32) {
101 let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
102 model.n = dataset.features[0].len();
103 model.features = dataset
104 .features
105 .iter()
106 .map(|f| (f.clone(), 0))
107 .collect::<HashMap<_, _>>();
108 let result = model.train(epochs, dataset).unwrap();
109 model.file_type = dataset.ftype;
110 (model, result)
111 }
112
113 #[inline]
116 #[must_use]
117 pub fn predict(&self, input: &[f32]) -> f32 {
118 let linear_model = input
119 .iter()
120 .zip(&self.weights)
121 .map(|(x, w)| x * w)
122 .sum::<f32>()
123 + self.bias;
124 sigmoid(linear_model)
125 }
126
127 #[allow(clippy::cast_precision_loss)]
133 pub fn train(&mut self, epochs: u32, dataset: &Dataset) -> Result<f32, &'static str> {
134 if dataset.labels.is_empty() {
135 return Err("Dataset must have labels");
136 }
137
138 if !dataset.validate() {
139 return Err("Dataset didn't pass validity check!");
140 }
141
142 if dataset.data[0].len() != self.weights.len() {
143 return Err("Dataset feature length must equal the number of model weights");
144 }
145
146 let mut loss = 0.0;
147 #[allow(unused)]
148 for epoch in 0..epochs {
149 loss = 0.0;
150 for (input, output) in dataset.data.iter().zip(&dataset.labels) {
151 let prediction = self.predict(input);
152 let error = prediction - output;
153 let p = prediction.clamp(1e-8, 1.0 - 1e-8);
154 loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
155
156 self.weights
157 .par_iter_mut()
158 .enumerate()
159 .for_each(|(i, weight)| {
160 let l1r = self.l1 * (*weight / (weight.abs() + 1e-8));
161 let l2r = self.l2 * *weight;
162 *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
163 });
164 self.bias -= self.learning_rate * error;
165 }
166 loss /= self.weights.len() as f32;
167
168 #[cfg(debug_assertions)]
169 println!("Epoch: {epoch}, Log loss: {loss}");
170
171 if loss < 1e-6 {
172 break;
173 }
174 }
175
176 self.trained = true;
177 self.file_type = dataset.ftype;
178 self.n = dataset.features[0].len();
179 Ok(loss)
180 }
181
182 pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
188 ensure!(!dataset.is_empty(), "Dataset is empty");
189 ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
190 ensure!(
191 dataset.data[0].len() == self.weights.len(),
192 "Dataset length must equal the number of model weights"
193 );
194
195 let mut tp_ = 0;
196 let mut fp_ = 0;
197 let mut tn_ = 0;
198 let mut fn_ = 0;
199 let mut predictions = Vec::with_capacity(dataset.labels.len());
200
201 for index in 0..dataset.len() {
202 let prediction = self.predict(&dataset.data[index]);
203 if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
204 tp_ += 1;
205 } else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
206 fp_ += 1;
207 } else if prediction < 0.5 && dataset.labels[index] < 0.5 {
208 tn_ += 1;
209 } else {
210 fn_ += 1;
211 }
212 predictions.push(prediction);
213 }
214
215 Ok(ConfusionMatrix {
216 true_p: tp_,
217 true_n: tn_,
218 false_p: fp_,
219 false_n: fn_,
220 dataset,
221 predictions,
222 })
223 }
224
225 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
231 pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
232 ensure!(
233 !self.features.is_empty(),
234 "Features are required for file evaluation"
235 );
236
237 ensure!(
238 self.file_type.matches_path(&path)?,
239 "File type doesn't match model type"
240 );
241
242 let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
243 let result = self.predict(&vector);
244 let features = vector.iter().map(|v| *v as u32).sum();
245 if result > 0.5 {
246 Ok(("Malicious", result, features))
247 } else {
248 Ok(("Benign", result, features))
249 }
250 }
251
252 pub fn reduce(&mut self) {
256 if self.trained {
257 let mut removed = vec![];
258 self.weights = self
259 .weights
260 .iter()
261 .enumerate()
262 .filter_map(|(index, w)| {
263 if w.abs() > 0.01 {
264 Some(w)
265 } else {
266 removed.push(index);
267 None
268 }
269 })
270 .copied()
271 .collect();
272
273 if !self.features.is_empty() {
274 removed.sort_unstable();
275 removed.reverse();
276 let mut removed_features = Vec::with_capacity(removed.len());
277 for index in removed {
278 for (feat, feat_index) in &self.features {
279 if index == *feat_index {
280 removed_features.push(feat.clone());
281 }
282 }
283 }
284
285 for removed_feature in removed_features {
286 self.features.remove(&removed_feature);
287 }
288 }
289 }
290 }
291
292 pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
298 ensure!(
299 features.len() == self.weights.len(),
300 "Provided features length {} does not equal the number of model features length {}",
301 features.len(),
302 self.weights.len()
303 );
304 self.features = features
305 .into_iter()
306 .enumerate()
307 .map(|(f, i)| (i, f))
308 .collect::<HashMap<_, _>>();
309
310 Ok(())
311 }
312
313 pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
319 ensure!(
320 features.len() == self.weights.len(),
321 "Provided features length {} does not equal the number of model features length {}",
322 features.len(),
323 self.weights.len()
324 );
325
326 Ok(Self {
327 learning_rate: self.learning_rate,
328 bias: self.bias,
329 weights: self.weights,
330 l1: self.l1,
331 l2: self.l2,
332 trained: self.trained,
333 original_ngrams: self.original_ngrams,
334 file_type: self.file_type,
335 n: self.n,
336 features: features
337 .into_iter()
338 .enumerate()
339 .map(|(f, i)| (i, f))
340 .collect::<HashMap<_, _>>(),
341 })
342 }
343}
344
345#[derive(Debug, Clone, PartialEq)]
347pub struct ConfusionMatrix<'a> {
348 pub true_p: u32,
350
351 pub true_n: u32,
353
354 pub false_p: u32,
356
357 pub false_n: u32,
359
360 dataset: &'a Dataset,
362
363 predictions: Vec<f32>,
365}
366
367impl ConfusionMatrix<'_> {
368 #[inline]
370 #[must_use]
371 #[allow(clippy::cast_precision_loss)]
372 pub fn accuracy(&self) -> f32 {
373 (self.true_p + self.true_n) as f32 / self.total() as f32
374 }
375
376 #[must_use]
378 #[allow(clippy::cast_precision_loss)]
379 pub fn precision(&self) -> f32 {
380 self.true_p as f32 / (self.true_p + self.false_p) as f32
381 }
382
383 #[must_use]
385 #[allow(clippy::cast_precision_loss)]
386 pub fn recall(&self) -> f32 {
387 self.true_p as f32 / (self.true_p + self.false_n) as f32
388 }
389
390 #[must_use]
392 #[allow(clippy::cast_precision_loss)]
393 pub fn f1(&self) -> f32 {
394 2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
395 }
396
397 #[inline]
399 #[must_use]
400 pub fn total(&self) -> u32 {
401 self.true_p + self.true_n + self.false_p + self.false_n
402 }
403
404 #[must_use]
406 #[allow(clippy::float_cmp)]
407 pub fn auc(&self) -> f32 {
408 let (mut true_positive_count, mut false_positive_count) = {
415 let mut pairs: Vec<_> = self
416 .predictions
417 .iter()
418 .copied()
419 .zip(self.dataset.labels.iter().copied())
420 .collect();
421
422 pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
424
425 let mut score_prev = f32::NAN;
426 let (mut tp, mut fp) = (0.0f32, 0.0f32);
428 let (mut tps, mut fps) = (vec![], vec![]);
429 for (score, label) in pairs {
430 if score != score_prev {
434 tps.push(tp);
435 fps.push(fp);
436 score_prev = score;
437 }
438 tp += label;
439 fp += 1.0 - label;
440 }
441 tps.push(tp);
443 fps.push(fp);
444 (tps, fps)
445 };
446
447 let true_positives = true_positive_count[true_positive_count.len() - 1];
448 let false_positives = false_positive_count[false_positive_count.len() - 1];
449
450 for (tp, fp) in true_positive_count
451 .iter_mut()
452 .zip(false_positive_count.iter_mut())
453 {
454 *tp /= true_positives;
455 *fp /= false_positives;
456 }
457
458 let mut prev_x = false_positive_count[0];
459 let mut prev_y = true_positive_count[0];
460 let mut integral = 0.0;
461
462 for (&x, &y) in false_positive_count
463 .iter()
464 .skip(1)
465 .zip(true_positive_count.iter().skip(1))
466 {
467 integral += (x - prev_x) * (prev_y + y) / 2.0;
468
469 prev_x = x;
470 prev_y = y;
471 }
472
473 integral
474 }
475}
476
477impl std::fmt::Display for ConfusionMatrix<'_> {
478 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
479 writeln!(f, "Result \\ Actual | Malicious | Benign")?;
480 writeln!(
481 f,
482 " Malicious | {} | {}",
483 self.true_p, self.false_p
484 )?;
485 writeln!(
486 f,
487 " Benign | {} | {}",
488 self.false_n, self.true_n
489 )
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use crate::dataset::Dataset;
497
498 #[test]
499 fn xor() {
500 let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
501 let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
502 lr.train(100, &dataset).unwrap();
503
504 let mut correct = 0u16;
505 let mut incorrect = 0u16;
506
507 for index in 0..dataset.data.len() {
508 println!(
509 "Predicted: {}, Expected: {}",
510 lr.predict(&dataset.data[index]),
511 dataset.labels[index]
512 );
513 if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
514 || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
515 {
516 correct += 1;
517 } else {
518 incorrect += 1;
519 }
520 }
521
522 println!("Correct: {correct}, Incorrect: {incorrect}");
523 assert!(correct > incorrect);
524
525 let result = lr.evaluate_dataset(&dataset).unwrap();
526 println!("{result}");
527 println!("Accuracy: {:.2}", result.accuracy());
528 println!("Precision: {:.2}", result.precision());
529 println!("Recall: {:.2}", result.recall());
530 println!("F1: {:.2}", result.f1());
531 println!("Auc: {:.2}", result.auc());
532 }
533
534 #[test]
535 fn reduction() {
536 const BOGUS_LEN: usize = 6;
537
538 let dataset =
539 Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
540 let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
541 lr.set_features(dataset.features.clone()).unwrap();
542 lr.train(20, &dataset).unwrap();
543 println!("Weights before reduction: {:?}", lr.weights);
544 println!("Features before reduction: {:?}", lr.features);
545 lr.reduce();
546 println!("Weights after reduction: {:?}", lr.weights);
547 println!("Features after reduction: {:?}", lr.features);
548 println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
549 assert!(
550 lr.weights.len() < BOGUS_LEN,
551 "** If this assertion fails, re-run the test once or twice. **"
552 );
553 }
554
555 #[test]
556 fn auc() {
557 let y_true = vec![1.0, 1.0, 0.0, 0.0];
558 let y_hat = vec![0.5, 0.2, 0.3, -1.0];
559
560 let dataset = Dataset {
561 data: vec![],
562 labels: y_true,
563 features: vec![],
564 ftype: FileType::DOCFILE, };
566
567 let confusion_matrix = ConfusionMatrix {
568 true_p: 0,
569 true_n: 0,
570 false_p: 0,
571 false_n: 0,
572 dataset: &dataset,
573 predictions: y_hat,
574 };
575
576 let auc = confusion_matrix.auc();
577 println!("Auc: {auc:.2}, expected 0.75");
578 assert!((0.73..0.78).contains(&auc));
579 }
580}