vkml 0.0.3

High-level Vulkan-based machine learning library
use bytemuck::{try_cast_slice, try_cast_slice_mut};

pub fn f32_f32_f32_cpu(
    strides_a: Vec<usize>,
    strides_b: Vec<usize>,
    dst_dims: Vec<i64>,
    src1_bytes: &[u8],
    src2_bytes: &[u8],
    dst_ptr: &mut [u8],
) {
    let num_elements: usize = dst_dims.iter().map(|d| *d as usize).product();

    let src1_f32: &[f32] = try_cast_slice(src1_bytes)
        .expect("src1 byte slice cannot be cast to f32 slice (alignment/length mismatch)");
    let src2_f32: &[f32] = try_cast_slice(src2_bytes)
        .expect("src2 byte slice cannot be cast to f32 slice (alignment/length mismatch)");
    let dst_f32: &mut [f32] = try_cast_slice_mut(dst_ptr)
        .expect("dst byte slice cannot be cast to f32 slice (alignment/length mismatch)");

    assert_eq!(dst_f32.len(), num_elements, "dst buffer size mismatch");

    let rank = dst_dims.len();
    let dims_usize: Vec<usize> = dst_dims.iter().map(|d| *d as usize).collect();

    let mut idxs = vec![0usize; rank];

    let mut off_a: usize = 0;
    let mut off_b: usize = 0;

    for dst_slot in dst_f32.iter_mut().take(num_elements) {
        *dst_slot = src1_f32[off_a].max(src2_f32[off_b]);

        for d in (0..rank).rev() {
            idxs[d] += 1;
            off_a = off_a.wrapping_add(strides_a[d]);
            off_b = off_b.wrapping_add(strides_b[d]);

            if idxs[d] < dims_usize[d] {
                break;
            } else {
                idxs[d] = 0;
                off_a = off_a.wrapping_sub(strides_a[d] * dims_usize[d]);
                off_b = off_b.wrapping_sub(strides_b[d] * dims_usize[d]);
            }
        }
    }
}