pub mod const_generic;
pub mod core;
pub use core::Shape;
pub use const_generic::{
common, utils as const_utils, BroadcastCompatible, ConstShape, ConstShapeOps, MatMulCompatible,
Rank0, Rank1, Rank2, Rank3, Rank4, Rank5, ReshapeInto, ShapeRank, SqueezeOps, TransposeOps,
UnsqueezeOps,
};
pub type ShapeRef<'a> = &'a Shape;
pub mod patterns {
use super::Shape;
use crate::error::Result;
pub struct NeuralNetShapes;
impl NeuralNetShapes {
pub fn fully_connected_input(batch_size: usize, features: usize) -> Result<Shape> {
Shape::from_2d(batch_size, features)
}
pub fn conv2d_input(
batch_size: usize,
channels: usize,
height: usize,
width: usize,
) -> Result<Shape> {
Shape::from_4d(batch_size, channels, height, width)
}
pub fn sequence(batch_size: usize, seq_len: usize, features: usize) -> Result<Shape> {
Shape::from_3d(batch_size, seq_len, features)
}
pub fn image_batch(
batch_size: usize,
channels: usize,
height: usize,
width: usize,
) -> Result<Shape> {
Shape::from_4d(batch_size, channels, height, width)
}
}
pub struct VisionShapes;
impl VisionShapes {
pub fn rgb_image(height: usize, width: usize) -> Result<Shape> {
Shape::from_3d(3, height, width)
}
pub fn grayscale_image(height: usize, width: usize) -> Result<Shape> {
Shape::from_3d(1, height, width)
}
pub fn bounding_boxes(num_boxes: usize) -> Result<Shape> {
Shape::from_2d(num_boxes, 4)
}
pub fn keypoints(num_keypoints: usize) -> Result<Shape> {
Shape::from_2d(num_keypoints, 2)
}
}
pub struct NlpShapes;
impl NlpShapes {
pub fn token_indices(batch_size: usize, seq_len: usize) -> Result<Shape> {
Shape::from_2d(batch_size, seq_len)
}
pub fn attention_mask(batch_size: usize, seq_len: usize) -> Result<Shape> {
Shape::from_2d(batch_size, seq_len)
}
pub fn embeddings(vocab_size: usize, embedding_dim: usize) -> Result<Shape> {
Shape::from_2d(vocab_size, embedding_dim)
}
pub fn transformer_hidden(
batch_size: usize,
seq_len: usize,
hidden_size: usize,
) -> Result<Shape> {
Shape::from_3d(batch_size, seq_len, hidden_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neural_net_shapes() {
let fc_input = NeuralNetShapes::fully_connected_input(32, 784)
.expect("shape creation should succeed");
assert_eq!(fc_input.dims(), &[32, 784]);
let conv_input = NeuralNetShapes::conv2d_input(16, 3, 224, 224)
.expect("shape creation should succeed");
assert_eq!(conv_input.dims(), &[16, 3, 224, 224]);
let sequence =
NeuralNetShapes::sequence(8, 50, 512).expect("shape creation should succeed");
assert_eq!(sequence.dims(), &[8, 50, 512]);
}
#[test]
fn test_vision_shapes() {
let rgb = VisionShapes::rgb_image(224, 224).expect("shape creation should succeed");
assert_eq!(rgb.dims(), &[3, 224, 224]);
let grayscale =
VisionShapes::grayscale_image(128, 128).expect("shape creation should succeed");
assert_eq!(grayscale.dims(), &[1, 128, 128]);
let boxes = VisionShapes::bounding_boxes(10).expect("shape creation should succeed");
assert_eq!(boxes.dims(), &[10, 4]);
}
#[test]
fn test_nlp_shapes() {
let tokens = NlpShapes::token_indices(4, 128).expect("shape creation should succeed");
assert_eq!(tokens.dims(), &[4, 128]);
let embeddings =
NlpShapes::embeddings(30000, 768).expect("shape creation should succeed");
assert_eq!(embeddings.dims(), &[30000, 768]);
let hidden =
NlpShapes::transformer_hidden(2, 50, 512).expect("shape creation should succeed");
assert_eq!(hidden.dims(), &[2, 50, 512]);
}
}
}
pub mod utils {
use super::Shape;
use crate::error::Result;
pub fn calculate_reshape_numel(current_shape: &Shape, new_dims: &[i32]) -> Result<usize> {
let current_numel = current_shape.numel();
let inferred_count = new_dims.iter().filter(|&&d| d == -1).count();
if inferred_count > 1 {
return Err(crate::error::TorshError::InvalidOperation(
"Cannot infer more than one dimension".to_string(),
));
}
if inferred_count == 0 {
let new_numel: usize = new_dims
.iter()
.map(|&d| d as usize)
.try_fold(1usize, |acc, dim| acc.checked_mul(dim))
.ok_or_else(|| {
crate::error::TorshError::InvalidOperation(
"Shape dimensions would overflow".to_string(),
)
})?;
if new_numel != current_numel {
return Err(crate::error::TorshError::InvalidOperation(format!(
"Cannot reshape tensor with {} elements to shape with {} elements",
current_numel, new_numel
)));
}
Ok(new_numel)
} else {
let known_product: usize = new_dims
.iter()
.filter(|&&d| d != -1)
.map(|&d| d as usize)
.try_fold(1usize, |acc, dim| acc.checked_mul(dim))
.ok_or_else(|| {
crate::error::TorshError::InvalidOperation(
"Shape dimensions would overflow".to_string(),
)
})?;
if !current_numel.is_multiple_of(known_product) {
return Err(crate::error::TorshError::InvalidOperation(format!(
"Cannot infer dimension: {} elements cannot be evenly divided by {}",
current_numel, known_product
)));
}
Ok(current_numel)
}
}
pub fn infer_reshape_dimension(current_shape: &Shape, new_dims: &[i32]) -> Result<Vec<usize>> {
let current_numel = current_shape.numel();
let inferred_indices: Vec<usize> = new_dims
.iter()
.enumerate()
.filter(|(_, &d)| d == -1)
.map(|(i, _)| i)
.collect();
if inferred_indices.len() > 1 {
return Err(crate::error::TorshError::InvalidOperation(
"Cannot infer more than one dimension".to_string(),
));
}
let mut result_dims = Vec::with_capacity(new_dims.len());
if inferred_indices.is_empty() {
for &dim in new_dims {
result_dims.push(dim as usize);
}
} else {
let known_product: usize = new_dims
.iter()
.filter(|&&d| d != -1)
.map(|&d| d as usize)
.try_fold(1usize, |acc, dim| acc.checked_mul(dim))
.ok_or_else(|| {
crate::error::TorshError::InvalidOperation(
"Shape dimensions would overflow".to_string(),
)
})?;
if !current_numel.is_multiple_of(known_product) {
return Err(crate::error::TorshError::InvalidOperation(format!(
"Cannot infer dimension: {} elements cannot be evenly divided by {}",
current_numel, known_product
)));
}
let inferred_dim = current_numel / known_product;
for &dim in new_dims.iter() {
if dim == -1 {
result_dims.push(inferred_dim);
} else {
result_dims.push(dim as usize);
}
}
}
Ok(result_dims)
}
pub fn can_view_as(from_shape: &Shape, to_shape: &Shape) -> bool {
from_shape.numel() == to_shape.numel()
}
pub fn calculate_contiguous_strides(shape: &Shape) -> Vec<usize> {
let dims = shape.dims();
if dims.is_empty() {
return vec![];
}
let mut strides = vec![1; dims.len()];
for i in (0..dims.len() - 1).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
strides
}
pub fn is_contiguous_strides(shape: &Shape, strides: &[usize]) -> bool {
if shape.dims().len() != strides.len() {
return false;
}
let expected_strides = calculate_contiguous_strides(shape);
strides == expected_strides
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reshape_calculation() {
let shape = Shape::new(vec![2, 3, 4]);
let numel = calculate_reshape_numel(&shape, &[6, 4])
.expect("reshape calculation should succeed");
assert_eq!(numel, 24);
let inferred = infer_reshape_dimension(&shape, &[-1, 4])
.expect("reshape inference should succeed");
assert_eq!(inferred, vec![6, 4]);
let inferred2 = infer_reshape_dimension(&shape, &[2, -1])
.expect("reshape inference should succeed");
assert_eq!(inferred2, vec![2, 12]);
}
#[test]
fn test_view_compatibility() {
let shape1 = Shape::new(vec![2, 3, 4]);
let shape2 = Shape::new(vec![6, 4]);
assert!(can_view_as(&shape1, &shape2));
let shape3 = Shape::new(vec![2, 5]);
assert!(!can_view_as(&shape1, &shape3));
}
#[test]
fn test_strides() {
let shape = Shape::new(vec![2, 3, 4]);
let strides = calculate_contiguous_strides(&shape);
assert_eq!(strides, vec![12, 4, 1]);
assert!(is_contiguous_strides(&shape, &strides));
assert!(!is_contiguous_strides(&shape, &[1, 2, 3]));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_module_integration() {
let shape = Shape::new(vec![2, 3, 4]);
assert_eq!(shape.numel(), 24);
assert_eq!(shape.ndim(), 3);
let conv_shape = patterns::NeuralNetShapes::conv2d_input(1, 3, 224, 224)
.expect("shape creation should succeed");
assert_eq!(conv_shape.dims(), &[1, 3, 224, 224]);
let strides = utils::calculate_contiguous_strides(&shape);
assert_eq!(strides, vec![12, 4, 1]);
}
}