gemm-common 0.19.0

Playground for matrix multiplication algorithms
Documentation
#[macro_export]
macro_rules! horizontal_kernel {
    ($([$target: tt])?, $name: ident, $m: tt, $n: tt $(,)?) => {
        $(#[target_feature(enable = $target)])?
        pub unsafe fn $name(
            k: usize,
            dst: *mut T,
            lhs: *const T,
            rhs: *const T,
            dst_cs: isize,
            dst_rs: isize,
            lhs_rs: isize,
            rhs_cs: isize,
            alpha: T,
            beta: T,
            alpha_status: u8,
            _conj_dst: bool,
            _conj_lhs: bool,
            _conj_rhs: bool,
        ) {
            let mut accum = [[splat(::core::mem::zeroed()); $m]; $n];
            seq_macro::seq!(M_ITER in 0..$m {
               let lhs~M_ITER = lhs.wrapping_offset(M_ITER * lhs_rs);
            });
            seq_macro::seq!(N_ITER in 0..$n {
               let rhs~N_ITER = rhs.wrapping_offset(N_ITER * rhs_cs);
            });

            let mut depth = 0;
            while depth < k / N * N {
                seq_macro::seq!(M_ITER in 0..$m {
                   let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const [T; N]);
                });
                seq_macro::seq!(N_ITER in 0..$n {
                   let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const [T; N]);
                });

                seq_macro::seq!(M_ITER in 0..$m {
                    seq_macro::seq!(N_ITER in 0..$n {
                        accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
                    });
                });

                depth += N;
            }

            if depth < k {
                seq_macro::seq!(M_ITER in 0..$m {
                   let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
                });
                seq_macro::seq!(N_ITER in 0..$n {
                   let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
                });

                seq_macro::seq!(M_ITER in 0..$m {
                    seq_macro::seq!(N_ITER in 0..$n {
                        accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
                    });
                });
            }

            let mut accum_reduced: [[T; $m]; $n] = core::mem::zeroed();
            seq_macro::seq!(M_ITER in 0..$m {
                seq_macro::seq!(N_ITER in 0..$n {
                    accum_reduced[N_ITER][M_ITER] = reduce_sum(accum[N_ITER][M_ITER]);
                });
            });

            if alpha_status == 0 {
                seq_macro::seq!(M_ITER in 0..$m {{
                    seq_macro::seq!(N_ITER in 0..$n {{
                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
                        *dst = scalar_mul(beta, accum_reduced[N_ITER][M_ITER]);
                    }});
                }});
            } else if alpha_status == 1 {
                seq_macro::seq!(M_ITER in 0..$m {{
                    seq_macro::seq!(N_ITER in 0..$n {{
                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
                        *dst = scalar_mul_add(
                            beta,
                            accum_reduced[N_ITER][M_ITER],
                            *dst,
                        );
                    }});
                }});
            } else {
                seq_macro::seq!(M_ITER in 0..$m {{
                    seq_macro::seq!(N_ITER in 0..$n {{
                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
                        *dst = scalar_add(
                            scalar_mul(beta, accum_reduced[N_ITER][M_ITER]),
                            scalar_mul(alpha, *dst),
                        );
                    }});
                }});
            }
        }
    };
}

