1use crate::Bytes;
4use crate::ftype::FileType;
5use crate::model::LogisticRegression;
6use crate::ngram::NgramsFile;
7
8use std::collections::HashMap;
9use std::io::{Read, Seek, SeekFrom, Write};
10use std::path::Path;
11use std::str::FromStr;
12
13use anyhow::{Result, anyhow, bail, ensure};
14use serde::de::IntoDeserializer;
15use serde::{Deserialize, Serialize};
16use walkdir::WalkDir;
17
18const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
19const FEATURES_PREFIX: &str = "Features:";
20const FILE_TYPE_PREFIX: &str = "File type:";
21
22#[inline]
28pub(crate) fn featurize_file<P: AsRef<Path>, S: ::std::hash::BuildHasher>(
29 file: P,
30 n: usize,
31 features: &HashMap<Bytes, usize, S>,
32) -> Result<Vec<f32>> {
33 let file_size = std::fs::metadata(&file)?.len();
34 ensure!(
35 file_size > n as u64,
36 "File {} is too small.",
37 file.as_ref().display()
38 );
39
40 let mut feature_vector = vec![0.0; features.len()];
41
42 if file_size < 10_485_760u64
43 {
45 let contents = std::fs::read(file)?;
46 for window in contents.windows(n) {
47 if let Some(index) = features.get(window) {
48 feature_vector[*index] = 1.0;
49 }
50 }
51 } else {
52 let mut file = std::fs::File::open(file)?;
53 let mut buffer = [0u8; crate::ngram::NGRAM_BUFFER_SIZE];
54 while let Ok(bytes_read) = file.read(&mut buffer) {
55 if bytes_read < n {
56 break;
57 }
58 for index in 0..bytes_read - n {
59 if let Some(index) = features.get(&buffer[index..index + n]) {
60 feature_vector[*index] = 1.0;
61 }
62 }
63
64 #[allow(clippy::cast_possible_wrap)]
68 file.seek(SeekFrom::Current(n as i64 - 1))?;
69 }
70 }
71
72 Ok(feature_vector)
73}
74
75#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
77pub enum DatasetFormat {
78 ARFF,
80
81 CSV,
83
84 SVM,
86}
87
88impl FromStr for DatasetFormat {
89 type Err = anyhow::Error;
90
91 fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
92 match value.to_lowercase().as_str() {
93 "arff" => Ok(Self::ARFF),
94 "csv" => Ok(Self::CSV),
95 "svm" => Ok(Self::SVM),
96 x => Err(anyhow!("Unknown data format '{x}'")),
97 }
98 }
99}
100
101impl TryFrom<&Path> for DatasetFormat {
102 type Error = anyhow::Error;
103
104 fn try_from(value: &Path) -> std::result::Result<Self, Self::Error> {
105 if let Some(extension) = value.extension() {
106 let ext = extension
107 .to_str()
108 .ok_or_else(|| anyhow!("Failed to get extension."))?;
109 DatasetFormat::from_str(ext)
110 } else {
111 Err(anyhow!("No extension, can't determine file type."))
112 }
113 }
114}
115
116#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct Dataset {
119 pub data: Vec<Vec<f32>>,
121
122 #[serde(default)]
124 pub labels: Vec<u8>,
125
126 #[serde(
128 serialize_with = "crate::serde::serialize_hex_vec",
129 deserialize_with = "crate::serde::deserialize_hex_vec"
130 )]
131 pub features: Vec<Bytes>,
132
133 pub ftype: FileType,
135}
136
137impl PartialEq for Dataset {
138 fn eq(&self, other: &Self) -> bool {
139 if self.data.len() != other.data.len()
140 || self.labels.len() != other.labels.len()
141 || self.features.len() != other.features.len()
142 || self.ftype != other.ftype
143 {
144 return false;
145 }
146
147 for this_data in &self.data {
148 if !other.data.contains(this_data) {
149 return false;
150 }
151 }
152
153 for other_data in &other.data {
154 if !self.data.contains(other_data) {
155 return false;
156 }
157 }
158
159 if !self.labels.is_empty() {
160 for this_label in &self.labels {
161 if !other.labels.contains(this_label) {
162 return false;
163 }
164 }
165
166 for other_label in &other.labels {
167 if !self.labels.contains(other_label) {
168 return false;
169 }
170 }
171 }
172
173 for this_features in &self.features {
174 if !other.features.contains(this_features) {
175 return false;
176 }
177 }
178
179 for other_feature in &other.features {
180 if !self.features.contains(other_feature) {
181 return false;
182 }
183 }
184
185 true
186 }
187}
188
189impl Dataset {
190 pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
197 if let Some(extension) = path.as_ref().extension() {
198 return match extension.to_str().unwrap_or_default() {
199 "arff" => Dataset::from_arff_file(path.as_ref()),
200 "csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
201 "svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
202 "json" => {
203 let contents = std::fs::read_to_string(path.as_ref())?;
204 serde_json::from_str(&contents).map_err(Into::into)
205 }
206 "toml" => {
207 let contents = std::fs::read_to_string(path.as_ref())?;
208 toml::from_str(&contents).map_err(Into::into)
209 }
210 ext => {
211 bail!("Unsupported/unknown data type '{ext}'");
212 }
213 };
214 }
215
216 bail!("No extension, can't determine file type.");
217 }
218
219 pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
229 let mut file = std::fs::File::open(path)?;
230 let mut contents = String::new();
231 file.read_to_string(&mut contents)?;
232
233 Self::from_csv_string(&contents, data_length)
234 }
235
236 pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
246 let mut file = std::fs::File::open(path)?;
247 let mut contents = String::new();
248 file.read_to_string(&mut contents)?;
249
250 let mut length = 0;
251 for line in contents.lines() {
252 if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
253 continue;
254 }
255
256 length = line.split(',').collect::<Vec<&str>>().len();
257 break;
258 }
259
260 ensure!(length > 0, "Failed to determine data length.");
261 Self::from_csv_string(&contents, length - 1)
262 }
263
264 pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
273 let mut data: Vec<Vec<f32>> = Vec::new();
274 let mut labels = Vec::new();
275 let mut features = Vec::new();
276 let mut file_type = FileType::NotSet;
277
278 for (row_number, line) in contents.lines().enumerate() {
279 if line.is_empty() {
280 continue;
281 }
282
283 if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
284 if line.contains(FEATURES_PREFIX) {
285 let offset =
286 line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
287 let line = line[offset..].trim();
288 features = line
289 .split(',')
290 .filter_map(|f| hex::decode(f.trim()).ok())
291 .collect();
292 }
293
294 if line.contains(FILE_TYPE_PREFIX)
295 && let Some(file_type_str) = line.split(':').nth(1)
296 {
297 let file_type_str = file_type_str.trim();
298 let ftype: Result<_, serde::de::value::Error> =
299 FileType::deserialize(String::from(file_type_str).into_deserializer());
300 file_type = ftype?;
301 }
302 }
303
304 if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
305 continue;
306 }
307 let row = line.split(',').collect::<Vec<&str>>();
308 let mut row_float = Vec::with_capacity(data_length);
309 for r in row.iter().take(data_length) {
310 row_float.push(r.parse::<f32>().map_err(|_| {
311 anyhow::Error::msg(format!("Non-float {r} encountered on CSV row {row_number}"))
312 })?);
313 }
314 if let Some(first_row) = data.first() {
315 ensure!(
316 first_row.len() == row_float.len(),
317 "CSV line {row_number} has invalid length {}, expected {}",
318 row_float.len(),
319 first_row.len()
320 );
321 }
322 data.push(row_float);
323 if row.len() == data_length + 1 {
324 let l = row[data_length].parse::<u8>().map_err(|_| {
325 anyhow::Error::msg(format!(
326 "Non-float label {} encountered on CSV row {row_number}",
327 row[data_length]
328 ))
329 })?;
330 labels.push(l);
331 } else if row.len() > data_length {
332 bail!(
333 "CSV row had more than one label on row {row_number}, which isn't supported."
334 );
335 }
336 }
337
338 ensure!(
339 features.len() == data[0].len(),
340 "Features need to be empty or the same size as the data length."
341 );
342
343 ensure!(
344 file_type != FileType::NotSet,
345 "No file type specified in CSV file."
346 );
347 Ok(Self {
348 data,
349 labels,
350 features,
351 ftype: file_type,
352 })
353 }
354
355 #[inline]
357 pub(crate) fn file_type_from_line(line: &str) -> Result<FileType, serde::de::value::Error> {
358 let line = line.split(':').nth(1).unwrap_or(line).to_uppercase();
359 let ftype: Result<_, serde::de::value::Error> =
360 FileType::deserialize(String::from(line.trim()).into_deserializer());
361 ftype
362 }
363
364 pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
374 let mut file = std::fs::File::open(path)?;
375 let mut contents = String::new();
376 file.read_to_string(&mut contents)?;
377
378 Self::from_arff_string(&contents)
379 }
380
381 pub fn from_arff_string(contents: &str) -> Result<Self> {
390 let mut data: Vec<Vec<f32>> = Vec::new();
391 let mut labels = Vec::new();
392 let mut features = Vec::new();
393 let mut file_type = FileType::NotSet;
394 let mut passed_data = false;
395
396 for (row_number, line) in contents.lines().enumerate() {
397 if line.is_empty() {
398 continue;
399 }
400
401 if (line.starts_with('%') || line.starts_with('#')) && line.contains(FILE_TYPE_PREFIX) {
402 file_type = Self::file_type_from_line(line)?;
403 continue;
404 }
405
406 if line.contains("@ATTRIBUTE") {
407 let parts: Vec<&str> = line.split_ascii_whitespace().collect();
408 if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
409 match hex::decode(parts[1]) {
410 Ok(feat) => features.push(feat),
411 Err(e) => {
412 bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
413 }
414 }
415 }
416 }
417
418 if line.contains("@DATA") {
419 passed_data = true;
420 continue;
421 }
422
423 if passed_data {
425 let row = line.split(',').collect::<Vec<&str>>();
426 let data_length = row.len() - 1;
427 let mut row_float = Vec::with_capacity(data_length);
428 for r in row.iter().take(data_length) {
429 row_float.push(r.parse::<f32>().map_err(|_| {
430 anyhow::Error::msg(format!(
431 "Non-float encountered on ARFF row {row_number}"
432 ))
433 })?);
434 }
435 if let Some(first_row) = data.first() {
436 ensure!(
437 first_row.len() == row_float.len(),
438 "ARFF line {row_number} has invalid length {}, expected {}",
439 row_float.len(),
440 first_row.len()
441 );
442 }
443 data.push(row_float);
444 if row.len() == data_length + 1 {
445 let l = row[data_length].parse::<u8>().map_err(|_| {
446 anyhow::Error::msg(format!(
447 "Non-float encountered on ARFF row {row_number}"
448 ))
449 })?;
450 labels.push(l);
451 } else if row.len() > data_length {
452 bail!(
453 "Arff row had more than one label on row {row_number}, which isn't supported."
454 );
455 }
456 }
457 }
458
459 ensure!(
460 features.len() == data[0].len(),
461 "Features need to be empty or the same size as the data length."
462 );
463
464 ensure!(
465 file_type != FileType::NotSet,
466 "No file type specified in ARFF file."
467 );
468 Ok(Self {
469 data,
470 labels,
471 features,
472 ftype: file_type,
473 })
474 }
475
476 pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
486 let mut file = std::fs::File::open(path)?;
487 let mut contents = String::new();
488 file.read_to_string(&mut contents)?;
489
490 Self::from_libsvm_string(&contents)
491 }
492
493 pub fn from_libsvm_string(contents: &str) -> Result<Self> {
499 let mut data = Vec::new();
500 let mut labels = Vec::new();
501 let mut features = Vec::new();
502 let mut file_type = FileType::NotSet;
503
504 for (row_number, line) in contents.lines().enumerate() {
505 if line.is_empty() {
506 continue;
507 }
508
509 if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
510 if line.contains(FEATURES_PREFIX) {
511 let offset =
512 line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
513 let line = line[offset..].trim();
514 features = line
515 .split(',')
516 .filter_map(|f| hex::decode(f.trim()).ok())
517 .collect();
518 }
519
520 if line.contains(FILE_TYPE_PREFIX) {
521 file_type = Self::file_type_from_line(line)?;
522 }
523 }
524
525 if line.is_empty() || line.starts_with('%') || line.starts_with('#') {
526 continue;
527 }
528
529 let parts = line.split_whitespace().collect::<Vec<&str>>();
530 let Ok(label) = parts[0].trim().parse::<u8>() else {
531 bail!(
532 "Encountered a non-numeric label {} on line {row_number}",
533 parts[0]
534 );
535 };
536 let mut row = vec![0.0f32; features.len()];
537
538 for part in parts.iter().skip(1) {
539 let part_parts = part.split(':').collect::<Vec<&str>>();
540 let Ok(part_index) = part_parts[0].trim().parse::<usize>() else {
541 bail!(
542 "Encountered a non-numeric index {} on line {row_number}",
543 part_parts[0]
544 );
545 };
546 let Ok(part_value) = part_parts[1].trim().parse::<f32>() else {
547 bail!(
548 "Encountered a non-numeric value {} on line {row_number}",
549 part_parts[1]
550 );
551 };
552
553 if part_index > row.len() && !features.is_empty() {
554 bail!(
555 "Encountered a value at index {part_index} greater than expected size {} on line {row_number}",
556 data.len()
557 );
558 }
559
560 if row.is_empty() {
561 row = vec![0.0; part_index + 1];
562 } else if part_index >= row.len() {
563 row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
564 }
565 row[part_index] = part_value;
566 }
567
568 data.push(row);
569 labels.push(label);
570 }
571
572 let data_len = data[0].len();
573 for row in &data {
574 if row.len() != data_len {
575 bail!(
576 "Encountered a row with length {} but expected length {data_len}",
577 row.len()
578 );
579 }
580 }
581
582 ensure!(
583 features.len() == data[0].len(),
584 "Features need to be empty or the same size as the data length."
585 );
586
587 ensure!(
588 file_type != FileType::NotSet,
589 "No file type specified in libsvm file."
590 );
591 Ok(Self {
592 data,
593 labels,
594 features,
595 ftype: file_type,
596 })
597 }
598
599 #[allow(clippy::too_many_lines)]
606 pub fn create_save_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
607 malicious_dir: P,
608 benign_dir: P,
609 ngrams_file: P,
610 output_file: P,
611 ) -> Result<()> {
612 const SUPPORTED_FORMATS: [DatasetFormat; 3] =
613 [DatasetFormat::CSV, DatasetFormat::ARFF, DatasetFormat::SVM];
614
615 let output_format = DatasetFormat::try_from(output_file.as_ref())?;
616 ensure!(
617 SUPPORTED_FORMATS.contains(&output_format),
618 "Only CSV, ARFF, or SVM formats are supported here."
619 );
620
621 let ngrams = NgramsFile::load(ngrams_file)?;
622 let mut output_file = std::fs::File::create(output_file)?;
623 writeln!(output_file, "# {FILE_TYPE_PREFIX} {:?}", ngrams.ftype)?;
624
625 match output_format {
626 DatasetFormat::SVM | DatasetFormat::CSV => {
627 let feature_string_vec = ngrams
628 .clone()
629 .into_vec()
630 .iter()
631 .map(hex::encode)
632 .collect::<Vec<String>>();
633 writeln!(
634 output_file,
635 "# {FEATURES_PREFIX} {}",
636 feature_string_vec.join(", ")
637 )?;
638 }
639
640 DatasetFormat::ARFF => {
641 let feature_string_vec = ngrams
642 .clone()
643 .into_vec()
644 .iter()
645 .map(hex::encode)
646 .collect::<Vec<String>>();
647 for feature in feature_string_vec {
648 let feature_hex = hex::encode(feature);
649 writeln!(output_file, "@ATTRIBUTE {feature_hex} NUMERIC")?;
650 }
651 }
652 }
653
654 for entry in WalkDir::new(malicious_dir)
655 .max_depth(crate::MAX_RECURSION_DEPTH)
656 .follow_links(true)
657 .into_iter()
658 .flatten()
659 {
660 if entry.file_type().is_file() {
661 match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
662 Ok(features) => match output_format {
663 DatasetFormat::CSV | DatasetFormat::ARFF => {
664 let mut line = features
665 .iter()
666 .map(|p| format!("{p}"))
667 .collect::<Vec<String>>()
668 .join(",");
669 line.push_str(",1\n");
670 output_file.write_all(line.as_bytes())?;
671 }
672
673 DatasetFormat::SVM => {
674 write!(output_file, "1")?;
675 for (data_index, data) in features.iter().enumerate() {
676 if *data != 0.0000 {
677 write!(output_file, " {data_index}:{data}")?;
678 }
679 }
680 writeln!(output_file)?;
681 }
682 },
683 Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
684 }
685 }
686 }
687
688 for entry in WalkDir::new(benign_dir)
689 .max_depth(crate::MAX_RECURSION_DEPTH)
690 .follow_links(true)
691 .into_iter()
692 .flatten()
693 {
694 if entry.file_type().is_file() {
695 match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
696 Ok(features) => match output_format {
697 DatasetFormat::CSV | DatasetFormat::ARFF => {
698 let mut line = features
699 .iter()
700 .map(|p| format!("{p}"))
701 .collect::<Vec<String>>()
702 .join(",");
703 line.push_str(",0\n");
704 output_file.write_all(line.as_bytes())?;
705 }
706
707 DatasetFormat::SVM => {
708 write!(output_file, "0")?;
709 for (data_index, data) in features.iter().enumerate() {
710 if *data != 0.0000 {
711 write!(output_file, " {data_index}:{data}")?;
712 }
713 }
714 writeln!(output_file)?;
715 }
716 },
717 Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
718 }
719 }
720 }
721
722 output_file.sync_all()?;
723 Ok(())
724 }
725
726 pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
732 let mut file = std::fs::File::create(path)?;
733
734 let feature_string_vec = self
735 .features
736 .iter()
737 .map(hex::encode)
738 .collect::<Vec<String>>();
739 writeln!(
740 file,
741 "# {FEATURES_PREFIX} {}",
742 feature_string_vec.join(", ")
743 )?;
744 writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
745
746 for index in 0..self.data.len() {
747 let mut line = self.data[index]
748 .iter()
749 .map(|p| format!("{p}"))
750 .collect::<Vec<String>>()
751 .join(",");
752
753 if !self.labels.is_empty() {
754 line = format!("{line},{}", self.labels[index]);
755 }
756 line.push('\n');
757
758 file.write_all(line.as_bytes())?;
759 }
760
761 file.sync_all().map_err(Into::into)
762 }
763
764 pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
770 let mut file = std::fs::File::create(path)?;
771 writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
772
773 for feature in &self.features {
774 let feature_hex = hex::encode(feature);
775 file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
776 }
777
778 if !self.labels.is_empty() {
779 file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
780 }
781
782 file.write_all("\n@DATA\n".as_bytes())?;
783 for index in 0..self.data.len() {
784 let mut line = self.data[index]
785 .iter()
786 .map(|p| format!("{p}"))
787 .collect::<Vec<String>>()
788 .join(",");
789
790 if !self.labels.is_empty() {
791 line = format!("{line},{}", self.labels[index]);
792 }
793 line.push('\n');
794
795 file.write_all(line.as_bytes())?;
796 }
797
798 file.sync_all().map_err(Into::into)
799 }
800
801 pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
807 ensure!(
808 !self.labels.is_empty(),
809 "Labels are required to create an libsvm file."
810 );
811 let mut file = std::fs::File::create(path)?;
812
813 let feature_string_vec = self
814 .features
815 .iter()
816 .map(hex::encode)
817 .collect::<Vec<String>>();
818 writeln!(
819 file,
820 "# {FEATURES_PREFIX} {}",
821 feature_string_vec.join(", ")
822 )?;
823 writeln!(file, "# {FILE_TYPE_PREFIX} {:?}", self.ftype)?;
824
825 for index in 0..self.data.len() {
826 file.write_all(format!("{}", self.labels[index]).as_bytes())?;
827 for (data_index, data) in self.data[index].iter().enumerate() {
828 if *data != 0.0000 {
829 file.write_all(format!(" {data_index}:{data}").as_bytes())?;
830 }
831 }
832
833 file.write_all(b"\n")?;
834 }
835
836 file.sync_all().map_err(Into::into)
837 }
838
839 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
845 if let Some(extension) = path.as_ref().extension() {
846 return match extension.to_str().unwrap_or_default() {
847 "arff" => self.save_arff(path),
848 "csv" => self.save_csv(path),
849 "svm" | "libsvm" => self.save_libsvm(path),
850 "json" => {
851 let contents = serde_json::to_string_pretty(self)?;
852 let mut file = std::fs::File::create(path)?;
853 file.write_all(contents.as_bytes())?;
854 file.sync_all().map_err(Into::into)
855 }
856 "toml" => {
857 let contents = toml::to_string_pretty(self)?;
858 let mut file = std::fs::File::create(path)?;
859 file.write_all(contents.as_bytes())?;
860 file.sync_all().map_err(Into::into)
861 }
862 ext => {
863 bail!("Unsupported/unknown data type '{ext}'");
864 }
865 };
866 }
867
868 bail!("No extension, can't determine file type.");
869 }
870
871 #[inline]
873 #[must_use]
874 pub fn len(&self) -> usize {
875 self.data.len()
876 }
877
878 #[inline]
880 #[must_use]
881 pub fn is_empty(&self) -> bool {
882 self.data.is_empty()
883 }
884
885 #[inline]
889 #[must_use]
890 pub fn validate(&self) -> bool {
891 let data_len = match self.data.first() {
892 Some(first) => first.len(),
893 None => return false,
894 };
895
896 for record in &self.data {
898 if record.len() != data_len {
899 #[cfg(debug_assertions)]
900 eprint!("Expected record size {data_len}, got {}", record.len());
901 return false;
902 }
903 }
904
905 let feature_len = if let Some(first) = self.features.first() {
906 first.len()
907 } else {
908 #[cfg(debug_assertions)]
909 eprintln!("Features data is missing");
910 return false;
911 };
912
913 for feature in &self.features {
914 if feature.len() != feature_len {
915 #[cfg(debug_assertions)]
916 eprint!("Expected feature size {feature_len}, got {}", feature.len());
917 return false;
918 }
919 }
920
921 (self.labels.is_empty() || self.labels.len() == self.data.len())
923 && self.features.len() == data_len
924 && self.ftype != FileType::NotSet
925 }
926
927 pub fn shuffle(&mut self) {
930 if !self.is_empty() {
932 let iterations = self.data.len().ilog10() * 10;
933 self.shuffle_iterations(iterations);
934 }
935 }
936
937 pub fn shuffle_iterations(&mut self, iterations: u32) {
940 use rand::RngExt;
941
942 if !self.is_empty() {
943 let mut rng = rand::rng();
944
945 for _ in 0..iterations {
946 let a = rng.random_range(0..self.data.len());
947 let b = rng.random_range(0..self.data.len());
948 let b = if b == a {
949 rng.random_range(0..self.data.len())
950 } else {
951 b
952 };
953
954 self.data.swap(a, b);
955 if !self.labels.is_empty() {
956 self.labels.swap(a, b);
957 }
958 }
959 }
960 }
961
962 #[must_use]
965 #[allow(
966 clippy::cast_sign_loss,
967 clippy::cast_possible_truncation,
968 clippy::cast_precision_loss
969 )]
970 pub fn split(&mut self, ratio: f32) -> Self {
971 let ratio = ratio.abs();
972 let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
973 let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
974
975 let new_data = self.data.drain(new_size..).collect();
976 let new_labels = if self.labels.is_empty() {
977 vec![]
978 } else {
979 self.labels.drain(new_size..).collect()
980 };
981
982 Self {
983 data: new_data,
984 labels: new_labels,
985 features: self.features.clone(),
986 ftype: self.ftype,
987 }
988 }
989
990 pub fn reduce(&mut self, model: &LogisticRegression) -> Result<Vec<usize>> {
999 let mut removed = vec![];
1000
1001 for (index, feature) in self.features.iter().enumerate() {
1002 if !model.features.contains_key(feature) {
1003 removed.push(index);
1004 }
1005 }
1006
1007 if removed.len() == self.data[0].len() {
1008 bail!(
1009 "This dataset and model are probably not from the same data - this operation would delete all the data!"
1010 );
1011 }
1012
1013 removed.sort_unstable();
1014 removed.reverse();
1015
1016 self.features
1017 .retain(|feature| model.features.contains_key(feature));
1018
1019 for row in &mut self.data {
1020 for removed in &removed {
1021 row.remove(*removed);
1022 }
1023 }
1024
1025 Ok(removed)
1026 }
1027
1028 #[must_use]
1030 pub fn column_iter(&'_ self, index: usize) -> Option<ColumnIterator<'_>> {
1031 if index < self.data[0].len() {
1032 Some(ColumnIterator {
1033 dataset: self,
1034 column_index: index,
1035 current_row_index: 0,
1036 })
1037 } else {
1038 None
1039 }
1040 }
1041}
1042
1043pub struct ColumnIterator<'a> {
1045 dataset: &'a Dataset,
1047
1048 column_index: usize,
1050
1051 current_row_index: usize,
1053}
1054
1055impl Iterator for ColumnIterator<'_> {
1056 type Item = f32;
1057
1058 fn next(&mut self) -> Option<Self::Item> {
1059 if self.current_row_index < self.dataset.data.len() {
1060 let val = self.dataset.data[self.current_row_index][self.column_index];
1061 self.current_row_index += 1;
1062 Some(val)
1063 } else {
1064 None
1065 }
1066 }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use crate::dataset::Dataset;
1072
1073 #[test]
1074 fn xor() {
1075 let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1076 assert!(csv_dataset.validate());
1077
1078 let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
1079 assert!(arff_dataset.validate());
1080
1081 let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
1082 assert!(svm_dataset.validate());
1083
1084 assert_eq!(csv_dataset, arff_dataset);
1085 assert_eq!(csv_dataset, svm_dataset);
1086 assert_eq!(arff_dataset, svm_dataset);
1087 }
1088
1089 #[test]
1090 fn xor_no_label() {
1091 assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
1092 assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
1093 }
1094
1095 #[test]
1096 fn shuffle() {
1097 let original_dataset =
1098 Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1099 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1100 dataset.shuffle();
1101
1102 assert_eq!(original_dataset, dataset);
1103 assert_ne!(original_dataset.data, dataset.data);
1104 assert_ne!(original_dataset.labels, dataset.labels);
1105 assert_eq!(original_dataset.features, dataset.features);
1106 }
1107
1108 #[test]
1109 fn split() {
1110 let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1111 let original_size = dataset.len();
1112 let smaller = dataset.split(0.8);
1113
1114 println!(
1115 "Original: {original_size}, New size: {}, Smaller dataset: {}",
1116 dataset.len(),
1117 smaller.len()
1118 );
1119 assert!(smaller.len() < dataset.len());
1120 assert_eq!(original_size, dataset.len() + smaller.len());
1121 assert_ne!(dataset, smaller);
1122 assert_eq!(dataset.features, smaller.features);
1123 }
1124
1125 #[test]
1126 fn save() {
1127 const COPY_CSV: &str = "xor_copy.csv";
1128 const COPY_ARFF: &str = "xor_copy.arff";
1129 const COPY_SVM: &str = "xor_copy.svm";
1130
1131 let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1132 dataset.save_csv(COPY_CSV).unwrap();
1133 dataset.save_arff(COPY_ARFF).unwrap();
1134 dataset.save_libsvm(COPY_SVM).unwrap();
1135
1136 let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
1137 assert_eq!(dataset, dataset2);
1138
1139 let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
1140 assert_eq!(dataset, dataset3);
1141 assert_eq!(dataset2, dataset3);
1142
1143 let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
1144 assert_eq!(dataset, dataset4);
1145 assert_eq!(dataset3, dataset4);
1146
1147 std::fs::remove_file(COPY_CSV).unwrap();
1148 std::fs::remove_file(COPY_ARFF).unwrap();
1149 std::fs::remove_file(COPY_SVM).unwrap();
1150 }
1151}