use ndarray::prelude::*;
use rand::{Rng, ThreadRng, thread_rng};
use mentor::configs::Scheduling;
#[derive(Debug, Clone)]
pub struct Sample {
pub input: Array1<f32>,
pub target: Array1<f32>,
}
impl Sample {
pub fn new<A1, A2>(input: A1, target: A2) -> Sample
where A1: Into<Vec<f32>>,
A2: Into<Vec<f32>>
{
Sample{
input : Array1::from_vec(input.into()),
target: Array1::from_vec(target.into())
}
}
}
impl<A1, A2> From<(A1, A2)> for Sample
where A1: Into<Vec<f32>>,
A2: Into<Vec<f32>>
{
fn from(from: (A1, A2)) -> Sample {
Sample::new(from.0, from.1)
}
}
#[derive(Debug, Clone)]
pub struct SampleView<'a> {
pub input: ArrayView1<'a, f32>,
pub target: ArrayView1<'a, f32>,
}
impl<'a> From<&'a Sample> for SampleView<'a> {
fn from(from: &'a Sample) -> SampleView<'a> {
SampleView {
input: from.input.view(),
target: from.target.view(),
}
}
}
#[macro_export]
macro_rules! samples {
[ $( [ $($i:expr),+ ] => [ $($e:expr),+ ] ),+ ] => {
vec![$(
Sample::new(
vec![$($i),+],
vec![$($e),+]
)
),+]
};
[ $( [ $($i:expr),+ ] => $e:expr ),+ ] => {
vec![$(
Sample::new(
vec![$($i),+],
vec![$e]
)
),+]
};
[ $( $i:expr => [ $($e:expr),+ ] ),+ ] => {
vec![$(
Sample::new(
vec![$i],
vec![$($e),+]
)
),+]
};
[ $( $i:expr => $e:expr ),+ ] => {
vec![$(
Sample::new(
vec![$i],
vec![$e]
)
),+]
};
}
#[derive(Clone)]
enum Scheduler {
Random(ThreadRng),
Iterative(u64),
}
impl ::std::fmt::Debug for Scheduler {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
use self::Scheduler::*;
match *self {
Random(_) => write!(f, "Scheduler::Random(_)"),
Iterative(x) => write!(f, "Scheduler::Iterative({})", x),
}
}
}
impl Scheduler {
fn from_kind(kind: Scheduling) -> Self {
use mentor::configs::Scheduling::*;
match kind {
Random => Scheduler::Random(thread_rng()),
Iterative => Scheduler::Iterative(0),
}
}
fn next(&mut self, num_samples: usize) -> usize {
use self::Scheduler::*;
match *self {
Random(ref mut rng) => {
rng.gen_range(0, num_samples)
},
Iterative(ref mut cur) => {
let next = *cur as usize % num_samples;
*cur += 1;
next
}
}
}
}
#[derive(Debug, Clone)]
pub struct SampleScheduler {
samples : Vec<Sample>,
scheduler: Scheduler,
}
impl SampleScheduler {
pub fn from_samples(kind: Scheduling, samples: Vec<Sample>) -> Self {
SampleScheduler {
samples: samples,
scheduler: Scheduler::from_kind(kind),
}
}
pub fn next_sample(&mut self) -> SampleView {
let len_samples = self.samples.len();
let id = self.scheduler.next(len_samples);
(&self.samples[id]).into()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_samples_eq(expansion: &[Sample], target: &[Sample]) {
assert_eq!(expansion.len(), target.len());
for (fst, snd) in expansion.iter().zip(target.iter()) {
assert_eq!(fst.input.len(), snd.input.len());
assert_eq!(fst.target.len(), snd.target.len());
for (i, t) in fst.input.iter().zip(snd.input.iter()) {
assert_eq!(i, t);
}
}
}
#[test]
fn sample_and_vec_equal() {
let s1: Vec<Sample> = samples![
[1.0, 2.0] => [3.0],
[4.0, 5.0] => [5.0]
];
let a1: Vec<Sample> = vec![
Sample::new(vec![1.0, 2.0], vec![3.0]),
Sample::new(vec![4.0, 5.0], vec![5.0]),
];
assert_samples_eq(&s1, &a1);
}
#[test]
fn missing_right_brackets() {
let s1: Vec<Sample> = samples![
[1.0, 2.0] => [3.0],
[4.0, 5.0] => [5.0],
[6.0, 7.0] => [8.0]
];
let s2: Vec<Sample> = samples![
[1.0, 2.0] => 3.0,
[4.0, 5.0] => 5.0,
[6.0, 7.0] => 8.0
];
assert_samples_eq(&s1, &s2);
}
#[test]
fn missing_left_brackets() {
let s1: Vec<Sample> = samples![
[1.0] => [2.0, 3.0],
[4.0] => [5.0, 6.0],
[7.0] => [8.0, 9.0]
];
let s2: Vec<Sample> = samples![
1.0 => [2.0, 3.0],
4.0 => [5.0, 6.0],
7.0 => [8.0, 9.0]
];
assert_samples_eq(&s1, &s2);
}
#[test]
fn missing_both_brackets() {
let s1: Vec<Sample> = samples![
[1.0] => [2.0],
[3.0] => [4.0],
[5.0] => [6.0]
];
let s2: Vec<Sample> = samples![
1.0 => 2.0,
3.0 => 4.0,
5.0 => 6.0
];
assert_samples_eq(&s1, &s2);
}
}