1use crate::ftype::FileType;
4use crate::{dataset::Dataset, Bytes};
5
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::path::Path;
9
10use anyhow::{anyhow, ensure, Result};
11use rand::Rng;
12use rayon::prelude::*;
13use serde::{Deserialize, Serialize};
14
15#[inline]
20#[must_use]
21fn sigmoid(x: f32) -> f32 {
22 1.0 / (1.0 + (-x).exp())
23}
24
25#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
27pub struct LogisticRegression {
28 pub learning_rate: f32,
30
31 pub bias: f32,
33
34 pub weights: Vec<f32>,
36
37 pub l1: f32,
39
40 pub l2: f32,
42
43 #[serde(
45 serialize_with = "crate::serde::serialize_hex_map",
46 deserialize_with = "crate::serde::deserialize_hex_map"
47 )]
48 pub features: HashMap<Bytes, usize>,
49
50 n: usize,
52
53 pub trained: bool,
55
56 pub original_ngrams: u32,
58
59 pub file_type: FileType,
61}
62
63impl LogisticRegression {
64 #[must_use]
66 #[allow(clippy::cast_possible_truncation)]
67 pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
68 let mut rng = rand::rng();
69
70 Self {
71 learning_rate,
72 weights: (0..input_size)
73 .map(|_| rng.random_range(-1.0..1.0))
74 .collect(),
75 l1,
76 l2,
77 n: 0,
78 features: HashMap::new(),
79 trained: false,
80 bias: rng.random(),
81 original_ngrams: input_size as u32,
82 file_type: FileType::NotSet,
83 }
84 }
85
86 #[must_use]
94 pub fn new_from_dataset_and_train(
95 dataset: &mut Dataset,
96 epochs: u32,
97 learning_rate: f32,
98 l1: f32,
99 l2: f32,
100 ) -> (Self, f32) {
101 let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
102 model.n = dataset.features[0].len();
103 model.features = dataset
104 .features
105 .iter()
106 .map(|f| (f.clone(), 0))
107 .collect::<HashMap<_, _>>();
108 let result = model.train(epochs, dataset).unwrap();
109 model.file_type = dataset.ftype;
110 (model, result)
111 }
112
113 #[inline]
116 #[must_use]
117 pub fn predict(&self, input: &[f32]) -> f32 {
118 let linear_model = input
119 .iter()
120 .zip(&self.weights)
121 .map(|(x, w)| x * w)
122 .sum::<f32>()
123 + self.bias;
124 sigmoid(linear_model)
125 }
126
127 #[allow(clippy::cast_precision_loss)]
133 pub fn train(&mut self, epochs: u32, dataset: &mut Dataset) -> Result<f32, &'static str> {
134 if dataset.labels.is_empty() {
135 return Err("Dataset must have labels");
136 }
137
138 if !dataset.validate() {
139 return Err("Dataset didn't pass validity check!");
140 }
141
142 if dataset.data[0].len() != self.weights.len() {
143 return Err("Dataset feature length must equal the number of model weights");
144 }
145
146 let mut loss = 0.0;
147 #[allow(unused)]
148 for epoch in 0..epochs {
149 loss = 0.0;
150 dataset.shuffle();
151 for (input, output) in dataset.data.iter().zip(&dataset.labels) {
152 let prediction = self.predict(input);
153 let error = prediction - output;
154 let p = prediction.clamp(1e-8, 1.0 - 1e-8);
155 loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
156
157 self.weights
158 .par_iter_mut()
159 .enumerate()
160 .for_each(|(i, weight)| {
161 let l1r = self.l1 * weight.signum();
162 let l2r = self.l2 * *weight;
163 *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
164 });
165 self.bias -= self.learning_rate * error;
166 }
167 loss /= self.weights.len() as f32;
168
169 #[cfg(debug_assertions)]
170 println!("Epoch: {epoch}, Log loss: {loss}");
171
172 if loss < 1e-6 {
173 break;
174 }
175 }
176
177 self.trained = true;
178 self.file_type = dataset.ftype;
179 self.n = dataset.features[0].len();
180 Ok(loss)
181 }
182
183 pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
189 ensure!(!dataset.is_empty(), "Dataset is empty");
190 ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
191 ensure!(
192 dataset.data[0].len() == self.weights.len(),
193 "Dataset length must equal the number of model weights"
194 );
195 ensure!(
196 self.file_type == dataset.ftype,
197 "Dataset file type must match model file type"
198 );
199
200 let mut tp_ = 0;
201 let mut fp_ = 0;
202 let mut tn_ = 0;
203 let mut fn_ = 0;
204 let mut predictions = Vec::with_capacity(dataset.labels.len());
205
206 for index in 0..dataset.len() {
207 let prediction = self.predict(&dataset.data[index]);
208 if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
209 tp_ += 1;
210 } else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
211 fp_ += 1;
212 } else if prediction < 0.5 && dataset.labels[index] < 0.5 {
213 tn_ += 1;
214 } else {
215 fn_ += 1;
216 }
217 predictions.push(prediction);
218 }
219
220 Ok(ConfusionMatrix {
221 true_p: tp_,
222 true_n: tn_,
223 false_p: fp_,
224 false_n: fn_,
225 dataset,
226 predictions,
227 })
228 }
229
230 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
236 pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
237 ensure!(
238 !self.features.is_empty(),
239 "Features are required for file evaluation"
240 );
241
242 ensure!(
243 self.file_type.matches_path(&path)?,
244 "File type doesn't match model type"
245 );
246
247 let vector = crate::dataset::featurize_file(path, self.n, &self.features)?;
248 let result = self.predict(&vector);
249 let features = vector.iter().map(|v| *v as u32).sum();
250 if result > 0.5 {
251 Ok(("Malicious", result, features))
252 } else {
253 Ok(("Benign", result, features))
254 }
255 }
256
257 #[allow(clippy::cast_precision_loss)]
263 pub fn reduce(&mut self, dataset: &Dataset) -> Result<usize> {
264 const THRESHOLD: f32 = 0.01;
265
266 ensure!(
267 self.trained,
268 "Model must be trained before reducing weights"
269 );
270 ensure!(
271 self.file_type == dataset.ftype,
272 "Dataset file type must match model file type"
273 );
274
275 let mut average_features = Vec::with_capacity(self.weights.len());
276 for (index, model_weight) in self.weights.iter().enumerate() {
277 let col_iter = dataset
278 .column_iter(index)
279 .ok_or_else(|| anyhow!("Column index out of bounds"))?;
280 average_features.push((
281 index,
282 (model_weight * col_iter.sum::<f32>() / dataset.len() as f32).abs(),
283 ));
284 }
285
286 average_features.sort_by(|(_, a), (_, b)| {
287 if a > b {
288 Ordering::Greater
289 } else if a < b {
290 Ordering::Less
291 } else {
292 Ordering::Equal
293 }
294 });
295
296 let mut removed = vec![];
297 let mut running_sum = 0.0;
298 for (index, weight) in average_features {
299 running_sum += weight;
300 if running_sum < THRESHOLD {
301 removed.push(index);
302 }
303 }
304
305 removed.sort_unstable();
306 removed.reverse();
307 let removed_len = removed.len();
308 for to_remove in &removed {
309 self.weights.remove(*to_remove);
310 }
311
312 if !self.features.is_empty() {
313 let mut removed_features = Vec::with_capacity(removed_len);
314 for index in &removed {
315 for (feat, feat_index) in &self.features {
316 if index == feat_index {
317 removed_features.push(feat.clone());
318 }
319 }
320 }
321
322 for removed_feature in removed_features {
323 self.features.remove(&removed_feature);
324 }
325 }
326
327 Ok(removed_len)
328 }
329
330 pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
336 ensure!(
337 features.len() == self.weights.len(),
338 "Provided features length {} does not equal the number of model features length {}",
339 features.len(),
340 self.weights.len()
341 );
342 self.features = features
343 .into_iter()
344 .enumerate()
345 .map(|(f, i)| (i, f))
346 .collect::<HashMap<_, _>>();
347
348 Ok(())
349 }
350
351 pub fn set_features_and_reduce(&mut self, dataset: &Dataset) -> Result<usize> {
359 ensure!(
360 self.file_type == dataset.ftype,
361 "Dataset file type must match model file type"
362 );
363 ensure!(
364 dataset.data[0].len() == self.weights.len(),
365 "Dataset length must equal the number of model weights"
366 );
367
368 self.features = dataset
369 .features
370 .iter()
371 .enumerate()
372 .map(|(f, i)| (i.clone(), f))
373 .collect::<HashMap<_, _>>();
374
375 self.reduce(dataset)
376 }
377
378 pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
384 ensure!(
385 features.len() == self.weights.len(),
386 "Provided features length {} does not equal the number of model features length {}",
387 features.len(),
388 self.weights.len()
389 );
390
391 Ok(Self {
392 learning_rate: self.learning_rate,
393 bias: self.bias,
394 weights: self.weights,
395 l1: self.l1,
396 l2: self.l2,
397 trained: self.trained,
398 original_ngrams: self.original_ngrams,
399 file_type: self.file_type,
400 n: self.n,
401 features: features
402 .into_iter()
403 .enumerate()
404 .map(|(f, i)| (i, f))
405 .collect::<HashMap<_, _>>(),
406 })
407 }
408}
409
410#[derive(Debug, Clone, PartialEq)]
412pub struct ConfusionMatrix<'a> {
413 pub true_p: u32,
415
416 pub true_n: u32,
418
419 pub false_p: u32,
421
422 pub false_n: u32,
424
425 dataset: &'a Dataset,
427
428 predictions: Vec<f32>,
430}
431
432impl ConfusionMatrix<'_> {
433 #[inline]
435 #[must_use]
436 #[allow(clippy::cast_precision_loss)]
437 pub fn accuracy(&self) -> f32 {
438 (self.true_p + self.true_n) as f32 / self.total() as f32
439 }
440
441 #[must_use]
443 #[allow(clippy::cast_precision_loss)]
444 pub fn precision(&self) -> f32 {
445 self.true_p as f32 / (self.true_p + self.false_p) as f32
446 }
447
448 #[must_use]
450 #[allow(clippy::cast_precision_loss)]
451 pub fn recall(&self) -> f32 {
452 self.true_p as f32 / (self.true_p + self.false_n) as f32
453 }
454
455 #[must_use]
457 #[allow(clippy::cast_precision_loss)]
458 pub fn f1(&self) -> f32 {
459 2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
460 }
461
462 #[inline]
464 #[must_use]
465 pub fn total(&self) -> u32 {
466 self.true_p + self.true_n + self.false_p + self.false_n
467 }
468
469 #[must_use]
471 #[allow(clippy::float_cmp)]
472 pub fn auc(&self) -> f32 {
473 let (mut true_positive_count, mut false_positive_count) = {
480 let mut pairs: Vec<_> = self
481 .predictions
482 .iter()
483 .copied()
484 .zip(self.dataset.labels.iter().copied())
485 .collect();
486
487 pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
489
490 let mut score_prev = f32::NAN;
491 let (mut tp, mut fp) = (0.0f32, 0.0f32);
493 let (mut tps, mut fps) = (vec![], vec![]);
494 for (score, label) in pairs {
495 if score != score_prev {
499 tps.push(tp);
500 fps.push(fp);
501 score_prev = score;
502 }
503 tp += label;
504 fp += 1.0 - label;
505 }
506 tps.push(tp);
508 fps.push(fp);
509 (tps, fps)
510 };
511
512 let true_positives = true_positive_count[true_positive_count.len() - 1];
513 let false_positives = false_positive_count[false_positive_count.len() - 1];
514
515 for (tp, fp) in true_positive_count
516 .iter_mut()
517 .zip(false_positive_count.iter_mut())
518 {
519 *tp /= true_positives;
520 *fp /= false_positives;
521 }
522
523 let mut prev_x = false_positive_count[0];
524 let mut prev_y = true_positive_count[0];
525 let mut integral = 0.0;
526
527 for (&x, &y) in false_positive_count
528 .iter()
529 .skip(1)
530 .zip(true_positive_count.iter().skip(1))
531 {
532 integral += (x - prev_x) * (prev_y + y) / 2.0;
533
534 prev_x = x;
535 prev_y = y;
536 }
537
538 integral
539 }
540}
541
542impl std::fmt::Display for ConfusionMatrix<'_> {
543 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
544 const WIDTH: usize = 10;
545
546 writeln!(f, "Result \\ Actual | Malicious | Benign")?;
547 writeln!(
548 f,
549 " Malicious | {:<WIDTH$} | {:<WIDTH$}",
550 self.true_p, self.false_p
551 )?;
552 writeln!(
553 f,
554 " Benign | {:<WIDTH$} | {:<WIDTH$}",
555 self.false_n, self.true_n
556 )
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use crate::dataset::Dataset;
564
565 #[test]
566 fn xor() {
567 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
568 let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
569 lr.train(100, &mut dataset).unwrap();
570
571 let mut correct = 0u16;
572 let mut incorrect = 0u16;
573
574 for index in 0..dataset.data.len() {
575 println!(
576 "Predicted: {}, Expected: {}",
577 lr.predict(&dataset.data[index]),
578 dataset.labels[index]
579 );
580 if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
581 || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
582 {
583 correct += 1;
584 } else {
585 incorrect += 1;
586 }
587 }
588
589 println!("Correct: {correct}, Incorrect: {incorrect}");
590 assert!(correct > incorrect);
591
592 let result = lr.evaluate_dataset(&dataset).unwrap();
593 println!("{result}");
594 println!("Accuracy: {:.2}", result.accuracy());
595 println!("Precision: {:.2}", result.precision());
596 println!("Recall: {:.2}", result.recall());
597 println!("F1: {:.2}", result.f1());
598 println!("Auc: {:.2}", result.auc());
599 }
600
601 #[test]
602 fn reduction() {
603 const BOGUS_LEN: usize = 6;
604
605 let mut dataset =
606 Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
607 let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
608 lr.set_features(dataset.features.clone()).unwrap();
609 lr.train(20, &mut dataset).unwrap();
610 let cm = lr.evaluate_dataset(&dataset).unwrap();
611 println!("{cm}");
612 println!("Weights before reduction: {:?}", lr.weights);
613 println!("Features before reduction: {:?}", lr.features);
614 lr.reduce(&dataset).expect("Failed to reduce weights");
615 println!("Weights after reduction: {:?}", lr.weights);
616 println!("Features after reduction: {:?}", lr.features);
617 println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
618 assert!(
619 lr.weights.len() < BOGUS_LEN,
620 "** If this assertion fails, re-run the test once or twice. **"
621 );
622 }
623
624 #[test]
625 fn auc() {
626 let y_true = vec![1.0, 1.0, 0.0, 0.0];
627 let y_hat = vec![0.5, 0.2, 0.3, -1.0];
628
629 let dataset = Dataset {
630 data: vec![],
631 labels: y_true,
632 features: vec![],
633 ftype: FileType::DOCFILE, };
635
636 let confusion_matrix = ConfusionMatrix {
637 true_p: 0,
638 true_n: 0,
639 false_p: 0,
640 false_n: 0,
641 dataset: &dataset,
642 predictions: y_hat,
643 };
644
645 let auc = confusion_matrix.auc();
646 println!("Auc: {auc:.2}, expected 0.75");
647 assert!((0.73..0.78).contains(&auc));
648 }
649}