Skip to main content

malware_modeler/
dataset.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use crate::Bytes;
4use crate::ftype::FileType;
5use crate::model::LogisticRegression;
6use crate::ngram::NgramsFile;
7
8use std::collections::HashMap;
9use std::io::{Read, Seek, SeekFrom, Write};
10use std::path::Path;
11use std::str::FromStr;
12
13use anyhow::{Result, anyhow, bail, ensure};
14use serde::de::IntoDeserializer;
15use serde::{Deserialize, Serialize};
16use walkdir::WalkDir;
17
18const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
19const FEATURES_PREFIX: &str = "Features:";
20const FILE_TYPE_PREFIX: &str = "File type:";
21
22/// Given a file path, feature size, and collection of features, return a vector
23/// which indicates if each feature is present.
24///
25/// # Errors
26/// Returns an error if the file can't be read.
27#[inline]
28pub(crate) fn featurize_file<P: AsRef<Path>, S: ::std::hash::BuildHasher>(
29    file: P,
30    n: usize,
31    features: &HashMap<Bytes, usize, S>,
32) -> Result<Vec<f32>> {
33    let file_size = std::fs::metadata(&file)?.len();
34    ensure!(
35        file_size > n as u64,
36        "File {} is too small.",
37        file.as_ref().display()
38    );
39
40    let mut feature_vector = vec![0.0; features.len()];
41
42    if file_size < 10_485_760u64
43    /* 10MB */
44    {
45        let contents = std::fs::read(file)?;
46        for window in contents.windows(n) {
47            if let Some(index) = features.get(window) {
48                feature_vector[*index] = 1.0;
49            }
50        }
51    } else {
52        let mut file = std::fs::File::open(file)?;
53        let mut buffer = [0u8; crate::ngram::NGRAM_BUFFER_SIZE];
54        while let Ok(bytes_read) = file.read(&mut buffer) {
55            if bytes_read < n {
56                break;
57            }
58            for index in 0..bytes_read - n {
59                if let Some(index) = features.get(&buffer[index..index + n]) {
60                    feature_vector[*index] = 1.0;
61                }
62            }
63
64            // A wraparound isn't possible here since n-grams are expected to be single-digit numbers.
65            // `usize` is just a convenience as we have to index into arrays
66            // Go back n-1 bytes to get the next n-gram.
67            #[allow(clippy::cast_possible_wrap)]
68            file.seek(SeekFrom::Current(n as i64 - 1))?;
69        }
70    }
71
72    Ok(feature_vector)
73}
74
75/// File format for a dataset
76#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
77pub enum DatasetFormat {
78    /// Attribute-relation format
79    ARFF,
80
81    /// Comma-separated values, the most common format
82    CSV,
83
84    /// Support vector machine format, ideal for sparse data
85    SVM,
86}
87
88impl FromStr for DatasetFormat {
89    type Err = anyhow::Error;
90
91    fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
92        match value.to_lowercase().as_str() {
93            "arff" => Ok(Self::ARFF),
94            "csv" => Ok(Self::CSV),
95            "svm" => Ok(Self::SVM),
96            x => Err(anyhow!("Unknown data format '{x}'")),
97        }
98    }
99}
100
101impl TryFrom<&Path> for DatasetFormat {
102    type Error = anyhow::Error;
103
104    fn try_from(value: &Path) -> std::result::Result<Self, Self::Error> {
105        if let Some(extension) = value.extension() {
106            let ext = extension
107                .to_str()
108                .ok_or_else(|| anyhow!("Failed to get extension."))?;
109            DatasetFormat::from_str(ext)
110        } else {
111            Err(anyhow!("No extension, can't determine file type."))
112        }
113    }
114}
115
116/// A dataset contains data for training or inference, training requires labels
117#[derive(Debug, Clone, Deserialize, Serialize)]
118pub struct Dataset {
119    /// Data used for training a model or calculating predictions
120    pub data: Vec<Vec<f32>>,
121
122    /// Data labels, can be empty if only used for inference
123    #[serde(default)]
124    pub labels: Vec<u8>,
125
126    /// N-gram byte features
127    #[serde(
128        serialize_with = "crate::serde::serialize_hex_vec",
129        deserialize_with = "crate::serde::deserialize_hex_vec"
130    )]
131    pub features: Vec<Bytes>,
132
133    /// The type of file represented
134    pub ftype: FileType,
135}
136
137impl PartialEq for Dataset {
138    fn eq(&self, other: &Self) -> bool {
139        if self.data.len() != other.data.len()
140            || self.labels.len() != other.labels.len()
141            || self.features.len() != other.features.len()
142            || self.ftype != other.ftype
143        {
144            return false;
145        }
146
147        for this_data in &self.data {
148            if !other.data.contains(this_data) {
149                return false;
150            }
151        }
152
153        for other_data in &other.data {
154            if !self.data.contains(other_data) {
155                return false;
156            }
157        }
158
159        if !self.labels.is_empty() {
160            for this_label in &self.labels {
161                if !other.labels.contains(this_label) {
162                    return false;
163                }
164            }
165
166            for other_label in &other.labels {
167                if !self.labels.contains(other_label) {
168                    return false;
169                }
170            }
171        }
172
173        for this_features in &self.features {
174            if !other.features.contains(this_features) {
175                return false;
176            }
177        }
178
179        for other_feature in &other.features {
180            if !self.features.contains(other_feature) {
181                return false;
182            }
183        }
184
185        true
186    }
187}
188
189impl Dataset {
190    /// Load a file
191    ///
192    /// # Errors
193    ///
194    /// An error results if the file type can't be determined, is incorrectly determined,
195    /// or if the file isn't a supported format.
196    pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
197        if let Some(extension) = path.as_ref().extension() {
198            return match extension.to_str().unwrap_or_default() {
199                "arff" => Dataset::from_arff_file(path.as_ref()),
200                "csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
201                "svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
202                "json" => {
203                    let contents = std::fs::read_to_string(path.as_ref())?;
204                    serde_json::from_str(&contents).map_err(Into::into)
205                }
206                "toml" => {
207                    let contents = std::fs::read_to_string(path.as_ref())?;
208                    toml::from_str(&contents).map_err(Into::into)
209                }
210                ext => {
211                    bail!("Unsupported/unknown data type '{ext}'");
212                }
213            };
214        }
215
216        bail!("No extension, can't determine file type.");
217    }
218
219    /// Create a dataset struct from a CSV file
220    ///
221    /// # Errors
222    ///
223    /// Returns an error if:
224    ///   * The file can't be read
225    ///   * The data contained isn't numeric
226    ///   * Feature data is missing
227    ///   * The expected amount of data isn't encountered
228    pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
229        let mut file = std::fs::File::open(path)?;
230        let mut contents = String::new();
231        file.read_to_string(&mut contents)?;
232
233        Self::from_csv_string(&contents, data_length)
234    }
235
236    /// Create a dataset struct from a CSV file
237    ///
238    /// # Errors
239    ///
240    /// Returns an error if:
241    ///   * The file can't be read
242    ///   * The data contained isn't a float
243    ///   * Feature data is missing
244    ///   * The amount of columns can't be determined
245    pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
246        let mut file = std::fs::File::open(path)?;
247        let mut contents = String::new();
248        file.read_to_string(&mut contents)?;
249
250        let mut length = 0;
251        for line in contents.lines() {
252            if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
253                continue;
254            }
255
256            length = line.split(',').collect::<Vec<&str>>().len();
257            break;
258        }
259
260        ensure!(length > 0, "Failed to determine data length.");
261        Self::from_csv_string(&contents, length - 1)
262    }
263
264    /// Create a dataset struct from a CSV string
265    ///
266    /// # Errors
267    ///
268    /// Returns an error if:
269    ///   * The data contained isn't numeric
270    ///   * Feature data is missing
271    ///   * The expected amount of data isn't encountered
272    pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
273        let mut data: Vec<Vec<f32>> = Vec::new();
274        let mut labels = Vec::new();
275        let mut features = Vec::new();
276        let mut file_type = FileType::NotSet;
277
278        for (row_number, line) in contents.lines().enumerate() {
279            if line.is_empty() {
280                continue;
281            }
282
283            if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
284                if line.contains(FEATURES_PREFIX) {
285                    let offset =
286                        line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
287                    let line = line[offset..].trim();
288                    features = line
289                        .split(',')
290                        .filter_map(|f| hex::decode(f.trim()).ok())
291                        .collect();
292                }
293
294                if line.contains(FILE_TYPE_PREFIX)
295                    && let Some(file_type_str) = line.split(':').nth(1)
296                {
297                    let file_type_str = file_type_str.trim();
298                    let ftype: Result<_, serde::de::value::Error> =
299                        FileType::deserialize(String::from(file_type_str).into_deserializer());
300                    file_type = ftype?;
301                }
302            }
303
304            if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
305                continue;
306            }
307            let row = line.split(',').collect::<Vec<&str>>();
308            let mut row_float = Vec::with_capacity(data_length);
309            for r in row.iter().take(data_length) {
310                row_float.push(r.parse::<f32>().map_err(|_| {
311                    anyhow::Error::msg(format!("Non-float {r} encountered on CSV row {row_number}"))
312                })?);
313            }
314            if let Some(first_row) = data.first() {
315                ensure!(
316                    first_row.len() == row_float.len(),
317                    "CSV line {row_number} has invalid length {}, expected {}",
318                    row_float.len(),
319                    first_row.len()
320                );
321            }
322            data.push(row_float);
323            if row.len() == data_length + 1 {
324                let l = row[data_length].parse::<u8>().map_err(|_| {
325                    anyhow::Error::msg(format!(
326                        "Non-float label {} encountered on CSV row {row_number}",
327                        row[data_length]
328                    ))
329                })?;
330                labels.push(l);
331            } else if row.len() > data_length {
332                bail!(
333                    "CSV row had more than one label on row {row_number}, which isn't supported."
334                );
335            }
336        }
337
338        ensure!(
339            features.len() == data[0].len(),
340            "Features need to be empty or the same size as the data length."
341        );
342
343        ensure!(
344            file_type != FileType::NotSet,
345            "No file type specified in CSV file."
346        );
347        Ok(Self {
348            data,
349            labels,
350            features,
351            ftype: file_type,
352        })
353    }
354
355    /// Get a file type object from a line in a dataset file
356    #[inline]
357    pub(crate) fn file_type_from_line(line: &str) -> Result<FileType, serde::de::value::Error> {
358        let line = line.split(':').nth(1).unwrap_or(line).to_uppercase();
359        let ftype: Result<_, serde::de::value::Error> =
360            FileType::deserialize(String::from(line.trim()).into_deserializer());
361        ftype
362    }
363
364    /// Create a dataset struct from an ARFF string
365    ///
366    /// # Errors
367    ///
368    /// Returns an error if:
369    ///   * The file can't be read
370    ///   * The data contained isn't numeric
371    ///   * Feature data is missing
372    ///   * The expected amount of data isn't encountered
373    pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
374        let mut file = std::fs::File::open(path)?;
375        let mut contents = String::new();
376        file.read_to_string(&mut contents)?;
377
378        Self::from_arff_string(&contents)
379    }
380
381    /// Create a dataset struct from an ARFF string
382    ///
383    /// # Errors
384    ///
385    /// Returns an error if:
386    ///   * The data contained isn't numeric
387    ///   * Feature data is missing
388    ///   * The expected amount of data isn't encountered
389    pub fn from_arff_string(contents: &str) -> Result<Self> {
390        let mut data: Vec<Vec<f32>> = Vec::new();
391        let mut labels = Vec::new();
392        let mut features = Vec::new();
393        let mut file_type = FileType::NotSet;
394        let mut passed_data = false;
395
396        for (row_number, line) in contents.lines().enumerate() {
397            if line.is_empty() {
398                continue;
399            }
400
401            if (line.starts_with('%') || line.starts_with('#')) && line.contains(FILE_TYPE_PREFIX) {
402                file_type = Self::file_type_from_line(line)?;
403                continue;
404            }
405
406            if line.contains("@ATTRIBUTE") {
407                let parts: Vec<&str> = line.split_ascii_whitespace().collect();
408                if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
409                    match hex::decode(parts[1]) {
410                        Ok(feat) => features.push(feat),
411                        Err(e) => {
412                            bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
413                        }
414                    }
415                }
416            }
417
418            if line.contains("@DATA") {
419                passed_data = true;
420                continue;
421            }
422
423            // Basically a CSV at this point
424            if passed_data {
425                let row = line.split(',').collect::<Vec<&str>>();
426                let data_length = row.len() - 1;
427                let mut row_float = Vec::with_capacity(data_length);
428                for r in row.iter().take(data_length) {
429                    row_float.push(r.parse::<f32>().map_err(|_| {
430                        anyhow::Error::msg(format!(
431                            "Non-float encountered on ARFF row {row_number}"
432                        ))
433                    })?);
434                }
435                if let Some(first_row) = data.first() {
436                    ensure!(
437                        first_row.len() == row_float.len(),
438                        "ARFF line {row_number} has invalid length {}, expected {}",
439                        row_float.len(),
440                        first_row.len()
441                    );
442                }
443                data.push(row_float);
444                if row.len() == data_length + 1 {
445                    let l = row[data_length].parse::<u8>().map_err(|_| {
446                        anyhow::Error::msg(format!(
447                            "Non-float encountered on ARFF row {row_number}"
448                        ))
449                    })?;
450                    labels.push(l);
451                } else if row.len() > data_length {
452                    bail!(
453                        "Arff row had more than one label on row {row_number}, which isn't supported."
454                    );
455                }
456            }
457        }
458
459        ensure!(
460            features.len() == data[0].len(),
461            "Features need to be empty or the same size as the data length."
462        );
463
464        ensure!(
465            file_type != FileType::NotSet,
466            "No file type specified in ARFF file."
467        );
468        Ok(Self {
469            data,
470            labels,
471            features,
472            ftype: file_type,
473        })
474    }
475
476    /// Create a dataset struct from a libsvm file
477    ///
478    /// # Errors
479    ///
480    /// Returns an error if:
481    ///   * The file can't be read
482    ///   * Feature data is missing
483    ///   * The data isn't in the expected format
484    ///   * The expected amount of data isn't encountered
485    pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
486        let mut file = std::fs::File::open(path)?;
487        let mut contents = String::new();
488        file.read_to_string(&mut contents)?;
489
490        Self::from_libsvm_string(&contents)
491    }
492
493    /// Create a dataset from a libsvm string
494    ///
495    /// # Errors
496    ///
497    /// Returns an error if the file doesn't contain the expected format or is missing features
498    pub fn from_libsvm_string(contents: &str) -> Result<Self> {
499        let mut data = Vec::new();
500        let mut labels = Vec::new();
501        let mut features = Vec::new();
502        let mut file_type = FileType::NotSet;
503
504        for (row_number, line) in contents.lines().enumerate() {
505            if line.is_empty() {
506                continue;
507            }
508
509            if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
510                if line.contains(FEATURES_PREFIX) {
511                    let offset =
512                        line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
513                    let line = line[offset..].trim();
514                    features = line
515                        .split(',')
516                        .filter_map(|f| hex::decode(f.trim()).ok())
517                        .collect();
518                }
519
520                if line.contains(FILE_TYPE_PREFIX) {
521                    file_type = Self::file_type_from_line(line)?;
522                }
523            }
524
525            if line.is_empty() || line.starts_with('%') || line.starts_with('#') {
526                continue;
527            }
528
529            let parts = line.split_whitespace().collect::<Vec<&str>>();
530            let Ok(label) = parts[0].trim().parse::<u8>() else {
531                bail!(
532                    "Encountered a non-numeric label {} on line {row_number}",
533                    parts[0]
534                );
535            };
536            let mut row = vec![0.0f32; features.len()];
537
538            for part in parts.iter().skip(1) {
539                let part_parts = part.split(':').collect::<Vec<&str>>();
540                let Ok(part_index) = part_parts[0].trim().parse::<usize>() else {
541                    bail!(
542                        "Encountered a non-numeric index {} on line {row_number}",
543                        part_parts[0]
544                    );
545                };
546                let Ok(part_value) = part_parts[1].trim().parse::<f32>() else {
547                    bail!(
548                        "Encountered a non-numeric value {} on line {row_number}",
549                        part_parts[1]
550                    );
551                };
552
553                if part_index > row.len() && !features.is_empty() {
554                    bail!(
555                        "Encountered a value at index {part_index} greater than expected size {} on line {row_number}",
556                        data.len()
557                    );
558                }
559
560                if row.is_empty() {
561                    row = vec![0.0; part_index + 1];
562                } else if part_index >= row.len() {
563                    row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
564                }
565                row[part_index] = part_value;
566            }
567
568            data.push(row);
569            labels.push(label);
570        }
571
572        let data_len = data[0].len();
573        for row in &data {
574            if row.len() != data_len {
575                bail!(
576                    "Encountered a row with length {} but expected length {data_len}",
577                    row.len()
578                );
579            }
580        }
581
582        ensure!(
583            features.len() == data[0].len(),
584            "Features need to be empty or the same size as the data length."
585        );
586
587        ensure!(
588            file_type != FileType::NotSet,
589            "No file type specified in libsvm file."
590        );
591        Ok(Self {
592            data,
593            labels,
594            features,
595            ftype: file_type,
596        })
597    }
598
599    /// Given paths to malicious files, benign files, and n-grams (features), get a Dataset object.
600    ///
601    /// # Errors
602    /// This will fail if:
603    /// * The directories for benign or malicious files don't exist or are empty.
604    /// * The n-gram feature file doesn't exist, is empty, or doesn't have hexidecimal-encoded features
605    #[allow(clippy::too_many_lines)]
606    pub fn create_save_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
607        malicious_dir: P,
608        benign_dir: P,
609        ngrams_file: P,
610        output_file: P,
611    ) -> Result<()> {
612        const SUPPORTED_FORMATS: [DatasetFormat; 3] =
613            [DatasetFormat::CSV, DatasetFormat::ARFF, DatasetFormat::SVM];
614
615        let output_format = DatasetFormat::try_from(output_file.as_ref())?;
616        ensure!(
617            SUPPORTED_FORMATS.contains(&output_format),
618            "Only CSV, ARFF, or SVM formats are supported here."
619        );
620
621        let ngrams = NgramsFile::load(ngrams_file)?;
622        let mut output_file = std::fs::File::create(output_file)?;
623        writeln!(output_file, "# {FILE_TYPE_PREFIX} {:?}", ngrams.ftype)?;
624
625        match output_format {
626            DatasetFormat::SVM | DatasetFormat::CSV => {
627                let feature_string_vec = ngrams
628                    .clone()
629                    .into_vec()
630                    .iter()
631                    .map(hex::encode)
632                    .collect::<Vec<String>>();
633                writeln!(
634                    output_file,
635                    "# {FEATURES_PREFIX} {}",
636                    feature_string_vec.join(", ")
637                )?;
638            }
639
640            DatasetFormat::ARFF => {
641                let feature_string_vec = ngrams
642                    .clone()
643                    .into_vec()
644                    .iter()
645                    .map(hex::encode)
646                    .collect::<Vec<String>>();
647                for feature in feature_string_vec {
648                    let feature_hex = hex::encode(feature);
649                    writeln!(output_file, "@ATTRIBUTE {feature_hex} NUMERIC")?;
650                }
651            }
652        }
653
654        for entry in WalkDir::new(malicious_dir)
655            .max_depth(crate::MAX_RECURSION_DEPTH)
656            .follow_links(true)
657            .into_iter()
658            .flatten()
659        {
660            if entry.file_type().is_file() {
661                match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
662                    Ok(features) => match output_format {
663                        DatasetFormat::CSV | DatasetFormat::ARFF => {
664                            let mut line = features
665                                .iter()
666                                .map(|p| format!("{p}"))
667                                .collect::<Vec<String>>()
668                                .join(",");
669                            line.push_str(",1\n");
670                            output_file.write_all(line.as_bytes())?;
671                        }
672
673                        DatasetFormat::SVM => {
674                            write!(output_file, "1")?;
675                            for (data_index, data) in features.iter().enumerate() {
676                                if *data != 0.0000 {
677                                    write!(output_file, " {data_index}:{data}")?;
678                                }
679                            }
680                            writeln!(output_file)?;
681                        }
682                    },
683                    Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
684                }
685            }
686        }
687
688        for entry in WalkDir::new(benign_dir)
689            .max_depth(crate::MAX_RECURSION_DEPTH)
690            .follow_links(true)
691            .into_iter()
692            .flatten()
693        {
694            if entry.file_type().is_file() {
695                match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
696                    Ok(features) => match output_format {
697                        DatasetFormat::CSV | DatasetFormat::ARFF => {
698                            let mut line = features
699                                .iter()
700                                .map(|p| format!("{p}"))
701                                .collect::<Vec<String>>()
702                                .join(",");
703                            line.push_str(",0\n");
704                            output_file.write_all(line.as_bytes())?;
705                        }
706
707                        DatasetFormat::SVM => {
708                            write!(output_file, "0")?;
709                            for (data_index, data) in features.iter().enumerate() {
710                                if *data != 0.0000 {
711                                    write!(output_file, " {data_index}:{data}")?;
712                                }
713                            }
714                            writeln!(output_file)?;
715                        }
716                    },
717                    Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
718                }
719            }
720        }
721
722        output_file.sync_all()?;
723        Ok(())
724    }
725
726    /// Save a dataset as a CSV
727    ///
728    /// # Errors
729    ///
730    /// An error will result if the file can't be opened for writing
731    pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
732        let mut file = std::fs::File::create(path)?;
733
734        let feature_string_vec = self
735            .features
736            .iter()
737            .map(hex::encode)
738            .collect::<Vec<String>>();
739        writeln!(
740            file,
741            "# {FEATURES_PREFIX} {}",
742            feature_string_vec.join(", ")
743        )?;
744        writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
745
746        for index in 0..self.data.len() {
747            let mut line = self.data[index]
748                .iter()
749                .map(|p| format!("{p}"))
750                .collect::<Vec<String>>()
751                .join(",");
752
753            if !self.labels.is_empty() {
754                line = format!("{line},{}", self.labels[index]);
755            }
756            line.push('\n');
757
758            file.write_all(line.as_bytes())?;
759        }
760
761        file.sync_all().map_err(Into::into)
762    }
763
764    /// Save a dataset as an ARFF file
765    ///
766    /// # Errors
767    ///
768    /// An error will result if the file can't be opened for writing
769    pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
770        let mut file = std::fs::File::create(path)?;
771        writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
772
773        for feature in &self.features {
774            let feature_hex = hex::encode(feature);
775            file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
776        }
777
778        if !self.labels.is_empty() {
779            file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
780        }
781
782        file.write_all("\n@DATA\n".as_bytes())?;
783        for index in 0..self.data.len() {
784            let mut line = self.data[index]
785                .iter()
786                .map(|p| format!("{p}"))
787                .collect::<Vec<String>>()
788                .join(",");
789
790            if !self.labels.is_empty() {
791                line = format!("{line},{}", self.labels[index]);
792            }
793            line.push('\n');
794
795            file.write_all(line.as_bytes())?;
796        }
797
798        file.sync_all().map_err(Into::into)
799    }
800
801    /// Save a dataset as a libsvm file
802    ///
803    /// # Errors
804    ///
805    /// An error will result if the file can't be opened for writing
806    pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
807        ensure!(
808            !self.labels.is_empty(),
809            "Labels are required to create an libsvm file."
810        );
811        let mut file = std::fs::File::create(path)?;
812
813        let feature_string_vec = self
814            .features
815            .iter()
816            .map(hex::encode)
817            .collect::<Vec<String>>();
818        writeln!(
819            file,
820            "# {FEATURES_PREFIX} {}",
821            feature_string_vec.join(", ")
822        )?;
823        writeln!(file, "# {FILE_TYPE_PREFIX} {:?}", self.ftype)?;
824
825        for index in 0..self.data.len() {
826            file.write_all(format!("{}", self.labels[index]).as_bytes())?;
827            for (data_index, data) in self.data[index].iter().enumerate() {
828                if *data != 0.0000 {
829                    file.write_all(format!(" {data_index}:{data}").as_bytes())?;
830                }
831            }
832
833            file.write_all(b"\n")?;
834        }
835
836        file.sync_all().map_err(Into::into)
837    }
838
839    /// Save the dataset using the file extension to determine data format
840    ///
841    /// # Errors
842    ///
843    /// There's an error if the file can't be written or if the format can't be determined
844    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
845        if let Some(extension) = path.as_ref().extension() {
846            return match extension.to_str().unwrap_or_default() {
847                "arff" => self.save_arff(path),
848                "csv" => self.save_csv(path),
849                "svm" | "libsvm" => self.save_libsvm(path),
850                "json" => {
851                    let contents = serde_json::to_string_pretty(self)?;
852                    let mut file = std::fs::File::create(path)?;
853                    file.write_all(contents.as_bytes())?;
854                    file.sync_all().map_err(Into::into)
855                }
856                "toml" => {
857                    let contents = toml::to_string_pretty(self)?;
858                    let mut file = std::fs::File::create(path)?;
859                    file.write_all(contents.as_bytes())?;
860                    file.sync_all().map_err(Into::into)
861                }
862                ext => {
863                    bail!("Unsupported/unknown data type '{ext}'");
864                }
865            };
866        }
867
868        bail!("No extension, can't determine file type.");
869    }
870
871    /// Return dataset size
872    #[inline]
873    #[must_use]
874    pub fn len(&self) -> usize {
875        self.data.len()
876    }
877
878    /// Indicate if the dataset is empty
879    #[inline]
880    #[must_use]
881    pub fn is_empty(&self) -> bool {
882        self.data.is_empty()
883    }
884
885    /// Ensure the dataset is valid
886    /// * Same size data columns
887    /// * If present, the amount of data rows equals the amount of labels
888    #[inline]
889    #[must_use]
890    pub fn validate(&self) -> bool {
891        let data_len = match self.data.first() {
892            Some(first) => first.len(),
893            None => return false,
894        };
895
896        // Ensure data records are the same size
897        for record in &self.data {
898            if record.len() != data_len {
899                #[cfg(debug_assertions)]
900                eprint!("Expected record size {data_len}, got {}", record.len());
901                return false;
902            }
903        }
904
905        let feature_len = if let Some(first) = self.features.first() {
906            first.len()
907        } else {
908            #[cfg(debug_assertions)]
909            eprintln!("Features data is missing");
910            return false;
911        };
912
913        for feature in &self.features {
914            if feature.len() != feature_len {
915                #[cfg(debug_assertions)]
916                eprint!("Expected feature size {feature_len}, got {}", feature.len());
917                return false;
918            }
919        }
920
921        // If we have labels, ensure it's the same size as the data
922        (self.labels.is_empty() || self.labels.len() == self.data.len())
923            && self.features.len() == data_len
924            && self.ftype != FileType::NotSet
925    }
926
927    /// Shuffle the data, using roughly 10 X log10(size).
928    /// So 10 records = 10 iterations, 1,000 records gets 30 iterations
929    pub fn shuffle(&mut self) {
930        // Avoid a panic since `.ilog10()` panics on zero.
931        if !self.is_empty() {
932            let iterations = self.data.len().ilog10() * 10;
933            self.shuffle_iterations(iterations);
934        }
935    }
936
937    /// Shuffle the data with a specified amount of iterations, ensures
938    /// that the labels are swapped with the data, if present
939    pub fn shuffle_iterations(&mut self, iterations: u32) {
940        use rand::RngExt;
941
942        if !self.is_empty() {
943            let mut rng = rand::rng();
944
945            for _ in 0..iterations {
946                let a = rng.random_range(0..self.data.len());
947                let b = rng.random_range(0..self.data.len());
948                let b = if b == a {
949                    rng.random_range(0..self.data.len())
950                } else {
951                    b
952                };
953
954                self.data.swap(a, b);
955                if !self.labels.is_empty() {
956                    self.labels.swap(a, b);
957                }
958            }
959        }
960    }
961
962    /// Split the dataset, ideally into train/test datasets.
963    /// The ratio indicates how much data is kept, the remaining size is shed and returned.
964    #[must_use]
965    #[allow(
966        clippy::cast_sign_loss,
967        clippy::cast_possible_truncation,
968        clippy::cast_precision_loss
969    )]
970    pub fn split(&mut self, ratio: f32) -> Self {
971        let ratio = ratio.abs();
972        let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
973        let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
974
975        let new_data = self.data.drain(new_size..).collect();
976        let new_labels = if self.labels.is_empty() {
977            vec![]
978        } else {
979            self.labels.drain(new_size..).collect()
980        };
981
982        Self {
983            data: new_data,
984            labels: new_labels,
985            features: self.features.clone(),
986            ftype: self.ftype,
987        }
988    }
989
990    /// The model training allows for the algorithm to not only train a model but determine the
991    /// features most useful for determining benign vs. malicious. This action removes the features
992    /// deemed unneeded.
993    ///
994    /// # Errors
995    ///
996    /// If the model would remove all features, an error is returned as an empty dataset isn't useful,
997    /// and it's instead likely the modal and dataset weren't for the same data collection.
998    pub fn reduce(&mut self, model: &LogisticRegression) -> Result<Vec<usize>> {
999        let mut removed = vec![];
1000
1001        for (index, feature) in self.features.iter().enumerate() {
1002            if !model.features.contains_key(feature) {
1003                removed.push(index);
1004            }
1005        }
1006
1007        if removed.len() == self.data[0].len() {
1008            bail!(
1009                "This dataset and model are probably not from the same data - this operation would delete all the data!"
1010            );
1011        }
1012
1013        removed.sort_unstable();
1014        removed.reverse();
1015
1016        self.features
1017            .retain(|feature| model.features.contains_key(feature));
1018
1019        for row in &mut self.data {
1020            for removed in &removed {
1021                row.remove(*removed);
1022            }
1023        }
1024
1025        Ok(removed)
1026    }
1027
1028    /// Returns an iterator over a column
1029    #[must_use]
1030    pub fn column_iter(&'_ self, index: usize) -> Option<ColumnIterator<'_>> {
1031        if index < self.data[0].len() {
1032            Some(ColumnIterator {
1033                dataset: self,
1034                column_index: index,
1035                current_row_index: 0,
1036            })
1037        } else {
1038            None
1039        }
1040    }
1041}
1042
1043/// Iterator for accessing the data by a given column without copying the data
1044pub struct ColumnIterator<'a> {
1045    /// Reference to a dataset
1046    dataset: &'a Dataset,
1047
1048    /// The column we're looking at
1049    column_index: usize,
1050
1051    /// The current row we're looking at
1052    current_row_index: usize,
1053}
1054
1055impl Iterator for ColumnIterator<'_> {
1056    type Item = f32;
1057
1058    fn next(&mut self) -> Option<Self::Item> {
1059        if self.current_row_index < self.dataset.data.len() {
1060            let val = self.dataset.data[self.current_row_index][self.column_index];
1061            self.current_row_index += 1;
1062            Some(val)
1063        } else {
1064            None
1065        }
1066    }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071    use crate::dataset::Dataset;
1072
1073    #[test]
1074    fn xor() {
1075        let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1076        assert!(csv_dataset.validate());
1077
1078        let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
1079        assert!(arff_dataset.validate());
1080
1081        let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
1082        assert!(svm_dataset.validate());
1083
1084        assert_eq!(csv_dataset, arff_dataset);
1085        assert_eq!(csv_dataset, svm_dataset);
1086        assert_eq!(arff_dataset, svm_dataset);
1087    }
1088
1089    #[test]
1090    fn xor_no_label() {
1091        assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
1092        assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
1093    }
1094
1095    #[test]
1096    fn shuffle() {
1097        let original_dataset =
1098            Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1099        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1100        dataset.shuffle();
1101
1102        assert_eq!(original_dataset, dataset);
1103        assert_ne!(original_dataset.data, dataset.data);
1104        assert_ne!(original_dataset.labels, dataset.labels);
1105        assert_eq!(original_dataset.features, dataset.features);
1106    }
1107
1108    #[test]
1109    fn split() {
1110        let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1111        let original_size = dataset.len();
1112        let smaller = dataset.split(0.8);
1113
1114        println!(
1115            "Original: {original_size}, New size: {}, Smaller dataset: {}",
1116            dataset.len(),
1117            smaller.len()
1118        );
1119        assert!(smaller.len() < dataset.len());
1120        assert_eq!(original_size, dataset.len() + smaller.len());
1121        assert_ne!(dataset, smaller);
1122        assert_eq!(dataset.features, smaller.features);
1123    }
1124
1125    #[test]
1126    fn save() {
1127        const COPY_CSV: &str = "xor_copy.csv";
1128        const COPY_ARFF: &str = "xor_copy.arff";
1129        const COPY_SVM: &str = "xor_copy.svm";
1130
1131        let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
1132        dataset.save_csv(COPY_CSV).unwrap();
1133        dataset.save_arff(COPY_ARFF).unwrap();
1134        dataset.save_libsvm(COPY_SVM).unwrap();
1135
1136        let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
1137        assert_eq!(dataset, dataset2);
1138
1139        let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
1140        assert_eq!(dataset, dataset3);
1141        assert_eq!(dataset2, dataset3);
1142
1143        let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
1144        assert_eq!(dataset, dataset4);
1145        assert_eq!(dataset3, dataset4);
1146
1147        std::fs::remove_file(COPY_CSV).unwrap();
1148        std::fs::remove_file(COPY_ARFF).unwrap();
1149        std::fs::remove_file(COPY_SVM).unwrap();
1150    }
1151}