use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[inline]
pub fn normalize_dim(dim: isize, ndim: usize) -> Option<usize> {
if ndim == 0 {
return None;
}
let idx = if dim < 0 {
let adjusted = ndim as isize + dim;
if adjusted < 0 {
return None;
}
adjusted as usize
} else {
dim as usize
};
if idx < ndim { Some(idx) } else { None }
}
#[inline]
pub fn normalize_stack_dim(dim: isize, ndim: usize) -> Option<usize> {
let new_ndim = ndim + 1;
if dim < 0 {
let adjusted = new_ndim as isize + dim;
if adjusted < 0 || adjusted >= new_ndim as isize {
None
} else {
Some(adjusted as usize)
}
} else {
let idx = dim as usize;
if idx <= ndim { Some(idx) } else { None }
}
}
#[derive(Debug, Clone)]
pub struct CatParams {
pub dim_idx: usize,
pub dtype: DType,
pub cat_dim_total: usize,
pub out_shape: Vec<usize>,
pub outer_size: usize,
pub inner_size: usize,
}
pub fn validate_cat<R: Runtime<DType = DType>>(
tensors: &[&Tensor<R>],
dim: isize,
) -> Result<CatParams> {
if tensors.is_empty() {
return Err(Error::InvalidArgument {
arg: "tensors",
reason: "cat requires at least one tensor".to_string(),
});
}
let first = tensors[0];
let dtype = first.dtype();
let ndim = first.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "tensors",
reason: "cannot concatenate scalar tensors".to_string(),
});
}
let dim_idx = normalize_dim(dim, ndim).ok_or(Error::InvalidDimension { dim, ndim })?;
let mut cat_dim_total = first.shape()[dim_idx];
for &tensor in &tensors[1..] {
if tensor.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: tensor.dtype(),
});
}
if tensor.ndim() != ndim {
return Err(Error::ShapeMismatch {
expected: first.shape().to_vec(),
got: tensor.shape().to_vec(),
});
}
for (i, (&a, &b)) in first.shape().iter().zip(tensor.shape().iter()).enumerate() {
if i != dim_idx && a != b {
return Err(Error::ShapeMismatch {
expected: first.shape().to_vec(),
got: tensor.shape().to_vec(),
});
}
}
cat_dim_total += tensor.shape()[dim_idx];
}
let mut out_shape = first.shape().to_vec();
out_shape[dim_idx] = cat_dim_total;
let outer_size: usize = out_shape[..dim_idx].iter().product();
let outer_size = outer_size.max(1);
let inner_size: usize = out_shape[dim_idx + 1..].iter().product();
let inner_size = inner_size.max(1);
Ok(CatParams {
dim_idx,
dtype,
cat_dim_total,
out_shape,
outer_size,
inner_size,
})
}
pub fn validate_stack<R: Runtime<DType = DType>>(
tensors: &[&Tensor<R>],
dim: isize,
) -> Result<usize> {
if tensors.is_empty() {
return Err(Error::InvalidArgument {
arg: "tensors",
reason: "stack requires at least one tensor".to_string(),
});
}
let first = tensors[0];
let dtype = first.dtype();
let ndim = first.ndim();
let dim_idx = normalize_stack_dim(dim, ndim).ok_or(Error::InvalidDimension {
dim,
ndim: ndim + 1,
})?;
for &tensor in &tensors[1..] {
if tensor.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: tensor.dtype(),
});
}
if tensor.shape() != first.shape() {
return Err(Error::ShapeMismatch {
expected: first.shape().to_vec(),
got: tensor.shape().to_vec(),
});
}
}
Ok(dim_idx)
}
pub fn split_impl<R: Runtime>(
tensor: &Tensor<R>,
split_size: usize,
dim: isize,
) -> Result<Vec<Tensor<R>>> {
if split_size == 0 {
return Err(Error::InvalidArgument {
arg: "split_size",
reason: "split_size must be greater than zero".to_string(),
});
}
let ndim = tensor.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: "cannot split a scalar tensor".to_string(),
});
}
let dim_idx = normalize_dim(dim, ndim).ok_or(Error::InvalidDimension { dim, ndim })?;
let dim_size = tensor.shape()[dim_idx];
let mut result = Vec::new();
let mut start = 0;
while start < dim_size {
let length = (dim_size - start).min(split_size);
let chunk = tensor.narrow(dim, start, length)?;
result.push(chunk);
start += length;
}
Ok(result)
}
pub fn chunk_impl<R: Runtime>(
tensor: &Tensor<R>,
chunks: usize,
dim: isize,
) -> Result<Vec<Tensor<R>>> {
if chunks == 0 {
return Err(Error::InvalidArgument {
arg: "chunks",
reason: "chunks must be greater than zero".to_string(),
});
}
let ndim = tensor.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: "cannot chunk a scalar tensor".to_string(),
});
}
let dim_idx = normalize_dim(dim, ndim).ok_or(Error::InvalidDimension { dim, ndim })?;
let dim_size = tensor.shape()[dim_idx];
let base_size = dim_size / chunks;
let remainder = dim_size % chunks;
let mut result = Vec::new();
let mut start = 0;
for i in 0..chunks {
let length = if i < remainder {
base_size + 1
} else {
base_size
};
if length > 0 {
let chunk = tensor.narrow(dim, start, length)?;
result.push(chunk);
start += length;
}
}
Ok(result)
}
#[derive(Debug, Clone)]
pub struct RepeatParams {
pub out_shape: Vec<usize>,
}
pub fn validate_repeat<R: Runtime>(tensor: &Tensor<R>, repeats: &[usize]) -> Result<RepeatParams> {
if repeats.len() != tensor.ndim() {
return Err(Error::InvalidArgument {
arg: "repeats",
reason: format!(
"repeats length ({}) must match tensor ndim ({})",
repeats.len(),
tensor.ndim()
),
});
}
for &r in repeats {
if r == 0 {
return Err(Error::InvalidArgument {
arg: "repeats",
reason: "repeat count cannot be zero".to_string(),
});
}
}
let out_shape: Vec<usize> = tensor
.shape()
.iter()
.zip(repeats.iter())
.map(|(&d, &r)| d * r)
.collect();
Ok(RepeatParams { out_shape })
}
#[derive(Debug, Clone)]
pub struct PadParams {
pub out_shape: Vec<usize>,
pub pad_per_dim: Vec<(usize, usize)>,
}
pub fn validate_pad<R: Runtime>(tensor: &Tensor<R>, padding: &[usize]) -> Result<PadParams> {
if !padding.len().is_multiple_of(2) {
return Err(Error::InvalidArgument {
arg: "padding",
reason: "padding must have even length (pairs of before/after)".to_string(),
});
}
let num_padded_dims = padding.len() / 2;
if num_padded_dims > tensor.ndim() {
return Err(Error::InvalidArgument {
arg: "padding",
reason: format!(
"padding specifies {} dimensions but tensor only has {}",
num_padded_dims,
tensor.ndim()
),
});
}
let ndim = tensor.ndim();
let mut pad_per_dim: Vec<(usize, usize)> = vec![(0, 0); ndim];
for i in 0..num_padded_dims {
let dim = ndim - 1 - i;
let before = padding[i * 2];
let after = padding[i * 2 + 1];
pad_per_dim[dim] = (before, after);
}
let out_shape: Vec<usize> = tensor
.shape()
.iter()
.zip(pad_per_dim.iter())
.map(|(&d, &(before, after))| d + before + after)
.collect();
Ok(PadParams {
out_shape,
pad_per_dim,
})
}
#[derive(Debug, Clone)]
pub struct RollParams {
pub dim_idx: usize,
pub shift: usize,
pub dim_size: usize,
}
pub fn validate_roll<R: Runtime>(
tensor: &Tensor<R>,
shift: isize,
dim: isize,
) -> Result<RollParams> {
let ndim = tensor.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: "cannot roll a scalar tensor".to_string(),
});
}
let dim_idx = normalize_dim(dim, ndim).ok_or(Error::InvalidDimension { dim, ndim })?;
let dim_size = tensor.shape()[dim_idx];
if dim_size == 0 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: "cannot roll along dimension of size 0".to_string(),
});
}
let shift = if shift >= 0 {
(shift as usize) % dim_size
} else {
let neg_shift = (-shift) as usize % dim_size;
if neg_shift == 0 {
0
} else {
dim_size - neg_shift
}
};
Ok(RollParams {
dim_idx,
shift,
dim_size,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_dim_positive() {
assert_eq!(normalize_dim(0, 3), Some(0));
assert_eq!(normalize_dim(1, 3), Some(1));
assert_eq!(normalize_dim(2, 3), Some(2));
assert_eq!(normalize_dim(3, 3), None); }
#[test]
fn test_normalize_dim_negative() {
assert_eq!(normalize_dim(-1, 3), Some(2));
assert_eq!(normalize_dim(-2, 3), Some(1));
assert_eq!(normalize_dim(-3, 3), Some(0));
assert_eq!(normalize_dim(-4, 3), None); }
#[test]
fn test_normalize_dim_zero_ndim() {
assert_eq!(normalize_dim(0, 0), None);
assert_eq!(normalize_dim(-1, 0), None);
}
#[test]
fn test_normalize_stack_dim() {
assert_eq!(normalize_stack_dim(0, 2), Some(0));
assert_eq!(normalize_stack_dim(1, 2), Some(1));
assert_eq!(normalize_stack_dim(2, 2), Some(2));
assert_eq!(normalize_stack_dim(3, 2), None);
assert_eq!(normalize_stack_dim(-1, 2), Some(2));
assert_eq!(normalize_stack_dim(-2, 2), Some(1));
assert_eq!(normalize_stack_dim(-3, 2), Some(0));
assert_eq!(normalize_stack_dim(-4, 2), None); }
}