use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{One, Zero};
use std::ops::Mul;
pub fn tri<T>(n: usize, m: Option<usize>, k: Option<isize>) -> Result<Array<T>>
where
T: Clone + num_traits::Zero + num_traits::One,
{
let m = m.unwrap_or(n);
let k = k.unwrap_or(0);
let mut data = Vec::with_capacity(n * m);
for i in 0..n {
for j in 0..m {
if j as isize <= i as isize + k {
data.push(T::one());
} else {
data.push(T::zero());
}
}
}
Ok(Array::from_vec(data).reshape(&[n, m]))
}
pub fn diagflat<T>(v: &Array<T>, k: i32) -> Result<Array<T>>
where
T: Clone + Zero,
{
let flat_data = v.to_vec();
let n = flat_data.len();
let size = (n as i32 + k.abs()) as usize;
let mut result = vec![T::zero(); size * size];
for i in 0..n {
let row = if k >= 0 { i } else { i + (-k) as usize };
let col = if k >= 0 { i + k as usize } else { i };
if row < size && col < size {
result[row * size + col] = flat_data[i].clone();
}
}
Ok(Array::from_vec(result).reshape(&[size, size]))
}
pub fn vander<T>(x: &Array<T>, n: Option<usize>, increasing: bool) -> Result<Array<T>>
where
T: Clone + Zero + One + Mul<Output = T>,
{
if x.ndim() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"vander requires a 1D array".to_string(),
));
}
let m = x.len();
let n_cols = n.unwrap_or(m);
if n_cols == 0 {
return Ok(Array::zeros(&[m, 0]));
}
let x_data = x.to_vec();
let mut result = vec![T::one(); m * n_cols];
for i in 0..m {
let mut power = T::one();
if increasing {
for j in 0..n_cols {
result[i * n_cols + j] = power.clone();
if j < n_cols - 1 {
power = power * x_data[i].clone();
}
}
} else {
for _ in 1..n_cols {
power = power * x_data[i].clone();
}
result[i * n_cols] = power.clone();
for j in 1..n_cols {
if j == n_cols - 1 {
result[i * n_cols + j] = T::one();
} else {
let mut pow = T::one();
for _ in 0..(n_cols - j - 1) {
pow = pow * x_data[i].clone();
}
result[i * n_cols + j] = pow;
}
}
}
}
Ok(Array::from_vec(result).reshape(&[m, n_cols]))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tri() {
let result: Array<i32> = tri(3, None, None).expect("operation should succeed");
assert_eq!(result.shape(), vec![3, 3]);
assert_eq!(result.to_vec(), vec![1, 0, 0, 1, 1, 0, 1, 1, 1]);
let result: Array<f64> = tri(3, Some(4), Some(1)).expect("operation should succeed");
assert_eq!(result.shape(), vec![3, 4]);
assert_eq!(
result.to_vec(),
vec![1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]
);
let result: Array<i32> = tri(3, None, Some(-1)).expect("operation should succeed");
assert_eq!(result.shape(), vec![3, 3]);
assert_eq!(result.to_vec(), vec![0, 0, 0, 1, 0, 0, 1, 1, 0]);
}
#[test]
fn test_diagflat() {
let v = Array::from_vec(vec![1.0, 2.0, 3.0]);
let result = diagflat(&v, 0).expect("operation should succeed");
assert_eq!(result.shape(), vec![3, 3]);
assert_eq!(
result.to_vec(),
vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]
);
let result = diagflat(&v, 1).expect("operation should succeed");
assert_eq!(result.shape(), vec![4, 4]);
let result = diagflat(&v, -1).expect("operation should succeed");
assert_eq!(result.shape(), vec![4, 4]);
}
#[test]
fn test_vander() {
let x = Array::from_vec(vec![1.0, 2.0, 3.0]);
let v = vander(&x, None, true).expect("operation should succeed");
assert_eq!(v.shape(), vec![3, 3]);
assert_eq!(
v.to_vec(),
vec![1.0, 1.0, 1.0, 1.0, 2.0, 4.0, 1.0, 3.0, 9.0]
);
let v = vander(&x, None, false).expect("operation should succeed");
assert_eq!(v.shape(), vec![3, 3]);
assert_eq!(
v.to_vec(),
vec![1.0, 1.0, 1.0, 4.0, 2.0, 1.0, 9.0, 3.0, 1.0]
);
let v = vander(&x, Some(2), false).expect("operation should succeed");
assert_eq!(v.shape(), vec![3, 2]);
assert_eq!(v.to_vec(), vec![1.0, 1.0, 2.0, 1.0, 3.0, 1.0]);
let x2d = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
assert!(vander(&x2d, None, true).is_err());
}
}