Skip to main content

axonml_tensor/
lazy.rs

1//! Lazy Tensor - Deferred Computation with Graph Optimization
2//!
3//! # File
4//! `crates/axonml-tensor/src/lazy.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::tensor::Tensor;
18
19// =============================================================================
20// LazyOp - Deferred Operation Graph
21// =============================================================================
22
23/// Represents a deferred tensor operation in the computation graph.
24///
25/// Each variant captures the operation and its operands (as sub-graphs),
26/// forming a tree that can be optimized before execution.
27#[derive(Debug, Clone)]
28pub enum LazyOp {
29    /// Leaf: a materialized tensor, referenced by index into the tensor store.
30    Tensor(usize),
31
32    // Unary ops
33    /// Element-wise negation.
34    Neg(Box<LazyOp>),
35    /// ReLU activation: max(0, x).
36    Relu(Box<LazyOp>),
37    /// Sigmoid activation: 1 / (1 + exp(-x)).
38    Sigmoid(Box<LazyOp>),
39    /// Element-wise exponential.
40    Exp(Box<LazyOp>),
41    /// Element-wise natural logarithm.
42    Log(Box<LazyOp>),
43    /// Element-wise square root.
44    Sqrt(Box<LazyOp>),
45    /// Element-wise absolute value.
46    Abs(Box<LazyOp>),
47
48    // Binary ops
49    /// Element-wise addition.
50    Add(Box<LazyOp>, Box<LazyOp>),
51    /// Element-wise subtraction.
52    Sub(Box<LazyOp>, Box<LazyOp>),
53    /// Element-wise multiplication.
54    Mul(Box<LazyOp>, Box<LazyOp>),
55    /// Element-wise division.
56    Div(Box<LazyOp>, Box<LazyOp>),
57
58    // Reductions
59    /// Sum of all elements (reduces to scalar).
60    Sum(Box<LazyOp>),
61    /// Mean of all elements (reduces to scalar).
62    Mean(Box<LazyOp>),
63
64    // Shape ops
65    /// Reshape to a new shape.
66    Reshape(Box<LazyOp>, Vec<usize>),
67    /// Transpose two dimensions.
68    Transpose(Box<LazyOp>, usize, usize),
69
70    // Scalar ops
71    /// Add a scalar to every element.
72    AddScalar(Box<LazyOp>, f32),
73    /// Multiply every element by a scalar.
74    MulScalar(Box<LazyOp>, f32),
75}
76
77// =============================================================================
78// LazyTensor - Deferred Execution Tensor
79// =============================================================================
80
81/// A deferred-execution tensor that records operations without running them.
82///
83/// Operations on `LazyTensor` build up a computation graph (via `LazyOp`) that
84/// is only executed when `materialize()` is called. Before materialization, the
85/// graph can be optimized via `optimize()` to apply algebraic simplifications
86/// such as constant folding, identity elimination, and inverse cancellation.
87#[derive(Debug, Clone)]
88pub struct LazyTensor {
89    /// The root operation of this lazy computation graph.
90    op: LazyOp,
91    /// Inferred output shape of this tensor.
92    shape: Vec<usize>,
93    /// Concrete tensors referenced by `LazyOp::Tensor` indices.
94    tensors: Vec<Tensor<f32>>,
95}
96
97impl LazyTensor {
98    // =========================================================================
99    // Constructors
100    // =========================================================================
101
102    /// Wraps a concrete tensor into a lazy tensor.
103    ///
104    /// The tensor is stored internally and referenced by a `LazyOp::Tensor` node.
105    pub fn from_tensor(tensor: Tensor<f32>) -> Self {
106        let shape = tensor.shape().to_vec();
107        Self {
108            op: LazyOp::Tensor(0),
109            shape,
110            tensors: vec![tensor],
111        }
112    }
113
114    /// Creates a lazy tensor filled with zeros (materialized on demand).
115    pub fn zeros(shape: &[usize]) -> Self {
116        let tensor = Tensor::<f32>::zeros(shape);
117        Self::from_tensor(tensor)
118    }
119
120    /// Creates a lazy tensor filled with ones (materialized on demand).
121    pub fn ones(shape: &[usize]) -> Self {
122        let tensor = Tensor::<f32>::ones(shape);
123        Self::from_tensor(tensor)
124    }
125
126    // =========================================================================
127    // Unary Operations
128    // =========================================================================
129
130    /// Element-wise negation.
131    pub fn neg(&self) -> LazyTensor {
132        LazyTensor {
133            op: LazyOp::Neg(Box::new(self.op.clone())),
134            shape: self.shape.clone(),
135            tensors: self.tensors.clone(),
136        }
137    }
138
139    /// ReLU activation: max(0, x).
140    pub fn relu(&self) -> LazyTensor {
141        LazyTensor {
142            op: LazyOp::Relu(Box::new(self.op.clone())),
143            shape: self.shape.clone(),
144            tensors: self.tensors.clone(),
145        }
146    }
147
148    /// Sigmoid activation: 1 / (1 + exp(-x)).
149    pub fn sigmoid(&self) -> LazyTensor {
150        LazyTensor {
151            op: LazyOp::Sigmoid(Box::new(self.op.clone())),
152            shape: self.shape.clone(),
153            tensors: self.tensors.clone(),
154        }
155    }
156
157    /// Element-wise exponential.
158    pub fn exp(&self) -> LazyTensor {
159        LazyTensor {
160            op: LazyOp::Exp(Box::new(self.op.clone())),
161            shape: self.shape.clone(),
162            tensors: self.tensors.clone(),
163        }
164    }
165
166    /// Element-wise natural logarithm.
167    pub fn log(&self) -> LazyTensor {
168        LazyTensor {
169            op: LazyOp::Log(Box::new(self.op.clone())),
170            shape: self.shape.clone(),
171            tensors: self.tensors.clone(),
172        }
173    }
174
175    /// Element-wise square root.
176    pub fn sqrt(&self) -> LazyTensor {
177        LazyTensor {
178            op: LazyOp::Sqrt(Box::new(self.op.clone())),
179            shape: self.shape.clone(),
180            tensors: self.tensors.clone(),
181        }
182    }
183
184    /// Element-wise absolute value.
185    pub fn abs(&self) -> LazyTensor {
186        LazyTensor {
187            op: LazyOp::Abs(Box::new(self.op.clone())),
188            shape: self.shape.clone(),
189            tensors: self.tensors.clone(),
190        }
191    }
192
193    // =========================================================================
194    // Binary Operations
195    // =========================================================================
196
197    /// Merges two LazyTensors' tensor stores and remaps indices in the right operand.
198    ///
199    /// Returns `(merged_tensors, remapped_right_op)`.
200    fn merge_stores(
201        left_tensors: &[Tensor<f32>],
202        right_tensors: &[Tensor<f32>],
203        right_op: &LazyOp,
204    ) -> (Vec<Tensor<f32>>, LazyOp) {
205        let offset = left_tensors.len();
206        let mut merged = left_tensors.to_vec();
207        merged.extend(right_tensors.iter().cloned());
208        let remapped = Self::remap_indices(right_op, offset);
209        (merged, remapped)
210    }
211
212    /// Recursively remaps `LazyOp::Tensor` indices by adding an offset.
213    fn remap_indices(op: &LazyOp, offset: usize) -> LazyOp {
214        match op {
215            LazyOp::Tensor(idx) => LazyOp::Tensor(idx + offset),
216
217            LazyOp::Neg(a) => LazyOp::Neg(Box::new(Self::remap_indices(a, offset))),
218            LazyOp::Relu(a) => LazyOp::Relu(Box::new(Self::remap_indices(a, offset))),
219            LazyOp::Sigmoid(a) => LazyOp::Sigmoid(Box::new(Self::remap_indices(a, offset))),
220            LazyOp::Exp(a) => LazyOp::Exp(Box::new(Self::remap_indices(a, offset))),
221            LazyOp::Log(a) => LazyOp::Log(Box::new(Self::remap_indices(a, offset))),
222            LazyOp::Sqrt(a) => LazyOp::Sqrt(Box::new(Self::remap_indices(a, offset))),
223            LazyOp::Abs(a) => LazyOp::Abs(Box::new(Self::remap_indices(a, offset))),
224
225            LazyOp::Add(a, b) => LazyOp::Add(
226                Box::new(Self::remap_indices(a, offset)),
227                Box::new(Self::remap_indices(b, offset)),
228            ),
229            LazyOp::Sub(a, b) => LazyOp::Sub(
230                Box::new(Self::remap_indices(a, offset)),
231                Box::new(Self::remap_indices(b, offset)),
232            ),
233            LazyOp::Mul(a, b) => LazyOp::Mul(
234                Box::new(Self::remap_indices(a, offset)),
235                Box::new(Self::remap_indices(b, offset)),
236            ),
237            LazyOp::Div(a, b) => LazyOp::Div(
238                Box::new(Self::remap_indices(a, offset)),
239                Box::new(Self::remap_indices(b, offset)),
240            ),
241
242            LazyOp::Sum(a) => LazyOp::Sum(Box::new(Self::remap_indices(a, offset))),
243            LazyOp::Mean(a) => LazyOp::Mean(Box::new(Self::remap_indices(a, offset))),
244
245            LazyOp::Reshape(a, s) => {
246                LazyOp::Reshape(Box::new(Self::remap_indices(a, offset)), s.clone())
247            }
248            LazyOp::Transpose(a, d0, d1) => {
249                LazyOp::Transpose(Box::new(Self::remap_indices(a, offset)), *d0, *d1)
250            }
251
252            LazyOp::AddScalar(a, s) => {
253                LazyOp::AddScalar(Box::new(Self::remap_indices(a, offset)), *s)
254            }
255            LazyOp::MulScalar(a, s) => {
256                LazyOp::MulScalar(Box::new(Self::remap_indices(a, offset)), *s)
257            }
258        }
259    }
260
261    /// Creates a binary operation LazyTensor, merging tensor stores from both operands.
262    fn binary_op(
263        &self,
264        other: &LazyTensor,
265        make_op: impl FnOnce(Box<LazyOp>, Box<LazyOp>) -> LazyOp,
266        shape: Vec<usize>,
267    ) -> LazyTensor {
268        let (merged, remapped_right) = Self::merge_stores(&self.tensors, &other.tensors, &other.op);
269        LazyTensor {
270            op: make_op(Box::new(self.op.clone()), Box::new(remapped_right)),
271            shape,
272            tensors: merged,
273        }
274    }
275
276    /// Element-wise addition. Shapes must match.
277    pub fn add(&self, other: &LazyTensor) -> LazyTensor {
278        assert_eq!(self.shape, other.shape, "LazyTensor add: shapes must match");
279        self.binary_op(other, LazyOp::Add, self.shape.clone())
280    }
281
282    /// Element-wise subtraction. Shapes must match.
283    pub fn sub(&self, other: &LazyTensor) -> LazyTensor {
284        assert_eq!(self.shape, other.shape, "LazyTensor sub: shapes must match");
285        self.binary_op(other, LazyOp::Sub, self.shape.clone())
286    }
287
288    /// Element-wise multiplication. Shapes must match.
289    pub fn mul(&self, other: &LazyTensor) -> LazyTensor {
290        assert_eq!(self.shape, other.shape, "LazyTensor mul: shapes must match");
291        self.binary_op(other, LazyOp::Mul, self.shape.clone())
292    }
293
294    /// Element-wise division. Shapes must match.
295    pub fn div(&self, other: &LazyTensor) -> LazyTensor {
296        assert_eq!(self.shape, other.shape, "LazyTensor div: shapes must match");
297        self.binary_op(other, LazyOp::Div, self.shape.clone())
298    }
299
300    // =========================================================================
301    // Scalar Operations
302    // =========================================================================
303
304    /// Adds a scalar to every element.
305    pub fn add_scalar(&self, s: f32) -> LazyTensor {
306        LazyTensor {
307            op: LazyOp::AddScalar(Box::new(self.op.clone()), s),
308            shape: self.shape.clone(),
309            tensors: self.tensors.clone(),
310        }
311    }
312
313    /// Multiplies every element by a scalar.
314    pub fn mul_scalar(&self, s: f32) -> LazyTensor {
315        LazyTensor {
316            op: LazyOp::MulScalar(Box::new(self.op.clone()), s),
317            shape: self.shape.clone(),
318            tensors: self.tensors.clone(),
319        }
320    }
321
322    // =========================================================================
323    // Reductions
324    // =========================================================================
325
326    /// Sum of all elements (reduces to scalar shape).
327    pub fn sum(&self) -> LazyTensor {
328        LazyTensor {
329            op: LazyOp::Sum(Box::new(self.op.clone())),
330            shape: vec![],
331            tensors: self.tensors.clone(),
332        }
333    }
334
335    /// Mean of all elements (reduces to scalar shape).
336    pub fn mean(&self) -> LazyTensor {
337        LazyTensor {
338            op: LazyOp::Mean(Box::new(self.op.clone())),
339            shape: vec![],
340            tensors: self.tensors.clone(),
341        }
342    }
343
344    // =========================================================================
345    // Shape Operations
346    // =========================================================================
347
348    /// Reshapes the lazy tensor to a new shape.
349    ///
350    /// The total number of elements must remain the same.
351    pub fn reshape(&self, shape: &[usize]) -> LazyTensor {
352        let old_numel: usize = self.shape.iter().product();
353        let new_numel: usize = shape.iter().product();
354        assert_eq!(
355            old_numel, new_numel,
356            "LazyTensor reshape: element count mismatch ({old_numel} vs {new_numel})"
357        );
358        LazyTensor {
359            op: LazyOp::Reshape(Box::new(self.op.clone()), shape.to_vec()),
360            shape: shape.to_vec(),
361            tensors: self.tensors.clone(),
362        }
363    }
364
365    // =========================================================================
366    // Graph Inspection
367    // =========================================================================
368
369    /// Returns the inferred output shape.
370    pub fn shape(&self) -> &[usize] {
371        &self.shape
372    }
373
374    /// Counts the number of operations in the computation graph.
375    ///
376    /// Leaf `Tensor` nodes count as 0 operations; every other node counts as 1
377    /// plus the count of its children.
378    pub fn op_count(&self) -> usize {
379        Self::count_ops(&self.op)
380    }
381
382    fn count_ops(op: &LazyOp) -> usize {
383        match op {
384            LazyOp::Tensor(_) => 0,
385
386            LazyOp::Neg(a)
387            | LazyOp::Relu(a)
388            | LazyOp::Sigmoid(a)
389            | LazyOp::Exp(a)
390            | LazyOp::Log(a)
391            | LazyOp::Sqrt(a)
392            | LazyOp::Abs(a)
393            | LazyOp::Sum(a)
394            | LazyOp::Mean(a)
395            | LazyOp::AddScalar(a, _)
396            | LazyOp::MulScalar(a, _)
397            | LazyOp::Reshape(a, _)
398            | LazyOp::Transpose(a, _, _) => 1 + Self::count_ops(a),
399
400            LazyOp::Add(a, b) | LazyOp::Sub(a, b) | LazyOp::Mul(a, b) | LazyOp::Div(a, b) => {
401                1 + Self::count_ops(a) + Self::count_ops(b)
402            }
403        }
404    }
405
406    // =========================================================================
407    // Materialization
408    // =========================================================================
409
410    /// Executes the recorded computation graph and returns a concrete `Tensor<f32>`.
411    ///
412    /// This recursively evaluates every node in the `LazyOp` tree, starting from
413    /// leaves and combining results upward.
414    pub fn materialize(&self) -> Tensor<f32> {
415        self.eval_op(&self.op)
416    }
417
418    fn eval_op(&self, op: &LazyOp) -> Tensor<f32> {
419        match op {
420            LazyOp::Tensor(idx) => self.tensors[*idx].clone(),
421
422            // Unary
423            LazyOp::Neg(a) => self.eval_op(a).neg(),
424            LazyOp::Relu(a) => self.eval_op(a).relu(),
425            LazyOp::Sigmoid(a) => self.eval_op(a).sigmoid(),
426            LazyOp::Exp(a) => self.eval_op(a).exp(),
427            LazyOp::Log(a) => self.eval_op(a).ln(),
428            LazyOp::Sqrt(a) => self.eval_op(a).sqrt(),
429            LazyOp::Abs(a) => {
430                let t = self.eval_op(a);
431                let data: Vec<f32> = t.to_vec().iter().map(|x| x.abs()).collect();
432                Tensor::from_vec(data, t.shape()).unwrap()
433            }
434
435            // Binary
436            LazyOp::Add(a, b) => {
437                let ta = self.eval_op(a);
438                let tb = self.eval_op(b);
439                ta.add(&tb).unwrap()
440            }
441            LazyOp::Sub(a, b) => {
442                let ta = self.eval_op(a);
443                let tb = self.eval_op(b);
444                ta.sub(&tb).unwrap()
445            }
446            LazyOp::Mul(a, b) => {
447                let ta = self.eval_op(a);
448                let tb = self.eval_op(b);
449                ta.mul(&tb).unwrap()
450            }
451            LazyOp::Div(a, b) => {
452                let ta = self.eval_op(a);
453                let tb = self.eval_op(b);
454                ta.div(&tb).unwrap()
455            }
456
457            // Reductions
458            LazyOp::Sum(a) => self.eval_op(a).sum(),
459            LazyOp::Mean(a) => self.eval_op(a).mean().unwrap(),
460
461            // Shape ops
462            LazyOp::Reshape(a, shape) => {
463                let t = self.eval_op(a);
464                let isize_shape: Vec<isize> = shape.iter().map(|&s| s as isize).collect();
465                t.reshape(&isize_shape).unwrap()
466            }
467            LazyOp::Transpose(a, d0, d1) => {
468                let t = self.eval_op(a);
469                t.transpose(*d0 as i64, *d1 as i64).unwrap()
470            }
471
472            // Scalar ops
473            LazyOp::AddScalar(a, s) => self.eval_op(a).add_scalar(*s),
474            LazyOp::MulScalar(a, s) => self.eval_op(a).mul_scalar(*s),
475        }
476    }
477
478    // =========================================================================
479    // Optimization
480    // =========================================================================
481
482    /// Applies algebraic simplifications to the computation graph.
483    ///
484    /// Currently supported optimizations:
485    /// - **Identity elimination**: `x + 0 -> x`, `x * 1 -> x`
486    /// - **Zero multiplication**: `x * 0 -> zeros`
487    /// - **Double negation**: `neg(neg(x)) -> x`
488    /// - **Inverse cancellation**: `exp(log(x)) -> x`, `log(exp(x)) -> x`
489    /// - **Scalar folding**: `x * s1 * s2 -> x * (s1*s2)`, `x + s1 + s2 -> x + (s1+s2)`
490    pub fn optimize(&self) -> LazyTensor {
491        LazyTensor {
492            op: Self::optimize_op(&self.op),
493            shape: self.shape.clone(),
494            tensors: self.tensors.clone(),
495        }
496    }
497
498    fn optimize_op(op: &LazyOp) -> LazyOp {
499        // First, recursively optimize children
500        let op = Self::optimize_children(op);
501        // Then apply local simplifications
502        Self::simplify(&op)
503    }
504
505    fn optimize_children(op: &LazyOp) -> LazyOp {
506        match op {
507            LazyOp::Tensor(idx) => LazyOp::Tensor(*idx),
508
509            LazyOp::Neg(a) => LazyOp::Neg(Box::new(Self::optimize_op(a))),
510            LazyOp::Relu(a) => LazyOp::Relu(Box::new(Self::optimize_op(a))),
511            LazyOp::Sigmoid(a) => LazyOp::Sigmoid(Box::new(Self::optimize_op(a))),
512            LazyOp::Exp(a) => LazyOp::Exp(Box::new(Self::optimize_op(a))),
513            LazyOp::Log(a) => LazyOp::Log(Box::new(Self::optimize_op(a))),
514            LazyOp::Sqrt(a) => LazyOp::Sqrt(Box::new(Self::optimize_op(a))),
515            LazyOp::Abs(a) => LazyOp::Abs(Box::new(Self::optimize_op(a))),
516
517            LazyOp::Add(a, b) => LazyOp::Add(
518                Box::new(Self::optimize_op(a)),
519                Box::new(Self::optimize_op(b)),
520            ),
521            LazyOp::Sub(a, b) => LazyOp::Sub(
522                Box::new(Self::optimize_op(a)),
523                Box::new(Self::optimize_op(b)),
524            ),
525            LazyOp::Mul(a, b) => LazyOp::Mul(
526                Box::new(Self::optimize_op(a)),
527                Box::new(Self::optimize_op(b)),
528            ),
529            LazyOp::Div(a, b) => LazyOp::Div(
530                Box::new(Self::optimize_op(a)),
531                Box::new(Self::optimize_op(b)),
532            ),
533
534            LazyOp::Sum(a) => LazyOp::Sum(Box::new(Self::optimize_op(a))),
535            LazyOp::Mean(a) => LazyOp::Mean(Box::new(Self::optimize_op(a))),
536
537            LazyOp::Reshape(a, s) => LazyOp::Reshape(Box::new(Self::optimize_op(a)), s.clone()),
538            LazyOp::Transpose(a, d0, d1) => {
539                LazyOp::Transpose(Box::new(Self::optimize_op(a)), *d0, *d1)
540            }
541
542            LazyOp::AddScalar(a, s) => LazyOp::AddScalar(Box::new(Self::optimize_op(a)), *s),
543            LazyOp::MulScalar(a, s) => LazyOp::MulScalar(Box::new(Self::optimize_op(a)), *s),
544        }
545    }
546
547    fn simplify(op: &LazyOp) -> LazyOp {
548        match op {
549            // neg(neg(x)) -> x
550            LazyOp::Neg(inner) => {
551                if let LazyOp::Neg(x) = inner.as_ref() {
552                    return *x.clone();
553                }
554                op.clone()
555            }
556
557            // exp(log(x)) -> x
558            LazyOp::Exp(inner) => {
559                if let LazyOp::Log(x) = inner.as_ref() {
560                    return *x.clone();
561                }
562                op.clone()
563            }
564
565            // log(exp(x)) -> x
566            LazyOp::Log(inner) => {
567                if let LazyOp::Exp(x) = inner.as_ref() {
568                    return *x.clone();
569                }
570                op.clone()
571            }
572
573            // x + 0 -> x  (AddScalar with 0)
574            LazyOp::AddScalar(a, s) if *s == 0.0 => *a.clone(),
575
576            // x * 1 -> x  (MulScalar with 1)
577            LazyOp::MulScalar(a, s) if (*s - 1.0).abs() < f32::EPSILON => *a.clone(),
578
579            // x * 0 -> zeros (handled at materialize; here we keep MulScalar(x,0))
580            // We keep it as-is since we don't know the shape at this level easily.
581            // But we can still fold scalars:
582
583            // Scalar folding: (x + s1) + s2 -> x + (s1 + s2)
584            LazyOp::AddScalar(inner, s2) => {
585                if let LazyOp::AddScalar(x, s1) = inner.as_ref() {
586                    return LazyOp::AddScalar(x.clone(), s1 + s2);
587                }
588                op.clone()
589            }
590
591            // Scalar folding: (x * s1) * s2 -> x * (s1 * s2)
592            LazyOp::MulScalar(inner, s2) => {
593                if let LazyOp::MulScalar(x, s1) = inner.as_ref() {
594                    return LazyOp::MulScalar(x.clone(), s1 * s2);
595                }
596                op.clone()
597            }
598
599            _ => op.clone(),
600        }
601    }
602}
603
604// =============================================================================
605// Tests
606// =============================================================================
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    fn approx_eq(a: &[f32], b: &[f32], tol: f32) {
613        assert_eq!(
614            a.len(),
615            b.len(),
616            "length mismatch: {} vs {}",
617            a.len(),
618            b.len()
619        );
620        for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
621            assert!(
622                (x - y).abs() < tol,
623                "element {i}: {x} vs {y} (diff = {})",
624                (x - y).abs()
625            );
626        }
627    }
628
629    #[test]
630    fn test_from_tensor_preserves_shape() {
631        let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
632        let lazy = LazyTensor::from_tensor(t.clone());
633        assert_eq!(lazy.shape(), &[2, 3]);
634        let result = lazy.materialize();
635        assert_eq!(result.shape(), &[2, 3]);
636        assert_eq!(result.to_vec(), t.to_vec());
637    }
638
639    #[test]
640    fn test_zeros_creation() {
641        let lazy = LazyTensor::zeros(&[3, 4]);
642        assert_eq!(lazy.shape(), &[3, 4]);
643        let result = lazy.materialize();
644        assert_eq!(result.to_vec(), vec![0.0; 12]);
645    }
646
647    #[test]
648    fn test_ones_creation() {
649        let lazy = LazyTensor::ones(&[2, 3]);
650        assert_eq!(lazy.shape(), &[2, 3]);
651        let result = lazy.materialize();
652        assert_eq!(result.to_vec(), vec![1.0; 6]);
653    }
654
655    #[test]
656    fn test_add_two_lazy_tensors() {
657        let a = LazyTensor::from_tensor(
658            Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
659        );
660        let b = LazyTensor::from_tensor(
661            Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2]).unwrap(),
662        );
663        let c = a.add(&b);
664        assert_eq!(c.shape(), &[2, 2]);
665        let result = c.materialize();
666        assert_eq!(result.to_vec(), vec![11.0, 22.0, 33.0, 44.0]);
667    }
668
669    #[test]
670    fn test_sub_two_lazy_tensors() {
671        let a =
672            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap());
673        let b =
674            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
675        let c = a.sub(&b);
676        assert_eq!(c.materialize().to_vec(), vec![9.0, 18.0, 27.0]);
677    }
678
679    #[test]
680    fn test_mul_two_lazy_tensors() {
681        let a =
682            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap());
683        let b =
684            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0], &[3]).unwrap());
685        let c = a.mul(&b);
686        assert_eq!(c.materialize().to_vec(), vec![10.0, 18.0, 28.0]);
687    }
688
689    #[test]
690    fn test_div_two_lazy_tensors() {
691        let a =
692            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap());
693        let b =
694            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0, 4.0, 5.0], &[3]).unwrap());
695        let c = a.div(&b);
696        assert_eq!(c.materialize().to_vec(), vec![5.0, 5.0, 6.0]);
697    }
698
699    #[test]
700    fn test_neg_lazy_tensor() {
701        let a =
702            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, -2.0, 3.0], &[3]).unwrap());
703        let result = a.neg().materialize();
704        assert_eq!(result.to_vec(), vec![-1.0, 2.0, -3.0]);
705    }
706
707    #[test]
708    fn test_relu_correctness() {
709        let a = LazyTensor::from_tensor(
710            Tensor::<f32>::from_vec(vec![-3.0, -1.0, 0.0, 1.0, 3.0], &[5]).unwrap(),
711        );
712        let result = a.relu().materialize();
713        assert_eq!(result.to_vec(), vec![0.0, 0.0, 0.0, 1.0, 3.0]);
714    }
715
716    #[test]
717    fn test_sigmoid_correctness() {
718        let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![0.0], &[1]).unwrap());
719        let result = a.sigmoid().materialize();
720        approx_eq(&result.to_vec(), &[0.5], 1e-6);
721    }
722
723    #[test]
724    fn test_exp_correctness() {
725        let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2]).unwrap());
726        let result = a.exp().materialize();
727        approx_eq(&result.to_vec(), &[1.0, std::f32::consts::E], 1e-5);
728    }
729
730    #[test]
731    fn test_log_correctness() {
732        let a = LazyTensor::from_tensor(
733            Tensor::<f32>::from_vec(vec![1.0, std::f32::consts::E], &[2]).unwrap(),
734        );
735        let result = a.log().materialize();
736        approx_eq(&result.to_vec(), &[0.0, 1.0], 1e-5);
737    }
738
739    #[test]
740    fn test_add_scalar_correctness() {
741        let a =
742            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
743        let result = a.add_scalar(10.0).materialize();
744        assert_eq!(result.to_vec(), vec![11.0, 12.0, 13.0]);
745    }
746
747    #[test]
748    fn test_mul_scalar_correctness() {
749        let a =
750            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
751        let result = a.mul_scalar(3.0).materialize();
752        assert_eq!(result.to_vec(), vec![3.0, 6.0, 9.0]);
753    }
754
755    #[test]
756    fn test_sum_reduction() {
757        let a = LazyTensor::from_tensor(
758            Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
759        );
760        let result = a.sum().materialize();
761        assert_eq!(result.shape(), &[] as &[usize]);
762        approx_eq(&result.to_vec(), &[10.0], 1e-6);
763    }
764
765    #[test]
766    fn test_mean_reduction() {
767        let a = LazyTensor::from_tensor(
768            Tensor::<f32>::from_vec(vec![2.0, 4.0, 6.0, 8.0], &[4]).unwrap(),
769        );
770        let result = a.mean().materialize();
771        approx_eq(&result.to_vec(), &[5.0], 1e-6);
772    }
773
774    #[test]
775    fn test_reshape() {
776        let a = LazyTensor::from_tensor(
777            Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
778        );
779        let reshaped = a.reshape(&[3, 2]);
780        assert_eq!(reshaped.shape(), &[3, 2]);
781        let result = reshaped.materialize();
782        assert_eq!(result.shape(), &[3, 2]);
783        assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
784    }
785
786    #[test]
787    fn test_chained_operations() {
788        // x.relu().add_scalar(1.0).mul_scalar(2.0)
789        let x = LazyTensor::from_tensor(
790            Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
791        );
792        let result = x.relu().add_scalar(1.0).mul_scalar(2.0).materialize();
793        // relu: [0, 0, 1, 2] -> +1: [1, 1, 2, 3] -> *2: [2, 2, 4, 6]
794        assert_eq!(result.to_vec(), vec![2.0, 2.0, 4.0, 6.0]);
795    }
796
797    #[test]
798    fn test_op_count_leaf() {
799        let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
800        assert_eq!(x.op_count(), 0);
801    }
802
803    #[test]
804    fn test_op_count_unary() {
805        let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
806        assert_eq!(x.relu().op_count(), 1);
807        assert_eq!(x.relu().neg().op_count(), 2);
808    }
809
810    #[test]
811    fn test_op_count_binary() {
812        let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
813        let b = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0], &[1]).unwrap());
814        // add(leaf, leaf) = 1 op
815        assert_eq!(a.add(&b).op_count(), 1);
816    }
817
818    #[test]
819    fn test_optimize_add_zero() {
820        let x =
821            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
822        let y = x.add_scalar(0.0);
823        assert_eq!(y.op_count(), 1); // AddScalar
824        let opt = y.optimize();
825        assert_eq!(opt.op_count(), 0); // Eliminated
826        assert_eq!(opt.materialize().to_vec(), vec![1.0, 2.0, 3.0]);
827    }
828
829    #[test]
830    fn test_optimize_mul_one() {
831        let x =
832            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap());
833        let y = x.mul_scalar(1.0);
834        assert_eq!(y.op_count(), 1);
835        let opt = y.optimize();
836        assert_eq!(opt.op_count(), 0);
837        assert_eq!(opt.materialize().to_vec(), vec![4.0, 5.0, 6.0]);
838    }
839
840    #[test]
841    fn test_optimize_neg_neg() {
842        let x =
843            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, -2.0, 3.0], &[3]).unwrap());
844        let y = x.neg().neg();
845        assert_eq!(y.op_count(), 2);
846        let opt = y.optimize();
847        assert_eq!(opt.op_count(), 0);
848        assert_eq!(opt.materialize().to_vec(), vec![1.0, -2.0, 3.0]);
849    }
850
851    #[test]
852    fn test_optimize_scalar_folding_mul() {
853        let x =
854            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
855        // x * 2 * 3 should fold to x * 6
856        let y = x.mul_scalar(2.0).mul_scalar(3.0);
857        assert_eq!(y.op_count(), 2);
858        let opt = y.optimize();
859        assert_eq!(opt.op_count(), 1);
860        assert_eq!(opt.materialize().to_vec(), vec![6.0, 12.0, 18.0]);
861    }
862
863    #[test]
864    fn test_optimize_scalar_folding_add() {
865        let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap());
866        // x + 3 + 7 should fold to x + 10
867        let y = x.add_scalar(3.0).add_scalar(7.0);
868        assert_eq!(y.op_count(), 2);
869        let opt = y.optimize();
870        assert_eq!(opt.op_count(), 1);
871        assert_eq!(opt.materialize().to_vec(), vec![11.0, 12.0]);
872    }
873
874    #[test]
875    fn test_optimize_exp_log() {
876        let x =
877            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
878        // exp(log(x)) -> x
879        let y = x.log().exp();
880        assert_eq!(y.op_count(), 2);
881        let opt = y.optimize();
882        assert_eq!(opt.op_count(), 0);
883        assert_eq!(opt.materialize().to_vec(), vec![1.0, 2.0, 3.0]);
884    }
885
886    #[test]
887    fn test_optimize_log_exp() {
888        let x =
889            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
890        // log(exp(x)) -> x
891        let y = x.exp().log();
892        assert_eq!(y.op_count(), 2);
893        let opt = y.optimize();
894        assert_eq!(opt.op_count(), 0);
895        assert_eq!(opt.materialize().to_vec(), vec![1.0, 2.0, 3.0]);
896    }
897
898    #[test]
899    fn test_materialize_matches_eager() {
900        let data = vec![1.0, 2.0, 3.0, 4.0];
901        let t = Tensor::<f32>::from_vec(data.clone(), &[2, 2]).unwrap();
902
903        // Eager: relu -> add_scalar(1) -> mul_scalar(2) -> sum
904        let eager = t.relu().add_scalar(1.0).mul_scalar(2.0).sum();
905
906        // Lazy
907        let lazy = LazyTensor::from_tensor(Tensor::<f32>::from_vec(data, &[2, 2]).unwrap());
908        let lazy_result = lazy
909            .relu()
910            .add_scalar(1.0)
911            .mul_scalar(2.0)
912            .sum()
913            .materialize();
914
915        approx_eq(&eager.to_vec(), &lazy_result.to_vec(), 1e-6);
916    }
917
918    #[test]
919    fn test_large_chain_optimization() {
920        let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![5.0], &[1]).unwrap());
921        // x * 2 * 3 * 4 + 1 + 2 + 3
922        let y = x
923            .mul_scalar(2.0)
924            .mul_scalar(3.0)
925            .mul_scalar(4.0)
926            .add_scalar(1.0)
927            .add_scalar(2.0)
928            .add_scalar(3.0);
929        assert_eq!(y.op_count(), 6);
930        let opt = y.optimize();
931        // mul chain: 3 -> 1 (folded), add chain: 3 -> 1 (folded) = 2 total
932        assert_eq!(opt.op_count(), 2);
933        // 5 * 24 + 6 = 126
934        approx_eq(&opt.materialize().to_vec(), &[126.0], 1e-6);
935    }
936
937    #[test]
938    fn test_binary_ops_tensor_merging() {
939        // Two independent LazyTensors sharing no tensor stores
940        let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap());
941        let b = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap());
942        // a has tensor[0], b has tensor[0] -> merged should have tensor[0], tensor[1]
943        let c = a.add(&b);
944        assert_eq!(c.tensors.len(), 2);
945        let result = c.materialize();
946        assert_eq!(result.to_vec(), vec![4.0, 6.0]);
947    }
948
949    #[test]
950    fn test_binary_ops_chain_merging() {
951        // (a + b) + c — should merge all three tensor stores
952        let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
953        let b = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0], &[1]).unwrap());
954        let c = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![3.0], &[1]).unwrap());
955        let ab = a.add(&b);
956        let abc = ab.add(&c);
957        assert_eq!(abc.tensors.len(), 3);
958        approx_eq(&abc.materialize().to_vec(), &[6.0], 1e-6);
959    }
960
961    #[test]
962    fn test_sqrt_correctness() {
963        let a =
964            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![4.0, 9.0, 16.0], &[3]).unwrap());
965        let result = a.sqrt().materialize();
966        approx_eq(&result.to_vec(), &[2.0, 3.0, 4.0], 1e-6);
967    }
968
969    #[test]
970    fn test_abs_correctness() {
971        let a =
972            LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![-3.0, 0.0, 5.0], &[3]).unwrap());
973        let result = a.abs().materialize();
974        assert_eq!(result.to_vec(), vec![3.0, 0.0, 5.0]);
975    }
976}