use rand::StdRng;
use prelude::*;
fn default_stdrng() -> StdRng {
StdRng::new().unwrap()
}
#[derive(Clone, Deserialize, Serialize)]
pub struct EncodableRng {
#[serde(skip, default = "default_stdrng")]
pub rng: StdRng,
}
impl EncodableRng {
pub fn new() -> EncodableRng {
EncodableRng {
rng: StdRng::new().unwrap(),
}
}
}
impl Default for EncodableRng {
fn default() -> Self {
EncodableRng::new()
}
}
pub fn check_valid_labels(y: &Array) -> Result<(), &'static str> {
if y.cols() != 1 {
return Err("Target array has more than one column.");
}
if y.data().iter().all(|&x| x == 0.0 || x == 1.0) {
Ok(())
} else {
Err("Invalid labels: target data is not either 0.0 or 1.0")
}
}
pub fn check_data_dimensionality<T: IndexableMatrix>(
model_dim: usize,
X: &T,
) -> Result<(), &'static str> {
if X.cols() == model_dim {
Ok(())
} else {
Err("Model input and model dimensionality differ.")
}
}
pub fn check_matched_dimensions<T: IndexableMatrix>(X: &T, y: &Array) -> Result<(), &'static str> {
if X.rows() == y.rows() {
Ok(())
} else {
Err("Data matrix and target array do not have the same number of rows")
}
}
#[cfg(test)]
mod tests {
use super::EncodableRng;
use serde_json;
#[test]
fn test_encodable_rng_serialization() {
let rng = EncodableRng::new();
let serialized = serde_json::to_string(&rng).unwrap();
let _: EncodableRng = serde_json::from_str(&serialized).unwrap();
}
}