use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::Axis;
use std::cmp;
use super::core::AxisArg;
pub fn tile<T: Clone>(array: &Array<T>, reps: &[usize]) -> Result<Array<T>> {
let a_shape = array.shape();
let mut output_shape = Vec::with_capacity(cmp::max(a_shape.len(), reps.len()));
let mut full_reps = Vec::with_capacity(cmp::max(a_shape.len(), reps.len()));
let reps_offset = if a_shape.len() > reps.len() {
a_shape.len() - reps.len()
} else {
0
};
for i in 0..full_reps.capacity() {
if i < reps_offset {
full_reps.push(1);
} else {
full_reps.push(reps[i - reps_offset]);
}
}
for (&a_dim, &rep) in a_shape.iter().zip(full_reps.iter()) {
output_shape.push(a_dim * rep);
}
if reps.len() > a_shape.len() {
let a_offset = reps.len() - a_shape.len();
for &rep in reps.iter().take(a_offset) {
output_shape.insert(0, rep);
}
}
let first_elem = array
.array()
.first()
.ok_or_else(|| NumRs2Error::InvalidOperation("Cannot tile an empty array".into()))?
.clone();
let mut result = Array::full(&output_shape, first_elem);
let result_vec = result
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
let input_vec = array.to_vec();
let input_size = input_vec.len();
if input_size == 0 {
return Err(NumRs2Error::InvalidOperation(
"Cannot tile an empty array".into(),
));
}
for (i, item) in result_vec.iter_mut().enumerate() {
let input_idx = i % input_size;
*item = input_vec[input_idx].clone();
}
Ok(result)
}
pub fn repeat<T: Clone>(array: &Array<T>, repeats: usize, axis: Option<usize>) -> Result<Array<T>> {
let a_shape = array.shape();
match axis {
Some(ax) => {
if ax >= a_shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
ax,
a_shape.len()
)));
}
let mut output_shape = a_shape.clone();
output_shape[ax] *= repeats;
let first_elem = array
.array()
.first()
.ok_or_else(|| {
NumRs2Error::InvalidOperation("Cannot repeat an empty array".into())
})?
.clone();
let mut result = Array::full(&output_shape, first_elem);
let result_vec = result.array_mut().as_slice_mut().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to get mutable slice".into())
})?;
let input_vec = array.to_vec();
if input_vec.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Cannot repeat an empty array".into(),
));
}
let axis_size = a_shape[ax];
let pre_axis_size: usize = a_shape.iter().take(ax).product();
let post_axis_size: usize = a_shape.iter().skip(ax + 1).product();
for i_pre in 0..pre_axis_size {
for i_axis in 0..axis_size {
for i_rep in 0..repeats {
for i_post in 0..post_axis_size {
let out_axis_idx = i_axis * repeats + i_rep;
let out_idx = i_pre * (output_shape[ax] * post_axis_size)
+ out_axis_idx * post_axis_size
+ i_post;
let in_idx = i_pre * (axis_size * post_axis_size)
+ i_axis * post_axis_size
+ i_post;
result_vec[out_idx] = input_vec[in_idx].clone();
}
}
}
}
Ok(result)
}
None => {
let input_vec = array.to_vec();
let mut result_vec = Vec::with_capacity(input_vec.len() * repeats);
for val in input_vec {
for _ in 0..repeats {
result_vec.push(val.clone());
}
}
Ok(Array::from_vec(result_vec))
}
}
}
pub fn concatenate<T: Clone>(arrays: &[&Array<T>], axis: impl Into<AxisArg>) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No arrays to concatenate".into(),
));
}
match axis.into() {
AxisArg::Single(axis) => concatenate_single_axis(arrays, axis),
AxisArg::Multiple(axes) => concatenate_multiple_axes(arrays, &axes),
}
}
fn concatenate_single_axis<T: Clone>(arrays: &[&Array<T>], axis: usize) -> Result<Array<T>> {
let first_shape = arrays[0].shape();
if axis >= first_shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
first_shape.len()
)));
}
for (_i, arr) in arrays.iter().enumerate().skip(1) {
let shape = arr.shape();
if shape.len() != first_shape.len() {
return Err(NumRs2Error::ShapeMismatch {
expected: first_shape.clone(),
actual: shape,
});
}
for (j, (&s1, &s2)) in first_shape.iter().zip(shape.iter()).enumerate() {
if j != axis && s1 != s2 {
return Err(NumRs2Error::ShapeMismatch {
expected: first_shape.clone(),
actual: shape,
});
}
}
}
let mut output_shape = first_shape.clone();
output_shape[axis] = arrays.iter().map(|arr| arr.shape()[axis]).sum();
let views: Result<Vec<_>> = arrays.iter().map(|arr| Ok(arr.array().view())).collect();
let views = views?;
let result = scirs2_core::ndarray::concatenate(Axis(axis), &views).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to concatenate arrays: {}", e))
})?;
Ok(Array::from_ndarray(result))
}
fn concatenate_multiple_axes<T: Clone>(arrays: &[&Array<T>], axes: &[usize]) -> Result<Array<T>> {
if axes.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No axes provided for concatenation".into(),
));
}
let mut result = arrays[0].clone();
for (i, &axis) in axes.iter().enumerate() {
if i == 0 {
result = concatenate_single_axis(arrays, axis)?;
} else {
result = concatenate_single_axis(&[&result, arrays[1]], axis)?;
}
}
Ok(result)
}
pub fn stack<T: Clone>(arrays: &[&Array<T>], axis: usize) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation("No arrays to stack".into()));
}
let first_shape = arrays[0].shape();
if axis > first_shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
first_shape.len()
)));
}
for arr in arrays.iter().skip(1) {
let shape = arr.shape();
if shape != first_shape {
return Err(NumRs2Error::ShapeMismatch {
expected: first_shape.clone(),
actual: shape,
});
}
}
let mut output_shape = first_shape.clone();
output_shape.insert(axis, arrays.len());
let mut reshaped_arrays = Vec::with_capacity(arrays.len());
for &arr in arrays {
let mut new_shape = first_shape.clone();
new_shape.insert(axis, 1);
let reshaped = arr.reshape(&new_shape);
reshaped_arrays.push(reshaped);
}
let mut result_refs: Vec<&Array<T>> = Vec::with_capacity(reshaped_arrays.len());
for arr in &reshaped_arrays {
result_refs.push(arr);
}
concatenate(&result_refs, axis)
}
pub fn block<T: Clone>(blocks: &[Vec<&Array<T>>]) -> Result<Array<T>> {
if blocks.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Empty block structure".into(),
));
}
let mut processed_rows = Vec::with_capacity(blocks.len());
for (row_idx, row) in blocks.iter().enumerate() {
if row.is_empty() {
return Err(NumRs2Error::InvalidOperation(format!(
"Empty row at index {} in block structure",
row_idx
)));
}
let mut processed_row = Vec::with_capacity(row.len());
let max_ndim = row.iter().map(|arr| arr.ndim()).max().unwrap_or(1);
for arr in row.iter() {
let arr_ndim = arr.ndim();
if arr_ndim < max_ndim {
let mut new_shape = arr.shape().to_vec();
while new_shape.len() < max_ndim {
if arr_ndim == 1 {
new_shape.push(1);
} else {
new_shape.insert(0, 1);
}
}
processed_row.push(arr.reshape(&new_shape));
} else {
processed_row.push((*arr).clone());
}
}
processed_rows.push(processed_row);
}
let mut rows_result = Vec::with_capacity(processed_rows.len());
for row in &processed_rows {
let ndim = row[0].ndim();
if !row.iter().all(|arr| arr.ndim() == ndim) {
return Err(NumRs2Error::InvalidOperation(
"Arrays in each row must have the same number of dimensions".into(),
));
}
if row.len() == 1 {
rows_result.push(row[0].clone());
continue;
}
let row_refs: Vec<&Array<T>> = row.iter().collect();
let axis = ndim - 1;
let concatenated_row = concatenate(&row_refs, axis)?;
rows_result.push(concatenated_row);
}
if rows_result.len() == 1 {
return Ok(rows_result[0].clone());
}
let row_ndim = rows_result[0].ndim();
if !rows_result.iter().all(|arr| arr.ndim() == row_ndim) {
return Err(NumRs2Error::InvalidOperation(
"All rows must have the same number of dimensions after processing".into(),
));
}
let row_refs: Vec<&Array<T>> = rows_result.iter().collect();
concatenate(&row_refs, 0)
}
pub fn r_<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
if arrays.len() > 1 && arrays.iter().all(|arr| arr.ndim() == 1) {
concatenate(arrays, 0)
} else {
let processed_arrays: Result<Vec<Array<T>>> = arrays
.iter()
.map(|arr| {
if arr.ndim() == 1 {
Ok(arr.reshape(&[1, arr.size()]))
} else {
Ok((*arr).clone())
}
})
.collect();
let processed = processed_arrays?;
let processed_refs: Vec<&Array<T>> = processed.iter().collect();
concatenate(&processed_refs, 0)
}
}
pub fn c_<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
let processed_arrays: Result<Vec<Array<T>>> = arrays
.iter()
.map(|arr| {
if arr.ndim() == 1 {
Ok(arr.reshape(&[arr.size(), 1]))
} else {
Ok((*arr).clone())
}
})
.collect();
let processed = processed_arrays?;
let processed_refs: Vec<&Array<T>> = processed.iter().collect();
concatenate(&processed_refs, 1)
}
pub fn require<T: Clone>(
array: &Array<T>,
requirements: Option<super::core::ArrayRequirements>,
) -> Result<Array<T>> {
use super::core::ArrayRequirements;
let requirements = requirements.unwrap_or(ArrayRequirements::empty());
if requirements.is_empty() {
return Ok(array.clone());
}
let need_c_layout = requirements.contains(ArrayRequirements::C_LAYOUT);
let need_f_layout = requirements.contains(ArrayRequirements::F_LAYOUT);
let need_contiguous = requirements.contains(ArrayRequirements::CONTIGUOUS);
let _need_owner = requirements.contains(ArrayRequirements::OWNDATA);
let _need_writeable = requirements.contains(ArrayRequirements::WRITEABLE);
let meets_c_layout = if need_c_layout {
array.is_c_contiguous()
} else {
true
};
let meets_f_layout = if need_f_layout {
array.is_f_contiguous()
} else {
true
};
let meets_contiguous = if need_contiguous {
array.is_contiguous()
} else {
true
};
if meets_c_layout && meets_f_layout && meets_contiguous {
return Ok(array.clone());
}
let mut result = array.clone();
if need_c_layout && !meets_c_layout {
result = result.to_c_layout();
}
if need_f_layout && !meets_f_layout {
result = result.to_f_layout();
}
if need_contiguous && !meets_contiguous && !need_c_layout && !need_f_layout {
result = result.to_c_layout();
}
Ok(result)
}