use crate::shape::{coord_to_flat, flat_to_coord, validate_shape};
use crate::{MattenError, Tensor};
fn validate_reshape(
len: usize,
new_shape: &[usize],
operation: &'static str,
) -> Result<usize, MattenError> {
let new_len = validate_shape(new_shape, operation)?;
if new_len != len {
return Err(MattenError::Shape {
operation,
message: format!(
"cannot reshape tensor with {len} elements into shape {new_shape:?} \
requiring {new_len} elements"
),
});
}
Ok(new_len)
}
pub(crate) fn try_reshape_impl(t: &Tensor, new_shape: &[usize]) -> Result<Tensor, MattenError> {
validate_reshape(t.len(), new_shape, "reshape")?;
Ok(Tensor {
data: t.data.clone(),
shape: new_shape.to_vec(),
})
}
pub(crate) fn permute_axes(t: &Tensor, permutation: &[usize]) -> Tensor {
let src_shape = t.shape();
let result_shape: Vec<usize> = permutation.iter().map(|&p| src_shape[p]).collect();
let len = t.len();
let mut result_data = vec![0.0f64; len];
for src_flat in 0..len {
let src_coord = flat_to_coord(src_flat, src_shape);
let result_coord: Vec<usize> = permutation.iter().map(|&p| src_coord[p]).collect();
let dst_flat = coord_to_flat(&result_coord, &result_shape)
.expect("permuted coordinate is always valid by construction");
result_data[dst_flat] = t.data[src_flat];
}
Tensor {
data: result_data,
shape: result_shape,
}
}
pub(crate) fn validate_axes(
axis1: usize,
axis2: usize,
ndim: usize,
operation: &'static str,
) -> Result<(), MattenError> {
for ax in [axis1, axis2] {
if ax >= ndim {
return Err(MattenError::Shape {
operation,
message: format!("axis {ax} is out of range for rank-{ndim} tensor"),
});
}
}
Ok(())
}