1use crate::Bytes;
4
5use std::io::{Read, Write};
6use std::path::Path;
7use std::str::FromStr;
8use std::sync::RwLock;
9
10use anyhow::{bail, ensure, Result};
11use rayon::prelude::*;
12use serde::{Deserialize, Serialize};
13use walkdir::WalkDir;
14
15const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
16const FEATURES_PREFIX: &str = "Features:";
17
18#[inline]
24pub fn featurize_file<P: AsRef<Path>>(file: P, n: usize, features: &[Bytes]) -> Result<Vec<f32>> {
25 let contents = std::fs::read(file)?;
26 let mut feature_vector = vec![0.0; features.len()];
27
28 for window in contents.windows(n) {
29 if let Some(position) = features.iter().position(|n| n == window) {
30 feature_vector[position] = 1.0;
31 }
32 }
33
34 Ok(feature_vector)
35}
36
37#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
39pub enum DatasetFormat {
40 ARFF,
42
43 CSV,
45
46 SVM,
48}
49
50impl FromStr for DatasetFormat {
51 type Err = String;
52
53 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
54 match value.to_lowercase().as_str() {
55 "arff" => Ok(Self::ARFF),
56 "csv" => Ok(Self::CSV),
57 "svm" => Ok(Self::SVM),
58 x => Err(format!("Unknown data format '{x}'")),
59 }
60 }
61}
62
63#[derive(Debug, Clone, Default, Deserialize, Serialize)]
65pub struct Dataset {
66 pub data: Vec<Vec<f32>>,
68
69 #[serde(default)]
71 pub labels: Vec<f32>,
72
73 pub features: Vec<Bytes>,
75}
76
77impl PartialEq for Dataset {
78 fn eq(&self, other: &Self) -> bool {
79 if self.data.len() != other.data.len()
80 || self.labels.len() != other.labels.len()
81 || self.features.len() != other.features.len()
82 {
83 return false;
84 }
85
86 for this_data in &self.data {
87 if !other.data.contains(this_data) {
88 return false;
89 }
90 }
91
92 for other_data in &other.data {
93 if !self.data.contains(other_data) {
94 return false;
95 }
96 }
97
98 if !self.labels.is_empty() {
99 for this_label in &self.labels {
100 if !other.labels.contains(this_label) {
101 return false;
102 }
103 }
104
105 for other_label in &other.labels {
106 if !self.labels.contains(other_label) {
107 return false;
108 }
109 }
110 }
111
112 for this_features in &self.features {
113 if !other.features.contains(this_features) {
114 return false;
115 }
116 }
117
118 for other_feature in &other.features {
119 if !self.features.contains(other_feature) {
120 return false;
121 }
122 }
123
124 true
125 }
126}
127
128impl Dataset {
129 pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
136 if let Some(extension) = path.as_ref().extension() {
137 return match extension.to_str().unwrap_or_default() {
138 "arff" => Dataset::from_arff_file(path.as_ref()),
139 "csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
140 "svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
141 "json" => {
142 let contents = std::fs::read_to_string(path.as_ref())?;
143 serde_json::from_str(&contents).map_err(Into::into)
144 }
145 "toml" => {
146 let contents = std::fs::read_to_string(path.as_ref())?;
147 toml::from_str(&contents).map_err(Into::into)
148 }
149 ext => {
150 bail!("Unsupported/unknown data type '{ext}'");
151 }
152 };
153 }
154
155 bail!("No extension, can't determine file type.");
156 }
157
158 pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
168 let mut file = std::fs::File::open(path)?;
169 let mut contents = String::new();
170 file.read_to_string(&mut contents)?;
171
172 Self::from_csv_string(&contents, data_length)
173 }
174
175 pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
185 let mut file = std::fs::File::open(path)?;
186 let mut contents = String::new();
187 file.read_to_string(&mut contents)?;
188
189 let mut length = 0;
190 for line in contents.lines() {
191 if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
192 continue;
193 }
194
195 length = line.split(',').collect::<Vec<&str>>().len();
196 break;
197 }
198
199 ensure!(length > 0, "Failed to determine data length.");
200 Self::from_csv_string(&contents, length - 1)
201 }
202
203 pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
212 let mut data: Vec<Vec<f32>> = Vec::new();
213 let mut labels = Vec::new();
214 let mut features = Vec::new();
215
216 for (row_number, line) in contents.lines().enumerate() {
217 if line.is_empty() {
218 continue;
219 }
220
221 if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) && line.contains(FEATURES_PREFIX) {
222 let offset = line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
223 let line = line[offset..].trim();
224 features = line
225 .split(',')
226 .filter_map(|f| hex::decode(f.trim()).ok())
227 .collect();
228 }
229
230 if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
231 continue;
232 }
233 let row = line.split(',').collect::<Vec<&str>>();
234 let mut row_float = Vec::with_capacity(data_length);
235 for r in row.iter().take(data_length) {
236 row_float.push(r.parse::<f32>().map_err(|_| {
237 anyhow::Error::msg(format!("Non-float encountered on CSV row {row_number}"))
238 })?);
239 }
240 if let Some(first_row) = data.first() {
241 ensure!(
242 first_row.len() == row_float.len(),
243 "CSV line {row_number} has invalid length {}, expected {}",
244 row_float.len(),
245 first_row.len()
246 );
247 }
248 data.push(row_float);
249 if row.len() == data_length + 1 {
250 let l = row[data_length].parse::<f32>().map_err(|_| {
251 anyhow::Error::msg(format!("Non-float encountered on CSV row {row_number}"))
252 })?;
253 labels.push(l);
254 } else if row.len() > data_length {
255 bail!(
256 "CSV row had more than one label on row {row_number}, which isn't supported."
257 );
258 }
259 }
260
261 ensure!(
262 features.len() == data[0].len(),
263 "Features need to be empty or the same size as the data length."
264 );
265
266 Ok(Self {
267 data,
268 labels,
269 features,
270 })
271 }
272
273 pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
283 let mut file = std::fs::File::open(path)?;
284 let mut contents = String::new();
285 file.read_to_string(&mut contents)?;
286
287 Self::from_arff_string(&contents)
288 }
289
290 pub fn from_arff_string(contents: &str) -> Result<Self> {
299 let mut data: Vec<Vec<f32>> = Vec::new();
300 let mut labels = Vec::new();
301 let mut features = Vec::new();
302 let mut passed_data = false;
303
304 for (row_number, line) in contents.lines().enumerate() {
305 if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
306 continue;
307 }
308
309 if line.contains("@ATTRIBUTE") {
310 let parts: Vec<&str> = line.split_ascii_whitespace().collect();
311 if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
312 match hex::decode(parts[1]) {
313 Ok(feat) => features.push(feat),
314 Err(e) => {
315 bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
316 }
317 }
318 }
319 }
320
321 if line.contains("@DATA") {
322 passed_data = true;
323 continue;
324 }
325
326 if passed_data {
328 let row = line.split(',').collect::<Vec<&str>>();
329 let data_length = row.len() - 1;
330 let mut row_float = Vec::with_capacity(data_length);
331 for r in row.iter().take(data_length) {
332 row_float.push(r.parse::<f32>().map_err(|_| {
333 anyhow::Error::msg(format!(
334 "Non-float encountered on ARFF row {row_number}"
335 ))
336 })?);
337 }
338 if let Some(first_row) = data.first() {
339 ensure!(
340 first_row.len() == row_float.len(),
341 "ARFF line {row_number} has invalid length {}, expected {}",
342 row_float.len(),
343 first_row.len()
344 );
345 }
346 data.push(row_float);
347 if row.len() == data_length + 1 {
348 let l = row[data_length].parse::<f32>().map_err(|_| {
349 anyhow::Error::msg(format!(
350 "Non-float encountered on ARFF row {row_number}"
351 ))
352 })?;
353 labels.push(l);
354 } else if row.len() > data_length {
355 bail!("Arff row had more than one label on row {row_number}, which isn't supported.");
356 }
357 }
358 }
359
360 ensure!(
361 features.len() == data[0].len(),
362 "Features need to be empty or the same size as the data length."
363 );
364
365 Ok(Self {
366 data,
367 labels,
368 features,
369 })
370 }
371
372 pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
382 let mut file = std::fs::File::open(path)?;
383 let mut contents = String::new();
384 file.read_to_string(&mut contents)?;
385
386 Self::from_libsvm_string(&contents)
387 }
388
389 pub fn from_libsvm_string(contents: &str) -> Result<Self> {
395 let mut data = Vec::new();
396 let mut labels = Vec::new();
397 let mut features = Vec::new();
398
399 for (row_number, line) in contents.lines().enumerate() {
400 if line.is_empty() {
401 continue;
402 }
403
404 if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) && line.contains(FEATURES_PREFIX) {
405 let offset = line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
406 let line = line[offset..].trim();
407 features = line
408 .split(',')
409 .filter_map(|f| hex::decode(f.trim()).ok())
410 .collect();
411 }
412
413 if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
414 continue;
415 }
416
417 let parts = line.split_whitespace().collect::<Vec<&str>>();
418 let label = parts[0].parse::<f32>()?;
419 let mut row = vec![0.0f32; features.len()];
420
421 for part in parts.iter().skip(1) {
422 let part_parts = part.split(':').collect::<Vec<&str>>();
423 let part_index = part_parts[0].parse::<usize>()?;
424 let part_value = part_parts[1].parse::<f32>()?;
425
426 if part_index > row.len() && !features.is_empty() {
427 bail!("Encountered a value at index {part_index} greater than expected size {} on line {row_number}", data.len());
428 }
429
430 if row.is_empty() {
431 row = vec![0.0; part_index + 1];
432 } else if part_index >= row.len() {
433 row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
434 }
435 row[part_index] = part_value;
436 }
437
438 data.push(row);
439 labels.push(label);
440 }
441
442 let data_len = data[0].len();
443 for row in &data {
444 if row.len() != data_len {
445 bail!(
446 "Encountered a row with length {} but expected length {data_len}",
447 row.len()
448 );
449 }
450 }
451
452 ensure!(
453 features.len() == data[0].len(),
454 "Features need to be empty or the same size as the data length."
455 );
456
457 Ok(Self {
458 data,
459 labels,
460 features,
461 })
462 }
463
464 pub fn create_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
471 malicious_dir: P,
472 benign_dir: P,
473 ngrams_file: P,
474 ) -> Result<Self> {
475 let ngram_contents = std::fs::read_to_string(&ngrams_file)?;
476 let mut n = 0;
477 let ngrams = ngram_contents
478 .lines()
479 .filter_map(|l| {
480 let line = if let Some(l) = l.split(',').collect::<Vec<&str>>().first() {
481 l
482 } else {
483 l
484 };
485 if !line.len().is_multiple_of(2) {
486 eprintln!("Line {line} has odd number of characters.");
487 return None;
488 }
489 if n == 0 {
490 n = line.len() / 2;
491 } else if line.len() / 2 != n {
492 eprintln!(
493 "Line {line} has unexpected length of {} bytes, expected {n}",
494 line.len() / 2
495 );
496 return None;
497 }
498 hex::decode(line).ok()
499 })
500 .collect::<Vec<_>>();
501
502 ensure!(
503 !ngrams.is_empty(),
504 "No n-grams read from {}.",
505 ngrams_file.as_ref().display()
506 );
507
508 let mut paths_labels = Vec::new();
509 for entry in WalkDir::new(malicious_dir)
510 .max_depth(crate::MAX_RECURSION_DEPTH)
511 .follow_links(true)
512 .into_iter()
513 .flatten()
514 {
515 if entry.file_type().is_file() {
516 paths_labels.push((entry, 1.0));
517 }
518 }
519
520 for entry in WalkDir::new(benign_dir)
521 .max_depth(crate::MAX_RECURSION_DEPTH)
522 .follow_links(true)
523 .into_iter()
524 .flatten()
525 {
526 if entry.file_type().is_file() {
527 paths_labels.push((entry, 0.0));
528 }
529 }
530
531 let found_files = paths_labels.len();
532 let dataset = Dataset::default();
533 let dataset_lock = RwLock::new(dataset);
534 paths_labels.into_par_iter().for_each(|(path, label)| {
535 match featurize_file(path.path(), n, &ngrams) {
536 Ok(features) => {
537 if let Ok(mut data) = dataset_lock.write() {
538 data.data.push(features);
539 data.labels.push(label);
540 }
541 }
542 Err(e) => eprintln!("Failed to featurized {}: {e}", path.path().display()),
543 }
544 });
545
546 let mut dataset = dataset_lock.into_inner()?;
547 dataset.features = ngrams;
548
549 if dataset.data.len() != found_files {
550 eprintln!(
551 "Warning: found {found_files} but only have features for {} files.",
552 dataset.data.len()
553 );
554 }
555
556 Ok(dataset)
557 }
558
559 pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
565 let mut file = std::fs::File::create(path)?;
566
567 let feature_string_vec = self
568 .features
569 .iter()
570 .map(hex::encode)
571 .collect::<Vec<String>>();
572 let features_string = format!("# {FEATURES_PREFIX} {}\n", feature_string_vec.join(", "));
573 file.write_all(features_string.as_bytes())?;
574
575 for index in 0..self.data.len() {
576 let mut line = self.data[index]
577 .iter()
578 .map(|p| format!("{p}"))
579 .collect::<Vec<String>>()
580 .join(",");
581
582 if !self.labels.is_empty() {
583 if self.labels[index] > 0.9 {
584 line.push_str(",1");
585 } else {
586 line.push_str(",0");
587 }
588 }
589 line.push('\n');
590
591 file.write_all(line.as_bytes())?;
592 }
593
594 file.sync_all().map_err(Into::into)
595 }
596
597 pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
603 let mut file = std::fs::File::create(path)?;
604
605 for feature in &self.features {
606 let feature_hex = hex::encode(feature);
607 file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
608 }
609
610 if !self.labels.is_empty() {
611 file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
612 }
613
614 file.write_all("\n@DATA\n".as_bytes())?;
615 for index in 0..self.data.len() {
616 let mut line = self.data[index]
617 .iter()
618 .map(|p| format!("{p}"))
619 .collect::<Vec<String>>()
620 .join(",");
621
622 if !self.labels.is_empty() {
623 if self.labels[index] > 0.9 {
624 line.push_str(",1");
625 } else {
626 line.push_str(",0");
627 }
628 }
629 line.push('\n');
630
631 file.write_all(line.as_bytes())?;
632 }
633
634 file.sync_all().map_err(Into::into)
635 }
636
637 pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
643 ensure!(
644 !self.labels.is_empty(),
645 "Labels are required to create an libsvm file."
646 );
647 let mut file = std::fs::File::create(path)?;
648
649 let feature_string_vec = self
650 .features
651 .iter()
652 .map(hex::encode)
653 .collect::<Vec<String>>();
654 let features_string = format!("# {FEATURES_PREFIX} {}\n", feature_string_vec.join(", "));
655 file.write_all(features_string.as_bytes())?;
656
657 for index in 0..self.data.len() {
658 file.write_all(format!("{}", self.labels[index]).as_bytes())?;
659 for (data_index, data) in self.data[index].iter().enumerate() {
660 if *data != 0.0000 {
661 file.write_all(format!(" {data_index}:{data}").as_bytes())?;
662 }
663 }
664
665 file.write_all(b"\n")?;
666 }
667
668 file.sync_all().map_err(Into::into)
669 }
670
671 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
677 if let Some(extension) = path.as_ref().extension() {
678 return match extension.to_str().unwrap_or_default() {
679 "arff" => self.save_arff(path),
680 "csv" => self.save_csv(path),
681 "svm" | "libsvm" => self.save_libsvm(path),
682 "json" => {
683 let contents = serde_json::to_string_pretty(self)?;
684 let mut file = std::fs::File::create(path)?;
685 file.write_all(contents.as_bytes())?;
686 file.sync_all().map_err(Into::into)
687 }
688 "toml" => {
689 let contents = toml::to_string_pretty(self)?;
690 let mut file = std::fs::File::create(path)?;
691 file.write_all(contents.as_bytes())?;
692 file.sync_all().map_err(Into::into)
693 }
694 ext => {
695 bail!("Unsupported/unknown data type '{ext}'");
696 }
697 };
698 }
699
700 bail!("No extension, can't determine file type.");
701 }
702
703 #[inline]
705 #[must_use]
706 pub fn len(&self) -> usize {
707 self.data.len()
708 }
709
710 #[inline]
712 #[must_use]
713 pub fn is_empty(&self) -> bool {
714 self.data.is_empty()
715 }
716
717 #[inline]
721 #[must_use]
722 pub fn validate(&self) -> bool {
723 let data_len = match self.data.first() {
724 Some(first) => first.len(),
725 None => return false,
726 };
727
728 for record in &self.data {
730 if record.len() != data_len {
731 #[cfg(debug_assertions)]
732 eprint!("Expected record size {data_len}, got {}", record.len());
733 return false;
734 }
735 }
736
737 let feature_len = if let Some(first) = self.features.first() {
738 first.len()
739 } else {
740 #[cfg(debug_assertions)]
741 eprintln!("Features data is missing");
742 return false;
743 };
744
745 for feature in &self.features {
746 if feature.len() != feature_len {
747 #[cfg(debug_assertions)]
748 eprint!("Expected feature size {feature_len}, got {}", feature.len());
749 return false;
750 }
751 }
752
753 (self.labels.is_empty() || self.labels.len() == self.data.len())
755 && self.features.len() == data_len
756 }
757
758 pub fn shuffle(&mut self) {
761 if !self.is_empty() {
763 let iterations = self.data.len().ilog10() * 10;
764 self.shuffle_iterations(iterations);
765 }
766 }
767
768 pub fn shuffle_iterations(&mut self, iterations: u32) {
771 use rand::Rng;
772
773 if !self.is_empty() {
774 let mut rng = rand::rng();
775
776 for _ in 0..iterations {
777 let a = rng.random_range(0..self.data.len());
778 let b = rng.random_range(0..self.data.len());
779 let b = if b == a {
780 rng.random_range(0..self.data.len())
781 } else {
782 b
783 };
784
785 self.data.swap(a, b);
786 if !self.labels.is_empty() {
787 self.labels.swap(a, b);
788 }
789 }
790 }
791 }
792
793 #[must_use]
796 #[allow(
797 clippy::cast_sign_loss,
798 clippy::cast_possible_truncation,
799 clippy::cast_precision_loss
800 )]
801 pub fn split(&mut self, ratio: f32) -> Self {
802 let ratio = ratio.abs();
803 let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
804 let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
805
806 let new_data = self.data.drain(new_size..).collect();
807 let new_labels = if self.labels.is_empty() {
808 vec![]
809 } else {
810 self.labels.drain(new_size..).collect()
811 };
812
813 Self {
814 data: new_data,
815 labels: new_labels,
816 features: self.features.clone(),
817 }
818 }
819}
820
821#[cfg(test)]
822mod tests {
823 use crate::dataset::Dataset;
824
825 #[test]
826 fn xor() {
827 let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
828 assert!(csv_dataset.validate());
829
830 let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
831 assert!(arff_dataset.validate());
832
833 let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
834 assert!(svm_dataset.validate());
835
836 assert_eq!(csv_dataset, arff_dataset);
837 assert_eq!(csv_dataset, svm_dataset);
838 assert_eq!(arff_dataset, svm_dataset);
839 }
840
841 #[test]
842 fn xor_no_label() {
843 assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
844 assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
845 }
846
847 #[test]
848 fn shuffle() {
849 let original_dataset =
850 Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
851 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
852 dataset.shuffle();
853
854 assert_eq!(original_dataset, dataset);
855 assert_ne!(original_dataset.data, dataset.data);
856 assert_ne!(original_dataset.labels, dataset.labels);
857 assert_eq!(original_dataset.features, dataset.features);
858 }
859
860 #[test]
861 fn split() {
862 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
863 let original_size = dataset.len();
864 let smaller = dataset.split(0.8);
865
866 println!(
867 "Original: {original_size}, New size: {}, Smaller dataset: {}",
868 dataset.len(),
869 smaller.len()
870 );
871 assert!(smaller.len() < dataset.len());
872 assert_eq!(original_size, dataset.len() + smaller.len());
873 assert_ne!(dataset, smaller);
874 assert_eq!(dataset.features, smaller.features);
875 }
876
877 #[test]
878 fn save() {
879 const COPY_CSV: &str = "xor_copy.csv";
880 const COPY_ARFF: &str = "xor_copy.arff";
881 const COPY_SVM: &str = "xor_copy.svm";
882
883 let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
884 dataset.save_csv(COPY_CSV).unwrap();
885 dataset.save_arff(COPY_ARFF).unwrap();
886 dataset.save_libsvm(COPY_SVM).unwrap();
887
888 let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
889 assert_eq!(dataset, dataset2);
890
891 let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
892 assert_eq!(dataset, dataset3);
893 assert_eq!(dataset2, dataset3);
894
895 let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
896 assert_eq!(dataset, dataset4);
897 assert_eq!(dataset3, dataset4);
898
899 std::fs::remove_file(COPY_CSV).unwrap();
900 std::fs::remove_file(COPY_ARFF).unwrap();
901 std::fs::remove_file(COPY_SVM).unwrap();
902 }
903}