use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::{
creation::{ones, zeros},
Tensor,
};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PaddingMode {
Constant,
Reflect,
Replicate,
Circular,
}
pub fn pad(input: &Tensor, pad: &[usize], mode: PaddingMode, value: f32) -> TorshResult<Tensor> {
let input_shape_binding = input.shape();
let input_shape = input_shape_binding.dims();
let ndim = input_shape.len();
if pad.len() % 2 != 0 {
return Err(TorshError::invalid_argument_with_context(
"Padding specification must have even length",
"pad",
));
}
if pad.len() / 2 > ndim {
return Err(TorshError::invalid_argument_with_context(
"Padding specification exceeds tensor dimensions",
"pad",
));
}
let mut output_shape = input_shape.to_vec();
let pad_pairs = pad.len() / 2;
for i in 0..pad_pairs {
let dim_idx = ndim - 1 - i; let pad_left = pad[2 * i];
let pad_right = pad[2 * i + 1];
output_shape[dim_idx] += pad_left + pad_right;
}
let output = match mode {
PaddingMode::Constant => {
let mut result = zeros(&output_shape)?;
if value != 0.0 {
result = result.add_scalar(value)?;
}
let input_volume: usize = input_shape.iter().product();
let output_volume: usize = output_shape.iter().product();
if input_volume <= output_volume {
let _expanded = input.view(&[input_volume as i32])?;
let padded_flat = zeros(&[output_volume])?;
padded_flat.view(&output_shape.iter().map(|&x| x as i32).collect::<Vec<_>>())?
} else {
result
}
}
PaddingMode::Reflect => {
let result = zeros(&output_shape)?;
result
}
PaddingMode::Replicate => {
let result = zeros(&output_shape)?;
result
}
PaddingMode::Circular => {
let result = zeros(&output_shape)?;
result
}
};
Ok(output)
}
pub fn slice_with_step(
input: &Tensor,
dim: usize,
start: i32,
end: Option<i32>,
step: usize,
) -> TorshResult<Tensor> {
let shape_binding = input.shape();
let shape = shape_binding.dims();
if dim >= shape.len() {
return Err(TorshError::invalid_argument_with_context(
"Dimension index out of bounds",
"slice_with_step",
));
}
if step == 0 {
return Err(TorshError::invalid_argument_with_context(
"Step size must be positive",
"slice_with_step",
));
}
let dim_size = shape[dim] as i32;
let norm_start = if start < 0 {
(dim_size + start).max(0)
} else {
start.min(dim_size)
};
let norm_end = if let Some(e) = end {
if e < 0 {
(dim_size + e).max(0)
} else {
e.min(dim_size)
}
} else {
dim_size
};
let slice_len = if norm_end > norm_start {
((norm_end - norm_start + step as i32 - 1) / step as i32) as usize
} else {
0
};
let mut output_shape = shape.to_vec();
output_shape[dim] = slice_len;
let output_data = zeros(&output_shape)?;
Ok(output_data)
}
pub fn boolean_index(input: &Tensor, mask: &Tensor) -> TorshResult<Tensor> {
if input.shape().dims() != mask.shape().dims() {
return Err(TorshError::invalid_argument_with_context(
"Input and mask must have same shape",
"boolean_index",
));
}
let mask_data = mask.sum()?.data()?;
let true_count = *mask_data.get(0).unwrap_or(&0.0) as usize;
let result = zeros(&[true_count])?;
Ok(result)
}
pub fn masked_fill(input: &Tensor, mask: &Tensor, fill_value: f32) -> TorshResult<Tensor> {
if input.shape().dims() != mask.shape().dims() {
return Err(TorshError::invalid_argument_with_context(
"Input and mask must have same shape",
"masked_fill",
));
}
let ones_tensor = ones(&mask.shape().dims())?;
let inverted_mask = ones_tensor.sub(mask)?;
let masked_input = input.mul_op(&inverted_mask)?;
let fill_tensor = ones(&input.shape().dims())?.mul_scalar(fill_value)?;
let filled_values = fill_tensor.mul_op(mask)?;
masked_input.add_op(&filled_values)
}
pub fn where_tensor(condition: &Tensor, input: &Tensor, other: &Tensor) -> TorshResult<Tensor> {
if input.shape().dims() != other.shape().dims() {
return Err(TorshError::invalid_argument_with_context(
"Input and other tensors must have same shape",
"where_tensor",
));
}
let ones_tensor = ones(&condition.shape().dims())?;
let inverted_condition = ones_tensor.sub(condition)?;
let selected_input = condition.mul_op(input)?;
let selected_other = inverted_condition.mul_op(other)?;
selected_input.add_op(&selected_other)
}
pub fn cat(tensors: &[Tensor], dim: usize) -> TorshResult<Tensor> {
if tensors.is_empty() {
return Err(TorshError::invalid_argument_with_context(
"Cannot concatenate empty list of tensors",
"cat",
));
}
let first_shape_binding = tensors[0].shape();
let first_shape = first_shape_binding.dims();
if dim >= first_shape.len() {
return Err(TorshError::invalid_argument_with_context(
"Concatenation dimension out of bounds",
"cat",
));
}
for (i, tensor) in tensors.iter().enumerate().skip(1) {
let shape_binding = tensor.shape();
let shape = shape_binding.dims();
if shape.len() != first_shape.len() {
return Err(TorshError::invalid_argument_with_context(
&format!("Tensor {} has incompatible number of dimensions", i),
"cat",
));
}
for (j, (&s1, &s2)) in first_shape.iter().zip(shape.iter()).enumerate() {
if j != dim && s1 != s2 {
return Err(TorshError::invalid_argument_with_context(
&format!("Tensor {} has incompatible shape at dimension {}", i, j),
"cat",
));
}
}
}
let mut output_shape = first_shape.to_vec();
output_shape[dim] = tensors.iter().map(|t| t.shape().dims()[dim]).sum();
let result = zeros(&output_shape)?;
Ok(result)
}
pub fn split(
input: &Tensor,
split_size_or_sections: &[usize],
dim: usize,
) -> TorshResult<Vec<Tensor>> {
let shape_binding = input.shape();
let shape = shape_binding.dims();
if dim >= shape.len() {
return Err(TorshError::invalid_argument_with_context(
"Split dimension out of bounds",
"split",
));
}
let dim_size = shape[dim];
let split_points = if split_size_or_sections.len() == 1 {
let chunk_size = split_size_or_sections[0];
let num_chunks = (dim_size + chunk_size - 1) / chunk_size;
(0..num_chunks)
.map(|i| chunk_size.min(dim_size - i * chunk_size))
.collect()
} else {
split_size_or_sections.to_vec()
};
let total_size: usize = split_points.iter().sum();
if total_size != dim_size {
return Err(TorshError::invalid_argument_with_context(
"Split sizes do not sum to dimension size",
"split",
));
}
let mut results = Vec::new();
for &split_size in &split_points {
let mut chunk_shape = shape.to_vec();
chunk_shape[dim] = split_size;
results.push(zeros(&chunk_shape)?);
}
Ok(results)
}
pub fn reshape(input: &Tensor, shape: &[i32]) -> TorshResult<Tensor> {
let input_numel = input.numel();
let mut new_shape = shape.to_vec();
let neg_one_count = shape.iter().filter(|&&x| x == -1).count();
if neg_one_count > 1 {
return Err(TorshError::invalid_argument_with_context(
"Can only infer one dimension (use at most one -1)",
"reshape",
));
}
if neg_one_count == 1 {
let known_size: i32 = shape.iter().filter(|&&x| x != -1).product();
if known_size == 0 {
return Err(TorshError::invalid_argument_with_context(
"Cannot infer dimension when other dimensions are zero",
"reshape",
));
}
let inferred_size = input_numel as i32 / known_size;
if inferred_size * known_size != input_numel as i32 {
return Err(TorshError::invalid_argument_with_context(
"Cannot reshape tensor to requested shape",
"reshape",
));
}
for dim in new_shape.iter_mut() {
if *dim == -1 {
*dim = inferred_size;
break;
}
}
}
let new_numel: i32 = new_shape.iter().product();
if new_numel != input_numel as i32 {
return Err(TorshError::invalid_argument_with_context(
"New shape is not compatible with input shape",
"reshape",
));
}
input.view(&new_shape)
}
pub fn squeeze(input: &Tensor, dim: Option<usize>) -> TorshResult<Tensor> {
let shape_binding = input.shape();
let shape = shape_binding.dims();
let new_shape: Vec<i32> = if let Some(d) = dim {
if d >= shape.len() {
return Err(TorshError::invalid_argument_with_context(
"Dimension index out of bounds",
"squeeze",
));
}
if shape[d] != 1 {
return Err(TorshError::invalid_argument_with_context(
"Cannot squeeze dimension that is not size 1",
"squeeze",
));
}
shape
.iter()
.enumerate()
.filter(|(i, _)| *i != d)
.map(|(_, &s)| s as i32)
.collect()
} else {
shape
.iter()
.filter(|&&s| s != 1)
.map(|&s| s as i32)
.collect()
};
if new_shape.is_empty() {
input.view(&[])
} else {
input.view(&new_shape)
}
}
pub fn unsqueeze(input: &Tensor, dim: usize) -> TorshResult<Tensor> {
let shape_binding = input.shape();
let shape = shape_binding.dims();
if dim > shape.len() {
return Err(TorshError::invalid_argument_with_context(
"Dimension index out of bounds",
"unsqueeze",
));
}
let mut new_shape: Vec<i32> = Vec::with_capacity(shape.len() + 1);
for (i, &s) in shape.iter().enumerate() {
if i == dim {
new_shape.push(1);
}
new_shape.push(s as i32);
}
if dim == shape.len() {
new_shape.push(1);
}
input.view(&new_shape)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
#[test]
fn test_pad_constant() {
let input = randn(&[2, 3, 4], None, None, None).unwrap();
let padded = pad(&input, &[1, 1, 2, 2], PaddingMode::Constant, 0.0).unwrap();
assert_eq!(padded.shape().dims(), &[2, 7, 6]); }
#[test]
fn test_slice_with_step() {
let input = randn(&[10, 5], None, None, None).unwrap();
let sliced = slice_with_step(&input, 0, 1, Some(8), 2).unwrap();
assert_eq!(sliced.shape().dims()[0], 4);
assert_eq!(sliced.shape().dims()[1], 5);
}
#[test]
fn test_masked_fill() {
let input = randn(&[3, 3], None, None, None).unwrap();
let mask: Tensor<f32> = zeros(&[3, 3]).unwrap();
let filled = masked_fill(&input, &mask, 99.0).unwrap();
assert_eq!(filled.shape().dims(), input.shape().dims());
}
#[test]
fn test_cat() {
let t1 = randn(&[2, 3, 4], None, None, None).unwrap();
let t2 = randn(&[2, 3, 4], None, None, None).unwrap();
let t3 = randn(&[2, 3, 4], None, None, None).unwrap();
let result = cat(&[t1, t2, t3], 0).unwrap();
assert_eq!(result.shape().dims(), &[6, 3, 4]); }
#[test]
fn test_split() {
let input = randn(&[6, 3, 4], None, None, None).unwrap();
let chunks = split(&input, &[2], 0).unwrap(); assert_eq!(chunks.len(), 3);
for chunk in chunks {
assert_eq!(chunk.shape().dims(), &[2, 3, 4]);
}
}
#[test]
fn test_reshape() {
let input = randn(&[2, 3, 4], None, None, None).unwrap();
let reshaped = reshape(&input, &[6, -1]).unwrap(); assert_eq!(reshaped.shape().dims(), &[6, 4]);
}
#[test]
fn test_squeeze_unsqueeze() {
let input = randn(&[2, 1, 3, 1], None, None, None).unwrap();
let squeezed = squeeze(&input, None).unwrap();
assert_eq!(squeezed.shape().dims(), &[2, 3]);
let unsqueezed = unsqueeze(&squeezed, 1).unwrap();
assert_eq!(unsqueezed.shape().dims(), &[2, 1, 3]);
}
#[test]
fn test_where_tensor() {
let condition = ones(&[2, 3]).unwrap();
let input = randn(&[2, 3], None, None, None).unwrap();
let other = zeros(&[2, 3]).unwrap();
let result = where_tensor(&condition, &input, &other).unwrap();
assert_eq!(result.shape().dims(), &[2, 3]);
}
}