1use crate::mat::Mat;
50use crate::mat_ref::MatRef;
51use core::marker::PhantomData;
52use num_complex::Complex;
53use num_traits::Zero;
54use oxiblas_core::scalar::Scalar;
55
56pub trait Expr: Sized {
65 type Elem: Scalar + bytemuck::Zeroable + Zero;
67
68 fn nrows(&self) -> usize;
70
71 fn ncols(&self) -> usize;
73
74 fn shape(&self) -> (usize, usize) {
76 (self.nrows(), self.ncols())
77 }
78
79 fn eval(&self) -> Mat<Self::Elem>;
81
82 fn eval_into(&self, target: &mut Mat<Self::Elem>);
84
85 fn t(self) -> ExprTranspose<Self> {
87 ExprTranspose { inner: self }
88 }
89
90 fn scale(self, alpha: Self::Elem) -> ExprScale<Self> {
92 ExprScale { inner: self, alpha }
93 }
94
95 fn add<E: Expr<Elem = Self::Elem>>(self, other: E) -> ExprAdd<Self, E> {
97 debug_assert_eq!(
98 self.shape(),
99 other.shape(),
100 "Matrix dimensions must match for addition"
101 );
102 ExprAdd {
103 lhs: self,
104 rhs: other,
105 }
106 }
107
108 fn sub<E: Expr<Elem = Self::Elem>>(self, other: E) -> ExprSub<Self, E> {
110 debug_assert_eq!(
111 self.shape(),
112 other.shape(),
113 "Matrix dimensions must match for subtraction"
114 );
115 ExprSub {
116 lhs: self,
117 rhs: other,
118 }
119 }
120
121 fn neg(self) -> ExprNeg<Self> {
123 ExprNeg { inner: self }
124 }
125
126 fn matmul<E: Expr<Elem = Self::Elem>>(self, other: E) -> ExprMul<Self, E> {
128 debug_assert_eq!(
129 self.ncols(),
130 other.nrows(),
131 "Matrix dimensions must be compatible for multiplication"
132 );
133 ExprMul {
134 lhs: self,
135 rhs: other,
136 }
137 }
138}
139
140pub trait ComplexExpr: Expr
142where
143 Self::Elem: ComplexScalar,
144{
145 fn conj(self) -> ExprConj<Self> {
147 ExprConj { inner: self }
148 }
149
150 fn h(self) -> ExprHermitian<Self> {
152 ExprHermitian { inner: self }
153 }
154}
155
156impl<E: Expr> ComplexExpr for E where E::Elem: ComplexScalar {}
158
159pub trait ComplexScalar: Scalar + bytemuck::Zeroable + Zero {
161 fn conj(&self) -> Self;
163}
164
165impl ComplexScalar for Complex<f32> {
166 fn conj(&self) -> Self {
167 Complex::conj(self)
168 }
169}
170
171impl ComplexScalar for Complex<f64> {
172 fn conj(&self) -> Self {
173 Complex::conj(self)
174 }
175}
176
177impl ComplexScalar for f32 {
179 fn conj(&self) -> Self {
180 *self
181 }
182}
183
184impl ComplexScalar for f64 {
185 fn conj(&self) -> Self {
186 *self
187 }
188}
189
190#[derive(Clone, Copy)]
196pub struct ExprLeaf<'a, T: Scalar + bytemuck::Zeroable + Zero> {
197 mat: MatRef<'a, T>,
198}
199
200impl<'a, T: Scalar + bytemuck::Zeroable + Zero> ExprLeaf<'a, T> {
201 pub fn new(mat: MatRef<'a, T>) -> Self {
203 Self { mat }
204 }
205}
206
207impl<'a, T: Scalar + bytemuck::Zeroable + Zero> Expr for ExprLeaf<'a, T> {
208 type Elem = T;
209
210 fn nrows(&self) -> usize {
211 self.mat.nrows()
212 }
213
214 fn ncols(&self) -> usize {
215 self.mat.ncols()
216 }
217
218 fn eval(&self) -> Mat<T> {
219 let mut result = Mat::zeros(self.nrows(), self.ncols());
220 self.eval_into(&mut result);
221 result
222 }
223
224 fn eval_into(&self, target: &mut Mat<T>) {
225 target.copy_from(&self.mat);
226 }
227}
228
229pub trait LazyExt<'a, T: Scalar + bytemuck::Zeroable + Zero> {
231 fn lazy(self) -> ExprLeaf<'a, T>;
233}
234
235impl<'a, T: Scalar + bytemuck::Zeroable + Zero> LazyExt<'a, T> for MatRef<'a, T> {
236 fn lazy(self) -> ExprLeaf<'a, T> {
237 ExprLeaf::new(self)
238 }
239}
240
241pub struct ExprTranspose<E: Expr> {
247 inner: E,
248}
249
250impl<E: Expr> Expr for ExprTranspose<E> {
251 type Elem = E::Elem;
252
253 fn nrows(&self) -> usize {
254 self.inner.ncols()
255 }
256
257 fn ncols(&self) -> usize {
258 self.inner.nrows()
259 }
260
261 fn eval(&self) -> Mat<Self::Elem> {
262 let mut result = Mat::zeros(self.nrows(), self.ncols());
263 self.eval_into(&mut result);
264 result
265 }
266
267 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
268 let inner = self.inner.eval();
269 for i in 0..self.nrows() {
270 for j in 0..self.ncols() {
271 target[(i, j)] = inner[(j, i)];
272 }
273 }
274 }
275}
276
277impl<E: Expr> ExprTranspose<ExprTranspose<E>> {
279 pub fn simplify(self) -> E {
281 self.inner.inner
282 }
283}
284
285pub struct ExprScale<E: Expr> {
291 inner: E,
292 alpha: E::Elem,
293}
294
295impl<E: Expr> Expr for ExprScale<E> {
296 type Elem = E::Elem;
297
298 fn nrows(&self) -> usize {
299 self.inner.nrows()
300 }
301
302 fn ncols(&self) -> usize {
303 self.inner.ncols()
304 }
305
306 fn eval(&self) -> Mat<Self::Elem> {
307 let mut result = Mat::zeros(self.nrows(), self.ncols());
308 self.eval_into(&mut result);
309 result
310 }
311
312 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
313 let inner = self.inner.eval();
314 for i in 0..self.nrows() {
315 for j in 0..self.ncols() {
316 target[(i, j)] = inner[(i, j)] * self.alpha;
317 }
318 }
319 }
320}
321
322impl<E: Expr> ExprScale<ExprScale<E>> {
324 pub fn simplify(self) -> ExprScale<E> {
326 ExprScale {
327 inner: self.inner.inner,
328 alpha: self.alpha * self.inner.alpha,
329 }
330 }
331}
332
333pub struct ExprNeg<E: Expr> {
339 inner: E,
340}
341
342impl<E: Expr> Expr for ExprNeg<E> {
343 type Elem = E::Elem;
344
345 fn nrows(&self) -> usize {
346 self.inner.nrows()
347 }
348
349 fn ncols(&self) -> usize {
350 self.inner.ncols()
351 }
352
353 fn eval(&self) -> Mat<Self::Elem> {
354 let mut result = Mat::zeros(self.nrows(), self.ncols());
355 self.eval_into(&mut result);
356 result
357 }
358
359 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
360 let inner = self.inner.eval();
361 for i in 0..self.nrows() {
362 for j in 0..self.ncols() {
363 target[(i, j)] = E::Elem::zero() - inner[(i, j)];
364 }
365 }
366 }
367}
368
369impl<E: Expr> ExprNeg<ExprNeg<E>> {
371 pub fn simplify(self) -> E {
373 self.inner.inner
374 }
375}
376
377pub struct ExprAdd<L: Expr, R: Expr<Elem = L::Elem>> {
383 lhs: L,
384 rhs: R,
385}
386
387impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprAdd<L, R> {
388 type Elem = L::Elem;
389
390 fn nrows(&self) -> usize {
391 self.lhs.nrows()
392 }
393
394 fn ncols(&self) -> usize {
395 self.lhs.ncols()
396 }
397
398 fn eval(&self) -> Mat<Self::Elem> {
399 let mut result = Mat::zeros(self.nrows(), self.ncols());
400 self.eval_into(&mut result);
401 result
402 }
403
404 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
405 let lhs = self.lhs.eval();
406 let rhs = self.rhs.eval();
407 for i in 0..self.nrows() {
408 for j in 0..self.ncols() {
409 target[(i, j)] = lhs[(i, j)] + rhs[(i, j)];
410 }
411 }
412 }
413}
414
415pub struct ExprSub<L: Expr, R: Expr<Elem = L::Elem>> {
421 lhs: L,
422 rhs: R,
423}
424
425impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprSub<L, R> {
426 type Elem = L::Elem;
427
428 fn nrows(&self) -> usize {
429 self.lhs.nrows()
430 }
431
432 fn ncols(&self) -> usize {
433 self.lhs.ncols()
434 }
435
436 fn eval(&self) -> Mat<Self::Elem> {
437 let mut result = Mat::zeros(self.nrows(), self.ncols());
438 self.eval_into(&mut result);
439 result
440 }
441
442 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
443 let lhs = self.lhs.eval();
444 let rhs = self.rhs.eval();
445 for i in 0..self.nrows() {
446 for j in 0..self.ncols() {
447 target[(i, j)] = lhs[(i, j)] - rhs[(i, j)];
448 }
449 }
450 }
451}
452
453pub struct ExprMul<L: Expr, R: Expr<Elem = L::Elem>> {
459 lhs: L,
460 rhs: R,
461}
462
463impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprMul<L, R> {
464 type Elem = L::Elem;
465
466 fn nrows(&self) -> usize {
467 self.lhs.nrows()
468 }
469
470 fn ncols(&self) -> usize {
471 self.rhs.ncols()
472 }
473
474 fn eval(&self) -> Mat<Self::Elem> {
475 let mut result = Mat::zeros(self.nrows(), self.ncols());
476 self.eval_into(&mut result);
477 result
478 }
479
480 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
481 let lhs = self.lhs.eval();
482 let rhs = self.rhs.eval();
483 let k = self.lhs.ncols();
484
485 for i in 0..self.nrows() {
486 for j in 0..self.ncols() {
487 let mut sum = L::Elem::zero();
488 for kk in 0..k {
489 sum += lhs[(i, kk)] * rhs[(kk, j)];
490 }
491 target[(i, j)] = sum;
492 }
493 }
494 }
495}
496
497pub struct ExprConj<E: Expr>
503where
504 E::Elem: ComplexScalar,
505{
506 inner: E,
507}
508
509impl<E: Expr> Expr for ExprConj<E>
510where
511 E::Elem: ComplexScalar,
512{
513 type Elem = E::Elem;
514
515 fn nrows(&self) -> usize {
516 self.inner.nrows()
517 }
518
519 fn ncols(&self) -> usize {
520 self.inner.ncols()
521 }
522
523 fn eval(&self) -> Mat<Self::Elem> {
524 let mut result = Mat::zeros(self.nrows(), self.ncols());
525 self.eval_into(&mut result);
526 result
527 }
528
529 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
530 let inner = self.inner.eval();
531 for i in 0..self.nrows() {
532 for j in 0..self.ncols() {
533 target[(i, j)] = inner[(i, j)].conj();
534 }
535 }
536 }
537}
538
539impl<E: Expr> ExprConj<ExprConj<E>>
541where
542 E::Elem: ComplexScalar,
543{
544 pub fn simplify(self) -> E {
546 self.inner.inner
547 }
548}
549
550pub struct ExprHermitian<E: Expr>
556where
557 E::Elem: ComplexScalar,
558{
559 inner: E,
560}
561
562impl<E: Expr> Expr for ExprHermitian<E>
563where
564 E::Elem: ComplexScalar,
565{
566 type Elem = E::Elem;
567
568 fn nrows(&self) -> usize {
569 self.inner.ncols()
570 }
571
572 fn ncols(&self) -> usize {
573 self.inner.nrows()
574 }
575
576 fn eval(&self) -> Mat<Self::Elem> {
577 let mut result = Mat::zeros(self.nrows(), self.ncols());
578 self.eval_into(&mut result);
579 result
580 }
581
582 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
583 let inner = self.inner.eval();
584 for i in 0..self.nrows() {
585 for j in 0..self.ncols() {
586 target[(i, j)] = inner[(j, i)].conj();
587 }
588 }
589 }
590}
591
592impl<E: Expr> ExprHermitian<ExprHermitian<E>>
594where
595 E::Elem: ComplexScalar,
596{
597 pub fn simplify(self) -> E {
599 self.inner.inner
600 }
601}
602
603impl<'a, T: Scalar + bytemuck::Zeroable + Zero, R: Expr<Elem = T>> core::ops::Add<R>
608 for ExprLeaf<'a, T>
609{
610 type Output = ExprAdd<Self, R>;
611
612 fn add(self, rhs: R) -> Self::Output {
613 Expr::add(self, rhs)
614 }
615}
616
617impl<'a, T: Scalar + bytemuck::Zeroable + Zero, R: Expr<Elem = T>> core::ops::Sub<R>
618 for ExprLeaf<'a, T>
619{
620 type Output = ExprSub<Self, R>;
621
622 fn sub(self, rhs: R) -> Self::Output {
623 Expr::sub(self, rhs)
624 }
625}
626
627impl<'a, T: Scalar + bytemuck::Zeroable + Zero> core::ops::Neg for ExprLeaf<'a, T> {
628 type Output = ExprNeg<Self>;
629
630 fn neg(self) -> Self::Output {
631 Expr::neg(self)
632 }
633}
634
635impl<L1, R1, L2: Expr<Elem = L1::Elem>> core::ops::Add<L2> for ExprAdd<L1, R1>
637where
638 L1: Expr,
639 R1: Expr<Elem = L1::Elem>,
640{
641 type Output = ExprAdd<Self, L2>;
642
643 fn add(self, rhs: L2) -> Self::Output {
644 Expr::add(self, rhs)
645 }
646}
647
648impl<L1, R1, L2: Expr<Elem = L1::Elem>> core::ops::Sub<L2> for ExprAdd<L1, R1>
650where
651 L1: Expr,
652 R1: Expr<Elem = L1::Elem>,
653{
654 type Output = ExprSub<Self, L2>;
655
656 fn sub(self, rhs: L2) -> Self::Output {
657 Expr::sub(self, rhs)
658 }
659}
660
661impl<L1, R1> core::ops::Neg for ExprAdd<L1, R1>
663where
664 L1: Expr,
665 R1: Expr<Elem = L1::Elem>,
666{
667 type Output = ExprNeg<Self>;
668
669 fn neg(self) -> Self::Output {
670 Expr::neg(self)
671 }
672}
673
674impl<E, R: Expr<Elem = E::Elem>> core::ops::Add<R> for ExprScale<E>
676where
677 E: Expr,
678{
679 type Output = ExprAdd<Self, R>;
680
681 fn add(self, rhs: R) -> Self::Output {
682 Expr::add(self, rhs)
683 }
684}
685
686impl<E, R: Expr<Elem = E::Elem>> core::ops::Sub<R> for ExprScale<E>
688where
689 E: Expr,
690{
691 type Output = ExprSub<Self, R>;
692
693 fn sub(self, rhs: R) -> Self::Output {
694 Expr::sub(self, rhs)
695 }
696}
697
698impl<E> core::ops::Neg for ExprScale<E>
700where
701 E: Expr,
702{
703 type Output = ExprNeg<Self>;
704
705 fn neg(self) -> Self::Output {
706 Expr::neg(self)
707 }
708}
709
710impl<E, R: Expr<Elem = E::Elem>> core::ops::Add<R> for ExprTranspose<E>
712where
713 E: Expr,
714{
715 type Output = ExprAdd<Self, R>;
716
717 fn add(self, rhs: R) -> Self::Output {
718 Expr::add(self, rhs)
719 }
720}
721
722impl<E, R: Expr<Elem = E::Elem>> core::ops::Sub<R> for ExprTranspose<E>
724where
725 E: Expr,
726{
727 type Output = ExprSub<Self, R>;
728
729 fn sub(self, rhs: R) -> Self::Output {
730 Expr::sub(self, rhs)
731 }
732}
733
734impl<E> core::ops::Neg for ExprTranspose<E>
736where
737 E: Expr,
738{
739 type Output = ExprNeg<Self>;
740
741 fn neg(self) -> Self::Output {
742 Expr::neg(self)
743 }
744}
745
746pub struct ExprFma<L: Expr, R: Expr<Elem = L::Elem>> {
752 lhs: L,
753 rhs: R,
754 alpha: L::Elem,
755 beta: L::Elem,
756}
757
758impl<L: Expr, R: Expr<Elem = L::Elem>> ExprFma<L, R> {
759 pub fn new(lhs: L, rhs: R, alpha: L::Elem, beta: L::Elem) -> Self {
761 debug_assert_eq!(lhs.shape(), rhs.shape());
762 Self {
763 lhs,
764 rhs,
765 alpha,
766 beta,
767 }
768 }
769}
770
771impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprFma<L, R> {
772 type Elem = L::Elem;
773
774 fn nrows(&self) -> usize {
775 self.lhs.nrows()
776 }
777
778 fn ncols(&self) -> usize {
779 self.lhs.ncols()
780 }
781
782 fn eval(&self) -> Mat<Self::Elem> {
783 let mut result = Mat::zeros(self.nrows(), self.ncols());
784 self.eval_into(&mut result);
785 result
786 }
787
788 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
789 let lhs = self.lhs.eval();
790 let rhs = self.rhs.eval();
791 for i in 0..self.nrows() {
792 for j in 0..self.ncols() {
793 target[(i, j)] = self.alpha * lhs[(i, j)] + self.beta * rhs[(i, j)];
794 }
795 }
796 }
797}
798
799pub struct ExprGemm<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>> {
801 a: A,
802 b: B,
803 c: C,
804 alpha: A::Elem,
805 beta: A::Elem,
806 _marker: PhantomData<A::Elem>,
807}
808
809impl<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>> ExprGemm<A, B, C> {
810 pub fn new(a: A, b: B, c: C, alpha: A::Elem, beta: A::Elem) -> Self {
812 debug_assert_eq!(a.ncols(), b.nrows());
813 debug_assert_eq!(a.nrows(), c.nrows());
814 debug_assert_eq!(b.ncols(), c.ncols());
815 Self {
816 a,
817 b,
818 c,
819 alpha,
820 beta,
821 _marker: PhantomData,
822 }
823 }
824}
825
826impl<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>> Expr for ExprGemm<A, B, C> {
827 type Elem = A::Elem;
828
829 fn nrows(&self) -> usize {
830 self.a.nrows()
831 }
832
833 fn ncols(&self) -> usize {
834 self.b.ncols()
835 }
836
837 fn eval(&self) -> Mat<Self::Elem> {
838 let mut result = Mat::zeros(self.nrows(), self.ncols());
839 self.eval_into(&mut result);
840 result
841 }
842
843 fn eval_into(&self, target: &mut Mat<Self::Elem>) {
844 let a = self.a.eval();
845 let b = self.b.eval();
846 let c = self.c.eval();
847 let k = self.a.ncols();
848
849 for i in 0..self.nrows() {
850 for j in 0..self.ncols() {
851 let mut sum = Self::Elem::zero();
852 for kk in 0..k {
853 sum += a[(i, kk)] * b[(kk, j)];
854 }
855 target[(i, j)] = self.alpha * sum + self.beta * c[(i, j)];
856 }
857 }
858 }
859}
860
861pub fn fma<L: Expr, R: Expr<Elem = L::Elem>>(
867 alpha: L::Elem,
868 a: L,
869 beta: L::Elem,
870 b: R,
871) -> ExprFma<L, R> {
872 ExprFma::new(a, b, alpha, beta)
873}
874
875pub fn gemm<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>>(
877 alpha: A::Elem,
878 a: A,
879 b: B,
880 beta: A::Elem,
881 c: C,
882) -> ExprGemm<A, B, C> {
883 ExprGemm::new(a, b, c, alpha, beta)
884}
885
886#[cfg(test)]
891mod tests {
892 use super::*;
893
894 #[test]
895 fn test_lazy_leaf() {
896 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
897
898 let expr = a.as_ref().lazy();
899 let result = expr.eval();
900
901 assert_eq!(result[(0, 0)], 1.0);
902 assert_eq!(result[(1, 1)], 4.0);
903 }
904
905 #[test]
906 fn test_lazy_add() {
907 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
908 let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
909
910 let expr = a.as_ref().lazy() + b.as_ref().lazy();
911 let result = expr.eval();
912
913 assert_eq!(result[(0, 0)], 6.0); assert_eq!(result[(0, 1)], 8.0); assert_eq!(result[(1, 0)], 10.0); assert_eq!(result[(1, 1)], 12.0); }
918
919 #[test]
920 fn test_lazy_sub() {
921 let a: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
922 let b: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
923
924 let expr = a.as_ref().lazy() - b.as_ref().lazy();
925 let result = expr.eval();
926
927 assert_eq!(result[(0, 0)], 4.0); assert_eq!(result[(0, 1)], 4.0); assert_eq!(result[(1, 0)], 4.0); assert_eq!(result[(1, 1)], 4.0); }
932
933 #[test]
934 fn test_lazy_neg() {
935 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
936
937 let expr = -a.as_ref().lazy();
938 let result = expr.eval();
939
940 assert_eq!(result[(0, 0)], -1.0);
941 assert_eq!(result[(1, 1)], -4.0);
942 }
943
944 #[test]
945 fn test_lazy_scale() {
946 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
947
948 let expr = a.as_ref().lazy().scale(2.0);
949 let result = expr.eval();
950
951 assert_eq!(result[(0, 0)], 2.0);
952 assert_eq!(result[(0, 1)], 4.0);
953 assert_eq!(result[(1, 0)], 6.0);
954 assert_eq!(result[(1, 1)], 8.0);
955 }
956
957 #[test]
958 fn test_lazy_transpose() {
959 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
960
961 let expr = a.as_ref().lazy().t();
962 let result = expr.eval();
963
964 assert_eq!(result.shape(), (3, 2));
965 assert_eq!(result[(0, 0)], 1.0);
966 assert_eq!(result[(1, 0)], 2.0);
967 assert_eq!(result[(2, 0)], 3.0);
968 assert_eq!(result[(0, 1)], 4.0);
969 assert_eq!(result[(1, 1)], 5.0);
970 assert_eq!(result[(2, 1)], 6.0);
971 }
972
973 #[test]
974 fn test_lazy_matmul() {
975 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
976 let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
977
978 let expr = a.as_ref().lazy().matmul(b.as_ref().lazy());
979 let result = expr.eval();
980
981 assert_eq!(result[(0, 0)], 19.0);
984 assert_eq!(result[(0, 1)], 22.0);
985 assert_eq!(result[(1, 0)], 43.0);
986 assert_eq!(result[(1, 1)], 50.0);
987 }
988
989 #[test]
990 fn test_lazy_chained() {
991 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
992 let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
993 let c: Mat<f64> = Mat::from_rows(&[&[1.0, 1.0], &[1.0, 1.0]]);
994
995 let expr = (a.as_ref().lazy() + b.as_ref().lazy()) - c.as_ref().lazy();
997 let result = expr.eval();
998
999 assert_eq!(result[(0, 0)], 5.0); assert_eq!(result[(0, 1)], 7.0); assert_eq!(result[(1, 0)], 9.0); assert_eq!(result[(1, 1)], 11.0); }
1004
1005 #[test]
1006 fn test_lazy_complex_conj() {
1007 use num_complex::Complex64;
1008
1009 let a: Mat<Complex64> = Mat::filled(2, 2, Complex64::new(0.0, 0.0));
1010 let mut a = a;
1011 a[(0, 0)] = Complex64::new(1.0, 2.0);
1012 a[(0, 1)] = Complex64::new(3.0, 4.0);
1013 a[(1, 0)] = Complex64::new(5.0, 6.0);
1014 a[(1, 1)] = Complex64::new(7.0, 8.0);
1015
1016 let expr = a.as_ref().lazy().conj();
1017 let result = expr.eval();
1018
1019 assert_eq!(result[(0, 0)], Complex64::new(1.0, -2.0));
1020 assert_eq!(result[(0, 1)], Complex64::new(3.0, -4.0));
1021 assert_eq!(result[(1, 0)], Complex64::new(5.0, -6.0));
1022 assert_eq!(result[(1, 1)], Complex64::new(7.0, -8.0));
1023 }
1024
1025 #[test]
1026 fn test_lazy_hermitian() {
1027 use num_complex::Complex64;
1028
1029 let mut a: Mat<Complex64> = Mat::filled(2, 2, Complex64::new(0.0, 0.0));
1030 a[(0, 0)] = Complex64::new(1.0, 2.0);
1031 a[(0, 1)] = Complex64::new(3.0, 4.0);
1032 a[(1, 0)] = Complex64::new(5.0, 6.0);
1033 a[(1, 1)] = Complex64::new(7.0, 8.0);
1034
1035 let expr = a.as_ref().lazy().h();
1036 let result = expr.eval();
1037
1038 assert_eq!(result.shape(), (2, 2));
1040 assert_eq!(result[(0, 0)], Complex64::new(1.0, -2.0));
1041 assert_eq!(result[(1, 0)], Complex64::new(3.0, -4.0));
1042 assert_eq!(result[(0, 1)], Complex64::new(5.0, -6.0));
1043 assert_eq!(result[(1, 1)], Complex64::new(7.0, -8.0));
1044 }
1045
1046 #[test]
1047 fn test_double_transpose_simplify() {
1048 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1049
1050 let expr = a.as_ref().lazy().t().t();
1051 let simplified = expr.simplify();
1052 let result = simplified.eval();
1053
1054 assert_eq!(result[(0, 0)], 1.0);
1055 assert_eq!(result[(1, 1)], 4.0);
1056 }
1057
1058 #[test]
1059 fn test_double_scale_simplify() {
1060 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1061
1062 let expr = a.as_ref().lazy().scale(2.0).scale(3.0);
1063 let simplified = expr.simplify();
1064 let result = simplified.eval();
1065
1066 assert_eq!(result[(0, 0)], 6.0);
1068 assert_eq!(result[(1, 1)], 24.0);
1069 }
1070
1071 #[test]
1072 fn test_fma() {
1073 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1074 let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
1075
1076 let expr = fma(2.0, a.as_ref().lazy(), 3.0, b.as_ref().lazy());
1077 let result = expr.eval();
1078
1079 assert_eq!(result[(0, 0)], 2.0 * 1.0 + 3.0 * 5.0); assert_eq!(result[(0, 1)], 2.0 * 2.0 + 3.0 * 6.0); assert_eq!(result[(1, 0)], 2.0 * 3.0 + 3.0 * 7.0); assert_eq!(result[(1, 1)], 2.0 * 4.0 + 3.0 * 8.0); }
1085
1086 #[test]
1087 fn test_gemm() {
1088 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1089 let b: Mat<f64> = Mat::from_rows(&[&[1.0, 0.0], &[0.0, 1.0]]);
1090 let c: Mat<f64> = Mat::from_rows(&[&[10.0, 10.0], &[10.0, 10.0]]);
1091
1092 let expr = gemm(
1093 2.0,
1094 a.as_ref().lazy(),
1095 b.as_ref().lazy(),
1096 1.0,
1097 c.as_ref().lazy(),
1098 );
1099 let result = expr.eval();
1100
1101 assert_eq!(result[(0, 0)], 2.0 * 1.0 + 10.0); assert_eq!(result[(0, 1)], 2.0 * 2.0 + 10.0); assert_eq!(result[(1, 0)], 2.0 * 3.0 + 10.0); assert_eq!(result[(1, 1)], 2.0 * 4.0 + 10.0); }
1107
1108 #[test]
1109 fn test_eval_into() {
1110 let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1111 let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
1112
1113 let expr = a.as_ref().lazy() + b.as_ref().lazy();
1114 let mut result: Mat<f64> = Mat::zeros(2, 2);
1115 expr.eval_into(&mut result);
1116
1117 assert_eq!(result[(0, 0)], 6.0);
1118 assert_eq!(result[(1, 1)], 12.0);
1119 }
1120}