#[macro_export]
macro_rules! horizontal_cplx_kernel {
    ($([$target: tt])?, $name: ident, $m: tt, $n: tt $(,)?) => {
        $(#[target_feature(enable = $target)])?
        pub unsafe fn $name(
            k: usize,
            dst: *mut num_complex::Complex<T>,
            lhs: *const num_complex::Complex<T>,
            rhs: *const num_complex::Complex<T>,
            dst_cs: isize,
            dst_rs: isize,
            lhs_rs: isize,
            rhs_cs: isize,
            alpha: num_complex::Complex<T>,
            beta: num_complex::Complex<T>,
            alpha_status: u8,
            _conj_dst: bool,
            conj_lhs: bool,
            conj_rhs: bool,
        ) {
            let mut accum = [[splat(::core::mem::zeroed()); $m]; $n];
            seq_macro::seq!(M_ITER in 0..$m {
               let lhs~M_ITER = lhs.wrapping_offset(M_ITER * lhs_rs);
            });
            seq_macro::seq!(N_ITER in 0..$n {
               let rhs~N_ITER = rhs.wrapping_offset(N_ITER * rhs_cs);
            });

            let (conj_lhs, conj_all) = match (conj_lhs, conj_rhs) {
                (true, true) => (false, true),
                (false, true) => (true, true),
                (true, false) => (true, false),
                (false, false) => (false, false),
            };

            if conj_lhs {
                let mut depth = 0;
                while depth < k / CPLX_N * CPLX_N {
                    seq_macro::seq!(M_ITER in 0..$m {
                       let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const Pack);
                    });
                    seq_macro::seq!(N_ITER in 0..$n {
                       let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const Pack);
                    });

                    seq_macro::seq!(M_ITER in 0..$m {
                        seq_macro::seq!(N_ITER in 0..$n {
                            accum[N_ITER][M_ITER] = conj_mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
                        });
                    });

                    depth += CPLX_N;
                }

                if depth < k {
                    seq_macro::seq!(M_ITER in 0..$m {
                       let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
                    });
                    seq_macro::seq!(N_ITER in 0..$n {
                       let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
                    });

                    seq_macro::seq!(M_ITER in 0..$m {
                        seq_macro::seq!(N_ITER in 0..$n {
                            accum[N_ITER][M_ITER] = conj_mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
                        });
                    });
                }
            } else {
                let mut depth = 0;
                while depth < k / CPLX_N * CPLX_N {
                    seq_macro::seq!(M_ITER in 0..$m {
                       let lhs~M_ITER = *(lhs~M_ITER.add(depth) as *const Pack);
                    });
                    seq_macro::seq!(N_ITER in 0..$n {
                       let rhs~N_ITER = *(rhs~N_ITER.add(depth) as *const Pack);
                    });

                    seq_macro::seq!(M_ITER in 0..$m {
                        seq_macro::seq!(N_ITER in 0..$n {
                            accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
                        });
                    });

                    depth += CPLX_N;
                }

                if depth < k {
                    seq_macro::seq!(M_ITER in 0..$m {
                       let lhs~M_ITER = partial_load(lhs~M_ITER.add(depth), k - depth);
                    });
                    seq_macro::seq!(N_ITER in 0..$n {
                       let rhs~N_ITER = partial_load(rhs~N_ITER.add(depth), k - depth);
                    });

                    seq_macro::seq!(M_ITER in 0..$m {
                        seq_macro::seq!(N_ITER in 0..$n {
                            accum[N_ITER][M_ITER] = mul_add(lhs~M_ITER, rhs~N_ITER, accum[N_ITER][M_ITER]);
                        });
                    });
                }
            }

            if conj_all {
                seq_macro::seq!(M_ITER in 0..$m {{
                    seq_macro::seq!(N_ITER in 0..$n {{
                        accum[N_ITER][M_ITER] = conj(accum[N_ITER][M_ITER]);
                    }});
                }});
            }

            let mut accum_reduced: [[num_complex::Complex<T>; $m]; $n] = core::mem::zeroed();
            seq_macro::seq!(M_ITER in 0..$m {
                seq_macro::seq!(N_ITER in 0..$n {
                    accum_reduced[N_ITER][M_ITER] = reduce_sum(accum[N_ITER][M_ITER]);
                });
            });

            if alpha_status == 0 {
                seq_macro::seq!(M_ITER in 0..$m {{
                    seq_macro::seq!(N_ITER in 0..$n {{
                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
                        *dst = beta * accum_reduced[N_ITER][M_ITER];
                    }});
                }});
            } else if alpha_status == 1 {
                seq_macro::seq!(M_ITER in 0..$m {{
                    seq_macro::seq!(N_ITER in 0..$n {{
                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
                        *dst = (beta * accum_reduced[N_ITER][M_ITER]) + *dst;
                    }});
                }});
            } else {
                seq_macro::seq!(M_ITER in 0..$m {{
                    seq_macro::seq!(N_ITER in 0..$n {{
                        let dst = dst.offset(dst_cs * N_ITER + dst_rs * M_ITER);
                        *dst = (beta * accum_reduced[N_ITER][M_ITER]) + (alpha * *dst);
                    }});
                }});
            }
        }
    };
}