use crate::limits::MattenLimits;
use crate::{MattenError, Tensor};
fn require_non_empty(tensors: &[&Tensor], operation: &'static str) -> Result<(), MattenError> {
if tensors.is_empty() {
return Err(MattenError::InvalidArgument {
operation,
argument: "tensors",
message: "at least one tensor is required".to_string(),
});
}
Ok(())
}
fn reject_dynamic(tensors: &[&Tensor], operation: &'static str) -> Result<(), MattenError> {
#[cfg(feature = "dynamic")]
{
for t in tensors {
if t.is_dynamic() {
return Err(MattenError::Unsupported {
operation,
message:
"dynamic tensors must be converted with try_numeric() before shape composition"
.to_string(),
});
}
}
}
#[cfg(not(feature = "dynamic"))]
let _ = (tensors, operation);
Ok(())
}
impl Tensor {
#[must_use]
pub fn concatenate(tensors: &[&Tensor], axis: usize) -> Tensor {
Tensor::try_concatenate(tensors, axis).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_concatenate(tensors: &[&Tensor], axis: usize) -> Result<Tensor, MattenError> {
require_non_empty(tensors, "concatenate")?;
reject_dynamic(tensors, "concatenate")?;
let first = tensors[0];
let rank = first.shape.len();
if axis >= rank {
return Err(MattenError::Shape {
operation: "concatenate",
message: format!(
"axis {axis} is out of range for concatenate on rank-{rank} tensors (valid 0..{rank})"
),
});
}
for (i, t) in tensors.iter().enumerate() {
if t.shape.len() != rank {
return Err(MattenError::Shape {
operation: "concatenate",
message: format!(
"tensor {i} has rank {} but tensor 0 has rank {rank}; \
concatenate requires equal ranks",
t.shape.len()
),
});
}
for (ax, (&d, &d0)) in t.shape.iter().zip(&first.shape).enumerate() {
if ax != axis && d != d0 {
return Err(MattenError::Shape {
operation: "concatenate",
message: format!(
"tensor {i} has size {d} at axis {ax} but tensor 0 has {d0}; \
all non-concatenation axes must match"
),
});
}
}
}
let mut axis_total: usize = 0;
for t in tensors {
axis_total =
axis_total
.checked_add(t.shape[axis])
.ok_or_else(|| MattenError::Allocation {
requested_elements: usize::MAX,
message: "concatenated axis size overflowed".to_string(),
})?;
}
let mut out_shape = first.shape.clone();
out_shape[axis] = axis_total;
let total = MattenLimits::default().check_shape(&out_shape, "concatenate")?;
let inner: usize = first.shape[axis + 1..].iter().product();
let outer: usize = first.shape[..axis].iter().product();
let mut data = Vec::with_capacity(total);
for o in 0..outer {
for t in tensors {
let block = t.shape[axis] * inner;
let start = o * block;
data.extend_from_slice(&t.data[start..start + block]);
}
}
Ok(Tensor {
data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
})
}
#[must_use]
pub fn stack(tensors: &[&Tensor], axis: usize) -> Tensor {
Tensor::try_stack(tensors, axis).unwrap_or_else(|e| panic!("{e}"))
}
pub fn try_stack(tensors: &[&Tensor], axis: usize) -> Result<Tensor, MattenError> {
require_non_empty(tensors, "stack")?;
reject_dynamic(tensors, "stack")?;
let first = tensors[0];
let rank = first.shape.len();
if axis > rank {
return Err(MattenError::Shape {
operation: "stack",
message: format!(
"axis {axis} is out of range for stack on rank-{rank} tensors (valid 0..={rank})"
),
});
}
for (i, t) in tensors.iter().enumerate() {
if t.shape != first.shape {
return Err(MattenError::Shape {
operation: "stack",
message: format!(
"tensor {i} has shape {:?} but tensor 0 has shape {:?}; \
stack requires identical shapes",
t.shape, first.shape
),
});
}
}
let n = tensors.len();
let mut out_shape = Vec::with_capacity(rank + 1);
out_shape.extend_from_slice(&first.shape[..axis]);
out_shape.push(n);
out_shape.extend_from_slice(&first.shape[axis..]);
let total = MattenLimits::default().check_shape(&out_shape, "stack")?;
let inner: usize = first.shape[axis..].iter().product();
let outer: usize = first.shape[..axis].iter().product();
let mut data = Vec::with_capacity(total);
for o in 0..outer {
for t in tensors {
let start = o * inner;
data.extend_from_slice(&t.data[start..start + inner]);
}
}
Ok(Tensor {
data,
shape: out_shape,
#[cfg(feature = "dynamic")]
dynamic: None,
})
}
}