1use crate::ftype::FileType;
4use crate::{Bytes, dataset::Dataset};
5
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::path::Path;
9
10use anyhow::{Result, anyhow, ensure};
11use rand::RngExt;
12use rayon::prelude::*;
13use serde::{Deserialize, Serialize};
14
15#[inline]
20#[must_use]
21fn sigmoid(x: f32) -> f32 {
22 1.0 / (1.0 + (-x).exp())
23}
24
25#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
27pub struct LogisticRegression {
28 pub learning_rate: f32,
30
31 pub bias: f32,
33
34 pub weights: Vec<f32>,
36
37 pub l1: f32,
39
40 pub l2: f32,
42
43 #[serde(
45 serialize_with = "crate::serde::serialize_hex_map",
46 deserialize_with = "crate::serde::deserialize_hex_map"
47 )]
48 pub features: HashMap<Bytes, usize>,
49
50 n: usize,
52
53 pub trained: bool,
55
56 pub original_ngrams: u32,
58
59 pub file_type: FileType,
61
62 pub train_performance: PerformanceStats,
64
65 #[serde(default)]
67 pub test_performance: Option<PerformanceStats>,
68}
69
70impl LogisticRegression {
71 #[must_use]
73 #[allow(clippy::cast_possible_truncation)]
74 pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
75 let mut rng = rand::rng();
76
77 Self {
78 learning_rate,
79 weights: (0..input_size)
80 .map(|_| rng.random_range(-1.0..1.0))
81 .collect(),
82 l1,
83 l2,
84 n: 0,
85 features: HashMap::new(),
86 trained: false,
87 bias: rng.random(),
88 original_ngrams: input_size as u32,
89 file_type: FileType::NotSet,
90 train_performance: PerformanceStats::default(),
91 test_performance: None,
92 }
93 }
94
95 #[must_use]
103 pub fn new_from_dataset_and_train(
104 dataset: &mut Dataset,
105 test: Option<&Dataset>,
106 epochs: u32,
107 learning_rate: f32,
108 l1: f32,
109 l2: f32,
110 ) -> (Self, f32) {
111 let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
112 model.n = dataset.features[0].len();
113 model.features = dataset
114 .features
115 .iter()
116 .map(|f| (f.clone(), 0))
117 .collect::<HashMap<_, _>>();
118 let result = model.train(epochs, dataset).unwrap();
119 let train_performance = model.evaluate_dataset(dataset).unwrap();
120 model.file_type = dataset.ftype;
121 model.train_performance = train_performance.into();
122
123 if let Some(test) = test
124 && dataset.ftype == test.ftype
125 && model.n == test.data[0].len()
126 && let Ok(test_performance) = model.evaluate_dataset(test)
127 {
128 model.test_performance = Some(test_performance.into());
129 }
130
131 (model, result)
132 }
133
134 #[inline]
137 #[must_use]
138 pub fn predict(&self, input: &[f32]) -> f32 {
139 let linear_model = input
140 .iter()
141 .zip(&self.weights)
142 .map(|(x, w)| x * w)
143 .sum::<f32>()
144 + self.bias;
145 sigmoid(linear_model)
146 }
147
148 #[allow(clippy::cast_precision_loss)]
156 pub fn train(&mut self, epochs: u32, dataset: &mut Dataset) -> Result<f32, &'static str> {
157 if dataset.labels.is_empty() {
158 return Err("Dataset must have labels");
159 }
160
161 if !dataset.validate() {
162 return Err("Dataset didn't pass validity check!");
163 }
164
165 if dataset.data[0].len() != self.weights.len() {
166 return Err("Dataset feature length must equal the number of model weights");
167 }
168
169 let mut loss = 0.0;
170 #[allow(unused)]
171 for epoch in 0..epochs {
172 loss = 0.0;
173 dataset.shuffle();
174 for (input, output) in dataset.data.iter().zip(&dataset.labels) {
175 let output = f32::from(*output);
176 let prediction = self.predict(input);
177 let error = prediction - output;
178 let p = prediction.clamp(1e-8, 1.0 - 1e-8);
179 loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
180
181 self.weights
182 .par_iter_mut()
183 .enumerate()
184 .for_each(|(i, weight)| {
185 let l1r = self.l1 * weight.signum();
186 let l2r = self.l2 * *weight;
187 *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
188 });
189 self.bias -= self.learning_rate * error;
190 }
191 loss /= self.weights.len() as f32;
192
193 #[cfg(debug_assertions)]
194 println!("Epoch: {epoch}, Log loss: {loss}");
195
196 if loss < 1e-6 {
197 break;
198 }
199 }
200
201 self.trained = true;
202 self.file_type = dataset.ftype;
203 self.n = dataset.features[0].len();
204 if let Ok(confusion_matrix) = self.evaluate_dataset(dataset) {
205 self.train_performance = confusion_matrix.into();
206 }
207 Ok(loss)
208 }
209
210 pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
216 ensure!(!dataset.is_empty(), "Dataset is empty");
217 ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
218 ensure!(
219 dataset.data[0].len() == self.weights.len(),
220 "Dataset length must equal the number of model weights"
221 );
222 ensure!(
223 self.file_type == dataset.ftype,
224 "Dataset file type must match model file type"
225 );
226
227 let mut tp_ = 0;
228 let mut fp_ = 0;
229 let mut tn_ = 0;
230 let mut fn_ = 0;
231 let mut predictions = Vec::with_capacity(dataset.labels.len());
232
233 for index in 0..dataset.len() {
234 let prediction = self.predict(&dataset.data[index]);
235 if prediction >= 0.5 && dataset.labels[index] >= 1 {
236 tp_ += 1;
237 } else if prediction >= 0.5 && dataset.labels[index] < 1 {
238 fp_ += 1;
239 } else if prediction < 0.5 && dataset.labels[index] < 1 {
240 tn_ += 1;
241 } else {
242 fn_ += 1;
243 }
244 predictions.push(prediction);
245 }
246
247 Ok(ConfusionMatrix {
248 true_p: tp_,
249 true_n: tn_,
250 false_p: fp_,
251 false_n: fn_,
252 dataset,
253 predictions,
254 })
255 }
256
257 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
263 pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
264 ensure!(
265 !self.features.is_empty(),
266 "Features are required for file evaluation"
267 );
268
269 ensure!(
270 self.file_type.matches_path(&path)?,
271 "File type doesn't match model type"
272 );
273
274 let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
275 let result = self.predict(&vector);
276 let features = vector.iter().map(|v| *v as u32).sum();
277 if result > 0.5 {
278 Ok(("Malicious", result, features))
279 } else {
280 Ok(("Benign", result, features))
281 }
282 }
283
284 #[allow(clippy::cast_precision_loss)]
290 pub fn reduce(&mut self, dataset: &Dataset) -> Result<usize> {
291 const THRESHOLD: f32 = 0.01;
292
293 ensure!(
294 self.trained,
295 "Model must be trained before reducing weights"
296 );
297 ensure!(
298 self.file_type == dataset.ftype,
299 "Dataset file type must match model file type"
300 );
301
302 let mut average_features = Vec::with_capacity(self.weights.len());
303 for (index, model_weight) in self.weights.iter().enumerate() {
304 let col_iter = dataset
305 .column_iter(index)
306 .ok_or_else(|| anyhow!("Column index out of bounds"))?;
307 average_features.push((
308 index,
309 (model_weight * col_iter.sum::<f32>() / dataset.len() as f32).abs(),
310 ));
311 }
312
313 average_features.sort_by(|(_, a), (_, b)| {
314 if a > b {
315 Ordering::Greater
316 } else if a < b {
317 Ordering::Less
318 } else {
319 Ordering::Equal
320 }
321 });
322
323 let mut removed = vec![];
324 let mut running_sum = 0.0;
325 for (index, weight) in average_features {
326 running_sum += weight;
327 if running_sum < THRESHOLD {
328 removed.push(index);
329 }
330 }
331
332 removed.sort_unstable();
333 removed.reverse();
334 let removed_len = removed.len();
335 for to_remove in &removed {
336 self.weights.remove(*to_remove);
337 }
338
339 if !self.features.is_empty() {
340 let mut removed_features = Vec::with_capacity(removed_len);
341 for index in &removed {
342 for (feat, feat_index) in &self.features {
343 if index == feat_index {
344 removed_features.push(feat.clone());
345 }
346 }
347 }
348
349 for removed_feature in removed_features {
350 self.features.remove(&removed_feature);
351 }
352 }
353
354 Ok(removed_len)
355 }
356
357 pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
363 ensure!(
364 features.len() == self.weights.len(),
365 "Provided features length {} does not equal the number of model features length {}",
366 features.len(),
367 self.weights.len()
368 );
369 self.features = features
370 .into_iter()
371 .enumerate()
372 .map(|(f, i)| (i, f))
373 .collect::<HashMap<_, _>>();
374
375 Ok(())
376 }
377
378 pub fn set_features_and_reduce(&mut self, dataset: &Dataset) -> Result<usize> {
386 ensure!(
387 self.file_type == dataset.ftype,
388 "Dataset file type must match model file type"
389 );
390 ensure!(
391 dataset.data[0].len() == self.weights.len(),
392 "Dataset length must equal the number of model weights"
393 );
394
395 self.features = dataset
396 .features
397 .iter()
398 .enumerate()
399 .map(|(f, i)| (i.clone(), f))
400 .collect::<HashMap<_, _>>();
401
402 self.reduce(dataset)
403 }
404
405 pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
411 ensure!(
412 features.len() == self.weights.len(),
413 "Provided features length {} does not equal the number of model features length {}",
414 features.len(),
415 self.weights.len()
416 );
417
418 Ok(Self {
419 learning_rate: self.learning_rate,
420 bias: self.bias,
421 weights: self.weights,
422 l1: self.l1,
423 l2: self.l2,
424 trained: self.trained,
425 original_ngrams: self.original_ngrams,
426 file_type: self.file_type,
427 n: self.n,
428 train_performance: PerformanceStats::default(),
429 test_performance: None,
430 features: features
431 .into_iter()
432 .enumerate()
433 .map(|(f, i)| (i, f))
434 .collect::<HashMap<_, _>>(),
435 })
436 }
437}
438
439#[derive(Debug, Clone, Copy, Default, PartialEq, Serialize, Deserialize)]
441pub struct PerformanceStats {
442 pub true_positives: u32,
444
445 pub true_negatives: u32,
447
448 pub false_positives: u32,
450
451 pub false_negatives: u32,
453
454 pub recall: f32,
456
457 pub precision: f32,
459
460 pub f1: f32,
462
463 pub auc: f32,
465}
466
467impl<'a> From<ConfusionMatrix<'a>> for PerformanceStats {
468 fn from(value: ConfusionMatrix<'a>) -> Self {
469 PerformanceStats {
470 true_positives: value.true_p,
471 true_negatives: value.true_n,
472 false_positives: value.false_p,
473 false_negatives: value.false_n,
474 recall: value.recall(),
475 precision: value.precision(),
476 f1: value.f1(),
477 auc: value.auc(),
478 }
479 }
480}
481
482#[derive(Debug, Clone, PartialEq)]
484pub struct ConfusionMatrix<'a> {
485 pub true_p: u32,
487
488 pub true_n: u32,
490
491 pub false_p: u32,
493
494 pub false_n: u32,
496
497 dataset: &'a Dataset,
499
500 predictions: Vec<f32>,
502}
503
504impl ConfusionMatrix<'_> {
505 #[inline]
507 #[must_use]
508 #[allow(clippy::cast_precision_loss)]
509 pub fn accuracy(&self) -> f32 {
510 (self.true_p + self.true_n) as f32 / self.total() as f32
511 }
512
513 #[must_use]
515 #[allow(clippy::cast_precision_loss)]
516 pub fn precision(&self) -> f32 {
517 self.true_p as f32 / (self.true_p + self.false_p) as f32
518 }
519
520 #[must_use]
522 #[allow(clippy::cast_precision_loss)]
523 pub fn recall(&self) -> f32 {
524 self.true_p as f32 / (self.true_p + self.false_n) as f32
525 }
526
527 #[must_use]
529 #[allow(clippy::cast_precision_loss)]
530 pub fn f1(&self) -> f32 {
531 2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
532 }
533
534 #[inline]
536 #[must_use]
537 pub fn total(&self) -> u32 {
538 self.true_p + self.true_n + self.false_p + self.false_n
539 }
540
541 #[must_use]
543 #[allow(clippy::float_cmp)]
544 pub fn auc(&self) -> f32 {
545 let (mut true_positive_count, mut false_positive_count) = {
552 let mut pairs: Vec<_> = self
553 .predictions
554 .iter()
555 .copied()
556 .zip(self.dataset.labels.iter().copied())
557 .collect();
558
559 pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
561
562 let mut score_prev = f32::NAN;
563 let (mut tp, mut fp) = (0.0f32, 0.0f32);
565 let (mut tps, mut fps) = (vec![], vec![]);
566 for (score, label) in pairs {
567 let label = f32::from(label);
568 if score != score_prev {
572 tps.push(tp);
573 fps.push(fp);
574 score_prev = score;
575 }
576 tp += label;
577 fp += 1.0 - label;
578 }
579 tps.push(tp);
581 fps.push(fp);
582 (tps, fps)
583 };
584
585 let true_positives = true_positive_count[true_positive_count.len() - 1];
586 let false_positives = false_positive_count[false_positive_count.len() - 1];
587
588 for (tp, fp) in true_positive_count
589 .iter_mut()
590 .zip(false_positive_count.iter_mut())
591 {
592 *tp /= true_positives;
593 *fp /= false_positives;
594 }
595
596 let mut prev_x = false_positive_count[0];
597 let mut prev_y = true_positive_count[0];
598 let mut integral = 0.0;
599
600 for (&x, &y) in false_positive_count
601 .iter()
602 .skip(1)
603 .zip(true_positive_count.iter().skip(1))
604 {
605 integral += (x - prev_x) * (prev_y + y) / 2.0;
606
607 prev_x = x;
608 prev_y = y;
609 }
610
611 integral
612 }
613}
614
615impl std::fmt::Display for ConfusionMatrix<'_> {
616 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
617 const WIDTH: usize = 10;
618
619 writeln!(f, "Result \\ Actual | Malicious | Benign")?;
620 writeln!(
621 f,
622 " Malicious | {:<WIDTH$} | {:<WIDTH$}",
623 self.true_p, self.false_p
624 )?;
625 writeln!(
626 f,
627 " Benign | {:<WIDTH$} | {:<WIDTH$}",
628 self.false_n, self.true_n
629 )
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use crate::dataset::Dataset;
637
638 #[test]
639 fn xor() {
640 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
641 let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
642 lr.train(100, &mut dataset).unwrap();
643
644 let mut correct = 0u16;
645 let mut incorrect = 0u16;
646
647 for index in 0..dataset.data.len() {
648 println!(
649 "Predicted: {}, Expected: {}",
650 lr.predict(&dataset.data[index]),
651 dataset.labels[index]
652 );
653 if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 1)
654 || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 1)
655 {
656 correct += 1;
657 } else {
658 incorrect += 1;
659 }
660 }
661
662 println!("Correct: {correct}, Incorrect: {incorrect}");
663 assert!(correct > incorrect);
664
665 let result = lr.evaluate_dataset(&dataset).unwrap();
666 println!("{result}");
667 println!("Accuracy: {:.2}", result.accuracy());
668 println!("Precision: {:.2}", result.precision());
669 println!("Recall: {:.2}", result.recall());
670 println!("F1: {:.2}", result.f1());
671 println!("Auc: {:.2}", result.auc());
672 }
673
674 #[test]
675 fn reduction() {
676 const BOGUS_LEN: usize = 6;
677
678 let mut dataset =
679 Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
680 let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
681 lr.set_features(dataset.features.clone()).unwrap();
682 lr.train(20, &mut dataset).unwrap();
683 let cm = lr.evaluate_dataset(&dataset).unwrap();
684 println!("{cm}");
685 println!("Weights before reduction: {:?}", lr.weights);
686 println!("Features before reduction: {:?}", lr.features);
687 lr.reduce(&dataset).expect("Failed to reduce weights");
688 println!("Weights after reduction: {:?}", lr.weights);
689 println!("Features after reduction: {:?}", lr.features);
690 println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
691 assert!(
692 lr.weights.len() < BOGUS_LEN,
693 "** If this assertion fails, re-run the test once or twice. **"
694 );
695 lr.test_performance = Some(cm.into());
696 println!("Model: {lr:?}");
697 }
698
699 #[test]
700 fn auc() {
701 let y_true = vec![1, 1, 0, 0];
702 let y_hat = vec![0.5, 0.2, 0.3, -1.0];
703
704 let dataset = Dataset {
705 data: vec![],
706 labels: y_true,
707 features: vec![],
708 ftype: FileType::DOCFILE, };
710
711 let confusion_matrix = ConfusionMatrix {
712 true_p: 0,
713 true_n: 0,
714 false_p: 0,
715 false_n: 0,
716 dataset: &dataset,
717 predictions: y_hat,
718 };
719
720 let auc = confusion_matrix.auc();
721 println!("Auc: {auc:.2}, expected 0.75");
722 assert!((0.73..0.78).contains(&auc));
723 }
724}