use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
use std::cmp;
pub fn diag<T: Clone + Zero>(array: &Array<T>, k: impl Into<Option<isize>>) -> Result<Array<T>> {
let k = k.into().unwrap_or(0);
let ndim = array.ndim();
match ndim {
1 => {
let size = array.size();
let diag_size = size + k.unsigned_abs();
let result = Array::zeros(&[diag_size, diag_size]);
let mut result_vec = result.to_vec();
let array_vec = array.to_vec();
#[allow(clippy::needless_range_loop)]
for i in 0..size {
let row: usize;
let col: usize;
if k >= 0 {
row = i;
col = i + k as usize;
} else {
row = i + (-k) as usize;
col = i;
}
if row < diag_size && col < diag_size {
let idx = row * diag_size + col;
result_vec[idx] = array_vec[i].clone();
}
}
Ok(Array::from_vec(result_vec).reshape(&[diag_size, diag_size]))
}
2 => {
let shape = array.shape();
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected a 2D array, got shape {:?}",
shape
)));
}
let rows = shape[0];
let cols = shape[1];
let diag_len = if k >= 0 {
cmp::min(rows, cols.saturating_sub(k as usize))
} else {
cmp::min(rows.saturating_sub((-k) as usize), cols)
};
if diag_len == 0 {
return Ok(Array::zeros(&[0]));
}
let mut result = Vec::with_capacity(diag_len);
let array_vec = array.to_vec();
for i in 0..diag_len {
let row: usize;
let col: usize;
if k >= 0 {
row = i;
col = i + k as usize;
} else {
row = i + (-k) as usize;
col = i;
}
if row < rows && col < cols {
let idx = row * cols + col;
result.push(array_vec[idx].clone());
}
}
Ok(Array::from_vec(result))
}
_ => Err(NumRs2Error::InvalidOperation(format!(
"Input must be 1D or 2D array, got {}D array",
ndim
))),
}
}
pub fn diagonal<T: Clone + num_traits::Zero>(
array: &Array<T>,
offset: impl Into<Option<isize>>,
axis1: impl Into<Option<usize>>,
axis2: impl Into<Option<usize>>,
) -> Result<Array<T>> {
let offset = offset.into().unwrap_or(0);
let axis1 = axis1.into().unwrap_or(0);
let axis2 = axis2.into().unwrap_or(1);
let ndim = array.ndim();
if ndim < 2 {
return Err(NumRs2Error::InvalidOperation(format!(
"Array must be at least 2D, got {}D array",
ndim
)));
}
if axis1 == axis2 {
return Err(NumRs2Error::InvalidOperation(format!(
"axis1 and axis2 cannot be the same: {}",
axis1
)));
}
if axis1 >= ndim || axis2 >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axes ({}, {}) out of bounds for array of dimension {}",
axis1, axis2, ndim
)));
}
let shape = array.shape();
let axis1_len = shape[axis1];
let axis2_len = shape[axis2];
let diag_len = if offset >= 0 {
cmp::min(axis1_len, axis2_len.saturating_sub(offset as usize))
} else {
cmp::min(axis1_len.saturating_sub((-offset) as usize), axis2_len)
};
if diag_len == 0 {
let mut result_shape = Vec::with_capacity(ndim - 1);
for (i, &dim) in shape.iter().enumerate() {
if i != axis1 && i != axis2 {
result_shape.push(dim);
}
}
result_shape.push(0);
return Ok(Array::zeros(&result_shape));
}
let mut result_shape = Vec::with_capacity(ndim - 1);
for (i, &dim) in shape.iter().enumerate() {
if i != axis1 && i != axis2 {
result_shape.push(dim);
}
}
result_shape.push(diag_len);
let result_size: usize = result_shape.iter().product();
let mut result_vec = Vec::with_capacity(result_size);
let array_vec = array.to_vec();
let mut strides = Vec::with_capacity(ndim);
let mut stride = 1;
for &dim in shape.iter().rev() {
strides.push(stride);
stride *= dim;
}
strides.reverse();
let axis1_stride = strides[axis1];
let axis2_stride = strides[axis2];
let calc_base_index = |indices: &[usize]| -> usize {
let mut base_idx = 0;
let mut _dst_idx = 0;
for (src_idx, &dim) in indices.iter().enumerate() {
if src_idx != axis1 && src_idx != axis2 {
base_idx += dim * strides[src_idx];
_dst_idx += 1;
}
}
base_idx
};
let mut indices = vec![0; ndim];
let increment_indices = |indices: &mut [usize], shape: &[usize], axis1, axis2| {
for i in (0..indices.len()).rev() {
if i != axis1 && i != axis2 {
indices[i] += 1;
if indices[i] < shape[i] {
return true;
}
indices[i] = 0;
}
}
false
};
let mut outer_elements = 1;
for (i, &dim) in shape.iter().enumerate() {
if i != axis1 && i != axis2 {
outer_elements *= dim;
}
}
for _ in 0..outer_elements {
let base_idx = calc_base_index(&indices);
for i in 0..diag_len {
let row: usize;
let col: usize;
if offset >= 0 {
row = i;
col = i + offset as usize;
} else {
row = i + (-offset) as usize;
col = i;
}
if row < axis1_len && col < axis2_len {
let idx = base_idx + row * axis1_stride + col * axis2_stride;
result_vec.push(array_vec[idx].clone());
}
}
increment_indices(&mut indices, &shape, axis1, axis2);
}
Ok(Array::from_vec(result_vec).reshape(&result_shape))
}
pub fn fill_diagonal<T: Clone>(array: &mut Array<T>, val: T, wrap: bool) -> Result<()> {
let ndim = array.ndim();
if ndim < 2 {
return Err(NumRs2Error::InvalidOperation(
"Array must be at least 2D".to_string(),
));
}
let shape = array.shape();
if ndim == 2 {
let n_rows = shape[0];
let n_cols = shape[1];
let diag_len = if wrap {
n_rows
} else {
std::cmp::min(n_rows, n_cols)
};
let array_data = array
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
for i in 0..diag_len {
let col = if wrap { i % n_cols } else { i };
if col < n_cols {
let idx = i * n_cols + col;
array_data[idx] = val.clone();
}
}
} else {
let min_dim = shape.iter().min().copied().unwrap_or(0);
let array_data = array
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
let mut strides = vec![1; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
for i in 0..min_dim {
let mut idx = 0;
for &stride in &strides {
idx += i * stride;
}
array_data[idx] = val.clone();
}
}
Ok(())
}