pub use self::sin::SinSource;
mod sin;
pub use self::xor::XORSource;
mod xor;
use af::{Dim4, Array, DType};
use std::cell::{RefCell, Cell};
use utils;
#[derive(Clone)]
pub struct Data {
pub input: RefCell<Box<Array>>,
pub target: RefCell<Box<Array>>,
}
#[derive(PartialEq, Clone, Debug)]
pub struct DataParams {
pub input_dims: Dim4, pub target_dims: Dim4, pub shuffle: bool, pub normalize: bool, pub current_epoch: Cell<u64>, pub dtype: DType, pub num_samples: u64,
pub num_train: u64,
pub num_test: u64,
pub num_validation: Option<u64>,
}
pub trait DataSource {
fn info(&self) -> DataParams;
fn get_train_iter(&self, num_batch: u64) -> Data;
fn get_test_iter(&self, num_batch: u64) -> Data;
fn get_validation_iter(&self, num_batch: u64) -> Option<Data>;
}
pub trait Normalize {
fn normalize(&mut self, num_std: f32);
}
pub trait Shuffle {
fn shuffle(&mut self);
}
impl Shuffle for Data {
fn shuffle(&mut self) {
println!("WARNING: shuffle not yet implemented");
}
}
impl Normalize for Data {
fn normalize(&mut self, num_std: f32){
let normalized_inputs = utils::normalize_array(&self.input.borrow(), num_std);
let normalized_target = utils::normalize_array(&self.target.borrow(), num_std);
self.input = RefCell::new(Box::new(normalized_inputs));
self.target = RefCell::new(Box::new(normalized_target));
}
}