use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::{
ArrayView as NdArrayView, ArrayViewMut as NdArrayViewMut, Axis, IxDyn, Slice, SliceInfo,
SliceInfoElem,
};
use std::ops::{Add, Div, Mul, Sub};
#[derive(Debug, Clone)]
pub enum SliceOrIndex {
Index(usize),
Slice(usize, Option<usize>, Option<usize>), }
impl SliceOrIndex {
pub fn to_ndarray_slice(&self) -> Slice {
match self {
SliceOrIndex::Index(idx) => Slice::from(*idx..*idx + 1),
SliceOrIndex::Slice(start, end, step) => {
let end_val = end.unwrap_or_else(|| usize::MAX);
let step_val = step.unwrap_or(1) as isize;
Slice::new(*start as isize, Some(end_val as isize), step_val)
}
}
}
}
pub struct ArrayView<'a, T> {
data: NdArrayView<'a, T, IxDyn>,
}
pub struct ArrayViewMut<'a, T> {
data: NdArrayViewMut<'a, T, IxDyn>,
}
impl<'a, T: 'a> ArrayView<'a, T> {
pub fn from_ndarray_view(view: NdArrayView<'a, T, IxDyn>) -> Self {
Self { data: view }
}
pub fn view(&self) -> &NdArrayView<'a, T, IxDyn> {
&self.data
}
pub fn shape(&self) -> Vec<usize> {
self.data.shape().to_vec()
}
pub fn ndim(&self) -> usize {
self.data.ndim()
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn to_owned(&self) -> Array<T>
where
T: Clone,
{
Array::from_ndarray(self.data.to_owned())
}
pub fn slice_axis(&'a self, axis: Axis, indices: Slice) -> ArrayView<'a, T> {
let sliced = self.data.slice_axis(axis, indices);
ArrayView::from_ndarray_view(sliced)
}
pub fn t(&'a self) -> ArrayView<'a, T> {
let transposed = self.data.t();
ArrayView::from_ndarray_view(transposed.into_dyn())
}
pub fn to_vec(&self) -> Vec<T>
where
T: Clone,
{
self.data.iter().cloned().collect()
}
pub fn get(&self, indices: &[usize]) -> Result<T>
where
T: Clone,
{
if indices.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} indices, got {}",
self.ndim(),
indices.len()
)));
}
match self.data.get(indices) {
Some(value) => Ok(value.clone()),
None => Err(NumRs2Error::IndexOutOfBounds(format!(
"Indices {:?} out of bounds for shape {:?}",
indices,
self.shape()
))),
}
}
pub fn map<F, U>(&self, f: F) -> Array<U>
where
F: Fn(&T) -> U,
U: Clone,
{
Array::from_ndarray(self.data.map(f).into_dyn())
}
}
impl<'a, T> Add for &ArrayView<'a, T>
where
T: 'a + Clone + Add<Output = T>,
{
type Output = Array<T>;
fn add(self, rhs: &ArrayView<'a, T>) -> Self::Output {
let result = &self.data + &rhs.data;
Array::from_ndarray(result.into_dyn())
}
}
impl<'a, T> Sub for &ArrayView<'a, T>
where
T: 'a + Clone + Sub<Output = T>,
{
type Output = Array<T>;
fn sub(self, rhs: &ArrayView<'a, T>) -> Self::Output {
let result = &self.data - &rhs.data;
Array::from_ndarray(result.into_dyn())
}
}
impl<'a, T> Mul for &ArrayView<'a, T>
where
T: 'a + Clone + Mul<Output = T>,
{
type Output = Array<T>;
fn mul(self, rhs: &ArrayView<'a, T>) -> Self::Output {
let result = &self.data * &rhs.data;
Array::from_ndarray(result.into_dyn())
}
}
impl<'a, T> Div for &ArrayView<'a, T>
where
T: 'a + Clone + Div<Output = T>,
{
type Output = Array<T>;
fn div(self, rhs: &ArrayView<'a, T>) -> Self::Output {
let result = &self.data / &rhs.data;
Array::from_ndarray(result.into_dyn())
}
}
impl<'a, T: 'a> ArrayViewMut<'a, T> {
pub fn from_ndarray_view_mut(view: NdArrayViewMut<'a, T, IxDyn>) -> Self {
Self { data: view }
}
pub fn view_mut(&self) -> &NdArrayViewMut<'a, T, IxDyn> {
&self.data
}
pub fn shape(&self) -> Vec<usize> {
self.data.shape().to_vec()
}
pub fn ndim(&self) -> usize {
self.data.ndim()
}
pub fn size(&self) -> usize {
self.data.len()
}
pub fn to_owned(&self) -> Array<T>
where
T: Clone,
{
Array::from_ndarray(self.data.to_owned())
}
pub fn get_mut(&mut self, indices: &[usize]) -> Result<&mut T> {
if indices.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} indices, got {}",
self.ndim(),
indices.len()
)));
}
let shape = self.shape();
match self.data.get_mut(indices) {
Some(value) => Ok(value),
None => Err(NumRs2Error::IndexOutOfBounds(format!(
"Indices {:?} out of bounds for shape {:?}",
indices, shape
))),
}
}
pub fn set(&mut self, indices: &[usize], value: T) -> Result<()>
where
T: Clone,
{
if let Some(elem) = self.data.get_mut(indices) {
*elem = value;
Ok(())
} else {
Err(NumRs2Error::IndexOutOfBounds(format!(
"Failed to set element at indices {:?}",
indices
)))
}
}
pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, T> {
let sliced = self.data.slice_axis_mut(axis, indices);
ArrayViewMut::from_ndarray_view_mut(sliced)
}
}
#[derive(Debug)]
pub struct StridedArrayView<'a, T> {
data: &'a [T],
shape: Vec<usize>,
strides: Vec<isize>,
offset: usize,
}
impl<'a, T: Clone> StridedArrayView<'a, T> {
pub fn new(data: &'a [T], shape: Vec<usize>, strides: Vec<isize>, offset: usize) -> Self {
Self {
data,
shape,
strides,
offset,
}
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn size(&self) -> usize {
self.shape.iter().product()
}
pub fn strides(&self) -> &[isize] {
&self.strides
}
pub fn get(&self, indices: &[usize]) -> Option<&T> {
if indices.len() != self.ndim() {
return None;
}
let mut flat_idx = self.offset as isize;
for (i, &idx) in indices.iter().enumerate() {
if idx >= self.shape[i] {
return None;
}
flat_idx += (idx as isize) * self.strides[i];
}
if flat_idx < 0 || flat_idx as usize >= self.data.len() {
return None;
}
Some(&self.data[flat_idx as usize])
}
pub fn to_vec(&self) -> Vec<T> {
let mut result = Vec::with_capacity(self.size());
self.collect_recursive(0, self.offset, &mut result);
result
}
fn collect_recursive(&self, dim: usize, current_offset: usize, result: &mut Vec<T>) {
if dim == self.ndim() {
result.push(self.data[current_offset].clone());
return;
}
for i in 0..self.shape[dim] {
let new_offset = (current_offset as isize + (i as isize) * self.strides[dim]) as usize;
self.collect_recursive(dim + 1, new_offset, result);
}
}
pub fn to_owned(&self) -> Array<T> {
Array::from_vec(self.to_vec()).reshape(&self.shape)
}
pub fn subview(&self, axis: usize, index: usize) -> Option<StridedArrayView<'a, T>> {
if axis >= self.ndim() || index >= self.shape[axis] {
return None;
}
let mut new_shape = self.shape.clone();
let mut new_strides = self.strides.clone();
new_shape.remove(axis);
new_strides.remove(axis);
let new_offset = (self.offset as isize + (index as isize) * self.strides[axis]) as usize;
Some(StridedArrayView {
data: self.data,
shape: new_shape,
strides: new_strides,
offset: new_offset,
})
}
pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
StridedViewIterOwned::new(self)
}
}
struct StridedViewIterOwned<'a, T> {
view: &'a StridedArrayView<'a, T>,
indices: Vec<usize>,
done: bool,
}
impl<'a, T: Clone> StridedViewIterOwned<'a, T> {
fn new(view: &'a StridedArrayView<'a, T>) -> Self {
let indices = vec![0; view.ndim()];
let done = view.size() == 0;
Self {
view,
indices,
done,
}
}
fn advance(&mut self) -> bool {
for i in (0..self.indices.len()).rev() {
self.indices[i] += 1;
if self.indices[i] < self.view.shape[i] {
return true;
}
self.indices[i] = 0;
}
false
}
}
impl<'a, T: Clone> Iterator for StridedViewIterOwned<'a, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
let result = self.view.get(&self.indices).cloned();
if !self.advance() {
self.done = true;
}
result
}
}
#[derive(Debug)]
pub struct WindowView<'a, T> {
data: &'a [T],
source_shape: Vec<usize>,
window_shape: Vec<usize>,
step: Vec<usize>,
n_windows: Vec<usize>,
}
impl<'a, T: Clone> WindowView<'a, T> {
pub fn new(
data: &'a [T],
source_shape: Vec<usize>,
window_shape: Vec<usize>,
step: Vec<usize>,
) -> Result<Self> {
if source_shape.len() != window_shape.len() || source_shape.len() != step.len() {
return Err(NumRs2Error::DimensionMismatch(
"Source shape, window shape, and step must have same dimensions".to_string(),
));
}
let mut n_windows = Vec::with_capacity(source_shape.len());
for i in 0..source_shape.len() {
if window_shape[i] > source_shape[i] {
return Err(NumRs2Error::ValueError(format!(
"Window size {} exceeds dimension size {} at axis {}",
window_shape[i], source_shape[i], i
)));
}
n_windows.push((source_shape[i] - window_shape[i]) / step[i] + 1);
}
Ok(Self {
data,
source_shape,
window_shape,
step,
n_windows,
})
}
pub fn shape(&self) -> Vec<usize> {
let mut shape = self.n_windows.clone();
shape.extend_from_slice(&self.window_shape);
shape
}
pub fn get_window(&self, position: &[usize]) -> Option<Vec<T>> {
if position.len() != self.n_windows.len() {
return None;
}
for (i, &pos) in position.iter().enumerate() {
if pos >= self.n_windows[i] {
return None;
}
}
let mut result = Vec::with_capacity(self.window_shape.iter().product());
let start: Vec<usize> = position
.iter()
.zip(&self.step)
.map(|(&p, &s)| p * s)
.collect();
self.extract_window_recursive(
0,
&start,
&mut vec![0; self.window_shape.len()],
&mut result,
);
Some(result)
}
fn extract_window_recursive(
&self,
dim: usize,
start: &[usize],
window_pos: &mut [usize],
result: &mut Vec<T>,
) {
if dim == self.window_shape.len() {
let mut idx = 0;
let mut stride = 1;
for i in (0..self.source_shape.len()).rev() {
idx += (start[i] + window_pos[i]) * stride;
stride *= self.source_shape[i];
}
if idx < self.data.len() {
result.push(self.data[idx].clone());
}
return;
}
for i in 0..self.window_shape[dim] {
window_pos[dim] = i;
self.extract_window_recursive(dim + 1, start, window_pos, result);
}
}
pub fn n_windows(&self) -> &[usize] {
&self.n_windows
}
pub fn to_owned(&self) -> Array<T> {
let shape = self.shape();
let total_elements: usize = shape.iter().product();
let mut data = Vec::with_capacity(total_elements);
let mut pos = vec![0; self.n_windows.len()];
loop {
if let Some(window) = self.get_window(&pos) {
data.extend(window);
}
let mut i = pos.len();
while i > 0 {
i -= 1;
pos[i] += 1;
if pos[i] < self.n_windows[i] {
break;
}
pos[i] = 0;
if i == 0 {
return Array::from_vec(data).reshape(&shape);
}
}
}
}
}
#[derive(Debug)]
pub struct DiagonalView<'a, T> {
data: &'a [T],
#[allow(dead_code)] shape: Vec<usize>,
#[allow(dead_code)] offset: isize,
length: usize,
stride: usize,
start: usize,
}
impl<'a, T: Clone> DiagonalView<'a, T> {
pub fn new(data: &'a [T], rows: usize, cols: usize, offset: isize) -> Result<Self> {
let (start, length) = if offset >= 0 {
let k = offset as usize;
if k >= cols {
return Err(NumRs2Error::ValueError(format!(
"Offset {} out of bounds for {} columns",
offset, cols
)));
}
let len = (rows).min(cols - k);
(k, len)
} else {
let k = (-offset) as usize;
if k >= rows {
return Err(NumRs2Error::ValueError(format!(
"Offset {} out of bounds for {} rows",
offset, rows
)));
}
let len = (rows - k).min(cols);
(k * cols, len)
};
Ok(Self {
data,
shape: vec![rows, cols],
offset,
length,
stride: cols + 1, start,
})
}
pub fn len(&self) -> usize {
self.length
}
pub fn is_empty(&self) -> bool {
self.length == 0
}
pub fn get(&self, index: usize) -> Option<&T> {
if index >= self.length {
return None;
}
let flat_idx = self.start + index * self.stride;
self.data.get(flat_idx)
}
pub fn to_vec(&self) -> Vec<T> {
(0..self.length)
.filter_map(|i| self.get(i).cloned())
.collect()
}
pub fn iter(&self) -> impl Iterator<Item = &T> {
(0..self.length).filter_map(|i| self.get(i))
}
}
impl<T: Clone> Array<T> {
pub fn view(&self) -> ArrayView<'_, T> {
ArrayView::from_ndarray_view(self.array().view())
}
pub fn view_mut(&mut self) -> ArrayViewMut<'_, T> {
ArrayViewMut::from_ndarray_view_mut(self.array_mut().view_mut())
}
pub fn strided_view(&self, strides: &[isize]) -> Result<Array<T>>
where
T: Clone,
{
if strides.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} strides, got {}",
self.ndim(),
strides.len()
)));
}
let view = self.array().view();
let shape = self.shape();
let mut slice_info = Vec::with_capacity(self.ndim());
for (i, &stride) in strides.iter().enumerate() {
let dim_size = shape[i];
if stride == 0 {
return Err(NumRs2Error::InvalidOperation(format!(
"Stride for dimension {} cannot be zero",
i
)));
}
let start = if stride > 0 { 0 } else { dim_size as isize - 1 };
let end = if stride > 0 { dim_size as isize } else { -1 };
slice_info.push(SliceInfoElem::Slice {
start,
end: Some(end),
step: stride,
});
}
let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_info).map_err(|_| {
NumRs2Error::InvalidOperation("Failed to create slice info".to_string())
})?;
let strided = view.slice(slice_info);
let result = Array::from_ndarray(strided.to_owned());
Ok(result)
}
pub fn sliced_view(&self, slices: &[SliceOrIndex]) -> Result<Array<T>>
where
T: Clone,
{
if slices.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} slices, got {}",
self.ndim(),
slices.len()
)));
}
let ndarray_view = self.array().view();
let mut result_view = ndarray_view.to_owned();
for (i, slice) in slices.iter().enumerate() {
let slice_op = slice.to_ndarray_slice();
result_view = result_view.slice_axis(Axis(i), slice_op).to_owned();
}
Ok(Array::from_ndarray(result_view))
}
pub fn transposed_view(&self) -> ArrayView<'_, T> {
let transposed = self.array().view().reversed_axes();
ArrayView::from_ndarray_view(transposed)
}
pub fn broadcast_view(&self, shape: &[usize]) -> Result<Array<T>>
where
T: Clone,
{
let broadcasted = self.broadcast_to(shape)?;
Ok(broadcasted.clone())
}
pub fn strided_array_view(&self, strides: &[isize]) -> Result<StridedArrayView<'_, T>> {
if strides.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} strides, got {}",
self.ndim(),
strides.len()
)));
}
for (i, &stride) in strides.iter().enumerate() {
if stride == 0 {
return Err(NumRs2Error::ValueError(format!(
"Stride at dimension {} cannot be zero",
i
)));
}
}
let shape = self.shape();
let mut new_shape = Vec::with_capacity(self.ndim());
for (i, &stride) in strides.iter().enumerate() {
let abs_stride = stride.unsigned_abs();
let new_dim = shape[i].div_ceil(abs_stride);
new_shape.push(new_dim);
}
let source_strides: Vec<isize> = self
.array()
.strides()
.iter()
.zip(strides.iter())
.map(|(&s, &user_stride)| s * user_stride)
.collect();
Ok(StridedArrayView::new(
self.to_vec().leak(), new_shape,
source_strides,
0,
))
}
pub fn window_view(
&self,
window_shape: &[usize],
step: Option<&[usize]>,
) -> Result<WindowView<'_, T>> {
let step_values = match step {
Some(s) => {
if s.len() != self.ndim() {
return Err(NumRs2Error::DimensionMismatch(
"Step must have same length as array dimensions".to_string(),
));
}
s.to_vec()
}
None => vec![1; self.ndim()],
};
WindowView::new(
self.to_vec().leak(), self.shape(),
window_shape.to_vec(),
step_values,
)
}
pub fn diagonal_view(&self, offset: isize) -> Result<DiagonalView<'_, T>> {
if self.ndim() != 2 {
return Err(NumRs2Error::ValueError(
"diagonal_view requires a 2D array".to_string(),
));
}
let shape = self.shape();
DiagonalView::new(
self.to_vec().leak(), shape[0],
shape[1],
offset,
)
}
pub fn create_strided_view(
&self,
shape: Vec<usize>,
strides: Vec<isize>,
) -> StridedArrayView<'_, T> {
let slice = self.array().as_slice().unwrap_or(&[]);
StridedArrayView::new(slice, shape, strides, 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strided_array_view_basic() {
let array = Array::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(&[3, 3]);
let view = array.create_strided_view(vec![3, 3], vec![3, 1]);
assert_eq!(view.shape(), &[3, 3]);
assert_eq!(view.get(&[0, 0]), Some(&1));
assert_eq!(view.get(&[1, 1]), Some(&5));
assert_eq!(view.get(&[2, 2]), Some(&9));
}
#[test]
fn test_strided_array_view_skip_elements() {
let array = Array::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(&[3, 3]);
let view = array.create_strided_view(vec![2, 2], vec![6, 2]);
assert_eq!(view.shape(), &[2, 2]);
assert_eq!(view.get(&[0, 0]), Some(&1));
assert_eq!(view.get(&[0, 1]), Some(&3));
}
#[test]
fn test_strided_view_to_vec() {
let array = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let view = array.create_strided_view(vec![2, 3], vec![3, 1]);
let vec = view.to_vec();
assert_eq!(vec, vec![1, 2, 3, 4, 5, 6]);
}
#[test]
fn test_strided_view_iterator() {
let array = Array::from_vec(vec![1, 2, 3, 4]).reshape(&[2, 2]);
let view = array.create_strided_view(vec![2, 2], vec![2, 1]);
let collected: Vec<_> = view.iter().collect();
assert_eq!(collected, vec![1, 2, 3, 4]);
}
#[test]
fn test_strided_view_subview() {
let array = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let view = array.create_strided_view(vec![2, 3], vec![3, 1]);
let row_view = view.subview(0, 0).expect("test: subview should succeed");
assert_eq!(row_view.shape(), &[3]);
assert_eq!(row_view.get(&[0]), Some(&1));
assert_eq!(row_view.get(&[1]), Some(&2));
assert_eq!(row_view.get(&[2]), Some(&3));
}
#[test]
fn test_window_view_1d() {
let data = vec![1, 2, 3, 4, 5];
let window_view = WindowView::new(&data, vec![5], vec![3], vec![1])
.expect("test: WindowView creation should succeed");
assert_eq!(window_view.shape(), vec![3, 3]);
assert_eq!(window_view.n_windows(), &[3]);
let win0 = window_view
.get_window(&[0])
.expect("test: get_window(0) should succeed");
assert_eq!(win0, vec![1, 2, 3]);
let win1 = window_view
.get_window(&[1])
.expect("test: get_window(1) should succeed");
assert_eq!(win1, vec![2, 3, 4]);
}
#[test]
fn test_window_view_2d() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let window_view = WindowView::new(&data, vec![3, 3], vec![2, 2], vec![1, 1])
.expect("test: 2D WindowView creation should succeed");
assert_eq!(window_view.shape(), vec![2, 2, 2, 2]);
let win = window_view
.get_window(&[0, 0])
.expect("test: get_window([0,0]) should succeed");
assert_eq!(win, vec![1, 2, 4, 5]);
}
#[test]
fn test_diagonal_view_main_diagonal() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let diag = DiagonalView::new(&data, 3, 3, 0)
.expect("test: DiagonalView main diagonal should succeed");
assert_eq!(diag.len(), 3);
assert_eq!(diag.get(0), Some(&1));
assert_eq!(diag.get(1), Some(&5));
assert_eq!(diag.get(2), Some(&9));
assert_eq!(diag.to_vec(), vec![1, 5, 9]);
}
#[test]
fn test_diagonal_view_upper_diagonal() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let diag = DiagonalView::new(&data, 3, 3, 1)
.expect("test: DiagonalView upper diagonal should succeed");
assert_eq!(diag.len(), 2);
assert_eq!(diag.to_vec(), vec![2, 6]);
}
#[test]
fn test_diagonal_view_lower_diagonal() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let diag = DiagonalView::new(&data, 3, 3, -1)
.expect("test: DiagonalView lower diagonal should succeed");
assert_eq!(diag.len(), 2);
assert_eq!(diag.to_vec(), vec![4, 8]);
}
#[test]
fn test_diagonal_view_iterator() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let diag = DiagonalView::new(&data, 3, 3, 0)
.expect("test: DiagonalView for iterator test should succeed");
let collected: Vec<_> = diag.iter().copied().collect();
assert_eq!(collected, vec![1, 5, 9]);
}
#[test]
fn test_window_view_invalid_size() {
let data = vec![1, 2, 3];
let result = WindowView::new(
&data,
vec![3],
vec![5], vec![1],
);
assert!(result.is_err());
}
#[test]
fn test_strided_view_to_owned() {
let array = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let view = array.create_strided_view(vec![2, 3], vec![3, 1]);
let owned = view.to_owned();
assert_eq!(owned.shape(), vec![2, 3]);
assert_eq!(owned.to_vec(), vec![1, 2, 3, 4, 5, 6]);
}
}