use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::{Array as NdArray, ArrayView, IxDyn};
use std::fmt;
#[cfg(all(feature = "matrix_decomp", feature = "lapack"))]
pub type LstsqResult<T> = Result<(
super::Array<T>,
super::Array<T>, // Residuals are same type as matrix elements
usize,
super::Array<T>, // Singular values are same type as matrix elements
)>;
#[derive(Debug, Clone)]
pub struct ArrayFlags {
pub c_contiguous: bool,
pub f_contiguous: bool,
pub writeable: bool,
pub aligned: bool,
pub owndata: bool,
}
#[derive(Clone)]
pub struct Array<T> {
pub(crate) data: NdArray<T, IxDyn>,
}
impl<T: Clone> Array<T> {
pub fn from_ndarray(array: NdArray<T, IxDyn>) -> Self {
Self { data: array }
}
pub fn array(&self) -> &NdArray<T, IxDyn> {
&self.data
}
pub fn byte_strides(&self) -> Vec<usize> {
let elem_strides = self.data.strides();
let elem_size = std::mem::size_of::<T>();
elem_strides
.iter()
.map(|&s| s as usize * elem_size)
.collect()
}
pub fn array_mut(&mut self) -> &mut NdArray<T, IxDyn> {
&mut self.data
}
pub fn set(&mut self, indices: &[usize], value: T) -> Result<()> {
if indices.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} indices, got {}",
self.ndim(),
indices.len()
)));
}
for (i, &idx) in indices.iter().enumerate() {
if idx >= self.shape()[i] {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds for dimension {} with size {}",
idx,
i,
self.shape()[i]
)));
}
}
if let Some(elem) = self.array_mut().get_mut(indices) {
*elem = value;
Ok(())
} else {
Err(NumRs2Error::IndexOutOfBounds(format!(
"Failed to set element at indices {:?}",
indices
)))
}
}
pub fn shape(&self) -> Vec<usize> {
self.data.shape().to_vec()
}
pub fn ndim(&self) -> usize {
self.data.ndim()
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn nbytes(&self) -> usize {
self.size() * std::mem::size_of::<T>()
}
pub fn itemsize(&self) -> usize {
std::mem::size_of::<T>()
}
pub fn owns_data(&self) -> bool {
true
}
pub fn flags(&self) -> ArrayFlags {
ArrayFlags {
c_contiguous: self.data.is_standard_layout(),
f_contiguous: false, writeable: true, aligned: true, owndata: true, }
}
pub fn strides(&self) -> Vec<isize> {
self.data.strides().to_vec()
}
pub fn base(&self) -> Option<&Array<T>> {
None
}
pub fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
let (raw_vec, _) = self.data.clone().into_raw_vec_and_offset();
raw_vec
}
pub fn len(&self) -> usize {
self.size()
}
pub fn is_empty(&self) -> bool {
self.size() == 0
}
pub fn is_c_contiguous(&self) -> bool {
self.data.is_standard_layout()
}
pub fn get_flat(&self, index: usize) -> Result<T>
where
T: Clone,
{
if index >= self.size() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Flat index {} out of bounds for array of size {}",
index,
self.size()
)));
}
let mut indices = Vec::with_capacity(self.ndim());
let mut remainder = index;
let shape = self.shape();
for i in (0..self.ndim()).rev() {
indices.push(remainder % shape[i]);
remainder /= shape[i];
}
indices.reverse();
self.data.get(&indices[..]).cloned().ok_or_else(|| {
NumRs2Error::IndexOutOfBounds(format!("Failed to get element at flat index {}", index))
})
}
pub fn is_f_contiguous(&self) -> bool {
let shape = self.data.shape();
let strides = self.data.strides();
if shape.is_empty() {
return true;
}
let mut expected_stride = 1;
for i in 0..shape.len() {
if strides[i] != expected_stride as isize {
return false;
}
expected_stride *= shape[i];
}
true
}
pub fn is_contiguous(&self) -> bool {
self.is_c_contiguous() || self.is_f_contiguous()
}
pub fn to_c_layout(&self) -> Self {
if self.is_c_contiguous() {
self.clone()
} else {
let standard = self.data.as_standard_layout();
Self {
data: standard.into_owned(),
}
}
}
pub fn to_f_layout(&self) -> Self {
if self.is_f_contiguous() {
self.clone()
} else {
let transposed = self.data.clone().reversed_axes();
Self { data: transposed }
}
}
pub fn ndarray_view(&self) -> ArrayView<'_, T, IxDyn> {
self.data.view()
}
pub fn ndarray_view_mut(&mut self) -> &mut Self
where
T: Clone,
{
self
}
}
impl<T: fmt::Display> fmt::Display for Array<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.data)
}
}
impl<T: fmt::Debug + Clone> fmt::Debug for Array<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Array")
.field("shape", &self.shape())
.field("data", &self.data)
.finish()
}
}