malware_modeler/
dataset.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::ftype::FileType;
4use crate::model::LogisticRegression;
5use crate::ngram::NgramsFile;
6use crate::Bytes;
7
8use std::collections::HashMap;
9use std::io::{Read, Seek, SeekFrom, Write};
10use std::path::Path;
11use std::str::FromStr;
12
13use anyhow::{anyhow, bail, ensure, Result};
14use serde::de::IntoDeserializer;
15use serde::{Deserialize, Serialize};
16use walkdir::WalkDir;
17
18const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
19const FEATURES_PREFIX: &str = "Features:";
20const FILE_TYPE_PREFIX: &str = "File type:";
21
22/// Given a file path, feature size, and collection of features, return a vector
23/// which indicates if each feature is present.
24///
25/// # Errors
26/// Returns an error if the file can't be read.
27#[inline]
28pub(crate) fn featurize_file<P: AsRef<Path>, S: ::std::hash::BuildHasher>(
29    file: P,
30    n: usize,
31    features: &HashMap<Bytes, usize, S>,
32) -> Result<Vec<f32>> {
33    let file_size = std::fs::metadata(&file)?.len();
34    ensure!(
35        file_size > n as u64,
36        "File {} is too small.",
37        file.as_ref().display()
38    );
39
40    let mut feature_vector = vec![0.0; features.len()];
41
42    if file_size < 10_485_760u64
43    /* 10MB */
44    {
45        let contents = std::fs::read(file)?;
46        for window in contents.windows(n) {
47            if let Some(index) = features.get(window) {
48                feature_vector[*index] = 1.0;
49            }
50        }
51    } else {
52        let mut file = std::fs::File::open(file)?;
53        let mut buffer = [0u8; crate::ngram::NGRAM_BUFFER_SIZE];
54        while let Ok(bytes_read) = file.read(&mut buffer) {
55            if bytes_read < n {
56                break;
57            }
58            for index in 0..bytes_read - n {
59                if let Some(index) = features.get(&buffer[index..index + n]) {
60                    feature_vector[*index] = 1.0;
61                }
62            }
63
64            // A wraparound isn't possible here since n-grams are expected to be single-digit numbers.
65            // `usize` is just a convenience as we have to index into arrays
66            // Go back n-1 bytes to get the next n-gram.
67            #[allow(clippy::cast_possible_wrap)]
68            file.seek(SeekFrom::Current(n as i64 - 1))?;
69        }
70    }
71
72    Ok(feature_vector)
73}
74
75/// File format for a dataset
76#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
77pub enum DatasetFormat {
78    /// Attribute-relation format
79    ARFF,
80
81    /// Comma-separated values, the most common format
82    CSV,
83
84    /// Support vector machine format, ideal for sparse data
85    SVM,
86}
87
88impl FromStr for DatasetFormat {
89    type Err = anyhow::Error;
90
91    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
92        match value.to_lowercase().as_str() {
93            "arff" => Ok(Self::ARFF),
94            "csv" => Ok(Self::CSV),
95            "svm" => Ok(Self::SVM),
96            x => Err(anyhow!("Unknown data format '{x}'")),
97        }
98    }
99}
100
101impl TryFrom<&Path> for DatasetFormat {
102    type Error = anyhow::Error;
103
104    fn try_from(value: &Path) -> std::result::Result<Self, Self::Error> {
105        if let Some(extension) = value.extension() {
106            let ext = extension
107                .to_str()
108                .ok_or_else(|| anyhow!("Failed to get extension."))?;
109            DatasetFormat::from_str(ext)
110        } else {
111            Err(anyhow!("No extension, can't determine file type."))
112        }
113    }
114}
115
116/// A dataset contains data for training or inference, training requires labels
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct Dataset {
119    /// Data used for training a model or calculating predictions
120    pub data: Vec<Vec<f32>>,
121
122    /// Data labels, can be empty if only used for inference
123    #[serde(default)]
124    pub labels: Vec<f32>,
125
126    /// N-gram byte features
127    #[serde(
128        serialize_with = "crate::serde::serialize_hex_vec",
129        deserialize_with = "crate::serde::deserialize_hex_vec"
130    )]
131    pub features: Vec<Bytes>,
132
133    /// The type of file represented
134    pub ftype: FileType,
135}
136
137impl PartialEq for Dataset {
138    fn eq(&self, other: &Self) -> bool {
139        if self.data.len() != other.data.len()
140            || self.labels.len() != other.labels.len()
141            || self.features.len() != other.features.len()
142            || self.ftype != other.ftype
143        {
144            return false;
145        }
146
147        for this_data in &self.data {
148            if !other.data.contains(this_data) {
149                return false;
150            }
151        }
152
153        for other_data in &other.data {
154            if !self.data.contains(other_data) {
155                return false;
156            }
157        }
158
159        if !self.labels.is_empty() {
160            for this_label in &self.labels {
161                if !other.labels.contains(this_label) {
162                    return false;
163                }
164            }
165
166            for other_label in &other.labels {
167                if !self.labels.contains(other_label) {
168                    return false;
169                }
170            }
171        }
172
173        for this_features in &self.features {
174            if !other.features.contains(this_features) {
175                return false;
176            }
177        }
178
179        for other_feature in &other.features {
180            if !self.features.contains(other_feature) {
181                return false;
182            }
183        }
184
185        true
186    }
187}
188
189impl Dataset {
190    /// Load a file
191    ///
192    /// # Errors
193    ///
194    /// An error results if the file type can't be determined, is incorrectly determined,
195    /// or if the file isn't a supported format.
196    pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
197        if let Some(extension) = path.as_ref().extension() {
198            return match extension.to_str().unwrap_or_default() {
199                "arff" => Dataset::from_arff_file(path.as_ref()),
200                "csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
201                "svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
202                "json" => {
203                    let contents = std::fs::read_to_string(path.as_ref())?;
204                    serde_json::from_str(&contents).map_err(Into::into)
205                }
206                "toml" => {
207                    let contents = std::fs::read_to_string(path.as_ref())?;
208                    toml::from_str(&contents).map_err(Into::into)
209                }
210                ext => {
211                    bail!("Unsupported/unknown data type '{ext}'");
212                }
213            };
214        }
215
216        bail!("No extension, can't determine file type.");
217    }
218
219    /// Create a dataset struct from a CSV file
220    ///
221    /// # Errors
222    ///
223    /// Returns an error if:
224    ///   * The file can't be read
225    ///   * The data contained isn't numeric
226    ///   * Feature data is missing
227    ///   * The expected amount of data isn't encountered
228    pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
229        let mut file = std::fs::File::open(path)?;
230        let mut contents = String::new();
231        file.read_to_string(&mut contents)?;
232
233        Self::from_csv_string(&contents, data_length)
234    }
235
236    /// Create a dataset struct from a CSV file
237    ///
238    /// # Errors
239    ///
240    /// Returns an error if:
241    ///   * The file can't be read
242    ///   * The data contained isn't a float
243    ///   * Feature data is missing
244    ///   * The amount of columns can't be determined
245    pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
246        let mut file = std::fs::File::open(path)?;
247        let mut contents = String::new();
248        file.read_to_string(&mut contents)?;
249
250        let mut length = 0;
251        for line in contents.lines() {
252            if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
253                continue;
254            }
255
256            length = line.split(',').collect::<Vec<&str>>().len();
257            break;
258        }
259
260        ensure!(length > 0, "Failed to determine data length.");
261        Self::from_csv_string(&contents, length - 1)
262    }
263
264    /// Create a dataset struct from a CSV string
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if:
269    ///   * The data contained isn't numeric
270    ///   * Feature data is missing
271    ///   * The expected amount of data isn't encountered
272    pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
273        let mut data: Vec<Vec<f32>> = Vec::new();
274        let mut labels = Vec::new();
275        let mut features = Vec::new();
276        let mut file_type = FileType::NotSet;
277
278        for (row_number, line) in contents.lines().enumerate() {
279            if line.is_empty() {
280                continue;
281            }
282
283            if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
284                if line.contains(FEATURES_PREFIX) {
285                    let offset =
286                        line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
287                    let line = line[offset..].trim();
288                    features = line
289                        .split(',')
290                        .filter_map(|f| hex::decode(f.trim()).ok())
291                        .collect();
292                }
293
294                if line.contains(FILE_TYPE_PREFIX) {
295                    if let Some(file_type_str) = line.split(':').nth(1) {
296                        let file_type_str = file_type_str.trim();
297                        let ftype: Result<_, serde::de::value::Error> =
298                            FileType::deserialize(String::from(file_type_str).into_deserializer());
299                        file_type = ftype?;
300                    }
301                }
302            }
303
304            if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
305                continue;
306            }
307            let row = line.split(',').collect::<Vec<&str>>();
308            let mut row_float = Vec::with_capacity(data_length);
309            for r in row.iter().take(data_length) {
310                row_float.push(r.parse::<f32>().map_err(|_| {
311                    anyhow::Error::msg(format!("Non-float {r} encountered on CSV row {row_number}"))
312                })?);
313            }
314            if let Some(first_row) = data.first() {
315                ensure!(
316                    first_row.len() == row_float.len(),
317                    "CSV line {row_number} has invalid length {}, expected {}",
318                    row_float.len(),
319                    first_row.len()
320                );
321            }
322            data.push(row_float);
323            if row.len() == data_length + 1 {
324                let l = row[data_length].parse::<f32>().map_err(|_| {
325                    anyhow::Error::msg(format!(
326                        "Non-float label {} encountered on CSV row {row_number}",
327                        row[data_length]
328                    ))
329                })?;
330                labels.push(l);
331            } else if row.len() > data_length {
332                bail!(
333                    "CSV row had more than one label on row {row_number}, which isn't supported."
334                );
335            }
336        }
337
338        ensure!(
339            features.len() == data[0].len(),
340            "Features need to be empty or the same size as the data length."
341        );
342
343        ensure!(
344            file_type != FileType::NotSet,
345            "No file type specified in CSV file."
346        );
347        Ok(Self {
348            data,
349            labels,
350            features,
351            ftype: file_type,
352        })
353    }
354
355    /// Get a file type object from a line in a dataset file
356    #[inline]
357    pub(crate) fn file_type_from_line(line: &str) -> Result<FileType, serde::de::value::Error> {
358        let line = line.split(':').nth(1).unwrap_or_default();
359        let ftype: Result<_, serde::de::value::Error> =
360            FileType::deserialize(String::from(line.trim()).into_deserializer());
361        ftype
362    }
363
364    /// Create a dataset struct from an ARFF string
365    ///
366    /// # Errors
367    ///
368    /// Returns an error if:
369    ///   * The file can't be read
370    ///   * The data contained isn't numeric
371    ///   * Feature data is missing
372    ///   * The expected amount of data isn't encountered
373    pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
374        let mut file = std::fs::File::open(path)?;
375        let mut contents = String::new();
376        file.read_to_string(&mut contents)?;
377
378        Self::from_arff_string(&contents)
379    }
380
381    /// Create a dataset struct from an ARFF string
382    ///
383    /// # Errors
384    ///
385    /// Returns an error if:
386    ///   * The data contained isn't numeric
387    ///   * Feature data is missing
388    ///   * The expected amount of data isn't encountered
389    pub fn from_arff_string(contents: &str) -> Result<Self> {
390        let mut data: Vec<Vec<f32>> = Vec::new();
391        let mut labels = Vec::new();
392        let mut features = Vec::new();
393        let mut file_type = FileType::NotSet;
394        let mut passed_data = false;
395
396        for (row_number, line) in contents.lines().enumerate() {
397            if line.is_empty() {
398                continue;
399            }
400
401            if (line.starts_with('%') || line.starts_with('#')) && line.contains(FILE_TYPE_PREFIX) {
402                file_type = Self::file_type_from_line(line)?;
403                continue;
404            }
405
406            if line.contains("@ATTRIBUTE") {
407                let parts: Vec<&str> = line.split_ascii_whitespace().collect();
408                if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
409                    match hex::decode(parts[1]) {
410                        Ok(feat) => features.push(feat),
411                        Err(e) => {
412                            bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
413                        }
414                    }
415                }
416            }
417
418            if line.contains("@DATA") {
419                passed_data = true;
420                continue;
421            }
422
423            // Basically a CSV at this point
424            if passed_data {
425                let row = line.split(',').collect::<Vec<&str>>();
426                let data_length = row.len() - 1;
427                let mut row_float = Vec::with_capacity(data_length);
428                for r in row.iter().take(data_length) {
429                    row_float.push(r.parse::<f32>().map_err(|_| {
430                        anyhow::Error::msg(format!(
431                            "Non-float encountered on ARFF row {row_number}"
432                        ))
433                    })?);
434                }
435                if let Some(first_row) = data.first() {
436                    ensure!(
437                        first_row.len() == row_float.len(),
438                        "ARFF line {row_number} has invalid length {}, expected {}",
439                        row_float.len(),
440                        first_row.len()
441                    );
442                }
443                data.push(row_float);
444                if row.len() == data_length + 1 {
445                    let l = row[data_length].parse::<f32>().map_err(|_| {
446                        anyhow::Error::msg(format!(
447                            "Non-float encountered on ARFF row {row_number}"
448                        ))
449                    })?;
450                    labels.push(l);
451                } else if row.len() > data_length {
452                    bail!("Arff row had more than one label on row {row_number}, which isn't supported.");
453                }
454            }
455        }
456
457        ensure!(
458            features.len() == data[0].len(),
459            "Features need to be empty or the same size as the data length."
460        );
461
462        ensure!(
463            file_type != FileType::NotSet,
464            "No file type specified in ARFF file."
465        );
466        Ok(Self {
467            data,
468            labels,
469            features,
470            ftype: file_type,
471        })
472    }
473
474    /// Create a dataset struct from a libsvm file
475    ///
476    /// # Errors
477    ///
478    /// Returns an error if:
479    ///   * The file can't be read
480    ///   * Feature data is missing
481    ///   * The data isn't in the expected format
482    ///   * The expected amount of data isn't encountered
483    pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
484        let mut file = std::fs::File::open(path)?;
485        let mut contents = String::new();
486        file.read_to_string(&mut contents)?;
487
488        Self::from_libsvm_string(&contents)
489    }
490
491    /// Create a dataset from a libsvm string
492    ///
493    /// # Errors
494    ///
495    /// Returns an error if the file doesn't contain the expected format or is missing features
496    pub fn from_libsvm_string(contents: &str) -> Result<Self> {
497        let mut data = Vec::new();
498        let mut labels = Vec::new();
499        let mut features = Vec::new();
500        let mut file_type = FileType::NotSet;
501
502        for (row_number, line) in contents.lines().enumerate() {
503            if line.is_empty() {
504                continue;
505            }
506
507            if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
508                if line.contains(FEATURES_PREFIX) {
509                    let offset =
510                        line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
511                    let line = line[offset..].trim();
512                    features = line
513                        .split(',')
514                        .filter_map(|f| hex::decode(f.trim()).ok())
515                        .collect();
516                }
517
518                if line.contains(FILE_TYPE_PREFIX) {
519                    file_type = Self::file_type_from_line(line)?;
520                }
521            }
522
523            if line.is_empty() || line.starts_with('%') || line.starts_with('#') {
524                continue;
525            }
526
527            let parts = line.split_whitespace().collect::<Vec<&str>>();
528            let Ok(label) = parts[0].trim().parse::<f32>() else {
529                bail!(
530                    "Encountered a non-numeric label {} on line {row_number}",
531                    parts[0]
532                );
533            };
534            let mut row = vec![0.0f32; features.len()];
535
536            for part in parts.iter().skip(1) {
537                let part_parts = part.split(':').collect::<Vec<&str>>();
538                let Ok(part_index) = part_parts[0].trim().parse::<usize>() else {
539                    bail!(
540                        "Encountered a non-numeric index {} on line {row_number}",
541                        part_parts[0]
542                    );
543                };
544                let Ok(part_value) = part_parts[1].trim().parse::<f32>() else {
545                    bail!(
546                        "Encountered a non-numeric value {} on line {row_number}",
547                        part_parts[1]
548                    );
549                };
550
551                if part_index > row.len() && !features.is_empty() {
552                    bail!("Encountered a value at index {part_index} greater than expected size {} on line {row_number}", data.len());
553                }
554
555                if row.is_empty() {
556                    row = vec![0.0; part_index + 1];
557                } else if part_index >= row.len() {
558                    row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
559                }
560                row[part_index] = part_value;
561            }
562
563            data.push(row);
564            labels.push(label);
565        }
566
567        let data_len = data[0].len();
568        for row in &data {
569            if row.len() != data_len {
570                bail!(
571                    "Encountered a row with length {} but expected length {data_len}",
572                    row.len()
573                );
574            }
575        }
576
577        ensure!(
578            features.len() == data[0].len(),
579            "Features need to be empty or the same size as the data length."
580        );
581
582        ensure!(
583            file_type != FileType::NotSet,
584            "No file type specified in libsvm file."
585        );
586        Ok(Self {
587            data,
588            labels,
589            features,
590            ftype: file_type,
591        })
592    }
593
594    /// Given paths to malicious files, benign files, and n-grams (features), get a Dataset object.
595    ///
596    /// # Errors
597    /// This will fail if:
598    /// * The directories for benign or malicious files don't exist or are empty.
599    /// * The n-gram feature file doesn't exist, is empty, or doesn't have hexidecimal-encoded features
600    #[allow(clippy::too_many_lines)]
601    pub fn create_save_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
602        malicious_dir: P,
603        benign_dir: P,
604        ngrams_file: P,
605        output_file: P,
606    ) -> Result<()> {
607        const SUPPORTED_FORMATS: [DatasetFormat; 3] =
608            [DatasetFormat::CSV, DatasetFormat::ARFF, DatasetFormat::SVM];
609
610        let output_format = DatasetFormat::try_from(output_file.as_ref())?;
611        ensure!(
612            SUPPORTED_FORMATS.contains(&output_format),
613            "Only CSV, ARFF, or SVM formats are supported here."
614        );
615
616        let ngrams = NgramsFile::load(ngrams_file)?;
617        let mut output_file = std::fs::File::create(output_file)?;
618        writeln!(output_file, "# {FILE_TYPE_PREFIX} {:?}", ngrams.ftype)?;
619
620        match output_format {
621            DatasetFormat::SVM | DatasetFormat::CSV => {
622                let feature_string_vec = ngrams
623                    .clone()
624                    .into_vec()
625                    .iter()
626                    .map(hex::encode)
627                    .collect::<Vec<String>>();
628                writeln!(
629                    output_file,
630                    "# {FEATURES_PREFIX} {}",
631                    feature_string_vec.join(", ")
632                )?;
633            }
634
635            DatasetFormat::ARFF => {
636                let feature_string_vec = ngrams
637                    .clone()
638                    .into_vec()
639                    .iter()
640                    .map(hex::encode)
641                    .collect::<Vec<String>>();
642                for feature in feature_string_vec {
643                    let feature_hex = hex::encode(feature);
644                    writeln!(output_file, "@ATTRIBUTE {feature_hex} NUMERIC")?;
645                }
646            }
647        }
648
649        for entry in WalkDir::new(malicious_dir)
650            .max_depth(crate::MAX_RECURSION_DEPTH)
651            .follow_links(true)
652            .into_iter()
653            .flatten()
654        {
655            if entry.file_type().is_file() {
656                match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
657                    Ok(features) => match output_format {
658                        DatasetFormat::CSV | DatasetFormat::ARFF => {
659                            let mut line = features
660                                .iter()
661                                .map(|p| format!("{p}"))
662                                .collect::<Vec<String>>()
663                                .join(",");
664                            line.push_str(",1\n");
665                            output_file.write_all(line.as_bytes())?;
666                        }
667
668                        DatasetFormat::SVM => {
669                            write!(output_file, "1")?;
670                            for (data_index, data) in features.iter().enumerate() {
671                                if *data != 0.0000 {
672                                    write!(output_file, " {data_index}:{data}")?;
673                                }
674                            }
675                            writeln!(output_file)?;
676                        }
677                    },
678                    Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
679                }
680            }
681        }
682
683        for entry in WalkDir::new(benign_dir)
684            .max_depth(crate::MAX_RECURSION_DEPTH)
685            .follow_links(true)
686            .into_iter()
687            .flatten()
688        {
689            if entry.file_type().is_file() {
690                match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
691                    Ok(features) => match output_format {
692                        DatasetFormat::CSV | DatasetFormat::ARFF => {
693                            let mut line = features
694                                .iter()
695                                .map(|p| format!("{p}"))
696                                .collect::<Vec<String>>()
697                                .join(",");
698                            line.push_str(",0\n");
699                            output_file.write_all(line.as_bytes())?;
700                        }
701
702                        DatasetFormat::SVM => {
703                            write!(output_file, "0")?;
704                            for (data_index, data) in features.iter().enumerate() {
705                                if *data != 0.0000 {
706                                    write!(output_file, " {data_index}:{data}")?;
707                                }
708                            }
709                            writeln!(output_file)?;
710                        }
711                    },
712                    Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
713                }
714            }
715        }
716
717        output_file.sync_all()?;
718        Ok(())
719    }
720
721    /// Save a dataset as a CSV
722    ///
723    /// # Errors
724    ///
725    /// An error will result if the file can't be opened for writing
726    pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
727        let mut file = std::fs::File::create(path)?;
728
729        let feature_string_vec = self
730            .features
731            .iter()
732            .map(hex::encode)
733            .collect::<Vec<String>>();
734        writeln!(
735            file,
736            "# {FEATURES_PREFIX} {}",
737            feature_string_vec.join(", ")
738        )?;
739        writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
740
741        for index in 0..self.data.len() {
742            let mut line = self.data[index]
743                .iter()
744                .map(|p| format!("{p}"))
745                .collect::<Vec<String>>()
746                .join(",");
747
748            if !self.labels.is_empty() {
749                if self.labels[index] > 0.9 {
750                    line.push_str(",1");
751                } else {
752                    line.push_str(",0");
753                }
754            }
755            line.push('\n');
756
757            file.write_all(line.as_bytes())?;
758        }
759
760        file.sync_all().map_err(Into::into)
761    }
762
763    /// Save a dataset as an ARFF file
764    ///
765    /// # Errors
766    ///
767    /// An error will result if the file can't be opened for writing
768    pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
769        let mut file = std::fs::File::create(path)?;
770        writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
771
772        for feature in &self.features {
773            let feature_hex = hex::encode(feature);
774            file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
775        }
776
777        if !self.labels.is_empty() {
778            file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
779        }
780
781        file.write_all("\n@DATA\n".as_bytes())?;
782        for index in 0..self.data.len() {
783            let mut line = self.data[index]
784                .iter()
785                .map(|p| format!("{p}"))
786                .collect::<Vec<String>>()
787                .join(",");
788
789            if !self.labels.is_empty() {
790                if self.labels[index] > 0.9 {
791                    line.push_str(",1");
792                } else {
793                    line.push_str(",0");
794                }
795            }
796            line.push('\n');
797
798            file.write_all(line.as_bytes())?;
799        }
800
801        file.sync_all().map_err(Into::into)
802    }
803
804    /// Save a dataset as a libsvm file
805    ///
806    /// # Errors
807    ///
808    /// An error will result if the file can't be opened for writing
809    pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
810        ensure!(
811            !self.labels.is_empty(),
812            "Labels are required to create an libsvm file."
813        );
814        let mut file = std::fs::File::create(path)?;
815
816        let feature_string_vec = self
817            .features
818            .iter()
819            .map(hex::encode)
820            .collect::<Vec<String>>();
821        writeln!(
822            file,
823            "# {FEATURES_PREFIX} {}",
824            feature_string_vec.join(", ")
825        )?;
826        writeln!(file, "# {FILE_TYPE_PREFIX} {:?}", self.ftype)?;
827
828        for index in 0..self.data.len() {
829            file.write_all(format!("{}", self.labels[index]).as_bytes())?;
830            for (data_index, data) in self.data[index].iter().enumerate() {
831                if *data != 0.0000 {
832                    file.write_all(format!(" {data_index}:{data}").as_bytes())?;
833                }
834            }
835
836            file.write_all(b"\n")?;
837        }
838
839        file.sync_all().map_err(Into::into)
840    }
841
842    /// Save the dataset using the file extension to determine data format
843    ///
844    /// # Errors
845    ///
846    /// There's an error if the file can't be written or if the format can't be determined
847    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
848        if let Some(extension) = path.as_ref().extension() {
849            return match extension.to_str().unwrap_or_default() {
850                "arff" => self.save_arff(path),
851                "csv" => self.save_csv(path),
852                "svm" | "libsvm" => self.save_libsvm(path),
853                "json" => {
854                    let contents = serde_json::to_string_pretty(self)?;
855                    let mut file = std::fs::File::create(path)?;
856                    file.write_all(contents.as_bytes())?;
857                    file.sync_all().map_err(Into::into)
858                }
859                "toml" => {
860                    let contents = toml::to_string_pretty(self)?;
861                    let mut file = std::fs::File::create(path)?;
862                    file.write_all(contents.as_bytes())?;
863                    file.sync_all().map_err(Into::into)
864                }
865                ext => {
866                    bail!("Unsupported/unknown data type '{ext}'");
867                }
868            };
869        }
870
871        bail!("No extension, can't determine file type.");
872    }
873
874    /// Return dataset size
875    #[inline]
876    #[must_use]
877    pub fn len(&self) -> usize {
878        self.data.len()
879    }
880
881    /// Indicate if the dataset is empty
882    #[inline]
883    #[must_use]
884    pub fn is_empty(&self) -> bool {
885        self.data.is_empty()
886    }
887
888    /// Ensure the dataset is valid
889    /// * Same size data columns
890    /// * If present, the amount of data rows equals the amount of labels
891    #[inline]
892    #[must_use]
893    pub fn validate(&self) -> bool {
894        let data_len = match self.data.first() {
895            Some(first) => first.len(),
896            None => return false,
897        };
898
899        // Ensure data records are the same size
900        for record in &self.data {
901            if record.len() != data_len {
902                #[cfg(debug_assertions)]
903                eprint!("Expected record size {data_len}, got {}", record.len());
904                return false;
905            }
906        }
907
908        let feature_len = if let Some(first) = self.features.first() {
909            first.len()
910        } else {
911            #[cfg(debug_assertions)]
912            eprintln!("Features data is missing");
913            return false;
914        };
915
916        for feature in &self.features {
917            if feature.len() != feature_len {
918                #[cfg(debug_assertions)]
919                eprint!("Expected feature size {feature_len}, got {}", feature.len());
920                return false;
921            }
922        }
923
924        // If we have labels, ensure it's the same size as the data
925        (self.labels.is_empty() || self.labels.len() == self.data.len())
926            && self.features.len() == data_len
927            && self.ftype != FileType::NotSet
928    }
929
930    /// Shuffle the data, using roughly 10 X log10(size).
931    /// So 10 records = 10 iterations, 1,000 records gets 30 iterations
932    pub fn shuffle(&mut self) {
933        // Avoid a panic since `.ilog10()` panics on zero.
934        if !self.is_empty() {
935            let iterations = self.data.len().ilog10() * 10;
936            self.shuffle_iterations(iterations);
937        }
938    }
939
940    /// Shuffle the data with a specified amount of iterations, ensures
941    /// that the labels are swapped with the data, if present
942    pub fn shuffle_iterations(&mut self, iterations: u32) {
943        use rand::Rng;
944
945        if !self.is_empty() {
946            let mut rng = rand::rng();
947
948            for _ in 0..iterations {
949                let a = rng.random_range(0..self.data.len());
950                let b = rng.random_range(0..self.data.len());
951                let b = if b == a {
952                    rng.random_range(0..self.data.len())
953                } else {
954                    b
955                };
956
957                self.data.swap(a, b);
958                if !self.labels.is_empty() {
959                    self.labels.swap(a, b);
960                }
961            }
962        }
963    }
964
965    /// Split the dataset, ideally into train/test datasets.
966    /// The ratio indicates how much data is kept, the remaining size is shed and returned.
967    #[must_use]
968    #[allow(
969        clippy::cast_sign_loss,
970        clippy::cast_possible_truncation,
971        clippy::cast_precision_loss
972    )]
973    pub fn split(&mut self, ratio: f32) -> Self {
974        let ratio = ratio.abs();
975        let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
976        let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
977
978        let new_data = self.data.drain(new_size..).collect();
979        let new_labels = if self.labels.is_empty() {
980            vec![]
981        } else {
982            self.labels.drain(new_size..).collect()
983        };
984
985        Self {
986            data: new_data,
987            labels: new_labels,
988            features: self.features.clone(),
989            ftype: self.ftype,
990        }
991    }
992
993    /// The model training allows for the algorithm to not only train a model but determine the
994    /// features most useful for determining benign vs. malicious. This action removes the features
995    /// deemed unneeded.
996    ///
997    /// # Errors
998    ///
999    /// If the model would remove all features, an error is returned as an empty dataset isn't useful,
1000    /// and it's instead likely the modal and dataset weren't for the same data collection.
1001    pub fn reduce(&mut self, model: &LogisticRegression) -> Result<Vec<usize>> {
1002        let mut removed = vec![];
1003
1004        for (index, feature) in self.features.iter().enumerate() {
1005            if !model.features.contains_key(feature) {
1006                removed.push(index);
1007            }
1008        }
1009
1010        if removed.len() == self.data[0].len() {
1011            bail!("This dataset and model are probably not from the same data - this operation would delete all the data!");
1012        }
1013
1014        removed.sort_unstable();
1015        removed.reverse();
1016
1017        self.features
1018            .retain(|feature| model.features.contains_key(feature));
1019
1020        for row in &mut self.data {
1021            for removed in &removed {
1022                row.remove(*removed);
1023            }
1024        }
1025
1026        Ok(removed)
1027    }
1028}
1029
1030#[cfg(test)]
1031mod tests {
1032    use crate::dataset::Dataset;
1033
1034    #[test]
1035    fn xor() {
1036        let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1037        assert!(csv_dataset.validate());
1038
1039        let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
1040        assert!(arff_dataset.validate());
1041
1042        let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
1043        assert!(svm_dataset.validate());
1044
1045        assert_eq!(csv_dataset, arff_dataset);
1046        assert_eq!(csv_dataset, svm_dataset);
1047        assert_eq!(arff_dataset, svm_dataset);
1048    }
1049
1050    #[test]
1051    fn xor_no_label() {
1052        assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
1053        assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
1054    }
1055
1056    #[test]
1057    fn shuffle() {
1058        let original_dataset =
1059            Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1060        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1061        dataset.shuffle();
1062
1063        assert_eq!(original_dataset, dataset);
1064        assert_ne!(original_dataset.data, dataset.data);
1065        assert_ne!(original_dataset.labels, dataset.labels);
1066        assert_eq!(original_dataset.features, dataset.features);
1067    }
1068
1069    #[test]
1070    fn split() {
1071        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1072        let original_size = dataset.len();
1073        let smaller = dataset.split(0.8);
1074
1075        println!(
1076            "Original: {original_size}, New size: {}, Smaller dataset: {}",
1077            dataset.len(),
1078            smaller.len()
1079        );
1080        assert!(smaller.len() < dataset.len());
1081        assert_eq!(original_size, dataset.len() + smaller.len());
1082        assert_ne!(dataset, smaller);
1083        assert_eq!(dataset.features, smaller.features);
1084    }
1085
1086    #[test]
1087    fn save() {
1088        const COPY_CSV: &str = "xor_copy.csv";
1089        const COPY_ARFF: &str = "xor_copy.arff";
1090        const COPY_SVM: &str = "xor_copy.svm";
1091
1092        let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1093        dataset.save_csv(COPY_CSV).unwrap();
1094        dataset.save_arff(COPY_ARFF).unwrap();
1095        dataset.save_libsvm(COPY_SVM).unwrap();
1096
1097        let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
1098        assert_eq!(dataset, dataset2);
1099
1100        let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
1101        assert_eq!(dataset, dataset3);
1102        assert_eq!(dataset2, dataset3);
1103
1104        let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
1105        assert_eq!(dataset, dataset4);
1106        assert_eq!(dataset3, dataset4);
1107
1108        std::fs::remove_file(COPY_CSV).unwrap();
1109        std::fs::remove_file(COPY_ARFF).unwrap();
1110        std::fs::remove_file(COPY_SVM).unwrap();
1111    }
1112}