use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::Axis;
pub enum AxisArg {
Single(usize),
Multiple(Vec<usize>),
}
impl From<usize> for AxisArg {
fn from(axis: usize) -> Self {
AxisArg::Single(axis)
}
}
impl From<&[usize]> for AxisArg {
fn from(axes: &[usize]) -> Self {
AxisArg::Multiple(axes.to_vec())
}
}
impl From<Vec<usize>> for AxisArg {
fn from(axes: Vec<usize>) -> Self {
AxisArg::Multiple(axes)
}
}
pub fn concatenate<T: Clone>(arrays: &[&Array<T>], axis: impl Into<AxisArg>) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No arrays to concatenate".into(),
));
}
match axis.into() {
AxisArg::Single(axis) => concatenate_single_axis(arrays, axis),
AxisArg::Multiple(axes) => concatenate_multiple_axes(arrays, &axes),
}
}
fn concatenate_single_axis<T: Clone>(arrays: &[&Array<T>], axis: usize) -> Result<Array<T>> {
let first_shape = arrays[0].shape();
if axis >= first_shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
first_shape.len()
)));
}
for (_i, arr) in arrays.iter().enumerate().skip(1) {
let shape = arr.shape();
if shape.len() != first_shape.len() {
return Err(NumRs2Error::ShapeMismatch {
expected: first_shape.clone(),
actual: shape,
});
}
for (j, (&s1, &s2)) in first_shape.iter().zip(shape.iter()).enumerate() {
if j != axis && s1 != s2 {
return Err(NumRs2Error::ShapeMismatch {
expected: first_shape.clone(),
actual: shape,
});
}
}
}
let mut output_shape = first_shape.clone();
output_shape[axis] = arrays.iter().map(|arr| arr.shape()[axis]).sum();
let views: Result<Vec<_>> = arrays.iter().map(|arr| Ok(arr.array().view())).collect();
let views = views?;
let result = scirs2_core::ndarray::concatenate(Axis(axis), &views).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to concatenate arrays: {}", e))
})?;
Ok(Array::from_ndarray(result))
}
fn concatenate_multiple_axes<T: Clone>(arrays: &[&Array<T>], axes: &[usize]) -> Result<Array<T>> {
if axes.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No axes provided for concatenation".into(),
));
}
let mut result = arrays[0].clone();
for (i, &axis) in axes.iter().enumerate() {
if i == 0 {
result = concatenate_single_axis(arrays, axis)?;
} else {
result = concatenate_single_axis(&[&result, arrays[1]], axis)?;
}
}
Ok(result)
}
pub fn stack<T: Clone>(arrays: &[&Array<T>], axis: usize) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation("No arrays to stack".into()));
}
let first_shape = arrays[0].shape();
if axis > first_shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
first_shape.len()
)));
}
for arr in arrays.iter().skip(1) {
let shape = arr.shape();
if shape != first_shape {
return Err(NumRs2Error::ShapeMismatch {
expected: first_shape.clone(),
actual: shape,
});
}
}
let mut output_shape = first_shape.clone();
output_shape.insert(axis, arrays.len());
let mut reshaped_arrays = Vec::with_capacity(arrays.len());
for &arr in arrays {
let mut new_shape = first_shape.clone();
new_shape.insert(axis, 1);
let reshaped = arr.reshape(&new_shape);
reshaped_arrays.push(reshaped);
}
let mut result_refs: Vec<&Array<T>> = Vec::with_capacity(reshaped_arrays.len());
for arr in &reshaped_arrays {
result_refs.push(arr);
}
concatenate(&result_refs, axis)
}
pub fn block<T: Clone>(blocks: &[Vec<&Array<T>>]) -> Result<Array<T>> {
if blocks.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Empty block structure".into(),
));
}
let mut processed_rows = Vec::with_capacity(blocks.len());
for (row_idx, row) in blocks.iter().enumerate() {
if row.is_empty() {
return Err(NumRs2Error::InvalidOperation(format!(
"Empty row at index {} in block structure",
row_idx
)));
}
let mut processed_row = Vec::with_capacity(row.len());
let max_ndim = row.iter().map(|arr| arr.ndim()).max().unwrap_or(1);
for arr in row.iter() {
let arr_ndim = arr.ndim();
if arr_ndim < max_ndim {
let mut new_shape = arr.shape().to_vec();
while new_shape.len() < max_ndim {
if arr_ndim == 1 {
new_shape.push(1);
} else {
new_shape.insert(0, 1);
}
}
processed_row.push(arr.reshape(&new_shape));
} else {
processed_row.push((*arr).clone());
}
}
processed_rows.push(processed_row);
}
let mut rows_result = Vec::with_capacity(processed_rows.len());
for row in &processed_rows {
let ndim = row[0].ndim();
if !row.iter().all(|arr| arr.ndim() == ndim) {
return Err(NumRs2Error::InvalidOperation(
"Arrays in each row must have the same number of dimensions".into(),
));
}
if row.len() == 1 {
rows_result.push(row[0].clone());
continue;
}
let row_refs: Vec<&Array<T>> = row.iter().collect();
let axis = ndim - 1;
let concatenated_row = concatenate(&row_refs, axis)?;
rows_result.push(concatenated_row);
}
if rows_result.len() == 1 {
return Ok(rows_result[0].clone());
}
let row_ndim = rows_result[0].ndim();
if !rows_result.iter().all(|arr| arr.ndim() == row_ndim) {
return Err(NumRs2Error::InvalidOperation(
"All rows must have the same number of dimensions after processing".into(),
));
}
let row_refs: Vec<&Array<T>> = rows_result.iter().collect();
let axis = 0;
concatenate(&row_refs, axis)
}
pub fn r_<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
concatenate(arrays, 0)
}
pub fn c_<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No arrays to concatenate".into(),
));
}
let all_1d = arrays.iter().all(|arr| arr.ndim() == 1);
if all_1d {
let mut column_vectors = Vec::with_capacity(arrays.len());
for &arr in arrays {
let shape = arr.shape();
let new_shape = vec![shape[0], 1]; column_vectors.push(arr.reshape(&new_shape));
}
let column_refs: Vec<&Array<T>> = column_vectors.iter().collect();
concatenate(&column_refs, 1)
} else {
concatenate(arrays, 1)
}
}
pub fn vstack<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation("No arrays to stack".into()));
}
let mut reshaped_arrays = Vec::with_capacity(arrays.len());
for &arr in arrays {
if arr.ndim() == 1 {
let shape = arr.shape();
reshaped_arrays.push(arr.reshape(&[1, shape[0]]));
} else {
reshaped_arrays.push(arr.clone());
}
}
let array_refs: Vec<&Array<T>> = reshaped_arrays.iter().collect();
concatenate(&array_refs, 0)
}
pub fn hstack<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation("No arrays to stack".into()));
}
let all_1d = arrays.iter().all(|arr| arr.ndim() == 1);
if all_1d {
concatenate(arrays, 0)
} else {
concatenate(arrays, 1)
}
}
pub fn dstack<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation("No arrays to stack".into()));
}
let mut reshaped_arrays = Vec::with_capacity(arrays.len());
for &arr in arrays {
let shape = arr.shape();
let reshaped = match arr.ndim() {
1 => {
arr.reshape(&[1, shape[0], 1])
}
2 => {
arr.reshape(&[shape[0], shape[1], 1])
}
_ => {
arr.clone()
}
};
reshaped_arrays.push(reshaped);
}
let array_refs: Vec<&Array<T>> = reshaped_arrays.iter().collect();
concatenate(&array_refs, 2)
}
pub fn row_stack<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
vstack(arrays)
}
pub fn column_stack<T: Clone>(arrays: &[&Array<T>]) -> Result<Array<T>> {
if arrays.is_empty() {
return Err(NumRs2Error::InvalidOperation("No arrays to stack".into()));
}
let all_1d = arrays.iter().all(|arr| arr.ndim() == 1);
if all_1d {
let mut column_vectors = Vec::with_capacity(arrays.len());
for &arr in arrays {
let shape = arr.shape();
let new_shape = vec![shape[0], 1]; column_vectors.push(arr.reshape(&new_shape));
}
let column_refs: Vec<&Array<T>> = column_vectors.iter().collect();
concatenate(&column_refs, 1)
} else {
hstack(arrays)
}
}
pub fn bmat_from_string<T: Clone>(_description: &str) -> Result<Array<T>> {
Err(NumRs2Error::InvalidOperation(
"String-based bmat not yet implemented - use bmat_from_arrays instead".to_string(),
))
}
pub fn bmat_from_arrays<T: Clone>(obj: &[Vec<&Array<T>>]) -> Result<Array<T>> {
if obj.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Empty block matrix specification".to_string(),
));
}
for (row_idx, row) in obj.iter().enumerate() {
if row.is_empty() {
return Err(NumRs2Error::InvalidOperation(format!(
"Empty row {} in block matrix",
row_idx
)));
}
for (col_idx, &arr) in row.iter().enumerate() {
match arr.ndim() {
0 => {
return Err(NumRs2Error::InvalidOperation(format!(
"Scalar at position ({}, {}) - bmat requires 2D arrays",
row_idx, col_idx
)));
}
1 => {
return Err(NumRs2Error::InvalidOperation(format!(
"1D array at position ({}, {}) - bmat requires 2D arrays",
row_idx, col_idx
)));
}
2 => {} _ => {
return Err(NumRs2Error::InvalidOperation(format!(
"{}D array at position ({}, {}) - bmat requires 2D arrays",
arr.ndim(),
row_idx,
col_idx
)));
}
}
}
}
let num_cols = obj[0].len();
for (row_idx, row) in obj.iter().enumerate() {
if row.len() != num_cols {
return Err(NumRs2Error::InvalidOperation(format!(
"Row {} has {} blocks, expected {}",
row_idx,
row.len(),
num_cols
)));
}
}
for row_idx in 0..obj.len() {
let row = &obj[row_idx];
let expected_height = row[0].shape()[0];
for (col_idx, &arr) in row.iter().enumerate() {
let height = arr.shape()[0];
if height != expected_height {
return Err(NumRs2Error::InvalidOperation(format!(
"Block at ({}, {}) has height {}, but row {} requires height {}",
row_idx, col_idx, height, row_idx, expected_height
)));
}
}
}
for col_idx in 0..num_cols {
let expected_width = obj[0][col_idx].shape()[1];
for (row_idx, row) in obj.iter().enumerate() {
let width = row[col_idx].shape()[1];
if width != expected_width {
return Err(NumRs2Error::InvalidOperation(format!(
"Block at ({}, {}) has width {}, but column {} requires width {}",
row_idx, col_idx, width, col_idx, expected_width
)));
}
}
}
let mut concatenated_rows = Vec::with_capacity(obj.len());
for row in obj.iter() {
if row.len() == 1 {
concatenated_rows.push(row[0].clone());
} else {
let concatenated_row = concatenate(row, 1)?;
concatenated_rows.push(concatenated_row);
}
}
if concatenated_rows.len() == 1 {
Ok(concatenated_rows[0].clone())
} else {
let row_refs: Vec<&Array<T>> = concatenated_rows.iter().collect();
concatenate(&row_refs, 0)
}
}
pub fn bmat<T: Clone>(obj: &[Vec<&Array<T>>]) -> Result<Array<T>> {
bmat_from_arrays(obj)
}