use std::fmt;
use rand::Rng;
use rand_distr::StandardNormal;
use std::ops::Add;
use std::ops::Sub;
use std::ops::{Mul, Div};
use std::f64;
use std::ops::{Index, IndexMut};
#[derive(Debug, Clone)]
pub struct NDArray {
data: Vec<f64>,
shape: Vec<usize>,
}
impl NDArray {
pub fn new(data: Vec<f64>, shape: Vec<usize>) -> Self {
let total_size: usize = shape.iter().product();
assert_eq!(data.len(), total_size, "Data length must match shape dimensions");
NDArray { data, shape }
}
pub fn from_vec(data: Vec<f64>) -> Self {
let len = data.len();
Self::new(data, vec![len])
}
pub fn from_matrix(data: Vec<Vec<f64>>) -> Self {
let rows = data.len();
let cols = data.get(0).map_or(0, |row| row.len());
let flat_data: Vec<f64> = data.into_iter().flatten().collect();
Self::new(flat_data, vec![rows, cols])
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn data(&self) -> &[f64] {
&self.data
}
pub fn arange(start: f64, stop: f64, step: f64) -> Self {
let mut data = Vec::new();
let mut current = start;
while current < stop {
data.push(current);
current += step;
}
Self::from_vec(data)
}
pub fn zeros(size: usize) -> Self {
Self::from_vec(vec![0.0; size])
}
pub fn zeros_2d(rows: usize, cols: usize) -> Self {
Self::new(vec![0.0; rows * cols], vec![rows, cols])
}
pub fn ones(size: usize) -> Self {
Self::from_vec(vec![1.0; size])
}
pub fn ones_2d(rows: usize, cols: usize) -> Self {
Self::new(vec![1.0; rows * cols], vec![rows, cols])
}
pub fn linspace(start: f64, end: f64, num: usize, precision: usize) -> Self {
assert!(num > 1, "Number of samples must be greater than 1");
let step = (end - start) / (num - 1) as f64;
let mut data = Vec::with_capacity(num);
let factor = 10f64.powi(precision as i32);
for i in 0..num {
let value = start + step * i as f64;
let rounded_value = (value * factor).round() / factor;
data.push(rounded_value);
}
Self::from_vec(data)
}
pub fn eye(n: usize) -> Self {
let mut data = vec![0.0; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
Self::new(data, vec![n, n])
}
pub fn rand(size: usize) -> Self {
let mut rng = rand::thread_rng();
let data: Vec<f64> = (0..size).map(|_| rng.gen()).collect();
Self::from_vec(data)
}
pub fn rand_2d(rows: usize, cols: usize) -> Self {
let mut rng = rand::thread_rng();
let data: Vec<f64> = (0..rows * cols).map(|_| rng.gen()).collect();
Self::new(data, vec![rows, cols])
}
pub fn randn(size: usize) -> Self {
let mut rng = rand::thread_rng();
let data: Vec<f64> = (0..size).map(|_| rng.sample(StandardNormal)).collect();
Self::from_vec(data)
}
pub fn randn_2d(rows: usize, cols: usize) -> Self {
let mut rng = rand::thread_rng();
let data: Vec<f64> = (0..rows * cols).map(|_| rng.sample(StandardNormal)).collect();
Self::new(data, vec![rows, cols])
}
pub fn randint(low: i32, high: i32, size: usize) -> Self {
let mut rng = rand::thread_rng();
let data: Vec<f64> = (0..size).map(|_| rng.gen_range(low..high) as f64).collect();
Self::from_vec(data)
}
pub fn randint_2d(low: i32, high: i32, rows: usize, cols: usize) -> Self {
let mut rng = rand::thread_rng();
let data: Vec<f64> = (0..rows * cols).map(|_| rng.gen_range(low..high) as f64).collect();
Self::new(data, vec![rows, cols])
}
pub fn reshape(&self, new_shape: Vec<usize>) -> Self {
let new_size: usize = new_shape.iter().product();
assert_eq!(self.data.len(), new_size, "New shape must have the same number of elements as the original array");
Self::new(self.data.clone(), new_shape)
}
pub fn max(&self) -> f64 {
*self.data.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()
}
pub fn argmax(&self) -> usize {
self.data.iter().enumerate().max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()).map(|(i, _)| i).unwrap()
}
pub fn min(&self) -> f64 {
*self.data.iter().min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap()
}
pub fn argmin(&self) -> usize {
self.data.iter().enumerate().min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()).map(|(i, _)| i).unwrap()
}
pub fn sqrt(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| x.sqrt()).collect();
Self::new(data, self.shape.clone())
}
pub fn exp(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| x.exp()).collect();
Self::new(data, self.shape.clone())
}
pub fn sin(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| x.sin()).collect();
Self::new(data, self.shape.clone())
}
pub fn cos(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| x.cos()).collect();
Self::new(data, self.shape.clone())
}
pub fn ln(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| x.ln()).collect();
Self::new(data, self.shape.clone())
}
pub fn get(&self, index: usize) -> f64 {
self.data[index]
}
pub fn slice(&self, start: usize, end: usize) -> Self {
let data = self.data[start..end].to_vec();
Self::from_vec(data)
}
pub fn set(&mut self, index: usize, value: f64) {
self.data[index] = value;
}
pub fn set_range(&mut self, start: usize, end: usize, value: f64) {
for i in start..end {
self.data[i] = value;
}
}
pub fn copy(&self) -> Self {
Self::new(self.data.clone(), self.shape.clone())
}
pub fn view(&self, start: usize, end: usize) -> &[f64] {
&self.data[start..end]
}
pub fn view_mut(&mut self, start: usize, end: usize) -> &mut [f64] {
&mut self.data[start..end]
}
pub fn get_2d(&self, row: usize, col: usize) -> f64 {
assert_eq!(self.ndim(), 2, "get_2d is only applicable to 2D arrays");
let cols = self.shape[1];
self.data[row * cols + col]
}
pub fn set_2d(&mut self, row: usize, col: usize, value: f64) {
assert_eq!(self.ndim(), 2, "set_2d is only applicable to 2D arrays");
let cols = self.shape[1];
self.data[row * cols + col] = value;
}
pub fn sub_matrix(&self, row_start: usize, row_end: usize, col_start: usize, col_end: usize) -> Self {
assert_eq!(self.ndim(), 2, "sub_matrix is only applicable to 2D arrays");
let cols = self.shape[1];
let mut data = Vec::new();
for row in row_start..row_end {
for col in col_start..col_end {
data.push(self.data[row * cols + col]);
}
}
Self::new(data, vec![row_end - row_start, col_end - col_start])
}
pub fn greater_than(&self, threshold: f64) -> Vec<bool> {
self.data.iter().map(|&x| x > threshold).collect()
}
pub fn filter(&self, condition: impl Fn(&f64) -> bool) -> Self {
let data: Vec<f64> = self.data.iter().cloned().filter(condition).collect();
Self::from_vec(data)
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn dtype(&self) -> &'static str {
"f64" }
pub fn new_axis(&self, axis: usize) -> Self {
let mut new_shape = self.shape.clone();
new_shape.insert(axis, 1);
Self::new(self.data.clone(), new_shape)
}
pub fn expand_dims(&self, axis: usize) -> Self {
self.new_axis(axis)
}
pub fn tanh(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| x.tanh()).collect();
Self::new(data, self.shape.clone())
}
pub fn relu(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| x.max(0.0)).collect();
Self::new(data, self.shape.clone())
}
pub fn leaky_relu(&self, alpha: f64) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| if x > 0.0 { x } else { alpha * x }).collect();
Self::new(data, self.shape.clone())
}
pub fn sigmoid(&self) -> Self {
let data: Vec<f64> = self.data.iter().map(|&x| 1.0_f64 / (1.0_f64 + (-x).exp())).collect();
Self::new(data, self.shape.clone())
}
}
impl fmt::Display for NDArray {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.ndim() == 1 {
write!(f, "array([{}])", self.data.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "))
} else {
let mut result = String::from("array([");
for i in 0..self.shape[0] {
if i > 0 {
result.push_str(",\n ");
}
result.push('[');
let row_start = i * self.shape[1];
let row_end = row_start + self.shape[1];
result.push_str(&self.data[row_start..row_end]
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", "));
result.push(']');
}
result.push_str("])");
write!(f, "{}", result)
}
}
}
impl Add<f64> for NDArray {
type Output = Self;
fn add(self, scalar: f64) -> Self::Output {
let data: Vec<f64> = self.data.iter().map(|&x| x + scalar).collect();
Self::new(data, self.shape.clone())
}
}
impl Add for NDArray {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
assert_eq!(self.shape, other.shape, "Shapes must be the same for element-wise addition");
let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(&a, &b)| a + b).collect();
Self::new(data, self.shape.clone())
}
}
impl Sub<f64> for NDArray {
type Output = Self;
fn sub(self, scalar: f64) -> Self::Output {
let data: Vec<f64> = self.data.iter().map(|&x| x - scalar).collect();
Self::new(data, self.shape.clone())
}
}
impl Sub for NDArray {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
assert_eq!(self.shape, other.shape, "Shapes must be the same for element-wise subtraction");
let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(&a, &b)| a - b).collect();
Self::new(data, self.shape.clone())
}
}
impl Mul<f64> for NDArray {
type Output = Self;
fn mul(self, scalar: f64) -> Self::Output {
let data: Vec<f64> = self.data.iter().map(|&x| x * scalar).collect();
Self::new(data, self.shape.clone())
}
}
impl Mul for NDArray {
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
assert_eq!(self.shape, other.shape, "Shapes must be the same for element-wise multiplication");
let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(&a, &b)| a * b).collect();
Self::new(data, self.shape.clone())
}
}
impl Div<f64> for NDArray {
type Output = Self;
fn div(self, scalar: f64) -> Self::Output {
let data: Vec<f64> = self.data.iter().map(|&x| x / scalar).collect();
Self::new(data, self.shape.clone())
}
}
impl Div for NDArray {
type Output = Self;
fn div(self, other: Self) -> Self::Output {
assert_eq!(self.shape, other.shape, "Shapes must be the same for element-wise division");
let data: Vec<f64> = self.data.iter().zip(other.data.iter()).map(|(&a, &b)| a / b).collect();
Self::new(data, self.shape.clone())
}
}
impl Index<usize> for NDArray {
type Output = f64;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl IndexMut<usize> for NDArray {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.data[index]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_1d_array_creation() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
assert_eq!(arr.data(), &[1.0, 2.0, 3.0]);
assert_eq!(arr.shape(), &[3]);
assert_eq!(arr.to_string(), "array([1, 2, 3])");
}
#[test]
fn test_2d_array_creation() {
let arr = NDArray::from_matrix(vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
]);
assert_eq!(arr.shape(), &[3, 3]);
assert_eq!(arr.to_string(), "array([[1, 2, 3],\n [4, 5, 6],\n [7, 8, 9]])");
}
#[test]
fn test_arange() {
let arr = NDArray::arange(0.0, 5.0, 1.0);
assert_eq!(arr.data(), &[0.0, 1.0, 2.0, 3.0, 4.0]);
assert_eq!(arr.to_string(), "array([0, 1, 2, 3, 4])");
let arr = NDArray::arange(1.0, 11.0, 2.0);
assert_eq!(arr.data(), &[1.0, 3.0, 5.0, 7.0, 9.0]);
assert_eq!(arr.to_string(), "array([1, 3, 5, 7, 9])");
}
#[test]
fn test_zeros() {
let arr = NDArray::zeros(4);
assert_eq!(arr.data(), &[0.0, 0.0, 0.0, 0.0]);
assert_eq!(arr.to_string(), "array([0, 0, 0, 0])");
let arr = NDArray::zeros_2d(2, 2);
assert_eq!(arr.data(), &[0.0, 0.0, 0.0, 0.0]);
assert_eq!(arr.to_string(), "array([[0, 0],\n [0, 0]])");
}
#[test]
fn test_ones() {
let arr = NDArray::ones(5);
assert_eq!(arr.data(), &[1.0, 1.0, 1.0, 1.0, 1.0]);
assert_eq!(arr.to_string(), "array([1, 1, 1, 1, 1])");
let arr = NDArray::ones_2d(3, 3);
assert_eq!(arr.data(), &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
assert_eq!(arr.to_string(), "array([[1, 1, 1],\n [1, 1, 1],\n [1, 1, 1]])");
}
#[test]
fn test_linspace() {
let arr = NDArray::linspace(0.0, 1.0, 11, 1);
let expected = &[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
for (a, &e) in arr.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-9, "Value {} is not close to expected {}", a, e);
}
assert_eq!(arr.to_string(), "array([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1])");
}
#[test]
fn test_eye() {
let arr = NDArray::eye(1);
assert_eq!(arr.data(), &[1.0]);
assert_eq!(arr.to_string(), "array([[1]])");
let arr = NDArray::eye(2);
assert_eq!(arr.data(), &[1.0, 0.0, 0.0, 1.0]);
assert_eq!(arr.to_string(), "array([[1, 0],\n [0, 1]])");
let arr = NDArray::eye(3);
assert_eq!(arr.data(), &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
assert_eq!(arr.to_string(), "array([[1, 0, 0],\n [0, 1, 0],\n [0, 0, 1]])");
}
#[test]
fn test_rand() {
let arr = NDArray::rand(5);
assert_eq!(arr.shape(), &[5]);
assert!(arr.data().iter().all(|&x| x >= 0.0 && x < 1.0));
}
#[test]
fn test_rand_2d() {
let arr = NDArray::rand_2d(2, 3);
assert_eq!(arr.shape(), &[2, 3]);
assert!(arr.data().iter().all(|&x| x >= 0.0 && x < 1.0));
}
#[test]
fn test_randint() {
let arr = NDArray::randint(1, 10, 5);
assert_eq!(arr.shape(), &[5]);
assert!(arr.data().iter().all(|&x| x >= 1.0 && x < 10.0));
}
#[test]
fn test_randint_2d() {
let arr = NDArray::randint_2d(1, 10, 2, 3);
assert_eq!(arr.shape(), &[2, 3]);
assert!(arr.data().iter().all(|&x| x >= 1.0 && x < 10.0));
}
#[test]
fn test_reshape() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(arr.shape(), &[6]);
let reshaped = arr.reshape(vec![2, 3]);
assert_eq!(reshaped.shape(), &[2, 3]);
assert_eq!(reshaped.to_string(), "array([[0, 1, 2],\n [3, 4, 5]])");
}
#[test]
fn test_max_min() {
let arr = NDArray::from_vec(vec![1.0, -2.0, 3.0, 4.0, 5.0]);
assert_eq!(arr.max(), 5.0);
assert_eq!(arr.argmax(), 4);
assert_eq!(arr.min(), -2.0);
assert_eq!(arr.argmin(), 1);
}
#[test]
fn test_scalar_addition() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.clone() + 2.0;
assert_eq!(result.data(), &[2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_element_wise_addition() {
let arr1 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let arr2 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr1 + arr2;
assert_eq!(result.data(), &[0.0, 2.0, 4.0, 6.0]);
}
#[test]
fn test_scalar_subtraction() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.clone() - 10.0;
assert_eq!(result.data(), &[-10.0, -9.0, -8.0, -7.0]);
}
#[test]
fn test_element_wise_subtraction() {
let arr1 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let arr2 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr1 - arr2;
assert_eq!(result.data(), &[0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn test_scalar_multiplication() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.clone() * 6.0;
assert_eq!(result.data(), &[0.0, 6.0, 12.0, 18.0]);
}
#[test]
fn test_element_wise_multiplication() {
let arr1 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let arr2 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr1 * arr2;
assert_eq!(result.data(), &[0.0, 1.0, 4.0, 9.0]);
}
#[test]
fn test_scalar_division() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.clone() / 2.0;
assert_eq!(result.data(), &[0.0, 0.5, 1.0, 1.5]);
}
#[test]
fn test_element_wise_division() {
let arr1 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let arr2 = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr1 / arr2;
assert!(result.data().iter().zip(&[f64::NAN, 1.0, 1.0, 1.0]).all(|(a, &b)| a.is_nan() || (a - b).abs() < 1e-9));
}
#[test]
fn test_sqrt() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.sqrt();
let expected = &[0.0, 1.0, 1.41421356, 1.73205081];
for (a, &e) in result.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-8, "Value {} is not close to expected {}", a, e);
}
}
#[test]
fn test_exp() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.exp();
let expected = &[1.0, 2.71828183, 7.3890561, 20.08553692];
for (a, &e) in result.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-8, "Value {} is not close to expected {}", a, e);
}
}
#[test]
fn test_sin() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.sin();
let expected = &[0.0, 0.84147098, 0.90929743, 0.14112001];
for (a, &e) in result.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-8, "Value {} is not close to expected {}", a, e);
}
}
#[test]
fn test_cos() {
let arr = NDArray::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
let result = arr.cos();
let expected = &[1.0, 0.54030231, -0.41614684, -0.9899925];
for (a, &e) in result.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-8, "Value {} is not close to expected {}", a, e);
}
}
#[test]
fn test_ln() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
let result = arr.ln();
let expected = &[0.0, 0.69314718, 1.09861229];
for (a, &e) in result.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-8, "Value {} is not close to expected {}", a, e);
}
}
#[test]
fn test_get_element() {
let arr = NDArray::from_vec(vec![0.69, 0.94, 0.66, 0.73, 0.83]);
assert_eq!(arr.get(0), 0.69);
}
#[test]
fn test_slice() {
let arr = NDArray::from_vec(vec![0.69, 0.94, 0.66, 0.73, 0.83]);
let sliced = arr.slice(1, 4);
assert_eq!(sliced.data(), &[0.94, 0.66, 0.73]);
}
#[test]
fn test_set_element() {
let mut arr = NDArray::from_vec(vec![0.69, 0.94, 0.66, 0.73, 0.83]);
arr.set(0, 1.0);
assert_eq!(arr.get(0), 1.0);
}
#[test]
fn test_index_operator() {
let arr = NDArray::from_vec(vec![0.69, 0.94, 0.66, 0.73, 0.83]);
assert_eq!(arr[0], 0.69);
}
#[test]
fn test_index_mut_operator() {
let mut arr = NDArray::from_vec(vec![0.69, 0.94, 0.66, 0.73, 0.83]);
arr[0] = 1.0;
assert_eq!(arr[0], 1.0);
}
#[test]
fn test_single_element_assignment() {
let mut arr = NDArray::from_vec(vec![0.12, 0.94, 0.66, 0.73, 0.83]);
arr.set(0, 0.0);
assert_eq!(arr.data(), &[0.0, 0.94, 0.66, 0.73, 0.83]);
}
#[test]
fn test_range_assignment() {
let mut arr = NDArray::from_vec(vec![0.12, 0.94, 0.66, 0.73, 0.83]);
arr.set_range(0, arr.data.len(), 0.0);
assert_eq!(arr.data(), &[0.0, 0.0, 0.0, 0.0, 0.0]);
arr.set_range(2, 5, 0.5);
assert_eq!(arr.data(), &[0.0, 0.0, 0.5, 0.5, 0.5]);
}
#[test]
fn test_array_referencing() {
let mut arr = NDArray::from_vec(vec![6.0, 7.0, 8.0, 9.0]);
{
let view = arr.view_mut(0, 2);
view[1] = 4.0;
}
assert_eq!(arr.data(), &[6.0, 4.0, 8.0, 9.0]); }
#[test]
fn test_array_copying() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
let mut copied = arr.copy();
assert_eq!(copied.data(), &[1.0, 2.0, 3.0]);
copied.set(0, 9.0);
assert_eq!(copied.data(), &[9.0, 2.0, 3.0]);
assert_eq!(arr.data(), &[1.0, 2.0, 3.0]); }
#[test]
fn test_2d_indexing() {
let mat = NDArray::from_matrix(vec![
vec![5.0, 10.0, 15.0],
vec![20.0, 25.0, 30.0],
vec![35.0, 40.0, 45.0],
]);
assert_eq!(mat.get_2d(0, 0), 5.0);
assert_eq!(mat.get_2d(0, 2), 15.0);
assert_eq!(mat.get_2d(2, 2), 45.0);
}
#[test]
fn test_2d_set() {
let mut mat = NDArray::from_matrix(vec![
vec![5.0, 10.0, 15.0],
vec![20.0, 25.0, 30.0],
vec![35.0, 40.0, 45.0],
]);
mat.set_2d(0, 0, 50.0);
assert_eq!(mat.get_2d(0, 0), 50.0);
}
#[test]
fn test_sub_matrix() {
let mat = NDArray::from_matrix(vec![
vec![5.0, 10.0, 15.0],
vec![20.0, 25.0, 30.0],
vec![35.0, 40.0, 45.0],
]);
let sub_mat = mat.sub_matrix(1, 3, 0, 3);
assert_eq!(sub_mat.shape(), &[2, 3]);
assert_eq!(sub_mat.to_string(), "array([[20, 25, 30],\n [35, 40, 45]])");
}
#[test]
fn test_greater_than() {
let arr = NDArray::from_vec(vec![0.69, 0.94, 0.66, 0.73, 0.83]);
let result = arr.greater_than(0.7);
assert_eq!(result, vec![false, true, false, true, true]);
}
#[test]
fn test_filter() {
let arr = NDArray::from_vec(vec![0.69, 0.94, 0.66, 0.73, 0.83]);
let filtered = arr.filter(|&x| x > 0.7);
assert_eq!(filtered.data(), &[0.94, 0.73, 0.83]);
}
#[test]
fn test_ndim() {
let arr = NDArray::from_matrix(vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![5.0, 6.0, 7.0, 8.0],
vec![9.0, 10.0, 11.0, 12.0],
]);
assert_eq!(arr.ndim(), 2);
}
#[test]
fn test_shape() {
let arr = NDArray::from_matrix(vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![5.0, 6.0, 7.0, 8.0],
vec![9.0, 10.0, 11.0, 12.0],
]);
assert_eq!(arr.shape(), &[3, 4]);
}
#[test]
fn test_size() {
let arr = NDArray::from_matrix(vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![5.0, 6.0, 7.0, 8.0],
vec![9.0, 10.0, 11.0, 12.0],
]);
assert_eq!(arr.size(), 12);
}
#[test]
fn test_dtype() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
assert_eq!(arr.dtype(), "f64");
}
#[test]
fn test_new_axis_row_vector() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let row_vector = arr.new_axis(0);
assert_eq!(row_vector.shape(), &[1, 6]);
}
#[test]
fn test_new_axis_col_vector() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let col_vector = arr.new_axis(1);
assert_eq!(col_vector.shape(), &[6, 1]);
}
#[test]
fn test_expand_dims_axis_0() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let expanded = arr.expand_dims(0);
assert_eq!(expanded.shape(), &[1, 6]);
}
#[test]
fn test_expand_dims_axis_1() {
let arr = NDArray::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let expanded = arr.expand_dims(1);
assert_eq!(expanded.shape(), &[6, 1]);
}
#[test]
fn test_tanh() {
let arr = NDArray::from_vec(vec![0.0_f64, 1.0_f64, -1.0_f64, 0.5_f64, -0.5_f64]);
let result = arr.tanh();
let expected = vec![0.0_f64.tanh(), 1.0_f64.tanh(), -1.0_f64.tanh(), 0.5_f64.tanh(), -0.5_f64.tanh()];
for (a, &e) in result.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-9, "Value {} is not close to expected {}", a, e);
}
}
#[test]
fn test_relu() {
let arr = NDArray::from_vec(vec![-1.0, 0.0, 1.0, -0.5, 2.0]);
let result = arr.relu();
let expected = vec![0.0, 0.0, 1.0, 0.0, 2.0];
assert_eq!(result.data(), &expected);
}
#[test]
fn test_leaky_relu() {
let arr = NDArray::from_vec(vec![-1.0, 0.0, 1.0, -0.5, 2.0]);
let result = arr.leaky_relu(0.01);
let expected = vec![-0.01, 0.0, 1.0, -0.005, 2.0];
assert_eq!(result.data(), &expected);
}
#[test]
fn test_sigmoid() {
let arr = NDArray::from_vec(vec![0.0_f64, 1.0_f64, -1.0_f64, 0.5_f64, -0.5_f64]);
let result = arr.sigmoid();
let expected = vec![
0.5_f64,
1.0_f64 / (1.0_f64 + (-1.0_f64).exp()),
1.0_f64 / (1.0_f64 + 1.0_f64.exp()),
1.0_f64 / (1.0_f64 + (-0.5_f64).exp()),
1.0_f64 / (1.0_f64 + 0.5_f64.exp()),
];
for (a, &e) in result.data().iter().zip(expected.iter()) {
assert!((a - e).abs() < 1e-9, "Value {} is not close to expected {}", a, e);
}
}
}