use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
pub fn tril<T: Clone + Zero>(array: &Array<T>, k: isize) -> Result<Array<T>> {
let shape = array.shape();
let ndim = shape.len();
if ndim < 2 {
return Ok(array.clone());
}
let n_rows = shape[ndim - 2];
let n_cols = shape[ndim - 1];
let mut result = array.clone();
let result_data = result
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
let n_matrices: usize = shape[..ndim - 2].iter().product();
let matrix_size = n_rows * n_cols;
for m in 0..n_matrices {
let matrix_offset = m * matrix_size;
for i in 0..n_rows {
for j in 0..n_cols {
if (j as isize) > (i as isize + k) {
let idx = matrix_offset + i * n_cols + j;
result_data[idx] = T::zero();
}
}
}
}
Ok(result)
}
pub fn triu<T: Clone + Zero>(array: &Array<T>, k: isize) -> Result<Array<T>> {
let shape = array.shape();
let ndim = shape.len();
if ndim < 2 {
return Ok(array.clone());
}
let n_rows = shape[ndim - 2];
let n_cols = shape[ndim - 1];
let mut result = array.clone();
let result_data = result
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
let n_matrices: usize = shape[..ndim - 2].iter().product();
let matrix_size = n_rows * n_cols;
for m in 0..n_matrices {
let matrix_offset = m * matrix_size;
for i in 0..n_rows {
for j in 0..n_cols {
if (j as isize) < (i as isize + k) {
let idx = matrix_offset + i * n_cols + j;
result_data[idx] = T::zero();
}
}
}
}
Ok(result)
}
pub fn append<T: Clone + Zero>(
array: &Array<T>,
values: &[T],
axis: Option<usize>,
) -> Result<Array<T>> {
match axis {
None => {
let mut result_data = array.to_vec();
result_data.extend_from_slice(values);
Ok(Array::from_vec(result_data))
}
Some(ax) => {
let shape = array.shape();
if ax >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
ax,
shape.len()
)));
}
let axis_size = shape[ax];
let mut values_shape = shape.clone();
let expected_values_size: usize = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != ax)
.map(|(_, &s)| s)
.product();
if !values.len().is_multiple_of(expected_values_size) {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![expected_values_size],
actual: vec![values.len()],
});
}
let values_axis_size = values.len() / expected_values_size;
values_shape[ax] = values_axis_size;
let values_array = Array::from_vec(values.to_vec()).reshape(&values_shape);
let mut result_shape = shape.clone();
result_shape[ax] = axis_size + values_axis_size;
let pre_axis_size: usize = shape[..ax].iter().product();
let post_axis_size: usize = shape[ax + 1..].iter().product();
let total_size = pre_axis_size * result_shape[ax] * post_axis_size;
let mut result_data = Vec::with_capacity(total_size);
let array_data = array.to_vec();
let values_data = values_array.to_vec();
for pre in 0..pre_axis_size {
for i in 0..axis_size {
for post in 0..post_axis_size {
let idx = pre * axis_size * post_axis_size + i * post_axis_size + post;
result_data.push(array_data[idx].clone());
}
}
for i in 0..values_axis_size {
for post in 0..post_axis_size {
let idx =
pre * values_axis_size * post_axis_size + i * post_axis_size + post;
result_data.push(values_data[idx].clone());
}
}
}
Ok(Array::from_vec(result_data).reshape(&result_shape))
}
}
}