pub mod construction;
pub mod contraction;
pub mod indexing;
pub mod shape;
pub mod splitting;
pub use shape::{atleast_1d, atleast_2d, atleast_3d};
pub use construction::{block_diag, cartesian_prod, meshgrid};
pub use splitting::{chunk, dsplit, hsplit, split, tensor_split, vsplit, SplitArg, TensorSplitArg};
pub use contraction::{tensordot, TensorDotAxes};
pub use indexing::{compute_strides, ravel_multi_index, unravel_index};
#[cfg(test)]
mod tests {
use super::*;
use crate::random_ops::randn;
use torsh_core::DeviceType;
use torsh_tensor::creation::ones;
#[test]
fn test_module_integration_shape_construction() -> torsh_core::Result<()> {
let vector = ones(&[3])?;
let matrix = atleast_2d(&vector)?;
let block = block_diag(&[matrix.clone(), matrix])?;
assert_eq!(block.shape().dims(), &[6, 2]);
Ok(())
}
#[test]
fn test_module_integration_split_contraction() -> torsh_core::Result<()> {
let large_matrix = randn(&[8, 6], None, None, None)?;
let splits = split(&large_matrix, SplitArg::Sections(2), 0)?;
assert_eq!(splits.len(), 2);
assert_eq!(splits[0].shape().dims(), &[4, 6]);
let other_matrix = randn(&[6, 4], None, None, None)?;
let result = tensordot(&splits[0], &other_matrix, TensorDotAxes::Int(1))?;
assert_eq!(result.shape().dims(), &[4, 4]);
Ok(())
}
#[test]
fn test_module_integration_indexing_construction() -> torsh_core::Result<()> {
let indices =
torsh_tensor::Tensor::from_data(vec![0.0f32, 1.0, 2.0, 3.0], vec![4], DeviceType::Cpu)?;
let shape = vec![2, 2];
let coords = unravel_index(&indices, &shape)?;
assert_eq!(coords.len(), 2);
let grid = meshgrid(&coords, "ij")?;
assert_eq!(grid.len(), 2);
assert_eq!(grid[0].shape().dims(), &[4, 4]);
Ok(())
}
#[test]
fn test_module_integration_comprehensive_workflow() -> torsh_core::Result<()> {
let vector1 = ones(&[4])?;
let vector2 = ones(&[4])?;
let matrix1 = atleast_2d(&vector1)?;
let matrix2 = atleast_2d(&vector2)?;
let block_matrix = block_diag(&[matrix1, matrix2])?;
assert_eq!(block_matrix.shape().dims(), &[8, 2]);
let splits = hsplit(&block_matrix, TensorSplitArg::Sections(2))?;
assert_eq!(splits.len(), 2);
assert_eq!(splits[0].shape().dims(), &[8, 1]);
let x = torsh_tensor::Tensor::from_data(vec![0.0, 1.0], vec![2], DeviceType::Cpu)?;
let y = torsh_tensor::Tensor::from_data(vec![0.0, 1.0], vec![2], DeviceType::Cpu)?;
let grids = meshgrid(&[x, y], "ij")?;
assert_eq!(grids[0].shape().dims(), &[2, 2]);
let weight_matrix = randn(&[1, 4], None, None, None)?;
let result = tensordot(
&splits[0],
&weight_matrix,
TensorDotAxes::Explicit(vec![1], vec![0]),
)?;
assert_eq!(result.shape().dims(), &[8, 4]);
Ok(())
}
#[test]
fn test_backward_compatibility() -> torsh_core::Result<()> {
let tensor = ones(&[5])?;
let _result1 = atleast_1d(&tensor)?;
let _result2 = atleast_2d(&tensor)?;
let _result3 = atleast_3d(&tensor)?;
let matrices = vec![ones(&[2, 2])?, ones(&[3, 3])?];
let _block = block_diag(&matrices)?;
let tensors = vec![
torsh_tensor::Tensor::from_data(vec![1.0, 2.0], vec![2], DeviceType::Cpu)?,
torsh_tensor::Tensor::from_data(vec![3.0, 4.0], vec![2], DeviceType::Cpu)?,
];
let _cart = cartesian_prod(&tensors)?;
let _mesh = meshgrid(&tensors, "xy")?;
let tensor = ones(&[6, 4])?;
let _splits1 = split(&tensor, SplitArg::Sections(2), 0)?;
let _splits2 = chunk(&tensor, 3, 0)?;
let _splits3 = tensor_split(&tensor, TensorSplitArg::Sections(2), 0)?;
let _splits4 = hsplit(&tensor, TensorSplitArg::Sections(2))?;
let _splits5 = vsplit(&tensor, TensorSplitArg::Sections(2))?;
let tensor_3d = ones(&[2, 3, 6])?;
let _splits6 = dsplit(&tensor_3d, TensorSplitArg::Sections(2))?;
let a = ones(&[3, 4])?;
let b = ones(&[4, 5])?;
let _result = tensordot(&a, &b, TensorDotAxes::Int(1))?;
let indices =
torsh_tensor::Tensor::from_data(vec![0.0, 1.0, 2.0], vec![3], DeviceType::Cpu)?;
let shape = vec![2, 2];
let _coords = unravel_index(&indices, &shape)?;
let _strides = compute_strides(&shape);
Ok(())
}
#[test]
fn test_error_propagation() -> torsh_core::Result<()> {
let tensor_1d = ones(&[5])?;
assert!(hsplit(&tensor_1d, TensorSplitArg::Sections(2)).is_err());
let a = ones(&[3, 4])?;
let b = ones(&[5, 6])?;
assert!(tensordot(&a, &b, TensorDotAxes::Int(1)).is_err());
let indices = torsh_tensor::Tensor::from_data(vec![10.0], vec![1], DeviceType::Cpu)?;
let shape = vec![2, 2]; assert!(unravel_index(&indices, &shape).is_err());
Ok(())
}
#[test]
fn test_performance_patterns() -> torsh_core::Result<()> {
let large_tensor = randn(&[1000, 500], None, None, None)?;
let chunks = chunk(&large_tensor, 10, 0)?;
assert_eq!(chunks.len(), 10);
assert_eq!(chunks[0].shape().dims(), &[100, 500]);
let matrices: Vec<_> = (0..5)
.map(|_| ones(&[100, 100]))
.collect::<Result<Vec<_>, _>>()?;
let block = block_diag(&matrices)?;
assert_eq!(block.shape().dims(), &[500, 500]);
let a = randn(&[50, 100], None, None, None)?;
let b = randn(&[100, 75], None, None, None)?;
let result = tensordot(&a, &b, TensorDotAxes::Int(1))?;
assert_eq!(result.shape().dims(), &[50, 75]);
Ok(())
}
}