1use crate::{dataset::Dataset, Bytes};
4
5use std::cmp::Ordering;
6use std::path::Path;
7
8use anyhow::{ensure, Result};
9use rand::Rng;
10use rayon::prelude::*;
11use serde::ser::Error;
12use serde::{Deserialize, Deserializer, Serialize, Serializer};
13#[inline]
18#[must_use]
19fn sigmoid(x: f32) -> f32 {
20 1.0 / (1.0 + (-x).exp())
21}
22
23#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
25pub struct LogisticRegression {
26 pub learning_rate: f32,
28
29 pub bias: f32,
31
32 pub weights: Vec<f32>,
34
35 pub l1: f32,
37
38 pub l2: f32,
40
41 #[serde(
43 serialize_with = "model_serialize_features",
44 deserialize_with = "model_deserialize_features"
45 )]
46 pub features: Vec<Bytes>,
47
48 pub trained: bool,
50
51 pub original_ngrams: u32,
53}
54
55impl LogisticRegression {
56 #[must_use]
58 #[allow(clippy::cast_possible_truncation)]
59 pub fn new(input_size: usize, learning_rate: f32, l1: f32, l2: f32) -> Self {
60 let mut rng = rand::rng();
61
62 Self {
63 learning_rate,
64 weights: (0..input_size)
65 .map(|_| rng.random_range(-1.0..1.0))
66 .collect(),
67 l1,
68 l2,
69 features: vec![],
70 trained: false,
71 bias: rng.random(),
72 original_ngrams: input_size as u32,
73 }
74 }
75
76 #[must_use]
84 pub fn new_from_dataset_and_train(
85 dataset: &Dataset,
86 epochs: u32,
87 learning_rate: f32,
88 l1: f32,
89 l2: f32,
90 ) -> (Self, f32) {
91 let mut model = Self::new(dataset.data.len(), learning_rate, l1, l2);
92 model.features.clone_from(&dataset.features);
93 let result = model.train(epochs, dataset).unwrap();
94 (model, result)
95 }
96
97 #[inline]
100 #[must_use]
101 pub fn predict(&self, input: &[f32]) -> f32 {
102 let linear_model = input
103 .iter()
104 .zip(&self.weights)
105 .map(|(x, w)| x * w)
106 .sum::<f32>()
107 + self.bias;
108 sigmoid(linear_model)
109 }
110
111 #[allow(clippy::cast_precision_loss)]
117 pub fn train(&mut self, epochs: u32, dataset: &Dataset) -> Result<f32, &'static str> {
118 if dataset.labels.is_empty() {
119 return Err("Dataset must have labels");
120 }
121
122 if !dataset.validate() {
123 return Err("Dataset didn't pass validity check!");
124 }
125
126 if dataset.data[0].len() != self.weights.len() {
127 return Err("Dataset feature length must equal the number of model weights");
128 }
129
130 let mut loss = 0.0;
131 #[allow(unused)]
132 for epoch in 0..epochs {
133 loss = 0.0;
134 for (input, output) in dataset.data.iter().zip(&dataset.labels) {
135 let prediction = self.predict(input);
136 let error = prediction - output;
137 let p = prediction.clamp(1e-8, 1.0 - 1e-8);
138 loss += -output * p.ln() - (1.0 - output) * (1.0 - p).ln();
139
140 self.weights
141 .par_iter_mut()
142 .enumerate()
143 .for_each(|(i, weight)| {
144 let l1r = self.l1 * (*weight / (weight.abs() + 1e-8));
145 let l2r = self.l2 * *weight;
146 *weight -= self.learning_rate * (error * input[i] + l1r + l2r);
147 });
148 self.bias -= self.learning_rate * error;
149 }
150 loss /= self.weights.len() as f32;
151
152 #[cfg(debug_assertions)]
153 println!("Epoch: {epoch}, Log loss: {loss}");
154
155 if loss < 1e-6 {
156 break;
157 }
158 }
159
160 self.trained = true;
161 Ok(loss)
162 }
163
164 pub fn evaluate_dataset<'a>(&self, dataset: &'a Dataset) -> Result<ConfusionMatrix<'a>> {
170 ensure!(!dataset.is_empty(), "Dataset is empty");
171 ensure!(!dataset.labels.is_empty(), "Dataset labels is empty");
172 ensure!(
173 dataset.data[0].len() == self.weights.len(),
174 "Dataset length must equal the number of model weights"
175 );
176
177 let mut tp_ = 0;
178 let mut fp_ = 0;
179 let mut tn_ = 0;
180 let mut fn_ = 0;
181 let mut predictions = Vec::with_capacity(dataset.labels.len());
182
183 for index in 0..dataset.len() {
184 let prediction = self.predict(&dataset.data[index]);
185 if prediction >= 0.5 && dataset.labels[index] >= 0.9 {
186 tp_ += 1;
187 } else if prediction >= 0.5 && dataset.labels[index] < 0.5 {
188 fp_ += 1;
189 } else if prediction < 0.5 && dataset.labels[index] < 0.5 {
190 tn_ += 1;
191 } else {
192 fn_ += 1;
193 }
194 predictions.push(prediction);
195 }
196
197 Ok(ConfusionMatrix {
198 true_p: tp_,
199 true_n: tn_,
200 false_p: fp_,
201 false_n: fn_,
202 dataset,
203 predictions,
204 })
205 }
206
207 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
213 pub fn evaluate_file<P: AsRef<Path>>(&self, path: P) -> Result<(&'static str, f32, u32)> {
214 ensure!(
215 !self.features.is_empty(),
216 "Features are required for file evaluation"
217 );
218
219 let n = self.features[0].len();
220 let vector = crate::dataset::featurize_file(path, n, &self.features)?;
221 let result = self.predict(&vector);
222 let features = vector.iter().map(|v| *v as u32).sum();
223 if result > 0.5 {
224 Ok(("Malicious", result, features))
225 } else {
226 Ok(("Benign", result, features))
227 }
228 }
229
230 pub fn reduce(&mut self) {
234 if self.trained {
235 let mut removed = vec![];
236 self.weights = self
237 .weights
238 .iter()
239 .enumerate()
240 .filter_map(|(index, w)| {
241 if w.abs() > 0.01 {
242 Some(w)
243 } else {
244 removed.push(index);
245 None
246 }
247 })
248 .copied()
249 .collect();
250
251 if !self.features.is_empty() {
252 removed.sort_unstable();
253 removed.reverse();
254 for index in removed {
255 self.features.remove(index);
256 }
257 }
258 }
259 }
260
261 pub fn set_features(&mut self, features: Vec<Bytes>) -> Result<()> {
267 ensure!(
268 features.len() == self.weights.len(),
269 "Provided features length {} does not equal the number of model features length {}",
270 features.len(),
271 self.weights.len()
272 );
273 self.features = features;
274
275 Ok(())
276 }
277
278 pub fn with_features(self, features: Vec<Bytes>) -> Result<Self> {
284 ensure!(
285 features.len() == self.weights.len(),
286 "Provided features length {} does not equal the number of model features length {}",
287 features.len(),
288 self.weights.len()
289 );
290
291 Ok(Self {
292 learning_rate: self.learning_rate,
293 bias: self.bias,
294 weights: self.weights,
295 l1: self.l1,
296 l2: self.l2,
297 trained: self.trained,
298 original_ngrams: self.original_ngrams,
299 features,
300 })
301 }
302}
303
304fn model_serialize_features<S>(x: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error>
305where
306 S: Serializer,
307{
308 if x.is_empty() {
309 return Err(Error::custom("N-gram features not set!"));
310 }
311
312 let features = x.iter().map(hex::encode).collect::<Vec<String>>();
313 s.collect_seq(features)
314}
315
316fn model_deserialize_features<'de, D>(deserializer: D) -> Result<Vec<Vec<u8>>, D::Error>
317where
318 D: Deserializer<'de>,
319{
320 use serde::de::Error;
321 let features = Vec::<String>::deserialize(deserializer)?;
322 if features.is_empty() {
323 return Err(Error::custom("N-gram features were empty!"));
324 }
325
326 features
327 .into_iter()
328 .map(hex::decode)
329 .collect::<Result<Vec<Vec<u8>>, _>>()
330 .map_err(Error::custom)
331}
332
333#[derive(Debug, Clone, PartialEq)]
335pub struct ConfusionMatrix<'a> {
336 pub true_p: u32,
338
339 pub true_n: u32,
341
342 pub false_p: u32,
344
345 pub false_n: u32,
347
348 dataset: &'a Dataset,
350
351 predictions: Vec<f32>,
353}
354
355impl ConfusionMatrix<'_> {
356 #[inline]
358 #[must_use]
359 #[allow(clippy::cast_precision_loss)]
360 pub fn accuracy(&self) -> f32 {
361 (self.true_p + self.true_n) as f32 / self.total() as f32
362 }
363
364 #[must_use]
366 #[allow(clippy::cast_precision_loss)]
367 pub fn precision(&self) -> f32 {
368 self.true_p as f32 / (self.true_p + self.false_p) as f32
369 }
370
371 #[must_use]
373 #[allow(clippy::cast_precision_loss)]
374 pub fn recall(&self) -> f32 {
375 self.true_p as f32 / (self.true_p + self.false_n) as f32
376 }
377
378 #[must_use]
380 #[allow(clippy::cast_precision_loss)]
381 pub fn f1(&self) -> f32 {
382 2.0 * (self.precision() * self.recall()) / (self.precision() + self.recall())
383 }
384
385 #[inline]
387 #[must_use]
388 pub fn total(&self) -> u32 {
389 self.true_p + self.true_n + self.false_p + self.false_n
390 }
391
392 #[must_use]
394 #[allow(clippy::float_cmp)]
395 pub fn auc(&self) -> f32 {
396 let (mut true_positive_count, mut false_positive_count) = {
403 let mut pairs: Vec<_> = self
404 .predictions
405 .iter()
406 .copied()
407 .zip(self.dataset.labels.iter().copied())
408 .collect();
409
410 pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
412
413 let mut score_prev = f32::NAN;
414 let (mut tp, mut fp) = (0.0f32, 0.0f32);
416 let (mut tps, mut fps) = (vec![], vec![]);
417 for (score, label) in pairs {
418 if score != score_prev {
422 tps.push(tp);
423 fps.push(fp);
424 score_prev = score;
425 }
426 tp += label;
427 fp += 1.0 - label;
428 }
429 tps.push(tp);
431 fps.push(fp);
432 (tps, fps)
433 };
434
435 let true_positives = true_positive_count[true_positive_count.len() - 1];
436 let false_positives = false_positive_count[false_positive_count.len() - 1];
437
438 for (tp, fp) in true_positive_count
439 .iter_mut()
440 .zip(false_positive_count.iter_mut())
441 {
442 *tp /= true_positives;
443 *fp /= false_positives;
444 }
445
446 let mut prev_x = false_positive_count[0];
447 let mut prev_y = true_positive_count[0];
448 let mut integral = 0.0;
449
450 for (&x, &y) in false_positive_count
451 .iter()
452 .skip(1)
453 .zip(true_positive_count.iter().skip(1))
454 {
455 integral += (x - prev_x) * (prev_y + y) / 2.0;
456
457 prev_x = x;
458 prev_y = y;
459 }
460
461 integral
462 }
463}
464
465impl std::fmt::Display for ConfusionMatrix<'_> {
466 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 writeln!(f, "Result \\ Actual | Malicious | Benign")?;
468 writeln!(
469 f,
470 " Malicious | {} | {}",
471 self.true_p, self.false_p
472 )?;
473 writeln!(
474 f,
475 " Benign | {} | {}",
476 self.false_n, self.true_n
477 )
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::dataset::Dataset;
485
486 #[test]
487 fn xor() {
488 let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
489 let mut lr = LogisticRegression::new(6, 0.2, 0.0, 0.0);
490 lr.train(100, &dataset).unwrap();
491
492 let mut correct = 0u16;
493 let mut incorrect = 0u16;
494
495 for index in 0..dataset.data.len() {
496 println!(
497 "Predicted: {}, Expected: {}",
498 lr.predict(&dataset.data[index]),
499 dataset.labels[index]
500 );
501 if (lr.predict(&dataset.data[index]) >= 0.5 && dataset.labels[index] >= 0.99)
502 || (lr.predict(&dataset.data[index]) < 0.5 && dataset.labels[index] < 0.1)
503 {
504 correct += 1;
505 } else {
506 incorrect += 1;
507 }
508 }
509
510 println!("Correct: {correct}, Incorrect: {incorrect}");
511 assert!(correct > incorrect);
512
513 let result = lr.evaluate_dataset(&dataset).unwrap();
514 println!("{result}");
515 println!("Accuracy: {:.2}", result.accuracy());
516 println!("Precision: {:.2}", result.precision());
517 println!("Recall: {:.2}", result.recall());
518 println!("F1: {:.2}", result.f1());
519 println!("Auc: {:.2}", result.auc());
520 }
521
522 #[test]
523 fn reduction() {
524 const BOGUS_LEN: usize = 6;
525
526 let dataset =
527 Dataset::from_csv_string(include_str!("../testdata/bogus.csv"), BOGUS_LEN).unwrap();
528 let mut lr = LogisticRegression::new(BOGUS_LEN, 0.2, 0.1, 0.1);
529 lr.set_features(dataset.features.clone()).unwrap();
530 lr.train(20, &dataset).unwrap();
531 println!("Weights before reduction: {:?}", lr.weights);
532 println!("Features before reduction: {:?}", lr.features);
533 lr.reduce();
534 println!("Weights after reduction: {:?}", lr.weights);
535 println!("Features after reduction: {:?}", lr.features);
536 println!("Weights from {BOGUS_LEN} to {}", lr.weights.len());
537 assert!(
538 lr.weights.len() < BOGUS_LEN,
539 "** If this assertion fails, re-run the test once or twice. **"
540 );
541 }
542
543 #[test]
544 fn auc() {
545 let y_true = vec![1.0, 1.0, 0.0, 0.0];
546 let y_hat = vec![0.5, 0.2, 0.3, -1.0];
547
548 let dataset = Dataset {
549 data: vec![],
550 labels: y_true,
551 features: vec![],
552 };
553
554 let confusion_matrix = ConfusionMatrix {
555 true_p: 0,
556 true_n: 0,
557 false_p: 0,
558 false_n: 0,
559 dataset: &dataset,
560 predictions: y_hat,
561 };
562
563 let auc = confusion_matrix.auc();
564 println!("Auc: {auc:.2}, expected 0.75");
565 assert!((0.73..0.78).contains(&auc));
566 }
567}