oxiblas_matrix/
lazy.rs

1//! Lazy evaluation for matrix operations.
2//!
3//! This module provides expression types that defer computation until explicitly evaluated.
4//! This enables operation fusion and optimization for chained matrix operations.
5//!
6//! # Expression Types
7//!
8//! - [`ExprAdd`]: Element-wise addition
9//! - [`ExprSub`]: Element-wise subtraction
10//! - [`ExprNeg`]: Element-wise negation
11//! - [`ExprScale`]: Scalar multiplication
12//! - [`ExprMul`]: Matrix-matrix multiplication
13//! - [`ExprTranspose`]: Matrix transpose
14//! - [`ExprConj`]: Complex conjugate
15//! - [`ExprHermitian`]: Conjugate transpose
16//!
17//! # Example
18//!
19//! ```
20//! use oxiblas_matrix::{Mat, lazy::*};
21//!
22//! let a: Mat<f64> = Mat::from_rows(&[
23//!     &[1.0, 2.0],
24//!     &[3.0, 4.0],
25//! ]);
26//! let b: Mat<f64> = Mat::from_rows(&[
27//!     &[5.0, 6.0],
28//!     &[7.0, 8.0],
29//! ]);
30//!
31//! // Build an expression tree (no computation yet)
32//! let expr = a.as_ref().lazy() + b.as_ref().lazy();
33//!
34//! // Evaluate the expression
35//! let result: Mat<f64> = expr.eval();
36//! assert_eq!(result[(0, 0)], 6.0); // 1 + 5
37//! assert_eq!(result[(1, 1)], 12.0); // 4 + 8
38//! ```
39//!
40//! # Optimization
41//!
42//! Lazy evaluation enables several optimizations:
43//!
44//! 1. **Fusion**: Multiple element-wise operations can be fused into a single pass
45//! 2. **Transpose elimination**: `(A^T)^T` can be simplified to `A`
46//! 3. **Scale accumulation**: `a * (b * A)` becomes `(a*b) * A`
47//! 4. **Memory efficiency**: Intermediate matrices are not allocated
48
49use crate::mat::Mat;
50use crate::mat_ref::MatRef;
51use core::marker::PhantomData;
52use num_complex::Complex;
53use num_traits::Zero;
54use oxiblas_core::scalar::Scalar;
55
56// =============================================================================
57// Expr trait - Base trait for lazy expressions
58// =============================================================================
59
60/// Trait for lazy matrix expressions.
61///
62/// All expression types implement this trait, providing common methods
63/// for querying dimensions and evaluating the expression.
64pub trait Expr: Sized {
65    /// The element type of the expression.
66    type Elem: Scalar + bytemuck::Zeroable + Zero;
67
68    /// Returns the number of rows in the result.
69    fn nrows(&self) -> usize;
70
71    /// Returns the number of columns in the result.
72    fn ncols(&self) -> usize;
73
74    /// Returns the shape as (nrows, ncols).
75    fn shape(&self) -> (usize, usize) {
76        (self.nrows(), self.ncols())
77    }
78
79    /// Evaluates the expression and returns the result as an owned matrix.
80    fn eval(&self) -> Mat<Self::Elem>;
81
82    /// Evaluates the expression into an existing matrix.
83    fn eval_into(&self, target: &mut Mat<Self::Elem>);
84
85    /// Wraps this expression in a transpose expression.
86    fn t(self) -> ExprTranspose<Self> {
87        ExprTranspose { inner: self }
88    }
89
90    /// Scales this expression by a scalar.
91    fn scale(self, alpha: Self::Elem) -> ExprScale<Self> {
92        ExprScale { inner: self, alpha }
93    }
94
95    /// Adds another expression to this one.
96    fn add<E: Expr<Elem = Self::Elem>>(self, other: E) -> ExprAdd<Self, E> {
97        debug_assert_eq!(
98            self.shape(),
99            other.shape(),
100            "Matrix dimensions must match for addition"
101        );
102        ExprAdd {
103            lhs: self,
104            rhs: other,
105        }
106    }
107
108    /// Subtracts another expression from this one.
109    fn sub<E: Expr<Elem = Self::Elem>>(self, other: E) -> ExprSub<Self, E> {
110        debug_assert_eq!(
111            self.shape(),
112            other.shape(),
113            "Matrix dimensions must match for subtraction"
114        );
115        ExprSub {
116            lhs: self,
117            rhs: other,
118        }
119    }
120
121    /// Negates this expression.
122    fn neg(self) -> ExprNeg<Self> {
123        ExprNeg { inner: self }
124    }
125
126    /// Matrix multiplies this expression with another.
127    fn matmul<E: Expr<Elem = Self::Elem>>(self, other: E) -> ExprMul<Self, E> {
128        debug_assert_eq!(
129            self.ncols(),
130            other.nrows(),
131            "Matrix dimensions must be compatible for multiplication"
132        );
133        ExprMul {
134            lhs: self,
135            rhs: other,
136        }
137    }
138}
139
140/// Extension trait for complex expressions.
141pub trait ComplexExpr: Expr
142where
143    Self::Elem: ComplexScalar,
144{
145    /// Returns the complex conjugate of this expression.
146    fn conj(self) -> ExprConj<Self> {
147        ExprConj { inner: self }
148    }
149
150    /// Returns the conjugate transpose (Hermitian) of this expression.
151    fn h(self) -> ExprHermitian<Self> {
152        ExprHermitian { inner: self }
153    }
154}
155
156// Blanket implementation for complex expressions
157impl<E: Expr> ComplexExpr for E where E::Elem: ComplexScalar {}
158
159/// Marker trait for complex scalar types.
160pub trait ComplexScalar: Scalar + bytemuck::Zeroable + Zero {
161    /// Returns the complex conjugate.
162    fn conj(&self) -> Self;
163}
164
165impl ComplexScalar for Complex<f32> {
166    fn conj(&self) -> Self {
167        Complex::conj(self)
168    }
169}
170
171impl ComplexScalar for Complex<f64> {
172    fn conj(&self) -> Self {
173        Complex::conj(self)
174    }
175}
176
177// Real numbers are their own conjugate
178impl ComplexScalar for f32 {
179    fn conj(&self) -> Self {
180        *self
181    }
182}
183
184impl ComplexScalar for f64 {
185    fn conj(&self) -> Self {
186        *self
187    }
188}
189
190// =============================================================================
191// ExprLeaf - Wraps a MatRef as a lazy expression
192// =============================================================================
193
194/// A leaf expression wrapping a matrix reference.
195#[derive(Clone, Copy)]
196pub struct ExprLeaf<'a, T: Scalar + bytemuck::Zeroable + Zero> {
197    mat: MatRef<'a, T>,
198}
199
200impl<'a, T: Scalar + bytemuck::Zeroable + Zero> ExprLeaf<'a, T> {
201    /// Creates a new leaf expression from a matrix reference.
202    pub fn new(mat: MatRef<'a, T>) -> Self {
203        Self { mat }
204    }
205}
206
207impl<'a, T: Scalar + bytemuck::Zeroable + Zero> Expr for ExprLeaf<'a, T> {
208    type Elem = T;
209
210    fn nrows(&self) -> usize {
211        self.mat.nrows()
212    }
213
214    fn ncols(&self) -> usize {
215        self.mat.ncols()
216    }
217
218    fn eval(&self) -> Mat<T> {
219        let mut result = Mat::zeros(self.nrows(), self.ncols());
220        self.eval_into(&mut result);
221        result
222    }
223
224    fn eval_into(&self, target: &mut Mat<T>) {
225        target.copy_from(&self.mat);
226    }
227}
228
229/// Trait extension for MatRef to create lazy expressions.
230pub trait LazyExt<'a, T: Scalar + bytemuck::Zeroable + Zero> {
231    /// Creates a lazy expression from this matrix reference.
232    fn lazy(self) -> ExprLeaf<'a, T>;
233}
234
235impl<'a, T: Scalar + bytemuck::Zeroable + Zero> LazyExt<'a, T> for MatRef<'a, T> {
236    fn lazy(self) -> ExprLeaf<'a, T> {
237        ExprLeaf::new(self)
238    }
239}
240
241// =============================================================================
242// ExprTranspose - Transpose expression
243// =============================================================================
244
245/// A lazy transpose expression.
246pub struct ExprTranspose<E: Expr> {
247    inner: E,
248}
249
250impl<E: Expr> Expr for ExprTranspose<E> {
251    type Elem = E::Elem;
252
253    fn nrows(&self) -> usize {
254        self.inner.ncols()
255    }
256
257    fn ncols(&self) -> usize {
258        self.inner.nrows()
259    }
260
261    fn eval(&self) -> Mat<Self::Elem> {
262        let mut result = Mat::zeros(self.nrows(), self.ncols());
263        self.eval_into(&mut result);
264        result
265    }
266
267    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
268        let inner = self.inner.eval();
269        for i in 0..self.nrows() {
270            for j in 0..self.ncols() {
271                target[(i, j)] = inner[(j, i)];
272            }
273        }
274    }
275}
276
277// Double transpose optimization: (A^T)^T = A
278impl<E: Expr> ExprTranspose<ExprTranspose<E>> {
279    /// Eliminates double transpose.
280    pub fn simplify(self) -> E {
281        self.inner.inner
282    }
283}
284
285// =============================================================================
286// ExprScale - Scalar multiplication expression
287// =============================================================================
288
289/// A lazy scalar multiplication expression.
290pub struct ExprScale<E: Expr> {
291    inner: E,
292    alpha: E::Elem,
293}
294
295impl<E: Expr> Expr for ExprScale<E> {
296    type Elem = E::Elem;
297
298    fn nrows(&self) -> usize {
299        self.inner.nrows()
300    }
301
302    fn ncols(&self) -> usize {
303        self.inner.ncols()
304    }
305
306    fn eval(&self) -> Mat<Self::Elem> {
307        let mut result = Mat::zeros(self.nrows(), self.ncols());
308        self.eval_into(&mut result);
309        result
310    }
311
312    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
313        let inner = self.inner.eval();
314        for i in 0..self.nrows() {
315            for j in 0..self.ncols() {
316                target[(i, j)] = inner[(i, j)] * self.alpha;
317            }
318        }
319    }
320}
321
322// Double scale optimization: a * (b * A) = (a*b) * A
323impl<E: Expr> ExprScale<ExprScale<E>> {
324    /// Combines nested scales.
325    pub fn simplify(self) -> ExprScale<E> {
326        ExprScale {
327            inner: self.inner.inner,
328            alpha: self.alpha * self.inner.alpha,
329        }
330    }
331}
332
333// =============================================================================
334// ExprNeg - Negation expression
335// =============================================================================
336
337/// A lazy negation expression.
338pub struct ExprNeg<E: Expr> {
339    inner: E,
340}
341
342impl<E: Expr> Expr for ExprNeg<E> {
343    type Elem = E::Elem;
344
345    fn nrows(&self) -> usize {
346        self.inner.nrows()
347    }
348
349    fn ncols(&self) -> usize {
350        self.inner.ncols()
351    }
352
353    fn eval(&self) -> Mat<Self::Elem> {
354        let mut result = Mat::zeros(self.nrows(), self.ncols());
355        self.eval_into(&mut result);
356        result
357    }
358
359    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
360        let inner = self.inner.eval();
361        for i in 0..self.nrows() {
362            for j in 0..self.ncols() {
363                target[(i, j)] = E::Elem::zero() - inner[(i, j)];
364            }
365        }
366    }
367}
368
369// Double negation optimization: --A = A
370impl<E: Expr> ExprNeg<ExprNeg<E>> {
371    /// Eliminates double negation.
372    pub fn simplify(self) -> E {
373        self.inner.inner
374    }
375}
376
377// =============================================================================
378// ExprAdd - Addition expression
379// =============================================================================
380
381/// A lazy addition expression.
382pub struct ExprAdd<L: Expr, R: Expr<Elem = L::Elem>> {
383    lhs: L,
384    rhs: R,
385}
386
387impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprAdd<L, R> {
388    type Elem = L::Elem;
389
390    fn nrows(&self) -> usize {
391        self.lhs.nrows()
392    }
393
394    fn ncols(&self) -> usize {
395        self.lhs.ncols()
396    }
397
398    fn eval(&self) -> Mat<Self::Elem> {
399        let mut result = Mat::zeros(self.nrows(), self.ncols());
400        self.eval_into(&mut result);
401        result
402    }
403
404    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
405        let lhs = self.lhs.eval();
406        let rhs = self.rhs.eval();
407        for i in 0..self.nrows() {
408            for j in 0..self.ncols() {
409                target[(i, j)] = lhs[(i, j)] + rhs[(i, j)];
410            }
411        }
412    }
413}
414
415// =============================================================================
416// ExprSub - Subtraction expression
417// =============================================================================
418
419/// A lazy subtraction expression.
420pub struct ExprSub<L: Expr, R: Expr<Elem = L::Elem>> {
421    lhs: L,
422    rhs: R,
423}
424
425impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprSub<L, R> {
426    type Elem = L::Elem;
427
428    fn nrows(&self) -> usize {
429        self.lhs.nrows()
430    }
431
432    fn ncols(&self) -> usize {
433        self.lhs.ncols()
434    }
435
436    fn eval(&self) -> Mat<Self::Elem> {
437        let mut result = Mat::zeros(self.nrows(), self.ncols());
438        self.eval_into(&mut result);
439        result
440    }
441
442    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
443        let lhs = self.lhs.eval();
444        let rhs = self.rhs.eval();
445        for i in 0..self.nrows() {
446            for j in 0..self.ncols() {
447                target[(i, j)] = lhs[(i, j)] - rhs[(i, j)];
448            }
449        }
450    }
451}
452
453// =============================================================================
454// ExprMul - Matrix multiplication expression
455// =============================================================================
456
457/// A lazy matrix multiplication expression.
458pub struct ExprMul<L: Expr, R: Expr<Elem = L::Elem>> {
459    lhs: L,
460    rhs: R,
461}
462
463impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprMul<L, R> {
464    type Elem = L::Elem;
465
466    fn nrows(&self) -> usize {
467        self.lhs.nrows()
468    }
469
470    fn ncols(&self) -> usize {
471        self.rhs.ncols()
472    }
473
474    fn eval(&self) -> Mat<Self::Elem> {
475        let mut result = Mat::zeros(self.nrows(), self.ncols());
476        self.eval_into(&mut result);
477        result
478    }
479
480    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
481        let lhs = self.lhs.eval();
482        let rhs = self.rhs.eval();
483        let k = self.lhs.ncols();
484
485        for i in 0..self.nrows() {
486            for j in 0..self.ncols() {
487                let mut sum = L::Elem::zero();
488                for kk in 0..k {
489                    sum += lhs[(i, kk)] * rhs[(kk, j)];
490                }
491                target[(i, j)] = sum;
492            }
493        }
494    }
495}
496
497// =============================================================================
498// ExprConj - Complex conjugate expression
499// =============================================================================
500
501/// A lazy complex conjugate expression.
502pub struct ExprConj<E: Expr>
503where
504    E::Elem: ComplexScalar,
505{
506    inner: E,
507}
508
509impl<E: Expr> Expr for ExprConj<E>
510where
511    E::Elem: ComplexScalar,
512{
513    type Elem = E::Elem;
514
515    fn nrows(&self) -> usize {
516        self.inner.nrows()
517    }
518
519    fn ncols(&self) -> usize {
520        self.inner.ncols()
521    }
522
523    fn eval(&self) -> Mat<Self::Elem> {
524        let mut result = Mat::zeros(self.nrows(), self.ncols());
525        self.eval_into(&mut result);
526        result
527    }
528
529    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
530        let inner = self.inner.eval();
531        for i in 0..self.nrows() {
532            for j in 0..self.ncols() {
533                target[(i, j)] = inner[(i, j)].conj();
534            }
535        }
536    }
537}
538
539// Double conjugate optimization: conj(conj(A)) = A
540impl<E: Expr> ExprConj<ExprConj<E>>
541where
542    E::Elem: ComplexScalar,
543{
544    /// Eliminates double conjugation.
545    pub fn simplify(self) -> E {
546        self.inner.inner
547    }
548}
549
550// =============================================================================
551// ExprHermitian - Conjugate transpose expression
552// =============================================================================
553
554/// A lazy conjugate transpose (Hermitian) expression.
555pub struct ExprHermitian<E: Expr>
556where
557    E::Elem: ComplexScalar,
558{
559    inner: E,
560}
561
562impl<E: Expr> Expr for ExprHermitian<E>
563where
564    E::Elem: ComplexScalar,
565{
566    type Elem = E::Elem;
567
568    fn nrows(&self) -> usize {
569        self.inner.ncols()
570    }
571
572    fn ncols(&self) -> usize {
573        self.inner.nrows()
574    }
575
576    fn eval(&self) -> Mat<Self::Elem> {
577        let mut result = Mat::zeros(self.nrows(), self.ncols());
578        self.eval_into(&mut result);
579        result
580    }
581
582    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
583        let inner = self.inner.eval();
584        for i in 0..self.nrows() {
585            for j in 0..self.ncols() {
586                target[(i, j)] = inner[(j, i)].conj();
587            }
588        }
589    }
590}
591
592// Double Hermitian optimization: (A^H)^H = A
593impl<E: Expr> ExprHermitian<ExprHermitian<E>>
594where
595    E::Elem: ComplexScalar,
596{
597    /// Eliminates double conjugate transpose.
598    pub fn simplify(self) -> E {
599        self.inner.inner
600    }
601}
602
603// =============================================================================
604// Operator overloading for expressions
605// =============================================================================
606
607impl<'a, T: Scalar + bytemuck::Zeroable + Zero, R: Expr<Elem = T>> core::ops::Add<R>
608    for ExprLeaf<'a, T>
609{
610    type Output = ExprAdd<Self, R>;
611
612    fn add(self, rhs: R) -> Self::Output {
613        Expr::add(self, rhs)
614    }
615}
616
617impl<'a, T: Scalar + bytemuck::Zeroable + Zero, R: Expr<Elem = T>> core::ops::Sub<R>
618    for ExprLeaf<'a, T>
619{
620    type Output = ExprSub<Self, R>;
621
622    fn sub(self, rhs: R) -> Self::Output {
623        Expr::sub(self, rhs)
624    }
625}
626
627impl<'a, T: Scalar + bytemuck::Zeroable + Zero> core::ops::Neg for ExprLeaf<'a, T> {
628    type Output = ExprNeg<Self>;
629
630    fn neg(self) -> Self::Output {
631        Expr::neg(self)
632    }
633}
634
635// Add for ExprAdd
636impl<L1, R1, L2: Expr<Elem = L1::Elem>> core::ops::Add<L2> for ExprAdd<L1, R1>
637where
638    L1: Expr,
639    R1: Expr<Elem = L1::Elem>,
640{
641    type Output = ExprAdd<Self, L2>;
642
643    fn add(self, rhs: L2) -> Self::Output {
644        Expr::add(self, rhs)
645    }
646}
647
648// Sub for ExprAdd
649impl<L1, R1, L2: Expr<Elem = L1::Elem>> core::ops::Sub<L2> for ExprAdd<L1, R1>
650where
651    L1: Expr,
652    R1: Expr<Elem = L1::Elem>,
653{
654    type Output = ExprSub<Self, L2>;
655
656    fn sub(self, rhs: L2) -> Self::Output {
657        Expr::sub(self, rhs)
658    }
659}
660
661// Neg for ExprAdd
662impl<L1, R1> core::ops::Neg for ExprAdd<L1, R1>
663where
664    L1: Expr,
665    R1: Expr<Elem = L1::Elem>,
666{
667    type Output = ExprNeg<Self>;
668
669    fn neg(self) -> Self::Output {
670        Expr::neg(self)
671    }
672}
673
674// Add for ExprScale
675impl<E, R: Expr<Elem = E::Elem>> core::ops::Add<R> for ExprScale<E>
676where
677    E: Expr,
678{
679    type Output = ExprAdd<Self, R>;
680
681    fn add(self, rhs: R) -> Self::Output {
682        Expr::add(self, rhs)
683    }
684}
685
686// Sub for ExprScale
687impl<E, R: Expr<Elem = E::Elem>> core::ops::Sub<R> for ExprScale<E>
688where
689    E: Expr,
690{
691    type Output = ExprSub<Self, R>;
692
693    fn sub(self, rhs: R) -> Self::Output {
694        Expr::sub(self, rhs)
695    }
696}
697
698// Neg for ExprScale
699impl<E> core::ops::Neg for ExprScale<E>
700where
701    E: Expr,
702{
703    type Output = ExprNeg<Self>;
704
705    fn neg(self) -> Self::Output {
706        Expr::neg(self)
707    }
708}
709
710// Add for ExprTranspose
711impl<E, R: Expr<Elem = E::Elem>> core::ops::Add<R> for ExprTranspose<E>
712where
713    E: Expr,
714{
715    type Output = ExprAdd<Self, R>;
716
717    fn add(self, rhs: R) -> Self::Output {
718        Expr::add(self, rhs)
719    }
720}
721
722// Sub for ExprTranspose
723impl<E, R: Expr<Elem = E::Elem>> core::ops::Sub<R> for ExprTranspose<E>
724where
725    E: Expr,
726{
727    type Output = ExprSub<Self, R>;
728
729    fn sub(self, rhs: R) -> Self::Output {
730        Expr::sub(self, rhs)
731    }
732}
733
734// Neg for ExprTranspose
735impl<E> core::ops::Neg for ExprTranspose<E>
736where
737    E: Expr,
738{
739    type Output = ExprNeg<Self>;
740
741    fn neg(self) -> Self::Output {
742        Expr::neg(self)
743    }
744}
745
746// =============================================================================
747// FusedExpr - Optimized fused operations
748// =============================================================================
749
750/// Fused multiply-add expression: alpha * A + beta * B
751pub struct ExprFma<L: Expr, R: Expr<Elem = L::Elem>> {
752    lhs: L,
753    rhs: R,
754    alpha: L::Elem,
755    beta: L::Elem,
756}
757
758impl<L: Expr, R: Expr<Elem = L::Elem>> ExprFma<L, R> {
759    /// Creates a new fused multiply-add expression.
760    pub fn new(lhs: L, rhs: R, alpha: L::Elem, beta: L::Elem) -> Self {
761        debug_assert_eq!(lhs.shape(), rhs.shape());
762        Self {
763            lhs,
764            rhs,
765            alpha,
766            beta,
767        }
768    }
769}
770
771impl<L: Expr, R: Expr<Elem = L::Elem>> Expr for ExprFma<L, R> {
772    type Elem = L::Elem;
773
774    fn nrows(&self) -> usize {
775        self.lhs.nrows()
776    }
777
778    fn ncols(&self) -> usize {
779        self.lhs.ncols()
780    }
781
782    fn eval(&self) -> Mat<Self::Elem> {
783        let mut result = Mat::zeros(self.nrows(), self.ncols());
784        self.eval_into(&mut result);
785        result
786    }
787
788    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
789        let lhs = self.lhs.eval();
790        let rhs = self.rhs.eval();
791        for i in 0..self.nrows() {
792            for j in 0..self.ncols() {
793                target[(i, j)] = self.alpha * lhs[(i, j)] + self.beta * rhs[(i, j)];
794            }
795        }
796    }
797}
798
799/// GEMM expression: alpha * A * B + beta * C
800pub struct ExprGemm<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>> {
801    a: A,
802    b: B,
803    c: C,
804    alpha: A::Elem,
805    beta: A::Elem,
806    _marker: PhantomData<A::Elem>,
807}
808
809impl<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>> ExprGemm<A, B, C> {
810    /// Creates a new GEMM expression.
811    pub fn new(a: A, b: B, c: C, alpha: A::Elem, beta: A::Elem) -> Self {
812        debug_assert_eq!(a.ncols(), b.nrows());
813        debug_assert_eq!(a.nrows(), c.nrows());
814        debug_assert_eq!(b.ncols(), c.ncols());
815        Self {
816            a,
817            b,
818            c,
819            alpha,
820            beta,
821            _marker: PhantomData,
822        }
823    }
824}
825
826impl<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>> Expr for ExprGemm<A, B, C> {
827    type Elem = A::Elem;
828
829    fn nrows(&self) -> usize {
830        self.a.nrows()
831    }
832
833    fn ncols(&self) -> usize {
834        self.b.ncols()
835    }
836
837    fn eval(&self) -> Mat<Self::Elem> {
838        let mut result = Mat::zeros(self.nrows(), self.ncols());
839        self.eval_into(&mut result);
840        result
841    }
842
843    fn eval_into(&self, target: &mut Mat<Self::Elem>) {
844        let a = self.a.eval();
845        let b = self.b.eval();
846        let c = self.c.eval();
847        let k = self.a.ncols();
848
849        for i in 0..self.nrows() {
850            for j in 0..self.ncols() {
851                let mut sum = Self::Elem::zero();
852                for kk in 0..k {
853                    sum += a[(i, kk)] * b[(kk, j)];
854                }
855                target[(i, j)] = self.alpha * sum + self.beta * c[(i, j)];
856            }
857        }
858    }
859}
860
861// =============================================================================
862// Expression builder helpers
863// =============================================================================
864
865/// Creates a fused multiply-add expression: alpha * A + beta * B
866pub fn fma<L: Expr, R: Expr<Elem = L::Elem>>(
867    alpha: L::Elem,
868    a: L,
869    beta: L::Elem,
870    b: R,
871) -> ExprFma<L, R> {
872    ExprFma::new(a, b, alpha, beta)
873}
874
875/// Creates a GEMM expression: alpha * A * B + beta * C
876pub fn gemm<A: Expr, B: Expr<Elem = A::Elem>, C: Expr<Elem = A::Elem>>(
877    alpha: A::Elem,
878    a: A,
879    b: B,
880    beta: A::Elem,
881    c: C,
882) -> ExprGemm<A, B, C> {
883    ExprGemm::new(a, b, c, alpha, beta)
884}
885
886// =============================================================================
887// Tests
888// =============================================================================
889
890#[cfg(test)]
891mod tests {
892    use super::*;
893
894    #[test]
895    fn test_lazy_leaf() {
896        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
897
898        let expr = a.as_ref().lazy();
899        let result = expr.eval();
900
901        assert_eq!(result[(0, 0)], 1.0);
902        assert_eq!(result[(1, 1)], 4.0);
903    }
904
905    #[test]
906    fn test_lazy_add() {
907        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
908        let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
909
910        let expr = a.as_ref().lazy() + b.as_ref().lazy();
911        let result = expr.eval();
912
913        assert_eq!(result[(0, 0)], 6.0); // 1 + 5
914        assert_eq!(result[(0, 1)], 8.0); // 2 + 6
915        assert_eq!(result[(1, 0)], 10.0); // 3 + 7
916        assert_eq!(result[(1, 1)], 12.0); // 4 + 8
917    }
918
919    #[test]
920    fn test_lazy_sub() {
921        let a: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
922        let b: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
923
924        let expr = a.as_ref().lazy() - b.as_ref().lazy();
925        let result = expr.eval();
926
927        assert_eq!(result[(0, 0)], 4.0); // 5 - 1
928        assert_eq!(result[(0, 1)], 4.0); // 6 - 2
929        assert_eq!(result[(1, 0)], 4.0); // 7 - 3
930        assert_eq!(result[(1, 1)], 4.0); // 8 - 4
931    }
932
933    #[test]
934    fn test_lazy_neg() {
935        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
936
937        let expr = -a.as_ref().lazy();
938        let result = expr.eval();
939
940        assert_eq!(result[(0, 0)], -1.0);
941        assert_eq!(result[(1, 1)], -4.0);
942    }
943
944    #[test]
945    fn test_lazy_scale() {
946        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
947
948        let expr = a.as_ref().lazy().scale(2.0);
949        let result = expr.eval();
950
951        assert_eq!(result[(0, 0)], 2.0);
952        assert_eq!(result[(0, 1)], 4.0);
953        assert_eq!(result[(1, 0)], 6.0);
954        assert_eq!(result[(1, 1)], 8.0);
955    }
956
957    #[test]
958    fn test_lazy_transpose() {
959        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
960
961        let expr = a.as_ref().lazy().t();
962        let result = expr.eval();
963
964        assert_eq!(result.shape(), (3, 2));
965        assert_eq!(result[(0, 0)], 1.0);
966        assert_eq!(result[(1, 0)], 2.0);
967        assert_eq!(result[(2, 0)], 3.0);
968        assert_eq!(result[(0, 1)], 4.0);
969        assert_eq!(result[(1, 1)], 5.0);
970        assert_eq!(result[(2, 1)], 6.0);
971    }
972
973    #[test]
974    fn test_lazy_matmul() {
975        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
976        let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
977
978        let expr = a.as_ref().lazy().matmul(b.as_ref().lazy());
979        let result = expr.eval();
980
981        // [1 2] * [5 6] = [1*5+2*7  1*6+2*8] = [19 22]
982        // [3 4]   [7 8]   [3*5+4*7  3*6+4*8]   [43 50]
983        assert_eq!(result[(0, 0)], 19.0);
984        assert_eq!(result[(0, 1)], 22.0);
985        assert_eq!(result[(1, 0)], 43.0);
986        assert_eq!(result[(1, 1)], 50.0);
987    }
988
989    #[test]
990    fn test_lazy_chained() {
991        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
992        let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
993        let c: Mat<f64> = Mat::from_rows(&[&[1.0, 1.0], &[1.0, 1.0]]);
994
995        // (A + B) - C
996        let expr = (a.as_ref().lazy() + b.as_ref().lazy()) - c.as_ref().lazy();
997        let result = expr.eval();
998
999        assert_eq!(result[(0, 0)], 5.0); // 1 + 5 - 1
1000        assert_eq!(result[(0, 1)], 7.0); // 2 + 6 - 1
1001        assert_eq!(result[(1, 0)], 9.0); // 3 + 7 - 1
1002        assert_eq!(result[(1, 1)], 11.0); // 4 + 8 - 1
1003    }
1004
1005    #[test]
1006    fn test_lazy_complex_conj() {
1007        use num_complex::Complex64;
1008
1009        let a: Mat<Complex64> = Mat::filled(2, 2, Complex64::new(0.0, 0.0));
1010        let mut a = a;
1011        a[(0, 0)] = Complex64::new(1.0, 2.0);
1012        a[(0, 1)] = Complex64::new(3.0, 4.0);
1013        a[(1, 0)] = Complex64::new(5.0, 6.0);
1014        a[(1, 1)] = Complex64::new(7.0, 8.0);
1015
1016        let expr = a.as_ref().lazy().conj();
1017        let result = expr.eval();
1018
1019        assert_eq!(result[(0, 0)], Complex64::new(1.0, -2.0));
1020        assert_eq!(result[(0, 1)], Complex64::new(3.0, -4.0));
1021        assert_eq!(result[(1, 0)], Complex64::new(5.0, -6.0));
1022        assert_eq!(result[(1, 1)], Complex64::new(7.0, -8.0));
1023    }
1024
1025    #[test]
1026    fn test_lazy_hermitian() {
1027        use num_complex::Complex64;
1028
1029        let mut a: Mat<Complex64> = Mat::filled(2, 2, Complex64::new(0.0, 0.0));
1030        a[(0, 0)] = Complex64::new(1.0, 2.0);
1031        a[(0, 1)] = Complex64::new(3.0, 4.0);
1032        a[(1, 0)] = Complex64::new(5.0, 6.0);
1033        a[(1, 1)] = Complex64::new(7.0, 8.0);
1034
1035        let expr = a.as_ref().lazy().h();
1036        let result = expr.eval();
1037
1038        // Transpose + conjugate
1039        assert_eq!(result.shape(), (2, 2));
1040        assert_eq!(result[(0, 0)], Complex64::new(1.0, -2.0));
1041        assert_eq!(result[(1, 0)], Complex64::new(3.0, -4.0));
1042        assert_eq!(result[(0, 1)], Complex64::new(5.0, -6.0));
1043        assert_eq!(result[(1, 1)], Complex64::new(7.0, -8.0));
1044    }
1045
1046    #[test]
1047    fn test_double_transpose_simplify() {
1048        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1049
1050        let expr = a.as_ref().lazy().t().t();
1051        let simplified = expr.simplify();
1052        let result = simplified.eval();
1053
1054        assert_eq!(result[(0, 0)], 1.0);
1055        assert_eq!(result[(1, 1)], 4.0);
1056    }
1057
1058    #[test]
1059    fn test_double_scale_simplify() {
1060        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1061
1062        let expr = a.as_ref().lazy().scale(2.0).scale(3.0);
1063        let simplified = expr.simplify();
1064        let result = simplified.eval();
1065
1066        // 2.0 * 3.0 = 6.0
1067        assert_eq!(result[(0, 0)], 6.0);
1068        assert_eq!(result[(1, 1)], 24.0);
1069    }
1070
1071    #[test]
1072    fn test_fma() {
1073        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1074        let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
1075
1076        let expr = fma(2.0, a.as_ref().lazy(), 3.0, b.as_ref().lazy());
1077        let result = expr.eval();
1078
1079        // 2*A + 3*B
1080        assert_eq!(result[(0, 0)], 2.0 * 1.0 + 3.0 * 5.0); // 17
1081        assert_eq!(result[(0, 1)], 2.0 * 2.0 + 3.0 * 6.0); // 22
1082        assert_eq!(result[(1, 0)], 2.0 * 3.0 + 3.0 * 7.0); // 27
1083        assert_eq!(result[(1, 1)], 2.0 * 4.0 + 3.0 * 8.0); // 32
1084    }
1085
1086    #[test]
1087    fn test_gemm() {
1088        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1089        let b: Mat<f64> = Mat::from_rows(&[&[1.0, 0.0], &[0.0, 1.0]]);
1090        let c: Mat<f64> = Mat::from_rows(&[&[10.0, 10.0], &[10.0, 10.0]]);
1091
1092        let expr = gemm(
1093            2.0,
1094            a.as_ref().lazy(),
1095            b.as_ref().lazy(),
1096            1.0,
1097            c.as_ref().lazy(),
1098        );
1099        let result = expr.eval();
1100
1101        // 2 * A * I + 1 * C = 2 * A + C
1102        assert_eq!(result[(0, 0)], 2.0 * 1.0 + 10.0); // 12
1103        assert_eq!(result[(0, 1)], 2.0 * 2.0 + 10.0); // 14
1104        assert_eq!(result[(1, 0)], 2.0 * 3.0 + 10.0); // 16
1105        assert_eq!(result[(1, 1)], 2.0 * 4.0 + 10.0); // 18
1106    }
1107
1108    #[test]
1109    fn test_eval_into() {
1110        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
1111        let b: Mat<f64> = Mat::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
1112
1113        let expr = a.as_ref().lazy() + b.as_ref().lazy();
1114        let mut result: Mat<f64> = Mat::zeros(2, 2);
1115        expr.eval_into(&mut result);
1116
1117        assert_eq!(result[(0, 0)], 6.0);
1118        assert_eq!(result[(1, 1)], 12.0);
1119    }
1120}