use crate::Bytes;
use std::io::{Read, Write};
use std::path::Path;
use std::str::FromStr;
use std::sync::RwLock;
use anyhow::{bail, ensure, Result};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use walkdir::WalkDir;
const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
const FEATURES_PREFIX: &str = "Features:";
#[inline]
pub fn featurize_file<P: AsRef<Path>>(file: P, n: usize, features: &[Bytes]) -> Result<Vec<f32>> {
let contents = std::fs::read(file)?;
let mut feature_vector = vec![0.0; features.len()];
for window in contents.windows(n) {
if let Some(position) = features.iter().position(|n| n == window) {
feature_vector[position] = 1.0;
}
}
Ok(feature_vector)
}
#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
pub enum DatasetFormat {
ARFF,
CSV,
SVM,
}
impl FromStr for DatasetFormat {
type Err = String;
fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
match value.to_lowercase().as_str() {
"arff" => Ok(Self::ARFF),
"csv" => Ok(Self::CSV),
"svm" => Ok(Self::SVM),
x => Err(format!("Unknown data format '{x}'")),
}
}
}
#[derive(Debug, Clone, Default, Deserialize, Serialize)]
pub struct Dataset {
pub data: Vec<Vec<f32>>,
#[serde(default)]
pub labels: Vec<f32>,
pub features: Vec<Bytes>,
}
impl PartialEq for Dataset {
fn eq(&self, other: &Self) -> bool {
if self.data.len() != other.data.len()
|| self.labels.len() != other.labels.len()
|| self.features.len() != other.features.len()
{
return false;
}
for this_data in &self.data {
if !other.data.contains(this_data) {
return false;
}
}
for other_data in &other.data {
if !self.data.contains(other_data) {
return false;
}
}
if !self.labels.is_empty() {
for this_label in &self.labels {
if !other.labels.contains(this_label) {
return false;
}
}
for other_label in &other.labels {
if !self.labels.contains(other_label) {
return false;
}
}
}
for this_features in &self.features {
if !other.features.contains(this_features) {
return false;
}
}
for other_feature in &other.features {
if !self.features.contains(other_feature) {
return false;
}
}
true
}
}
impl Dataset {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
if let Some(extension) = path.as_ref().extension() {
return match extension.to_str().unwrap_or_default() {
"arff" => Dataset::from_arff_file(path.as_ref()),
"csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
"svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
"json" => {
let contents = std::fs::read_to_string(path.as_ref())?;
serde_json::from_str(&contents).map_err(Into::into)
}
"toml" => {
let contents = std::fs::read_to_string(path.as_ref())?;
toml::from_str(&contents).map_err(Into::into)
}
ext => {
bail!("Unsupported/unknown data type '{ext}'");
}
};
}
bail!("No extension, can't determine file type.");
}
pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
Self::from_csv_string(&contents, data_length)
}
pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
let mut length = 0;
for line in contents.lines() {
if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
continue;
}
length = line.split(',').collect::<Vec<&str>>().len();
break;
}
ensure!(length > 0, "Failed to determine data length.");
Self::from_csv_string(&contents, length - 1)
}
pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
let mut data: Vec<Vec<f32>> = Vec::new();
let mut labels = Vec::new();
let mut features = Vec::new();
for (row_number, line) in contents.lines().enumerate() {
if line.is_empty() {
continue;
}
if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) && line.contains(FEATURES_PREFIX) {
let offset = line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
let line = line[offset..].trim();
features = line
.split(',')
.filter_map(|f| hex::decode(f.trim()).ok())
.collect();
}
if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
continue;
}
let row = line.split(',').collect::<Vec<&str>>();
let mut row_float = Vec::with_capacity(data_length);
for r in row.iter().take(data_length) {
row_float.push(r.parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!("Non-float encountered on CSV row {row_number}"))
})?);
}
if let Some(first_row) = data.first() {
ensure!(
first_row.len() == row_float.len(),
"CSV line {row_number} has invalid length {}, expected {}",
row_float.len(),
first_row.len()
);
}
data.push(row_float);
if row.len() == data_length + 1 {
let l = row[data_length].parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!("Non-float encountered on CSV row {row_number}"))
})?;
labels.push(l);
} else if row.len() > data_length {
bail!(
"CSV row had more than one label on row {row_number}, which isn't supported."
);
}
}
ensure!(
features.len() == data[0].len(),
"Features need to be empty or the same size as the data length."
);
Ok(Self {
data,
labels,
features,
})
}
pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
Self::from_arff_string(&contents)
}
pub fn from_arff_string(contents: &str) -> Result<Self> {
let mut data: Vec<Vec<f32>> = Vec::new();
let mut labels = Vec::new();
let mut features = Vec::new();
let mut passed_data = false;
for (row_number, line) in contents.lines().enumerate() {
if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
continue;
}
if line.contains("@ATTRIBUTE") {
let parts: Vec<&str> = line.split_ascii_whitespace().collect();
if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
match hex::decode(parts[1]) {
Ok(feat) => features.push(feat),
Err(e) => {
bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
}
}
}
}
if line.contains("@DATA") {
passed_data = true;
continue;
}
if passed_data {
let row = line.split(',').collect::<Vec<&str>>();
let data_length = row.len() - 1;
let mut row_float = Vec::with_capacity(data_length);
for r in row.iter().take(data_length) {
row_float.push(r.parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!(
"Non-float encountered on ARFF row {row_number}"
))
})?);
}
if let Some(first_row) = data.first() {
ensure!(
first_row.len() == row_float.len(),
"ARFF line {row_number} has invalid length {}, expected {}",
row_float.len(),
first_row.len()
);
}
data.push(row_float);
if row.len() == data_length + 1 {
let l = row[data_length].parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!(
"Non-float encountered on ARFF row {row_number}"
))
})?;
labels.push(l);
} else if row.len() > data_length {
bail!("Arff row had more than one label on row {row_number}, which isn't supported.");
}
}
}
ensure!(
features.len() == data[0].len(),
"Features need to be empty or the same size as the data length."
);
Ok(Self {
data,
labels,
features,
})
}
pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
Self::from_libsvm_string(&contents)
}
pub fn from_libsvm_string(contents: &str) -> Result<Self> {
let mut data = Vec::new();
let mut labels = Vec::new();
let mut features = Vec::new();
for (row_number, line) in contents.lines().enumerate() {
if line.is_empty() {
continue;
}
if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) && line.contains(FEATURES_PREFIX) {
let offset = line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
let line = line[offset..].trim();
features = line
.split(',')
.filter_map(|f| hex::decode(f.trim()).ok())
.collect();
}
if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
continue;
}
let parts = line.split_whitespace().collect::<Vec<&str>>();
let label = parts[0].parse::<f32>()?;
let mut row = vec![0.0f32; features.len()];
for part in parts.iter().skip(1) {
let part_parts = part.split(':').collect::<Vec<&str>>();
let part_index = part_parts[0].parse::<usize>()?;
let part_value = part_parts[1].parse::<f32>()?;
if part_index > row.len() && !features.is_empty() {
bail!("Encountered a value at index {part_index} greater than expected size {} on line {row_number}", data.len());
}
if row.is_empty() {
row = vec![0.0; part_index + 1];
} else if part_index >= row.len() {
row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
}
row[part_index] = part_value;
}
data.push(row);
labels.push(label);
}
let data_len = data[0].len();
for row in &data {
if row.len() != data_len {
bail!(
"Encountered a row with length {} but expected length {data_len}",
row.len()
);
}
}
ensure!(
features.len() == data[0].len(),
"Features need to be empty or the same size as the data length."
);
Ok(Self {
data,
labels,
features,
})
}
pub fn create_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
malicious_dir: P,
benign_dir: P,
ngrams_file: P,
) -> Result<Self> {
let ngram_contents = std::fs::read_to_string(&ngrams_file)?;
let mut n = 0;
let ngrams = ngram_contents
.lines()
.filter_map(|l| {
let line = if let Some(l) = l.split(',').collect::<Vec<&str>>().first() {
l
} else {
l
};
if !line.len().is_multiple_of(2) {
eprintln!("Line {line} has odd number of characters.");
return None;
}
if n == 0 {
n = line.len() / 2;
} else if line.len() / 2 != n {
eprintln!(
"Line {line} has unexpected length of {} bytes, expected {n}",
line.len() / 2
);
return None;
}
hex::decode(line).ok()
})
.collect::<Vec<_>>();
ensure!(
!ngrams.is_empty(),
"No n-grams read from {}.",
ngrams_file.as_ref().display()
);
let mut paths_labels = Vec::new();
for entry in WalkDir::new(malicious_dir)
.max_depth(crate::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
paths_labels.push((entry, 1.0));
}
}
for entry in WalkDir::new(benign_dir)
.max_depth(crate::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
paths_labels.push((entry, 0.0));
}
}
let found_files = paths_labels.len();
let dataset = Dataset::default();
let dataset_lock = RwLock::new(dataset);
paths_labels.into_par_iter().for_each(|(path, label)| {
match featurize_file(path.path(), n, &ngrams) {
Ok(features) => {
if let Ok(mut data) = dataset_lock.write() {
data.data.push(features);
data.labels.push(label);
}
}
Err(e) => eprintln!("Failed to featurized {}: {e}", path.path().display()),
}
});
let mut dataset = dataset_lock.into_inner()?;
dataset.features = ngrams;
if dataset.data.len() != found_files {
eprintln!(
"Warning: found {found_files} but only have features for {} files.",
dataset.data.len()
);
}
Ok(dataset)
}
pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let mut file = std::fs::File::create(path)?;
let feature_string_vec = self
.features
.iter()
.map(hex::encode)
.collect::<Vec<String>>();
let features_string = format!("# {FEATURES_PREFIX} {}\n", feature_string_vec.join(", "));
file.write_all(features_string.as_bytes())?;
for index in 0..self.data.len() {
let mut line = self.data[index]
.iter()
.map(|p| format!("{p}"))
.collect::<Vec<String>>()
.join(",");
if !self.labels.is_empty() {
if self.labels[index] > 0.9 {
line.push_str(",1");
} else {
line.push_str(",0");
}
}
line.push('\n');
file.write_all(line.as_bytes())?;
}
file.sync_all().map_err(Into::into)
}
pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let mut file = std::fs::File::create(path)?;
for feature in &self.features {
let feature_hex = hex::encode(feature);
file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
}
if !self.labels.is_empty() {
file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
}
file.write_all("\n@DATA\n".as_bytes())?;
for index in 0..self.data.len() {
let mut line = self.data[index]
.iter()
.map(|p| format!("{p}"))
.collect::<Vec<String>>()
.join(",");
if !self.labels.is_empty() {
if self.labels[index] > 0.9 {
line.push_str(",1");
} else {
line.push_str(",0");
}
}
line.push('\n');
file.write_all(line.as_bytes())?;
}
file.sync_all().map_err(Into::into)
}
pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
ensure!(
!self.labels.is_empty(),
"Labels are required to create an libsvm file."
);
let mut file = std::fs::File::create(path)?;
let feature_string_vec = self
.features
.iter()
.map(hex::encode)
.collect::<Vec<String>>();
let features_string = format!("# {FEATURES_PREFIX} {}\n", feature_string_vec.join(", "));
file.write_all(features_string.as_bytes())?;
for index in 0..self.data.len() {
file.write_all(format!("{}", self.labels[index]).as_bytes())?;
for (data_index, data) in self.data[index].iter().enumerate() {
if *data != 0.0000 {
file.write_all(format!(" {data_index}:{data}").as_bytes())?;
}
}
file.write_all(b"\n")?;
}
file.sync_all().map_err(Into::into)
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
if let Some(extension) = path.as_ref().extension() {
return match extension.to_str().unwrap_or_default() {
"arff" => self.save_arff(path),
"csv" => self.save_csv(path),
"svm" | "libsvm" => self.save_libsvm(path),
"json" => {
let contents = serde_json::to_string_pretty(self)?;
let mut file = std::fs::File::create(path)?;
file.write_all(contents.as_bytes())?;
file.sync_all().map_err(Into::into)
}
"toml" => {
let contents = toml::to_string_pretty(self)?;
let mut file = std::fs::File::create(path)?;
file.write_all(contents.as_bytes())?;
file.sync_all().map_err(Into::into)
}
ext => {
bail!("Unsupported/unknown data type '{ext}'");
}
};
}
bail!("No extension, can't determine file type.");
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
#[must_use]
pub fn validate(&self) -> bool {
let data_len = match self.data.first() {
Some(first) => first.len(),
None => return false,
};
for record in &self.data {
if record.len() != data_len {
#[cfg(debug_assertions)]
eprint!("Expected record size {data_len}, got {}", record.len());
return false;
}
}
let feature_len = if let Some(first) = self.features.first() {
first.len()
} else {
#[cfg(debug_assertions)]
eprintln!("Features data is missing");
return false;
};
for feature in &self.features {
if feature.len() != feature_len {
#[cfg(debug_assertions)]
eprint!("Expected feature size {feature_len}, got {}", feature.len());
return false;
}
}
(self.labels.is_empty() || self.labels.len() == self.data.len())
&& self.features.len() == data_len
}
pub fn shuffle(&mut self) {
if !self.is_empty() {
let iterations = self.data.len().ilog10() * 10;
self.shuffle_iterations(iterations);
}
}
pub fn shuffle_iterations(&mut self, iterations: u32) {
use rand::Rng;
if !self.is_empty() {
let mut rng = rand::rng();
for _ in 0..iterations {
let a = rng.random_range(0..self.data.len());
let b = rng.random_range(0..self.data.len());
let b = if b == a {
rng.random_range(0..self.data.len())
} else {
b
};
self.data.swap(a, b);
if !self.labels.is_empty() {
self.labels.swap(a, b);
}
}
}
}
#[must_use]
#[allow(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::cast_precision_loss
)]
pub fn split(&mut self, ratio: f32) -> Self {
let ratio = ratio.abs();
let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
let new_data = self.data.drain(new_size..).collect();
let new_labels = if self.labels.is_empty() {
vec![]
} else {
self.labels.drain(new_size..).collect()
};
Self {
data: new_data,
labels: new_labels,
features: self.features.clone(),
}
}
}
#[cfg(test)]
mod tests {
use crate::dataset::Dataset;
#[test]
fn xor() {
let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
assert!(csv_dataset.validate());
let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
assert!(arff_dataset.validate());
let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
assert!(svm_dataset.validate());
assert_eq!(csv_dataset, arff_dataset);
assert_eq!(csv_dataset, svm_dataset);
assert_eq!(arff_dataset, svm_dataset);
}
#[test]
fn xor_no_label() {
assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
}
#[test]
fn shuffle() {
let original_dataset =
Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
dataset.shuffle();
assert_eq!(original_dataset, dataset);
assert_ne!(original_dataset.data, dataset.data);
assert_ne!(original_dataset.labels, dataset.labels);
assert_eq!(original_dataset.features, dataset.features);
}
#[test]
fn split() {
let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
let original_size = dataset.len();
let smaller = dataset.split(0.8);
println!(
"Original: {original_size}, New size: {}, Smaller dataset: {}",
dataset.len(),
smaller.len()
);
assert!(smaller.len() < dataset.len());
assert_eq!(original_size, dataset.len() + smaller.len());
assert_ne!(dataset, smaller);
assert_eq!(dataset.features, smaller.features);
}
#[test]
fn save() {
const COPY_CSV: &str = "xor_copy.csv";
const COPY_ARFF: &str = "xor_copy.arff";
const COPY_SVM: &str = "xor_copy.svm";
let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
dataset.save_csv(COPY_CSV).unwrap();
dataset.save_arff(COPY_ARFF).unwrap();
dataset.save_libsvm(COPY_SVM).unwrap();
let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
assert_eq!(dataset, dataset2);
let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
assert_eq!(dataset, dataset3);
assert_eq!(dataset2, dataset3);
let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
assert_eq!(dataset, dataset4);
assert_eq!(dataset3, dataset4);
std::fs::remove_file(COPY_CSV).unwrap();
std::fs::remove_file(COPY_ARFF).unwrap();
std::fs::remove_file(COPY_SVM).unwrap();
}
}