rusty_science 0.1.1

An easy to learn and use ML toolkit for rust
Documentation
use num_traits::{Num, ToPrimitive};
use rand::distributions::{Distribution, Uniform};
use rand::distributions::uniform::SampleUniform;

pub fn generate_labeled_data<D, L>(shape: (usize, usize)) -> (Vec<Vec<D>>, Vec<L>)
where
    D: Num + Copy + Clone + PartialOrd + ToPrimitive + SampleUniform,
    L: Num + Copy + Clone + PartialOrd + ToPrimitive + SampleUniform,
{
    let (rows, _) = shape;
    let mut rng = rand::thread_rng();

    let data = generate_unlabeled_data::<D>(shape);
    
    let label_dist = Uniform::new(L::zero(), L::one());
    let labels: Vec<L> = (0..rows).map(|_| label_dist.sample(&mut rng)).collect();

    (data, labels)
}

pub fn generate_unlabeled_data<D>(shape: (usize, usize)) -> Vec<Vec<D>>
where
    D: Num + Copy + Clone + PartialOrd + ToPrimitive + SampleUniform,
{
    let (rows, cols) = shape;
    let mut rng = rand::thread_rng();

    let data_dist = Uniform::new(D::zero(), D::one());

    let data: Vec<Vec<D>> = (0..rows)
        .map(|_| (0..cols).map(|_| data_dist.sample(&mut rng)).collect())
        .collect();
    
    data
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_generate_labeled_data() {
        let (data, labels) = generate_labeled_data::<f64, f64>((100,100));
        assert_eq!(data.len(), 100);
        assert_eq!(labels.len(), 100);
    }
    
    #[test]
    fn test_generate_unlabeled_data() {
        let data = generate_unlabeled_data::<f64>((100,100));
        assert_eq!(data.len(), 100);
    }
}