faer-core 0.13.4

Basic linear algebra routines
Documentation
use std::{env, fmt::Write, fs, path::Path};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let out_dir = env::var_os("OUT_DIR").unwrap();
    let dest_path = Path::new(&out_dir).join("zip.rs");
    let mut code = String::new();
    for count in 1..=16 {
        let generics_decl = (0..count)
            .map(|i| format!("M{i}: for<'short> Mat<'short>,"))
            .collect::<Vec<_>>()
            .join("");

        let generics_use = (0..count)
            .map(|i| format!("M{i},"))
            .collect::<Vec<_>>()
            .join("");

        let generics_items = (0..count)
            .map(|i| format!("<M{i} as Mat<'_>>::Item,"))
            .collect::<Vec<_>>()
            .join("");

        let slice_decl = (0..count)
            .map(|i| format!("mut slice{i}: <M{i} as Mat<'_>>::RawSlice,"))
            .collect::<Vec<_>>()
            .join("");

        let slice_def = (0..count)
            .map(|i| format!("let slice{i} = this.tuple.{i}.get_column_slice(start, j, len);"))
            .collect::<Vec<_>>()
            .join("");

        let slice_use = (0..count)
            .map(|i| format!("slice{i},"))
            .collect::<Vec<_>>()
            .join("");

        let slice_get_elem = (0..count)
            .map(|i| format!("M{i}::get_slice_elem(&mut slice{i}, i),"))
            .collect::<Vec<_>>()
            .join("");

        let do_transpose = (0..count)
            .map(|i| format!("this.tuple.{i} = this.tuple.{i}.transpose();"))
            .collect::<Vec<_>>()
            .join("");

        let do_reverse = (0..count)
            .map(|i| format!("this.tuple.{i} = this.tuple.{i}.reverse_rows();"))
            .collect::<Vec<_>>()
            .join("");

        let args_is_col_major = (0..count)
            .map(|i| format!("this.tuple.{i}.row_stride() == 1"))
            .collect::<Vec<_>>()
            .join("&&");

        let get = (0..count)
            .map(|i| format!("this.tuple.{i}.get(i, j),"))
            .collect::<Vec<_>>()
            .join("");

        let args = (0..count)
            .map(|i| format!("this.tuple.{i}"))
            .collect::<Vec<_>>()
            .join(",");

        write!(
            code,
            r#"
#[inline(always)]
unsafe fn contiguous_impl{count}<{generics_decl}>(
    len: usize,
    op: &mut impl FnMut({generics_items}),
    {slice_decl}
) {{
    for i in 0..len {{
        (*op)({slice_get_elem});
    }}
}}

impl <{generics_decl}> Zip<({generics_use})> {{
    #[inline(always)]
    pub fn for_each(self, op: impl FnMut({generics_items})) {{
        let mut this = self;
        let mut op = op;

        if this.tuple.0.row_stride().unsigned_abs() > this.tuple.0.col_stride().unsigned_abs() {{
            {do_transpose}
        }}
        if this.tuple.0.row_stride() < 0 {{
            {do_reverse}
        }}

        let nrows = this.tuple.0.nrows();
        let ncols = this.tuple.0.ncols();

        unsafe {{
            if {args_is_col_major} {{
                let start = 0;
                let len = nrows;
                for j in 0..ncols {{
                    {slice_def}

                    contiguous_impl{count}(nrows, &mut op, {slice_use});
                }}
            }} else {{
                for j in 0..ncols {{
                    for i in 0..nrows {{
                        op({get});
                    }}
                }}
            }}
        }}
    }}

    #[inline(always)]
    pub fn for_each_triangular_lower(self, diag: Diag, op: impl FnMut({generics_items})) {{
        let strict = match diag {{
            Diag::Skip => true,
            Diag::Include => false,
        }};

        let mut this = self;
        let mut op = op;

        let mut transpose = false;
        let mut reverse_rows = false;

        if this.tuple.0.row_stride().unsigned_abs() > this.tuple.0.col_stride().unsigned_abs() {{
            {do_transpose}
            transpose = true;
        }}
        if this.tuple.0.row_stride() < 0 {{
            {do_reverse}
            reverse_rows = true;
        }}

        let nrows = this.tuple.0.nrows();
        let ncols = this.tuple.0.ncols();

        let ncols = if transpose {{ ncols }} else {{ ncols.min(nrows) }};

        unsafe {{
            if {args_is_col_major} {{
                for j in 0..ncols {{
                    let (start, end) = match (transpose, reverse_rows) {{
                        (false, false) => (j + strict as usize, nrows),
                        (false, true) => (0, (nrows - (j + strict as usize))),
                        (true, false) => (0, (j + !strict as usize).min(nrows)),
                        (true, true) => (nrows - ((j + !strict as usize).min(nrows)), nrows),
                    }};

                    let len = end - start;

                    {slice_def}

                    contiguous_impl{count}(len, &mut op, {slice_use});
                }}
            }} else {{
                for j in 0..ncols {{
                    let (start, end) = match (transpose, reverse_rows) {{
                        (false, false) => (j + strict as usize, nrows),
                        (false, true) => (0, (nrows - (j + strict as usize))),
                        (true, false) => (0, (j + !strict as usize).min(nrows)),
                        (true, true) => (nrows - ((j + !strict as usize).min(nrows)), nrows),
                    }};

                    for i in start..end {{
                        op({get});
                    }}
                }}
            }}
        }}
    }}

    #[inline(always)]
    pub fn for_each_triangular_upper(self, diag: Diag, op: impl FnMut({generics_items})) {{
        let mut this = self;
        {do_transpose}
        this.for_each_triangular_lower(diag, op);
    }}

    #[inline]
    #[track_caller]
    pub fn zip_unchecked<M{count}: for<'short> Mat<'short>>(self, last: M{count}) -> Zip<({generics_use} M{count})> {{
        let this = self;
        debug_assert!((last.nrows(), last.ncols()) == (this.tuple.0.nrows(), this.tuple.0.ncols()));
        Zip {{ tuple: ({args}, last) }}
    }}
    #[inline]
    #[track_caller]
    pub fn zip<M{count}: for<'short> Mat<'short>>(self, last: M{count}) -> Zip<({generics_use} M{count})> {{
        let this = self;
        assert!((last.nrows(), last.ncols()) == (this.tuple.0.nrows(), this.tuple.0.ncols()));
        Zip {{ tuple: ({args}, last) }}
    }}
}}
"#
        )?;
    }
    fs::write(dest_path, &code)?;
    println!("cargo:rerun-if-changed=build.rs");

    Ok(())
}