use csv::{ReaderBuilder, StringRecord};
use itertools::Itertools;
use ndarray::{
iter::AxisChunksIter, Array, ArrayView, Axis, Dimension, IntoDimension, Ix, RemoveAxis, Zip,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
use serde::de::DeserializeOwned;
use std::{fs::File, io::Read};
fn stacked_shape<D: Dimension>(rows: usize, shape: D) -> D::Larger {
let mut new_shape = D::Larger::zeros(shape.ndim() + 1);
new_shape[0] = rows;
new_shape.slice_mut()[1..].clone_from_slice(shape.slice());
new_shape
}
pub struct Dataset<D> {
records: Array<f32, D>,
}
impl<D: RemoveAxis> Dataset<D> {
fn new(records: Array<f32, D>) -> Self {
Self { records }
}
pub fn records(&self) -> &Array<f32, D> {
&self.records
}
pub fn len(&self) -> usize {
self.records.len_of(Axis(0))
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn kfold(&self, k: usize) -> KFold<D> {
KFold::new(self.records.view(), k)
}
pub fn batch(&self, batch_size: usize) -> Batch<D> {
Batch::new(&self.records, batch_size)
}
pub fn split(self, lengths: &[usize]) -> Vec<Dataset<D>> {
if self.len() != lengths.iter().sum::<usize>() {
panic!("error: input lengths do not cover the whole dataset.");
}
let mut shape = self.records.raw_dim();
let elems: Ix = shape.slice().iter().skip(1).product();
let mut records = self.records.into_raw_vec();
let mut datasets = Vec::with_capacity(lengths.len());
for length in lengths {
shape[0] = *length;
datasets.push(Dataset::new(
Array::from_shape_vec(shape.clone(), records.drain(..length * elems).collect())
.unwrap(),
));
}
datasets
}
pub fn shuffle(&mut self) -> &mut Self {
self.shuffle_with_seed(rand::thread_rng().gen())
}
pub fn shuffle_with_seed(&mut self, seed: u64) -> &mut Self {
let len = self.records.len_of(Axis(0));
if len == 0 {
return self;
}
let mut rng = StdRng::seed_from_u64(seed);
for i in 0..len - 1 {
let j = rng.gen_range(0..len - i - 1);
let mut iter = self.records.outer_iter_mut();
Zip::from(iter.nth(i).unwrap())
.and(iter.nth(j).unwrap())
.for_each(std::mem::swap);
}
self
}
}
pub struct DataLoader {
r_builder: ReaderBuilder,
}
impl DataLoader {
pub fn with_labels(self, labels: &[usize]) -> LabeledDataLoader {
LabeledDataLoader::new(self, labels)
}
pub fn without_headers(&mut self) -> &mut Self {
self.r_builder.has_headers(false);
self
}
pub fn with_delimiter(&mut self, delimiter: char) -> &mut Self {
self.r_builder.delimiter(delimiter as u8);
self
}
pub fn from_csv<S>(&mut self, src: &str, shape: S) -> Dataset<<S::Dim as Dimension>::Larger>
where
S: IntoDimension,
{
self.from_reader_fn(File::open(src).unwrap(), shape, |r| r)
}
pub fn from_reader<R, S>(&mut self, src: R, shape: S) -> Dataset<<S::Dim as Dimension>::Larger>
where
R: Read,
S: IntoDimension,
{
self.from_reader_fn(src, shape, |r| r)
}
pub fn from_csv_fn<S, T, F>(
&mut self,
src: &str,
shape: S,
f: F,
) -> Dataset<<S::Dim as Dimension>::Larger>
where
S: IntoDimension,
T: DeserializeOwned,
F: Fn(T) -> Vec<f32>,
{
self.from_reader_fn(File::open(src).unwrap(), shape, f)
}
pub fn from_reader_fn<R, S, T, F>(
&mut self,
src: R,
shape: S,
f: F,
) -> Dataset<<S::Dim as Dimension>::Larger>
where
R: Read,
S: IntoDimension,
T: DeserializeOwned,
F: Fn(T) -> Vec<f32>,
{
let shape = shape.into_dimension();
if shape.size() == 0 {
panic!("error: cannot handle empty records.")
}
let mut records = Vec::new();
let mut rows = 0;
for record in self.r_builder.from_reader(src).deserialize() {
let record = f(record.unwrap());
records.extend(record);
rows += 1;
}
Dataset::new(Array::from_shape_vec(stacked_shape(rows, shape), records).unwrap())
}
}
impl Default for DataLoader {
fn default() -> Self {
Self {
r_builder: ReaderBuilder::new(),
}
}
}
pub struct LabeledDataLoader {
r_builder: ReaderBuilder,
labels: Vec<usize>,
}
impl LabeledDataLoader {
fn deserialize_record<T, U>(&self, record: StringRecord) -> (T, U)
where
T: DeserializeOwned,
U: DeserializeOwned,
{
let mut input = StringRecord::new();
let mut label = StringRecord::new();
for (id, value) in record.iter().enumerate() {
match self.labels.binary_search(&id) {
Ok(_) => label.push_field(value),
Err(_) => input.push_field(value),
}
}
(
input.deserialize(None).unwrap(),
label.deserialize(None).unwrap(),
)
}
fn new(builder: DataLoader, labels: &[usize]) -> Self {
if labels.is_empty() {
panic!("error: labels were not provided.");
}
let mut labels = labels.to_vec();
labels.sort_unstable();
if labels.windows(2).any(|w| w[0] == w[1]) {
panic!("error: duplicated labels.");
}
Self {
r_builder: builder.r_builder,
labels,
}
}
pub fn without_headers(&mut self) -> &mut Self {
self.r_builder.has_headers(false);
self
}
pub fn with_delimiter(&mut self, delimiter: char) -> &mut Self {
self.r_builder.delimiter(delimiter as u8);
self
}
pub fn from_csv<S1, S2>(
&mut self,
src: &str,
record_shape: S1,
label_shape: S2,
) -> LabeledDataset<<S1::Dim as Dimension>::Larger, <S2::Dim as Dimension>::Larger>
where
S1: IntoDimension,
S2: IntoDimension,
{
self.from_reader_fn(File::open(src).unwrap(), record_shape, label_shape, |r| r)
}
pub fn from_reader<R, S1, S2>(
&mut self,
src: R,
record_shape: S1,
label_shape: S2,
) -> LabeledDataset<<S1::Dim as Dimension>::Larger, <S2::Dim as Dimension>::Larger>
where
R: Read,
S1: IntoDimension,
S2: IntoDimension,
{
self.from_reader_fn(src, record_shape, label_shape, |r| r)
}
pub fn from_csv_fn<S1, S2, T, U, F>(
&mut self,
src: &str,
record_shape: S1,
label_shape: S2,
f: F,
) -> LabeledDataset<<S1::Dim as Dimension>::Larger, <S2::Dim as Dimension>::Larger>
where
S1: IntoDimension,
S2: IntoDimension,
T: DeserializeOwned,
U: DeserializeOwned,
F: Fn((T, U)) -> (Vec<f32>, Vec<f32>),
{
self.from_reader_fn(File::open(src).unwrap(), record_shape, label_shape, f)
}
pub fn from_reader_fn<R, S1, S2, T, U, F>(
&mut self,
src: R,
record_shape: S1,
label_shape: S2,
f: F,
) -> LabeledDataset<<S1::Dim as Dimension>::Larger, <S2::Dim as Dimension>::Larger>
where
R: Read,
S1: IntoDimension,
S2: IntoDimension,
T: DeserializeOwned,
U: DeserializeOwned,
F: Fn((T, U)) -> (Vec<f32>, Vec<f32>),
{
let record_shape = record_shape.into_dimension();
let label_shape = label_shape.into_dimension();
if record_shape.size() == 0 || label_shape.size() == 0 {
panic!("error: cannot handle empty records")
}
let mut records = Vec::new();
let mut labels = Vec::new();
let mut rows = 0;
for record in self.r_builder.from_reader(src).records() {
let (record, label) = f(self.deserialize_record(record.unwrap()));
records.extend(record);
labels.extend(label);
rows += 1;
}
LabeledDataset::new(
Array::from_shape_vec(stacked_shape(rows, record_shape), records).unwrap(),
Array::from_shape_vec(stacked_shape(rows, label_shape), labels).unwrap(),
)
}
}
pub struct LabeledDataset<D1, D2> {
records: Array<f32, D1>,
labels: Array<f32, D2>,
}
impl<D1: RemoveAxis, D2: RemoveAxis> LabeledDataset<D1, D2> {
fn new(records: Array<f32, D1>, labels: Array<f32, D2>) -> Self {
Self { records, labels }
}
pub fn records(&self) -> &Array<f32, D1> {
&self.records
}
pub fn labels(&self) -> &Array<f32, D2> {
&self.labels
}
pub fn len(&self) -> usize {
self.records.len_of(Axis(0))
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn kfold(&self, k: usize) -> LabeledKFold<D1, D2> {
LabeledKFold::new(self.records.view(), self.labels.view(), k)
}
pub fn batch(&self, size: usize) -> LabeledBatch<D1, D2> {
LabeledBatch::new(&self.records, &self.labels, size)
}
pub fn split(self, lengths: &[usize]) -> Vec<LabeledDataset<D1, D2>> {
if self.len() != lengths.iter().sum::<usize>() {
panic!("error: input lengths do not cover the whole dataset.");
}
let mut r_shape = self.records.raw_dim();
let r_elems: Ix = r_shape.slice().iter().skip(1).product();
let mut records = self.records.into_raw_vec();
let mut l_shape = self.labels.raw_dim();
let l_elems: Ix = l_shape.slice().iter().skip(1).product();
let mut labels = self.labels.into_raw_vec();
let mut datasets = Vec::with_capacity(lengths.len());
for length in lengths {
r_shape[0] = *length;
l_shape[0] = *length;
datasets.push(LabeledDataset::new(
Array::from_shape_vec(r_shape.clone(), records.drain(..length * r_elems).collect())
.unwrap(),
Array::from_shape_vec(l_shape.clone(), labels.drain(..length * l_elems).collect())
.unwrap(),
));
}
datasets
}
pub fn shuffle(&mut self) -> &mut Self {
self.shuffle_with_seed(rand::thread_rng().gen())
}
pub fn shuffle_with_seed(&mut self, seed: u64) -> &mut Self {
let len = self.records.len_of(Axis(0));
if len == 0 {
return self;
}
let mut rng = StdRng::seed_from_u64(seed);
for i in 0..len - 1 {
let j = rng.gen_range(0..len - i - 1);
let mut iter = self.records.outer_iter_mut();
Zip::from(iter.nth(i).unwrap())
.and(iter.nth(j).unwrap())
.for_each(std::mem::swap);
let mut iter = self.labels.outer_iter_mut();
Zip::from(iter.nth(i).unwrap())
.and(iter.nth(j).unwrap())
.for_each(std::mem::swap);
}
self
}
}
pub struct Batch<'a, D> {
iter: AxisChunksIter<'a, f32, D>,
}
impl<'a, D: RemoveAxis> Batch<'a, D> {
fn new(source: &'a Array<f32, D>, size: usize) -> Self {
Self {
iter: source.axis_chunks_iter(Axis(0), size),
}
}
pub fn drop_last(mut self) -> Self {
let mut current = self.iter.clone();
if let Some(next) = current.next() {
if let Some(last) = current.last() {
if next.len_of(Axis(0)) != last.len_of(Axis(0)) {
self.iter = self.iter.dropping_back(1);
}
}
}
self
}
}
impl<'a, D: RemoveAxis> Iterator for Batch<'a, D> {
type Item = <AxisChunksIter<'a, f32, D> as Iterator>::Item;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
struct SetKFold<'a, D> {
source: ArrayView<'a, f32, D>,
step: usize,
axis_len: usize,
}
impl<'a, D: RemoveAxis> SetKFold<'a, D> {
pub fn new(source: ArrayView<'a, f32, D>, k: usize) -> Self {
if k < 2 {
panic!("error: folds must be > 2.");
}
let axis_len = source.len_of(Axis(0));
debug_assert_ne!(axis_len, 0, "no record provided");
Self {
source,
step: 1 + (axis_len - 1) / k,
axis_len,
}
}
pub fn compute_fold(&mut self, i: usize) -> (Array<f32, D>, Array<f32, D>) {
let start = self.step * i;
let stop = self.axis_len.min(start + self.step);
let train_ids: Vec<usize> = (0..start).chain(stop..self.axis_len).collect();
let test_ids: Vec<usize> = (start..stop).collect();
(
self.source.select(Axis(0), &train_ids),
self.source.select(Axis(0), &test_ids),
)
}
}
pub struct LabeledKFold<'a, D1, D2> {
records: SetKFold<'a, D1>,
labels: SetKFold<'a, D2>,
iteration: usize,
k: usize,
}
impl<'a, D1, D2> LabeledKFold<'a, D1, D2>
where
D1: RemoveAxis,
D2: RemoveAxis,
{
pub fn new(records: ArrayView<'a, f32, D1>, labels: ArrayView<'a, f32, D2>, k: usize) -> Self {
assert_eq!(records.len_of(Axis(0)), labels.len_of(Axis(0)));
Self {
records: SetKFold::new(records, k),
labels: SetKFold::new(labels, k),
iteration: 0,
k,
}
}
}
impl<'a, D1, D2> Iterator for LabeledKFold<'a, D1, D2>
where
D1: RemoveAxis,
D2: RemoveAxis,
{
type Item = (LabeledDataset<D1, D2>, LabeledDataset<D1, D2>);
fn next(&mut self) -> Option<Self::Item> {
if self.iteration >= self.k {
return None;
}
let (train_in, test_in) = self.records.compute_fold(self.iteration);
let (train_out, test_out) = self.labels.compute_fold(self.iteration);
self.iteration += 1;
Some((
LabeledDataset::new(train_in, train_out),
LabeledDataset::new(test_in, test_out),
))
}
}
pub struct LabeledBatch<'a, D1, D2> {
records: Batch<'a, D1>,
labels: Batch<'a, D2>,
}
impl<'a, D1: RemoveAxis, D2: RemoveAxis> LabeledBatch<'a, D1, D2> {
fn new(records: &'a Array<f32, D1>, labels: &'a Array<f32, D2>, size: usize) -> Self {
assert_eq!(records.len_of(Axis(0)), labels.len_of(Axis(0)));
Self {
records: Batch::new(records, size),
labels: Batch::new(labels, size),
}
}
pub fn drop_last(mut self) -> Self {
self.records = self.records.drop_last();
self.labels = self.labels.drop_last();
self
}
}
impl<'a, D1: RemoveAxis, D2: RemoveAxis> Iterator for LabeledBatch<'a, D1, D2> {
type Item = (
<Batch<'a, D1> as Iterator>::Item,
<Batch<'a, D2> as Iterator>::Item,
);
fn next(&mut self) -> Option<Self::Item> {
match self.records.next() {
Some(records) => Some((records, self.labels.next().unwrap())),
None => None,
}
}
}
pub struct KFold<'a, D> {
records: SetKFold<'a, D>,
iteration: usize,
k: usize,
}
impl<'a, D: RemoveAxis> KFold<'a, D> {
pub fn new(records: ArrayView<'a, f32, D>, k: usize) -> Self {
Self {
records: SetKFold::new(records, k),
iteration: 0,
k,
}
}
}
impl<'a, D: RemoveAxis> Iterator for KFold<'a, D> {
type Item = (Dataset<D>, Dataset<D>);
fn next(&mut self) -> Option<Self::Item> {
if self.iteration >= self.k {
return None;
}
let (records_in, records_out) = self.records.compute_fold(self.iteration);
self.iteration += 1;
Some((Dataset::new(records_in), Dataset::new(records_out)))
}
}
#[cfg(test)]
mod tests {
use super::*;
mod dataset {
use super::*;
use ndarray::Array;
static DATASET: &str = "\
0,1,2,3,4,5,6,7,8,9\n\
9,8,7,6,5,4,3,2,1,0\n\
0,1,2,3,4,5,6,7,8,9\n\
9,8,7,6,5,4,3,2,1,0\n\
0,1,2,3,4,5,6,7,8,9";
#[test]
fn from_reader() {
let dataset = DataLoader::default()
.without_headers()
.from_reader(DATASET.as_bytes(), 10);
assert_eq!(
dataset.records(),
Array::from_shape_vec(
(5, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2.,
1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
}
#[test]
fn kfold() {
let dataset = DataLoader::default()
.without_headers()
.from_reader(DATASET.as_bytes(), 10);
let mut kfold = dataset.kfold(2);
let (train, test) = kfold.next().unwrap();
assert_eq!(
train.records(),
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
assert_eq!(
test.records(),
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
let (train, test) = kfold.next().unwrap();
assert_eq!(
train.records(),
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert_eq!(
test.records(),
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
assert!(kfold.next().is_none());
}
#[test]
fn batch() {
let dataset = DataLoader::default()
.without_headers()
.from_reader(DATASET.as_bytes(), 10);
let mut batch = dataset.batch(3);
assert_eq!(
batch.next().unwrap(),
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert_eq!(
batch.next().unwrap(),
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
assert!(batch.next().is_none());
}
#[test]
fn split() {
let dataset = DataLoader::default()
.without_headers()
.from_reader(DATASET.as_bytes(), 10);
let datasets = dataset.split(&[1, 1, 1, 2]);
assert_eq!(
datasets[0].records(),
Array::from_shape_vec((1, 10), vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,])
.unwrap()
);
assert_eq!(
datasets[1].records(),
Array::from_shape_vec((1, 10), vec![9., 8., 7., 6., 5., 4., 3., 2., 1., 0.])
.unwrap()
);
assert_eq!(
datasets[2].records(),
Array::from_shape_vec((1, 10), vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,])
.unwrap()
);
assert_eq!(
datasets[3].records(),
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
}
#[test]
fn shuffle() {
let mut dataset = DataLoader::default()
.without_headers()
.from_reader(DATASET.as_bytes(), 10);
dataset.shuffle();
assert_ne!(
dataset.records(),
Array::from_shape_vec(
(5, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2.,
1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
}
#[test]
fn drop_last() {
let dataset = DataLoader::default()
.without_headers()
.from_reader(DATASET.as_bytes(), 10);
let mut batch = dataset.batch(3).drop_last();
assert_eq!(
batch.next().unwrap(),
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert!(batch.next().is_none());
assert!(batch.next().is_none());
}
}
mod labeled_dataset {
use super::*;
use ndarray::Array;
static DATASET: &str = "\
0,1,2,1,3,4,5,6,0,7,8,9\n\
9,8,7,0,6,5,4,3,1,2,1,0\n\
0,1,2,1,3,4,5,6,0,7,8,9\n\
9,8,7,0,6,5,4,3,1,2,1,0\n\
0,1,2,1,3,4,5,6,0,7,8,9";
#[test]
fn from_reader() {
let dataset = DataLoader::default()
.with_labels(&[3, 8])
.without_headers()
.from_reader(DATASET.as_bytes(), 10, 2);
assert_eq!(
dataset.records(),
Array::from_shape_vec(
(5, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2.,
1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert_eq!(
dataset.labels(),
Array::from_shape_vec((5, 2), vec![1., 0., 0., 1., 1., 0., 0., 1., 1., 0.])
.unwrap()
);
}
#[test]
fn kfold() {
let dataset = DataLoader::default()
.with_labels(&[3, 8])
.without_headers()
.from_reader(DATASET.as_bytes(), 10, 2);
let mut kfold = dataset.kfold(2);
let (train, test) = kfold.next().unwrap();
assert_eq!(
train.records(),
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
assert_eq!(
train.labels(),
Array::from_shape_vec((2, 2), vec![0., 1., 1., 0.]).unwrap()
);
assert_eq!(
test.records(),
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert_eq!(
test.labels(),
Array::from_shape_vec((3, 2), vec![1., 0., 0., 1., 1., 0.]).unwrap()
);
let (train, test) = kfold.next().unwrap();
assert_eq!(
train.records(),
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert_eq!(
train.labels(),
Array::from_shape_vec((3, 2), vec![1., 0., 0., 1., 1., 0.]).unwrap()
);
assert_eq!(
test.records(),
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
assert_eq!(
test.labels(),
Array::from_shape_vec((2, 2), vec![0., 1., 1., 0.]).unwrap()
);
assert!(kfold.next().is_none());
}
#[test]
fn batch() {
let dataset = DataLoader::default()
.with_labels(&[3, 8])
.without_headers()
.from_reader(DATASET.as_bytes(), 10, 2);
let mut batch = dataset.batch(3);
let (records, labels) = batch.next().unwrap();
assert_eq!(
records,
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert_eq!(
labels,
Array::from_shape_vec((3, 2), vec![1., 0., 0., 1., 1., 0.]).unwrap()
);
let (records, labels) = batch.next().unwrap();
assert_eq!(
records,
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
assert_eq!(
labels,
Array::from_shape_vec((2, 2), vec![0., 1., 1., 0.]).unwrap()
);
assert!(batch.next().is_none());
}
#[test]
fn split() {
let dataset = DataLoader::default()
.with_labels(&[3, 8])
.without_headers()
.from_reader(DATASET.as_bytes(), 10, 2);
let datasets = dataset.split(&[1, 1, 1, 2]);
assert_eq!(
datasets[0].records(),
Array::from_shape_vec((1, 10), vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,])
.unwrap()
);
assert_eq!(
datasets[0].labels(),
Array::from_shape_vec((1, 2), vec![1., 0.]).unwrap()
);
assert_eq!(
datasets[1].records(),
Array::from_shape_vec((1, 10), vec![9., 8., 7., 6., 5., 4., 3., 2., 1., 0.,])
.unwrap()
);
assert_eq!(
datasets[1].labels(),
Array::from_shape_vec((1, 2), vec![0., 1.]).unwrap()
);
assert_eq!(
datasets[2].records(),
Array::from_shape_vec((1, 10), vec![0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,])
.unwrap()
);
assert_eq!(
datasets[2].labels(),
Array::from_shape_vec((1, 2), vec![1., 0.]).unwrap()
);
assert_eq!(
datasets[3].records(),
Array::from_shape_vec(
(2, 10),
vec![
9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8.,
9.,
]
)
.unwrap()
);
assert_eq!(
datasets[3].labels(),
Array::from_shape_vec((2, 2), vec![0., 1., 1., 0.]).unwrap()
);
}
#[test]
fn shuffle() {
let mut dataset = DataLoader::default()
.with_labels(&[3, 8])
.without_headers()
.from_reader(DATASET.as_bytes(), 10, 2);
dataset.shuffle();
assert_ne!(
dataset.records(),
Array::from_shape_vec(
(5, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2.,
1., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
}
#[test]
fn drop_last() {
let dataset = DataLoader::default()
.with_labels(&[3, 8])
.without_headers()
.from_reader(DATASET.as_bytes(), 10, 2);
let mut batch = dataset.batch(3).drop_last();
let (records, labels) = batch.next().unwrap();
assert_eq!(
records,
Array::from_shape_vec(
(3, 10),
vec![
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 9., 8., 7., 6., 5., 4., 3., 2., 1.,
0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.,
]
)
.unwrap()
);
assert_eq!(
labels,
Array::from_shape_vec((3, 2), vec![1., 0., 0., 1., 1., 0.]).unwrap()
);
assert!(batch.next().is_none());
assert!(batch.next().is_none());
}
}
}