use rand::prelude::SliceRandom;
use std::vec::Vec;
use crate::arrays::HasArrayData;
use crate::tensor::{Tensor1D, Tensor2D, TensorCreator};
pub fn arange<const N: usize>() -> Tensor1D<N> {
let mut output = Tensor1D::zeros();
for i in 0..N {
output.mut_data()[i] = i as f32;
}
output
}
pub fn one_hot_encode<const B: usize, const N: usize>(class_labels: &[usize; B]) -> Tensor2D<B, N> {
let mut result = Tensor2D::zeros();
for (i, row) in result.mut_data().iter_mut().enumerate() {
row[class_labels[i]] = 1.0;
}
result
}
pub struct SubsetIterator<const B: usize> {
i: usize,
indices: Vec<usize>,
}
impl<const B: usize> SubsetIterator<B> {
pub fn in_order(n: usize) -> Self {
let mut indices: Vec<usize> = Vec::with_capacity(n);
for i in 0..n {
indices.push(i);
}
Self { i: 0, indices }
}
pub fn shuffled<R: rand::Rng>(n: usize, rng: &mut R) -> Self {
let mut sampler = Self::in_order(n);
sampler.indices.shuffle(rng);
sampler
}
}
impl<const B: usize> Iterator for SubsetIterator<B> {
type Item = [usize; B];
fn next(&mut self) -> Option<Self::Item> {
if self.indices.len() < B || self.i + B > self.indices.len() {
None
} else {
let mut batch = [0; B];
batch.copy_from_slice(&self.indices[self.i..self.i + B]);
self.i += B;
Some(batch)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sampler_uses_all() {
let mut seen: Vec<usize> = Vec::new();
for batch in SubsetIterator::<5>::in_order(100) {
seen.extend(batch.iter());
}
for i in 0..100 {
assert!(seen.contains(&i));
}
}
#[test]
fn sampler_drops_last() {
let mut seen: Vec<usize> = Vec::new();
for batch in SubsetIterator::<6>::in_order(100) {
seen.extend(batch.iter());
}
for i in 0..96 {
assert!(seen.contains(&i));
}
}
}