use crate::error::{CoreError, CoreResult, ErrorContext};
use ::ndarray::{Array1, Array2, ArrayD, Dimension, IntoDimension, IxDyn, ShapeError};
use num_traits::{Float, One, Zero};
use std::fmt::Display;
use std::ops::MulAssign;
pub struct MatrixBuilder<T>(std::marker::PhantomData<T>);
impl<T> MatrixBuilder<T>
where
T: Clone + Zero,
{
pub fn zeros(rows: usize, cols: usize) -> Array2<T> {
Array2::<T>::zeros((rows, cols))
}
pub fn from_vec(data: Vec<T>, rows: usize, cols: usize) -> CoreResult<Array2<T>> {
if data.len() != rows * cols {
return Err(CoreError::InvalidInput(ErrorContext::new(format!(
"MatrixBuilder::from_vec: expected {} elements for a {}×{} matrix, got {}",
rows * cols,
rows,
cols,
data.len()
))));
}
Array2::from_shape_vec((rows, cols), data).map_err(|e: ShapeError| {
CoreError::InvalidInput(ErrorContext::new(format!(
"MatrixBuilder::from_vec shape error: {e}"
)))
})
}
}
impl<T> MatrixBuilder<T>
where
T: Clone + Zero + One,
{
pub fn eye(n: usize) -> Array2<T> {
let mut m = Array2::<T>::zeros((n, n));
for i in 0..n {
m[[i, i]] = T::one();
}
m
}
pub fn ones(rows: usize, cols: usize) -> Array2<T> {
Array2::<T>::from_elem((rows, cols), T::one())
}
}
impl<T> MatrixBuilder<T>
where
T: Clone,
{
pub fn full(rows: usize, cols: usize, value: T) -> Array2<T> {
Array2::from_elem((rows, cols), value)
}
pub fn from_fn<F>(rows: usize, cols: usize, mut f: F) -> Array2<T>
where
F: FnMut(usize, usize) -> T,
{
Array2::from_shape_fn((rows, cols), |(r, c)| f(r, c))
}
}
impl<T> MatrixBuilder<T>
where
T: Float + Clone,
{
pub fn rand(rows: usize, cols: usize, seed: u64) -> Array2<T> {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
Array2::from_shape_fn((rows, cols), |_| {
use rand::RngExt;
let v: f64 = rng.random();
T::from(v).unwrap_or_else(T::zero)
})
}
pub fn randn(rows: usize, cols: usize, seed: u64) -> Array2<T> {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
Array2::from_shape_fn((rows, cols), |_| {
let v: f64 = StandardNormal.sample(&mut rng);
T::from(v).unwrap_or_else(T::zero)
})
}
}
pub struct VectorBuilder<T>(std::marker::PhantomData<T>);
impl<T> VectorBuilder<T>
where
T: Clone + Zero,
{
pub fn zeros(n: usize) -> Array1<T> {
Array1::<T>::zeros(n)
}
pub fn from_vec(data: Vec<T>) -> Array1<T> {
Array1::from(data)
}
}
impl<T> VectorBuilder<T>
where
T: Clone + Zero + One,
{
pub fn ones(n: usize) -> Array1<T> {
Array1::from_elem(n, T::one())
}
}
impl<T> VectorBuilder<T>
where
T: Clone,
{
pub fn from_fn<F>(n: usize, mut f: F) -> Array1<T>
where
F: FnMut(usize) -> T,
{
Array1::from_shape_fn(n, |i| f(i))
}
pub fn full(n: usize, value: T) -> Array1<T> {
Array1::from_elem(n, value)
}
}
impl<T> VectorBuilder<T>
where
T: Float + Display + Clone + MulAssign,
{
pub fn linspace(start: T, stop: T, n: usize) -> Array1<T> {
if n == 0 {
return Array1::from(vec![]);
}
if n == 1 {
return Array1::from(vec![start]);
}
let steps = T::from(n - 1).unwrap_or_else(T::one);
Array1::from_shape_fn(n, |i| {
let t = T::from(i).unwrap_or_else(T::zero);
start + (stop - start) * (t / steps)
})
}
pub fn arange(start: T, stop: T, step: T) -> Array1<T> {
if step == T::zero() || (stop - start).signum() != step.signum() {
return Array1::from(vec![]);
}
let n_float = ((stop - start) / step).ceil();
let n = n_float.to_usize().unwrap_or(0);
Array1::from_shape_fn(n, |i| start + step * T::from(i).unwrap_or_else(T::zero))
}
pub fn logspace(start: T, stop: T, n: usize) -> Array1<T> {
let lin = Self::linspace(start, stop, n);
lin.mapv(|x| T::from(10.0_f64).unwrap_or_else(T::one).powf(x))
}
pub fn rand(n: usize, seed: u64) -> Array1<T> {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
let mut rng = ChaCha8Rng::seed_from_u64(seed);
Array1::from_shape_fn(n, |_| {
use rand::RngExt;
let v: f64 = rng.random();
T::from(v).unwrap_or_else(T::zero)
})
}
pub fn randn(n: usize, seed: u64) -> Array1<T> {
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = ChaCha8Rng::seed_from_u64(seed);
Array1::from_shape_fn(n, |_| {
let v: f64 = StandardNormal.sample(&mut rng);
T::from(v).unwrap_or_else(T::zero)
})
}
}
pub struct ArrayBuilder<T, D>(std::marker::PhantomData<(T, D)>);
impl<T, D> ArrayBuilder<T, D>
where
T: Clone + Zero,
D: Dimension,
{
pub fn zeros<Sh>(shape: Sh) -> ::ndarray::Array<T, D>
where
Sh: IntoDimension<Dim = D>,
{
::ndarray::Array::zeros(shape)
}
pub fn full<Sh>(shape: Sh, value: T) -> ::ndarray::Array<T, D>
where
Sh: IntoDimension<Dim = D>,
{
::ndarray::Array::from_elem(shape, value)
}
pub fn from_fn<Sh, F>(shape: Sh, f: F) -> ::ndarray::Array<T, D>
where
Sh: IntoDimension<Dim = D>,
F: FnMut(D::Pattern) -> T,
{
::ndarray::Array::from_shape_fn(shape, f)
}
pub fn from_vec<Sh>(data: Vec<T>, shape: Sh) -> CoreResult<::ndarray::Array<T, D>>
where
Sh: IntoDimension<Dim = D>,
{
::ndarray::Array::from_shape_vec(shape, data).map_err(|e: ShapeError| {
CoreError::InvalidInput(ErrorContext::new(format!(
"ArrayBuilder::from_vec shape error: {e}"
)))
})
}
}
impl<T> ArrayBuilder<T, IxDyn>
where
T: Clone + Zero,
{
pub fn zeros_dyn(shape: &[usize]) -> ArrayD<T> {
ArrayD::zeros(IxDyn(shape))
}
pub fn full_dyn(shape: &[usize], value: T) -> ArrayD<T> {
ArrayD::from_elem(IxDyn(shape), value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_matrix_zeros() {
let m = MatrixBuilder::<f64>::zeros(3, 4);
assert_eq!(m.shape(), &[3, 4]);
assert!(m.iter().all(|&v| v == 0.0));
}
#[test]
fn test_matrix_ones() {
let m = MatrixBuilder::<f64>::ones(2, 5);
assert_eq!(m.shape(), &[2, 5]);
assert!(m.iter().all(|&v| v == 1.0));
}
#[test]
fn test_matrix_eye() {
let eye = MatrixBuilder::<f64>::eye(3);
assert_eq!(eye.shape(), &[3, 3]);
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(eye[[i, j]], expected);
}
}
}
#[test]
fn test_matrix_from_vec() {
let m = MatrixBuilder::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], 2, 2)
.expect("element count should match");
assert_eq!(m[[0, 0]], 1.0);
assert_eq!(m[[0, 1]], 2.0);
assert_eq!(m[[1, 0]], 3.0);
assert_eq!(m[[1, 1]], 4.0);
}
#[test]
fn test_matrix_from_vec_error() {
let result = MatrixBuilder::<f64>::from_vec(vec![1.0, 2.0, 3.0], 2, 2);
assert!(result.is_err());
}
#[test]
fn test_matrix_from_fn() {
let m = MatrixBuilder::from_fn(3, 3, |r, c| (r * 3 + c) as f64);
for r in 0..3 {
for c in 0..3 {
assert_abs_diff_eq!(m[[r, c]], (r * 3 + c) as f64);
}
}
}
#[test]
fn test_matrix_full() {
let m = MatrixBuilder::full(3, 3, 42_i32);
assert!(m.iter().all(|&v| v == 42));
}
#[test]
fn test_matrix_rand() {
let m = MatrixBuilder::<f64>::rand(10, 10, 99);
assert_eq!(m.shape(), &[10, 10]);
assert!(m.iter().all(|&v| v >= 0.0 && v < 1.0));
let m2 = MatrixBuilder::<f64>::rand(10, 10, 99);
assert_eq!(m, m2);
}
#[test]
fn test_matrix_randn() {
let m = MatrixBuilder::<f64>::randn(100, 100, 0);
let mean = m.mean().expect("non-empty");
assert!(mean.abs() < 0.5, "mean={mean}");
}
#[test]
fn test_vector_zeros() {
let v = VectorBuilder::<f64>::zeros(5);
assert_eq!(v.len(), 5);
assert!(v.iter().all(|&x| x == 0.0));
}
#[test]
fn test_vector_ones() {
let v = VectorBuilder::<f64>::ones(4);
assert_eq!(v.len(), 4);
assert!(v.iter().all(|&x| x == 1.0));
}
#[test]
fn test_vector_from_vec() {
let v = VectorBuilder::from_vec(vec![10.0_f64, 20.0, 30.0]);
assert_eq!(v.len(), 3);
assert_eq!(v[1], 20.0);
}
#[test]
fn test_vector_from_fn() {
let v = VectorBuilder::from_fn(5, |i| i as f64 * 2.0);
assert_abs_diff_eq!(v[3], 6.0);
}
#[test]
fn test_vector_full() {
let v = VectorBuilder::full(4, 1.23_f64);
assert!(v.iter().all(|&x| (x - 1.23).abs() < 1e-12));
}
#[test]
fn test_vector_linspace() {
let v = VectorBuilder::<f64>::linspace(0.0, 4.0, 5);
assert_eq!(v.len(), 5);
for (i, &val) in v.iter().enumerate() {
assert_abs_diff_eq!(val, i as f64, epsilon = 1e-12);
}
}
#[test]
fn test_vector_linspace_single() {
let v = VectorBuilder::<f64>::linspace(3.0, 3.0, 1);
assert_eq!(v.len(), 1);
assert_abs_diff_eq!(v[0], 3.0);
}
#[test]
fn test_vector_linspace_empty() {
let v = VectorBuilder::<f64>::linspace(0.0, 1.0, 0);
assert_eq!(v.len(), 0);
}
#[test]
fn test_vector_arange() {
let v = VectorBuilder::<f64>::arange(0.0, 5.0, 1.0);
assert_eq!(v.len(), 5);
for (i, &val) in v.iter().enumerate() {
assert_abs_diff_eq!(val, i as f64, epsilon = 1e-12);
}
}
#[test]
fn test_vector_arange_fractional() {
let v = VectorBuilder::<f64>::arange(0.0, 1.0, 0.5);
assert_eq!(v.len(), 2);
assert_abs_diff_eq!(v[0], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(v[1], 0.5, epsilon = 1e-12);
}
#[test]
fn test_vector_arange_empty() {
let v = VectorBuilder::<f64>::arange(0.0, 5.0, 0.0);
assert_eq!(v.len(), 0);
let v2 = VectorBuilder::<f64>::arange(5.0, 0.0, 1.0);
assert_eq!(v2.len(), 0);
}
#[test]
fn test_vector_logspace() {
let v = VectorBuilder::<f64>::logspace(0.0, 3.0, 4);
assert_eq!(v.len(), 4);
assert_abs_diff_eq!(v[0], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(v[1], 10.0, epsilon = 1e-8);
assert_abs_diff_eq!(v[2], 100.0, epsilon = 1e-6);
assert_abs_diff_eq!(v[3], 1000.0, epsilon = 1e-4);
}
#[test]
fn test_vector_rand() {
let v = VectorBuilder::<f64>::rand(20, 7);
assert_eq!(v.len(), 20);
assert!(v.iter().all(|&x| x >= 0.0 && x < 1.0));
let v2 = VectorBuilder::<f64>::rand(20, 7);
assert_eq!(v, v2);
}
#[test]
fn test_vector_randn() {
let v = VectorBuilder::<f64>::randn(1000, 123);
assert_eq!(v.len(), 1000);
let mean = v.mean().expect("non-empty");
assert!(mean.abs() < 0.2, "mean={mean}");
}
#[test]
fn test_array_builder_zeros_2d() {
let a = ArrayBuilder::<f64, ::ndarray::Ix2>::zeros(::ndarray::Ix2(3, 4));
assert_eq!(a.shape(), &[3, 4]);
assert!(a.iter().all(|&v| v == 0.0));
}
#[test]
fn test_array_builder_zeros_3d() {
let a = ArrayBuilder::<f64, ::ndarray::Ix3>::zeros(::ndarray::Ix3(2, 3, 4));
assert_eq!(a.shape(), &[2, 3, 4]);
}
#[test]
fn test_array_builder_zeros_dyn() {
let a = ArrayBuilder::<f64, ::ndarray::IxDyn>::zeros_dyn(&[2, 3, 4]);
assert_eq!(a.ndim(), 3);
assert_eq!(a.shape(), &[2, 3, 4]);
}
#[test]
fn test_array_builder_full() {
let a = ArrayBuilder::<i32, ::ndarray::Ix2>::full(::ndarray::Ix2(3, 3), 7);
assert!(a.iter().all(|&v| v == 7));
}
#[test]
fn test_array_builder_from_vec_ok() {
let a = ArrayBuilder::<f64, ::ndarray::Ix2>::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
::ndarray::Ix2(2, 3),
)
.expect("valid shape");
assert_eq!(a[[1, 2]], 6.0);
}
#[test]
fn test_array_builder_from_vec_err() {
let result = ArrayBuilder::<f64, ::ndarray::Ix2>::from_vec(
vec![1.0, 2.0, 3.0],
::ndarray::Ix2(2, 3), );
assert!(result.is_err());
}
}