1use crate::tensor::Tensor;
18
19#[derive(Debug, Clone)]
28pub enum LazyOp {
29 Tensor(usize),
31
32 Neg(Box<LazyOp>),
35 Relu(Box<LazyOp>),
37 Sigmoid(Box<LazyOp>),
39 Exp(Box<LazyOp>),
41 Log(Box<LazyOp>),
43 Sqrt(Box<LazyOp>),
45 Abs(Box<LazyOp>),
47
48 Add(Box<LazyOp>, Box<LazyOp>),
51 Sub(Box<LazyOp>, Box<LazyOp>),
53 Mul(Box<LazyOp>, Box<LazyOp>),
55 Div(Box<LazyOp>, Box<LazyOp>),
57
58 Sum(Box<LazyOp>),
61 Mean(Box<LazyOp>),
63
64 Reshape(Box<LazyOp>, Vec<usize>),
67 Transpose(Box<LazyOp>, usize, usize),
69
70 AddScalar(Box<LazyOp>, f32),
73 MulScalar(Box<LazyOp>, f32),
75}
76
77#[derive(Debug, Clone)]
88pub struct LazyTensor {
89 op: LazyOp,
91 shape: Vec<usize>,
93 tensors: Vec<Tensor<f32>>,
95}
96
97impl LazyTensor {
98 pub fn from_tensor(tensor: Tensor<f32>) -> Self {
106 let shape = tensor.shape().to_vec();
107 Self {
108 op: LazyOp::Tensor(0),
109 shape,
110 tensors: vec![tensor],
111 }
112 }
113
114 pub fn zeros(shape: &[usize]) -> Self {
116 let tensor = Tensor::<f32>::zeros(shape);
117 Self::from_tensor(tensor)
118 }
119
120 pub fn ones(shape: &[usize]) -> Self {
122 let tensor = Tensor::<f32>::ones(shape);
123 Self::from_tensor(tensor)
124 }
125
126 pub fn neg(&self) -> LazyTensor {
132 LazyTensor {
133 op: LazyOp::Neg(Box::new(self.op.clone())),
134 shape: self.shape.clone(),
135 tensors: self.tensors.clone(),
136 }
137 }
138
139 pub fn relu(&self) -> LazyTensor {
141 LazyTensor {
142 op: LazyOp::Relu(Box::new(self.op.clone())),
143 shape: self.shape.clone(),
144 tensors: self.tensors.clone(),
145 }
146 }
147
148 pub fn sigmoid(&self) -> LazyTensor {
150 LazyTensor {
151 op: LazyOp::Sigmoid(Box::new(self.op.clone())),
152 shape: self.shape.clone(),
153 tensors: self.tensors.clone(),
154 }
155 }
156
157 pub fn exp(&self) -> LazyTensor {
159 LazyTensor {
160 op: LazyOp::Exp(Box::new(self.op.clone())),
161 shape: self.shape.clone(),
162 tensors: self.tensors.clone(),
163 }
164 }
165
166 pub fn log(&self) -> LazyTensor {
168 LazyTensor {
169 op: LazyOp::Log(Box::new(self.op.clone())),
170 shape: self.shape.clone(),
171 tensors: self.tensors.clone(),
172 }
173 }
174
175 pub fn sqrt(&self) -> LazyTensor {
177 LazyTensor {
178 op: LazyOp::Sqrt(Box::new(self.op.clone())),
179 shape: self.shape.clone(),
180 tensors: self.tensors.clone(),
181 }
182 }
183
184 pub fn abs(&self) -> LazyTensor {
186 LazyTensor {
187 op: LazyOp::Abs(Box::new(self.op.clone())),
188 shape: self.shape.clone(),
189 tensors: self.tensors.clone(),
190 }
191 }
192
193 fn merge_stores(
201 left_tensors: &[Tensor<f32>],
202 right_tensors: &[Tensor<f32>],
203 right_op: &LazyOp,
204 ) -> (Vec<Tensor<f32>>, LazyOp) {
205 let offset = left_tensors.len();
206 let mut merged = left_tensors.to_vec();
207 merged.extend(right_tensors.iter().cloned());
208 let remapped = Self::remap_indices(right_op, offset);
209 (merged, remapped)
210 }
211
212 fn remap_indices(op: &LazyOp, offset: usize) -> LazyOp {
214 match op {
215 LazyOp::Tensor(idx) => LazyOp::Tensor(idx + offset),
216
217 LazyOp::Neg(a) => LazyOp::Neg(Box::new(Self::remap_indices(a, offset))),
218 LazyOp::Relu(a) => LazyOp::Relu(Box::new(Self::remap_indices(a, offset))),
219 LazyOp::Sigmoid(a) => LazyOp::Sigmoid(Box::new(Self::remap_indices(a, offset))),
220 LazyOp::Exp(a) => LazyOp::Exp(Box::new(Self::remap_indices(a, offset))),
221 LazyOp::Log(a) => LazyOp::Log(Box::new(Self::remap_indices(a, offset))),
222 LazyOp::Sqrt(a) => LazyOp::Sqrt(Box::new(Self::remap_indices(a, offset))),
223 LazyOp::Abs(a) => LazyOp::Abs(Box::new(Self::remap_indices(a, offset))),
224
225 LazyOp::Add(a, b) => LazyOp::Add(
226 Box::new(Self::remap_indices(a, offset)),
227 Box::new(Self::remap_indices(b, offset)),
228 ),
229 LazyOp::Sub(a, b) => LazyOp::Sub(
230 Box::new(Self::remap_indices(a, offset)),
231 Box::new(Self::remap_indices(b, offset)),
232 ),
233 LazyOp::Mul(a, b) => LazyOp::Mul(
234 Box::new(Self::remap_indices(a, offset)),
235 Box::new(Self::remap_indices(b, offset)),
236 ),
237 LazyOp::Div(a, b) => LazyOp::Div(
238 Box::new(Self::remap_indices(a, offset)),
239 Box::new(Self::remap_indices(b, offset)),
240 ),
241
242 LazyOp::Sum(a) => LazyOp::Sum(Box::new(Self::remap_indices(a, offset))),
243 LazyOp::Mean(a) => LazyOp::Mean(Box::new(Self::remap_indices(a, offset))),
244
245 LazyOp::Reshape(a, s) => {
246 LazyOp::Reshape(Box::new(Self::remap_indices(a, offset)), s.clone())
247 }
248 LazyOp::Transpose(a, d0, d1) => {
249 LazyOp::Transpose(Box::new(Self::remap_indices(a, offset)), *d0, *d1)
250 }
251
252 LazyOp::AddScalar(a, s) => {
253 LazyOp::AddScalar(Box::new(Self::remap_indices(a, offset)), *s)
254 }
255 LazyOp::MulScalar(a, s) => {
256 LazyOp::MulScalar(Box::new(Self::remap_indices(a, offset)), *s)
257 }
258 }
259 }
260
261 fn binary_op(
263 &self,
264 other: &LazyTensor,
265 make_op: impl FnOnce(Box<LazyOp>, Box<LazyOp>) -> LazyOp,
266 shape: Vec<usize>,
267 ) -> LazyTensor {
268 let (merged, remapped_right) = Self::merge_stores(&self.tensors, &other.tensors, &other.op);
269 LazyTensor {
270 op: make_op(Box::new(self.op.clone()), Box::new(remapped_right)),
271 shape,
272 tensors: merged,
273 }
274 }
275
276 pub fn add(&self, other: &LazyTensor) -> LazyTensor {
278 assert_eq!(self.shape, other.shape, "LazyTensor add: shapes must match");
279 self.binary_op(other, LazyOp::Add, self.shape.clone())
280 }
281
282 pub fn sub(&self, other: &LazyTensor) -> LazyTensor {
284 assert_eq!(self.shape, other.shape, "LazyTensor sub: shapes must match");
285 self.binary_op(other, LazyOp::Sub, self.shape.clone())
286 }
287
288 pub fn mul(&self, other: &LazyTensor) -> LazyTensor {
290 assert_eq!(self.shape, other.shape, "LazyTensor mul: shapes must match");
291 self.binary_op(other, LazyOp::Mul, self.shape.clone())
292 }
293
294 pub fn div(&self, other: &LazyTensor) -> LazyTensor {
296 assert_eq!(self.shape, other.shape, "LazyTensor div: shapes must match");
297 self.binary_op(other, LazyOp::Div, self.shape.clone())
298 }
299
300 pub fn add_scalar(&self, s: f32) -> LazyTensor {
306 LazyTensor {
307 op: LazyOp::AddScalar(Box::new(self.op.clone()), s),
308 shape: self.shape.clone(),
309 tensors: self.tensors.clone(),
310 }
311 }
312
313 pub fn mul_scalar(&self, s: f32) -> LazyTensor {
315 LazyTensor {
316 op: LazyOp::MulScalar(Box::new(self.op.clone()), s),
317 shape: self.shape.clone(),
318 tensors: self.tensors.clone(),
319 }
320 }
321
322 pub fn sum(&self) -> LazyTensor {
328 LazyTensor {
329 op: LazyOp::Sum(Box::new(self.op.clone())),
330 shape: vec![],
331 tensors: self.tensors.clone(),
332 }
333 }
334
335 pub fn mean(&self) -> LazyTensor {
337 LazyTensor {
338 op: LazyOp::Mean(Box::new(self.op.clone())),
339 shape: vec![],
340 tensors: self.tensors.clone(),
341 }
342 }
343
344 pub fn reshape(&self, shape: &[usize]) -> LazyTensor {
352 let old_numel: usize = self.shape.iter().product();
353 let new_numel: usize = shape.iter().product();
354 assert_eq!(
355 old_numel, new_numel,
356 "LazyTensor reshape: element count mismatch ({old_numel} vs {new_numel})"
357 );
358 LazyTensor {
359 op: LazyOp::Reshape(Box::new(self.op.clone()), shape.to_vec()),
360 shape: shape.to_vec(),
361 tensors: self.tensors.clone(),
362 }
363 }
364
365 pub fn shape(&self) -> &[usize] {
371 &self.shape
372 }
373
374 pub fn op_count(&self) -> usize {
379 Self::count_ops(&self.op)
380 }
381
382 fn count_ops(op: &LazyOp) -> usize {
383 match op {
384 LazyOp::Tensor(_) => 0,
385
386 LazyOp::Neg(a)
387 | LazyOp::Relu(a)
388 | LazyOp::Sigmoid(a)
389 | LazyOp::Exp(a)
390 | LazyOp::Log(a)
391 | LazyOp::Sqrt(a)
392 | LazyOp::Abs(a)
393 | LazyOp::Sum(a)
394 | LazyOp::Mean(a)
395 | LazyOp::AddScalar(a, _)
396 | LazyOp::MulScalar(a, _)
397 | LazyOp::Reshape(a, _)
398 | LazyOp::Transpose(a, _, _) => 1 + Self::count_ops(a),
399
400 LazyOp::Add(a, b) | LazyOp::Sub(a, b) | LazyOp::Mul(a, b) | LazyOp::Div(a, b) => {
401 1 + Self::count_ops(a) + Self::count_ops(b)
402 }
403 }
404 }
405
406 pub fn materialize(&self) -> Tensor<f32> {
415 self.eval_op(&self.op)
416 }
417
418 fn eval_op(&self, op: &LazyOp) -> Tensor<f32> {
419 match op {
420 LazyOp::Tensor(idx) => self.tensors[*idx].clone(),
421
422 LazyOp::Neg(a) => self.eval_op(a).neg(),
424 LazyOp::Relu(a) => self.eval_op(a).relu(),
425 LazyOp::Sigmoid(a) => self.eval_op(a).sigmoid(),
426 LazyOp::Exp(a) => self.eval_op(a).exp(),
427 LazyOp::Log(a) => self.eval_op(a).ln(),
428 LazyOp::Sqrt(a) => self.eval_op(a).sqrt(),
429 LazyOp::Abs(a) => {
430 let t = self.eval_op(a);
431 let data: Vec<f32> = t.to_vec().iter().map(|x| x.abs()).collect();
432 Tensor::from_vec(data, t.shape()).unwrap()
433 }
434
435 LazyOp::Add(a, b) => {
437 let ta = self.eval_op(a);
438 let tb = self.eval_op(b);
439 ta.add(&tb).unwrap()
440 }
441 LazyOp::Sub(a, b) => {
442 let ta = self.eval_op(a);
443 let tb = self.eval_op(b);
444 ta.sub(&tb).unwrap()
445 }
446 LazyOp::Mul(a, b) => {
447 let ta = self.eval_op(a);
448 let tb = self.eval_op(b);
449 ta.mul(&tb).unwrap()
450 }
451 LazyOp::Div(a, b) => {
452 let ta = self.eval_op(a);
453 let tb = self.eval_op(b);
454 ta.div(&tb).unwrap()
455 }
456
457 LazyOp::Sum(a) => self.eval_op(a).sum(),
459 LazyOp::Mean(a) => self.eval_op(a).mean().unwrap(),
460
461 LazyOp::Reshape(a, shape) => {
463 let t = self.eval_op(a);
464 let isize_shape: Vec<isize> = shape.iter().map(|&s| s as isize).collect();
465 t.reshape(&isize_shape).unwrap()
466 }
467 LazyOp::Transpose(a, d0, d1) => {
468 let t = self.eval_op(a);
469 t.transpose(*d0 as i64, *d1 as i64).unwrap()
470 }
471
472 LazyOp::AddScalar(a, s) => self.eval_op(a).add_scalar(*s),
474 LazyOp::MulScalar(a, s) => self.eval_op(a).mul_scalar(*s),
475 }
476 }
477
478 pub fn optimize(&self) -> LazyTensor {
491 LazyTensor {
492 op: Self::optimize_op(&self.op),
493 shape: self.shape.clone(),
494 tensors: self.tensors.clone(),
495 }
496 }
497
498 fn optimize_op(op: &LazyOp) -> LazyOp {
499 let op = Self::optimize_children(op);
501 Self::simplify(&op)
503 }
504
505 fn optimize_children(op: &LazyOp) -> LazyOp {
506 match op {
507 LazyOp::Tensor(idx) => LazyOp::Tensor(*idx),
508
509 LazyOp::Neg(a) => LazyOp::Neg(Box::new(Self::optimize_op(a))),
510 LazyOp::Relu(a) => LazyOp::Relu(Box::new(Self::optimize_op(a))),
511 LazyOp::Sigmoid(a) => LazyOp::Sigmoid(Box::new(Self::optimize_op(a))),
512 LazyOp::Exp(a) => LazyOp::Exp(Box::new(Self::optimize_op(a))),
513 LazyOp::Log(a) => LazyOp::Log(Box::new(Self::optimize_op(a))),
514 LazyOp::Sqrt(a) => LazyOp::Sqrt(Box::new(Self::optimize_op(a))),
515 LazyOp::Abs(a) => LazyOp::Abs(Box::new(Self::optimize_op(a))),
516
517 LazyOp::Add(a, b) => LazyOp::Add(
518 Box::new(Self::optimize_op(a)),
519 Box::new(Self::optimize_op(b)),
520 ),
521 LazyOp::Sub(a, b) => LazyOp::Sub(
522 Box::new(Self::optimize_op(a)),
523 Box::new(Self::optimize_op(b)),
524 ),
525 LazyOp::Mul(a, b) => LazyOp::Mul(
526 Box::new(Self::optimize_op(a)),
527 Box::new(Self::optimize_op(b)),
528 ),
529 LazyOp::Div(a, b) => LazyOp::Div(
530 Box::new(Self::optimize_op(a)),
531 Box::new(Self::optimize_op(b)),
532 ),
533
534 LazyOp::Sum(a) => LazyOp::Sum(Box::new(Self::optimize_op(a))),
535 LazyOp::Mean(a) => LazyOp::Mean(Box::new(Self::optimize_op(a))),
536
537 LazyOp::Reshape(a, s) => LazyOp::Reshape(Box::new(Self::optimize_op(a)), s.clone()),
538 LazyOp::Transpose(a, d0, d1) => {
539 LazyOp::Transpose(Box::new(Self::optimize_op(a)), *d0, *d1)
540 }
541
542 LazyOp::AddScalar(a, s) => LazyOp::AddScalar(Box::new(Self::optimize_op(a)), *s),
543 LazyOp::MulScalar(a, s) => LazyOp::MulScalar(Box::new(Self::optimize_op(a)), *s),
544 }
545 }
546
547 fn simplify(op: &LazyOp) -> LazyOp {
548 match op {
549 LazyOp::Neg(inner) => {
551 if let LazyOp::Neg(x) = inner.as_ref() {
552 return *x.clone();
553 }
554 op.clone()
555 }
556
557 LazyOp::Exp(inner) => {
559 if let LazyOp::Log(x) = inner.as_ref() {
560 return *x.clone();
561 }
562 op.clone()
563 }
564
565 LazyOp::Log(inner) => {
567 if let LazyOp::Exp(x) = inner.as_ref() {
568 return *x.clone();
569 }
570 op.clone()
571 }
572
573 LazyOp::AddScalar(a, s) if *s == 0.0 => *a.clone(),
575
576 LazyOp::MulScalar(a, s) if (*s - 1.0).abs() < f32::EPSILON => *a.clone(),
578
579 LazyOp::AddScalar(inner, s2) => {
585 if let LazyOp::AddScalar(x, s1) = inner.as_ref() {
586 return LazyOp::AddScalar(x.clone(), s1 + s2);
587 }
588 op.clone()
589 }
590
591 LazyOp::MulScalar(inner, s2) => {
593 if let LazyOp::MulScalar(x, s1) = inner.as_ref() {
594 return LazyOp::MulScalar(x.clone(), s1 * s2);
595 }
596 op.clone()
597 }
598
599 _ => op.clone(),
600 }
601 }
602}
603
604#[cfg(test)]
609mod tests {
610 use super::*;
611
612 fn approx_eq(a: &[f32], b: &[f32], tol: f32) {
613 assert_eq!(
614 a.len(),
615 b.len(),
616 "length mismatch: {} vs {}",
617 a.len(),
618 b.len()
619 );
620 for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
621 assert!(
622 (x - y).abs() < tol,
623 "element {i}: {x} vs {y} (diff = {})",
624 (x - y).abs()
625 );
626 }
627 }
628
629 #[test]
630 fn test_from_tensor_preserves_shape() {
631 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
632 let lazy = LazyTensor::from_tensor(t.clone());
633 assert_eq!(lazy.shape(), &[2, 3]);
634 let result = lazy.materialize();
635 assert_eq!(result.shape(), &[2, 3]);
636 assert_eq!(result.to_vec(), t.to_vec());
637 }
638
639 #[test]
640 fn test_zeros_creation() {
641 let lazy = LazyTensor::zeros(&[3, 4]);
642 assert_eq!(lazy.shape(), &[3, 4]);
643 let result = lazy.materialize();
644 assert_eq!(result.to_vec(), vec![0.0; 12]);
645 }
646
647 #[test]
648 fn test_ones_creation() {
649 let lazy = LazyTensor::ones(&[2, 3]);
650 assert_eq!(lazy.shape(), &[2, 3]);
651 let result = lazy.materialize();
652 assert_eq!(result.to_vec(), vec![1.0; 6]);
653 }
654
655 #[test]
656 fn test_add_two_lazy_tensors() {
657 let a = LazyTensor::from_tensor(
658 Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
659 );
660 let b = LazyTensor::from_tensor(
661 Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2]).unwrap(),
662 );
663 let c = a.add(&b);
664 assert_eq!(c.shape(), &[2, 2]);
665 let result = c.materialize();
666 assert_eq!(result.to_vec(), vec![11.0, 22.0, 33.0, 44.0]);
667 }
668
669 #[test]
670 fn test_sub_two_lazy_tensors() {
671 let a =
672 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap());
673 let b =
674 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
675 let c = a.sub(&b);
676 assert_eq!(c.materialize().to_vec(), vec![9.0, 18.0, 27.0]);
677 }
678
679 #[test]
680 fn test_mul_two_lazy_tensors() {
681 let a =
682 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0, 3.0, 4.0], &[3]).unwrap());
683 let b =
684 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![5.0, 6.0, 7.0], &[3]).unwrap());
685 let c = a.mul(&b);
686 assert_eq!(c.materialize().to_vec(), vec![10.0, 18.0, 28.0]);
687 }
688
689 #[test]
690 fn test_div_two_lazy_tensors() {
691 let a =
692 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0], &[3]).unwrap());
693 let b =
694 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0, 4.0, 5.0], &[3]).unwrap());
695 let c = a.div(&b);
696 assert_eq!(c.materialize().to_vec(), vec![5.0, 5.0, 6.0]);
697 }
698
699 #[test]
700 fn test_neg_lazy_tensor() {
701 let a =
702 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, -2.0, 3.0], &[3]).unwrap());
703 let result = a.neg().materialize();
704 assert_eq!(result.to_vec(), vec![-1.0, 2.0, -3.0]);
705 }
706
707 #[test]
708 fn test_relu_correctness() {
709 let a = LazyTensor::from_tensor(
710 Tensor::<f32>::from_vec(vec![-3.0, -1.0, 0.0, 1.0, 3.0], &[5]).unwrap(),
711 );
712 let result = a.relu().materialize();
713 assert_eq!(result.to_vec(), vec![0.0, 0.0, 0.0, 1.0, 3.0]);
714 }
715
716 #[test]
717 fn test_sigmoid_correctness() {
718 let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![0.0], &[1]).unwrap());
719 let result = a.sigmoid().materialize();
720 approx_eq(&result.to_vec(), &[0.5], 1e-6);
721 }
722
723 #[test]
724 fn test_exp_correctness() {
725 let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2]).unwrap());
726 let result = a.exp().materialize();
727 approx_eq(&result.to_vec(), &[1.0, std::f32::consts::E], 1e-5);
728 }
729
730 #[test]
731 fn test_log_correctness() {
732 let a = LazyTensor::from_tensor(
733 Tensor::<f32>::from_vec(vec![1.0, std::f32::consts::E], &[2]).unwrap(),
734 );
735 let result = a.log().materialize();
736 approx_eq(&result.to_vec(), &[0.0, 1.0], 1e-5);
737 }
738
739 #[test]
740 fn test_add_scalar_correctness() {
741 let a =
742 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
743 let result = a.add_scalar(10.0).materialize();
744 assert_eq!(result.to_vec(), vec![11.0, 12.0, 13.0]);
745 }
746
747 #[test]
748 fn test_mul_scalar_correctness() {
749 let a =
750 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
751 let result = a.mul_scalar(3.0).materialize();
752 assert_eq!(result.to_vec(), vec![3.0, 6.0, 9.0]);
753 }
754
755 #[test]
756 fn test_sum_reduction() {
757 let a = LazyTensor::from_tensor(
758 Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
759 );
760 let result = a.sum().materialize();
761 assert_eq!(result.shape(), &[] as &[usize]);
762 approx_eq(&result.to_vec(), &[10.0], 1e-6);
763 }
764
765 #[test]
766 fn test_mean_reduction() {
767 let a = LazyTensor::from_tensor(
768 Tensor::<f32>::from_vec(vec![2.0, 4.0, 6.0, 8.0], &[4]).unwrap(),
769 );
770 let result = a.mean().materialize();
771 approx_eq(&result.to_vec(), &[5.0], 1e-6);
772 }
773
774 #[test]
775 fn test_reshape() {
776 let a = LazyTensor::from_tensor(
777 Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
778 );
779 let reshaped = a.reshape(&[3, 2]);
780 assert_eq!(reshaped.shape(), &[3, 2]);
781 let result = reshaped.materialize();
782 assert_eq!(result.shape(), &[3, 2]);
783 assert_eq!(result.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
784 }
785
786 #[test]
787 fn test_chained_operations() {
788 let x = LazyTensor::from_tensor(
790 Tensor::<f32>::from_vec(vec![-1.0, 0.0, 1.0, 2.0], &[4]).unwrap(),
791 );
792 let result = x.relu().add_scalar(1.0).mul_scalar(2.0).materialize();
793 assert_eq!(result.to_vec(), vec![2.0, 2.0, 4.0, 6.0]);
795 }
796
797 #[test]
798 fn test_op_count_leaf() {
799 let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
800 assert_eq!(x.op_count(), 0);
801 }
802
803 #[test]
804 fn test_op_count_unary() {
805 let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
806 assert_eq!(x.relu().op_count(), 1);
807 assert_eq!(x.relu().neg().op_count(), 2);
808 }
809
810 #[test]
811 fn test_op_count_binary() {
812 let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
813 let b = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0], &[1]).unwrap());
814 assert_eq!(a.add(&b).op_count(), 1);
816 }
817
818 #[test]
819 fn test_optimize_add_zero() {
820 let x =
821 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
822 let y = x.add_scalar(0.0);
823 assert_eq!(y.op_count(), 1); let opt = y.optimize();
825 assert_eq!(opt.op_count(), 0); assert_eq!(opt.materialize().to_vec(), vec![1.0, 2.0, 3.0]);
827 }
828
829 #[test]
830 fn test_optimize_mul_one() {
831 let x =
832 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![4.0, 5.0, 6.0], &[3]).unwrap());
833 let y = x.mul_scalar(1.0);
834 assert_eq!(y.op_count(), 1);
835 let opt = y.optimize();
836 assert_eq!(opt.op_count(), 0);
837 assert_eq!(opt.materialize().to_vec(), vec![4.0, 5.0, 6.0]);
838 }
839
840 #[test]
841 fn test_optimize_neg_neg() {
842 let x =
843 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, -2.0, 3.0], &[3]).unwrap());
844 let y = x.neg().neg();
845 assert_eq!(y.op_count(), 2);
846 let opt = y.optimize();
847 assert_eq!(opt.op_count(), 0);
848 assert_eq!(opt.materialize().to_vec(), vec![1.0, -2.0, 3.0]);
849 }
850
851 #[test]
852 fn test_optimize_scalar_folding_mul() {
853 let x =
854 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
855 let y = x.mul_scalar(2.0).mul_scalar(3.0);
857 assert_eq!(y.op_count(), 2);
858 let opt = y.optimize();
859 assert_eq!(opt.op_count(), 1);
860 assert_eq!(opt.materialize().to_vec(), vec![6.0, 12.0, 18.0]);
861 }
862
863 #[test]
864 fn test_optimize_scalar_folding_add() {
865 let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap());
866 let y = x.add_scalar(3.0).add_scalar(7.0);
868 assert_eq!(y.op_count(), 2);
869 let opt = y.optimize();
870 assert_eq!(opt.op_count(), 1);
871 assert_eq!(opt.materialize().to_vec(), vec![11.0, 12.0]);
872 }
873
874 #[test]
875 fn test_optimize_exp_log() {
876 let x =
877 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
878 let y = x.log().exp();
880 assert_eq!(y.op_count(), 2);
881 let opt = y.optimize();
882 assert_eq!(opt.op_count(), 0);
883 assert_eq!(opt.materialize().to_vec(), vec![1.0, 2.0, 3.0]);
884 }
885
886 #[test]
887 fn test_optimize_log_exp() {
888 let x =
889 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap());
890 let y = x.exp().log();
892 assert_eq!(y.op_count(), 2);
893 let opt = y.optimize();
894 assert_eq!(opt.op_count(), 0);
895 assert_eq!(opt.materialize().to_vec(), vec![1.0, 2.0, 3.0]);
896 }
897
898 #[test]
899 fn test_materialize_matches_eager() {
900 let data = vec![1.0, 2.0, 3.0, 4.0];
901 let t = Tensor::<f32>::from_vec(data.clone(), &[2, 2]).unwrap();
902
903 let eager = t.relu().add_scalar(1.0).mul_scalar(2.0).sum();
905
906 let lazy = LazyTensor::from_tensor(Tensor::<f32>::from_vec(data, &[2, 2]).unwrap());
908 let lazy_result = lazy
909 .relu()
910 .add_scalar(1.0)
911 .mul_scalar(2.0)
912 .sum()
913 .materialize();
914
915 approx_eq(&eager.to_vec(), &lazy_result.to_vec(), 1e-6);
916 }
917
918 #[test]
919 fn test_large_chain_optimization() {
920 let x = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![5.0], &[1]).unwrap());
921 let y = x
923 .mul_scalar(2.0)
924 .mul_scalar(3.0)
925 .mul_scalar(4.0)
926 .add_scalar(1.0)
927 .add_scalar(2.0)
928 .add_scalar(3.0);
929 assert_eq!(y.op_count(), 6);
930 let opt = y.optimize();
931 assert_eq!(opt.op_count(), 2);
933 approx_eq(&opt.materialize().to_vec(), &[126.0], 1e-6);
935 }
936
937 #[test]
938 fn test_binary_ops_tensor_merging() {
939 let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap());
941 let b = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap());
942 let c = a.add(&b);
944 assert_eq!(c.tensors.len(), 2);
945 let result = c.materialize();
946 assert_eq!(result.to_vec(), vec![4.0, 6.0]);
947 }
948
949 #[test]
950 fn test_binary_ops_chain_merging() {
951 let a = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![1.0], &[1]).unwrap());
953 let b = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![2.0], &[1]).unwrap());
954 let c = LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![3.0], &[1]).unwrap());
955 let ab = a.add(&b);
956 let abc = ab.add(&c);
957 assert_eq!(abc.tensors.len(), 3);
958 approx_eq(&abc.materialize().to_vec(), &[6.0], 1e-6);
959 }
960
961 #[test]
962 fn test_sqrt_correctness() {
963 let a =
964 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![4.0, 9.0, 16.0], &[3]).unwrap());
965 let result = a.sqrt().materialize();
966 approx_eq(&result.to_vec(), &[2.0, 3.0, 4.0], 1e-6);
967 }
968
969 #[test]
970 fn test_abs_correctness() {
971 let a =
972 LazyTensor::from_tensor(Tensor::<f32>::from_vec(vec![-3.0, 0.0, 5.0], &[3]).unwrap());
973 let result = a.abs().materialize();
974 assert_eq!(result.to_vec(), vec![3.0, 0.0, 5.0]);
975 }
976}