ha-ndarray 0.5.0

A hardware-accelerated n-dimensional array
Documentation
use memoize::memoize;
use ocl::Program;

use crate::Error;

use super::{build, Builder, ElementDual, TILE_SIZE, WG_SIZE};

#[memoize]
pub fn diagonal(c_type: &'static str) -> Result<Program, Error> {
    let src = format!(
        r#"
        __kernel void diagonal(
            const {c_type}* restrict matrices,
            {c_type}* restrict diagonals)
        {{
            const ulong m = get_global_id(0);
            const ulong i = get_global_id(1);
            diagonals[m, i] = matrices[m, i, i];
        }}
        "#,
    );

    build(&src)
}

#[memoize]
pub fn pad_matrices(c_type: &'static str) -> Result<Program, Error> {
    let src = format!(
        r#"
        __kernel void pad_matrices(
                ulong2 const stride_in,
                ulong2 const stride_out,
                __global const {c_type}* restrict input,
                __global {c_type}* restrict output)
        {{
            const uint w = get_global_id(0);    // matrix number
            const uint x = get_global_id(1);    // row index
            const uint y = get_global_id(2);    // column index

            const ulong offset_in = (w * stride_in.x) + (x * stride_in.y) + y;

            const ulong offset_out = (w * stride_out.x) + (x * stride_out.y) + y;

            output[offset_out] = input[offset_in];
        }}
    "#
    );

    build(&src)
}

#[memoize]
pub fn matmul(mul: ElementDual) -> Result<Program, Error> {
    debug_assert_eq!(TILE_SIZE * TILE_SIZE, WG_SIZE);

    let i_type = mul.i_type;
    let o_type = mul.o_type;
    let name = mul.name;
    let op = mul.build();

    let src = format!(
        r#"
        {op}

        __kernel void matmul(
                ulong4 const dims,
                ulong const reduce_tiles,
                __global const {i_type}* restrict left,
                __global const {i_type}* restrict right,
                __global {o_type}* restrict output)
        {{
            // x := output axis 0
            // y := reduce axis
            // z := output axis 1
            // w := matrix number

            const ulong w = get_global_id(0);
            const ulong x_tile = get_global_id(1);
            const ulong z_tile = get_global_id(2);

            const ulong x_offset = x_tile * {TILE_SIZE};
            const ulong z_offset = z_tile * {TILE_SIZE};
            const ulong left_offset = w * dims.x * dims.y;
            const ulong right_offset = w * dims.y * dims.z;

            {o_type} tile[{TILE_SIZE}][{TILE_SIZE}] = {{ 0 }};

            // initialize the local cache for the left and right tiles to zero
            {o_type} left_tile[{TILE_SIZE}][{TILE_SIZE}] = {{ 0 }};
            {o_type} right_tile[{TILE_SIZE}][{TILE_SIZE}] = {{ 0 }};

            // for each tile on the y axis
            for (ulong y_tile = 0; y_tile < reduce_tiles; y_tile++) {{
                const ulong y_offset = y_tile * {TILE_SIZE};

                // read the left and right tiles into the local cache
                #pragma unroll
                for (uint i = 0; i < {TILE_SIZE}; i++) {{
                    #pragma unroll
                    for (uint j = 0; j < {TILE_SIZE}; j++) {{
                        ulong offset_l = left_offset + ((x_offset + i) * dims.y) + (y_offset + j);
                        left_tile[i][j] = left[offset_l];

                        ulong offset_r = right_offset + ((y_offset + i) * dims.z) + (z_offset + j);
                        right_tile[i][j] = right[offset_r];
                    }}
                }}

                // tile += left tile @ right tile
                #pragma unroll
                for (uint i = 0; i < {TILE_SIZE}; i++) {{
                    #pragma unroll
                    for (uint j = 0; j < {TILE_SIZE}; j++) {{
                        #pragma unroll
                        for (uint k = 0; k < {TILE_SIZE}; k++) {{
                            tile[i][k] += {name}(left_tile[i][j], right_tile[j][k]);
                        }}
                    }}
                }}
            }}

            // write tile to output
            ulong tile_offset = (w * dims.x * dims.z) + (x_offset * dims.z) + z_offset;

            #pragma unroll
            for (uint i = 0; i < {TILE_SIZE}; i++) {{
                #pragma unroll
                for (uint j = 0; j < {TILE_SIZE}; j++) {{
                    if ((x_offset + i) < dims.x && (z_offset + j) < dims.z) {{
                        ulong offset = tile_offset + (i * dims.z) + j;
                        output[offset] = tile[i][j];
                    }}
                }}
            }}
        }}
        "#,
    );

    build(&src)
}