1use crate::ftype::FileType;
4use crate::model::LogisticRegression;
5use crate::ngram::NgramsFile;
6use crate::Bytes;
7
8use std::collections::HashMap;
9use std::io::{Read, Seek, SeekFrom, Write};
10use std::path::Path;
11use std::str::FromStr;
12
13use anyhow::{anyhow, bail, ensure, Result};
14use serde::de::IntoDeserializer;
15use serde::{Deserialize, Serialize};
16use walkdir::WalkDir;
17
18const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
19const FEATURES_PREFIX: &str = "Features:";
20const FILE_TYPE_PREFIX: &str = "File type:";
21
22#[inline]
28pub(crate) fn featurize_file<P: AsRef<Path>, S: ::std::hash::BuildHasher>(
29 file: P,
30 n: usize,
31 features: &HashMap<Bytes, usize, S>,
32) -> Result<Vec<f32>> {
33 let file_size = std::fs::metadata(&file)?.len();
34 ensure!(
35 file_size > n as u64,
36 "File {} is too small.",
37 file.as_ref().display()
38 );
39
40 let mut feature_vector = vec![0.0; features.len()];
41
42 if file_size < 10_485_760u64
43 {
45 let contents = std::fs::read(file)?;
46 for window in contents.windows(n) {
47 if let Some(index) = features.get(window) {
48 feature_vector[*index] = 1.0;
49 }
50 }
51 } else {
52 let mut file = std::fs::File::open(file)?;
53 let mut buffer = [0u8; crate::ngram::NGRAM_BUFFER_SIZE];
54 while let Ok(bytes_read) = file.read(&mut buffer) {
55 if bytes_read < n {
56 break;
57 }
58 for index in 0..bytes_read - n {
59 if let Some(index) = features.get(&buffer[index..index + n]) {
60 feature_vector[*index] = 1.0;
61 }
62 }
63
64 #[allow(clippy::cast_possible_wrap)]
68 file.seek(SeekFrom::Current(n as i64 - 1))?;
69 }
70 }
71
72 Ok(feature_vector)
73}
74
75#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
77pub enum DatasetFormat {
78 ARFF,
80
81 CSV,
83
84 SVM,
86}
87
88impl FromStr for DatasetFormat {
89 type Err = anyhow::Error;
90
91 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
92 match value.to_lowercase().as_str() {
93 "arff" => Ok(Self::ARFF),
94 "csv" => Ok(Self::CSV),
95 "svm" => Ok(Self::SVM),
96 x => Err(anyhow!("Unknown data format '{x}'")),
97 }
98 }
99}
100
101impl TryFrom<&Path> for DatasetFormat {
102 type Error = anyhow::Error;
103
104 fn try_from(value: &Path) -> std::result::Result<Self, Self::Error> {
105 if let Some(extension) = value.extension() {
106 let ext = extension
107 .to_str()
108 .ok_or_else(|| anyhow!("Failed to get extension."))?;
109 DatasetFormat::from_str(ext)
110 } else {
111 Err(anyhow!("No extension, can't determine file type."))
112 }
113 }
114}
115
116#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct Dataset {
119 pub data: Vec<Vec<f32>>,
121
122 #[serde(default)]
124 pub labels: Vec<f32>,
125
126 #[serde(
128 serialize_with = "crate::serde::serialize_hex_vec",
129 deserialize_with = "crate::serde::deserialize_hex_vec"
130 )]
131 pub features: Vec<Bytes>,
132
133 pub ftype: FileType,
135}
136
137impl PartialEq for Dataset {
138 fn eq(&self, other: &Self) -> bool {
139 if self.data.len() != other.data.len()
140 || self.labels.len() != other.labels.len()
141 || self.features.len() != other.features.len()
142 || self.ftype != other.ftype
143 {
144 return false;
145 }
146
147 for this_data in &self.data {
148 if !other.data.contains(this_data) {
149 return false;
150 }
151 }
152
153 for other_data in &other.data {
154 if !self.data.contains(other_data) {
155 return false;
156 }
157 }
158
159 if !self.labels.is_empty() {
160 for this_label in &self.labels {
161 if !other.labels.contains(this_label) {
162 return false;
163 }
164 }
165
166 for other_label in &other.labels {
167 if !self.labels.contains(other_label) {
168 return false;
169 }
170 }
171 }
172
173 for this_features in &self.features {
174 if !other.features.contains(this_features) {
175 return false;
176 }
177 }
178
179 for other_feature in &other.features {
180 if !self.features.contains(other_feature) {
181 return false;
182 }
183 }
184
185 true
186 }
187}
188
189impl Dataset {
190 pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
197 if let Some(extension) = path.as_ref().extension() {
198 return match extension.to_str().unwrap_or_default() {
199 "arff" => Dataset::from_arff_file(path.as_ref()),
200 "csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
201 "svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
202 "json" => {
203 let contents = std::fs::read_to_string(path.as_ref())?;
204 serde_json::from_str(&contents).map_err(Into::into)
205 }
206 "toml" => {
207 let contents = std::fs::read_to_string(path.as_ref())?;
208 toml::from_str(&contents).map_err(Into::into)
209 }
210 ext => {
211 bail!("Unsupported/unknown data type '{ext}'");
212 }
213 };
214 }
215
216 bail!("No extension, can't determine file type.");
217 }
218
219 pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
229 let mut file = std::fs::File::open(path)?;
230 let mut contents = String::new();
231 file.read_to_string(&mut contents)?;
232
233 Self::from_csv_string(&contents, data_length)
234 }
235
236 pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
246 let mut file = std::fs::File::open(path)?;
247 let mut contents = String::new();
248 file.read_to_string(&mut contents)?;
249
250 let mut length = 0;
251 for line in contents.lines() {
252 if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
253 continue;
254 }
255
256 length = line.split(',').collect::<Vec<&str>>().len();
257 break;
258 }
259
260 ensure!(length > 0, "Failed to determine data length.");
261 Self::from_csv_string(&contents, length - 1)
262 }
263
264 pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
273 let mut data: Vec<Vec<f32>> = Vec::new();
274 let mut labels = Vec::new();
275 let mut features = Vec::new();
276 let mut file_type = FileType::NotSet;
277
278 for (row_number, line) in contents.lines().enumerate() {
279 if line.is_empty() {
280 continue;
281 }
282
283 if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
284 if line.contains(FEATURES_PREFIX) {
285 let offset =
286 line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
287 let line = line[offset..].trim();
288 features = line
289 .split(',')
290 .filter_map(|f| hex::decode(f.trim()).ok())
291 .collect();
292 }
293
294 if line.contains(FILE_TYPE_PREFIX) {
295 if let Some(file_type_str) = line.split(':').nth(1) {
296 let file_type_str = file_type_str.trim();
297 let ftype: Result<_, serde::de::value::Error> =
298 FileType::deserialize(String::from(file_type_str).into_deserializer());
299 file_type = ftype?;
300 }
301 }
302 }
303
304 if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
305 continue;
306 }
307 let row = line.split(',').collect::<Vec<&str>>();
308 let mut row_float = Vec::with_capacity(data_length);
309 for r in row.iter().take(data_length) {
310 row_float.push(r.parse::<f32>().map_err(|_| {
311 anyhow::Error::msg(format!("Non-float {r} encountered on CSV row {row_number}"))
312 })?);
313 }
314 if let Some(first_row) = data.first() {
315 ensure!(
316 first_row.len() == row_float.len(),
317 "CSV line {row_number} has invalid length {}, expected {}",
318 row_float.len(),
319 first_row.len()
320 );
321 }
322 data.push(row_float);
323 if row.len() == data_length + 1 {
324 let l = row[data_length].parse::<f32>().map_err(|_| {
325 anyhow::Error::msg(format!(
326 "Non-float label {} encountered on CSV row {row_number}",
327 row[data_length]
328 ))
329 })?;
330 labels.push(l);
331 } else if row.len() > data_length {
332 bail!(
333 "CSV row had more than one label on row {row_number}, which isn't supported."
334 );
335 }
336 }
337
338 ensure!(
339 features.len() == data[0].len(),
340 "Features need to be empty or the same size as the data length."
341 );
342
343 ensure!(
344 file_type != FileType::NotSet,
345 "No file type specified in CSV file."
346 );
347 Ok(Self {
348 data,
349 labels,
350 features,
351 ftype: file_type,
352 })
353 }
354
355 #[inline]
357 pub(crate) fn file_type_from_line(line: &str) -> Result<FileType, serde::de::value::Error> {
358 let line = line.split(':').nth(1).unwrap_or_default();
359 let ftype: Result<_, serde::de::value::Error> =
360 FileType::deserialize(String::from(line.trim()).into_deserializer());
361 ftype
362 }
363
364 pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
374 let mut file = std::fs::File::open(path)?;
375 let mut contents = String::new();
376 file.read_to_string(&mut contents)?;
377
378 Self::from_arff_string(&contents)
379 }
380
381 pub fn from_arff_string(contents: &str) -> Result<Self> {
390 let mut data: Vec<Vec<f32>> = Vec::new();
391 let mut labels = Vec::new();
392 let mut features = Vec::new();
393 let mut file_type = FileType::NotSet;
394 let mut passed_data = false;
395
396 for (row_number, line) in contents.lines().enumerate() {
397 if line.is_empty() {
398 continue;
399 }
400
401 if (line.starts_with('%') || line.starts_with('#')) && line.contains(FILE_TYPE_PREFIX) {
402 file_type = Self::file_type_from_line(line)?;
403 continue;
404 }
405
406 if line.contains("@ATTRIBUTE") {
407 let parts: Vec<&str> = line.split_ascii_whitespace().collect();
408 if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
409 match hex::decode(parts[1]) {
410 Ok(feat) => features.push(feat),
411 Err(e) => {
412 bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
413 }
414 }
415 }
416 }
417
418 if line.contains("@DATA") {
419 passed_data = true;
420 continue;
421 }
422
423 if passed_data {
425 let row = line.split(',').collect::<Vec<&str>>();
426 let data_length = row.len() - 1;
427 let mut row_float = Vec::with_capacity(data_length);
428 for r in row.iter().take(data_length) {
429 row_float.push(r.parse::<f32>().map_err(|_| {
430 anyhow::Error::msg(format!(
431 "Non-float encountered on ARFF row {row_number}"
432 ))
433 })?);
434 }
435 if let Some(first_row) = data.first() {
436 ensure!(
437 first_row.len() == row_float.len(),
438 "ARFF line {row_number} has invalid length {}, expected {}",
439 row_float.len(),
440 first_row.len()
441 );
442 }
443 data.push(row_float);
444 if row.len() == data_length + 1 {
445 let l = row[data_length].parse::<f32>().map_err(|_| {
446 anyhow::Error::msg(format!(
447 "Non-float encountered on ARFF row {row_number}"
448 ))
449 })?;
450 labels.push(l);
451 } else if row.len() > data_length {
452 bail!("Arff row had more than one label on row {row_number}, which isn't supported.");
453 }
454 }
455 }
456
457 ensure!(
458 features.len() == data[0].len(),
459 "Features need to be empty or the same size as the data length."
460 );
461
462 ensure!(
463 file_type != FileType::NotSet,
464 "No file type specified in ARFF file."
465 );
466 Ok(Self {
467 data,
468 labels,
469 features,
470 ftype: file_type,
471 })
472 }
473
474 pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
484 let mut file = std::fs::File::open(path)?;
485 let mut contents = String::new();
486 file.read_to_string(&mut contents)?;
487
488 Self::from_libsvm_string(&contents)
489 }
490
491 pub fn from_libsvm_string(contents: &str) -> Result<Self> {
497 let mut data = Vec::new();
498 let mut labels = Vec::new();
499 let mut features = Vec::new();
500 let mut file_type = FileType::NotSet;
501
502 for (row_number, line) in contents.lines().enumerate() {
503 if line.is_empty() {
504 continue;
505 }
506
507 if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
508 if line.contains(FEATURES_PREFIX) {
509 let offset =
510 line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
511 let line = line[offset..].trim();
512 features = line
513 .split(',')
514 .filter_map(|f| hex::decode(f.trim()).ok())
515 .collect();
516 }
517
518 if line.contains(FILE_TYPE_PREFIX) {
519 file_type = Self::file_type_from_line(line)?;
520 }
521 }
522
523 if line.is_empty() || line.starts_with('%') || line.starts_with('#') {
524 continue;
525 }
526
527 let parts = line.split_whitespace().collect::<Vec<&str>>();
528 let Ok(label) = parts[0].trim().parse::<f32>() else {
529 bail!(
530 "Encountered a non-numeric label {} on line {row_number}",
531 parts[0]
532 );
533 };
534 let mut row = vec![0.0f32; features.len()];
535
536 for part in parts.iter().skip(1) {
537 let part_parts = part.split(':').collect::<Vec<&str>>();
538 let Ok(part_index) = part_parts[0].trim().parse::<usize>() else {
539 bail!(
540 "Encountered a non-numeric index {} on line {row_number}",
541 part_parts[0]
542 );
543 };
544 let Ok(part_value) = part_parts[1].trim().parse::<f32>() else {
545 bail!(
546 "Encountered a non-numeric value {} on line {row_number}",
547 part_parts[1]
548 );
549 };
550
551 if part_index > row.len() && !features.is_empty() {
552 bail!("Encountered a value at index {part_index} greater than expected size {} on line {row_number}", data.len());
553 }
554
555 if row.is_empty() {
556 row = vec![0.0; part_index + 1];
557 } else if part_index >= row.len() {
558 row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
559 }
560 row[part_index] = part_value;
561 }
562
563 data.push(row);
564 labels.push(label);
565 }
566
567 let data_len = data[0].len();
568 for row in &data {
569 if row.len() != data_len {
570 bail!(
571 "Encountered a row with length {} but expected length {data_len}",
572 row.len()
573 );
574 }
575 }
576
577 ensure!(
578 features.len() == data[0].len(),
579 "Features need to be empty or the same size as the data length."
580 );
581
582 ensure!(
583 file_type != FileType::NotSet,
584 "No file type specified in libsvm file."
585 );
586 Ok(Self {
587 data,
588 labels,
589 features,
590 ftype: file_type,
591 })
592 }
593
594 #[allow(clippy::too_many_lines)]
601 pub fn create_save_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
602 malicious_dir: P,
603 benign_dir: P,
604 ngrams_file: P,
605 output_file: P,
606 ) -> Result<()> {
607 const SUPPORTED_FORMATS: [DatasetFormat; 3] =
608 [DatasetFormat::CSV, DatasetFormat::ARFF, DatasetFormat::SVM];
609
610 let output_format = DatasetFormat::try_from(output_file.as_ref())?;
611 ensure!(
612 SUPPORTED_FORMATS.contains(&output_format),
613 "Only CSV, ARFF, or SVM formats are supported here."
614 );
615
616 let ngrams = NgramsFile::load(ngrams_file)?;
617 let mut output_file = std::fs::File::create(output_file)?;
618 writeln!(output_file, "# {FILE_TYPE_PREFIX} {:?}", ngrams.ftype)?;
619
620 match output_format {
621 DatasetFormat::SVM | DatasetFormat::CSV => {
622 let feature_string_vec = ngrams
623 .clone()
624 .into_vec()
625 .iter()
626 .map(hex::encode)
627 .collect::<Vec<String>>();
628 writeln!(
629 output_file,
630 "# {FEATURES_PREFIX} {}",
631 feature_string_vec.join(", ")
632 )?;
633 }
634
635 DatasetFormat::ARFF => {
636 let feature_string_vec = ngrams
637 .clone()
638 .into_vec()
639 .iter()
640 .map(hex::encode)
641 .collect::<Vec<String>>();
642 for feature in feature_string_vec {
643 let feature_hex = hex::encode(feature);
644 writeln!(output_file, "@ATTRIBUTE {feature_hex} NUMERIC")?;
645 }
646 }
647 }
648
649 for entry in WalkDir::new(malicious_dir)
650 .max_depth(crate::MAX_RECURSION_DEPTH)
651 .follow_links(true)
652 .into_iter()
653 .flatten()
654 {
655 if entry.file_type().is_file() {
656 match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
657 Ok(features) => match output_format {
658 DatasetFormat::CSV | DatasetFormat::ARFF => {
659 let mut line = features
660 .iter()
661 .map(|p| format!("{p}"))
662 .collect::<Vec<String>>()
663 .join(",");
664 line.push_str(",1\n");
665 output_file.write_all(line.as_bytes())?;
666 }
667
668 DatasetFormat::SVM => {
669 write!(output_file, "1")?;
670 for (data_index, data) in features.iter().enumerate() {
671 if *data != 0.0000 {
672 write!(output_file, " {data_index}:{data}")?;
673 }
674 }
675 writeln!(output_file)?;
676 }
677 },
678 Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
679 }
680 }
681 }
682
683 for entry in WalkDir::new(benign_dir)
684 .max_depth(crate::MAX_RECURSION_DEPTH)
685 .follow_links(true)
686 .into_iter()
687 .flatten()
688 {
689 if entry.file_type().is_file() {
690 match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
691 Ok(features) => match output_format {
692 DatasetFormat::CSV | DatasetFormat::ARFF => {
693 let mut line = features
694 .iter()
695 .map(|p| format!("{p}"))
696 .collect::<Vec<String>>()
697 .join(",");
698 line.push_str(",0\n");
699 output_file.write_all(line.as_bytes())?;
700 }
701
702 DatasetFormat::SVM => {
703 write!(output_file, "0")?;
704 for (data_index, data) in features.iter().enumerate() {
705 if *data != 0.0000 {
706 write!(output_file, " {data_index}:{data}")?;
707 }
708 }
709 writeln!(output_file)?;
710 }
711 },
712 Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
713 }
714 }
715 }
716
717 output_file.sync_all()?;
718 Ok(())
719 }
720
721 pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
727 let mut file = std::fs::File::create(path)?;
728
729 let feature_string_vec = self
730 .features
731 .iter()
732 .map(hex::encode)
733 .collect::<Vec<String>>();
734 writeln!(
735 file,
736 "# {FEATURES_PREFIX} {}",
737 feature_string_vec.join(", ")
738 )?;
739 writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
740
741 for index in 0..self.data.len() {
742 let mut line = self.data[index]
743 .iter()
744 .map(|p| format!("{p}"))
745 .collect::<Vec<String>>()
746 .join(",");
747
748 if !self.labels.is_empty() {
749 if self.labels[index] > 0.9 {
750 line.push_str(",1");
751 } else {
752 line.push_str(",0");
753 }
754 }
755 line.push('\n');
756
757 file.write_all(line.as_bytes())?;
758 }
759
760 file.sync_all().map_err(Into::into)
761 }
762
763 pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
769 let mut file = std::fs::File::create(path)?;
770 writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
771
772 for feature in &self.features {
773 let feature_hex = hex::encode(feature);
774 file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
775 }
776
777 if !self.labels.is_empty() {
778 file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
779 }
780
781 file.write_all("\n@DATA\n".as_bytes())?;
782 for index in 0..self.data.len() {
783 let mut line = self.data[index]
784 .iter()
785 .map(|p| format!("{p}"))
786 .collect::<Vec<String>>()
787 .join(",");
788
789 if !self.labels.is_empty() {
790 if self.labels[index] > 0.9 {
791 line.push_str(",1");
792 } else {
793 line.push_str(",0");
794 }
795 }
796 line.push('\n');
797
798 file.write_all(line.as_bytes())?;
799 }
800
801 file.sync_all().map_err(Into::into)
802 }
803
804 pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
810 ensure!(
811 !self.labels.is_empty(),
812 "Labels are required to create an libsvm file."
813 );
814 let mut file = std::fs::File::create(path)?;
815
816 let feature_string_vec = self
817 .features
818 .iter()
819 .map(hex::encode)
820 .collect::<Vec<String>>();
821 writeln!(
822 file,
823 "# {FEATURES_PREFIX} {}",
824 feature_string_vec.join(", ")
825 )?;
826 writeln!(file, "# {FILE_TYPE_PREFIX} {:?}", self.ftype)?;
827
828 for index in 0..self.data.len() {
829 file.write_all(format!("{}", self.labels[index]).as_bytes())?;
830 for (data_index, data) in self.data[index].iter().enumerate() {
831 if *data != 0.0000 {
832 file.write_all(format!(" {data_index}:{data}").as_bytes())?;
833 }
834 }
835
836 file.write_all(b"\n")?;
837 }
838
839 file.sync_all().map_err(Into::into)
840 }
841
842 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
848 if let Some(extension) = path.as_ref().extension() {
849 return match extension.to_str().unwrap_or_default() {
850 "arff" => self.save_arff(path),
851 "csv" => self.save_csv(path),
852 "svm" | "libsvm" => self.save_libsvm(path),
853 "json" => {
854 let contents = serde_json::to_string_pretty(self)?;
855 let mut file = std::fs::File::create(path)?;
856 file.write_all(contents.as_bytes())?;
857 file.sync_all().map_err(Into::into)
858 }
859 "toml" => {
860 let contents = toml::to_string_pretty(self)?;
861 let mut file = std::fs::File::create(path)?;
862 file.write_all(contents.as_bytes())?;
863 file.sync_all().map_err(Into::into)
864 }
865 ext => {
866 bail!("Unsupported/unknown data type '{ext}'");
867 }
868 };
869 }
870
871 bail!("No extension, can't determine file type.");
872 }
873
874 #[inline]
876 #[must_use]
877 pub fn len(&self) -> usize {
878 self.data.len()
879 }
880
881 #[inline]
883 #[must_use]
884 pub fn is_empty(&self) -> bool {
885 self.data.is_empty()
886 }
887
888 #[inline]
892 #[must_use]
893 pub fn validate(&self) -> bool {
894 let data_len = match self.data.first() {
895 Some(first) => first.len(),
896 None => return false,
897 };
898
899 for record in &self.data {
901 if record.len() != data_len {
902 #[cfg(debug_assertions)]
903 eprint!("Expected record size {data_len}, got {}", record.len());
904 return false;
905 }
906 }
907
908 let feature_len = if let Some(first) = self.features.first() {
909 first.len()
910 } else {
911 #[cfg(debug_assertions)]
912 eprintln!("Features data is missing");
913 return false;
914 };
915
916 for feature in &self.features {
917 if feature.len() != feature_len {
918 #[cfg(debug_assertions)]
919 eprint!("Expected feature size {feature_len}, got {}", feature.len());
920 return false;
921 }
922 }
923
924 (self.labels.is_empty() || self.labels.len() == self.data.len())
926 && self.features.len() == data_len
927 && self.ftype != FileType::NotSet
928 }
929
930 pub fn shuffle(&mut self) {
933 if !self.is_empty() {
935 let iterations = self.data.len().ilog10() * 10;
936 self.shuffle_iterations(iterations);
937 }
938 }
939
940 pub fn shuffle_iterations(&mut self, iterations: u32) {
943 use rand::Rng;
944
945 if !self.is_empty() {
946 let mut rng = rand::rng();
947
948 for _ in 0..iterations {
949 let a = rng.random_range(0..self.data.len());
950 let b = rng.random_range(0..self.data.len());
951 let b = if b == a {
952 rng.random_range(0..self.data.len())
953 } else {
954 b
955 };
956
957 self.data.swap(a, b);
958 if !self.labels.is_empty() {
959 self.labels.swap(a, b);
960 }
961 }
962 }
963 }
964
965 #[must_use]
968 #[allow(
969 clippy::cast_sign_loss,
970 clippy::cast_possible_truncation,
971 clippy::cast_precision_loss
972 )]
973 pub fn split(&mut self, ratio: f32) -> Self {
974 let ratio = ratio.abs();
975 let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
976 let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
977
978 let new_data = self.data.drain(new_size..).collect();
979 let new_labels = if self.labels.is_empty() {
980 vec![]
981 } else {
982 self.labels.drain(new_size..).collect()
983 };
984
985 Self {
986 data: new_data,
987 labels: new_labels,
988 features: self.features.clone(),
989 ftype: self.ftype,
990 }
991 }
992
993 pub fn reduce(&mut self, model: &LogisticRegression) -> Result<Vec<usize>> {
1002 let mut removed = vec![];
1003
1004 for (index, feature) in self.features.iter().enumerate() {
1005 if !model.features.contains_key(feature) {
1006 removed.push(index);
1007 }
1008 }
1009
1010 if removed.len() == self.data[0].len() {
1011 bail!("This dataset and model are probably not from the same data - this operation would delete all the data!");
1012 }
1013
1014 removed.sort_unstable();
1015 removed.reverse();
1016
1017 self.features
1018 .retain(|feature| model.features.contains_key(feature));
1019
1020 for row in &mut self.data {
1021 for removed in &removed {
1022 row.remove(*removed);
1023 }
1024 }
1025
1026 Ok(removed)
1027 }
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032 use crate::dataset::Dataset;
1033
1034 #[test]
1035 fn xor() {
1036 let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1037 assert!(csv_dataset.validate());
1038
1039 let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
1040 assert!(arff_dataset.validate());
1041
1042 let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
1043 assert!(svm_dataset.validate());
1044
1045 assert_eq!(csv_dataset, arff_dataset);
1046 assert_eq!(csv_dataset, svm_dataset);
1047 assert_eq!(arff_dataset, svm_dataset);
1048 }
1049
1050 #[test]
1051 fn xor_no_label() {
1052 assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
1053 assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
1054 }
1055
1056 #[test]
1057 fn shuffle() {
1058 let original_dataset =
1059 Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1060 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1061 dataset.shuffle();
1062
1063 assert_eq!(original_dataset, dataset);
1064 assert_ne!(original_dataset.data, dataset.data);
1065 assert_ne!(original_dataset.labels, dataset.labels);
1066 assert_eq!(original_dataset.features, dataset.features);
1067 }
1068
1069 #[test]
1070 fn split() {
1071 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1072 let original_size = dataset.len();
1073 let smaller = dataset.split(0.8);
1074
1075 println!(
1076 "Original: {original_size}, New size: {}, Smaller dataset: {}",
1077 dataset.len(),
1078 smaller.len()
1079 );
1080 assert!(smaller.len() < dataset.len());
1081 assert_eq!(original_size, dataset.len() + smaller.len());
1082 assert_ne!(dataset, smaller);
1083 assert_eq!(dataset.features, smaller.features);
1084 }
1085
1086 #[test]
1087 fn save() {
1088 const COPY_CSV: &str = "xor_copy.csv";
1089 const COPY_ARFF: &str = "xor_copy.arff";
1090 const COPY_SVM: &str = "xor_copy.svm";
1091
1092 let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1093 dataset.save_csv(COPY_CSV).unwrap();
1094 dataset.save_arff(COPY_ARFF).unwrap();
1095 dataset.save_libsvm(COPY_SVM).unwrap();
1096
1097 let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
1098 assert_eq!(dataset, dataset2);
1099
1100 let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
1101 assert_eq!(dataset, dataset3);
1102 assert_eq!(dataset2, dataset3);
1103
1104 let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
1105 assert_eq!(dataset, dataset4);
1106 assert_eq!(dataset3, dataset4);
1107
1108 std::fs::remove_file(COPY_CSV).unwrap();
1109 std::fs::remove_file(COPY_ARFF).unwrap();
1110 std::fs::remove_file(COPY_SVM).unwrap();
1111 }
1112}