use rand::Rng;
use rand::distributions::{Distribution, Standard};
use rand_distr::{Normal, StandardNormal, Uniform};
use axonml_core::dtype::{Float, Numeric, Scalar};
use crate::tensor::Tensor;
#[must_use]
pub fn zeros<T: Scalar>(shape: &[usize]) -> Tensor<T> {
let numel: usize = shape.iter().product();
let data = vec![T::zeroed(); numel];
Tensor::from_vec(data, shape).expect("tensor creation failed")
}
#[must_use]
pub fn ones<T: Numeric>(shape: &[usize]) -> Tensor<T> {
full(shape, T::one())
}
pub fn full<T: Scalar>(shape: &[usize], value: T) -> Tensor<T> {
let numel: usize = shape.iter().product();
let data = vec![value; numel];
Tensor::from_vec(data, shape).expect("tensor creation failed")
}
#[must_use]
pub fn zeros_like<T: Scalar>(other: &Tensor<T>) -> Tensor<T> {
zeros(other.shape())
}
#[must_use]
pub fn ones_like<T: Numeric>(other: &Tensor<T>) -> Tensor<T> {
ones(other.shape())
}
pub fn full_like<T: Scalar>(other: &Tensor<T>, value: T) -> Tensor<T> {
full(other.shape(), value)
}
#[must_use]
pub fn eye<T: Numeric>(n: usize) -> Tensor<T> {
let mut data = vec![T::zero(); n * n];
for i in 0..n {
data[i * n + i] = T::one();
}
Tensor::from_vec(data, &[n, n]).expect("tensor creation failed")
}
pub fn diag<T: Numeric>(diag: &[T]) -> Tensor<T> {
let n = diag.len();
let mut data = vec![T::zero(); n * n];
for (i, &val) in diag.iter().enumerate() {
data[i * n + i] = val;
}
Tensor::from_vec(data, &[n, n]).expect("tensor creation failed")
}
#[must_use]
pub fn rand<T: Float>(shape: &[usize]) -> Tensor<T>
where
Standard: Distribution<T>,
{
let numel: usize = shape.iter().product();
let mut rng = rand::thread_rng();
let data: Vec<T> = (0..numel).map(|_| rng.r#gen()).collect();
Tensor::from_vec(data, shape).expect("tensor creation failed")
}
#[must_use]
pub fn randn<T: Float>(shape: &[usize]) -> Tensor<T>
where
StandardNormal: Distribution<T>,
{
let numel: usize = shape.iter().product();
let mut rng = rand::thread_rng();
let normal = StandardNormal;
let data: Vec<T> = (0..numel).map(|_| normal.sample(&mut rng)).collect();
Tensor::from_vec(data, shape).expect("tensor creation failed")
}
pub fn uniform<T: Float>(shape: &[usize], low: T, high: T) -> Tensor<T>
where
T: rand::distributions::uniform::SampleUniform,
{
let numel: usize = shape.iter().product();
let mut rng = rand::thread_rng();
let dist = Uniform::new(low, high);
let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
Tensor::from_vec(data, shape).expect("tensor creation failed")
}
pub fn normal<T: Float>(shape: &[usize], mean: T, std: T) -> Tensor<T>
where
T: rand::distributions::uniform::SampleUniform,
StandardNormal: Distribution<T>,
{
let numel: usize = shape.iter().product();
let mut rng = rand::thread_rng();
let dist = Normal::new(mean, std).unwrap();
let data: Vec<T> = (0..numel).map(|_| dist.sample(&mut rng)).collect();
Tensor::from_vec(data, shape).expect("tensor creation failed")
}
#[must_use]
pub fn randint<T: Numeric>(shape: &[usize], low: i64, high: i64) -> Tensor<T>
where
T: num_traits::NumCast,
{
let numel: usize = shape.iter().product();
let mut rng = rand::thread_rng();
let dist = Uniform::new(low, high);
let data: Vec<T> = (0..numel)
.map(|_| T::from(dist.sample(&mut rng)).unwrap())
.collect();
Tensor::from_vec(data, shape).expect("tensor creation failed")
}
pub fn arange<T: Numeric>(start: T, end: T, step: T) -> Tensor<T>
where
T: num_traits::NumCast + PartialOrd,
{
let mut data = Vec::new();
let mut current = start;
if step > T::zero() {
while current < end {
data.push(current);
current = current + step;
}
} else if step < T::zero() {
while current > end {
data.push(current);
current = current + step;
}
}
let len = data.len();
Tensor::from_vec(data, &[len]).expect("tensor creation failed")
}
pub fn linspace<T: Float>(start: T, end: T, num: usize) -> Tensor<T> {
if num == 0 {
return Tensor::from_vec(vec![], &[0]).expect("tensor creation failed");
}
if num == 1 {
return Tensor::from_vec(vec![start], &[1]).expect("tensor creation failed");
}
let step = (end - start) / T::from(num - 1).unwrap();
let data: Vec<T> = (0..num)
.map(|i| start + step * T::from(i).unwrap())
.collect();
Tensor::from_vec(data, &[num]).expect("tensor creation failed")
}
pub fn logspace<T: Float>(start: T, end: T, num: usize, base: T) -> Tensor<T> {
if num == 0 {
return Tensor::from_vec(vec![], &[0]).expect("tensor creation failed");
}
let lin = linspace(start, end, num);
let data: Vec<T> = lin.to_vec().iter().map(|&x| base.pow_value(x)).collect();
Tensor::from_vec(data, &[num]).expect("tensor creation failed")
}
#[must_use]
pub fn empty<T: Scalar>(shape: &[usize]) -> Tensor<T> {
zeros(shape)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zeros() {
let t = zeros::<f32>(&[2, 3]);
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.numel(), 6);
for val in t.to_vec() {
assert_eq!(val, 0.0);
}
}
#[test]
fn test_ones() {
let t = ones::<f32>(&[2, 3]);
for val in t.to_vec() {
assert_eq!(val, 1.0);
}
}
#[test]
fn test_full() {
let t = full::<f32>(&[2, 3], 42.0);
for val in t.to_vec() {
assert_eq!(val, 42.0);
}
}
#[test]
fn test_eye() {
let t = eye::<f32>(3);
assert_eq!(t.shape(), &[3, 3]);
assert_eq!(t.get(&[0, 0]).unwrap(), 1.0);
assert_eq!(t.get(&[1, 1]).unwrap(), 1.0);
assert_eq!(t.get(&[2, 2]).unwrap(), 1.0);
assert_eq!(t.get(&[0, 1]).unwrap(), 0.0);
}
#[test]
fn test_rand() {
let t = rand::<f32>(&[100]);
for val in t.to_vec() {
assert!((0.0..1.0).contains(&val));
}
}
#[test]
fn test_arange() {
let t = arange::<f32>(0.0, 5.0, 1.0);
assert_eq!(t.to_vec(), vec![0.0, 1.0, 2.0, 3.0, 4.0]);
let t = arange::<f32>(0.0, 1.0, 0.2);
assert_eq!(t.numel(), 5);
}
#[test]
fn test_linspace() {
let t = linspace::<f32>(0.0, 1.0, 5);
let data = t.to_vec();
assert_eq!(data.len(), 5);
assert!((data[0] - 0.0).abs() < 1e-6);
assert!((data[4] - 1.0).abs() < 1e-6);
}
#[test]
fn test_zeros_like() {
let a = ones::<f32>(&[2, 3]);
let b = zeros_like(&a);
assert_eq!(b.shape(), &[2, 3]);
for val in b.to_vec() {
assert_eq!(val, 0.0);
}
}
}