use super::advanced_ops::{ArrayView, Shape};
use crate::error::{NumRs2Error, Result};
use crate::traits::NumericElement;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemoryLayout {
C,
Fortran,
Custom,
Strided,
}
pub struct ShapeEngine {
stride_cache: HashMap<(Vec<usize>, MemoryLayout), Vec<usize>>,
}
impl ShapeEngine {
pub fn new() -> Self {
Self {
stride_cache: HashMap::new(),
}
}
pub fn compute_strides(&mut self, shape: &[usize], layout: MemoryLayout) -> Vec<usize> {
let cache_key = (shape.to_vec(), layout);
if let Some(cached_strides) = self.stride_cache.get(&cache_key) {
return cached_strides.clone();
}
let strides = match layout {
MemoryLayout::C => self.compute_c_strides(shape),
MemoryLayout::Fortran => self.compute_fortran_strides(shape),
MemoryLayout::Custom => self.compute_optimal_strides(shape),
MemoryLayout::Strided => self.compute_default_strides(shape),
};
self.stride_cache.insert(cache_key, strides.clone());
strides
}
fn compute_c_strides(&self, shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
fn compute_fortran_strides(&self, shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in 1..shape.len() {
strides[i] = strides[i - 1] * shape[i - 1];
}
strides
}
fn compute_optimal_strides(&self, shape: &[usize]) -> Vec<usize> {
let mut dim_sizes: Vec<(usize, usize)> = shape
.iter()
.enumerate()
.map(|(i, &size)| (i, size))
.collect();
dim_sizes.sort_by_key(|&(_, size)| std::cmp::Reverse(size));
let mut strides = vec![0; shape.len()];
let mut current_stride = 1;
for &(dim_idx, dim_size) in &dim_sizes {
strides[dim_idx] = current_stride;
current_stride *= dim_size;
}
strides
}
fn compute_default_strides(&self, shape: &[usize]) -> Vec<usize> {
self.compute_c_strides(shape)
}
pub fn can_reshape(&self, current_shape: &[usize], new_shape: &[usize]) -> bool {
let current_size: usize = current_shape.iter().product();
let new_size: usize = new_shape.iter().product();
current_size == new_size
}
pub fn reshape_view<'a, T>(
&self,
view: &ArrayView<'a, T>,
new_shape: &[usize],
) -> Result<ArrayView<'a, T>>
where
T: NumericElement,
{
if !self.can_reshape(&view.shape().dims, new_shape) {
return Err(NumRs2Error::DimensionMismatch(format!(
"Cannot reshape array from {:?} to {:?}: incompatible sizes",
view.shape().dims,
new_shape
)));
}
if !view.is_c_contiguous() {
return Err(NumRs2Error::InvalidOperation(
"Cannot reshape non-contiguous view. Use copy() first.".to_string(),
));
}
let new_shape_obj = Shape::new(new_shape.to_vec());
view.reshape(new_shape_obj)
}
pub fn transpose_view<'a, T>(
&self,
view: &ArrayView<'a, T>,
axes: Option<Vec<usize>>,
) -> Result<ArrayView<'a, T>>
where
T: NumericElement,
{
view.transpose(axes)
}
pub fn swapaxes_view<'a, T>(
&self,
view: &ArrayView<'a, T>,
axis1: usize,
axis2: usize,
) -> Result<ArrayView<'a, T>>
where
T: NumericElement,
{
let ndim = view.shape().ndim();
if axis1 >= ndim || axis2 >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axes {} and {} are out of bounds for array of dimension {}",
axis1, axis2, ndim
)));
}
let mut axes: Vec<usize> = (0..ndim).collect();
axes.swap(axis1, axis2);
self.transpose_view(view, Some(axes))
}
pub fn moveaxis_view<'a, T>(
&self,
view: &ArrayView<'a, T>,
source: usize,
destination: usize,
) -> Result<ArrayView<'a, T>>
where
T: NumericElement,
{
let ndim = view.shape().ndim();
if source >= ndim || destination >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Source axis {} or destination axis {} is out of bounds for array of dimension {}",
source, destination, ndim
)));
}
let mut axes: Vec<usize> = (0..ndim).collect();
let removed = axes.remove(source);
axes.insert(destination, removed);
self.transpose_view(view, Some(axes))
}
pub fn roll_view<T>(
&self,
view: &ArrayView<'_, T>,
shift: isize,
axis: Option<usize>,
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
match axis {
Some(ax) => self.roll_along_axis(view, shift, ax),
None => self.roll_flattened(view, shift),
}
}
fn roll_along_axis<T>(
&self,
view: &ArrayView<'_, T>,
shift: isize,
axis: usize,
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if axis >= view.shape().ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array of dimension {}",
axis,
view.shape().ndim()
)));
}
let axis_size = view.shape().dims[axis];
let effective_shift =
((shift % axis_size as isize) + axis_size as isize) as usize % axis_size;
let mut result = Vec::with_capacity(view.shape().size());
let mut indices = vec![0; view.shape().ndim()];
loop {
let original_axis_idx = indices[axis];
let rolled_axis_idx = (original_axis_idx + effective_shift) % axis_size;
let mut rolled_indices = indices.clone();
rolled_indices[axis] = rolled_axis_idx;
if let Ok(element) = view.get(&rolled_indices) {
result.push(*element);
}
if !self.advance_indices(&mut indices, &view.shape().dims) {
break;
}
}
Ok(result)
}
fn roll_flattened<T>(&self, view: &ArrayView<'_, T>, shift: isize) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
let flat_data = view.to_vec();
let size = flat_data.len();
if size == 0 {
return Ok(flat_data);
}
let effective_shift = ((shift % size as isize) + size as isize) as usize % size;
let mut result = Vec::with_capacity(size);
for i in 0..size {
let src_idx = (i + size - effective_shift) % size;
result.push(flat_data[src_idx]);
}
Ok(result)
}
pub fn flip_view<T>(&self, view: &ArrayView<'_, T>, axes: Option<Vec<usize>>) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
let axes_to_flip = match axes {
Some(ax) => ax,
None => (0..view.shape().ndim()).collect(),
};
for &axis in &axes_to_flip {
if axis >= view.shape().ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array of dimension {}",
axis,
view.shape().ndim()
)));
}
}
let mut result = Vec::with_capacity(view.shape().size());
let mut indices = vec![0; view.shape().ndim()];
loop {
let mut flipped_indices = indices.clone();
for &axis in &axes_to_flip {
flipped_indices[axis] = view.shape().dims[axis] - 1 - indices[axis];
}
if let Ok(element) = view.get(&flipped_indices) {
result.push(*element);
}
if !self.advance_indices(&mut indices, &view.shape().dims) {
break;
}
}
Ok(result)
}
pub fn rot90_view<T>(
&self,
view: &ArrayView<'_, T>,
k: i32,
axes: Option<(usize, usize)>,
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if view.shape().ndim() < 2 {
return Err(NumRs2Error::DimensionMismatch(
"rot90 requires at least 2 dimensions".to_string(),
));
}
let (axis1, axis2) = axes.unwrap_or((0, 1));
if axis1 >= view.shape().ndim() || axis2 >= view.shape().ndim() || axis1 == axis2 {
return Err(NumRs2Error::DimensionMismatch(
"Invalid rotation axes".to_string(),
));
}
let k_norm = k.rem_euclid(4);
match k_norm {
0 => Ok(view.to_vec()), 1 => self.rotate_90_once(view, axis1, axis2),
2 => self.rotate_180(view, axis1, axis2),
3 => self.rotate_270(view, axis1, axis2),
_ => unreachable!(),
}
}
fn rotate_90_once<T>(
&self,
view: &ArrayView<'_, T>,
axis1: usize,
axis2: usize,
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
let mut result = Vec::with_capacity(view.shape().size());
let mut indices = vec![0; view.shape().ndim()];
loop {
let mut rotated_indices = indices.clone();
let old_i = indices[axis1];
let old_j = indices[axis2];
rotated_indices[axis1] = old_j;
rotated_indices[axis2] = view.shape().dims[axis1] - 1 - old_i;
if let Ok(element) = view.get(&rotated_indices) {
result.push(*element);
}
if !self.advance_indices(&mut indices, &view.shape().dims) {
break;
}
}
Ok(result)
}
fn rotate_180<T>(&self, view: &ArrayView<'_, T>, axis1: usize, axis2: usize) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
let mut result = Vec::with_capacity(view.shape().size());
let mut indices = vec![0; view.shape().ndim()];
loop {
let mut rotated_indices = indices.clone();
rotated_indices[axis1] = view.shape().dims[axis1] - 1 - indices[axis1];
rotated_indices[axis2] = view.shape().dims[axis2] - 1 - indices[axis2];
if let Ok(element) = view.get(&rotated_indices) {
result.push(*element);
}
if !self.advance_indices(&mut indices, &view.shape().dims) {
break;
}
}
Ok(result)
}
fn rotate_270<T>(&self, view: &ArrayView<'_, T>, axis1: usize, axis2: usize) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
let mut result = Vec::with_capacity(view.shape().size());
let mut indices = vec![0; view.shape().ndim()];
loop {
let mut rotated_indices = indices.clone();
let old_i = indices[axis1];
let old_j = indices[axis2];
rotated_indices[axis1] = view.shape().dims[axis2] - 1 - old_j;
rotated_indices[axis2] = old_i;
if let Ok(element) = view.get(&rotated_indices) {
result.push(*element);
}
if !self.advance_indices(&mut indices, &view.shape().dims) {
break;
}
}
Ok(result)
}
pub fn squeeze_view<'a, T>(
&self,
view: &ArrayView<'a, T>,
axes: Option<Vec<usize>>,
) -> Result<ArrayView<'a, T>>
where
T: NumericElement,
{
let axes_to_squeeze = match axes {
Some(ax) => {
for &axis in &ax {
if axis >= view.shape().ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds",
axis
)));
}
if view.shape().dims[axis] != 1 {
return Err(NumRs2Error::DimensionMismatch(format!(
"Cannot squeeze axis {} with size {}",
axis,
view.shape().dims[axis]
)));
}
}
ax
}
None => {
view.shape()
.dims
.iter()
.enumerate()
.filter_map(|(i, &size)| if size == 1 { Some(i) } else { None })
.collect()
}
};
let new_dims: Vec<usize> = view
.shape()
.dims
.iter()
.enumerate()
.filter_map(|(i, &size)| {
if axes_to_squeeze.contains(&i) {
None
} else {
Some(size)
}
})
.collect();
if new_dims.is_empty() {
let new_shape = Shape::new(vec![1]);
view.reshape(new_shape)
} else {
let new_shape = Shape::new(new_dims);
view.reshape(new_shape)
}
}
pub fn expand_dims_view<'a, T>(
&self,
view: &ArrayView<'a, T>,
axes: Vec<usize>,
) -> Result<ArrayView<'a, T>>
where
T: NumericElement,
{
let mut new_dims = view.shape().dims.clone();
let mut sorted_axes = axes.clone();
sorted_axes.sort_unstable();
for &axis in &sorted_axes {
if axis > new_dims.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for expansion",
axis
)));
}
}
for &axis in sorted_axes.iter().rev() {
new_dims.insert(axis, 1);
}
let new_shape = Shape::new(new_dims);
view.reshape(new_shape)
}
pub fn is_broadcastable(&self, source_shape: &[usize], target_shape: &[usize]) -> bool {
let max_ndim = std::cmp::max(source_shape.len(), target_shape.len());
for i in 0..max_ndim {
let src_dim = if i < source_shape.len() {
source_shape[source_shape.len() - i - 1]
} else {
1
};
let tgt_dim = if i < target_shape.len() {
target_shape[target_shape.len() - i - 1]
} else {
1
};
if src_dim != tgt_dim && src_dim != 1 && tgt_dim != 1 {
return false;
}
}
true
}
pub fn analyze_layout_efficiency(&self, shape: &[usize], strides: &[usize]) -> LayoutAnalysis {
let c_strides = self.compute_c_strides(shape);
let f_strides = self.compute_fortran_strides(shape);
let is_c_contiguous = strides == c_strides;
let is_f_contiguous = strides == f_strides;
let is_contiguous = is_c_contiguous || is_f_contiguous;
let total_elements: usize = shape.iter().product();
let memory_span = self.calculate_memory_span(shape, strides);
let efficiency = if memory_span > 0 {
total_elements as f64 / memory_span as f64
} else {
0.0
};
let layout_pattern = if is_c_contiguous {
LayoutPattern::CContiguous
} else if is_f_contiguous {
LayoutPattern::FortranContiguous
} else if self.is_unit_stride_pattern(strides) {
LayoutPattern::UnitStride
} else if self.has_regular_pattern(strides) {
LayoutPattern::Regular
} else {
LayoutPattern::Irregular
};
LayoutAnalysis {
is_contiguous,
is_c_contiguous,
is_f_contiguous,
efficiency,
layout_pattern,
memory_span,
recommended_layout: self.recommend_layout(shape, strides),
}
}
fn calculate_memory_span(&self, shape: &[usize], strides: &[usize]) -> usize {
if shape.is_empty() {
return 0;
}
let mut min_offset = 0;
let mut max_offset = 0;
for (&dim_size, &stride) in shape.iter().zip(strides.iter()) {
if dim_size > 1 {
let offset = (dim_size - 1) * stride;
if stride > 0 {
max_offset += offset;
} else {
min_offset += offset;
}
}
}
max_offset - min_offset + 1
}
fn is_unit_stride_pattern(&self, strides: &[usize]) -> bool {
strides.contains(&1)
}
fn has_regular_pattern(&self, strides: &[usize]) -> bool {
if strides.len() < 2 {
return true;
}
for i in 1..strides.len() {
if strides[i] == 0 || strides[i - 1] == 0 {
continue;
}
if strides[i] > strides[i - 1] * 10 || strides[i - 1] > strides[i] * 10 {
return false;
}
}
true
}
fn recommend_layout(&self, shape: &[usize], current_strides: &[usize]) -> MemoryLayout {
let c_strides = self.compute_c_strides(shape);
let f_strides = self.compute_fortran_strides(shape);
let is_c_contiguous = current_strides == c_strides;
let is_f_contiguous = current_strides == f_strides;
let total_elements: usize = shape.iter().product();
let memory_span = self.calculate_memory_span(shape, current_strides);
let efficiency = if memory_span > 0 {
total_elements as f64 / memory_span as f64
} else {
0.0
};
if efficiency > 0.9 {
if is_c_contiguous {
MemoryLayout::C
} else if is_f_contiguous {
MemoryLayout::Fortran
} else {
MemoryLayout::Custom
}
} else {
if shape.len() <= 2 {
MemoryLayout::C
} else {
MemoryLayout::Custom }
}
}
fn advance_indices(&self, indices: &mut [usize], shape: &[usize]) -> bool {
for i in (0..indices.len()).rev() {
indices[i] += 1;
if indices[i] < shape[i] {
return true;
}
indices[i] = 0;
}
false
}
}
impl Default for ShapeEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct LayoutAnalysis {
pub is_contiguous: bool,
pub is_c_contiguous: bool,
pub is_f_contiguous: bool,
pub efficiency: f64,
pub layout_pattern: LayoutPattern,
pub memory_span: usize,
pub recommended_layout: MemoryLayout,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayoutPattern {
CContiguous,
FortranContiguous,
UnitStride,
Regular,
Irregular,
}
pub struct ViewSystem {
shape_engine: ShapeEngine,
}
impl ViewSystem {
pub fn new() -> Self {
Self {
shape_engine: ShapeEngine::new(),
}
}
pub fn create_optimized_view<'a, T>(
&mut self,
data: &'a [T],
shape: &[usize],
intended_operations: &[ViewOperation],
) -> Result<ArrayView<'a, T>>
where
T: NumericElement,
{
let optimal_layout = self.determine_optimal_layout(shape, intended_operations);
let strides = self.shape_engine.compute_strides(shape, optimal_layout);
let shape_obj = Shape::new(shape.to_vec());
ArrayView::new(data, shape_obj, strides, 0)
}
fn determine_optimal_layout(
&self,
shape: &[usize],
operations: &[ViewOperation],
) -> MemoryLayout {
let mut score_c = 0;
let mut score_fortran = 0;
let mut score_custom = 0;
for op in operations {
match op {
ViewOperation::RowAccess => score_c += 2,
ViewOperation::ColumnAccess => score_fortran += 2,
ViewOperation::RandomAccess => score_custom += 1,
ViewOperation::SequentialScan => score_c += 1,
ViewOperation::Transpose => score_fortran += 1,
ViewOperation::MatrixMultiply => {
score_c += 1;
score_fortran += 1;
}
ViewOperation::Reduction => score_c += 1,
ViewOperation::Broadcasting => score_custom += 2,
}
}
if shape.len() > 2 {
score_custom += 1;
}
if score_custom > score_c && score_custom > score_fortran {
MemoryLayout::Custom
} else if score_fortran > score_c {
MemoryLayout::Fortran
} else {
MemoryLayout::C
}
}
pub fn create_view_chain<'a, T>(
&mut self,
initial_view: ArrayView<'a, T>,
operations: &[ViewChainOperation],
) -> Result<ArrayView<'a, T>>
where
T: NumericElement + Copy,
{
let mut current_view = initial_view;
for operation in operations {
current_view = match operation {
ViewChainOperation::Reshape(new_shape) => {
self.shape_engine.reshape_view(¤t_view, new_shape)?
}
ViewChainOperation::Transpose(axes) => self
.shape_engine
.transpose_view(¤t_view, axes.clone())?,
ViewChainOperation::SwapAxes(ax1, ax2) => {
self.shape_engine.swapaxes_view(¤t_view, *ax1, *ax2)?
}
ViewChainOperation::MoveAxis(src, dst) => {
self.shape_engine.moveaxis_view(¤t_view, *src, *dst)?
}
ViewChainOperation::Squeeze(axes) => self
.shape_engine
.squeeze_view(¤t_view, axes.clone())?,
ViewChainOperation::ExpandDims(axes) => self
.shape_engine
.expand_dims_view(¤t_view, axes.clone())?,
};
}
Ok(current_view)
}
}
impl Default for ViewSystem {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum ViewOperation {
RowAccess,
ColumnAccess,
RandomAccess,
SequentialScan,
Transpose,
MatrixMultiply,
Reduction,
Broadcasting,
}
#[derive(Debug, Clone)]
pub enum ViewChainOperation {
Reshape(Vec<usize>),
Transpose(Option<Vec<usize>>),
SwapAxes(usize, usize),
MoveAxis(usize, usize),
Squeeze(Option<Vec<usize>>),
ExpandDims(Vec<usize>),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrays::advanced_ops::{ArrayView, Shape};
#[test]
fn test_shape_engine_creation() {
let engine = ShapeEngine::new();
assert!(engine.stride_cache.is_empty());
}
#[test]
fn test_c_strides_computation() {
let mut engine = ShapeEngine::new();
let shape = [2, 3, 4];
let strides = engine.compute_strides(&shape, MemoryLayout::C);
assert_eq!(strides, vec![12, 4, 1]);
}
#[test]
fn test_fortran_strides_computation() {
let mut engine = ShapeEngine::new();
let shape = [2, 3, 4];
let strides = engine.compute_strides(&shape, MemoryLayout::Fortran);
assert_eq!(strides, vec![1, 2, 6]);
}
#[test]
fn test_reshape_validation() {
let engine = ShapeEngine::new();
assert!(engine.can_reshape(&[2, 3], &[6]));
assert!(engine.can_reshape(&[2, 3], &[3, 2]));
assert!(!engine.can_reshape(&[2, 3], &[7]));
}
#[test]
fn test_broadcastability_check() {
let engine = ShapeEngine::new();
assert!(engine.is_broadcastable(&[3, 1, 4], &[2, 4]));
assert!(engine.is_broadcastable(&[1, 4], &[3, 4]));
assert!(!engine.is_broadcastable(&[3, 4], &[5, 4]));
}
#[test]
fn test_layout_analysis() {
let mut engine = ShapeEngine::new();
let shape = [3, 4];
let c_strides = engine.compute_strides(&shape, MemoryLayout::C);
let analysis = engine.analyze_layout_efficiency(&shape, &c_strides);
assert!(analysis.is_c_contiguous);
assert!(analysis.is_contiguous);
assert!(analysis.efficiency > 0.9);
}
#[test]
fn test_view_system_creation() {
let mut view_system = ViewSystem::new();
let data = vec![1, 2, 3, 4, 5, 6];
let shape = [2, 3];
let operations = vec![ViewOperation::RowAccess];
let view = view_system
.create_optimized_view(&data, &shape, &operations)
.expect("test: operation should succeed");
assert_eq!(view.shape().dims, vec![2, 3]);
}
#[test]
fn test_squeeze_operation() {
let engine = ShapeEngine::new();
let data = vec![1, 2, 3, 4];
let shape = Shape::new(vec![1, 2, 1, 2]);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let squeezed = engine
.squeeze_view(&view, None)
.expect("test: operation should succeed");
assert_eq!(squeezed.shape().dims, vec![2, 2]);
}
#[test]
fn test_expand_dims_operation() {
let engine = ShapeEngine::new();
let data = vec![1, 2, 3, 4];
let shape = Shape::new(vec![2, 2]);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let expanded = engine
.expand_dims_view(&view, vec![1])
.expect("test: operation should succeed");
assert_eq!(expanded.shape().dims, vec![2, 1, 2]);
}
}