malware_modeler/
dataset.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::Bytes;
4
5use std::io::{Read, Write};
6use std::path::Path;
7use std::str::FromStr;
8use std::sync::RwLock;
9
10use anyhow::{bail, ensure, Result};
11use rayon::prelude::*;
12use serde::{Deserialize, Serialize};
13use walkdir::WalkDir;
14
15const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
16const FEATURES_PREFIX: &str = "Features:";
17
18/// Given a file path, feature size, and collection of features, return a vector
19/// which indicates if each feature is present.
20///
21/// # Errors
22/// Returns an error if the file can't be read.
23#[inline]
24pub fn featurize_file<P: AsRef<Path>>(file: P, n: usize, features: &[Bytes]) -> Result<Vec<f32>> {
25    let contents = std::fs::read(file)?;
26    let mut feature_vector = vec![0.0; features.len()];
27
28    for window in contents.windows(n) {
29        if let Some(position) = features.iter().position(|n| n == window) {
30            feature_vector[position] = 1.0;
31        }
32    }
33
34    Ok(feature_vector)
35}
36
37/// File format for a dataset
38#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
39pub enum DatasetFormat {
40    /// Attribute-relation format
41    ARFF,
42
43    /// Comma-separated values, the most common format
44    CSV,
45
46    /// Support vector machine format, ideal for sparse data
47    SVM,
48}
49
50impl FromStr for DatasetFormat {
51    type Err = String;
52
53    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
54        match value.to_lowercase().as_str() {
55            "arff" => Ok(Self::ARFF),
56            "csv" => Ok(Self::CSV),
57            "svm" => Ok(Self::SVM),
58            x => Err(format!("Unknown data format '{x}'")),
59        }
60    }
61}
62
63/// A dataset contains data for training or inference, training requires labels
64#[derive(Debug, Clone, Default, Deserialize, Serialize)]
65pub struct Dataset {
66    /// Data used for training a model or calculating predictions
67    pub data: Vec<Vec<f32>>,
68
69    /// Data labels, can be empty if only used for inference
70    #[serde(default)]
71    pub labels: Vec<f32>,
72
73    /// N-gram features
74    pub features: Vec<Bytes>,
75}
76
77impl PartialEq for Dataset {
78    fn eq(&self, other: &Self) -> bool {
79        if self.data.len() != other.data.len()
80            || self.labels.len() != other.labels.len()
81            || self.features.len() != other.features.len()
82        {
83            return false;
84        }
85
86        for this_data in &self.data {
87            if !other.data.contains(this_data) {
88                return false;
89            }
90        }
91
92        for other_data in &other.data {
93            if !self.data.contains(other_data) {
94                return false;
95            }
96        }
97
98        if !self.labels.is_empty() {
99            for this_label in &self.labels {
100                if !other.labels.contains(this_label) {
101                    return false;
102                }
103            }
104
105            for other_label in &other.labels {
106                if !self.labels.contains(other_label) {
107                    return false;
108                }
109            }
110        }
111
112        for this_features in &self.features {
113            if !other.features.contains(this_features) {
114                return false;
115            }
116        }
117
118        for other_feature in &other.features {
119            if !self.features.contains(other_feature) {
120                return false;
121            }
122        }
123
124        true
125    }
126}
127
128impl Dataset {
129    /// Load a file
130    ///
131    /// # Errors
132    ///
133    /// An error results if the file type can't be determined, is incorrectly determined,
134    /// or if the file isn't a supported format.
135    pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
136        if let Some(extension) = path.as_ref().extension() {
137            return match extension.to_str().unwrap_or_default() {
138                "arff" => Dataset::from_arff_file(path.as_ref()),
139                "csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
140                "svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
141                "json" => {
142                    let contents = std::fs::read_to_string(path.as_ref())?;
143                    serde_json::from_str(&contents).map_err(Into::into)
144                }
145                "toml" => {
146                    let contents = std::fs::read_to_string(path.as_ref())?;
147                    toml::from_str(&contents).map_err(Into::into)
148                }
149                ext => {
150                    bail!("Unsupported/unknown data type '{ext}'");
151                }
152            };
153        }
154
155        bail!("No extension, can't determine file type.");
156    }
157
158    /// Create a dataset struct from a CSV file
159    ///
160    /// # Errors
161    ///
162    /// Returns an error if:
163    ///   * The file can't be read
164    ///   * The data contained isn't numeric
165    ///   * Feature data is missing
166    ///   * The expected amount of data isn't encountered
167    pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
168        let mut file = std::fs::File::open(path)?;
169        let mut contents = String::new();
170        file.read_to_string(&mut contents)?;
171
172        Self::from_csv_string(&contents, data_length)
173    }
174
175    /// Create a dataset struct from a CSV file
176    ///
177    /// # Errors
178    ///
179    /// Returns an error if:
180    ///   * The file can't be read
181    ///   * The data contained isn't a float
182    ///   * Feature data is missing
183    ///   * The amount of columns can't be determined
184    pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
185        let mut file = std::fs::File::open(path)?;
186        let mut contents = String::new();
187        file.read_to_string(&mut contents)?;
188
189        let mut length = 0;
190        for line in contents.lines() {
191            if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
192                continue;
193            }
194
195            length = line.split(',').collect::<Vec<&str>>().len();
196            break;
197        }
198
199        ensure!(length > 0, "Failed to determine data length.");
200        Self::from_csv_string(&contents, length - 1)
201    }
202
203    /// Create a dataset struct from a CSV string
204    ///
205    /// # Errors
206    ///
207    /// Returns an error if:
208    ///   * The data contained isn't numeric
209    ///   * Feature data is missing
210    ///   * The expected amount of data isn't encountered
211    pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
212        let mut data: Vec<Vec<f32>> = Vec::new();
213        let mut labels = Vec::new();
214        let mut features = Vec::new();
215
216        for (row_number, line) in contents.lines().enumerate() {
217            if line.is_empty() {
218                continue;
219            }
220
221            if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) && line.contains(FEATURES_PREFIX) {
222                let offset = line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
223                let line = line[offset..].trim();
224                features = line
225                    .split(',')
226                    .filter_map(|f| hex::decode(f.trim()).ok())
227                    .collect();
228            }
229
230            if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
231                continue;
232            }
233            let row = line.split(',').collect::<Vec<&str>>();
234            let mut row_float = Vec::with_capacity(data_length);
235            for r in row.iter().take(data_length) {
236                row_float.push(r.parse::<f32>().map_err(|_| {
237                    anyhow::Error::msg(format!("Non-float encountered on CSV row {row_number}"))
238                })?);
239            }
240            if let Some(first_row) = data.first() {
241                ensure!(
242                    first_row.len() == row_float.len(),
243                    "CSV line {row_number} has invalid length {}, expected {}",
244                    row_float.len(),
245                    first_row.len()
246                );
247            }
248            data.push(row_float);
249            if row.len() == data_length + 1 {
250                let l = row[data_length].parse::<f32>().map_err(|_| {
251                    anyhow::Error::msg(format!("Non-float encountered on CSV row {row_number}"))
252                })?;
253                labels.push(l);
254            } else if row.len() > data_length {
255                bail!(
256                    "CSV row had more than one label on row {row_number}, which isn't supported."
257                );
258            }
259        }
260
261        ensure!(
262            features.len() == data[0].len(),
263            "Features need to be empty or the same size as the data length."
264        );
265
266        Ok(Self {
267            data,
268            labels,
269            features,
270        })
271    }
272
273    /// Create a dataset struct from an ARFF string
274    ///
275    /// # Errors
276    ///
277    /// Returns an error if:
278    ///   * The file can't be read
279    ///   * The data contained isn't numeric
280    ///   * Feature data is missing
281    ///   * The expected amount of data isn't encountered
282    pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
283        let mut file = std::fs::File::open(path)?;
284        let mut contents = String::new();
285        file.read_to_string(&mut contents)?;
286
287        Self::from_arff_string(&contents)
288    }
289
290    /// Create a dataset struct from an ARFF string
291    ///
292    /// # Errors
293    ///
294    /// Returns an error if:
295    ///   * The data contained isn't numeric
296    ///   * Feature data is missing
297    ///   * The expected amount of data isn't encountered
298    pub fn from_arff_string(contents: &str) -> Result<Self> {
299        let mut data: Vec<Vec<f32>> = Vec::new();
300        let mut labels = Vec::new();
301        let mut features = Vec::new();
302        let mut passed_data = false;
303
304        for (row_number, line) in contents.lines().enumerate() {
305            if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
306                continue;
307            }
308
309            if line.contains("@ATTRIBUTE") {
310                let parts: Vec<&str> = line.split_ascii_whitespace().collect();
311                if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
312                    match hex::decode(parts[1]) {
313                        Ok(feat) => features.push(feat),
314                        Err(e) => {
315                            bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
316                        }
317                    }
318                }
319            }
320
321            if line.contains("@DATA") {
322                passed_data = true;
323                continue;
324            }
325
326            // Basically a CSV at this point
327            if passed_data {
328                let row = line.split(',').collect::<Vec<&str>>();
329                let data_length = row.len() - 1;
330                let mut row_float = Vec::with_capacity(data_length);
331                for r in row.iter().take(data_length) {
332                    row_float.push(r.parse::<f32>().map_err(|_| {
333                        anyhow::Error::msg(format!(
334                            "Non-float encountered on ARFF row {row_number}"
335                        ))
336                    })?);
337                }
338                if let Some(first_row) = data.first() {
339                    ensure!(
340                        first_row.len() == row_float.len(),
341                        "ARFF line {row_number} has invalid length {}, expected {}",
342                        row_float.len(),
343                        first_row.len()
344                    );
345                }
346                data.push(row_float);
347                if row.len() == data_length + 1 {
348                    let l = row[data_length].parse::<f32>().map_err(|_| {
349                        anyhow::Error::msg(format!(
350                            "Non-float encountered on ARFF row {row_number}"
351                        ))
352                    })?;
353                    labels.push(l);
354                } else if row.len() > data_length {
355                    bail!("Arff row had more than one label on row {row_number}, which isn't supported.");
356                }
357            }
358        }
359
360        ensure!(
361            features.len() == data[0].len(),
362            "Features need to be empty or the same size as the data length."
363        );
364
365        Ok(Self {
366            data,
367            labels,
368            features,
369        })
370    }
371
372    /// Create a dataset struct from a libsvm file
373    ///
374    /// # Errors
375    ///
376    /// Returns an error if:
377    ///   * The file can't be read
378    ///   * Feature data is missing
379    ///   * The data isn't in the expected format
380    ///   * The expected amount of data isn't encountered
381    pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
382        let mut file = std::fs::File::open(path)?;
383        let mut contents = String::new();
384        file.read_to_string(&mut contents)?;
385
386        Self::from_libsvm_string(&contents)
387    }
388
389    /// Create a dataset from a libsvm string
390    ///
391    /// # Errors
392    ///
393    /// Returns an error if the file doesn't contain the expected format or is missing features
394    pub fn from_libsvm_string(contents: &str) -> Result<Self> {
395        let mut data = Vec::new();
396        let mut labels = Vec::new();
397        let mut features = Vec::new();
398
399        for (row_number, line) in contents.lines().enumerate() {
400            if line.is_empty() {
401                continue;
402            }
403
404            if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) && line.contains(FEATURES_PREFIX) {
405                let offset = line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
406                let line = line[offset..].trim();
407                features = line
408                    .split(',')
409                    .filter_map(|f| hex::decode(f.trim()).ok())
410                    .collect();
411            }
412
413            if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
414                continue;
415            }
416
417            let parts = line.split_whitespace().collect::<Vec<&str>>();
418            let label = parts[0].parse::<f32>()?;
419            let mut row = vec![0.0f32; features.len()];
420
421            for part in parts.iter().skip(1) {
422                let part_parts = part.split(':').collect::<Vec<&str>>();
423                let part_index = part_parts[0].parse::<usize>()?;
424                let part_value = part_parts[1].parse::<f32>()?;
425
426                if part_index > row.len() && !features.is_empty() {
427                    bail!("Encountered a value at index {part_index} greater than expected size {} on line {row_number}", data.len());
428                }
429
430                if row.is_empty() {
431                    row = vec![0.0; part_index + 1];
432                } else if part_index >= row.len() {
433                    row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
434                }
435                row[part_index] = part_value;
436            }
437
438            data.push(row);
439            labels.push(label);
440        }
441
442        let data_len = data[0].len();
443        for row in &data {
444            if row.len() != data_len {
445                bail!(
446                    "Encountered a row with length {} but expected length {data_len}",
447                    row.len()
448                );
449            }
450        }
451
452        ensure!(
453            features.len() == data[0].len(),
454            "Features need to be empty or the same size as the data length."
455        );
456
457        Ok(Self {
458            data,
459            labels,
460            features,
461        })
462    }
463
464    /// Given paths to malicious files, benign files, and n-grams (features), get a Dataset object.
465    ///
466    /// # Errors
467    /// This will fail if:
468    /// * The directories for benign or malicious files don't exist or are empty.
469    /// * The n-gram feature file doesn't exist, is empty, or doesn't have hexidecimal-encoded features
470    pub fn create_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
471        malicious_dir: P,
472        benign_dir: P,
473        ngrams_file: P,
474    ) -> Result<Self> {
475        let ngram_contents = std::fs::read_to_string(&ngrams_file)?;
476        let mut n = 0;
477        let ngrams = ngram_contents
478            .lines()
479            .filter_map(|l| {
480                let line = if let Some(l) = l.split(',').collect::<Vec<&str>>().first() {
481                    l
482                } else {
483                    l
484                };
485                if !line.len().is_multiple_of(2) {
486                    eprintln!("Line {line} has odd number of characters.");
487                    return None;
488                }
489                if n == 0 {
490                    n = line.len() / 2;
491                } else if line.len() / 2 != n {
492                    eprintln!(
493                        "Line {line} has unexpected length of {} bytes, expected {n}",
494                        line.len() / 2
495                    );
496                    return None;
497                }
498                hex::decode(line).ok()
499            })
500            .collect::<Vec<_>>();
501
502        ensure!(
503            !ngrams.is_empty(),
504            "No n-grams read from {}.",
505            ngrams_file.as_ref().display()
506        );
507
508        let mut paths_labels = Vec::new();
509        for entry in WalkDir::new(malicious_dir)
510            .max_depth(crate::MAX_RECURSION_DEPTH)
511            .follow_links(true)
512            .into_iter()
513            .flatten()
514        {
515            if entry.file_type().is_file() {
516                paths_labels.push((entry, 1.0));
517            }
518        }
519
520        for entry in WalkDir::new(benign_dir)
521            .max_depth(crate::MAX_RECURSION_DEPTH)
522            .follow_links(true)
523            .into_iter()
524            .flatten()
525        {
526            if entry.file_type().is_file() {
527                paths_labels.push((entry, 0.0));
528            }
529        }
530
531        let found_files = paths_labels.len();
532        let dataset = Dataset::default();
533        let dataset_lock = RwLock::new(dataset);
534        paths_labels.into_par_iter().for_each(|(path, label)| {
535            match featurize_file(path.path(), n, &ngrams) {
536                Ok(features) => {
537                    if let Ok(mut data) = dataset_lock.write() {
538                        data.data.push(features);
539                        data.labels.push(label);
540                    }
541                }
542                Err(e) => eprintln!("Failed to featurized {}: {e}", path.path().display()),
543            }
544        });
545
546        let mut dataset = dataset_lock.into_inner()?;
547        dataset.features = ngrams;
548
549        if dataset.data.len() != found_files {
550            eprintln!(
551                "Warning: found {found_files} but only have features for {} files.",
552                dataset.data.len()
553            );
554        }
555
556        Ok(dataset)
557    }
558
559    /// Save a dataset as a CSV
560    ///
561    /// # Errors
562    ///
563    /// An error will result if the file can't be opened for writing
564    pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
565        let mut file = std::fs::File::create(path)?;
566
567        let feature_string_vec = self
568            .features
569            .iter()
570            .map(hex::encode)
571            .collect::<Vec<String>>();
572        let features_string = format!("# {FEATURES_PREFIX} {}\n", feature_string_vec.join(", "));
573        file.write_all(features_string.as_bytes())?;
574
575        for index in 0..self.data.len() {
576            let mut line = self.data[index]
577                .iter()
578                .map(|p| format!("{p}"))
579                .collect::<Vec<String>>()
580                .join(",");
581
582            if !self.labels.is_empty() {
583                if self.labels[index] > 0.9 {
584                    line.push_str(",1");
585                } else {
586                    line.push_str(",0");
587                }
588            }
589            line.push('\n');
590
591            file.write_all(line.as_bytes())?;
592        }
593
594        file.sync_all().map_err(Into::into)
595    }
596
597    /// Save a dataset as an ARFF file
598    ///
599    /// # Errors
600    ///
601    /// An error will result if the file can't be opened for writing
602    pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
603        let mut file = std::fs::File::create(path)?;
604
605        for feature in &self.features {
606            let feature_hex = hex::encode(feature);
607            file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
608        }
609
610        if !self.labels.is_empty() {
611            file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
612        }
613
614        file.write_all("\n@DATA\n".as_bytes())?;
615        for index in 0..self.data.len() {
616            let mut line = self.data[index]
617                .iter()
618                .map(|p| format!("{p}"))
619                .collect::<Vec<String>>()
620                .join(",");
621
622            if !self.labels.is_empty() {
623                if self.labels[index] > 0.9 {
624                    line.push_str(",1");
625                } else {
626                    line.push_str(",0");
627                }
628            }
629            line.push('\n');
630
631            file.write_all(line.as_bytes())?;
632        }
633
634        file.sync_all().map_err(Into::into)
635    }
636
637    /// Save a dataset as a libsvm file
638    ///
639    /// # Errors
640    ///
641    /// An error will result if the file can't be opened for writing
642    pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
643        ensure!(
644            !self.labels.is_empty(),
645            "Labels are required to create an libsvm file."
646        );
647        let mut file = std::fs::File::create(path)?;
648
649        let feature_string_vec = self
650            .features
651            .iter()
652            .map(hex::encode)
653            .collect::<Vec<String>>();
654        let features_string = format!("# {FEATURES_PREFIX} {}\n", feature_string_vec.join(", "));
655        file.write_all(features_string.as_bytes())?;
656
657        for index in 0..self.data.len() {
658            file.write_all(format!("{}", self.labels[index]).as_bytes())?;
659            for (data_index, data) in self.data[index].iter().enumerate() {
660                if *data != 0.0000 {
661                    file.write_all(format!(" {data_index}:{data}").as_bytes())?;
662                }
663            }
664
665            file.write_all(b"\n")?;
666        }
667
668        file.sync_all().map_err(Into::into)
669    }
670
671    /// Save the dataset using the file extension to determine data format
672    ///
673    /// # Errors
674    ///
675    /// There's an error if the file can't be written or if the format can't be determined
676    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
677        if let Some(extension) = path.as_ref().extension() {
678            return match extension.to_str().unwrap_or_default() {
679                "arff" => self.save_arff(path),
680                "csv" => self.save_csv(path),
681                "svm" | "libsvm" => self.save_libsvm(path),
682                "json" => {
683                    let contents = serde_json::to_string_pretty(self)?;
684                    let mut file = std::fs::File::create(path)?;
685                    file.write_all(contents.as_bytes())?;
686                    file.sync_all().map_err(Into::into)
687                }
688                "toml" => {
689                    let contents = toml::to_string_pretty(self)?;
690                    let mut file = std::fs::File::create(path)?;
691                    file.write_all(contents.as_bytes())?;
692                    file.sync_all().map_err(Into::into)
693                }
694                ext => {
695                    bail!("Unsupported/unknown data type '{ext}'");
696                }
697            };
698        }
699
700        bail!("No extension, can't determine file type.");
701    }
702
703    /// Return dataset size
704    #[inline]
705    #[must_use]
706    pub fn len(&self) -> usize {
707        self.data.len()
708    }
709
710    /// Indicate if the dataset is empty
711    #[inline]
712    #[must_use]
713    pub fn is_empty(&self) -> bool {
714        self.data.is_empty()
715    }
716
717    /// Ensure the dataset is valid
718    /// * Same size data columns
719    /// * If present, the amount of data rows equals the amount of labels
720    #[inline]
721    #[must_use]
722    pub fn validate(&self) -> bool {
723        let data_len = match self.data.first() {
724            Some(first) => first.len(),
725            None => return false,
726        };
727
728        // Ensure data records are the same size
729        for record in &self.data {
730            if record.len() != data_len {
731                #[cfg(debug_assertions)]
732                eprint!("Expected record size {data_len}, got {}", record.len());
733                return false;
734            }
735        }
736
737        let feature_len = if let Some(first) = self.features.first() {
738            first.len()
739        } else {
740            #[cfg(debug_assertions)]
741            eprintln!("Features data is missing");
742            return false;
743        };
744
745        for feature in &self.features {
746            if feature.len() != feature_len {
747                #[cfg(debug_assertions)]
748                eprint!("Expected feature size {feature_len}, got {}", feature.len());
749                return false;
750            }
751        }
752
753        // If we have labels, ensure it's the same size as the data
754        (self.labels.is_empty() || self.labels.len() == self.data.len())
755            && self.features.len() == data_len
756    }
757
758    /// Shuffle the data, using roughly 10 X log10(size).
759    /// So 10 records = 10 iterations, 1,000 records gets 30 iterations
760    pub fn shuffle(&mut self) {
761        // Avoid a panic since `.ilog10()` panics on zero.
762        if !self.is_empty() {
763            let iterations = self.data.len().ilog10() * 10;
764            self.shuffle_iterations(iterations);
765        }
766    }
767
768    /// Shuffle the data with a specified amount of iterations, ensures
769    /// that the labels are swapped with the data, if present
770    pub fn shuffle_iterations(&mut self, iterations: u32) {
771        use rand::Rng;
772
773        if !self.is_empty() {
774            let mut rng = rand::rng();
775
776            for _ in 0..iterations {
777                let a = rng.random_range(0..self.data.len());
778                let b = rng.random_range(0..self.data.len());
779                let b = if b == a {
780                    rng.random_range(0..self.data.len())
781                } else {
782                    b
783                };
784
785                self.data.swap(a, b);
786                if !self.labels.is_empty() {
787                    self.labels.swap(a, b);
788                }
789            }
790        }
791    }
792
793    /// Split the dataset, ideally into train/test datasets.
794    /// The ratio indicates how much data is kept, the remaining size is shed and returned.
795    #[must_use]
796    #[allow(
797        clippy::cast_sign_loss,
798        clippy::cast_possible_truncation,
799        clippy::cast_precision_loss
800    )]
801    pub fn split(&mut self, ratio: f32) -> Self {
802        let ratio = ratio.abs();
803        let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
804        let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
805
806        let new_data = self.data.drain(new_size..).collect();
807        let new_labels = if self.labels.is_empty() {
808            vec![]
809        } else {
810            self.labels.drain(new_size..).collect()
811        };
812
813        Self {
814            data: new_data,
815            labels: new_labels,
816            features: self.features.clone(),
817        }
818    }
819}
820
821#[cfg(test)]
822mod tests {
823    use crate::dataset::Dataset;
824
825    #[test]
826    fn xor() {
827        let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
828        assert!(csv_dataset.validate());
829
830        let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
831        assert!(arff_dataset.validate());
832
833        let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
834        assert!(svm_dataset.validate());
835
836        assert_eq!(csv_dataset, arff_dataset);
837        assert_eq!(csv_dataset, svm_dataset);
838        assert_eq!(arff_dataset, svm_dataset);
839    }
840
841    #[test]
842    fn xor_no_label() {
843        assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
844        assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
845    }
846
847    #[test]
848    fn shuffle() {
849        let original_dataset =
850            Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
851        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
852        dataset.shuffle();
853
854        assert_eq!(original_dataset, dataset);
855        assert_ne!(original_dataset.data, dataset.data);
856        assert_ne!(original_dataset.labels, dataset.labels);
857        assert_eq!(original_dataset.features, dataset.features);
858    }
859
860    #[test]
861    fn split() {
862        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
863        let original_size = dataset.len();
864        let smaller = dataset.split(0.8);
865
866        println!(
867            "Original: {original_size}, New size: {}, Smaller dataset: {}",
868            dataset.len(),
869            smaller.len()
870        );
871        assert!(smaller.len() < dataset.len());
872        assert_eq!(original_size, dataset.len() + smaller.len());
873        assert_ne!(dataset, smaller);
874        assert_eq!(dataset.features, smaller.features);
875    }
876
877    #[test]
878    fn save() {
879        const COPY_CSV: &str = "xor_copy.csv";
880        const COPY_ARFF: &str = "xor_copy.arff";
881        const COPY_SVM: &str = "xor_copy.svm";
882
883        let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
884        dataset.save_csv(COPY_CSV).unwrap();
885        dataset.save_arff(COPY_ARFF).unwrap();
886        dataset.save_libsvm(COPY_SVM).unwrap();
887
888        let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
889        assert_eq!(dataset, dataset2);
890
891        let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
892        assert_eq!(dataset, dataset3);
893        assert_eq!(dataset2, dataset3);
894
895        let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
896        assert_eq!(dataset, dataset4);
897        assert_eq!(dataset3, dataset4);
898
899        std::fs::remove_file(COPY_CSV).unwrap();
900        std::fs::remove_file(COPY_ARFF).unwrap();
901        std::fs::remove_file(COPY_SVM).unwrap();
902    }
903}