use crate::ftype::FileType;
use crate::model::LogisticRegression;
use crate::ngram::NgramsFile;
use crate::Bytes;
use std::collections::HashMap;
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::Path;
use std::str::FromStr;
use anyhow::{anyhow, bail, ensure, Result};
use serde::de::IntoDeserializer;
use serde::{Deserialize, Serialize};
use walkdir::WalkDir;
const COMMENT_PREFIXES: [u8; 2] = [b'#', b'%'];
const FEATURES_PREFIX: &str = "Features:";
const FILE_TYPE_PREFIX: &str = "File type:";
#[inline]
pub(crate) fn featurize_file<P: AsRef<Path>, S: ::std::hash::BuildHasher>(
file: P,
n: usize,
features: &HashMap<Bytes, usize, S>,
) -> Result<Vec<f32>> {
let file_size = std::fs::metadata(&file)?.len();
ensure!(
file_size > n as u64,
"File {} is too small.",
file.as_ref().display()
);
let mut feature_vector = vec![0.0; features.len()];
if file_size < 10_485_760u64
{
let contents = std::fs::read(file)?;
for window in contents.windows(n) {
if let Some(index) = features.get(window) {
feature_vector[*index] = 1.0;
}
}
} else {
let mut file = std::fs::File::open(file)?;
let mut buffer = [0u8; crate::ngram::NGRAM_BUFFER_SIZE];
while let Ok(bytes_read) = file.read(&mut buffer) {
if bytes_read < n {
break;
}
for index in 0..bytes_read - n {
if let Some(index) = features.get(&buffer[index..index + n]) {
feature_vector[*index] = 1.0;
}
}
#[allow(clippy::cast_possible_wrap)]
file.seek(SeekFrom::Current(n as i64 - 1))?;
}
}
Ok(feature_vector)
}
#[derive(Copy, Clone, Deserialize, Serialize, Hash, Eq, PartialEq)]
pub enum DatasetFormat {
ARFF,
CSV,
SVM,
}
impl FromStr for DatasetFormat {
type Err = anyhow::Error;
fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
match value.to_lowercase().as_str() {
"arff" => Ok(Self::ARFF),
"csv" => Ok(Self::CSV),
"svm" => Ok(Self::SVM),
x => Err(anyhow!("Unknown data format '{x}'")),
}
}
}
impl TryFrom<&Path> for DatasetFormat {
type Error = anyhow::Error;
fn try_from(value: &Path) -> std::result::Result<Self, Self::Error> {
if let Some(extension) = value.extension() {
let ext = extension
.to_str()
.ok_or_else(|| anyhow!("Failed to get extension."))?;
DatasetFormat::from_str(ext)
} else {
Err(anyhow!("No extension, can't determine file type."))
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Dataset {
pub data: Vec<Vec<f32>>,
#[serde(default)]
pub labels: Vec<f32>,
#[serde(
serialize_with = "crate::serde::serialize_hex_vec",
deserialize_with = "crate::serde::deserialize_hex_vec"
)]
pub features: Vec<Bytes>,
pub ftype: FileType,
}
impl PartialEq for Dataset {
fn eq(&self, other: &Self) -> bool {
if self.data.len() != other.data.len()
|| self.labels.len() != other.labels.len()
|| self.features.len() != other.features.len()
|| self.ftype != other.ftype
{
return false;
}
for this_data in &self.data {
if !other.data.contains(this_data) {
return false;
}
}
for other_data in &other.data {
if !self.data.contains(other_data) {
return false;
}
}
if !self.labels.is_empty() {
for this_label in &self.labels {
if !other.labels.contains(this_label) {
return false;
}
}
for other_label in &other.labels {
if !self.labels.contains(other_label) {
return false;
}
}
}
for this_features in &self.features {
if !other.features.contains(this_features) {
return false;
}
}
for other_feature in &other.features {
if !self.features.contains(other_feature) {
return false;
}
}
true
}
}
impl Dataset {
pub fn load<P: AsRef<Path>>(path: P) -> Result<Dataset> {
if let Some(extension) = path.as_ref().extension() {
return match extension.to_str().unwrap_or_default() {
"arff" => Dataset::from_arff_file(path.as_ref()),
"csv" => Dataset::from_csv_file_assume_data_length(path.as_ref()),
"svm" | "libsvm" => Dataset::from_libsvm_file(path.as_ref()),
"json" => {
let contents = std::fs::read_to_string(path.as_ref())?;
serde_json::from_str(&contents).map_err(Into::into)
}
"toml" => {
let contents = std::fs::read_to_string(path.as_ref())?;
toml::from_str(&contents).map_err(Into::into)
}
ext => {
bail!("Unsupported/unknown data type '{ext}'");
}
};
}
bail!("No extension, can't determine file type.");
}
pub fn from_csv_file<P: AsRef<Path>>(path: P, data_length: usize) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
Self::from_csv_string(&contents, data_length)
}
pub fn from_csv_file_assume_data_length<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
let mut length = 0;
for line in contents.lines() {
if line.is_empty() || COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
continue;
}
length = line.split(',').collect::<Vec<&str>>().len();
break;
}
ensure!(length > 0, "Failed to determine data length.");
Self::from_csv_string(&contents, length - 1)
}
pub fn from_csv_string(contents: &str, data_length: usize) -> Result<Self> {
let mut data: Vec<Vec<f32>> = Vec::new();
let mut labels = Vec::new();
let mut features = Vec::new();
let mut file_type = FileType::NotSet;
for (row_number, line) in contents.lines().enumerate() {
if line.is_empty() {
continue;
}
if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
if line.contains(FEATURES_PREFIX) {
let offset =
line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
let line = line[offset..].trim();
features = line
.split(',')
.filter_map(|f| hex::decode(f.trim()).ok())
.collect();
}
if line.contains(FILE_TYPE_PREFIX) {
if let Some(file_type_str) = line.split(':').nth(1) {
let file_type_str = file_type_str.trim();
let ftype: Result<_, serde::de::value::Error> =
FileType::deserialize(String::from(file_type_str).into_deserializer());
file_type = ftype?;
}
}
}
if line.is_empty() || line.starts_with('%') | line.starts_with('#') {
continue;
}
let row = line.split(',').collect::<Vec<&str>>();
let mut row_float = Vec::with_capacity(data_length);
for r in row.iter().take(data_length) {
row_float.push(r.parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!("Non-float {r} encountered on CSV row {row_number}"))
})?);
}
if let Some(first_row) = data.first() {
ensure!(
first_row.len() == row_float.len(),
"CSV line {row_number} has invalid length {}, expected {}",
row_float.len(),
first_row.len()
);
}
data.push(row_float);
if row.len() == data_length + 1 {
let l = row[data_length].parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!(
"Non-float label {} encountered on CSV row {row_number}",
row[data_length]
))
})?;
labels.push(l);
} else if row.len() > data_length {
bail!(
"CSV row had more than one label on row {row_number}, which isn't supported."
);
}
}
ensure!(
features.len() == data[0].len(),
"Features need to be empty or the same size as the data length."
);
ensure!(
file_type != FileType::NotSet,
"No file type specified in CSV file."
);
Ok(Self {
data,
labels,
features,
ftype: file_type,
})
}
#[inline]
pub(crate) fn file_type_from_line(line: &str) -> Result<FileType, serde::de::value::Error> {
let line = line.split(':').nth(1).unwrap_or(line).to_uppercase();
let ftype: Result<_, serde::de::value::Error> =
FileType::deserialize(String::from(line.trim()).into_deserializer());
ftype
}
pub fn from_arff_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
Self::from_arff_string(&contents)
}
pub fn from_arff_string(contents: &str) -> Result<Self> {
let mut data: Vec<Vec<f32>> = Vec::new();
let mut labels = Vec::new();
let mut features = Vec::new();
let mut file_type = FileType::NotSet;
let mut passed_data = false;
for (row_number, line) in contents.lines().enumerate() {
if line.is_empty() {
continue;
}
if (line.starts_with('%') || line.starts_with('#')) && line.contains(FILE_TYPE_PREFIX) {
file_type = Self::file_type_from_line(line)?;
continue;
}
if line.contains("@ATTRIBUTE") {
let parts: Vec<&str> = line.split_ascii_whitespace().collect();
if parts.len() == 3 && !parts[1].eq_ignore_ascii_case("CLASS") {
match hex::decode(parts[1]) {
Ok(feat) => features.push(feat),
Err(e) => {
bail!("Invalid n-gram attribute on line {row_number}: {line}: {e}")
}
}
}
}
if line.contains("@DATA") {
passed_data = true;
continue;
}
if passed_data {
let row = line.split(',').collect::<Vec<&str>>();
let data_length = row.len() - 1;
let mut row_float = Vec::with_capacity(data_length);
for r in row.iter().take(data_length) {
row_float.push(r.parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!(
"Non-float encountered on ARFF row {row_number}"
))
})?);
}
if let Some(first_row) = data.first() {
ensure!(
first_row.len() == row_float.len(),
"ARFF line {row_number} has invalid length {}, expected {}",
row_float.len(),
first_row.len()
);
}
data.push(row_float);
if row.len() == data_length + 1 {
let l = row[data_length].parse::<f32>().map_err(|_| {
anyhow::Error::msg(format!(
"Non-float encountered on ARFF row {row_number}"
))
})?;
labels.push(l);
} else if row.len() > data_length {
bail!("Arff row had more than one label on row {row_number}, which isn't supported.");
}
}
}
ensure!(
features.len() == data[0].len(),
"Features need to be empty or the same size as the data length."
);
ensure!(
file_type != FileType::NotSet,
"No file type specified in ARFF file."
);
Ok(Self {
data,
labels,
features,
ftype: file_type,
})
}
pub fn from_libsvm_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut file = std::fs::File::open(path)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
Self::from_libsvm_string(&contents)
}
pub fn from_libsvm_string(contents: &str) -> Result<Self> {
let mut data = Vec::new();
let mut labels = Vec::new();
let mut features = Vec::new();
let mut file_type = FileType::NotSet;
for (row_number, line) in contents.lines().enumerate() {
if line.is_empty() {
continue;
}
if COMMENT_PREFIXES.contains(&line.as_bytes()[0]) {
if line.contains(FEATURES_PREFIX) {
let offset =
line.find(FEATURES_PREFIX).unwrap_or_default() + FEATURES_PREFIX.len();
let line = line[offset..].trim();
features = line
.split(',')
.filter_map(|f| hex::decode(f.trim()).ok())
.collect();
}
if line.contains(FILE_TYPE_PREFIX) {
file_type = Self::file_type_from_line(line)?;
}
}
if line.is_empty() || line.starts_with('%') || line.starts_with('#') {
continue;
}
let parts = line.split_whitespace().collect::<Vec<&str>>();
let Ok(label) = parts[0].trim().parse::<f32>() else {
bail!(
"Encountered a non-numeric label {} on line {row_number}",
parts[0]
);
};
let mut row = vec![0.0f32; features.len()];
for part in parts.iter().skip(1) {
let part_parts = part.split(':').collect::<Vec<&str>>();
let Ok(part_index) = part_parts[0].trim().parse::<usize>() else {
bail!(
"Encountered a non-numeric index {} on line {row_number}",
part_parts[0]
);
};
let Ok(part_value) = part_parts[1].trim().parse::<f32>() else {
bail!(
"Encountered a non-numeric value {} on line {row_number}",
part_parts[1]
);
};
if part_index > row.len() && !features.is_empty() {
bail!("Encountered a value at index {part_index} greater than expected size {} on line {row_number}", data.len());
}
if row.is_empty() {
row = vec![0.0; part_index + 1];
} else if part_index >= row.len() {
row.extend_from_slice(&vec![0.0f32; row.len() - part_index + 1]);
}
row[part_index] = part_value;
}
data.push(row);
labels.push(label);
}
let data_len = data[0].len();
for row in &data {
if row.len() != data_len {
bail!(
"Encountered a row with length {} but expected length {data_len}",
row.len()
);
}
}
ensure!(
features.len() == data[0].len(),
"Features need to be empty or the same size as the data length."
);
ensure!(
file_type != FileType::NotSet,
"No file type specified in libsvm file."
);
Ok(Self {
data,
labels,
features,
ftype: file_type,
})
}
#[allow(clippy::too_many_lines)]
pub fn create_save_from_benign_malicious_files_and_ngrams<P: AsRef<Path>>(
malicious_dir: P,
benign_dir: P,
ngrams_file: P,
output_file: P,
) -> Result<()> {
const SUPPORTED_FORMATS: [DatasetFormat; 3] =
[DatasetFormat::CSV, DatasetFormat::ARFF, DatasetFormat::SVM];
let output_format = DatasetFormat::try_from(output_file.as_ref())?;
ensure!(
SUPPORTED_FORMATS.contains(&output_format),
"Only CSV, ARFF, or SVM formats are supported here."
);
let ngrams = NgramsFile::load(ngrams_file)?;
let mut output_file = std::fs::File::create(output_file)?;
writeln!(output_file, "# {FILE_TYPE_PREFIX} {:?}", ngrams.ftype)?;
match output_format {
DatasetFormat::SVM | DatasetFormat::CSV => {
let feature_string_vec = ngrams
.clone()
.into_vec()
.iter()
.map(hex::encode)
.collect::<Vec<String>>();
writeln!(
output_file,
"# {FEATURES_PREFIX} {}",
feature_string_vec.join(", ")
)?;
}
DatasetFormat::ARFF => {
let feature_string_vec = ngrams
.clone()
.into_vec()
.iter()
.map(hex::encode)
.collect::<Vec<String>>();
for feature in feature_string_vec {
let feature_hex = hex::encode(feature);
writeln!(output_file, "@ATTRIBUTE {feature_hex} NUMERIC")?;
}
}
}
for entry in WalkDir::new(malicious_dir)
.max_depth(crate::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
Ok(features) => match output_format {
DatasetFormat::CSV | DatasetFormat::ARFF => {
let mut line = features
.iter()
.map(|p| format!("{p}"))
.collect::<Vec<String>>()
.join(",");
line.push_str(",1\n");
output_file.write_all(line.as_bytes())?;
}
DatasetFormat::SVM => {
write!(output_file, "1")?;
for (data_index, data) in features.iter().enumerate() {
if *data != 0.0000 {
write!(output_file, " {data_index}:{data}")?;
}
}
writeln!(output_file)?;
}
},
Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
}
}
}
for entry in WalkDir::new(benign_dir)
.max_depth(crate::MAX_RECURSION_DEPTH)
.follow_links(true)
.into_iter()
.flatten()
{
if entry.file_type().is_file() {
match featurize_file(entry.path(), ngrams.n, &ngrams.ngrams) {
Ok(features) => match output_format {
DatasetFormat::CSV | DatasetFormat::ARFF => {
let mut line = features
.iter()
.map(|p| format!("{p}"))
.collect::<Vec<String>>()
.join(",");
line.push_str(",0\n");
output_file.write_all(line.as_bytes())?;
}
DatasetFormat::SVM => {
write!(output_file, "0")?;
for (data_index, data) in features.iter().enumerate() {
if *data != 0.0000 {
write!(output_file, " {data_index}:{data}")?;
}
}
writeln!(output_file)?;
}
},
Err(e) => eprintln!("Failed to featurize {}: {e}", entry.path().display()),
}
}
}
output_file.sync_all()?;
Ok(())
}
pub fn save_csv<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let mut file = std::fs::File::create(path)?;
let feature_string_vec = self
.features
.iter()
.map(hex::encode)
.collect::<Vec<String>>();
writeln!(
file,
"# {FEATURES_PREFIX} {}",
feature_string_vec.join(", ")
)?;
writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
for index in 0..self.data.len() {
let mut line = self.data[index]
.iter()
.map(|p| format!("{p}"))
.collect::<Vec<String>>()
.join(",");
if !self.labels.is_empty() {
if self.labels[index] > 0.9 {
line.push_str(",1");
} else {
line.push_str(",0");
}
}
line.push('\n');
file.write_all(line.as_bytes())?;
}
file.sync_all().map_err(Into::into)
}
pub fn save_arff<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let mut file = std::fs::File::create(path)?;
writeln!(file, "# {FILE_TYPE_PREFIX} {:?}\n", self.ftype)?;
for feature in &self.features {
let feature_hex = hex::encode(feature);
file.write_all(format!("@ATTRIBUTE {feature_hex} NUMERIC\n").as_bytes())?;
}
if !self.labels.is_empty() {
file.write_all("@ATTRIBUTE class NUMERIC\n".as_bytes())?;
}
file.write_all("\n@DATA\n".as_bytes())?;
for index in 0..self.data.len() {
let mut line = self.data[index]
.iter()
.map(|p| format!("{p}"))
.collect::<Vec<String>>()
.join(",");
if !self.labels.is_empty() {
if self.labels[index] > 0.9 {
line.push_str(",1");
} else {
line.push_str(",0");
}
}
line.push('\n');
file.write_all(line.as_bytes())?;
}
file.sync_all().map_err(Into::into)
}
pub fn save_libsvm<P: AsRef<Path>>(&self, path: P) -> Result<()> {
ensure!(
!self.labels.is_empty(),
"Labels are required to create an libsvm file."
);
let mut file = std::fs::File::create(path)?;
let feature_string_vec = self
.features
.iter()
.map(hex::encode)
.collect::<Vec<String>>();
writeln!(
file,
"# {FEATURES_PREFIX} {}",
feature_string_vec.join(", ")
)?;
writeln!(file, "# {FILE_TYPE_PREFIX} {:?}", self.ftype)?;
for index in 0..self.data.len() {
file.write_all(format!("{}", self.labels[index]).as_bytes())?;
for (data_index, data) in self.data[index].iter().enumerate() {
if *data != 0.0000 {
file.write_all(format!(" {data_index}:{data}").as_bytes())?;
}
}
file.write_all(b"\n")?;
}
file.sync_all().map_err(Into::into)
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
if let Some(extension) = path.as_ref().extension() {
return match extension.to_str().unwrap_or_default() {
"arff" => self.save_arff(path),
"csv" => self.save_csv(path),
"svm" | "libsvm" => self.save_libsvm(path),
"json" => {
let contents = serde_json::to_string_pretty(self)?;
let mut file = std::fs::File::create(path)?;
file.write_all(contents.as_bytes())?;
file.sync_all().map_err(Into::into)
}
"toml" => {
let contents = toml::to_string_pretty(self)?;
let mut file = std::fs::File::create(path)?;
file.write_all(contents.as_bytes())?;
file.sync_all().map_err(Into::into)
}
ext => {
bail!("Unsupported/unknown data type '{ext}'");
}
};
}
bail!("No extension, can't determine file type.");
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
#[must_use]
pub fn validate(&self) -> bool {
let data_len = match self.data.first() {
Some(first) => first.len(),
None => return false,
};
for record in &self.data {
if record.len() != data_len {
#[cfg(debug_assertions)]
eprint!("Expected record size {data_len}, got {}", record.len());
return false;
}
}
let feature_len = if let Some(first) = self.features.first() {
first.len()
} else {
#[cfg(debug_assertions)]
eprintln!("Features data is missing");
return false;
};
for feature in &self.features {
if feature.len() != feature_len {
#[cfg(debug_assertions)]
eprint!("Expected feature size {feature_len}, got {}", feature.len());
return false;
}
}
(self.labels.is_empty() || self.labels.len() == self.data.len())
&& self.features.len() == data_len
&& self.ftype != FileType::NotSet
}
pub fn shuffle(&mut self) {
if !self.is_empty() {
let iterations = self.data.len().ilog10() * 10;
self.shuffle_iterations(iterations);
}
}
pub fn shuffle_iterations(&mut self, iterations: u32) {
use rand::Rng;
if !self.is_empty() {
let mut rng = rand::rng();
for _ in 0..iterations {
let a = rng.random_range(0..self.data.len());
let b = rng.random_range(0..self.data.len());
let b = if b == a {
rng.random_range(0..self.data.len())
} else {
b
};
self.data.swap(a, b);
if !self.labels.is_empty() {
self.labels.swap(a, b);
}
}
}
}
#[must_use]
#[allow(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::cast_precision_loss
)]
pub fn split(&mut self, ratio: f32) -> Self {
let ratio = ratio.abs();
let ratio = if ratio > 1.0 { 1.0 - ratio } else { ratio };
let new_size = (self.data.len() as f32 * ratio).ceil() as usize;
let new_data = self.data.drain(new_size..).collect();
let new_labels = if self.labels.is_empty() {
vec![]
} else {
self.labels.drain(new_size..).collect()
};
Self {
data: new_data,
labels: new_labels,
features: self.features.clone(),
ftype: self.ftype,
}
}
pub fn reduce(&mut self, model: &LogisticRegression) -> Result<Vec<usize>> {
let mut removed = vec![];
for (index, feature) in self.features.iter().enumerate() {
if !model.features.contains_key(feature) {
removed.push(index);
}
}
if removed.len() == self.data[0].len() {
bail!("This dataset and model are probably not from the same data - this operation would delete all the data!");
}
removed.sort_unstable();
removed.reverse();
self.features
.retain(|feature| model.features.contains_key(feature));
for row in &mut self.data {
for removed in &removed {
row.remove(*removed);
}
}
Ok(removed)
}
#[must_use]
pub fn column_iter(&'_ self, index: usize) -> Option<ColumnIterator<'_>> {
if index < self.data[0].len() {
Some(ColumnIterator {
dataset: self,
column_index: index,
current_row_index: 0,
})
} else {
None
}
}
}
pub struct ColumnIterator<'a> {
dataset: &'a Dataset,
column_index: usize,
current_row_index: usize,
}
impl Iterator for ColumnIterator<'_> {
type Item = f32;
fn next(&mut self) -> Option<Self::Item> {
if self.current_row_index < self.dataset.data.len() {
let val = self.dataset.data[self.current_row_index][self.column_index];
self.current_row_index += 1;
Some(val)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use crate::dataset::Dataset;
#[test]
fn xor() {
let csv_dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
assert!(csv_dataset.validate());
let arff_dataset = Dataset::from_arff_string(include_str!("../testdata/xor.arff")).unwrap();
assert!(arff_dataset.validate());
let svm_dataset = Dataset::from_libsvm_string(include_str!("../testdata/xor.svm")).unwrap();
assert!(svm_dataset.validate());
assert_eq!(csv_dataset, arff_dataset);
assert_eq!(csv_dataset, svm_dataset);
assert_eq!(arff_dataset, svm_dataset);
}
#[test]
fn xor_no_label() {
assert!(Dataset::from_csv_string(include_str!("../testdata/xor_no_label.csv"), 6).is_err());
assert!(Dataset::from_libsvm_string(include_str!("../testdata/xor_no_label.svm")).is_err());
}
#[test]
fn shuffle() {
let original_dataset =
Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
dataset.shuffle();
assert_eq!(original_dataset, dataset);
assert_ne!(original_dataset.data, dataset.data);
assert_ne!(original_dataset.labels, dataset.labels);
assert_eq!(original_dataset.features, dataset.features);
}
#[test]
fn split() {
let mut dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
let original_size = dataset.len();
let smaller = dataset.split(0.8);
println!(
"Original: {original_size}, New size: {}, Smaller dataset: {}",
dataset.len(),
smaller.len()
);
assert!(smaller.len() < dataset.len());
assert_eq!(original_size, dataset.len() + smaller.len());
assert_ne!(dataset, smaller);
assert_eq!(dataset.features, smaller.features);
}
#[test]
fn save() {
const COPY_CSV: &str = "xor_copy.csv";
const COPY_ARFF: &str = "xor_copy.arff";
const COPY_SVM: &str = "xor_copy.svm";
let dataset = Dataset::from_csv_string(include_str!("../testdata/xor.csv"), 6).unwrap();
dataset.save_csv(COPY_CSV).unwrap();
dataset.save_arff(COPY_ARFF).unwrap();
dataset.save_libsvm(COPY_SVM).unwrap();
let dataset2 = Dataset::from_csv_file(COPY_CSV, 6).unwrap();
assert_eq!(dataset, dataset2);
let dataset3 = Dataset::from_arff_file(COPY_ARFF).unwrap();
assert_eq!(dataset, dataset3);
assert_eq!(dataset2, dataset3);
let dataset4 = Dataset::from_libsvm_file(COPY_SVM).unwrap();
assert_eq!(dataset, dataset4);
assert_eq!(dataset3, dataset4);
std::fs::remove_file(COPY_CSV).unwrap();
std::fs::remove_file(COPY_ARFF).unwrap();
std::fs::remove_file(COPY_SVM).unwrap();
}
}