use crate::error::{Result, TorshError};
use crate::shape::Shape;
pub fn scalar_shape() -> Shape {
Shape::new(vec![])
}
pub fn vector_shape(size: usize) -> Shape {
Shape::new(vec![size])
}
pub fn matrix_shape(rows: usize, cols: usize) -> Shape {
Shape::new(vec![rows, cols])
}
pub fn image_shape(height: usize, width: usize, channels: usize) -> Shape {
Shape::new(vec![height, width, channels])
}
pub fn batch_shape(batch_size: usize, base_shape: &Shape) -> Shape {
let mut dims = vec![batch_size];
dims.extend_from_slice(base_shape.dims());
Shape::new(dims)
}
pub fn sequence_shape(seq_len: usize, features: usize) -> Shape {
Shape::new(vec![seq_len, features])
}
pub fn flatten_from(shape: &Shape, start_dim: usize) -> Result<Shape> {
let dims = shape.dims();
if start_dim > dims.len() {
return Err(TorshError::InvalidShape(format!(
"start_dim {} is out of bounds for shape with {} dimensions",
start_dim,
dims.len()
)));
}
if start_dim == dims.len() {
return Ok(shape.clone());
}
let mut new_dims = Vec::with_capacity(start_dim + 1);
new_dims.extend_from_slice(&dims[..start_dim]);
let flattened_size: usize = dims[start_dim..].iter().product();
new_dims.push(flattened_size);
Ok(Shape::new(new_dims))
}
pub fn unsqueeze_at(shape: &Shape, dim: usize) -> Result<Shape> {
let dims = shape.dims();
if dim > dims.len() {
return Err(TorshError::InvalidShape(format!(
"dim {} is out of bounds for unsqueeze (max: {})",
dim,
dims.len()
)));
}
let mut new_dims = Vec::with_capacity(dims.len() + 1);
new_dims.extend_from_slice(&dims[..dim]);
new_dims.push(1);
new_dims.extend_from_slice(&dims[dim..]);
Ok(Shape::new(new_dims))
}
pub fn squeeze(shape: &Shape, dim: Option<usize>) -> Result<Shape> {
let dims = shape.dims();
if let Some(d) = dim {
if d >= dims.len() {
return Err(TorshError::InvalidShape(format!(
"dim {} is out of bounds for shape with {} dimensions",
d,
dims.len()
)));
}
if dims[d] != 1 {
return Err(TorshError::InvalidShape(format!(
"Cannot squeeze dimension {} with size {}",
d, dims[d]
)));
}
let mut new_dims = Vec::with_capacity(dims.len() - 1);
new_dims.extend_from_slice(&dims[..d]);
new_dims.extend_from_slice(&dims[d + 1..]);
Ok(Shape::new(new_dims))
} else {
let new_dims: Vec<usize> = dims.iter().copied().filter(|&d| d != 1).collect();
Ok(Shape::new(new_dims))
}
}
pub fn expand_to_rank(shape: &Shape, target_rank: usize) -> Result<Shape> {
let dims = shape.dims();
if dims.len() > target_rank {
return Err(TorshError::InvalidShape(format!(
"Shape rank {} is already greater than target rank {}",
dims.len(),
target_rank
)));
}
if dims.len() == target_rank {
return Ok(shape.clone());
}
let num_prepend = target_rank - dims.len();
let mut new_dims = vec![1; num_prepend];
new_dims.extend_from_slice(dims);
Ok(Shape::new(new_dims))
}
pub fn permute(shape: &Shape, permutation: &[usize]) -> Result<Shape> {
let dims = shape.dims();
if permutation.len() != dims.len() {
return Err(TorshError::InvalidShape(format!(
"Permutation length {} doesn't match shape rank {}",
permutation.len(),
dims.len()
)));
}
let mut seen = vec![false; dims.len()];
for &idx in permutation {
if idx >= dims.len() {
return Err(TorshError::InvalidShape(format!(
"Permutation index {} is out of bounds for shape with {} dimensions",
idx,
dims.len()
)));
}
if seen[idx] {
return Err(TorshError::InvalidShape(format!(
"Permutation index {} appears multiple times",
idx
)));
}
seen[idx] = true;
}
let new_dims: Vec<usize> = permutation.iter().map(|&i| dims[i]).collect();
Ok(Shape::new(new_dims))
}
pub fn are_compatible(shape1: &Shape, shape2: &Shape) -> bool {
if shape1.dims() == shape2.dims() {
return true;
}
let dims1 = shape1.dims();
let dims2 = shape2.dims();
let max_rank = dims1.len().max(dims2.len());
for i in 0..max_rank {
let dim1 = dims1
.get(dims1.len().saturating_sub(max_rank - i))
.copied()
.unwrap_or(1);
let dim2 = dims2
.get(dims2.len().saturating_sub(max_rank - i))
.copied()
.unwrap_or(1);
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return false;
}
}
true
}
pub fn numel(shape: &Shape) -> usize {
shape.numel()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_shape() {
let s = scalar_shape();
let empty: &[usize] = &[];
assert_eq!(s.dims(), empty);
assert!(s.is_scalar());
}
#[test]
fn test_vector_shape() {
let s = vector_shape(128);
assert_eq!(s.dims(), &[128]);
}
#[test]
fn test_matrix_shape() {
let s = matrix_shape(64, 128);
assert_eq!(s.dims(), &[64, 128]);
}
#[test]
fn test_image_shape() {
let rgb = image_shape(224, 224, 3);
assert_eq!(rgb.dims(), &[224, 224, 3]);
let grayscale = image_shape(28, 28, 1);
assert_eq!(grayscale.dims(), &[28, 28, 1]);
}
#[test]
fn test_batch_shape() {
let img = image_shape(224, 224, 3);
let batch = batch_shape(32, &img);
assert_eq!(batch.dims(), &[32, 224, 224, 3]);
}
#[test]
fn test_sequence_shape() {
let seq = sequence_shape(100, 512);
assert_eq!(seq.dims(), &[100, 512]);
}
#[test]
fn test_flatten_from() {
let shape = Shape::new(vec![32, 3, 224, 224]);
let flattened = flatten_from(&shape, 1).expect("flatten_from should succeed");
assert_eq!(flattened.dims(), &[32, 150528]);
let flattened_all = flatten_from(&shape, 0).expect("flatten_from should succeed");
assert_eq!(flattened_all.dims(), &[4816896]);
}
#[test]
fn test_unsqueeze_at() {
let shape = Shape::new(vec![3, 224, 224]);
let unsqueezed = unsqueeze_at(&shape, 0).expect("unsqueeze_at should succeed");
assert_eq!(unsqueezed.dims(), &[1, 3, 224, 224]);
let unsqueezed_end = unsqueeze_at(&shape, 3).expect("unsqueeze_at should succeed");
assert_eq!(unsqueezed_end.dims(), &[3, 224, 224, 1]);
assert!(unsqueeze_at(&shape, 10).is_err());
}
#[test]
fn test_squeeze() {
let shape = Shape::new(vec![1, 3, 1, 224, 224]);
let squeezed = squeeze(&shape, None).expect("squeeze should succeed");
assert_eq!(squeezed.dims(), &[3, 224, 224]);
let shape2 = Shape::new(vec![1, 3, 1, 224]);
let squeezed_dim = squeeze(&shape2, Some(2)).expect("squeeze should succeed");
assert_eq!(squeezed_dim.dims(), &[1, 3, 224]);
assert!(squeeze(&shape2, Some(1)).is_err());
}
#[test]
fn test_expand_to_rank() {
let shape = Shape::new(vec![224, 224]);
let expanded = expand_to_rank(&shape, 4).expect("expand_to_rank should succeed");
assert_eq!(expanded.dims(), &[1, 1, 224, 224]);
let same = expand_to_rank(&shape, 2).expect("expand_to_rank should succeed");
assert_eq!(same.dims(), &[224, 224]);
assert!(expand_to_rank(&shape, 1).is_err());
}
#[test]
fn test_permute() {
let shape = Shape::new(vec![32, 3, 224, 224]);
let permuted = permute(&shape, &[0, 2, 3, 1]).expect("permute should succeed");
assert_eq!(permuted.dims(), &[32, 224, 224, 3]);
assert!(permute(&shape, &[0, 1]).is_err()); assert!(permute(&shape, &[0, 1, 2, 10]).is_err()); assert!(permute(&shape, &[0, 1, 1, 2]).is_err()); }
#[test]
fn test_are_compatible() {
let s1 = Shape::new(vec![32, 3, 224, 224]);
let s2 = Shape::new(vec![32, 3, 224, 224]);
assert!(are_compatible(&s1, &s2));
let s3 = Shape::new(vec![1, 3, 1, 1]);
assert!(are_compatible(&s1, &s3));
let s4 = Shape::new(vec![32, 5, 224, 224]);
assert!(!are_compatible(&s1, &s4));
}
#[test]
fn test_numel() {
let shape = Shape::new(vec![32, 3, 224, 224]);
assert_eq!(numel(&shape), 4816896);
let scalar = scalar_shape();
assert_eq!(numel(&scalar), 1);
}
}