use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use torsh_core::{
device::DeviceType,
dtype::TensorElement,
error::{Result, TorshError},
shape::Shape,
};
use crate::{Tensor, Operation};
use std::fmt;
#[derive(Clone)]
pub enum LazyOp<T: TensorElement> {
Identity(Arc<Tensor<T>>),
Add(Box<LazyOp<T>>, Box<LazyOp<T>>),
Mul(Box<LazyOp<T>>, Box<LazyOp<T>>),
Sub(Box<LazyOp<T>>, Box<LazyOp<T>>),
Div(Box<LazyOp<T>>, Box<LazyOp<T>>),
AddScalar(Box<LazyOp<T>>, T),
MulScalar(Box<LazyOp<T>>, T),
SubScalar(Box<LazyOp<T>>, T),
DivScalar(Box<LazyOp<T>>, T),
Pow(Box<LazyOp<T>>, T),
MatMul(Box<LazyOp<T>>, Box<LazyOp<T>>),
Transpose(Box<LazyOp<T>>, Option<(usize, usize)>),
Reshape(Box<LazyOp<T>>, Shape),
Sum(Box<LazyOp<T>>, Option<i32>),
Mean(Box<LazyOp<T>>, Option<i32>),
ReLU(Box<LazyOp<T>>),
Sigmoid(Box<LazyOp<T>>),
Tanh(Box<LazyOp<T>>),
Exp(Box<LazyOp<T>>),
Log(Box<LazyOp<T>>),
Sin(Box<LazyOp<T>>),
Cos(Box<LazyOp<T>>),
Custom(String, Box<LazyOp<T>>, Arc<dyn Fn(&Tensor<T>) -> Result<Tensor<T>> + Send + Sync>),
}
impl<T: TensorElement> fmt::Debug for LazyOp<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LazyOp::Identity(tensor) => write!(f, "Identity({:?})", tensor),
LazyOp::Add(lhs, rhs) => write!(f, "Add({:?}, {:?})", lhs, rhs),
LazyOp::Mul(lhs, rhs) => write!(f, "Mul({:?}, {:?})", lhs, rhs),
LazyOp::Sub(lhs, rhs) => write!(f, "Sub({:?}, {:?})", lhs, rhs),
LazyOp::Div(lhs, rhs) => write!(f, "Div({:?}, {:?})", lhs, rhs),
LazyOp::AddScalar(input, scalar) => write!(f, "AddScalar({:?}, {:?})", input, scalar),
LazyOp::MulScalar(input, scalar) => write!(f, "MulScalar({:?}, {:?})", input, scalar),
LazyOp::SubScalar(input, scalar) => write!(f, "SubScalar({:?}, {:?})", input, scalar),
LazyOp::DivScalar(input, scalar) => write!(f, "DivScalar({:?}, {:?})", input, scalar),
LazyOp::Pow(input, exp) => write!(f, "Pow({:?}, {:?})", input, exp),
LazyOp::MatMul(lhs, rhs) => write!(f, "MatMul({:?}, {:?})", lhs, rhs),
LazyOp::Transpose(input, dims) => write!(f, "Transpose({:?}, {:?})", input, dims),
LazyOp::Reshape(input, shape) => write!(f, "Reshape({:?}, {:?})", input, shape),
LazyOp::Sum(input, dim) => write!(f, "Sum({:?}, {:?})", input, dim),
LazyOp::Mean(input, dim) => write!(f, "Mean({:?}, {:?})", input, dim),
LazyOp::ReLU(input) => write!(f, "ReLU({:?})", input),
LazyOp::Sigmoid(input) => write!(f, "Sigmoid({:?})", input),
LazyOp::Tanh(input) => write!(f, "Tanh({:?})", input),
LazyOp::Exp(input) => write!(f, "Exp({:?})", input),
LazyOp::Log(input) => write!(f, "Log({:?})", input),
LazyOp::Sin(input) => write!(f, "Sin({:?})", input),
LazyOp::Cos(input) => write!(f, "Cos({:?})", input),
LazyOp::Custom(name, input, _) => write!(f, "Custom({}, {:?}, <fn>)", name, input),
}
}
}
pub struct LazyTensor<T: TensorElement> {
operation: LazyOp<T>,
cached_shape: RwLock<Option<Shape>>,
optimization_passes: Vec<OptimizationPass>,
}
pub type OptimizationPass = Box<dyn Fn(&LazyOp<f32>) -> LazyOp<f32> + Send + Sync>;
impl<T: TensorElement + Into<f32> + std::iter::Sum + num_traits::FromPrimitive + torsh_core::dtype::FloatElement> LazyTensor<T> {
pub fn from_tensor(tensor: Tensor<T>) -> Self {
Self {
operation: LazyOp::Identity(Arc::new(tensor)),
cached_shape: RwLock::new(None),
optimization_passes: Vec::new(),
}
}
pub fn from_operation(operation: LazyOp<T>) -> Self {
Self {
operation,
cached_shape: RwLock::new(None),
optimization_passes: Vec::new(),
}
}
pub fn with_optimization<F>(mut self, pass: F) -> Self
where
F: Fn(&LazyOp<f32>) -> LazyOp<f32> + Send + Sync + 'static,
{
self.optimization_passes.push(Box::new(pass));
self
}
pub fn shape(&self) -> Result<Shape> {
{
let cached = self.cached_shape.read().expect("lock should not be poisoned");
if let Some(ref shape) = *cached {
return Ok(shape.clone());
}
}
let shape = self.compute_shape(&self.operation)?;
{
let mut cached = self.cached_shape.write().expect("lock should not be poisoned");
*cached = Some(shape.clone());
}
Ok(shape)
}
pub fn eval(self) -> Result<Tensor<T>>
where
T: std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ num_traits::Float + Copy,
{
let optimized_op = self.operation.clone();
let temp_instance = LazyTensor {
operation: LazyOp::Identity(Arc::new(Tensor::zeros(&[1], torsh_core::device::DeviceType::Cpu).expect("tensor creation should succeed"))), cached_shape: RwLock::new(None),
optimization_passes: Vec::new(),
};
temp_instance.evaluate_operation(&optimized_op)
}
fn apply_optimizations(&self, op: LazyOp<T>) -> LazyOp<T> {
op
}
fn evaluate_operation(&self, op: &LazyOp<T>) -> Result<Tensor<T>>
where
T: std::ops::Add<Output = T>
+ std::ops::Sub<Output = T>
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ num_traits::Float
+ std::iter::Sum
+ torsh_core::dtype::FloatElement
+ num_traits::FromPrimitive
+ Into<f32>
+ Copy,
{
match op {
LazyOp::Identity(tensor) => Ok((**tensor).clone()),
LazyOp::Add(lhs, rhs) => {
let lhs_val = self.evaluate_operation(lhs)?;
let rhs_val = self.evaluate_operation(rhs)?;
lhs_val.add_op(&rhs_val)
}
LazyOp::Mul(lhs, rhs) => {
let lhs_val = self.evaluate_operation(lhs)?;
let rhs_val = self.evaluate_operation(rhs)?;
lhs_val.mul_op(&rhs_val)
}
LazyOp::Sub(lhs, rhs) => {
let lhs_val = self.evaluate_operation(lhs)?;
let rhs_val = self.evaluate_operation(rhs)?;
lhs_val.sub(&rhs_val)
}
LazyOp::Div(lhs, rhs) => {
let lhs_val = self.evaluate_operation(lhs)?;
let rhs_val = self.evaluate_operation(rhs)?;
lhs_val.div(&rhs_val)
}
LazyOp::AddScalar(input, scalar) => {
let input_val = self.evaluate_operation(input)?;
input_val.add_scalar(*scalar)
}
LazyOp::MulScalar(input, scalar) => {
let input_val = self.evaluate_operation(input)?;
input_val.mul_scalar(*scalar)
}
LazyOp::SubScalar(input, scalar) => {
let input_val = self.evaluate_operation(input)?;
input_val.sub_scalar(*scalar)
}
LazyOp::DivScalar(input, scalar) => {
let input_val = self.evaluate_operation(input)?;
input_val.div_scalar(*scalar)
}
LazyOp::Pow(input, exponent) => {
let input_val = self.evaluate_operation(input)?;
input_val.pow(*exponent)
}
LazyOp::MatMul(lhs, rhs) => {
let lhs_val = self.evaluate_operation(lhs)?;
let rhs_val = self.evaluate_operation(rhs)?;
lhs_val.matmul(&rhs_val)
}
LazyOp::Transpose(input, dims) => {
let input_val = self.evaluate_operation(input)?;
match dims {
Some((dim1, dim2)) => input_val.transpose(*dim1 as i32, *dim2 as i32),
None => input_val.t(),
}
}
LazyOp::Reshape(input, new_shape) => {
let input_val = self.evaluate_operation(input)?;
let dims_i32: Vec<i32> = new_shape.dims().iter().map(|&d| d as i32).collect();
input_val.reshape(&dims_i32)
}
LazyOp::Sum(input, dim) => {
let input_val = self.evaluate_operation(input)?;
match dim {
Some(d) => input_val.sum_dim(&[*d], false),
None => Ok(input_val.sum()?),
}
}
LazyOp::Mean(input, dim) => {
let input_val = self.evaluate_operation(input)?;
match dim {
Some(d) => {
let dim_usize = if *d < 0 {
(input_val.shape().dims().len() as i32 + *d) as usize
} else {
*d as usize
};
input_val.mean(Some(&[dim_usize]), false)
},
None => input_val.mean(None, false),
}
}
LazyOp::ReLU(input) => {
let input_val = self.evaluate_operation(input)?;
input_val.relu()
}
LazyOp::Sigmoid(input) => {
let input_val = self.evaluate_operation(input)?;
input_val.sigmoid()
}
LazyOp::Tanh(input) => {
let input_val = self.evaluate_operation(input)?;
input_val.tanh()
}
LazyOp::Exp(input) => {
let input_val = self.evaluate_operation(input)?;
input_val.exp()
}
LazyOp::Log(input) => {
let input_val = self.evaluate_operation(input)?;
input_val.log()
}
LazyOp::Sin(input) => {
let input_val = self.evaluate_operation(input)?;
input_val.sin()
}
LazyOp::Cos(input) => {
let input_val = self.evaluate_operation(input)?;
input_val.cos()
}
LazyOp::Custom(_name, input, func) => {
let input_val = self.evaluate_operation(input)?;
func(&input_val)
}
}
}
fn compute_shape(&self, op: &LazyOp<T>) -> Result<Shape> {
match op {
LazyOp::Identity(tensor) => Ok(tensor.shape().clone()),
LazyOp::Add(lhs, rhs) | LazyOp::Mul(lhs, rhs) | LazyOp::Sub(lhs, rhs) | LazyOp::Div(lhs, rhs) => {
let lhs_shape = self.compute_shape(lhs)?;
let rhs_shape = self.compute_shape(rhs)?;
Ok(lhs_shape)
}
LazyOp::AddScalar(input, _) | LazyOp::MulScalar(input, _) |
LazyOp::SubScalar(input, _) | LazyOp::DivScalar(input, _) => {
self.compute_shape(input)
}
LazyOp::Pow(input, _) => self.compute_shape(input),
LazyOp::MatMul(lhs, rhs) => {
let lhs_shape = self.compute_shape(lhs)?;
let rhs_shape = self.compute_shape(rhs)?;
if lhs_shape.dims().len() < 2 || rhs_shape.dims().len() < 2 {
return Err(TorshError::InvalidShape("MatMul requires at least 2D tensors".to_string()));
}
let lhs_dims = lhs_shape.dims();
let rhs_dims = rhs_shape.dims();
let m = lhs_dims[lhs_dims.len() - 2];
let k1 = lhs_dims[lhs_dims.len() - 1];
let k2 = rhs_dims[rhs_dims.len() - 2];
let n = rhs_dims[rhs_dims.len() - 1];
if k1 != k2 {
return Err(TorshError::ShapeMismatch {
expected: vec![k1],
got: vec![k2],
});
}
let mut result_dims = lhs_dims[..lhs_dims.len()-2].to_vec();
result_dims.extend_from_slice(&[m, n]);
Ok(Shape::new(result_dims))
}
LazyOp::Transpose(input, dims) => {
let input_shape = self.compute_shape(input)?;
let mut result_dims = input_shape.dims().to_vec();
match dims {
Some((dim1, dim2)) => {
if *dim1 < result_dims.len() && *dim2 < result_dims.len() {
result_dims.swap(*dim1, *dim2);
}
}
None => {
if result_dims.len() >= 2 {
let len = result_dims.len();
result_dims.swap(len - 2, len - 1);
}
}
}
Ok(Shape::new(result_dims))
}
LazyOp::Reshape(_, new_shape) => Ok(new_shape.clone()),
LazyOp::Sum(input, dim) => {
let input_shape = self.compute_shape(input)?;
match dim {
Some(d) => {
let mut result_dims = input_shape.dims().to_vec();
if *d >= 0 && (*d as usize) < result_dims.len() {
result_dims.remove(*d as usize);
}
Ok(Shape::new(result_dims))
}
None => Ok(Shape::new(vec![])), }
}
LazyOp::Mean(input, dim) => self.compute_shape(&LazyOp::Sum(input.clone(), *dim)),
LazyOp::ReLU(input) | LazyOp::Sigmoid(input) | LazyOp::Tanh(input) |
LazyOp::Exp(input) | LazyOp::Log(input) | LazyOp::Sin(input) | LazyOp::Cos(input) => {
self.compute_shape(input)
}
LazyOp::Custom(_, input, _) => self.compute_shape(input),
}
}
}
impl<T: TensorElement + num_traits::Float + std::iter::Sum + torsh_core::dtype::FloatElement + num_traits::FromPrimitive + Into<f32>> LazyTensor<T> {
pub fn add(self, other: LazyTensor<T>) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Add(
Box::new(self.operation),
Box::new(other.operation),
))
}
pub fn mul(self, other: LazyTensor<T>) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Mul(
Box::new(self.operation),
Box::new(other.operation),
))
}
pub fn sub(self, other: LazyTensor<T>) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Sub(
Box::new(self.operation),
Box::new(other.operation),
))
}
pub fn div(self, other: LazyTensor<T>) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Div(
Box::new(self.operation),
Box::new(other.operation),
))
}
pub fn add_scalar(self, scalar: T) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::AddScalar(
Box::new(self.operation),
scalar,
))
}
pub fn mul_scalar(self, scalar: T) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::MulScalar(
Box::new(self.operation),
scalar,
))
}
pub fn sub_scalar(self, scalar: T) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::SubScalar(
Box::new(self.operation),
scalar,
))
}
pub fn div_scalar(self, scalar: T) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::DivScalar(
Box::new(self.operation),
scalar,
))
}
pub fn pow(self, exponent: T) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Pow(
Box::new(self.operation),
exponent,
))
}
pub fn matmul(self, other: LazyTensor<T>) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::MatMul(
Box::new(self.operation),
Box::new(other.operation),
))
}
pub fn transpose(self, dim1: usize, dim2: usize) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Transpose(
Box::new(self.operation),
Some((dim1, dim2)),
))
}
pub fn t(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Transpose(
Box::new(self.operation),
None,
))
}
pub fn reshape(self, shape: &[usize]) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Reshape(
Box::new(self.operation),
Shape::new(shape.to_vec()),
))
}
pub fn sum(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Sum(
Box::new(self.operation),
None,
))
}
pub fn sum_dim(self, dim: i32) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Sum(
Box::new(self.operation),
Some(dim),
))
}
pub fn mean(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Mean(
Box::new(self.operation),
None,
))
}
pub fn mean_dim(self, dim: i32) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Mean(
Box::new(self.operation),
Some(dim),
))
}
pub fn relu(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::ReLU(
Box::new(self.operation),
))
}
pub fn sigmoid(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Sigmoid(
Box::new(self.operation),
))
}
pub fn tanh(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Tanh(
Box::new(self.operation),
))
}
pub fn exp(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Exp(
Box::new(self.operation),
))
}
pub fn log(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Log(
Box::new(self.operation),
))
}
pub fn sin(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Sin(
Box::new(self.operation),
))
}
pub fn cos(self) -> LazyTensor<T> {
LazyTensor::from_operation(LazyOp::Cos(
Box::new(self.operation),
))
}
pub fn custom<F>(self, name: String, func: F) -> LazyTensor<T>
where
F: Fn(&Tensor<T>) -> Result<Tensor<T>> + Send + Sync + 'static,
{
LazyTensor::from_operation(LazyOp::Custom(
name,
Box::new(self.operation),
Arc::new(func),
))
}
}
pub trait TensorLazyExt<T: TensorElement> {
fn lazy(self) -> LazyTensor<T>;
}
impl<T: TensorElement + Into<f32> + std::iter::Sum + num_traits::FromPrimitive + torsh_core::dtype::FloatElement> TensorLazyExt<T> for Tensor<T> {
fn lazy(self) -> LazyTensor<T> {
LazyTensor::from_tensor(self)
}
}
pub mod optimizations {
use super::*;
pub fn constant_folding(op: &LazyOp<f32>) -> LazyOp<f32> {
match op {
LazyOp::AddScalar(inner_box, s2) => {
if let LazyOp::AddScalar(inner, s1) = &**inner_box {
LazyOp::AddScalar(
Box::new(constant_folding(inner)),
s1 + s2,
)
} else {
op.clone()
}
}
LazyOp::MulScalar(inner_box, s2) => {
if let LazyOp::MulScalar(inner, s1) = &**inner_box {
LazyOp::MulScalar(
Box::new(constant_folding(inner)),
s1 * s2,
)
} else {
op.clone()
}
}
_ => op.clone(),
}
}
pub fn dead_code_elimination(op: &LazyOp<f32>) -> LazyOp<f32> {
match op {
LazyOp::MulScalar(inner, scalar) if *scalar == 0.0 => {
op.clone()
}
LazyOp::AddScalar(inner, scalar) if *scalar == 0.0 => {
constant_folding(inner)
}
LazyOp::MulScalar(inner, scalar) if *scalar == 1.0 => {
constant_folding(inner)
}
_ => op.clone(),
}
}
pub fn operation_fusion(op: &LazyOp<f32>) -> LazyOp<f32> {
match op {
LazyOp::ReLU(inner_box) => {
if let LazyOp::AddScalar(inner, scalar) = &**inner_box {
LazyOp::Custom(
"fused_add_relu".to_string(),
Box::new(operation_fusion(inner)),
Arc::new({
let s = *scalar;
move |tensor: &Tensor<f32>| -> Result<Tensor<f32>> {
let added = tensor.add_scalar(s)?;
added.relu()
}
}),
)
} else {
op.clone()
}
}
_ => op.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Tensor;
use torsh_core::device::DeviceType;
fn create_test_tensor() -> Tensor<f32> {
Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
DeviceType::Cpu,
).expect("tensor creation should succeed")
}
#[test]
fn test_lazy_tensor_creation() {
let tensor = create_test_tensor();
let lazy = tensor.lazy();
let shape = lazy.shape().expect("shape should be available");
assert_eq!(shape.dims(), &[2, 2]);
}
#[test]
fn test_lazy_operation_chaining() {
let tensor1 = create_test_tensor();
let tensor2 = create_test_tensor();
let result = tensor1.lazy()
.add(tensor2.lazy())
.mul_scalar(2.0)
.relu()
.eval()
.expect("addition should succeed");
let expected_data = vec![4.0, 8.0, 12.0, 16.0];
let result_data = result.to_vec().expect("to_vec conversion should succeed");
for (expected, actual) in expected_data.iter().zip(result_data.iter()) {
assert!((expected - actual).abs() < f32::EPSILON);
}
}
#[test]
fn test_lazy_matmul() {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
DeviceType::Cpu,
).expect("tensor creation should succeed");
let b = Tensor::from_data(
vec![2.0, 0.0, 1.0, 3.0],
vec![2, 2],
DeviceType::Cpu,
).expect("tensor creation should succeed");
let result = a.lazy()
.matmul(b.lazy())
.eval()
.expect("tensor creation should succeed");
let expected_data = vec![4.0, 6.0, 10.0, 12.0];
let result_data = result.to_vec().expect("to_vec conversion should succeed");
for (expected, actual) in expected_data.iter().zip(result_data.iter()) {
assert!((expected - actual).abs() < f32::EPSILON);
}
}
#[test]
fn test_lazy_reshape_and_transpose() {
let tensor = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DeviceType::Cpu,
).expect("tensor creation should succeed");
let result = tensor.lazy()
.reshape(&[3, 2])
.t()
.eval()
.expect("tensor creation should succeed");
assert_eq!(result.shape().dims(), &[2, 3]);
let expected_data = vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0];
let result_data = result.to_vec().expect("to_vec conversion should succeed");
for (expected, actual) in expected_data.iter().zip(result_data.iter()) {
assert!((expected - actual).abs() < f32::EPSILON);
}
}
#[test]
fn test_lazy_reductions() {
let tensor = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![2, 3],
DeviceType::Cpu,
).expect("tensor creation should succeed");
let sum_result = tensor.clone().lazy()
.sum_dim(1)
.eval()
.expect("tensor creation should succeed");
let expected_sum = vec![6.0, 15.0];
let result_sum = sum_result.to_vec().expect("to_vec conversion should succeed");
for (expected, actual) in expected_sum.iter().zip(result_sum.iter()) {
assert!((expected - actual).abs() < f32::EPSILON);
}
let mean_result = tensor.lazy()
.mean()
.eval()
.expect("mean should succeed");
let result_mean = mean_result.to_vec().expect("to_vec conversion should succeed");
assert!((result_mean[0] - 3.5).abs() < f32::EPSILON);
}
#[test]
fn test_custom_operation() {
let tensor = create_test_tensor();
let result = tensor.lazy()
.custom(
"square".to_string(),
|t: &Tensor<f32>| -> Result<Tensor<f32>> {
t.pow(2.0)
},
)
.eval()
.expect("pow should succeed");
let expected_data = vec![1.0, 4.0, 9.0, 16.0];
let result_data = result.to_vec().expect("to_vec conversion should succeed");
for (expected, actual) in expected_data.iter().zip(result_data.iter()) {
assert!((expected - actual).abs() < f32::EPSILON);
}
}
#[test]
fn test_shape_inference() {
let tensor = create_test_tensor();
let lazy_reshaped = tensor.lazy().reshape(&[4, 1]);
let shape = lazy_reshaped.shape().expect("shape should be available");
assert_eq!(shape.dims(), &[4, 1]);
let lazy_transposed = lazy_reshaped.t();
let transposed_shape = lazy_transposed.shape().expect("shape should be available");
assert_eq!(transposed_shape.dims(), &[1, 4]);
}
#[test]
fn test_complex_chain() {
let a = create_test_tensor(); let b = create_test_tensor();
let result = a.lazy()
.add(b.lazy())
.mul_scalar(2.0)
.sub_scalar(1.0)
.exp()
.sum()
.eval()
.expect("addition should succeed");
let result_val = result.to_vec().expect("to_vec conversion should succeed")[0];
let expected = 3.0_f32.exp() + 7.0_f32.exp() + 11.0_f32.exp() + 15.0_f32.exp();
assert!((result_val - expected).abs() < 1e-4);
}
}