#![cfg(feature = "datasets")]
use dataset_core::datasets::iris::*;
use dataset_core::utils::{download_to, file_sha256_matches};
use std::fs::{File, create_dir_all, remove_dir_all};
use std::io::Write;
use std::path::Path;
#[test]
fn test_load_iris() {
let download_dir = "./test_load_iris";
let dataset = Iris::new(download_dir);
let features = dataset.features().unwrap();
let labels = dataset.labels().unwrap();
assert_eq!(features.shape(), &[150, 4]);
assert_eq!(labels.len(), 150);
let (features, labels) = dataset.data().unwrap(); let mut features_owned = features.to_owned();
let mut labels_owned = labels.to_owned();
features_owned[[0, 0]] = 5.5;
labels_owned[0] = "setosa-modified";
remove_dir_all(download_dir).unwrap();
}
#[test]
fn test_iris_no_need_download() {
let download_dir = "./test_load_iris_no_need_download";
let download_dir_path = Path::new(download_dir);
create_dir_all(download_dir_path).unwrap();
download_to(
"https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv",
download_dir_path,
None,
)
.unwrap();
let dataset = Iris::new(download_dir);
let (_features, _labels) = dataset.data().unwrap();
remove_dir_all(download_dir).unwrap();
}
#[test]
fn test_iris_overwrite() {
let download_dir = "./test_load_iris_overwrite";
let download_dir_path = Path::new(download_dir);
create_dir_all(download_dir_path).unwrap();
{
let iris_path = download_dir_path.join("iris.csv");
let mut fake_iris = File::create(iris_path).unwrap();
fake_iris.write_all(b"fake data").unwrap();
}
let dataset = Iris::new(download_dir);
let (_features, _labels) = dataset.data().unwrap();
assert!(
file_sha256_matches(
&download_dir_path.join("iris.csv"),
"c52742e50315a99f956a383faedf7575552675f6409ef0f9a47076dd08479930"
)
.unwrap()
);
remove_dir_all(download_dir).unwrap();
}